use crate::chunk_header::{ChunkHeader, ChunkType};
use crate::compression::{CompressOptions, CompressionType, decompress_data};
use crate::error::RiegeliError;
use crate::varint::{decode_u64, encode_u64};
#[derive(Debug, Clone)]
pub struct Chunk {
pub header: ChunkHeader,
pub data: Vec<u8>,
}
pub struct SimpleChunkEncoder {
records: Vec<Vec<u8>>,
compression: CompressionType,
compress_opts: CompressOptions,
}
impl SimpleChunkEncoder {
pub fn new() -> Self {
Self {
records: Vec::new(),
compression: CompressionType::None,
compress_opts: CompressOptions::default(),
}
}
pub fn with_compression(compression: CompressionType) -> Self {
Self {
records: Vec::new(),
compression,
compress_opts: CompressOptions::default(),
}
}
pub fn with_options(compression: CompressionType, compress_opts: CompressOptions) -> Self {
Self {
records: Vec::new(),
compression,
compress_opts,
}
}
pub fn add_record(&mut self, data: &[u8]) {
self.records.push(data.to_vec());
}
pub fn encode(self) -> Result<Chunk, RiegeliError> {
let num_records = self.records.len() as u64;
let decoded_data_size: u64 = self.records.iter().map(|r| r.len() as u64).sum();
let data = match self.compression {
CompressionType::None => {
let mut sizes_section: Vec<u8> = Vec::new();
for record in &self.records {
sizes_section.extend_from_slice(&encode_u64(record.len() as u64));
}
let mut data: Vec<u8> = Vec::new();
data.push(0x00);
data.extend_from_slice(&encode_u64(sizes_section.len() as u64));
data.extend_from_slice(&sizes_section);
for record in &self.records {
data.extend_from_slice(record);
}
data
}
CompressionType::Brotli => {
#[cfg(feature = "brotli")]
{
let opts = self.compress_opts;
encode_compressed(&self.records, CompressionType::Brotli, |b| {
crate::compression::compress_brotli(b, opts)
})?
}
#[cfg(not(feature = "brotli"))]
{
return Err(RiegeliError::UnsupportedCompression(
CompressionType::Brotli as u8,
));
}
}
CompressionType::Zstd => {
#[cfg(feature = "zstd")]
{
let opts = self.compress_opts;
encode_compressed(&self.records, CompressionType::Zstd, |b| {
crate::compression::compress_zstd(b, opts)
})?
}
#[cfg(not(feature = "zstd"))]
{
return Err(RiegeliError::UnsupportedCompression(
CompressionType::Zstd as u8,
));
}
}
CompressionType::Snappy => {
#[cfg(feature = "snappy")]
{
encode_compressed(&self.records, CompressionType::Snappy, |b| {
crate::compression::compress_snappy(b)
})?
}
#[cfg(not(feature = "snappy"))]
{
return Err(RiegeliError::UnsupportedCompression(
CompressionType::Snappy as u8,
));
}
}
};
let header =
ChunkHeader::from_parts(&data, ChunkType::Simple, num_records, decoded_data_size);
Ok(Chunk { header, data })
}
}
fn encode_compressed<F>(
records: &[Vec<u8>],
compression: CompressionType,
compress: F,
) -> Result<Vec<u8>, RiegeliError>
where
F: Fn(&[u8]) -> Result<Vec<u8>, RiegeliError>,
{
use crate::varint::length_varint_u64;
let mut sizes_section: Vec<u8> = Vec::new();
for record in records {
sizes_section.extend_from_slice(&encode_u64(record.len() as u64));
}
let mut values_section: Vec<u8> = Vec::new();
for record in records {
values_section.extend_from_slice(record);
}
let uncompressed_sizes_len = sizes_section.len() as u64;
let uncompressed_values_len = values_section.len() as u64;
let compressed_sizes = compress(&sizes_section)?;
let compressed_values = compress(&values_section)?;
let uncompressed_sizes_varint_len = length_varint_u64(uncompressed_sizes_len);
let total_sizes_blob_len = uncompressed_sizes_varint_len as u64 + compressed_sizes.len() as u64;
let mut data: Vec<u8> = Vec::new();
data.push(compression as u8);
data.extend_from_slice(&encode_u64(total_sizes_blob_len));
data.extend_from_slice(&encode_u64(uncompressed_sizes_len));
data.extend_from_slice(&compressed_sizes);
data.extend_from_slice(&encode_u64(uncompressed_values_len));
data.extend_from_slice(&compressed_values);
Ok(data)
}
impl Default for SimpleChunkEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct SimpleChunkDecoder {
record_ranges: Vec<(usize, usize)>,
values: Vec<u8>,
next_record: usize,
}
impl SimpleChunkDecoder {
pub fn new(chunk: Chunk) -> Result<Self, RiegeliError> {
if !chunk.header.is_data_valid(&chunk.data) {
return Err(RiegeliError::DataHashMismatch);
}
let data = &chunk.data;
let num_records = chunk.header.num_records() as usize;
let _decoded_data_size = chunk.header.decoded_data_size() as usize;
if data.is_empty() {
return Err(RiegeliError::MalformedData(
"chunk data is empty".to_string(),
));
}
let compression_byte = data[0];
match compression_byte {
0x00 => {
decode_uncompressed(&data[1..], num_records)
}
b'b' => {
#[cfg(feature = "brotli")]
{
decode_compressed(&data[1..], num_records, CompressionType::Brotli)
}
#[cfg(not(feature = "brotli"))]
{
Err(RiegeliError::UnsupportedCompression(b'b'))
}
}
b'z' => {
#[cfg(feature = "zstd")]
{
decode_compressed(&data[1..], num_records, CompressionType::Zstd)
}
#[cfg(not(feature = "zstd"))]
{
Err(RiegeliError::UnsupportedCompression(b'z'))
}
}
b's' => {
#[cfg(feature = "snappy")]
{
decode_compressed(&data[1..], num_records, CompressionType::Snappy)
}
#[cfg(not(feature = "snappy"))]
{
Err(RiegeliError::UnsupportedCompression(b's'))
}
}
other => Err(RiegeliError::UnsupportedCompression(other)),
}
}
pub fn read_record(&mut self) -> Result<Option<Vec<u8>>, RiegeliError> {
if self.next_record >= self.record_ranges.len() {
return Ok(None);
}
let (offset, len) = self.record_ranges[self.next_record];
self.next_record += 1;
Ok(Some(self.values[offset..offset + len].to_vec()))
}
}
fn decode_uncompressed(
payload: &[u8],
num_records: usize,
) -> Result<SimpleChunkDecoder, RiegeliError> {
let (sizes_byte_len, varint_consumed) = decode_u64(payload).map_err(|e| {
RiegeliError::MalformedData(format!("failed to read sizes_byte_length: {e}"))
})?;
let sizes_byte_len = sizes_byte_len as usize;
let sizes_start = varint_consumed;
if sizes_start + sizes_byte_len > payload.len() {
return Err(RiegeliError::MalformedData(format!(
"sizes section truncated: need {sizes_byte_len} bytes starting at offset {sizes_start}, \
but payload is only {} bytes",
payload.len()
)));
}
let sizes_data = &payload[sizes_start..sizes_start + sizes_byte_len];
let values_start = sizes_start + sizes_byte_len;
let mut pos = 0usize;
let mut sizes: Vec<usize> = Vec::with_capacity(num_records);
for i in 0..num_records {
if pos >= sizes_data.len() {
return Err(RiegeliError::MalformedData(format!(
"unexpected end of sizes section at record {i}"
)));
}
let (size, consumed) = decode_u64(&sizes_data[pos..]).map_err(|e| {
RiegeliError::MalformedData(format!("varint decode error at record {i}: {e}"))
})?;
pos += consumed;
sizes.push(size as usize);
}
let total_values_len: usize = sizes.iter().sum();
if values_start + total_values_len > payload.len() {
return Err(RiegeliError::MalformedData(format!(
"values section truncated: need {total_values_len} bytes but only {} available",
payload.len() - values_start
)));
}
let mut record_ranges: Vec<(usize, usize)> = Vec::with_capacity(num_records);
let mut offset = 0usize;
for size in &sizes {
record_ranges.push((offset, *size));
offset += size;
}
let values = payload[values_start..values_start + total_values_len].to_vec();
Ok(SimpleChunkDecoder {
record_ranges,
values,
next_record: 0,
})
}
fn decode_compressed(
payload: &[u8],
num_records: usize,
compression: CompressionType,
) -> Result<SimpleChunkDecoder, RiegeliError> {
if payload.is_empty() {
return Err(RiegeliError::MalformedData(
"compressed payload is empty".to_string(),
));
}
let mut pos = 0usize;
let (sizes_blob_len, consumed) = decode_u64(&payload[pos..])
.map_err(|e| RiegeliError::MalformedData(format!("failed to read sizes_blob_len: {e}")))?;
pos += consumed;
let sizes_blob_len = sizes_blob_len as usize;
if pos + sizes_blob_len > payload.len() {
return Err(RiegeliError::MalformedData(format!(
"sizes blob truncated: need {sizes_blob_len} bytes at offset {pos}, \
payload is {} bytes",
payload.len()
)));
}
let sizes_blob = &payload[pos..pos + sizes_blob_len];
pos += sizes_blob_len;
let (uncompressed_sizes_len, consumed2) = decode_u64(sizes_blob).map_err(|e| {
RiegeliError::MalformedData(format!("failed to read uncompressed_sizes_len: {e}"))
})?;
let uncompressed_sizes_len = uncompressed_sizes_len as usize;
let compressed_sizes = &sizes_blob[consumed2..];
let sizes_bytes = decompress_data(compressed_sizes, compression)?;
if sizes_bytes.len() != uncompressed_sizes_len {
return Err(RiegeliError::MalformedData(format!(
"decompressed sizes length {} != expected {}",
sizes_bytes.len(),
uncompressed_sizes_len
)));
}
let values_blob = &payload[pos..];
let (uncompressed_values_len, consumed3) = decode_u64(values_blob).map_err(|e| {
RiegeliError::MalformedData(format!("failed to read uncompressed_values_len: {e}"))
})?;
let _uncompressed_values_len = uncompressed_values_len as usize;
let compressed_values = &values_blob[consumed3..];
let values_bytes = decompress_data(compressed_values, compression)?;
let mut spos = 0usize;
let mut sizes: Vec<usize> = Vec::with_capacity(num_records);
for i in 0..num_records {
if spos >= sizes_bytes.len() {
return Err(RiegeliError::MalformedData(format!(
"unexpected end of decompressed sizes at record {i}"
)));
}
let (size, consumed) = decode_u64(&sizes_bytes[spos..]).map_err(|e| {
RiegeliError::MalformedData(format!("varint decode in sizes at record {i}: {e}"))
})?;
spos += consumed;
sizes.push(size as usize);
}
let total_values_len: usize = sizes.iter().sum();
if total_values_len != values_bytes.len() {
return Err(RiegeliError::MalformedData(format!(
"values length mismatch: sizes sum {total_values_len} != decompressed values {}",
values_bytes.len()
)));
}
let mut record_ranges: Vec<(usize, usize)> = Vec::with_capacity(num_records);
let mut offset = 0usize;
for size in &sizes {
record_ranges.push((offset, *size));
offset += size;
}
Ok(SimpleChunkDecoder {
record_ranges,
values: values_bytes,
next_record: 0,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hash::highway_hash_64;
#[test]
fn encode_zero_records() {
let encoder = SimpleChunkEncoder::new();
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 0);
assert!(
chunk.header.data_size() > 0,
"data_size must be > 0 (has compression byte)"
);
assert_eq!(chunk.header.decoded_data_size(), 0);
}
#[test]
fn encode_decode_hello() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"hello");
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let record = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(record, b"hello");
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
fn encode_decode_three_records() {
let records: &[&[u8]] = &[b"alpha", b"bb", b"ccccc"];
let mut encoder = SimpleChunkEncoder::new();
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for expected in records {
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got.as_slice(), *expected);
}
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
fn data_hash_matches() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"hello");
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.data_hash(), highway_hash_64(&chunk.data));
}
#[test]
fn header_hash_valid() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"test record");
let chunk = encoder.encode().expect("encode ok");
assert!(chunk.header.is_header_valid());
}
#[test]
fn round_trip_no_corruption() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"foo");
encoder.add_record(b"bar");
let chunk = encoder.encode().expect("encode ok");
assert!(SimpleChunkDecoder::new(chunk).is_ok());
}
#[test]
fn corrupted_data_hash_returns_err() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"hello");
let mut chunk = encoder.encode().expect("encode ok");
let mut bytes = chunk.header.to_bytes();
bytes[16] ^= 0xff;
chunk.header = ChunkHeader::from_bytes(bytes);
let result = SimpleChunkDecoder::new(chunk);
assert!(matches!(result, Err(RiegeliError::DataHashMismatch)));
}
#[test]
fn exact_byte_layout_hello() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"hello");
let chunk = encoder.encode().expect("encode ok");
let expected: &[u8] = &[0x00, 0x01, 0x05, b'h', b'e', b'l', b'l', b'o'];
assert_eq!(chunk.data, expected);
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_first_byte_is_b() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"hello brotli");
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.data[0], b'b', "first data byte must be b'b' (0x62)");
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_round_trip_single_record() {
let input = b"hello compressed world";
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(input);
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got, input);
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "zstd")]
fn zstd_round_trip_single_record() {
let input = b"hello zstd world";
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
encoder.add_record(input);
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got, input);
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "snappy")]
fn snappy_round_trip_single_record() {
let input = b"hello snappy world";
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Snappy);
encoder.add_record(input);
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got, input);
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_sizes_section_has_varint_prefix() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"hello");
let chunk = encoder.encode().expect("encode ok");
let data = &chunk.data;
assert_eq!(data[0], b'b');
let (sizes_blob_len, vlen1) = decode_u64(&data[1..]).expect("varint decode ok");
let blob = &data[1 + vlen1..1 + vlen1 + sizes_blob_len as usize];
let (uncompressed_sizes_len, _) = decode_u64(blob).expect("varint decode ok");
assert_eq!(
uncompressed_sizes_len, 1u64,
"sizes_section for 1 record of any size is 1 varint byte for 5 bytes"
);
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_sizes_section_prefix_three_records() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"hello");
encoder.add_record(b"bb");
encoder.add_record(b"world");
let chunk = encoder.encode().expect("encode ok");
let data = &chunk.data;
let (sizes_blob_len, vlen1) = decode_u64(&data[1..]).expect("varint decode ok");
let blob = &data[1 + vlen1..1 + vlen1 + sizes_blob_len as usize];
let (uncompressed_sizes_len, _) = decode_u64(blob).expect("varint decode ok");
assert_eq!(
uncompressed_sizes_len, 3u64,
"three records each needing 1 varint byte = 3 bytes"
);
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_compression_actually_compresses() {
let record: Vec<u8> = b"AAAAAAAAAA".iter().cycle().take(1024).cloned().collect();
let mut enc_compressed = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
let mut enc_none = SimpleChunkEncoder::new();
for _ in 0..1000 {
enc_compressed.add_record(&record);
enc_none.add_record(&record);
}
let compressed_chunk = enc_compressed.encode().expect("encode ok");
let uncompressed_chunk = enc_none.encode().expect("encode ok");
assert!(
compressed_chunk.data.len() < uncompressed_chunk.data.len(),
"compressed={} should be < uncompressed={}",
compressed_chunk.data.len(),
uncompressed_chunk.data.len()
);
}
#[test]
fn unsupported_compression_byte_returns_err() {
let data: Vec<u8> = vec![0xFF, 0x00]; let header = ChunkHeader::from_parts(&data, ChunkType::Simple, 0, 0);
let chunk = Chunk { header, data };
let result = SimpleChunkDecoder::new(chunk);
assert!(
matches!(result, Err(RiegeliError::UnsupportedCompression(0xFF))),
"expected UnsupportedCompression(0xFF), got: {result:?}"
);
}
#[test]
#[cfg(feature = "brotli")]
fn decoded_data_size_brotli() {
let records: &[&[u8]] = &[b"hello", b"world", b"foo"];
let expected_sum: u64 = records.iter().map(|r| r.len() as u64).sum();
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.decoded_data_size(), expected_sum);
}
#[test]
#[cfg(feature = "zstd")]
fn decoded_data_size_zstd() {
let records: &[&[u8]] = &[b"hello", b"world", b"foo"];
let expected_sum: u64 = records.iter().map(|r| r.len() as u64).sum();
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.decoded_data_size(), expected_sum);
}
#[test]
fn decoded_data_size_none() {
let records: &[&[u8]] = &[b"hello", b"world", b"foo"];
let expected_sum: u64 = records.iter().map(|r| r.len() as u64).sum();
let mut encoder = SimpleChunkEncoder::new();
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.decoded_data_size(), expected_sum);
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_round_trip_multiple_records() {
let records: &[&[u8]] = &[b"alpha", b"beta", b"gamma delta epsilon"];
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for expected in records {
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got.as_slice(), *expected);
}
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "zstd")]
fn zstd_round_trip_multiple_records() {
let records: &[&[u8]] = &[b"alpha", b"beta", b"gamma delta epsilon"];
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
for r in records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for expected in records {
let got = decoder
.read_record()
.expect("no error")
.expect("has record");
assert_eq!(got.as_slice(), *expected);
}
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_zero_records() {
let encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 0);
assert_eq!(chunk.header.decoded_data_size(), 0);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "zstd")]
fn zstd_zero_records() {
let encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 0);
assert_eq!(chunk.header.decoded_data_size(), 0);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
assert!(decoder.read_record().expect("no error").is_none());
}
#[test]
#[cfg(feature = "brotli")]
fn brotli_chunk_hashes_valid() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"test");
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.data_hash(), highway_hash_64(&chunk.data));
assert!(chunk.header.is_header_valid());
}
#[test]
fn adversarial_1000_records() {
let mut encoder = SimpleChunkEncoder::new();
for i in 0u64..1000 {
encoder.add_record(&i.to_le_bytes());
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid");
for i in 0u64..1000 {
let got = decoder.read_record().expect("ok").expect("record");
assert_eq!(got, i.to_le_bytes());
}
assert!(decoder.read_record().expect("ok").is_none());
}
#[test]
fn test_decode_zero_records() {
let encoder = SimpleChunkEncoder::new();
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_empty_record_roundtrip() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"");
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 1);
assert_eq!(chunk.header.decoded_data_size(), 0);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let record = decoder
.read_record()
.unwrap()
.expect("should have one record");
assert_eq!(record, b"");
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_multiple_empty_records() {
let mut encoder = SimpleChunkEncoder::new();
for _ in 0..5 {
encoder.add_record(b"");
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 5);
assert_eq!(chunk.header.decoded_data_size(), 0);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for _ in 0..5 {
let record = decoder.read_record().unwrap().expect("should have record");
assert_eq!(record, b"");
}
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_varint_boundary_128_bytes() {
let record_128 = vec![0xAB_u8; 128];
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(&record_128);
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.data[2], 0x80);
assert_eq!(chunk.data[3], 0x01);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let result = decoder.read_record().unwrap().expect("should have record");
assert_eq!(result, record_128);
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_varint_boundary_mixed_sizes() {
let record_127 = vec![0x01_u8; 127];
let record_128 = vec![0x02_u8; 128];
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(&record_127);
encoder.add_record(&record_128);
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
assert_eq!(
decoder.read_record().unwrap().expect("record 1"),
record_127
);
assert_eq!(
decoder.read_record().unwrap().expect("record 2"),
record_128
);
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_bit_flip_in_data_detected() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"important data");
let mut chunk = encoder.encode().expect("encode ok");
let last = chunk.data.len() - 1;
chunk.data[last] ^= 0x01;
assert!(matches!(
SimpleChunkDecoder::new(chunk),
Err(RiegeliError::DataHashMismatch)
));
}
#[test]
fn test_bit_flip_in_compression_byte_detected() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"test");
let mut chunk = encoder.encode().expect("encode ok");
chunk.data[0] ^= 0x01;
assert!(matches!(
SimpleChunkDecoder::new(chunk),
Err(RiegeliError::DataHashMismatch)
));
}
#[test]
fn test_truncated_data_returns_error() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"hello world, this is a longer record");
let mut chunk = encoder.encode().expect("encode ok");
let original_len = chunk.data.len();
chunk.data.truncate(original_len / 2);
assert!(SimpleChunkDecoder::new(chunk).is_err());
}
#[test]
fn test_empty_data_returns_error() {
use crate::chunk_header::ChunkType;
let data: Vec<u8> = vec![];
let header = crate::chunk_header::ChunkHeader::from_parts(&data, ChunkType::Simple, 0, 0);
let chunk = Chunk { header, data };
assert!(SimpleChunkDecoder::new(chunk).is_err());
}
#[test]
fn test_record_count_exceeds_sizes() {
use crate::chunk_header::ChunkType;
let data: Vec<u8> = vec![0x00, 0x01, 0x05, b'h', b'e', b'l', b'l', b'o'];
let header = crate::chunk_header::ChunkHeader::from_parts(&data, ChunkType::Simple, 3, 5);
let chunk = Chunk { header, data };
assert!(SimpleChunkDecoder::new(chunk).is_err());
}
#[test]
fn test_values_section_truncated() {
use crate::chunk_header::ChunkType;
let data: Vec<u8> = vec![0x00, 0x01, 0x0A, b'a', b'b', b'c'];
let header = crate::chunk_header::ChunkHeader::from_parts(&data, ChunkType::Simple, 1, 10);
let chunk = Chunk { header, data };
assert!(SimpleChunkDecoder::new(chunk).is_err());
}
#[test]
fn test_invalid_compression_bytes_rejected() {
use crate::chunk_header::ChunkType;
let valid: &[u8] = &[0x00, b'b', b'z', b's'];
for byte in 0u8..=255 {
if valid.contains(&byte) {
continue;
}
let data: Vec<u8> = vec![byte, 0x00];
let header =
crate::chunk_header::ChunkHeader::from_parts(&data, ChunkType::Simple, 0, 0);
let chunk = Chunk {
header,
data: data.clone(),
};
assert!(
matches!(
SimpleChunkDecoder::new(chunk),
Err(RiegeliError::UnsupportedCompression(b)) if b == byte
),
"byte {byte:#04x} should return UnsupportedCompression"
);
}
}
#[test]
fn test_1000_varying_length_records() {
let mut encoder = SimpleChunkEncoder::new();
let mut expected: Vec<Vec<u8>> = Vec::with_capacity(1000);
for i in 0u32..1000 {
let record: Vec<u8> = (0..i).map(|b| (b % 256) as u8).collect();
encoder.add_record(&record);
expected.push(record);
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 1000);
assert_eq!(chunk.header.decoded_data_size(), 499_500);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for (i, exp) in expected.iter().enumerate() {
let got = decoder.read_record().unwrap().unwrap_or_else(|| {
panic!("expected record {i} but got None");
});
assert_eq!(got, *exp, "mismatch at record {i}");
}
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
fn test_data_hash_various_inputs() {
for input in [b"" as &[u8], b"a", b"hello world", &[0xFF; 10000]] {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(input);
let chunk = encoder.encode().expect("encode ok");
assert_eq!(
chunk.header.data_hash(),
crate::hash::highway_hash_64(&chunk.data),
"data_hash mismatch for input of length {}",
input.len()
);
assert!(
chunk.header.is_header_valid(),
"header_hash invalid for input of length {}",
input.len()
);
}
}
#[test]
fn test_repeated_none_after_exhaustion() {
let mut encoder = SimpleChunkEncoder::new();
encoder.add_record(b"only");
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let _ = decoder.read_record().unwrap().expect("record");
for _ in 0..5 {
assert!(decoder.read_record().unwrap().is_none());
}
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_five_empty_records() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for _ in 0..5 {
encoder.add_record(b"");
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.header.num_records(), 5);
assert_eq!(chunk.header.decoded_data_size(), 0);
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for _ in 0..5 {
let record = decoder.read_record().unwrap().expect("should have record");
assert_eq!(record, b"");
}
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_128_byte_varint_boundary() {
let record = vec![0xAB_u8; 128];
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(&record);
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let got = decoder.read_record().unwrap().expect("has record");
assert_eq!(got, record);
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
#[cfg(feature = "zstd")]
fn test_zstd_mixed_record_sizes() {
let records: Vec<Vec<u8>> = vec![
vec![], vec![0x42; 1], vec![0x43; 127], vec![0x44; 128], vec![0x45; 16384], ];
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
for r in &records {
encoder.add_record(r);
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for (i, expected) in records.iter().enumerate() {
let got = decoder.read_record().unwrap().unwrap_or_else(|| {
panic!("expected record {i} but got None");
});
assert_eq!(got, *expected, "mismatch at record {i}");
}
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
#[cfg(feature = "zstd")]
fn test_zstd_1000_records() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Zstd);
let mut expected: Vec<Vec<u8>> = Vec::with_capacity(1000);
for i in 0u32..1000 {
let record: Vec<u8> = (0..i).map(|b| (b % 256) as u8).collect();
encoder.add_record(&record);
expected.push(record);
}
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
for (i, exp) in expected.iter().enumerate() {
let got = decoder.read_record().unwrap().unwrap_or_else(|| {
panic!("expected record {i} but got None");
});
assert_eq!(got, *exp, "mismatch at record {i}");
}
assert!(decoder.read_record().unwrap().is_none());
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_varint_prefix_value_10_records() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for _ in 0..10 {
encoder.add_record(&vec![0xAA; 200]);
}
let chunk = encoder.encode().expect("encode ok");
assert_eq!(chunk.data[0], b'b');
let (blob_len, blob_len_consumed) = decode_u64(&chunk.data[1..]).expect("varint decode ok");
let blob_start = 1 + blob_len_consumed;
let blob_data = &chunk.data[blob_start..blob_start + blob_len as usize];
let (uncompressed_sizes_len, _) = decode_u64(blob_data).expect("varint decode ok");
assert_eq!(
uncompressed_sizes_len, 20,
"10 records of 200 bytes each: sizes_section should be 20 bytes"
);
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_corrupted_varint_prefix_fails_decode() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"hello");
encoder.add_record(b"world");
let chunk = encoder.encode().expect("encode ok");
let original_data = chunk.data.clone();
assert_eq!(original_data[0], b'b');
let (original_prefix, prefix_len) =
decode_u64(&original_data[1..]).expect("varint decode ok");
let wrong_prefix = original_prefix + 100;
let wrong_prefix_bytes = encode_u64(wrong_prefix);
let mut corrupted_data = Vec::new();
corrupted_data.push(b'b');
corrupted_data.extend_from_slice(&wrong_prefix_bytes);
corrupted_data.extend_from_slice(&original_data[1 + prefix_len..]);
use crate::chunk_header::ChunkType;
let header =
crate::chunk_header::ChunkHeader::from_parts(&corrupted_data, ChunkType::Simple, 2, 10);
let corrupted_chunk = Chunk {
header,
data: corrupted_data,
};
let result = SimpleChunkDecoder::new(corrupted_chunk);
assert!(
result.is_err(),
"Decoder should fail when varint64(uncompressed_sizes_len) prefix is corrupted"
);
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_varint_prefix_determines_split() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
for i in 0..5 {
encoder.add_record(&vec![0x42; 50 + i * 10]); }
let chunk = encoder.encode().expect("encode ok");
let (blob_len, blob_len_consumed) = decode_u64(&chunk.data[1..]).expect("varint decode ok");
let blob_start = 1 + blob_len_consumed;
let blob_data = &chunk.data[blob_start..blob_start + blob_len as usize];
let (uncompressed_sizes_len, _) = decode_u64(blob_data).expect("varint decode ok");
assert_eq!(
uncompressed_sizes_len, 5,
"5 records with sizes < 128 should produce 5-byte sizes section"
);
assert_eq!(chunk.data[0], b'b');
let rest_len = chunk.data.len() - blob_start - blob_len as usize;
assert!(
rest_len > 0,
"should have compressed values data after sizes blob"
);
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_repeated_none_after_exhaustion() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"only");
let chunk = encoder.encode().expect("encode ok");
let mut decoder = SimpleChunkDecoder::new(chunk).expect("valid chunk");
let _ = decoder.read_record().unwrap().expect("record");
for _ in 0..5 {
assert!(decoder.read_record().unwrap().is_none());
}
}
#[test]
#[cfg(feature = "brotli")]
fn test_brotli_bit_flip_in_compressed_data() {
let mut encoder = SimpleChunkEncoder::with_compression(CompressionType::Brotli);
encoder.add_record(b"important data that should be protected");
let mut chunk = encoder.encode().expect("encode ok");
if chunk.data.len() > 5 {
chunk.data[5] ^= 0x01;
}
let result = SimpleChunkDecoder::new(chunk);
assert!(
matches!(result, Err(RiegeliError::DataHashMismatch)),
"bit flip in compressed data should cause DataHashMismatch"
);
}
}