use bytemuck::must_cast_slice;
use bytes::{BufMut, BytesMut};
use snafu::ResultExt;
use crate::{
error::{IoSnafu, OutOfSpecSnafu, Result},
memory::EstimateMemory,
};
use std::io::Read;
use super::{
rle::GenericRle,
util::{read_u8, try_read_u8},
PrimitiveValueEncoder,
};
const MAX_LITERAL_LENGTH: usize = 128;
const MIN_REPEAT_LENGTH: usize = 3;
const MAX_REPEAT_LENGTH: usize = 130;
pub struct ByteRleEncoder {
writer: BytesMut,
literals: [u8; MAX_LITERAL_LENGTH],
num_literals: usize,
tail_run_length: usize,
run_value: Option<u8>,
}
impl ByteRleEncoder {
fn process_value(&mut self, value: u8) {
if self.num_literals == 0 {
self.run_value = None;
self.literals[0] = value;
self.num_literals = 1;
self.tail_run_length = 1;
} else if let Some(run_value) = self.run_value {
if value == run_value {
self.num_literals += 1;
if self.num_literals == MAX_REPEAT_LENGTH {
write_run(&mut self.writer, run_value, MAX_REPEAT_LENGTH);
self.clear_state();
}
} else {
write_run(&mut self.writer, run_value, self.num_literals);
self.run_value = None;
self.literals[0] = value;
self.num_literals = 1;
self.tail_run_length = 1;
}
} else {
if value == self.literals[self.num_literals - 1] {
self.tail_run_length += 1;
} else {
self.tail_run_length = 1;
}
if self.tail_run_length == MIN_REPEAT_LENGTH {
if self.num_literals + 1 == MIN_REPEAT_LENGTH {
self.run_value = Some(value);
self.num_literals += 1;
} else {
let len = self.num_literals - (MIN_REPEAT_LENGTH - 1);
let literals = &self.literals[..len];
write_literals(&mut self.writer, literals);
self.run_value = Some(value);
self.num_literals = MIN_REPEAT_LENGTH;
}
} else {
self.literals[self.num_literals] = value;
self.num_literals += 1;
if self.num_literals == MAX_LITERAL_LENGTH {
write_literals(&mut self.writer, &self.literals);
self.clear_state();
}
}
}
}
fn clear_state(&mut self) {
self.run_value = None;
self.tail_run_length = 0;
self.num_literals = 0;
}
fn flush(&mut self) {
if self.num_literals != 0 {
if let Some(value) = self.run_value {
write_run(&mut self.writer, value, self.num_literals);
} else {
let literals = &self.literals[..self.num_literals];
write_literals(&mut self.writer, literals);
}
self.clear_state();
}
}
}
impl EstimateMemory for ByteRleEncoder {
fn estimate_memory_size(&self) -> usize {
self.writer.len() + self.num_literals
}
}
impl PrimitiveValueEncoder<i8> for ByteRleEncoder {
fn new() -> Self {
Self {
writer: BytesMut::new(),
literals: [0; MAX_LITERAL_LENGTH],
num_literals: 0,
tail_run_length: 0,
run_value: None,
}
}
fn write_one(&mut self, value: i8) {
self.process_value(value as u8);
}
fn take_inner(&mut self) -> bytes::Bytes {
self.flush();
std::mem::take(&mut self.writer).into()
}
}
fn write_run(writer: &mut BytesMut, value: u8, run_length: usize) {
debug_assert!(
(MIN_REPEAT_LENGTH..=MAX_REPEAT_LENGTH).contains(&run_length),
"Byte RLE Run sequence must be in range 3..=130"
);
let header = run_length - MIN_REPEAT_LENGTH;
writer.put_u8(header as u8);
writer.put_u8(value);
}
fn write_literals(writer: &mut BytesMut, literals: &[u8]) {
debug_assert!(
(1..=MAX_LITERAL_LENGTH).contains(&literals.len()),
"Byte RLE Literal sequence must be in range 1..=128"
);
let header = -(literals.len() as i32);
writer.put_u8(header as u8);
writer.put_slice(literals);
}
pub struct ByteRleDecoder<R> {
reader: R,
leftovers: Vec<u8>,
index: usize,
}
impl<R: Read> ByteRleDecoder<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
leftovers: Vec::with_capacity(MAX_REPEAT_LENGTH),
index: 0,
}
}
}
impl<R: Read> GenericRle<i8> for ByteRleDecoder<R> {
fn advance(&mut self, n: usize) {
self.index += n
}
fn available(&self) -> &[i8] {
let bytes = &self.leftovers[self.index..];
must_cast_slice(bytes)
}
fn decode_batch(&mut self) -> Result<()> {
self.index = 0;
self.leftovers.clear();
let header = read_u8(&mut self.reader)?;
if header < 0x80 {
let length = header as usize + MIN_REPEAT_LENGTH;
let value = read_u8(&mut self.reader)?;
self.leftovers.extend(std::iter::repeat(value).take(length));
} else {
let length = 0x100 - header as usize;
self.leftovers.resize(length, 0);
self.reader
.read_exact(&mut self.leftovers)
.context(IoSnafu)?;
}
Ok(())
}
fn skip_values(&mut self, n: usize) -> Result<()> {
let mut remaining = n;
let available_count = self.available().len();
if available_count >= remaining {
self.advance(remaining);
return Ok(());
}
self.advance(available_count);
remaining -= available_count;
while remaining > 0 {
let header = match try_read_u8(&mut self.reader)? {
Some(byte) => byte,
None => {
return OutOfSpecSnafu {
msg: "not enough values to skip in Byte RLE",
}
.fail();
}
};
if header < 0x80 {
let length = header as usize + MIN_REPEAT_LENGTH;
if length <= remaining {
read_u8(&mut self.reader)?;
remaining -= length;
} else {
let value = read_u8(&mut self.reader)?;
self.leftovers.clear();
self.index = 0;
self.leftovers.extend(std::iter::repeat(value).take(length));
self.advance(remaining);
remaining = 0;
}
} else {
let length = 0x100 - header as usize;
if length <= remaining {
let mut discard_buffer = vec![0u8; length];
self.reader
.read_exact(&mut discard_buffer)
.context(IoSnafu)?;
remaining -= length;
} else {
self.leftovers.clear();
self.index = 0;
self.leftovers.resize(length, 0);
self.reader
.read_exact(&mut self.leftovers)
.context(IoSnafu)?;
self.advance(remaining);
remaining = 0;
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use crate::encoding::PrimitiveValueDecoder;
use super::*;
use proptest::prelude::*;
fn test_helper(data: &[u8], expected: &[i8]) {
let mut reader = ByteRleDecoder::new(Cursor::new(data));
let mut actual = vec![0; expected.len()];
reader.decode(&mut actual).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn reader_test() {
let data = [0x61u8, 0x00];
let expected = [0; 100];
test_helper(&data, &expected);
let data = [0x01, 0x01];
let expected = [1; 4];
test_helper(&data, &expected);
let data = [0xfe, 0x44, 0x45];
let expected = [0x44, 0x45];
test_helper(&data, &expected);
}
#[test]
fn test_skip_values() -> Result<()> {
let data = [0x61u8, 0x07]; let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
let mut batch = vec![0; 10];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 10]);
decoder.skip(5)?;
let mut batch = vec![0; 5];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 5]);
let data = [0x61u8, 0x07]; let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
decoder.skip(100)?;
let mut batch = vec![0; 1];
let result = decoder.decode(&mut batch);
assert!(result.is_err());
let data = [0x61u8, 0x07]; let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
decoder.skip(50)?;
let mut batch = vec![0; 10];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 10]);
let data = [0xfeu8, 0x44, 0x45]; let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
decoder.skip(2)?;
let mut batch = vec![0; 1];
let result = decoder.decode(&mut batch);
assert!(result.is_err());
let data = [0xfbu8, 0x01, 0x02, 0x03, 0x04, 0x05]; let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
decoder.skip(2)?;
let mut batch = vec![0; 3];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![3, 4, 5]);
let data = [
0x07, 0x00, 0xfdu8, 0x0b, 0x0c, 0x0d, 0x11, 0x05, ];
let mut decoder = ByteRleDecoder::new(Cursor::new(&data));
decoder.skip(12)?;
let mut batch = vec![0; 1];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![13]);
let mut batch = vec![0; 5];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![5; 5]);
Ok(())
}
fn roundtrip_byte_rle_helper(values: &[i8]) -> Result<Vec<i8>> {
let mut writer = ByteRleEncoder::new();
writer.write_slice(values);
writer.flush();
let buf = writer.take_inner();
let mut cursor = Cursor::new(&buf);
let mut reader = ByteRleDecoder::new(&mut cursor);
let mut actual = vec![0; values.len()];
reader.decode(&mut actual)?;
Ok(actual)
}
#[derive(Debug, Clone)]
enum ByteSequence {
Run(i8, usize),
Literals(Vec<i8>),
}
fn byte_sequence_strategy() -> impl Strategy<Value = ByteSequence> {
prop_oneof![
(any::<i8>(), 1..140_usize).prop_map(|(a, b)| ByteSequence::Run(a, b)),
prop::collection::vec(any::<i8>(), 1..140).prop_map(ByteSequence::Literals)
]
}
fn generate_bytes_from_sequences(sequences: Vec<ByteSequence>) -> Vec<i8> {
let mut bytes = vec![];
for sequence in sequences {
match sequence {
ByteSequence::Run(value, length) => {
bytes.extend(std::iter::repeat(value).take(length))
}
ByteSequence::Literals(literals) => bytes.extend(literals),
}
}
bytes
}
proptest! {
#[test]
fn roundtrip_byte_rle_pure_random(values: Vec<i8>) {
let out = roundtrip_byte_rle_helper(&values).unwrap();
prop_assert_eq!(out, values);
}
#[test]
fn roundtrip_byte_rle_biased(
sequences in prop::collection::vec(byte_sequence_strategy(), 1..200)
) {
let values = generate_bytes_from_sequences(sequences);
let out = roundtrip_byte_rle_helper(&values).unwrap();
prop_assert_eq!(out, values);
}
}
}