use std::{io::Read, marker::PhantomData, ops::RangeInclusive};
use bytes::{BufMut, BytesMut};
use snafu::OptionExt;
use crate::{
encoding::{
rle::GenericRle,
util::{read_u8, try_read_u8},
PrimitiveValueEncoder,
},
error::{OutOfSpecSnafu, Result},
memory::EstimateMemory,
};
use super::{
util::{read_varint_zigzagged, write_varint_zigzagged},
EncodingSign, NInt,
};
const MIN_RUN_LENGTH: usize = 3;
const MAX_RUN_LENGTH: usize = 127 + MIN_RUN_LENGTH;
const MAX_LITERAL_LENGTH: usize = 128;
const DELAT_RANGE: RangeInclusive<i64> = -128..=127;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
enum EncodingType {
Run { length: usize, delta: i8 },
Literals { length: usize },
}
impl EncodingType {
fn from_header<R: Read>(reader: &mut R) -> Result<Option<Self>> {
let opt_encoding = match try_read_u8(reader)?.map(|b| b as i8) {
Some(header) if header < 0 => {
let length = header.unsigned_abs() as usize;
Some(Self::Literals { length })
}
Some(header) => {
let length = header as u8 as usize + 3;
let delta = read_u8(reader)? as i8;
Some(Self::Run { length, delta })
}
None => None,
};
Ok(opt_encoding)
}
}
pub struct RleV1Decoder<N: NInt, R: Read, S: EncodingSign> {
reader: R,
decoded_ints: Vec<N>,
current_head: usize,
sign: PhantomData<S>,
}
impl<N: NInt, R: Read, S: EncodingSign> RleV1Decoder<N, R, S> {
pub fn new(reader: R) -> Self {
Self {
reader,
decoded_ints: Vec::with_capacity(MAX_RUN_LENGTH),
current_head: 0,
sign: Default::default(),
}
}
}
fn read_literals<N: NInt, R: Read, S: EncodingSign>(
reader: &mut R,
out_ints: &mut Vec<N>,
length: usize,
) -> Result<()> {
for _ in 0..length {
let lit = read_varint_zigzagged::<_, _, S>(reader)?;
out_ints.push(lit);
}
Ok(())
}
fn read_run<N: NInt, R: Read, S: EncodingSign>(
reader: &mut R,
out_ints: &mut Vec<N>,
length: usize,
delta: i8,
) -> Result<()> {
let mut base = read_varint_zigzagged::<_, _, S>(reader)?;
let length = length - 1;
out_ints.push(base);
if delta < 0 {
let delta = delta.unsigned_abs();
let delta = N::from_u8(delta);
for _ in 0..length {
base = base.checked_sub(&delta).context(OutOfSpecSnafu {
msg: "over/underflow when decoding patched base integer",
})?;
out_ints.push(base);
}
} else {
let delta = delta as u8;
let delta = N::from_u8(delta);
for _ in 0..length {
base = base.checked_add(&delta).context(OutOfSpecSnafu {
msg: "over/underflow when decoding patched base integer",
})?;
out_ints.push(base);
}
}
Ok(())
}
impl<N: NInt, R: Read, S: EncodingSign> GenericRle<N> for RleV1Decoder<N, R, S> {
fn advance(&mut self, n: usize) {
self.current_head += n;
}
fn available(&self) -> &[N] {
&self.decoded_ints[self.current_head..]
}
fn decode_batch(&mut self) -> Result<()> {
self.current_head = 0;
self.decoded_ints.clear();
match EncodingType::from_header(&mut self.reader)? {
Some(EncodingType::Literals { length }) => {
read_literals::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length)
}
Some(EncodingType::Run { length, delta }) => {
read_run::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length, delta)
}
None => OutOfSpecSnafu {
msg: "not enough values to decode",
}
.fail(),
}
}
fn skip_values(&mut self, n: usize) -> Result<()> {
let mut remaining = n;
let available_count = self.available().len();
if available_count >= remaining {
self.advance(remaining);
return Ok(());
}
self.advance(available_count);
remaining -= available_count;
while remaining > 0 {
match EncodingType::from_header(&mut self.reader)? {
Some(EncodingType::Literals { length }) => {
if length <= remaining {
for _ in 0..length {
read_varint_zigzagged::<N, _, S>(&mut self.reader)?;
}
remaining -= length;
} else {
self.decoded_ints.clear();
self.current_head = 0;
read_literals::<_, _, S>(&mut self.reader, &mut self.decoded_ints, length)?;
self.advance(remaining);
remaining = 0;
}
}
Some(EncodingType::Run { length, delta }) => {
if length <= remaining {
read_varint_zigzagged::<N, _, S>(&mut self.reader)?;
remaining -= length;
} else {
self.decoded_ints.clear();
self.current_head = 0;
read_run::<_, _, S>(
&mut self.reader,
&mut self.decoded_ints,
length,
delta,
)?;
self.advance(remaining);
remaining = 0;
}
}
None => {
return OutOfSpecSnafu {
msg: "not enough values to skip in RLE v1",
}
.fail();
}
}
}
Ok(())
}
}
#[derive(Debug, Clone, Eq, PartialEq, Default)]
enum RleV1EncodingState<N: NInt> {
#[default]
Empty,
Literal,
Run {
value: N,
delta: i8,
length: usize,
},
}
pub struct RleV1Encoder<N: NInt, S: EncodingSign> {
writer: BytesMut,
state: RleV1EncodingState<N>,
buffer: Vec<N>,
sign: PhantomData<S>,
}
impl<N: NInt, S: EncodingSign> RleV1Encoder<N, S> {
fn process_value(&mut self, value: N) {
match &mut self.state {
RleV1EncodingState::Empty => {
self.buffer.clear();
self.buffer.push(value);
self.state = RleV1EncodingState::Literal;
}
RleV1EncodingState::Literal => {
let buf = &mut self.buffer;
buf.push(value);
let length = buf.len();
let delta = (value - buf[length - 2]).as_i64();
if length >= MIN_RUN_LENGTH
&& DELAT_RANGE.contains(&delta)
&& delta == (buf[length - 2] - buf[length - 3]).as_i64()
{
if length > MIN_RUN_LENGTH {
write_literals::<_, S>(&mut self.writer, &buf[..(length - MIN_RUN_LENGTH)]);
}
self.state = RleV1EncodingState::Run {
value: buf[length - MIN_RUN_LENGTH],
delta: delta as i8,
length: MIN_RUN_LENGTH,
}
} else if length == MAX_LITERAL_LENGTH {
write_literals::<_, S>(&mut self.writer, buf);
self.state = RleV1EncodingState::Empty;
}
}
RleV1EncodingState::Run {
value: run_value,
delta,
length,
} => {
if run_value.as_i64() + (*delta as i64) * (*length as i64) == value.as_i64() {
*length += 1;
if *length == MAX_RUN_LENGTH {
write_run::<_, S>(&mut self.writer, *run_value, *delta, *length);
self.state = RleV1EncodingState::Empty;
}
} else {
write_run::<_, S>(&mut self.writer, *run_value, *delta, *length);
self.buffer.clear();
self.buffer.push(value);
self.state = RleV1EncodingState::Literal;
}
}
}
}
fn flush(&mut self) {
let state = std::mem::take(&mut self.state);
match state {
RleV1EncodingState::Empty => {}
RleV1EncodingState::Literal => {
write_literals::<_, S>(&mut self.writer, &self.buffer);
}
RleV1EncodingState::Run {
value,
delta,
length,
} => {
write_run::<_, S>(&mut self.writer, value, delta, length);
}
}
}
}
fn write_run<N: NInt, S: EncodingSign>(writer: &mut BytesMut, value: N, delta: i8, length: usize) {
writer.put_u8(length as u8 - 3);
writer.put_u8(delta as u8);
write_varint_zigzagged::<_, S>(writer, value);
}
fn write_literals<N: NInt, S: EncodingSign>(writer: &mut BytesMut, buffer: &[N]) {
writer.put_u8(-(buffer.len() as i8) as u8);
for literal in buffer {
write_varint_zigzagged::<_, S>(writer, *literal);
}
}
impl<N: NInt, S: EncodingSign> EstimateMemory for RleV1Encoder<N, S> {
fn estimate_memory_size(&self) -> usize {
self.writer.len()
}
}
impl<N: NInt, S: EncodingSign> PrimitiveValueEncoder<N> for RleV1Encoder<N, S> {
fn new() -> Self {
Self {
writer: BytesMut::new(),
state: Default::default(),
buffer: Vec::with_capacity(MAX_LITERAL_LENGTH),
sign: Default::default(),
}
}
fn write_one(&mut self, value: N) {
self.process_value(value);
}
fn take_inner(&mut self) -> bytes::Bytes {
self.flush();
std::mem::take(&mut self.writer).into()
}
}
#[cfg(test)]
mod tests {
use std::io::Cursor;
use crate::encoding::{integer::UnsignedEncoding, PrimitiveValueDecoder};
use super::*;
fn test_helper(original: &[i64], encoded: &[u8]) {
let mut encoder = RleV1Encoder::<i64, UnsignedEncoding>::new();
encoder.write_slice(original);
encoder.flush();
let actual_encoded = encoder.take_inner();
assert_eq!(actual_encoded, encoded);
let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(encoded));
let mut actual_decoded = vec![0; original.len()];
decoder.decode(&mut actual_decoded).unwrap();
assert_eq!(actual_decoded, original);
}
#[test]
fn test_run() -> Result<()> {
let original = [7; 100];
let encoded = [0x61, 0x00, 0x07];
test_helper(&original, &encoded);
let original = (1..=100).rev().collect::<Vec<_>>();
let encoded = [0x61, 0xff, 0x64];
test_helper(&original, &encoded);
let original = (1..=150).rev().collect::<Vec<_>>();
let encoded = [0x7f, 0xff, 0x96, 0x01, 0x11, 0xff, 0x14];
test_helper(&original, &encoded);
let original = [2, 4, 6, 8, 1, 3, 5, 7, 255];
let encoded = [0x01, 0x02, 0x02, 0x01, 0x02, 0x01, 0xff, 0xff, 0x01];
test_helper(&original, &encoded);
Ok(())
}
#[test]
fn test_literal() -> Result<()> {
let original = vec![2, 3, 6, 7, 11];
let encoded = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb];
test_helper(&original, &encoded);
let original = vec![2, 3, 6, 7, 11, 1, 2, 3, 0, 256];
let encoded = [
0xfb, 0x02, 0x03, 0x06, 0x07, 0x0b, 0x00, 0x01, 0x01, 0xfe, 0x00, 0x80, 0x02,
];
test_helper(&original, &encoded);
Ok(())
}
#[test]
fn test_skip_values() -> Result<()> {
let encoded = [0x61, 0x00, 0x07]; let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
let mut batch = vec![0; 10];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 10]);
decoder.skip(5)?;
let mut batch = vec![0; 5];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 5]);
let encoded = [0x61, 0x00, 0x07]; let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
decoder.skip(100)?;
let mut batch = vec![0; 1];
let result = decoder.decode(&mut batch);
assert!(result.is_err());
let encoded = [0x61, 0x00, 0x07]; let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
decoder.skip(50)?;
let mut batch = vec![0; 10];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![7; 10]);
let encoded = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
decoder.skip(5)?;
let mut batch = vec![0; 1];
let result = decoder.decode(&mut batch);
assert!(result.is_err());
let encoded = [0xfb, 0x02, 0x03, 0x06, 0x07, 0xb]; let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
decoder.skip(2)?;
let mut batch = vec![0; 3];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![6, 7, 11]);
let encoded = [0x7f, 0xff, 0x96, 0x01, 0x11, 0xff, 0x14];
let mut decoder = RleV1Decoder::<i64, _, UnsignedEncoding>::new(Cursor::new(&encoded));
decoder.skip(100)?;
let mut batch = vec![0; 10];
decoder.decode(&mut batch)?;
assert_eq!(batch, vec![50, 49, 48, 47, 46, 45, 44, 43, 42, 41]);
Ok(())
}
}