Skip to main content

sit_algos/
arsenic.rs

1use std::io;
2
3use bitstream_io::{BigEndian, BitRead, BitReader};
4use crc::Digest;
5
6const CRC_32_ISO_HDLC: crc::Crc<u32> = crc::Crc::<u32>::new(&crc::CRC_32_ISO_HDLC);
7
8const RANDOMIZATION_TABLE: [u16; 256] = [
9    0xee, 0x56, 0xf8, 0xc3, 0x9d, 0x9f, 0xae, 0x2c, 0xad, 0xcd, 0x24, 0x9d, 0xa6, 0x101, 0x18,
10    0xb9, 0xa1, 0x82, 0x75, 0xe9, 0x9f, 0x55, 0x66, 0x6a, 0x86, 0x71, 0xdc, 0x84, 0x56, 0x96, 0x56,
11    0xa1, 0x84, 0x78, 0xb7, 0x32, 0x6a, 0x3, 0xe3, 0x2, 0x11, 0x101, 0x8, 0x44, 0x83, 0x100, 0x43,
12    0xe3, 0x1c, 0xf0, 0x86, 0x6a, 0x6b, 0xf, 0x3, 0x2d, 0x86, 0x17, 0x7b, 0x10, 0xf6, 0x80, 0x78,
13    0x7a, 0xa1, 0xe1, 0xef, 0x8c, 0xf6, 0x87, 0x4b, 0xa7, 0xe2, 0x77, 0xfa, 0xb8, 0x81, 0xee, 0x77,
14    0xc0, 0x9d, 0x29, 0x20, 0x27, 0x71, 0x12, 0xe0, 0x6b, 0xd1, 0x7c, 0xa, 0x89, 0x7d, 0x87, 0xc4,
15    0x101, 0xc1, 0x31, 0xaf, 0x38, 0x3, 0x68, 0x1b, 0x76, 0x79, 0x3f, 0xdb, 0xc7, 0x1b, 0x36, 0x7b,
16    0xe2, 0x63, 0x81, 0xee, 0xc, 0x63, 0x8b, 0x78, 0x38, 0x97, 0x9b, 0xd7, 0x8f, 0xdd, 0xf2, 0xa3,
17    0x77, 0x8c, 0xc3, 0x39, 0x20, 0xb3, 0x12, 0x11, 0xe, 0x17, 0x42, 0x80, 0x2c, 0xc4, 0x92, 0x59,
18    0xc8, 0xdb, 0x40, 0x76, 0x64, 0xb4, 0x55, 0x1a, 0x9e, 0xfe, 0x5f, 0x6, 0x3c, 0x41, 0xef, 0xd4,
19    0xaa, 0x98, 0x29, 0xcd, 0x1f, 0x2, 0xa8, 0x87, 0xd2, 0xa0, 0x93, 0x98, 0xef, 0xc, 0x43, 0xed,
20    0x9d, 0xc2, 0xeb, 0x81, 0xe9, 0x64, 0x23, 0x68, 0x1e, 0x25, 0x57, 0xde, 0x9a, 0xcf, 0x7f, 0xe5,
21    0xba, 0x41, 0xea, 0xea, 0x36, 0x1a, 0x28, 0x79, 0x20, 0x5e, 0x18, 0x4e, 0x7c, 0x8e, 0x58, 0x7a,
22    0xef, 0x91, 0x2, 0x93, 0xbb, 0x56, 0xa1, 0x49, 0x1b, 0x79, 0x92, 0xf3, 0x58, 0x4f, 0x52, 0x9c,
23    0x2, 0x77, 0xaf, 0x2a, 0x8f, 0x49, 0xd0, 0x99, 0x4d, 0x98, 0x101, 0x60, 0x93, 0x100, 0x75,
24    0x31, 0xce, 0x49, 0x20, 0x56, 0x57, 0xe2, 0xf5, 0x26, 0x2b, 0x8a, 0xbf, 0xde, 0xd0, 0x83, 0x34,
25    0xf4, 0x17,
26];
27
28#[derive(Debug, thiserror::Error)]
29pub enum Error {
30    #[error(transparent)]
31    Io(#[from] io::Error),
32    #[error(transparent)]
33    BinRw(#[from] binrw::Error),
34
35    #[error("invalid file")]
36    InvalidFile,
37}
38
39impl From<Error> for io::Error {
40    fn from(val: Error) -> Self {
41        match val {
42            Error::Io(e) => e,
43            error => io::Error::other(error),
44        }
45    }
46}
47
48#[derive(Default, Clone, Copy)]
49struct Symbol {
50    symbol: i32,
51    frequency: i32,
52}
53
54struct Model {
55    frequency: i32,
56    increment: i32,
57    limit: i32,
58
59    symbol_count: usize,
60    // TODO: make symobl_count const generic
61    symbols: [Symbol; 128],
62}
63
64impl Model {
65    fn increment(&mut self, symindex: i32) {
66        self.symbols[symindex as usize].frequency += self.increment;
67        self.frequency += self.increment;
68
69        if self.frequency > self.limit {
70            self.frequency = 0;
71            for i in 0..self.symbol_count {
72                self.symbols[i].frequency += 1;
73                self.symbols[i].frequency >>= 1;
74                self.frequency += self.symbols[i].frequency;
75            }
76        }
77    }
78}
79
80struct MtfState {
81    table: [i32; 256],
82}
83
84impl MtfState {
85    fn reset(&mut self) {
86        self.table
87            .iter_mut()
88            .enumerate()
89            .for_each(|(idx, b)| *b = idx as i32);
90    }
91
92    fn decode(&mut self, symbol: i32) -> i32 {
93        let res = self.table[symbol as usize];
94        for i in (1..=symbol).rev() {
95            self.table[i as usize] = self.table[i as usize - 1];
96        }
97        self.table[0] = res;
98        res
99    }
100}
101
102impl Model {
103    fn init(&mut self, first_symbol: i32, last_symbol: i32, increment: i32, limit: i32) {
104        self.increment = increment;
105        self.limit = limit;
106        self.symbol_count = (last_symbol - first_symbol + 1) as usize;
107
108        for i in 0..self.symbol_count {
109            self.symbols[i] = Symbol {
110                symbol: first_symbol + i as i32,
111                frequency: 0,
112            }
113        }
114
115        self.frequency = self.increment * self.symbol_count as i32;
116        for i in 0..self.symbol_count {
117            self.symbols[i].frequency = self.increment;
118        }
119    }
120
121    fn reset(&mut self) {
122        self.frequency = self.increment * self.symbol_count as i32;
123        for i in 0..self.symbol_count {
124            self.symbols[i].frequency = self.increment;
125        }
126    }
127}
128
129impl Default for Model {
130    fn default() -> Self {
131        Self {
132            frequency: Default::default(),
133            increment: Default::default(),
134            limit: Default::default(),
135            symbol_count: Default::default(),
136            symbols: [Default::default(); 128],
137        }
138    }
139}
140
141const NUM_BITS: usize = 26;
142const ONE: i32 = 1 << (NUM_BITS - 1);
143const HALF: i32 = 1 << (NUM_BITS - 2);
144
145#[derive(Default)]
146struct Decoder {
147    range: i32,
148    code: i32,
149}
150
151impl Decoder {
152    pub fn try_init<R: io::Read + io::Seek>(
153        &mut self,
154        reader: &mut BitReader<R, BigEndian>,
155    ) -> Result<(), Error> {
156        self.range = ONE;
157        self.code = reader.read_var(NUM_BITS as u32)?;
158
159        Ok(())
160    }
161
162    fn next_bit_string<R: io::Read + io::Seek>(
163        &mut self,
164        reader: &mut BitReader<R, BigEndian>,
165        model: &mut Model,
166        bits: usize,
167    ) -> Result<i32, Error> {
168        let mut result: i32 = 0;
169        for i in 0..bits {
170            if self.next_symbol(reader, model)? != 0 {
171                result |= 1 << i;
172            }
173        }
174        Ok(result)
175    }
176
177    fn next_symbol<R: io::Read + io::Seek>(
178        &mut self,
179        reader: &mut BitReader<R, BigEndian>,
180        model: &mut Model,
181    ) -> Result<i32, Error> {
182        let frequency: i32 = self.code / (self.range / model.frequency);
183        let mut cumulative = 0;
184        for n in 0..(model.symbol_count - 1) {
185            if cumulative + model.symbols[n].frequency > frequency {
186                self.read_next_arithmetic_code(
187                    reader,
188                    cumulative,
189                    model.symbols[n].frequency,
190                    model.frequency,
191                )?;
192                model.increment(n as i32);
193                return Ok(model.symbols[n].symbol);
194            }
195
196            cumulative += model.symbols[n].frequency;
197        }
198
199        let n = model.symbol_count - 1;
200        self.read_next_arithmetic_code(
201            reader,
202            cumulative,
203            model.symbols[n].frequency,
204            model.frequency,
205        )?;
206        model.increment(n as i32);
207        Ok(model.symbols[n].symbol)
208    }
209
210    fn read_next_arithmetic_code<R: io::Read + io::Seek>(
211        &mut self,
212        reader: &mut BitReader<R, BigEndian>,
213        symlow: i32,
214        symsize: i32,
215        symtot: i32,
216    ) -> Result<(), Error> {
217        let renorm_factor = self.range / symtot;
218        let lowincr = renorm_factor * symlow;
219
220        self.code -= lowincr;
221        if symlow + symsize == symtot {
222            self.range -= lowincr;
223        } else {
224            self.range = symsize * renorm_factor;
225        }
226
227        while self.range <= HALF {
228            self.range <<= 1;
229            self.code = (self.code << 1) | if reader.read_bit()? { 1 } else { 0 };
230        }
231
232        Ok(())
233    }
234}
235
236pub struct ArsenicReader<'a, R: io::Read + io::Seek> {
237    inner: BitReader<R, BigEndian>,
238
239    initial_model: Model,
240    selector_model: Model,
241    mtf: [Model; 7],
242    decoder: Decoder,
243    mtf_state: MtfState,
244
245    block_bits: i32,
246    block_size: i32,
247    block: Vec<u8>,
248    end_of_block: bool,
249
250    num_bytes: i32,
251    byte_count: i32,
252    transform_index: i32,
253    transform: Vec<u32>,
254
255    randomized: i32,
256    randcount: i32,
257    randindex: i32,
258
259    repeat: i32,
260    count: i32,
261    last: i32,
262
263    comp_crc: u32,
264
265    pos: usize,
266    uncompressed_size: u64,
267
268    crc3: Digest<'a, u32>,
269}
270
271impl<'a, R: io::Read + io::Seek> ArsenicReader<'a, R> {
272    pub fn try_from(inner: R, uncompressed_size: u64) -> Result<Self, Error> {
273        let mut me = Self {
274            inner: BitReader::new(inner),
275
276            initial_model: Default::default(),
277            selector_model: Default::default(),
278            mtf: Default::default(),
279            mtf_state: MtfState { table: [0i32; 256] },
280            decoder: Default::default(),
281            block_bits: 0,
282            block_size: 0,
283            block: Vec::new(),
284            end_of_block: false,
285            num_bytes: 0,
286            byte_count: 0,
287            transform_index: 0,
288            transform: Vec::new(),
289            randomized: 0,
290            randcount: 0,
291            randindex: 0,
292            repeat: 0,
293            count: 0,
294            last: 0,
295            comp_crc: 0,
296
297            pos: 0,
298            uncompressed_size,
299
300            crc3: CRC_32_ISO_HDLC.digest(),
301        };
302        me.reset()?;
303        Ok(me)
304    }
305
306    fn reset(&mut self) -> Result<(), Error> {
307        self.decoder.try_init(&mut self.inner)?;
308        self.initial_model.init(0, 1, 1, 256);
309        self.selector_model.init(0, 10, 8, 1024);
310        self.mtf[0].init(2, 3, 8, 1024);
311        self.mtf[1].init(4, 7, 4, 1024);
312        self.mtf[2].init(8, 15, 4, 1024);
313        self.mtf[3].init(16, 31, 4, 1024);
314        self.mtf[4].init(32, 63, 2, 1024);
315        self.mtf[5].init(64, 127, 2, 1024);
316        self.mtf[6].init(128, 255, 1, 1024);
317
318        if self
319            .decoder
320            .next_bit_string(&mut self.inner, &mut self.initial_model, 8)? as u8
321            != b'A'
322        {
323            return Err(Error::InvalidFile);
324        }
325
326        if self
327            .decoder
328            .next_bit_string(&mut self.inner, &mut self.initial_model, 8)? as u8
329            != b's'
330        {
331            return Err(Error::InvalidFile);
332        }
333
334        self.block_bits =
335            self.decoder
336                .next_bit_string(&mut self.inner, &mut self.initial_model, 4)?
337                + 9;
338        self.block_size = 1 << self.block_bits;
339        self.num_bytes = 0;
340        self.byte_count = 0;
341        self.repeat = 0;
342        self.block = vec![0u8; self.block_size as usize];
343        self.comp_crc = 0;
344
345        self.end_of_block = self
346            .decoder
347            .next_symbol(&mut self.inner, &mut self.initial_model)?
348            != 0;
349
350        Ok(())
351    }
352
353    fn read_block(&mut self) -> Result<(), Error> {
354        self.mtf_state.reset();
355
356        self.randomized = self
357            .decoder
358            .next_symbol(&mut self.inner, &mut self.initial_model)?;
359        self.transform_index = self.decoder.next_bit_string(
360            &mut self.inner,
361            &mut self.initial_model,
362            self.block_bits as usize,
363        )?;
364        self.num_bytes = 0;
365
366        loop {
367            let mut sel = self
368                .decoder
369                .next_symbol(&mut self.inner, &mut self.selector_model)?;
370            if sel == 0 || sel == 1 {
371                let mut zerostate = 1;
372                let mut zerocount = 0;
373
374                while sel < 2 {
375                    if sel == 0 {
376                        zerocount += zerostate;
377                    } else if sel == 1 {
378                        zerocount += 2 * zerostate;
379                    }
380                    zerostate *= 2;
381                    sel = self
382                        .decoder
383                        .next_symbol(&mut self.inner, &mut self.selector_model)?;
384                }
385
386                if self.num_bytes + zerocount > self.block_size {
387                    return Err(Error::InvalidFile);
388                }
389
390                let value = self.mtf_state.decode(0);
391                for j in 0..zerocount {
392                    self.block[self.num_bytes as usize + j as usize] = value as u8;
393                }
394                self.num_bytes += zerocount;
395            }
396
397            let symbol;
398            if sel == 10 {
399                break;
400            } else if sel == 2 {
401                symbol = 1;
402            } else {
403                symbol = self
404                    .decoder
405                    .next_symbol(&mut self.inner, &mut self.mtf[sel as usize - 3])?;
406            }
407
408            if self.num_bytes > self.block_size {
409                return Err(Error::InvalidFile);
410            }
411
412            self.block[self.num_bytes as usize] = self.mtf_state.decode(symbol) as u8;
413            self.num_bytes += 1;
414        }
415
416        if self.transform_index > self.num_bytes {
417            return Err(Error::InvalidFile);
418        }
419
420        self.selector_model.reset();
421        for mtf in self.mtf.iter_mut() {
422            mtf.reset();
423        }
424
425        if self
426            .decoder
427            .next_symbol(&mut self.inner, &mut self.initial_model)?
428            != 0
429        {
430            self.comp_crc =
431                self.decoder
432                    .next_bit_string(&mut self.inner, &mut self.initial_model, 32)?
433                    as u32;
434            self.end_of_block = true;
435        }
436
437        self.transform = vec![0u32; self.num_bytes as usize];
438
439        calcuate_inverse_bwt(
440            &mut self.transform,
441            &mut self.block,
442            self.num_bytes as usize,
443        );
444
445        Ok(())
446    }
447
448    #[inline]
449    fn produce_next_byte(&mut self) -> Result<Option<u8>, Error> {
450        if self.pos >= self.uncompressed_size as usize {
451            return Ok(None);
452        }
453        self.pos += 1;
454
455        if self.repeat > 0 {
456            self.repeat -= 1;
457            Ok(Some(self.track_crc(self.last as u8)))
458        } else {
459            loop {
460                if self.byte_count >= self.num_bytes {
461                    if self.end_of_block {
462                        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, ""))?;
463                    }
464
465                    self.read_block()?;
466                    self.byte_count = 0;
467                    self.count = 0;
468                    self.last = 0;
469
470                    self.randindex = 0;
471                    self.randcount = RANDOMIZATION_TABLE[0] as i32;
472                }
473
474                self.transform_index = self.transform[self.transform_index as usize] as i32;
475                let mut byte = self.block[self.transform_index as usize];
476
477                if self.randomized != 0 && self.randcount == self.byte_count {
478                    byte ^= 1;
479                    self.randindex = (self.randindex + 1) & 255;
480                    self.randcount += RANDOMIZATION_TABLE[self.randindex as usize] as i32;
481                }
482
483                self.byte_count += 1;
484
485                if self.count == 4 {
486                    self.count = 0;
487                    if byte == 0 {
488                        continue;
489                    }
490                    self.repeat = byte as i32 - 1;
491                    return Ok(Some(self.track_crc(self.last as u8)));
492                } else {
493                    if byte == self.last as u8 {
494                        self.count += 1;
495                    } else {
496                        self.count = 1;
497                        self.last = byte as i32;
498                    }
499
500                    return Ok(Some(self.track_crc(byte)));
501                }
502            }
503        }
504    }
505
506    fn track_crc(&mut self, data: u8) -> u8 {
507        self.crc3.update(&[data]);
508        data
509    }
510
511    pub fn is_checksum_valid(&mut self) -> bool {
512        self.comp_crc == self.crc3.clone().finalize()
513    }
514}
515
516impl<'a, R: io::Read + io::Seek> io::Read for ArsenicReader<'a, R> {
517    #[inline]
518    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
519        for (idx, b) in buf.iter_mut().enumerate() {
520            match self.produce_next_byte() {
521                Ok(None) => return Ok(idx),
522                Ok(Some(val)) => *b = val,
523                Err(e) => return Err(e.into()),
524            }
525        }
526
527        Ok(buf.len())
528    }
529}
530
531impl<'a, R: io::Read + io::Seek> io::Seek for ArsenicReader<'a, R> {
532    fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
533        todo!()
534    }
535
536    #[inline]
537    fn stream_len(&mut self) -> io::Result<u64> {
538        Ok(self.uncompressed_size)
539    }
540
541    #[inline]
542    fn stream_position(&mut self) -> io::Result<u64> {
543        Ok(self.pos as u64)
544    }
545}
546
547fn calcuate_inverse_bwt(transform: &mut [u32], block: &mut [u8], count: usize) {
548    let mut counts = [0i32; 256];
549    let mut cumulative_counts = [0i32; 256];
550
551    for i in 0..count {
552        counts[block[i] as usize] += 1;
553    }
554
555    let mut total = 0;
556    for i in 0..256 {
557        cumulative_counts[i] = total;
558        total += counts[i];
559        counts[i] = 0;
560    }
561
562    for i in 0..count {
563        transform
564            [cumulative_counts[block[i] as usize] as usize + counts[block[i] as usize] as usize] =
565            i as u32;
566        counts[block[i] as usize] += 1;
567    }
568}