use std::io;
use bitstream_io::{BigEndian, BitRead, BitReader};
use crc::Digest;
const CRC_32_ISO_HDLC: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
const RANDOMIZATION_TABLE: [u16; 256] = [
0xee, 0x56, 0xf8, 0xc3, 0x9d, 0x9f, 0xae, 0x2c, 0xad, 0xcd, 0x24, 0x9d, 0xa6, 0x101, 0x18,
0xb9, 0xa1, 0x82, 0x75, 0xe9, 0x9f, 0x55, 0x66, 0x6a, 0x86, 0x71, 0xdc, 0x84, 0x56, 0x96, 0x56,
0xa1, 0x84, 0x78, 0xb7, 0x32, 0x6a, 0x3, 0xe3, 0x2, 0x11, 0x101, 0x8, 0x44, 0x83, 0x100, 0x43,
0xe3, 0x1c, 0xf0, 0x86, 0x6a, 0x6b, 0xf, 0x3, 0x2d, 0x86, 0x17, 0x7b, 0x10, 0xf6, 0x80, 0x78,
0x7a, 0xa1, 0xe1, 0xef, 0x8c, 0xf6, 0x87, 0x4b, 0xa7, 0xe2, 0x77, 0xfa, 0xb8, 0x81, 0xee, 0x77,
0xc0, 0x9d, 0x29, 0x20, 0x27, 0x71, 0x12, 0xe0, 0x6b, 0xd1, 0x7c, 0xa, 0x89, 0x7d, 0x87, 0xc4,
0x101, 0xc1, 0x31, 0xaf, 0x38, 0x3, 0x68, 0x1b, 0x76, 0x79, 0x3f, 0xdb, 0xc7, 0x1b, 0x36, 0x7b,
0xe2, 0x63, 0x81, 0xee, 0xc, 0x63, 0x8b, 0x78, 0x38, 0x97, 0x9b, 0xd7, 0x8f, 0xdd, 0xf2, 0xa3,
0x77, 0x8c, 0xc3, 0x39, 0x20, 0xb3, 0x12, 0x11, 0xe, 0x17, 0x42, 0x80, 0x2c, 0xc4, 0x92, 0x59,
0xc8, 0xdb, 0x40, 0x76, 0x64, 0xb4, 0x55, 0x1a, 0x9e, 0xfe, 0x5f, 0x6, 0x3c, 0x41, 0xef, 0xd4,
0xaa, 0x98, 0x29, 0xcd, 0x1f, 0x2, 0xa8, 0x87, 0xd2, 0xa0, 0x93, 0x98, 0xef, 0xc, 0x43, 0xed,
0x9d, 0xc2, 0xeb, 0x81, 0xe9, 0x64, 0x23, 0x68, 0x1e, 0x25, 0x57, 0xde, 0x9a, 0xcf, 0x7f, 0xe5,
0xba, 0x41, 0xea, 0xea, 0x36, 0x1a, 0x28, 0x79, 0x20, 0x5e, 0x18, 0x4e, 0x7c, 0x8e, 0x58, 0x7a,
0xef, 0x91, 0x2, 0x93, 0xbb, 0x56, 0xa1, 0x49, 0x1b, 0x79, 0x92, 0xf3, 0x58, 0x4f, 0x52, 0x9c,
0x2, 0x77, 0xaf, 0x2a, 0x8f, 0x49, 0xd0, 0x99, 0x4d, 0x98, 0x101, 0x60, 0x93, 0x100, 0x75,
0x31, 0xce, 0x49, 0x20, 0x56, 0x57, 0xe2, 0xf5, 0x26, 0x2b, 0x8a, 0xbf, 0xde, 0xd0, 0x83, 0x34,
0xf4, 0x17,
];
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Io(#[from] io::Error),
#[error(transparent)]
BinRw(#[from] binrw::Error),
#[error("invalid file")]
InvalidFile,
}
impl From<Error> for io::Error {
fn from(val: Error) -> Self {
match val {
Error::Io(e) => e,
error => io::Error::other(error),
}
}
}
#[derive(Default, Clone, Copy)]
struct Symbol {
symbol: i32,
frequency: i32,
}
struct Model {
frequency: i32,
increment: i32,
limit: i32,
symbol_count: usize,
symbols: [Symbol; 128],
}
impl Model {
fn increment(&mut self, symindex: i32) {
self.symbols[symindex as usize].frequency += self.increment;
self.frequency += self.increment;
if self.frequency > self.limit {
self.frequency = 0;
for i in 0..self.symbol_count {
self.symbols[i].frequency += 1;
self.symbols[i].frequency >>= 1;
self.frequency += self.symbols[i].frequency;
}
}
}
}
struct MtfState {
table: [i32; 256],
}
impl MtfState {
fn reset(&mut self) {
self.table
.iter_mut()
.enumerate()
.for_each(|(idx, b)| *b = idx as i32);
}
fn decode(&mut self, symbol: i32) -> i32 {
let res = self.table[symbol as usize];
for i in (1..=symbol).rev() {
self.table[i as usize] = self.table[i as usize - 1];
}
self.table[0] = res;
res
}
}
impl Model {
fn init(&mut self, first_symbol: i32, last_symbol: i32, increment: i32, limit: i32) {
self.increment = increment;
self.limit = limit;
self.symbol_count = (last_symbol - first_symbol + 1) as usize;
for i in 0..self.symbol_count {
self.symbols[i] = Symbol {
symbol: first_symbol + i as i32,
frequency: 0,
}
}
self.frequency = self.increment * self.symbol_count as i32;
for i in 0..self.symbol_count {
self.symbols[i].frequency = self.increment;
}
}
fn reset(&mut self) {
self.frequency = self.increment * self.symbol_count as i32;
for i in 0..self.symbol_count {
self.symbols[i].frequency = self.increment;
}
}
}
impl Default for Model {
fn default() -> Self {
Self {
frequency: Default::default(),
increment: Default::default(),
limit: Default::default(),
symbol_count: Default::default(),
symbols: [Default::default(); 128],
}
}
}
const NUM_BITS: usize = 26;
const ONE: i32 = 1 << (NUM_BITS - 1);
const HALF: i32 = 1 << (NUM_BITS - 2);
#[derive(Default)]
struct Decoder {
range: i32,
code: i32,
}
impl Decoder {
pub fn try_init<R: io::Read + io::Seek>(
&mut self,
reader: &mut BitReader<R, BigEndian>,
) -> Result<(), Error> {
self.range = ONE;
self.code = reader.read_var(NUM_BITS as u32)?;
Ok(())
}
fn next_bit_string<R: io::Read + io::Seek>(
&mut self,
reader: &mut BitReader<R, BigEndian>,
model: &mut Model,
bits: usize,
) -> Result<i32, Error> {
let mut result: i32 = 0;
for i in 0..bits {
if self.next_symbol(reader, model)? != 0 {
result |= 1 << i;
}
}
Ok(result)
}
fn next_symbol<R: io::Read + io::Seek>(
&mut self,
reader: &mut BitReader<R, BigEndian>,
model: &mut Model,
) -> Result<i32, Error> {
let frequency: i32 = self.code / (self.range / model.frequency);
let mut cumulative = 0;
for n in 0..(model.symbol_count - 1) {
if cumulative + model.symbols[n].frequency > frequency {
self.read_next_arithmetic_code(
reader,
cumulative,
model.symbols[n].frequency,
model.frequency,
)?;
model.increment(n as i32);
return Ok(model.symbols[n].symbol);
}
cumulative += model.symbols[n].frequency;
}
let n = model.symbol_count - 1;
self.read_next_arithmetic_code(
reader,
cumulative,
model.symbols[n].frequency,
model.frequency,
)?;
model.increment(n as i32);
Ok(model.symbols[n].symbol)
}
fn read_next_arithmetic_code<R: io::Read + io::Seek>(
&mut self,
reader: &mut BitReader<R, BigEndian>,
symlow: i32,
symsize: i32,
symtot: i32,
) -> Result<(), Error> {
let renorm_factor = self.range / symtot;
let lowincr = renorm_factor * symlow;
self.code -= lowincr;
if symlow + symsize == symtot {
self.range -= lowincr;
} else {
self.range = symsize * renorm_factor;
}
while self.range <= HALF {
self.range <<= 1;
self.code = (self.code << 1) | if reader.read_bit()? { 1 } else { 0 };
}
Ok(())
}
}
pub struct ArsenicReader<'a, R: io::Read + io::Seek> {
inner: BitReader<R, BigEndian>,
initial_model: Model,
selector_model: Model,
mtf: [Model; 7],
decoder: Decoder,
mtf_state: MtfState,
block_bits: i32,
block_size: i32,
block: Vec<u8>,
end_of_block: bool,
num_bytes: i32,
byte_count: i32,
transform_index: i32,
transform: Vec<u32>,
randomized: i32,
randcount: i32,
randindex: i32,
repeat: i32,
count: i32,
last: i32,
comp_crc: u32,
pos: usize,
uncompressed_size: u64,
crc3: Digest<'a, u32>,
}
impl<'a, R: io::Read + io::Seek> ArsenicReader<'a, R> {
pub fn try_from(inner: R, uncompressed_size: u64) -> Result<Self, Error> {
let mut me = Self {
inner: BitReader::new(inner),
initial_model: Default::default(),
selector_model: Default::default(),
mtf: Default::default(),
mtf_state: MtfState { table: [0i32; 256] },
decoder: Default::default(),
block_bits: 0,
block_size: 0,
block: Vec::new(),
end_of_block: false,
num_bytes: 0,
byte_count: 0,
transform_index: 0,
transform: Vec::new(),
randomized: 0,
randcount: 0,
randindex: 0,
repeat: 0,
count: 0,
last: 0,
comp_crc: 0,
pos: 0,
uncompressed_size,
crc3: CRC_32_ISO_HDLC.digest(),
};
me.reset()?;
Ok(me)
}
fn reset(&mut self) -> Result<(), Error> {
self.decoder.try_init(&mut self.inner)?;
self.initial_model.init(0, 1, 1, 256);
self.selector_model.init(0, 10, 8, 1024);
self.mtf[0].init(2, 3, 8, 1024);
self.mtf[1].init(4, 7, 4, 1024);
self.mtf[2].init(8, 15, 4, 1024);
self.mtf[3].init(16, 31, 4, 1024);
self.mtf[4].init(32, 63, 2, 1024);
self.mtf[5].init(64, 127, 2, 1024);
self.mtf[6].init(128, 255, 1, 1024);
if self
.decoder
.next_bit_string(&mut self.inner, &mut self.initial_model, 8)? as u8
!= b'A'
{
return Err(Error::InvalidFile);
}
if self
.decoder
.next_bit_string(&mut self.inner, &mut self.initial_model, 8)? as u8
!= b's'
{
return Err(Error::InvalidFile);
}
self.block_bits =
self.decoder
.next_bit_string(&mut self.inner, &mut self.initial_model, 4)?
+ 9;
self.block_size = 1 << self.block_bits;
self.num_bytes = 0;
self.byte_count = 0;
self.repeat = 0;
self.block = vec![0u8; self.block_size as usize];
self.comp_crc = 0;
self.end_of_block = self
.decoder
.next_symbol(&mut self.inner, &mut self.initial_model)?
!= 0;
Ok(())
}
fn read_block(&mut self) -> Result<(), Error> {
self.mtf_state.reset();
self.randomized = self
.decoder
.next_symbol(&mut self.inner, &mut self.initial_model)?;
self.transform_index = self.decoder.next_bit_string(
&mut self.inner,
&mut self.initial_model,
self.block_bits as usize,
)?;
self.num_bytes = 0;
loop {
let mut sel = self
.decoder
.next_symbol(&mut self.inner, &mut self.selector_model)?;
if sel == 0 || sel == 1 {
let mut zerostate = 1;
let mut zerocount = 0;
while sel < 2 {
if sel == 0 {
zerocount += zerostate;
} else if sel == 1 {
zerocount += 2 * zerostate;
}
zerostate *= 2;
sel = self
.decoder
.next_symbol(&mut self.inner, &mut self.selector_model)?;
}
if self.num_bytes + zerocount > self.block_size {
return Err(Error::InvalidFile);
}
let value = self.mtf_state.decode(0);
for j in 0..zerocount {
self.block[self.num_bytes as usize + j as usize] = value as u8;
}
self.num_bytes += zerocount;
}
let symbol;
if sel == 10 {
break;
} else if sel == 2 {
symbol = 1;
} else {
symbol = self
.decoder
.next_symbol(&mut self.inner, &mut self.mtf[sel as usize - 3])?;
}
if self.num_bytes > self.block_size {
return Err(Error::InvalidFile);
}
self.block[self.num_bytes as usize] = self.mtf_state.decode(symbol) as u8;
self.num_bytes += 1;
}
if self.transform_index > self.num_bytes {
return Err(Error::InvalidFile);
}
self.selector_model.reset();
for mtf in self.mtf.iter_mut() {
mtf.reset();
}
if self
.decoder
.next_symbol(&mut self.inner, &mut self.initial_model)?
!= 0
{
self.comp_crc =
self.decoder
.next_bit_string(&mut self.inner, &mut self.initial_model, 32)?
as u32;
self.end_of_block = true;
}
self.transform = vec![0u32; self.num_bytes as usize];
calcuate_inverse_bwt(
&mut self.transform,
&mut self.block,
self.num_bytes as usize,
);
Ok(())
}
#[inline]
fn produce_next_byte(&mut self) -> Result<Option<u8>, Error> {
if self.pos >= self.uncompressed_size as usize {
return Ok(None);
}
self.pos += 1;
if self.repeat > 0 {
self.repeat -= 1;
Ok(Some(self.track_crc(self.last as u8)))
} else {
loop {
if self.byte_count >= self.num_bytes {
if self.end_of_block {
return Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))?;
}
self.read_block()?;
self.byte_count = 0;
self.count = 0;
self.last = 0;
self.randindex = 0;
self.randcount = RANDOMIZATION_TABLE[0] as i32;
}
self.transform_index = self.transform[self.transform_index as usize] as i32;
let mut byte = self.block[self.transform_index as usize];
if self.randomized != 0 && self.randcount == self.byte_count {
byte ^= 1;
self.randindex = (self.randindex + 1) & 255;
self.randcount += RANDOMIZATION_TABLE[self.randindex as usize] as i32;
}
self.byte_count += 1;
if self.count == 4 {
self.count = 0;
if byte == 0 {
continue;
}
self.repeat = byte as i32 - 1;
return Ok(Some(self.track_crc(self.last as u8)));
} else {
if byte == self.last as u8 {
self.count += 1;
} else {
self.count = 1;
self.last = byte as i32;
}
return Ok(Some(self.track_crc(byte)));
}
}
}
}
fn track_crc(&mut self, data: u8) -> u8 {
self.crc3.update(&[data]);
data
}
pub fn is_checksum_valid(&mut self) -> bool {
self.comp_crc == self.crc3.clone().finalize()
}
}
impl<'a, R: io::Read + io::Seek> io::Read for ArsenicReader<'a, R> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
for (idx, b) in buf.iter_mut().enumerate() {
match self.produce_next_byte() {
Ok(None) => return Ok(idx),
Ok(Some(val)) => *b = val,
Err(e) => return Err(e.into()),
}
}
Ok(buf.len())
}
}
impl<'a, R: io::Read + io::Seek> io::Seek for ArsenicReader<'a, R> {
fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
todo!()
}
#[inline]
fn stream_len(&mut self) -> io::Result<u64> {
Ok(self.uncompressed_size)
}
#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.pos as u64)
}
}
fn calcuate_inverse_bwt(transform: &mut [u32], block: &mut [u8], count: usize) {
let mut counts = [0i32; 256];
let mut cumulative_counts = [0i32; 256];
for i in 0..count {
counts[block[i] as usize] += 1;
}
let mut total = 0;
for i in 0..256 {
cumulative_counts[i] = total;
total += counts[i];
counts[i] = 0;
}
for i in 0..count {
transform
[cumulative_counts[block[i] as usize] as usize + counts[block[i] as usize] as usize] =
i as u32;
counts[block[i] as usize] += 1;
}
}