use std::io::{self, Seek};
use bitstream_io::BitRead;
use super::huffman_decoder::HuffmanDecoder;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Encountered an invalid tree path while reading next symbol")]
InvalidCode,
}
impl From<Error> for io::Error {
fn from(val: Error) -> Self {
match val {
Error::Io(error) => error,
err => io::Error::other(err),
}
}
}
pub(crate) enum LiteralOrOffset {
Literal(u8),
Offset { length: u16, offset: u32 },
}
#[derive(Debug, Copy, Clone)]
pub struct TreeNode {
parent_ptr: Option<NodeId>,
left_ptr: Option<NodeId>,
right_ptr: Option<NodeId>,
index: TreeIdx,
value: u16,
frequency: usize,
}
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq)]
struct NodeId(u16);
impl NodeId {
#[inline]
fn raw(&self) -> usize {
self.0 as usize
}
}
#[derive(Debug, Copy, Clone, Default, PartialEq, Eq, PartialOrd, Ord)]
struct TreeIdx(u16);
impl TreeIdx {
#[inline]
fn is_root(&self) -> bool {
*self == ROOT
}
#[inline]
fn previous(&self) -> Option<Self> {
if self.is_root() {
None
} else {
Some(Self(self.0 - 1))
}
}
#[inline]
fn index(&self) -> usize {
self.0 as usize
}
}
const ROOT: TreeIdx = TreeIdx(0);
impl Default for TreeNode {
fn default() -> Self {
Self {
parent_ptr: None,
left_ptr: None,
right_ptr: None,
index: Default::default(),
value: Default::default(),
frequency: usize::MAX,
}
}
}
const WINDOW_SIZE: usize = 4096;
const LEAF_COUNT: usize = 314;
const NODE_COUNT: usize = LEAF_COUNT * 2 - 1;
pub(crate) struct LzSlidingWindow<const WINDOW_SIZE: usize> {
pub(crate) window: [u8; WINDOW_SIZE],
pub(crate) window_mask: usize,
pub(crate) match_len: u16,
pub(crate) match_offset: i32,
}
const fn default_window() -> LzSlidingWindow<WINDOW_SIZE> {
LzSlidingWindow {
match_len: 0,
match_offset: 0,
window_mask: WINDOW_SIZE - 1,
window: default_window_contents(),
}
}
const fn default_window_contents() -> [u8; WINDOW_SIZE] {
use cfor::cfor;
let mut window = [0u8; WINDOW_SIZE];
let mut cur = 0;
cur += 18;
cfor! {let mut i=0; i < 256; i += 1; {
cfor!{let mut j=0; j < 13; j+=1; {
window[cur + i * 13 + j] = i as u8;
}}
}}
cur += 13 * 256;
cfor! {let mut i=0; i < 256; i += 1; {
window[cur + i] = i as u8;
}}
cur += 256;
cfor! {let mut i=0; i < 256; i+=1; {
window[cur + i] = 255 - i as u8;
}}
cur += 256;
cur += 128;
cfor! {let mut i=0; i < 110; i += 1; {
window[cur + i] = b' ';
}}
cur += 110;
if cur != WINDOW_SIZE {
panic!("Something went wrong during window initialization");
}
window
}
impl Default for LzSlidingWindow<WINDOW_SIZE> {
fn default() -> Self {
Self {
window: default_window_contents(),
window_mask: WINDOW_SIZE - 1,
match_len: 0,
match_offset: 0,
}
}
}
impl<const WINDOW_SIZE: usize> LzSlidingWindow<WINDOW_SIZE> {
#[inline]
pub(crate) fn is_empty(&self) -> bool {
self.match_len == 0
}
#[inline]
pub(crate) fn update(&mut self, pos: usize, val: LiteralOrOffset) -> u8 {
match val {
LiteralOrOffset::Literal(lit) => {
self.window[pos & self.window_mask] = lit;
lit
}
LiteralOrOffset::Offset { length, offset } => {
self.match_len = length;
self.match_offset = pos as i32 - offset as i32;
self.next(pos)
}
}
}
#[inline]
pub(crate) fn next(&mut self, pos: usize) -> u8 {
self.match_len -= 1;
let byte = self.window[self.match_offset as usize & self.window_mask];
self.match_offset += 1;
self.window[pos & self.window_mask] = byte;
byte
}
}
pub struct LzahReader<R: io::Read + io::Seek> {
inner: bitstream_io::BitReader<R, bitstream_io::BigEndian>,
uncompressed_size: u64,
nodes: [TreeNode; NODE_COUNT],
tree: [NodeId; NODE_COUNT],
pos: usize,
decoder: HuffmanDecoder,
win: LzSlidingWindow<WINDOW_SIZE>,
}
impl<R: io::Read + io::Seek> LzahReader<R> {
pub fn new(inner: R, uncompressed_size: u64) -> Self {
let mut decoder = HuffmanDecoder::initialize(
&[
3, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7,
7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8,
8, 8, 8, 8, 8, 8, 8, 8,
],
8,
true,
)
.unwrap();
decoder.make_table(false);
let mut me = Self {
inner: bitstream_io::BitReader::<_, bitstream_io::BigEndian>::new(inner),
nodes: [Default::default(); NODE_COUNT],
tree: [NodeId(0); NODE_COUNT],
decoder,
pos: 0,
win: default_window(),
uncompressed_size,
};
me.reset();
me
}
fn reset(&mut self) {
self.nodes = [Default::default(); NODE_COUNT];
self.tree = [Default::default(); NODE_COUNT];
self.tree
.iter_mut()
.enumerate()
.for_each(|(idx, node)| *node = NodeId(idx as u16));
for i in 0..LEAF_COUNT {
let node = NodeId(NODE_COUNT as u16 - 1 - i as u16);
self.node_mut(node).index = TreeIdx(node.0);
self.node_mut(node).frequency = 1;
self.node_mut(node).value = i as u16;
}
for i in (0..(LEAF_COUNT - 1)).rev() {
let parent = NodeId(i as u16);
let left = NodeId(i as u16 * 2 + 1);
let right = NodeId(i as u16 * 2 + 2);
self.node_mut(parent).index = TreeIdx(i as u16);
self.node_mut(parent).left_ptr = Some(left);
self.node_mut(parent).right_ptr = Some(right);
self.node_mut(parent).frequency =
self.node(left).frequency + self.node(right).frequency;
self.node_mut(left).parent_ptr = Some(parent);
self.node_mut(right).parent_ptr = Some(parent);
}
}
pub fn into_inner(self) -> R {
self.inner.into_reader()
}
#[inline]
fn next(&mut self) -> Result<LiteralOrOffset, Error> {
let mut node = self.tree_lookup(ROOT);
while self.node(node).left_ptr.is_some() || self.node(node).right_ptr.is_some() {
node = match self.inner.read_bit() {
Ok(true) => self.node(node).left_ptr.unwrap(),
Ok(false) => self.node(node).right_ptr.unwrap(),
Err(e) => return Err(e)?,
}
}
if self.tree_node(ROOT).frequency == 0x8000 {
self.reconstruct_tree();
}
self.update_node(node);
let literal = self.node(node).value;
if literal < 0x100 {
return Ok(LiteralOrOffset::Literal(literal as u8));
}
let length = literal - 0x100 + 3;
let highbits = self.decoder.next_symbol(&mut self.inner)?;
let lowbits = self.inner.read::<6, u32>()?;
let offset = ((highbits as u32) << 6) + lowbits + 1;
Ok(LiteralOrOffset::Offset { length, offset })
}
#[inline]
fn update_node(&mut self, node: NodeId) {
let mut node = node;
loop {
self.node_mut(node).frequency += 1;
if self.node(node).parent_ptr.is_none() {
break;
}
self.rearrange_node(node);
node = self.node(node).parent_ptr.unwrap();
}
}
fn rearrange_node(&mut self, node: NodeId) {
let TreeNode {
index: node_idx,
frequency,
..
} = self.node(node);
let mut ancestor = node_idx;
while let Some(prev) = ancestor.previous() {
if self.tree_node(prev).frequency < frequency {
ancestor = prev;
} else {
break;
}
}
if ancestor < node_idx {
self.swap_nodes(ancestor, node_idx);
}
}
fn swap_nodes(&mut self, node_1_idx: TreeIdx, node_2_idx: TreeIdx) {
let node_1 = self.tree_lookup(node_1_idx);
let parent_1 = self.node(node_1).parent_ptr;
let node_2 = self.tree_lookup(node_2_idx);
let parent_2 = self.node(node_2).parent_ptr;
let node_1_is_right_child = parent_1
.map(|parent| self.node(parent).right_ptr == Some(node_1))
.unwrap_or_default();
let node_2_is_right_child = parent_2
.map(|parent| self.node(parent).right_ptr == Some(node_2))
.unwrap_or_default();
if let Some(parent) = parent_1 {
if node_1_is_right_child {
self.node_mut(parent).right_ptr = Some(node_2);
} else {
self.node_mut(parent).left_ptr = Some(node_2);
}
}
if let Some(parent) = parent_2 {
if node_2_is_right_child {
self.node_mut(parent).right_ptr = Some(node_1);
} else {
self.node_mut(parent).left_ptr = Some(node_1);
}
}
self.node_mut(node_1).parent_ptr = parent_2;
self.node_mut(node_2).parent_ptr = parent_1;
self.node_mut(node_1).index = node_2_idx;
self.node_mut(node_2).index = node_1_idx;
self.tree[node_1_idx.index()] = node_2;
self.tree[node_2_idx.index()] = node_1;
}
fn reconstruct_tree(&mut self) {
let mut leaf_nodes = Vec::with_capacity(LEAF_COUNT);
for index in 0..NODE_COUNT {
let node = self.tree[index];
if self.is_leaf(node) {
self.node_mut(node).frequency = self.node(node).frequency.div_ceil(2);
leaf_nodes.push(node);
}
}
assert_eq!(leaf_nodes.len(), LEAF_COUNT);
let mut leaf_index = LEAF_COUNT as i32 - 1;
let mut branch_index = LEAF_COUNT as i32 - 2;
let mut node_index: i32 = NODE_COUNT as i32 - 1;
let mut pair_index: i32 = NODE_COUNT as i32 - 2;
while node_index >= 0 {
while node_index >= pair_index {
let leaf = leaf_nodes[leaf_index as usize];
self.tree[node_index as usize] = leaf;
self.node_mut(leaf).index = TreeIdx(node_index as u16);
node_index -= 1;
leaf_index -= 1;
}
let branch = NodeId(branch_index as u16);
let left_child = self.tree[pair_index as usize];
let right_child = self.tree[pair_index as usize + 1];
self.node_mut(branch).left_ptr = Some(left_child);
self.node_mut(branch).right_ptr = Some(right_child);
self.node_mut(left_child).parent_ptr = Some(branch);
self.node_mut(right_child).parent_ptr = Some(branch);
self.node_mut(branch).frequency =
self.node(left_child).frequency + self.node(right_child).frequency;
branch_index -= 1;
while leaf_index >= 0
&& self.node(leaf_nodes[leaf_index as usize]).frequency
<= self.node(branch).frequency
{
let leaf = leaf_nodes[leaf_index as usize];
self.tree[node_index as usize] = leaf;
self.node_mut(leaf).index = TreeIdx(node_index as u16);
node_index -= 1;
leaf_index -= 1;
}
self.tree[node_index as usize] = branch;
self.node_mut(branch).index = TreeIdx(node_index as u16);
node_index -= 1;
pair_index -= 2;
}
self.node_mut(self.tree_lookup(ROOT)).parent_ptr = None;
}
#[inline]
fn is_leaf(&self, node: NodeId) -> bool {
!self.has_left_child(node) && !self.has_right_child(node)
}
#[inline]
fn has_left_child(&self, node: NodeId) -> bool {
self.node(node).left_ptr.is_some()
}
#[inline]
fn has_right_child(&self, node: NodeId) -> bool {
self.node(node).right_ptr.is_some()
}
#[inline]
fn produce_next_byte(&mut self) -> Result<u8, Error> {
if self.win.is_empty() {
let token = self.next()?;
return Ok(self.win.update(self.pos, token));
}
Ok(self.win.next(self.pos))
}
#[inline]
fn node(&self, node: NodeId) -> TreeNode {
self.nodes[node.raw()]
}
#[inline]
fn node_mut(&mut self, node: NodeId) -> &mut TreeNode {
&mut self.nodes[node.raw()]
}
#[inline]
fn tree_node(&self, idx: TreeIdx) -> TreeNode {
self.node(self.tree[idx.index()])
}
#[inline]
fn tree_lookup(&self, idx: TreeIdx) -> NodeId {
self.tree[idx.index()]
}
}
impl<R: io::Read + io::Seek> io::Read for LzahReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
for (idx, b) in buf.iter_mut().enumerate() {
if self.stream_position()? >= self.stream_len()? {
return Ok(idx);
}
match self.produce_next_byte() {
Ok(byte) => {
*b = byte;
self.pos += 1;
}
Err(e) => return Err(e)?,
}
}
Ok(buf.len())
}
}
impl<R: io::Read + io::Seek> io::Seek for LzahReader<R> {
fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
todo!()
}
#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.pos as u64)
}
#[inline]
fn stream_len(&mut self) -> io::Result<u64> {
Ok(self.uncompressed_size)
}
}