use crate::bitwriter::BackwardBitWriter;
use crate::fse::{FseTable, FseTableEntry};
use crate::lz77::Lz77Sequence;
use oxiarc_core::error::{OxiArcError, Result};
#[derive(Debug, Clone, Copy)]
struct ZstdSequence {
ll_code: u8,
ll_extra_bits: u8,
ll_extra_value: u32,
ml_code: u8,
ml_extra_bits: u8,
ml_extra_value: u32,
of_code: u8,
of_extra_bits: u8,
of_extra_value: u32,
}
const LL_BASELINE: [u32; 36] = [
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 18, 20, 22, 24, 28, 32, 40, 48, 64,
128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536,
];
const LL_EXTRA: [u8; 36] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 4, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16,
];
const ML_BASELINE: [u32; 53] = [
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,
28, 29, 30, 31, 32, 33, 34, 35, 37, 39, 41, 43, 47, 51, 59, 67, 83, 99, 131, 259, 515, 1027,
2051, 4099, 8195, 16387, 32771, 65539,
];
const ML_EXTRA: [u8; 53] = [
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
1, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SequenceCompressionMode {
Predefined,
Rle(u8),
}
pub fn encode_compressed_block(sequences: &[Lz77Sequence]) -> Result<Vec<u8>> {
let literals: Vec<u8> = sequences
.iter()
.flat_map(|s| s.literals.iter().copied())
.collect();
let literals_section = encode_literals_section(&literals)?;
let ref_sequences: Vec<&Lz77Sequence> =
sequences.iter().filter(|s| s.match_length > 0).collect();
let zstd_sequences = convert_sequences(&ref_sequences)?;
let sequences_section = encode_sequences_section(&zstd_sequences)?;
let mut block = Vec::with_capacity(literals_section.len() + sequences_section.len());
block.extend_from_slice(&literals_section);
block.extend_from_slice(&sequences_section);
Ok(block)
}
fn encode_literals_section(literals: &[u8]) -> Result<Vec<u8>> {
if literals.is_empty() {
return Ok(vec![0]);
}
let first = literals[0];
let all_same = literals.iter().all(|&b| b == first);
if all_same {
return encode_rle_literals(literals);
}
encode_raw_literals(literals)
}
fn encode_raw_literals(literals: &[u8]) -> Result<Vec<u8>> {
let size = literals.len();
let mut out = Vec::with_capacity(3 + size);
if size < 32 {
out.push((size as u8) << 3); } else if size < 4096 {
let header: u16 = (0b01 << 2) | ((size as u16) << 4);
out.push((header & 0xFF) as u8);
out.push((header >> 8) as u8);
} else {
let header: u32 = (0b11 << 2) | ((size as u32) << 4);
out.push((header & 0xFF) as u8);
out.push(((header >> 8) & 0xFF) as u8);
out.push(((header >> 16) & 0xFF) as u8);
}
out.extend_from_slice(literals);
Ok(out)
}
fn encode_rle_literals(literals: &[u8]) -> Result<Vec<u8>> {
let byte = literals[0];
let size = literals.len();
let mut out = Vec::with_capacity(4);
if size < 32 {
out.push(((size as u8) << 3) | 0b01); } else if size < 4096 {
let header: u16 = 0b01 | (0b01 << 2) | ((size as u16) << 4);
out.push((header & 0xFF) as u8);
out.push((header >> 8) as u8);
} else {
let header: u32 = 0b01 | (0b11 << 2) | ((size as u32) << 4);
out.push((header & 0xFF) as u8);
out.push(((header >> 8) & 0xFF) as u8);
out.push(((header >> 16) & 0xFF) as u8);
}
out.push(byte);
Ok(out)
}
#[allow(dead_code)]
fn encode_compressed_literals(regen_size: usize, table: &[u8], streams: &[u8]) -> Result<Vec<u8>> {
let compressed_size = table.len() + streams.len();
let mut out = Vec::with_capacity(5 + compressed_size);
if regen_size < 1024 && compressed_size < 1024 {
let header: u32 = 0b10 | ((regen_size as u32) << 4)
| ((compressed_size as u32) << 14);
out.push((header & 0xFF) as u8);
out.push(((header >> 8) & 0xFF) as u8);
out.push(((header >> 16) & 0xFF) as u8);
} else if regen_size < 16384 && compressed_size < 16384 {
let header: u32 = 0b10 | (0b10 << 2) | ((regen_size as u32) << 4)
| ((compressed_size as u32) << 18);
out.push((header & 0xFF) as u8);
out.push(((header >> 8) & 0xFF) as u8);
out.push(((header >> 16) & 0xFF) as u8);
out.push(((header >> 24) & 0xFF) as u8);
} else {
let header: u64 = 0b10 | (0b11 << 2) | ((regen_size as u64) << 4)
| ((compressed_size as u64) << 22);
out.push((header & 0xFF) as u8);
out.push(((header >> 8) & 0xFF) as u8);
out.push(((header >> 16) & 0xFF) as u8);
out.push(((header >> 24) & 0xFF) as u8);
out.push(((header >> 32) & 0xFF) as u8);
}
out.extend_from_slice(table);
out.extend_from_slice(streams);
Ok(out)
}
fn convert_sequences(sequences: &[&Lz77Sequence]) -> Result<Vec<ZstdSequence>> {
let mut out = Vec::with_capacity(sequences.len());
for seq in sequences {
let ll = seq.literals.len() as u32;
let ml = seq.match_length as u32;
let offset = seq.offset as u32;
let (ll_code, ll_extra_bits, ll_extra_value) = encode_literal_length(ll)?;
let (ml_code, ml_extra_bits, ml_extra_value) = encode_match_length(ml)?;
let (of_code, of_extra_bits, of_extra_value) = encode_offset(offset)?;
out.push(ZstdSequence {
ll_code,
ll_extra_bits,
ll_extra_value,
ml_code,
ml_extra_bits,
ml_extra_value,
of_code,
of_extra_bits,
of_extra_value,
});
}
Ok(out)
}
fn encode_literal_length(value: u32) -> Result<(u8, u8, u32)> {
for (code, (&baseline, &extra)) in LL_BASELINE.iter().zip(LL_EXTRA.iter()).enumerate().rev() {
if value >= baseline {
let extra_value = value - baseline;
return Ok((code as u8, extra, extra_value));
}
}
Ok((0, 0, value))
}
fn encode_match_length(value: u32) -> Result<(u8, u8, u32)> {
if value < 3 {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: format!("match length {} is less than minimum 3", value),
});
}
for (code, (&baseline, &extra)) in ML_BASELINE.iter().zip(ML_EXTRA.iter()).enumerate().rev() {
if value >= baseline {
let extra_value = value - baseline;
return Ok((code as u8, extra, extra_value));
}
}
Err(OxiArcError::CorruptedData {
offset: 0,
message: format!("could not encode match length {}", value),
})
}
fn encode_offset(offset: u32) -> Result<(u8, u8, u32)> {
if offset == 0 {
return Err(OxiArcError::CorruptedData {
offset: 0,
message: "offset must be >= 1".to_string(),
});
}
let offset_value = offset + 3;
let code = 31 - offset_value.leading_zeros(); let extra_bits = code as u8;
let extra_value = offset_value - (1u32 << code);
Ok((code as u8, extra_bits, extra_value))
}
fn encode_sequences_section(sequences: &[ZstdSequence]) -> Result<Vec<u8>> {
if sequences.is_empty() {
return Ok(vec![0]); }
let mut out = Vec::new();
let count = sequences.len();
if count < 128 {
out.push(count as u8);
} else if count < 0x7F00 {
out.push(((count >> 8) as u8) + 128);
out.push((count & 0xFF) as u8);
} else {
out.push(255);
let adjusted = count - 0x7F00;
out.push((adjusted & 0xFF) as u8);
out.push(((adjusted >> 8) & 0xFF) as u8);
}
let ll_mode = choose_mode_for_codes(sequences.iter().map(|s| s.ll_code));
let of_mode = choose_mode_for_codes(sequences.iter().map(|s| s.of_code));
let ml_mode = choose_mode_for_codes(sequences.iter().map(|s| s.ml_code));
let modes_byte = (mode_to_bits(&ll_mode) << 6)
| (mode_to_bits(&of_mode) << 4)
| (mode_to_bits(&ml_mode) << 2);
out.push(modes_byte);
write_mode_table_data(&mut out, &ll_mode);
write_mode_table_data(&mut out, &of_mode);
write_mode_table_data(&mut out, &ml_mode);
let bitstream = encode_sequences_bitstream(sequences, &ll_mode, &of_mode, &ml_mode)?;
out.extend_from_slice(&bitstream);
Ok(out)
}
fn choose_mode_for_codes(mut codes: impl Iterator<Item = u8>) -> SequenceCompressionMode {
let first = match codes.next() {
Some(v) => v,
None => return SequenceCompressionMode::Predefined,
};
if codes.all(|c| c == first) {
SequenceCompressionMode::Rle(first)
} else {
SequenceCompressionMode::Predefined
}
}
fn mode_to_bits(mode: &SequenceCompressionMode) -> u8 {
match mode {
SequenceCompressionMode::Predefined => 0,
SequenceCompressionMode::Rle(_) => 1,
}
}
fn write_mode_table_data(out: &mut Vec<u8>, mode: &SequenceCompressionMode) {
match mode {
SequenceCompressionMode::Predefined => {}
SequenceCompressionMode::Rle(symbol) => {
out.push(*symbol);
}
}
}
struct FseEncodingTable {
symbol_states: Vec<Vec<FseEncState>>,
decoding_table: FseTable,
}
#[derive(Debug, Clone, Copy)]
struct FseEncState {
state: usize,
baseline: u16,
}
impl FseEncodingTable {
fn from_decoding_table(table: FseTable) -> Self {
let table_size = 1usize << table.accuracy_log();
let mut max_symbol = 0u8;
for i in 0..table_size {
let entry = table.get(i);
if entry.symbol > max_symbol {
max_symbol = entry.symbol;
}
}
let mut symbol_states = vec![Vec::new(); max_symbol as usize + 1];
for i in 0..table_size {
let entry = table.get(i);
symbol_states[entry.symbol as usize].push(FseEncState {
state: i,
baseline: entry.baseline,
});
}
for states in &mut symbol_states {
states.sort_by_key(|s| s.baseline);
}
Self {
symbol_states,
decoding_table: table,
}
}
fn accuracy_log(&self) -> u8 {
self.decoding_table.accuracy_log()
}
fn get_entry(&self, state: usize) -> &FseTableEntry {
self.decoding_table.get(state)
}
}
fn encode_sequences_bitstream(
sequences: &[ZstdSequence],
ll_mode: &SequenceCompressionMode,
of_mode: &SequenceCompressionMode,
ml_mode: &SequenceCompressionMode,
) -> Result<Vec<u8>> {
let mut writer = BackwardBitWriter::new();
let ll_enc = build_predefined_enc_table(ll_mode, TableCategory::LiteralLength);
let of_enc = build_predefined_enc_table(of_mode, TableCategory::Offset);
let ml_enc = build_predefined_enc_table(ml_mode, TableCategory::MatchLength);
let n = sequences.len();
if n == 0 {
return Ok(writer.finish());
}
let ll_states =
compute_fse_states_backward(&ll_enc, sequences.iter().map(|s| s.ll_code).collect());
let of_states =
compute_fse_states_backward(&of_enc, sequences.iter().map(|s| s.of_code).collect());
let ml_states =
compute_fse_states_backward(&ml_enc, sequences.iter().map(|s| s.ml_code).collect());
if let Some(ref enc) = ll_enc {
writer.write_bits(ll_states[0] as u64, enc.accuracy_log());
}
if let Some(ref enc) = of_enc {
writer.write_bits(of_states[0] as u64, enc.accuracy_log());
}
if let Some(ref enc) = ml_enc {
writer.write_bits(ml_states[0] as u64, enc.accuracy_log());
}
for idx in 0..n {
let seq = &sequences[idx];
if let Some(ref enc) = of_enc {
let entry = enc.get_entry(of_states[idx]);
if entry.num_bits > 0 {
let target_next = if idx + 1 < n {
of_states[idx + 1]
} else {
entry.baseline as usize
};
let bits_val = target_next.wrapping_sub(entry.baseline as usize);
writer.write_bits(bits_val as u64, entry.num_bits);
}
}
if let Some(ref enc) = ml_enc {
let entry = enc.get_entry(ml_states[idx]);
if entry.num_bits > 0 {
let target_next = if idx + 1 < n {
ml_states[idx + 1]
} else {
entry.baseline as usize
};
let bits_val = target_next.wrapping_sub(entry.baseline as usize);
writer.write_bits(bits_val as u64, entry.num_bits);
}
}
if let Some(ref enc) = ll_enc {
let entry = enc.get_entry(ll_states[idx]);
if entry.num_bits > 0 {
let target_next = if idx + 1 < n {
ll_states[idx + 1]
} else {
entry.baseline as usize
};
let bits_val = target_next.wrapping_sub(entry.baseline as usize);
writer.write_bits(bits_val as u64, entry.num_bits);
}
}
if seq.ll_extra_bits > 0 {
writer.write_bits(seq.ll_extra_value as u64, seq.ll_extra_bits);
}
if seq.ml_extra_bits > 0 {
writer.write_bits(seq.ml_extra_value as u64, seq.ml_extra_bits);
}
if seq.of_extra_bits > 0 {
writer.write_bits(seq.of_extra_value as u64, seq.of_extra_bits);
}
}
Ok(writer.finish())
}
fn compute_fse_states_backward(enc: &Option<FseEncodingTable>, symbols: Vec<u8>) -> Vec<usize> {
let enc = match enc {
Some(e) => e,
None => return Vec::new(),
};
let n = symbols.len();
if n == 0 {
return Vec::new();
}
let mut states = vec![0usize; n];
let last_sym = symbols[n - 1] as usize;
states[n - 1] = if last_sym < enc.symbol_states.len() && !enc.symbol_states[last_sym].is_empty()
{
enc.symbol_states[last_sym][0].state
} else {
0
};
for i in (0..n.saturating_sub(1)).rev() {
let sym = symbols[i] as usize;
let target_next = states[i + 1];
if sym >= enc.symbol_states.len() || enc.symbol_states[sym].is_empty() {
states[i] = 0;
continue;
}
let mut found = false;
for enc_state in &enc.symbol_states[sym] {
let entry = enc.get_entry(enc_state.state);
let range_size = 1usize << entry.num_bits;
let baseline = entry.baseline as usize;
if target_next >= baseline && target_next < baseline + range_size {
states[i] = enc_state.state;
found = true;
break;
}
}
if !found {
states[i] = enc.symbol_states[sym][0].state;
}
}
states
}
enum TableCategory {
LiteralLength,
Offset,
MatchLength,
}
fn build_predefined_enc_table(
mode: &SequenceCompressionMode,
category: TableCategory,
) -> Option<FseEncodingTable> {
match mode {
SequenceCompressionMode::Predefined => {
let dec_table = match category {
TableCategory::LiteralLength => {
let probs: [i16; 36] = [
4, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 2, 1, 1, 1, 1, 1, -1, -1, -1, -1,
];
FseTable::new(6, &probs).ok()?
}
TableCategory::Offset => {
let probs: [i16; 29] = [
1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1,
-1, -1, -1, -1,
];
FseTable::new(5, &probs).ok()?
}
TableCategory::MatchLength => {
let probs: [i16; 53] = [
1, 4, 3, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1,
-1, -1, -1, -1,
];
FseTable::new(6, &probs).ok()?
}
};
Some(FseEncodingTable::from_decoding_table(dec_table))
}
SequenceCompressionMode::Rle(_) => None,
}
}
#[allow(dead_code)]
fn count_symbol_frequencies(sequences: &[ZstdSequence]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
let mut ll_freqs = vec![0u32; 36];
let mut of_freqs = vec![0u32; 29];
let mut ml_freqs = vec![0u32; 53];
for seq in sequences {
ll_freqs[seq.ll_code as usize] += 1;
of_freqs[seq.of_code as usize] += 1;
ml_freqs[seq.ml_code as usize] += 1;
}
(ll_freqs, of_freqs, ml_freqs)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_literal_length_small() {
for val in 0..16u32 {
let (code, extra, extra_val) = encode_literal_length(val).unwrap();
assert_eq!(code, val as u8);
assert_eq!(extra, 0);
assert_eq!(extra_val, 0);
}
}
#[test]
fn test_encode_literal_length_large() {
let (code, extra_bits, extra_val) = encode_literal_length(18).unwrap();
assert_eq!(code, 17);
assert_eq!(extra_bits, 1);
assert_eq!(extra_val, 0);
let (code, extra_bits, extra_val) = encode_literal_length(19).unwrap();
assert_eq!(code, 17);
assert_eq!(extra_bits, 1);
assert_eq!(extra_val, 1);
}
#[test]
fn test_encode_match_length_minimum() {
let (code, extra, extra_val) = encode_match_length(3).unwrap();
assert_eq!(code, 0);
assert_eq!(extra, 0);
assert_eq!(extra_val, 0);
}
#[test]
fn test_encode_match_length_too_small() {
assert!(encode_match_length(2).is_err());
assert!(encode_match_length(0).is_err());
}
#[test]
fn test_encode_offset() {
let (code, extra_bits, extra_val) = encode_offset(1).unwrap();
assert_eq!(code, 2);
assert_eq!(extra_bits, 2);
assert_eq!(extra_val, 0);
let (code, extra_bits, extra_val) = encode_offset(2).unwrap();
assert_eq!(code, 2);
assert_eq!(extra_bits, 2);
assert_eq!(extra_val, 1);
let (code, extra_bits, extra_val) = encode_offset(5).unwrap();
assert_eq!(code, 3);
assert_eq!(extra_bits, 3);
assert_eq!(extra_val, 0);
}
#[test]
fn test_encode_offset_zero_fails() {
assert!(encode_offset(0).is_err());
}
#[test]
fn test_encode_raw_literals_small() {
let literals = b"Hello";
let encoded = encode_raw_literals(literals).unwrap();
assert_eq!(encoded[0], (5u8) << 3);
assert_eq!(&encoded[1..], b"Hello");
}
#[test]
fn test_encode_raw_literals_medium() {
let literals = vec![0xAB; 100];
let encoded = encode_raw_literals(&literals).unwrap();
let header: u16 = (0b01 << 2) | ((100u16) << 4);
assert_eq!(encoded[0], (header & 0xFF) as u8);
assert_eq!(encoded[1], (header >> 8) as u8);
assert_eq!(encoded.len(), 2 + 100);
}
#[test]
fn test_encode_rle_literals() {
let literals = vec![0xCC; 10];
let encoded = encode_rle_literals(&literals).unwrap();
assert_eq!(encoded[0], (10u8 << 3) | 0b01);
assert_eq!(encoded[1], 0xCC);
assert_eq!(encoded.len(), 2);
}
#[test]
fn test_encode_literals_section_empty() {
let encoded = encode_literals_section(&[]).unwrap();
assert_eq!(encoded, vec![0]);
}
#[test]
fn test_encode_literals_section_rle() {
let literals = vec![0xFF; 20];
let encoded = encode_literals_section(&literals).unwrap();
assert_eq!(encoded[0] & 0x03, 0x01); }
#[test]
fn test_encode_sequences_section_empty() {
let encoded = encode_sequences_section(&[]).unwrap();
assert_eq!(encoded, vec![0]);
}
#[test]
fn test_choose_mode_all_same() {
let mode = choose_mode_for_codes([5u8, 5, 5, 5].iter().copied());
assert_eq!(mode, SequenceCompressionMode::Rle(5));
}
#[test]
fn test_choose_mode_different() {
let mode = choose_mode_for_codes([1u8, 2, 3].iter().copied());
assert_eq!(mode, SequenceCompressionMode::Predefined);
}
#[test]
fn test_count_symbol_frequencies() {
let seqs = vec![
ZstdSequence {
ll_code: 0,
ll_extra_bits: 0,
ll_extra_value: 0,
ml_code: 0,
ml_extra_bits: 0,
ml_extra_value: 0,
of_code: 1,
of_extra_bits: 1,
of_extra_value: 0,
},
ZstdSequence {
ll_code: 0,
ll_extra_bits: 0,
ll_extra_value: 0,
ml_code: 1,
ml_extra_bits: 0,
ml_extra_value: 0,
of_code: 1,
of_extra_bits: 1,
of_extra_value: 0,
},
];
let (ll, of, ml) = count_symbol_frequencies(&seqs);
assert_eq!(ll[0], 2);
assert_eq!(of[1], 2);
assert_eq!(ml[0], 1);
assert_eq!(ml[1], 1);
}
#[test]
fn test_encode_compressed_block_simple() {
let sequences = vec![Lz77Sequence {
literals: b"Hello".to_vec(),
match_length: 3,
offset: 1,
}];
let block = encode_compressed_block(&sequences).unwrap();
assert!(!block.is_empty());
}
#[test]
fn test_encode_compressed_block_literals_only() {
let sequences = vec![Lz77Sequence {
literals: b"Trailing literals".to_vec(),
match_length: 0,
offset: 0,
}];
let block = encode_compressed_block(&sequences).unwrap();
assert!(!block.is_empty());
}
}