extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::lha::bits::{BitReader, BitWriter};
use crate::lha::huffman::{HuffTable, assign_lengths, lengths_to_codes};
const NC: usize = 510;
const CBIT: u32 = 9;
const NT: usize = 19;
const TBIT: u32 = 5;
const MIN_MATCH: usize = 3;
const MAX_MATCH: usize = 256;
const TABLE_BITS: u32 = 12;
const PT_TABLE_BITS: u32 = 8;
const SPECIAL_INDEX: usize = 3;
#[derive(Clone, Copy)]
pub struct Params {
pub ring_size: usize,
pub pbit: u32,
pub np: usize,
}
impl Params {
pub const fn for_method(name: &str) -> Params {
match name.as_bytes() {
b"lh4" => Params {
ring_size: 1 << 12,
pbit: 4,
np: 14,
},
b"lh6" => Params {
ring_size: 1 << 16,
pbit: 5,
np: 17,
},
b"lh7" => Params {
ring_size: 1 << 17,
pbit: 5,
np: 18,
},
_ => Params {
ring_size: 1 << 14,
pbit: 4,
np: 14,
},
}
}
}
fn read_length_value(br: &mut BitReader<'_>) -> Result<u8, Error> {
let mut n = br.get_bits(3);
if n == 7 {
loop {
if br.get_bits(1) == 0 {
break;
}
n += 1;
if n > super::huffman::MAX_BITS {
return Err(Error::InvalidHuffmanTree);
}
}
}
Ok(n as u8)
}
fn write_length_value(bw: &mut BitWriter, len: u8) {
let l = len as u32;
if l < 7 {
bw.put_bits(3, l);
} else {
bw.put_bits(3, 7);
let extra = l - 7;
for _ in 0..extra {
bw.put_bits(1, 1);
}
bw.put_bits(1, 0);
}
}
fn read_temp_table(br: &mut BitReader<'_>) -> Result<HuffTable, Error> {
let n = br.get_bits(TBIT) as usize;
if n == 0 {
let sym = br.get_bits(TBIT) as u16;
return HuffTable::build_single(NT, sym, PT_TABLE_BITS);
}
if n > NT {
return Err(Error::Corrupt);
}
let mut lens = vec![0u8; NT];
let mut i = 0usize;
while i < n {
lens[i] = read_length_value(br)?;
i += 1;
if i == SPECIAL_INDEX {
let skip = br.get_bits(2) as usize;
for _ in 0..skip {
if i >= NT {
return Err(Error::Corrupt);
}
lens[i] = 0;
i += 1;
}
}
}
HuffTable::build(&lens, PT_TABLE_BITS)
}
fn write_temp_table(bw: &mut BitWriter, lens: &[u8]) {
let n = NT;
bw.put_bits(TBIT, n as u32);
let mut i = 0usize;
while i < n {
write_length_value(bw, lens[i]);
i += 1;
if i == SPECIAL_INDEX {
bw.put_bits(2, 0);
}
}
}
fn read_c_lengths(br: &mut BitReader<'_>, temp: &HuffTable) -> Result<Vec<u8>, Error> {
let n = br.get_bits(CBIT) as usize;
if n == 0 {
let sym = br.get_bits(CBIT) as u16;
if sym as usize >= NC {
return Err(Error::InvalidHuffmanTree);
}
let mut lens = vec![0u8; NC];
lens[sym as usize] = SINGLE_MARKER;
return Ok(lens);
}
if n > NC {
return Err(Error::Corrupt);
}
let mut lens = vec![0u8; NC];
let mut i = 0usize;
while i < n {
let c = temp.decode(br)?;
match c {
0 => {
i += 1;
}
1 => {
let cnt = br.get_bits(4) as usize + 3;
for _ in 0..cnt {
if i >= NC {
return Err(Error::Corrupt);
}
lens[i] = 0;
i += 1;
}
}
2 => {
let cnt = br.get_bits(9) as usize + 20;
for _ in 0..cnt {
if i >= NC {
return Err(Error::Corrupt);
}
lens[i] = 0;
i += 1;
}
}
_ => {
if (c as usize) < 2 {
return Err(Error::Corrupt);
}
lens[i] = (c - 2) as u8;
if lens[i] as u32 > super::huffman::MAX_BITS {
return Err(Error::InvalidHuffmanTree);
}
i += 1;
}
}
if i > NC {
return Err(Error::Corrupt);
}
}
Ok(lens)
}
const SINGLE_MARKER: u8 = 0xFF;
fn c_lengths_to_temp_symbols(lens: &[u8]) -> Vec<TempSym> {
let mut out = Vec::new();
let mut i = 0usize;
let n = lens.len();
while i < n {
if lens[i] == 0 {
let mut run = 1usize;
while i + run < n && lens[i + run] == 0 {
run += 1;
}
let mut rem = run;
while rem > 0 {
if rem >= 20 {
let take = rem.min(20 + 511); out.push(TempSym::Run2((take - 20) as u32));
rem -= take;
} else if rem >= 3 {
let take = rem.min(3 + 15); out.push(TempSym::Run1((take - 3) as u32));
rem -= take;
} else {
out.push(TempSym::Zero);
rem -= 1;
}
}
i += run;
} else {
out.push(TempSym::Len(lens[i]));
i += 1;
}
}
out
}
#[derive(Clone, Copy)]
enum TempSym {
Zero,
Run1(u32),
Run2(u32),
Len(u8),
}
impl TempSym {
fn symbol(&self) -> usize {
match self {
TempSym::Zero => 0,
TempSym::Run1(_) => 1,
TempSym::Run2(_) => 2,
TempSym::Len(l) => *l as usize + 2,
}
}
}
fn read_position_table(br: &mut BitReader<'_>, np: usize, pbit: u32) -> Result<HuffTable, Error> {
let n = br.get_bits(pbit) as usize;
if n == 0 {
let sym = br.get_bits(pbit) as u16;
if sym as usize >= np {
return Err(Error::InvalidHuffmanTree);
}
return HuffTable::build_single(np, sym, TABLE_BITS);
}
if n > np {
return Err(Error::Corrupt);
}
let mut lens = vec![0u8; np];
let mut i = 0usize;
while i < n {
lens[i] = read_length_value(br)?;
i += 1;
}
HuffTable::build(&lens, TABLE_BITS)
}
fn write_position_table(bw: &mut BitWriter, lens: &[u8], np: usize, pbit: u32) {
let _ = np;
let n = lens.len();
bw.put_bits(pbit, n as u32);
for &l in lens.iter().take(n) {
write_length_value(bw, l);
}
}
fn offset_to_symbol(offset: usize) -> (usize, u32, u32) {
if offset == 0 {
return (0, 0, 0);
}
if offset == 1 {
return (1, 0, 0);
}
let mut s = 1usize;
while (1usize << s) <= offset {
s += 1;
}
let extra_bits = (s - 1) as u32;
let extra = (offset - (1usize << (s - 1))) as u32;
(s, extra_bits, extra)
}
fn read_offset_code(br: &mut BitReader<'_>, table: &HuffTable) -> Result<usize, Error> {
let sym = table.decode(br)? as usize;
if sym == 0 {
Ok(0)
} else if sym == 1 {
Ok(1)
} else {
let extra = br.get_bits((sym - 1) as u32) as usize;
Ok((1usize << (sym - 1)) + extra)
}
}
pub fn decode_payload(
payload: &[u8],
expected: Option<usize>,
params: Params,
) -> Result<Vec<u8>, Error> {
let mut out: Vec<u8> = Vec::with_capacity(expected.unwrap_or(0).min(1 << 20));
if expected == Some(0) {
return Ok(out);
}
let ring_size = params.ring_size;
let mut ring = vec![b' '; ring_size];
let mut ring_pos = 0usize;
let mut br = BitReader::new(payload);
loop {
if let Some(n) = expected
&& out.len() >= n
{
break;
}
let block_codes = br.get_bits(16) as usize;
if br.overran() {
return match expected {
Some(_) => Err(Error::UnexpectedEnd),
None => Ok(out),
};
}
if block_codes == 0 {
return Err(Error::Corrupt);
}
let temp = read_temp_table(&mut br)?;
let c_lens = read_c_lengths(&mut br, &temp)?;
let c_table = build_c_table(&c_lens)?;
let p_table = read_position_table(&mut br, params.np, params.pbit)?;
let mut remaining = block_codes;
while remaining > 0 {
if let Some(n) = expected
&& out.len() >= n
{
break;
}
let code = c_table.decode(&mut br)? as usize;
if br.overran() {
return Err(Error::UnexpectedEnd);
}
if code < 256 {
out.push(code as u8);
ring[ring_pos] = code as u8;
ring_pos = (ring_pos + 1) % ring_size;
} else {
let count = code - 256 + MIN_MATCH;
if count > MAX_MATCH {
return Err(Error::Corrupt);
}
let offset = read_offset_code(&mut br, &p_table)?;
if br.overran() {
return Err(Error::UnexpectedEnd);
}
if offset >= ring_size {
return Err(Error::InvalidDistance);
}
let start = (ring_pos + ring_size - offset - 1) % ring_size;
for k in 0..count {
if let Some(n) = expected
&& out.len() >= n
{
break;
}
let b = ring[(start + k) % ring_size];
out.push(b);
ring[ring_pos] = b;
ring_pos = (ring_pos + 1) % ring_size;
}
}
remaining -= 1;
}
}
Ok(out)
}
fn build_c_table(c_lens: &[u8]) -> Result<HuffTable, Error> {
let mut single: Option<u16> = None;
let mut any_normal = false;
for (s, &l) in c_lens.iter().enumerate() {
if l == SINGLE_MARKER {
single = Some(s as u16);
} else if l != 0 {
any_normal = true;
}
}
if let Some(sym) = single {
if any_normal {
return Err(Error::Corrupt);
}
return HuffTable::build_single(NC, sym, TABLE_BITS);
}
HuffTable::build(c_lens, TABLE_BITS)
}
enum Token {
Lit(u8),
Match { len: usize, offset: usize },
}
fn single_symbol(lens: &[u8]) -> Option<usize> {
let mut found = None;
for (s, &l) in lens.iter().enumerate() {
if l != 0 {
if found.is_some() {
return None;
}
found = Some(s);
}
}
found
}
pub fn encode_payload(data: &[u8], params: Params) -> Vec<u8> {
let mut bw = BitWriter::new();
if data.is_empty() {
return bw.finish();
}
let max_dist = (1usize << (params.np - 1)) - 1;
let tokens = lz_parse(data, max_dist);
let mut c_freq = vec![0u32; NC];
let mut p_freq = vec![0u32; params.np];
for t in &tokens {
match t {
Token::Lit(b) => c_freq[*b as usize] += 1,
Token::Match { len, offset } => {
let code = 256 + (len - MIN_MATCH);
c_freq[code] += 1;
let (sym, _, _) = offset_to_symbol(*offset);
p_freq[sym] += 1;
}
}
}
let c_lens = assign_lengths(&c_freq, super::huffman::MAX_BITS);
let p_lens = assign_lengths(&p_freq, super::huffman::MAX_BITS);
let c_codes = lengths_to_codes(&c_lens);
let p_codes = lengths_to_codes(&p_lens);
bw.put_bits(16, tokens.len() as u32);
let c_single = single_symbol(&c_lens);
if let Some(sym) = c_single {
bw.put_bits(TBIT, 0);
bw.put_bits(TBIT, 0);
bw.put_bits(CBIT, 0);
bw.put_bits(CBIT, sym as u32);
} else {
let temp_syms = c_lengths_to_temp_symbols(&c_lens);
let mut t_freq = vec![0u32; NT];
for ts in &temp_syms {
t_freq[ts.symbol()] += 1;
}
let t_lens = assign_lengths(&t_freq, super::huffman::MAX_BITS);
let t_codes = lengths_to_codes(&t_lens);
write_temp_table(&mut bw, &t_lens);
bw.put_bits(CBIT, NC as u32);
for ts in &temp_syms {
let sym = ts.symbol();
bw.put_bits(t_lens[sym] as u32, t_codes[sym]);
match ts {
TempSym::Zero | TempSym::Len(_) => {}
TempSym::Run1(extra) => bw.put_bits(4, *extra),
TempSym::Run2(extra) => bw.put_bits(9, *extra),
}
}
}
let p_single = single_symbol(&p_lens);
if let Some(sym) = p_single {
bw.put_bits(params.pbit, 0);
bw.put_bits(params.pbit, sym as u32);
} else {
write_position_table(&mut bw, &p_lens, params.np, params.pbit);
}
for t in &tokens {
match t {
Token::Lit(b) => {
let s = *b as usize;
if c_single.is_none() {
bw.put_bits(c_lens[s] as u32, c_codes[s]);
}
}
Token::Match { len, offset } => {
let code = 256 + (len - MIN_MATCH);
if c_single.is_none() {
bw.put_bits(c_lens[code] as u32, c_codes[code]);
}
let (sym, extra_bits, extra) = offset_to_symbol(*offset);
if p_single.is_none() {
bw.put_bits(p_lens[sym] as u32, p_codes[sym]);
}
bw.put_bits(extra_bits, extra);
}
}
}
bw.finish()
}
fn lz_parse(data: &[u8], window: usize) -> Vec<Token> {
let n = data.len();
let mut tokens = Vec::new();
const HASH_BITS: u32 = 16;
const HASH_SIZE: usize = 1 << HASH_BITS;
let mut head = vec![usize::MAX; HASH_SIZE];
let mut prev = vec![usize::MAX; n];
let hash3 = |d: &[u8], i: usize| -> usize {
let a = d[i] as usize;
let b = d[i + 1] as usize;
let c = d[i + 2] as usize;
((a << 10) ^ (b << 5) ^ c).wrapping_mul(2654435761) >> (32 - HASH_BITS) & (HASH_SIZE - 1)
};
let max_chain = 128usize;
let mut i = 0usize;
while i < n {
let mut best_len = 0usize;
let mut best_off = 0usize;
if i + MIN_MATCH <= n {
let h = hash3(data, i);
let mut cand = head[h];
let mut chain = 0usize;
let max_match = MAX_MATCH.min(n - i);
let min_pos = i.saturating_sub(window);
while cand != usize::MAX && cand >= min_pos && chain < max_chain {
let mut l = 0usize;
while l < max_match && data[cand + l] == data[i + l] {
l += 1;
}
if l > best_len {
best_len = l;
best_off = i - cand - 1; if l >= max_match {
break;
}
}
cand = prev[cand];
chain += 1;
}
}
if best_len >= MIN_MATCH {
tokens.push(Token::Match {
len: best_len,
offset: best_off,
});
let end = i + best_len;
while i < end {
if i + MIN_MATCH <= n {
let h = hash3(data, i);
prev[i] = head[h];
head[h] = i;
}
i += 1;
}
} else {
tokens.push(Token::Lit(data[i]));
if i + MIN_MATCH <= n {
let h = hash3(data, i);
prev[i] = head[h];
head[h] = i;
}
i += 1;
}
}
tokens
}