use std::io::{self, Read, Write};
use lz4_flex::frame::{FrameDecoder, FrameEncoder};
use super::skippable_frame::{self, FrameReader, MAX_HEADER_SIZE, SKIPPABLE_FRAME_MAGIC};
use super::{Decoder, Encoder, method};
const LZ4_HEADER_SIZE: usize = 12;
const LZ4_FRAME_SIZE: u32 = 4;
pub struct Lz4Decoder<R: Read> {
inner: Option<FrameDecoder<FrameReader<R, LZ4_HEADER_SIZE>>>,
}
impl<R: Read> std::fmt::Debug for Lz4Decoder<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Lz4Decoder").finish_non_exhaustive()
}
}
impl<R: Read + Send> Lz4Decoder<R> {
pub fn new(mut input: R) -> io::Result<Self> {
let mut header = [0u8; MAX_HEADER_SIZE];
let header_read =
skippable_frame::read_full_or_eof(&mut input, &mut header[..LZ4_HEADER_SIZE])?;
if header_read == 0 {
return Ok(Self { inner: None });
}
let frame_reader = if header_read == LZ4_HEADER_SIZE {
if let Some(compressed_size) =
validate_lz4_header(&header[..LZ4_HEADER_SIZE].try_into().unwrap())
{
FrameReader::new_skippable(input, compressed_size)
} else {
FrameReader::new_standard(input, header, header_read)
}
} else {
FrameReader::new_standard(input, header, header_read)
};
let decoder = FrameDecoder::new(frame_reader);
Ok(Self {
inner: Some(decoder),
})
}
}
impl<R: Read + Send> Read for Lz4Decoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let Some(inner) = &mut self.inner else {
return Ok(0);
};
match inner.read(buf) {
Ok(0) => {
let frame_reader = inner.get_mut();
match frame_reader.try_read_next_frame(validate_lz4_header)? {
Some(_compressed_size) => {
let reader = std::mem::replace(frame_reader, FrameReader::Empty);
let mut new_decoder = FrameDecoder::new(reader);
let result = new_decoder.read(buf);
self.inner = Some(new_decoder);
result
}
None => {
self.inner = None;
Ok(0)
}
}
}
result => result,
}
}
}
impl<R: Read + Send> Decoder for Lz4Decoder<R> {
fn method_id(&self) -> &'static [u8] {
method::LZ4
}
}
fn validate_lz4_header(header: &[u8; LZ4_HEADER_SIZE]) -> Option<u32> {
let magic = u32::from_le_bytes([header[0], header[1], header[2], header[3]]);
let frame_size = u32::from_le_bytes([header[4], header[5], header[6], header[7]]);
if magic == SKIPPABLE_FRAME_MAGIC && frame_size == LZ4_FRAME_SIZE {
Some(u32::from_le_bytes([
header[8], header[9], header[10], header[11],
]))
} else {
None
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Lz4EncoderOptions {
_reserved: (),
}
pub struct Lz4Encoder<W: Write> {
inner: FrameEncoder<W>,
}
impl<W: Write + Send> Lz4Encoder<W> {
pub fn new(output: W, _options: &Lz4EncoderOptions) -> Self {
Self {
inner: FrameEncoder::new(output),
}
}
pub fn try_finish(self) -> io::Result<W> {
self.inner.finish().map_err(io::Error::other)
}
}
impl<W: Write + Send> Write for Lz4Encoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<W: Write + Send> Encoder for Lz4Encoder<W> {
fn method_id(&self) -> &'static [u8] {
method::LZ4
}
fn finish(self: Box<Self>) -> io::Result<()> {
self.inner.finish()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn round_trip() {
let original = b"Hello, World! This is a test of LZ4 compression.";
let mut compressed = Vec::new();
{
let mut encoder = Lz4Encoder::new(&mut compressed, &Lz4EncoderOptions::default());
encoder.write_all(original).unwrap();
encoder.try_finish().unwrap();
}
let mut decoder = Lz4Decoder::new(Cursor::new(compressed)).unwrap();
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn method_id() {
let data = vec![0u8; 16];
let decoder = Lz4Decoder::new(Cursor::new(data)).unwrap();
assert_eq!(decoder.method_id(), method::LZ4);
}
#[test]
fn empty_input() {
let mut decoder = Lz4Decoder::new(Cursor::new(Vec::new())).unwrap();
let mut output = Vec::new();
let n = decoder.read_to_end(&mut output).unwrap();
assert_eq!(n, 0);
}
#[test]
fn skippable_frame_single() {
let original = b"test data for skippable frame";
let mut lz4_data = Vec::new();
{
let mut encoder = Lz4Encoder::new(&mut lz4_data, &Lz4EncoderOptions::default());
encoder.write_all(original).unwrap();
encoder.try_finish().unwrap();
}
let mut framed = Vec::new();
framed.extend_from_slice(&SKIPPABLE_FRAME_MAGIC.to_le_bytes());
framed.extend_from_slice(&LZ4_FRAME_SIZE.to_le_bytes());
framed.extend_from_slice(&(lz4_data.len() as u32).to_le_bytes());
framed.extend_from_slice(&lz4_data);
let mut decoder = Lz4Decoder::new(Cursor::new(framed)).unwrap();
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
assert_eq!(decompressed, original);
}
#[test]
fn skippable_frame_multi() {
let data1 = b"first frame data";
let data2 = b"second frame data";
let mut lz4_data1 = Vec::new();
{
let mut encoder = Lz4Encoder::new(&mut lz4_data1, &Lz4EncoderOptions::default());
encoder.write_all(data1).unwrap();
encoder.try_finish().unwrap();
}
let mut lz4_data2 = Vec::new();
{
let mut encoder = Lz4Encoder::new(&mut lz4_data2, &Lz4EncoderOptions::default());
encoder.write_all(data2).unwrap();
encoder.try_finish().unwrap();
}
let mut framed = Vec::new();
framed.extend_from_slice(&SKIPPABLE_FRAME_MAGIC.to_le_bytes());
framed.extend_from_slice(&LZ4_FRAME_SIZE.to_le_bytes());
framed.extend_from_slice(&(lz4_data1.len() as u32).to_le_bytes());
framed.extend_from_slice(&lz4_data1);
framed.extend_from_slice(&SKIPPABLE_FRAME_MAGIC.to_le_bytes());
framed.extend_from_slice(&LZ4_FRAME_SIZE.to_le_bytes());
framed.extend_from_slice(&(lz4_data2.len() as u32).to_le_bytes());
framed.extend_from_slice(&lz4_data2);
let mut decoder = Lz4Decoder::new(Cursor::new(framed)).unwrap();
let mut decompressed = Vec::new();
decoder.read_to_end(&mut decompressed).unwrap();
let mut expected = Vec::new();
expected.extend_from_slice(data1);
expected.extend_from_slice(data2);
assert_eq!(decompressed, expected);
}
#[test]
fn skippable_frame_empty_payload() {
let mut framed = Vec::new();
framed.extend_from_slice(&SKIPPABLE_FRAME_MAGIC.to_le_bytes());
framed.extend_from_slice(&LZ4_FRAME_SIZE.to_le_bytes());
framed.extend_from_slice(&0u32.to_le_bytes());
let mut decoder = Lz4Decoder::new(Cursor::new(framed)).unwrap();
let mut output = Vec::new();
let _ = decoder.read_to_end(&mut output);
}
}