use crate::{
cumulative_encoded_len,
decode::{decode_num_scalar, DecodeQuadSink, Decoder, SliceDecodeSink, WriteQuadToSlice},
encoded_shape,
scalar::Scalar,
EncodedShape,
};
#[derive(Debug)]
pub struct DecodeCursor<'a> {
control_bytes: &'a [u8],
encoded_nums: &'a [u8],
encoded_shape: EncodedShape,
total_nums: usize,
nums_decoded: usize,
control_bytes_read: usize,
encoded_bytes_read: usize,
}
impl<'a> DecodeCursor<'a> {
pub fn new(input: &'a [u8], count: usize) -> DecodeCursor<'a> {
let shape = encoded_shape(count);
DecodeCursor {
control_bytes: &input[0..shape.control_bytes_len],
encoded_nums: &input[shape.control_bytes_len..],
encoded_shape: shape,
total_nums: count,
nums_decoded: 0,
control_bytes_read: 0,
encoded_bytes_read: 0,
}
}
pub fn skip(&mut self, to_skip: usize) {
assert_eq!(to_skip % 4, 0, "Must be a multiple of 4");
let control_bytes_to_skip = to_skip / 4;
assert!(
self.control_bytes_read + control_bytes_to_skip
<= self.encoded_shape.complete_control_bytes_len,
"Can't skip past the end of complete control bytes"
);
let slice_to_skip = &self.control_bytes
[self.control_bytes_read..(self.control_bytes_read + control_bytes_to_skip)];
let skipped_encoded_len = cumulative_encoded_len(&slice_to_skip);
self.control_bytes_read += control_bytes_to_skip;
self.encoded_bytes_read += skipped_encoded_len;
self.nums_decoded += to_skip;
}
pub fn decode_slice<D: Decoder + WriteQuadToSlice>(&mut self, output: &mut [u32]) -> usize {
let output_len = output.len();
let mut sink = SliceDecodeSink::new(output);
self.decode_sink::<D, SliceDecodeSink>(&mut sink, output_len)
}
pub fn decode_sink<D, S>(&mut self, sink: &mut S, max_numbers_to_decode: usize) -> usize
where
D: Decoder,
S: DecodeQuadSink<D> + DecodeQuadSink<Scalar>,
{
let start_nums_decoded = self.nums_decoded;
let mut complete_quad_nums_decoded_this_invocation;
let complete_control_bytes_to_decode = max_numbers_to_decode / 4;
{
let (primary_nums_decoded, primary_bytes_read) = D::decode_quads(
&self.control_bytes
[self.control_bytes_read..self.encoded_shape.complete_control_bytes_len],
&self.encoded_nums[self.encoded_bytes_read..],
complete_control_bytes_to_decode,
0,
sink,
);
complete_quad_nums_decoded_this_invocation = primary_nums_decoded;
self.nums_decoded += primary_nums_decoded;
self.encoded_bytes_read += primary_bytes_read;
self.control_bytes_read += complete_quad_nums_decoded_this_invocation / 4;
}
{
let (more_nums_decoded, more_bytes_read) = Scalar::decode_quads(
&self.control_bytes
[self.control_bytes_read..self.encoded_shape.complete_control_bytes_len],
&self.encoded_nums[self.encoded_bytes_read..],
complete_control_bytes_to_decode - complete_quad_nums_decoded_this_invocation / 4,
complete_quad_nums_decoded_this_invocation,
sink,
);
complete_quad_nums_decoded_this_invocation += more_nums_decoded;
self.encoded_bytes_read += more_bytes_read;
self.control_bytes_read += more_nums_decoded / 4;
self.nums_decoded += more_nums_decoded;
}
if max_numbers_to_decode - complete_quad_nums_decoded_this_invocation
>= self.encoded_shape.leftover_numbers
&& self.control_bytes_read == self.encoded_shape.complete_control_bytes_len
&& self.encoded_shape.leftover_numbers > 0
&& self.nums_decoded < self.total_nums
{
let control_byte = self.control_bytes[self.encoded_shape.complete_control_bytes_len];
for i in 0..self.encoded_shape.leftover_numbers {
let bitmask = 0x03 << (i * 2);
let len = ((control_byte & bitmask) >> (i * 2)) as usize + 1;
sink.on_number(
decode_num_scalar(len, &self.encoded_nums[self.encoded_bytes_read..]),
complete_quad_nums_decoded_this_invocation + i,
);
self.nums_decoded += 1;
self.encoded_bytes_read += len;
}
}
self.nums_decoded - start_nums_decoded
}
pub fn input_consumed(&self) -> usize {
self.encoded_shape.control_bytes_len + self.encoded_bytes_read
}
pub fn has_more(&self) -> bool {
self.nums_decoded < self.total_nums
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode::encode;
#[test]
#[should_panic(expected = "Must be a multiple of 4")]
fn skip_panics_on_not_multiple_of_4() {
DecodeCursor::new(&vec![], 0).skip(3)
}
#[test]
#[should_panic(expected = "Can't skip past the end of complete control bytes")]
fn skip_panics_on_exceeding_full_quads() {
let nums: Vec<u32> = (0..100).collect();
let mut encoded = Vec::new();
encoded.resize(nums.len() * 5, 0);
let encoded_len = encode::<Scalar>(&nums, &mut encoded);
DecodeCursor::new(&encoded[0..encoded_len], nums.len()).skip(104);
}
#[test]
fn skip_entire_enput_is_done() {
let nums: Vec<u32> = (0..100).collect();
let mut encoded = Vec::new();
encoded.resize(nums.len() * 5, 0);
let encoded_len = encode::<Scalar>(&nums, &mut encoded);
let mut cursor = DecodeCursor::new(&encoded[0..encoded_len], nums.len());
assert!(cursor.has_more());
cursor.skip(100);
assert!(!cursor.has_more());
let mut decoded: Vec<u32> = (0..100).map(|_| 0).collect();
assert_eq!(100, decoded.len());
assert_eq!(0, cursor.decode_slice::<Scalar>(&mut decoded[..]))
}
}