use alloc::vec::Vec;
use crate::error::Error;
const LIZARD_MIN_CLEVEL: u8 = 10;
const LIZARD_MAX_CLEVEL: u8 = 49;
const FLAG_LITERALS: u8 = 0x01;
const FLAG_FLAGS: u8 = 0x02;
const FLAG_OFFSET16: u8 = 0x04;
const FLAG_OFFSET24: u8 = 0x08;
const FLAG_LEN: u8 = 0x10;
const FLAG_UNCOMPRESSED: u8 = 0x80;
const MINMATCH: usize = 4;
const RUN_BITS_LZ4: u32 = 4;
const RUN_MASK_LZ4: u8 = (1 << RUN_BITS_LZ4) - 1;
const ML_MASK_LZ4: u8 = (1 << RUN_BITS_LZ4) - 1;
pub struct Lz4ModeDecoder;
pub fn decode_compressed_block(input: &[u8], out: &mut Vec<u8>, cap: usize) -> Result<(), Error> {
if input.is_empty() {
return Err(Error::UnexpectedEnd);
}
let mut ip = 0usize;
let clevel = input[ip];
ip += 1;
if !(LIZARD_MIN_CLEVEL..=LIZARD_MAX_CLEVEL).contains(&clevel) {
return Err(Error::Corrupt);
}
let is_lz4_mode = matches!(clevel, 10..=19 | 30..=39);
if !is_lz4_mode {
return Err(Error::Unsupported);
}
if ip >= input.len() {
return Err(Error::UnexpectedEnd);
}
let res = input[ip];
ip += 1;
if res == FLAG_UNCOMPRESSED {
if ip + 3 > input.len() {
return Err(Error::UnexpectedEnd);
}
let length = read_u24_le(&input[ip..]);
ip += 3;
if ip + length > input.len() {
return Err(Error::UnexpectedEnd);
}
if out.len() + length > cap {
return Err(Error::Corrupt);
}
out.extend_from_slice(&input[ip..ip + length]);
return Ok(());
}
if res & FLAG_LEN != 0 {
return Err(Error::Corrupt);
}
let huffman_bits = res & (FLAG_LITERALS | FLAG_FLAGS | FLAG_OFFSET16 | FLAG_OFFSET24);
if huffman_bits != 0 {
return Err(Error::Unsupported);
}
let lengths = read_raw_stream(input, &mut ip)?;
let offset16 = read_raw_stream(input, &mut ip)?;
let offset24 = read_raw_stream(input, &mut ip)?;
let flags = read_raw_stream(input, &mut ip)?;
let literals = read_raw_stream(input, &mut ip)?;
if ip != input.len() {
return Err(Error::Corrupt);
}
if !lengths.is_empty() || !offset16.is_empty() || !offset24.is_empty() {
return Err(Error::Corrupt);
}
decode_lz4_sequences(flags, literals, out, cap)
}
fn read_raw_stream<'a>(input: &'a [u8], ip: &mut usize) -> Result<&'a [u8], Error> {
if *ip + 3 > input.len() {
return Err(Error::UnexpectedEnd);
}
let len = read_u24_le(&input[*ip..]);
*ip += 3;
if *ip + len > input.len() {
return Err(Error::UnexpectedEnd);
}
let slice = &input[*ip..*ip + len];
*ip += len;
Ok(slice)
}
#[inline]
fn read_u24_le(s: &[u8]) -> usize {
(s[0] as usize) | ((s[1] as usize) << 8) | ((s[2] as usize) << 16)
}
fn decode_lz4_sequences(
flags: &[u8],
literals: &[u8],
out: &mut Vec<u8>,
cap: usize,
) -> Result<(), Error> {
let mut lp = 0usize;
for &token in flags {
let mut lit_len = (token & RUN_MASK_LZ4) as usize;
if lit_len == RUN_MASK_LZ4 as usize {
lit_len = read_length_ext(literals, &mut lp)?;
lit_len = lit_len
.checked_add(RUN_MASK_LZ4 as usize)
.ok_or(Error::Corrupt)?;
}
if lit_len > 0 {
if lp + lit_len > literals.len() {
return Err(Error::UnexpectedEnd);
}
if out.len() + lit_len > cap {
return Err(Error::Corrupt);
}
out.extend_from_slice(&literals[lp..lp + lit_len]);
lp += lit_len;
}
if lp + 2 > literals.len() {
return Err(Error::UnexpectedEnd);
}
let offset = (literals[lp] as usize) | ((literals[lp + 1] as usize) << 8);
lp += 2;
if offset == 0 {
return Err(Error::InvalidDistance);
}
if offset > out.len() {
return Err(Error::InvalidDistance);
}
let mut match_excess = ((token >> RUN_BITS_LZ4) & ML_MASK_LZ4) as usize;
if match_excess == ML_MASK_LZ4 as usize {
match_excess = read_length_ext(literals, &mut lp)?;
match_excess = match_excess
.checked_add(ML_MASK_LZ4 as usize)
.ok_or(Error::Corrupt)?;
}
let match_len = match_excess.checked_add(MINMATCH).ok_or(Error::Corrupt)?;
copy_match(out, offset, match_len, cap)?;
}
if lp < literals.len() {
let tail = literals.len() - lp;
if out.len() + tail > cap {
return Err(Error::Corrupt);
}
out.extend_from_slice(&literals[lp..]);
}
Ok(())
}
fn read_length_ext(literals: &[u8], lp: &mut usize) -> Result<usize, Error> {
if *lp >= literals.len() {
return Err(Error::UnexpectedEnd);
}
let first = literals[*lp];
if first < 254 {
*lp += 1;
return Ok(first as usize);
}
if first == 254 {
if *lp + 3 > literals.len() {
return Err(Error::UnexpectedEnd);
}
let v = (literals[*lp + 1] as usize) | ((literals[*lp + 2] as usize) << 8);
*lp += 3;
return Ok(v);
}
if *lp + 4 > literals.len() {
return Err(Error::UnexpectedEnd);
}
let v = (literals[*lp + 1] as usize)
| ((literals[*lp + 2] as usize) << 8)
| ((literals[*lp + 3] as usize) << 16);
*lp += 4;
Ok(v)
}
fn copy_match(out: &mut Vec<u8>, offset: usize, match_len: usize, cap: usize) -> Result<(), Error> {
if offset > out.len() {
return Err(Error::InvalidDistance);
}
if out.len() + match_len > cap {
return Err(Error::Corrupt);
}
let start = out.len() - offset;
if offset >= match_len {
out.extend_from_within(start..start + match_len);
} else if offset == 1 {
let b = out[start];
out.resize(out.len() + match_len, b);
} else {
for i in 0..match_len {
let b = out[start + i];
out.push(b);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_flags_just_literals() {
let mut block = alloc::vec![
10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, ];
block.extend_from_slice(b"hi");
let mut out = Vec::new();
decode_compressed_block(&block, &mut out, usize::MAX).unwrap();
assert_eq!(out, b"hi");
}
#[test]
fn in_block_uncompressed() {
let mut block = alloc::vec![
10, FLAG_UNCOMPRESSED, 5,
0,
0, ];
block.extend_from_slice(b"hello");
let mut out = Vec::new();
decode_compressed_block(&block, &mut out, usize::MAX).unwrap();
assert_eq!(out, b"hello");
}
#[test]
fn rejects_lizv1_mode() {
let block = alloc::vec![20u8, 0u8]; let mut out = Vec::new();
assert_eq!(
decode_compressed_block(&block, &mut out, usize::MAX),
Err(Error::Unsupported)
);
}
#[test]
fn rejects_huffman_flag() {
let block = alloc::vec![10u8, FLAG_LITERALS]; let mut out = Vec::new();
assert_eq!(
decode_compressed_block(&block, &mut out, usize::MAX),
Err(Error::Unsupported)
);
}
#[test]
fn rejects_bad_clevel() {
let block = alloc::vec![9u8, 0u8];
let mut out = Vec::new();
assert_eq!(
decode_compressed_block(&block, &mut out, usize::MAX),
Err(Error::Corrupt)
);
}
}