#![allow(clippy::let_unit_value)]
use crate::{CompressionError, Compressor};
use alloc::vec::Vec;
pub struct ChunkedLz4<const CHUNK_SIZE: usize = 4096>;
impl<const N: usize> ChunkedLz4<N> {
const CHUNK_SIZE_NONZERO: () = assert!(N > 0, "CHUNK_SIZE must be greater than zero");
pub fn decompress_chunk(data: &[u8], chunk_idx: usize) -> Result<Vec<u8>, CompressionError> {
_ = Self::CHUNK_SIZE_NONZERO;
let data = strip_discriminant(data, <Self as Compressor>::DISCRIMINANT)?;
let (chunk_count, _original_len) = parse_header(data)?;
if chunk_idx >= chunk_count {
return Err(CompressionError::DecompressFailed);
}
read_chunk(data, chunk_count, chunk_idx)
}
pub fn chunk_count(data: &[u8]) -> Result<usize, CompressionError> {
_ = Self::CHUNK_SIZE_NONZERO;
let data = strip_discriminant(data, <Self as Compressor>::DISCRIMINANT)?;
let (chunk_count, _original_len) = parse_header(data)?;
Ok(chunk_count)
}
}
impl<const N: usize> Compressor for ChunkedLz4<N> {
const NAME: &'static str = "chunked_lz4";
const DISCRIMINANT: u8 = 0x02;
fn compress(input: &[u8]) -> Result<Vec<u8>, CompressionError> {
_ = Self::CHUNK_SIZE_NONZERO;
let chunk_count = input.len().div_ceil(N);
debug_assert!(chunk_count <= u32::MAX as usize);
#[cfg(feature = "discriminant")]
let index_base = 9usize; #[cfg(not(feature = "discriminant"))]
let index_base = 8usize;
let header_len = index_base + chunk_count * 8;
let mut out = Vec::with_capacity(header_len);
#[cfg(feature = "discriminant")]
out.push(Self::DISCRIMINANT);
out.extend_from_slice(&(chunk_count as u32).to_le_bytes());
out.extend_from_slice(&(input.len() as u32).to_le_bytes());
out.resize(header_len, 0u8);
let mut table = alloc::vec![0u16; LZ4_HASH_SIZE];
let mut data_offset: u32 = 0;
for (i, chunk) in input.chunks(N).enumerate() {
table.fill(0); let block_start = out.len();
lz4_compress_chunk(chunk, &mut table, &mut out);
let block_len = (out.len() - block_start) as u32;
let entry = index_base + i * 8;
out[entry..entry + 4].copy_from_slice(&data_offset.to_le_bytes());
out[entry + 4..entry + 8].copy_from_slice(&block_len.to_le_bytes());
data_offset = data_offset.wrapping_add(block_len);
}
Ok(out)
}
fn decompress(input: &[u8]) -> Result<Vec<u8>, CompressionError> {
_ = Self::CHUNK_SIZE_NONZERO;
let data = strip_discriminant(input, Self::DISCRIMINANT)?;
let (chunk_count, original_len) = parse_header(data)?;
let mut out = Vec::with_capacity(original_len);
for i in 0..chunk_count {
let chunk = read_chunk(data, chunk_count, i)?;
out.extend_from_slice(&chunk);
}
if out.len() != original_len {
return Err(CompressionError::DecompressFailed);
}
Ok(out)
}
}
#[cfg(feature = "discriminant")]
fn strip_discriminant(input: &[u8], expected: u8) -> Result<&[u8], CompressionError> {
match input.split_first() {
Some((&d, rest)) if d == expected => Ok(rest),
_ => Err(CompressionError::DecompressFailed),
}
}
pub const LZ4_HASH_TABLE_WORDS: usize = 4096; const LZ4_HASH_SIZE: usize = LZ4_HASH_TABLE_WORDS;
const LZ4_HASH_BITS: u32 = 12; const LZ4_MINMATCH: usize = 4;
const LZ4_MFLIMIT: usize = 12; const LZ4_LASTLITERALS: usize = 5;
pub fn lz4_compress_chunk(input: &[u8], table: &mut [u16], out: &mut Vec<u8>) {
debug_assert_eq!(table.len(), LZ4_HASH_SIZE);
out.extend_from_slice(&(input.len() as u32).to_le_bytes());
let n = input.len();
if n < LZ4_MINMATCH {
lz4_push_lits(out, input);
return;
}
let search_limit = n.saturating_sub(LZ4_MFLIMIT);
let match_limit = n.saturating_sub(LZ4_LASTLITERALS);
let mut ip = 0usize; let mut anchor = 0usize;
loop {
if ip >= search_limit {
break;
}
let seq = unsafe { lz4_read4_unc(input, ip) };
let h = lz4_hash(seq);
let candidate = table[h] as usize;
table[h] = ip as u16;
if candidate < ip {
let offset = ip - candidate;
if offset <= 0xFFFF && unsafe { lz4_read4_unc(input, candidate) } == seq {
let mut ml = LZ4_MINMATCH;
let word_limit = match_limit.saturating_sub(3);
while ip + ml < word_limit {
let wi = unsafe { lz4_read4_unc(input, ip + ml) };
let wc = unsafe { lz4_read4_unc(input, candidate + ml) };
if wi != wc {
ml += (wi ^ wc).trailing_zeros() as usize / 8;
break;
}
ml += 4;
}
while ip + ml < match_limit && input[candidate + ml] == input[ip + ml] {
ml += 1;
}
lz4_emit_seq(out, &input[anchor..ip], offset as u16, ml);
ip += ml;
anchor = ip;
continue;
}
}
ip += 1;
}
lz4_push_lits(out, &input[anchor..]);
}
#[inline(always)]
unsafe fn lz4_read4_unc(data: &[u8], pos: usize) -> u32 {
core::ptr::read_unaligned(data.as_ptr().add(pos) as *const u32)
}
#[inline(always)]
fn lz4_hash(seq: u32) -> usize {
(seq.wrapping_mul(2654435761u32) >> (32 - LZ4_HASH_BITS)) as usize
}
fn lz4_emit_seq(out: &mut Vec<u8>, lits: &[u8], offset: u16, match_len: usize) {
let ll = lits.len();
let ml_extra = match_len - LZ4_MINMATCH;
out.push(((ll.min(15) as u8) << 4) | (ml_extra.min(15) as u8));
if ll >= 15 {
lz4_push_extra(out, ll - 15);
}
out.extend_from_slice(lits);
out.push(offset as u8);
out.push((offset >> 8) as u8);
if ml_extra >= 15 {
lz4_push_extra(out, ml_extra - 15);
}
}
fn lz4_push_lits(out: &mut Vec<u8>, lits: &[u8]) {
let n = lits.len();
out.push((n.min(15) as u8) << 4); if n >= 15 {
lz4_push_extra(out, n - 15);
}
out.extend_from_slice(lits);
}
fn lz4_push_extra(out: &mut Vec<u8>, mut remaining: usize) {
while remaining >= 255 {
out.push(255);
remaining -= 255;
}
out.push(remaining as u8);
}
#[cfg(not(feature = "discriminant"))]
fn strip_discriminant(input: &[u8], _expected: u8) -> Result<&[u8], CompressionError> {
Ok(input)
}
fn parse_header(data: &[u8]) -> Result<(usize, usize), CompressionError> {
if data.len() < 8 {
return Err(CompressionError::DecompressFailed);
}
let chunk_count = u32::from_le_bytes(data[0..4].try_into().unwrap()) as usize;
let original_len = u32::from_le_bytes(data[4..8].try_into().unwrap()) as usize;
let index_bytes = chunk_count
.checked_mul(8)
.ok_or(CompressionError::DecompressFailed)?;
let data_region_start = 8usize
.checked_add(index_bytes)
.ok_or(CompressionError::DecompressFailed)?;
if data.len() < data_region_start {
return Err(CompressionError::DecompressFailed);
}
Ok((chunk_count, original_len))
}
fn read_chunk(
data: &[u8],
chunk_count: usize,
chunk_idx: usize,
) -> Result<Vec<u8>, CompressionError> {
let entry_offset = 8 + chunk_idx * 8;
let offset =
u32::from_le_bytes(data[entry_offset..entry_offset + 4].try_into().unwrap()) as usize;
let compressed_len =
u32::from_le_bytes(data[entry_offset + 4..entry_offset + 8].try_into().unwrap()) as usize;
let data_region_start = 8 + chunk_count * 8;
let data_region = &data[data_region_start..];
let end = offset
.checked_add(compressed_len)
.ok_or(CompressionError::DecompressFailed)?;
if end > data_region.len() {
return Err(CompressionError::DecompressFailed);
}
lz4_flex::block::decompress_size_prepended(&data_region[offset..end])
.map_err(|_| CompressionError::DecompressFailed)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{CompressionError, Compressor};
use alloc::vec;
use alloc::vec::Vec;
type C = ChunkedLz4<4096>;
type C512 = ChunkedLz4<512>;
type C1024 = ChunkedLz4<1024>;
#[test]
fn roundtrip_empty() {
let compressed = C::compress(b"").unwrap();
let decompressed = C::decompress(&compressed).unwrap();
assert_eq!(decompressed, b"");
}
#[test]
fn roundtrip_single_chunk() {
let input = b"Hello Solana!";
let compressed = C::compress(input).unwrap();
let decompressed = C::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
#[test]
fn roundtrip_exact_boundary() {
let input: Vec<u8> = (0u8..=255).cycle().take(4096).collect();
let compressed = C::compress(&input).unwrap();
let decompressed = C::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
#[test]
fn roundtrip_multi_chunk() {
let input: Vec<u8> = (0u8..=255).cycle().take(10_000).collect();
let compressed = C::compress(&input).unwrap();
let decompressed = C::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
#[test]
fn roundtrip_repetitive() {
let input: Vec<u8> = b"aaaa".repeat(2048); let compressed = C::compress(&input).unwrap();
assert!(
compressed.len() < input.len(),
"repetitive data should compress"
);
let decompressed = C::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
#[test]
fn chunk_count_empty() {
let compressed = C::compress(b"").unwrap();
assert_eq!(C::chunk_count(&compressed).unwrap(), 0);
}
#[test]
fn chunk_count_single() {
let compressed = C::compress(b"hello").unwrap();
assert_eq!(C::chunk_count(&compressed).unwrap(), 1);
}
#[test]
fn chunk_count_multi() {
let input = vec![0u8; 10_000];
let compressed = C::compress(&input).unwrap();
assert_eq!(C::chunk_count(&compressed).unwrap(), 3);
}
#[test]
fn chunk_count_exact_boundary() {
let input = vec![0u8; 4096];
let compressed = C::compress(&input).unwrap();
assert_eq!(C::chunk_count(&compressed).unwrap(), 1);
}
#[test]
fn decompress_chunk_single() {
let input = b"Hello!";
let compressed = C::compress(input).unwrap();
let chunk = C::decompress_chunk(&compressed, 0).unwrap();
assert_eq!(chunk, input);
}
#[test]
fn decompress_chunk_all_multi() {
let input: Vec<u8> = (0u8..=255).cycle().take(9_000).collect();
let compressed = C::compress(&input).unwrap();
let count = C::chunk_count(&compressed).unwrap();
let mut reconstructed = Vec::new();
for i in 0..count {
let chunk = C::decompress_chunk(&compressed, i).unwrap();
reconstructed.extend_from_slice(&chunk);
}
assert_eq!(reconstructed, input);
}
#[test]
fn decompress_chunk_last_partial() {
let input = vec![42u8; 4097];
let compressed = C::compress(&input).unwrap();
let last = C::decompress_chunk(&compressed, 1).unwrap();
assert_eq!(last, vec![42u8]);
}
#[test]
fn decompress_chunk_out_of_bounds() {
let input = b"hello";
let compressed = C::compress(input).unwrap();
let result = C::decompress_chunk(&compressed, 1);
assert_eq!(result, Err(CompressionError::DecompressFailed));
}
#[test]
fn header_too_short() {
#[cfg(feature = "discriminant")]
let data: &[u8] = &[0x02, 0x01, 0x00, 0x00, 0x00];
#[cfg(not(feature = "discriminant"))]
let data: &[u8] = &[0x01, 0x00, 0x00, 0x00];
assert_eq!(C::decompress(data), Err(CompressionError::DecompressFailed));
}
#[test]
fn corrupt_chunk_bytes() {
let input = vec![42u8; 4096];
let mut compressed = C::compress(&input).unwrap();
#[cfg(feature = "discriminant")]
let dr_start = 17; #[cfg(not(feature = "discriminant"))]
let dr_start = 16; for b in compressed[dr_start..].iter_mut() {
*b = 0;
}
assert_eq!(
C::decompress(&compressed),
Err(CompressionError::DecompressFailed)
);
}
#[test]
fn compressed_len_overflow() {
let mut data: Vec<u8> = Vec::new();
#[cfg(feature = "discriminant")]
data.push(0x02); data.extend_from_slice(&1u32.to_le_bytes()); data.extend_from_slice(&5u32.to_le_bytes()); data.extend_from_slice(&0u32.to_le_bytes()); data.extend_from_slice(&u32::MAX.to_le_bytes()); assert_eq!(
C::decompress(&data),
Err(CompressionError::DecompressFailed)
);
}
#[test]
#[cfg(feature = "discriminant")]
fn wrong_discriminant() {
let mut compressed = C::compress(b"hello").unwrap();
compressed[0] = 0xFF; assert_eq!(
C::decompress(&compressed),
Err(CompressionError::DecompressFailed)
);
}
#[test]
fn discriminants_differ() {
assert_ne!(crate::Lz4::DISCRIMINANT, C::DISCRIMINANT);
assert_eq!(crate::Lz4::DISCRIMINANT, 0x01);
assert_eq!(C::DISCRIMINANT, 0x02);
}
#[test]
fn lz4_output_rejected_by_chunked() {
let lz4_compressed = crate::Lz4::compress(b"hello").unwrap();
assert_eq!(
C::decompress(&lz4_compressed),
Err(CompressionError::DecompressFailed)
);
}
#[test]
fn roundtrip_chunk_512() {
let input: Vec<u8> = (0u8..=255).cycle().take(2_000).collect();
let compressed = C512::compress(&input).unwrap();
let decompressed = C512::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
#[test]
fn roundtrip_chunk_1024() {
let input: Vec<u8> = (0u8..=255).cycle().take(3_000).collect();
let compressed = C1024::compress(&input).unwrap();
let decompressed = C1024::decompress(&compressed).unwrap();
assert_eq!(decompressed, input);
}
}