use std::io::{self, Write};
use super::fast_lzma2_encode::{
self, ChunkResetMode, DEFAULT_BLOCK_SIZE, Token, encode_greedy, write_compressed_chunk,
};
use super::lzma_context::LzmaEncoderState;
use super::lzma_rc::LzmaRangeEncoder;
use super::{Encoder, method};
pub use super::radix_mf::{Match, RadixMatchFinder};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Strategy {
Fast,
#[default]
Balanced,
Best,
}
#[derive(Debug, Clone)]
pub struct FastLzma2Options {
pub level: u32,
pub dict_size: u32,
pub strategy: Strategy,
pub threads: Option<usize>,
pub depth: Option<u32>,
}
impl Default for FastLzma2Options {
fn default() -> Self {
Self {
level: 6,
dict_size: 8 * 1024 * 1024, strategy: Strategy::Balanced,
threads: None,
depth: None,
}
}
}
impl FastLzma2Options {
pub fn with_level(level: u32) -> Self {
let level = level.clamp(1, 10);
let strategy = match level {
1..=3 => Strategy::Fast,
4..=6 => Strategy::Balanced,
_ => Strategy::Best,
};
let dict_size = match level {
1 => 64 * 1024, 2 => 256 * 1024, 3 => 1024 * 1024, 4 => 2 * 1024 * 1024, 5 => 4 * 1024 * 1024, 6 => 8 * 1024 * 1024, 7 => 16 * 1024 * 1024, 8 => 32 * 1024 * 1024, 9 => 64 * 1024 * 1024, _ => 64 * 1024 * 1024, };
Self {
level,
dict_size,
strategy,
threads: None,
depth: None,
}
}
pub fn dict_size(mut self, size: u32) -> Self {
self.dict_size = size;
self
}
pub fn strategy(mut self, strategy: Strategy) -> Self {
self.strategy = strategy;
self
}
pub fn threads(mut self, threads: usize) -> Self {
self.threads = Some(threads);
self
}
pub fn properties(&self) -> u8 {
let dict = self.dict_size.max(4096); let log2 = (32 - dict.leading_zeros()) - 1;
let power = 1u32 << log2;
if dict == power {
((log2 - 12) * 2) as u8
} else {
((log2 - 12) * 2 + 1) as u8
}
}
}
pub struct FastLzma2Encoder<W: Write> {
output: W,
options: FastLzma2Options,
match_finder: RadixMatchFinder,
encoder_state: LzmaEncoderState,
buffer: Vec<u8>,
first_chunk: bool,
finished: bool,
}
impl<W: Write> std::fmt::Debug for FastLzma2Encoder<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FastLzma2Encoder")
.field("options", &self.options)
.field("buffer_len", &self.buffer.len())
.field("first_chunk", &self.first_chunk)
.field("finished", &self.finished)
.finish_non_exhaustive()
}
}
impl<W: Write + Send> FastLzma2Encoder<W> {
pub fn new(output: W, options: &FastLzma2Options) -> Self {
let depth = options.depth.unwrap_or(match options.strategy {
Strategy::Fast => 16,
Strategy::Balanced => 64,
Strategy::Best => 256,
});
let match_finder = RadixMatchFinder::new(options.dict_size as usize, depth);
let encoder_state = LzmaEncoderState::new(3, 0, 2);
Self {
output,
options: options.clone(),
match_finder,
encoder_state,
buffer: Vec::with_capacity(DEFAULT_BLOCK_SIZE),
first_chunk: true,
finished: false,
}
}
pub fn properties(options: &FastLzma2Options) -> Vec<u8> {
vec![options.properties()]
}
const MAX_CHUNK_SIZE: usize = 64 * 1024;
fn compress_block(&mut self, data: &[u8]) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
let mut pos = 0;
while pos < data.len() {
let chunk_end = (pos + Self::MAX_CHUNK_SIZE).min(data.len());
let chunk = &data[pos..chunk_end];
self.compress_chunk(chunk)?;
pos = chunk_end;
}
Ok(())
}
fn compress_chunk(&mut self, data: &[u8]) -> io::Result<()> {
if data.is_empty() {
return Ok(());
}
self.match_finder.build(data);
let tokens = encode_greedy(data, &self.match_finder);
let compressed = self.encode_tokens(&tokens, data);
if compressed.len() >= data.len() {
fast_lzma2_encode::write_uncompressed_data(&mut self.output, data, self.first_chunk)?;
} else {
let reset_mode = ChunkResetMode::AllReset;
let props = Some(0x5D);
write_compressed_chunk(&mut self.output, &compressed, data.len(), reset_mode, props)?;
}
self.encoder_state.reset();
self.match_finder.reset();
self.first_chunk = false;
Ok(())
}
fn encode_tokens(&mut self, tokens: &[Token], data: &[u8]) -> Vec<u8> {
let mut rc = LzmaRangeEncoder::new();
let mut pos = 0usize;
for token in tokens {
match token {
Token::Literal(byte) => {
let prev_byte = if pos > 0 { data[pos - 1] } else { 0 };
let match_byte = if self.encoder_state.state() >= 7 && pos > 0 {
let reps = self.encoder_state.reps();
let dist = reps[0] + 1;
if pos >= dist as usize {
Some(data[pos - dist as usize])
} else {
None
}
} else {
None
};
self.encoder_state
.encode_literal(&mut rc, *byte, pos, prev_byte, match_byte);
pos += 1;
}
Token::Match { distance, length } => {
self.encoder_state
.encode_match(&mut rc, *distance, *length, pos);
pos += *length as usize;
}
}
}
rc.finish()
}
pub fn try_finish(mut self) -> io::Result<W> {
if self.finished {
return Ok(self.output);
}
if !self.buffer.is_empty() {
let data = std::mem::take(&mut self.buffer);
self.compress_block(&data)?;
}
fast_lzma2_encode::write_end_marker(&mut self.output)?;
self.finished = true;
Ok(self.output)
}
pub fn finish(self) -> io::Result<()> {
self.try_finish()?;
Ok(())
}
}
impl<W: Write + Send> Write for FastLzma2Encoder<W> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.finished {
return Err(io::Error::other("encoder already finished"));
}
self.buffer.extend_from_slice(buf);
while self.buffer.len() >= DEFAULT_BLOCK_SIZE {
let block: Vec<u8> = self.buffer.drain(..DEFAULT_BLOCK_SIZE).collect();
self.compress_block(&block)?;
}
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
if !self.buffer.is_empty() {
let data = std::mem::take(&mut self.buffer);
self.compress_block(&data)?;
}
self.output.flush()
}
}
impl<W: Write + Send> Encoder for FastLzma2Encoder<W> {
fn method_id(&self) -> &'static [u8] {
method::LZMA2
}
fn finish(self: Box<Self>) -> io::Result<()> {
(*self).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_options_default() {
let opts = FastLzma2Options::default();
assert_eq!(opts.level, 6);
assert_eq!(opts.dict_size, 8 * 1024 * 1024);
assert_eq!(opts.strategy, Strategy::Balanced);
}
#[test]
fn test_options_with_level() {
let opts = FastLzma2Options::with_level(1);
assert_eq!(opts.level, 1);
assert_eq!(opts.strategy, Strategy::Fast);
let opts = FastLzma2Options::with_level(5);
assert_eq!(opts.level, 5);
assert_eq!(opts.strategy, Strategy::Balanced);
let opts = FastLzma2Options::with_level(9);
assert_eq!(opts.level, 9);
assert_eq!(opts.strategy, Strategy::Best);
}
#[test]
fn test_properties_encoding() {
let opts = FastLzma2Options::default();
let prop = opts.properties();
assert_eq!(prop, 22);
let opts = FastLzma2Options::with_level(1).dict_size(4096);
let prop = opts.properties();
assert_eq!(prop, 0);
}
#[test]
fn test_fast_lzma2_encoder_roundtrip() {
let data = b"Hello, Fast LZMA2 World! This is a test of compression.";
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(data).unwrap();
let _ = encoder.try_finish().unwrap();
}
assert!(!compressed.is_empty());
let mut decompressed = Vec::new();
let mut decoder = lzma_rust2::Lzma2Reader::new(
std::io::Cursor::new(&compressed),
1024 * 1024, None,
);
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(&decompressed[..], &data[..]);
}
#[test]
fn test_fast_lzma2_encoder_incompressible() {
let mut state = 12345u32;
let data: Vec<u8> = (0..60_000u32)
.map(|_| {
state = state.wrapping_mul(1103515245).wrapping_add(12345);
(state >> 16) as u8
})
.collect();
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(&data).unwrap();
let _ = encoder.try_finish().unwrap();
}
let ctrl = compressed[0];
assert!(
ctrl == 0x01 || ctrl == 0x02,
"Expected uncompressed format for random data"
);
let mut decompressed = Vec::new();
let mut decoder =
lzma_rust2::Lzma2Reader::new(std::io::Cursor::new(&compressed), 8 * 1024 * 1024, None);
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_fast_lzma2_encoder_single_chunk() {
let data: Vec<u8> = (0..60_000u32).map(|i| (i % 256) as u8).collect();
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(&data).unwrap();
let _ = encoder.try_finish().unwrap();
}
let mut decompressed = Vec::new();
let mut decoder = lzma_rust2::Lzma2Reader::new(
std::io::Cursor::new(&compressed),
8 * 1024 * 1024, None,
);
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_fast_lzma2_encoder_medium_data() {
let data: Vec<u8> = (0..100_000u32).map(|i| (i % 256) as u8).collect();
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(&data).unwrap();
let _ = encoder.try_finish().unwrap();
}
let mut decompressed = Vec::new();
let mut decoder = lzma_rust2::Lzma2Reader::new(
std::io::Cursor::new(&compressed),
8 * 1024 * 1024, None,
);
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_fast_lzma2_encoder_large_data() {
let data: Vec<u8> = (0..1_500_000u32).map(|i| (i % 256) as u8).collect();
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(&data).unwrap();
let _ = encoder.try_finish().unwrap();
}
let mut decompressed = Vec::new();
let mut decoder = lzma_rust2::Lzma2Reader::new(
std::io::Cursor::new(&compressed),
8 * 1024 * 1024, None,
);
std::io::Read::read_to_end(&mut decoder, &mut decompressed).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn test_fast_lzma2_encoder_empty_data() {
let data = b"";
let mut compressed = Vec::new();
{
let mut encoder =
FastLzma2Encoder::new(&mut compressed, &FastLzma2Options::with_level(3));
encoder.write_all(data).unwrap();
let _ = encoder.try_finish().unwrap();
}
assert_eq!(compressed, vec![0x00]);
}
#[test]
fn test_match_struct() {
let m = Match::new(100, 5);
assert_eq!(m.offset, 100);
assert_eq!(m.length, 5);
}
#[test]
fn test_radix_match_finder_new() {
let mf = RadixMatchFinder::new(1024 * 1024, 64);
assert_eq!(mf.dict_size(), 1024 * 1024);
}
#[test]
fn test_radix_match_finder_integration() {
let opts = FastLzma2Options::with_level(6);
let mut mf = RadixMatchFinder::new(opts.dict_size as usize, 32);
let data = b"Hello, World! Hello, World! Hello, World!";
mf.build(data);
let m = mf.get_match(data, 14);
assert!(m.is_some());
let m = m.unwrap();
assert!(m.length >= 13); }
}