1use super::table::FseTable;
6use haagenti_core::{Error, Result};
7
8#[derive(Debug)]
13pub struct FseDecoder<'a> {
14 table: &'a FseTable,
16 state: usize,
18}
19
20impl<'a> FseDecoder<'a> {
21 pub fn new(table: &'a FseTable) -> Self {
23 Self { table, state: 0 }
24 }
25
26 pub fn init_state(&mut self, bits: &mut BitReader) -> Result<()> {
30 let accuracy_log = self.table.accuracy_log();
31 self.state = bits.read_bits(accuracy_log as usize)? as usize;
32 Ok(())
33 }
34
35 pub fn decode_symbol(&mut self, bits: &mut BitReader) -> Result<u8> {
39 let entry = self.table.decode(self.state);
40 let symbol = entry.symbol;
41
42 let add_bits = bits.read_bits(entry.num_bits as usize)? as u16;
44 self.state = (entry.baseline + add_bits) as usize;
45
46 Ok(symbol)
47 }
48
49 pub fn peek_symbol(&self) -> u8 {
51 self.table.decode(self.state).symbol
52 }
53
54 pub fn peek_num_bits(&self) -> u8 {
56 self.table.decode(self.state).num_bits
57 }
58
59 pub fn peek_seq_base(&self) -> u32 {
62 self.table.decode(self.state).seq_base
63 }
64
65 pub fn peek_seq_extra_bits(&self) -> u8 {
68 self.table.decode(self.state).seq_extra_bits
69 }
70
71 pub fn update_state(&mut self, bits: &mut BitReader) -> Result<()> {
76 let entry = self.table.decode(self.state);
77 let add_bits = bits.read_bits(entry.num_bits as usize)? as u16;
78 self.state = (entry.baseline + add_bits) as usize;
79 Ok(())
80 }
81
82 pub fn state(&self) -> usize {
84 self.state
85 }
86
87 #[cfg(test)]
89 pub fn set_state(&mut self, state: usize) {
90 self.state = state;
91 }
92}
93
94#[derive(Debug, Clone)]
99pub struct BitReader<'a> {
100 data: &'a [u8],
102 byte_pos: usize,
104 bit_pos: u8,
106 reversed: bool,
108 rev_byte_idx: usize,
110 rev_bit_pos: i8,
112 rev_total_bits: usize,
114 fse_container: u64,
116 fse_bits_consumed: usize,
118 fse_total_bits: usize,
120 fse_stream_bits_consumed: usize,
122 fse_byte_pos: usize,
124 fse_mode: bool,
126}
127
128impl<'a> BitReader<'a> {
129 pub fn new(data: &'a [u8]) -> Self {
131 Self {
132 data,
133 byte_pos: 0,
134 bit_pos: 0,
135 reversed: false,
136 rev_byte_idx: 0,
137 rev_bit_pos: 0,
138 rev_total_bits: 0,
139 fse_container: 0,
140 fse_bits_consumed: 0,
141 fse_total_bits: 0,
142 fse_stream_bits_consumed: 0,
143 fse_byte_pos: 0,
144 fse_mode: false,
145 }
146 }
147
148 pub fn new_reversed(data: &'a [u8]) -> Result<BitReader<'a>> {
157 if data.is_empty() {
158 return Err(Error::corrupted("Empty bitstream"));
159 }
160
161 let last_byte = data[data.len() - 1];
163 if last_byte == 0 {
164 return Err(Error::corrupted("Invalid bitstream: no sentinel bit"));
165 }
166
167 let sentinel_pos = 7 - last_byte.leading_zeros() as i8;
169
170 let prev_bytes_bits = (data.len() - 1) * 8;
172 let last_byte_bits = sentinel_pos as usize; let total_bits = prev_bytes_bits + last_byte_bits;
174
175 let start_byte_idx = data.len() - 1;
177 let start_bit_pos = sentinel_pos - 1; Ok(Self {
180 data,
181 byte_pos: 0,
182 bit_pos: 0,
183 reversed: true,
184 rev_byte_idx: start_byte_idx,
185 rev_bit_pos: start_bit_pos,
186 rev_total_bits: total_bits,
187 fse_container: 0,
188 fse_bits_consumed: 0,
189 fse_total_bits: 0,
190 fse_stream_bits_consumed: 0,
191 fse_byte_pos: 0,
192 fse_mode: false,
193 })
194 }
195
196 pub fn init_from_end(&mut self) -> Result<()> {
201 if self.data.is_empty() {
202 return Err(Error::corrupted("Empty bitstream"));
203 }
204
205 let last_byte = self.data[self.data.len() - 1];
207 if last_byte == 0 {
208 return Err(Error::corrupted("Invalid bitstream: no sentinel bit"));
209 }
210
211 let sentinel_pos = 7 - last_byte.leading_zeros() as i8;
213
214 let prev_bytes_bits = (self.data.len() - 1) * 8;
216 let last_byte_bits = sentinel_pos as usize; let total_bits = prev_bytes_bits + last_byte_bits;
218
219 self.reversed = true;
221 self.rev_byte_idx = self.data.len() - 1;
222 self.rev_bit_pos = sentinel_pos - 1; self.rev_total_bits = total_bits;
224
225 Ok(())
226 }
227
228 pub fn init_fse(&mut self) -> Result<()> {
237 if self.data.is_empty() {
238 return Err(Error::corrupted("Empty bitstream"));
239 }
240
241 let mut container: u64 = 0;
243 for (i, &byte) in self.data.iter().enumerate() {
244 if i >= 8 {
245 break; }
247 container |= (byte as u64) << (i * 8);
248 }
249
250 if container == 0 {
252 return Err(Error::corrupted("Invalid bitstream: no sentinel bit"));
253 }
254 let sentinel_pos = 63 - container.leading_zeros() as usize;
255
256 let total_bits = sentinel_pos;
258
259 self.fse_mode = true;
260 self.fse_container = container;
261 self.fse_bits_consumed = 0;
262 self.fse_total_bits = total_bits;
263 self.fse_stream_bits_consumed = 0;
264 self.fse_byte_pos = 0;
265
266 Ok(())
267 }
268
269 pub fn switch_to_lsb_mode(&mut self) -> Result<()> {
276 if !self.reversed {
277 return Err(Error::corrupted(
278 "switch_to_lsb_mode requires reversed mode",
279 ));
280 }
281
282 let remaining_bits = self.rev_total_bits;
284 if remaining_bits == 0 {
285 self.fse_mode = true;
286 self.fse_container = 0;
287 self.fse_bits_consumed = 0;
288 self.fse_total_bits = 0;
289 self.fse_stream_bits_consumed = 0;
290 self.fse_byte_pos = 0;
291 return Ok(());
292 }
293
294 let mut container: u64 = 0;
297 for (i, &byte) in self.data.iter().enumerate() {
298 if i >= 8 {
299 break;
300 }
301 container |= (byte as u64) << (i * 8);
302 }
303
304 self.fse_mode = true;
307 self.fse_container = container;
308 self.fse_bits_consumed = 0;
309 self.fse_total_bits = remaining_bits;
310 self.fse_stream_bits_consumed = 0;
311 self.fse_byte_pos = 0;
312 self.reversed = false;
313
314 Ok(())
315 }
316
317 fn fse_refill(&mut self) {
319 if self.fse_bits_consumed < 32 {
322 return;
323 }
324
325 let bytes_consumed = (self.fse_bits_consumed / 8).min(7);
327 if bytes_consumed == 0 {
328 return;
329 }
330
331 let shift_bits = bytes_consumed * 8;
333 self.fse_container >>= shift_bits;
334 self.fse_bits_consumed -= shift_bits;
335 self.fse_byte_pos += bytes_consumed;
336
337 for i in 0..bytes_consumed {
339 let byte_idx = self.fse_byte_pos + 8 - bytes_consumed + i;
340 if byte_idx < self.data.len() {
341 let byte = self.data[byte_idx] as u64;
342 let shift = (8 - bytes_consumed + i) * 8;
343 self.fse_container |= byte << shift;
344 }
345 }
346 }
347
348 fn read_bits_fse(&mut self, n: usize) -> Result<u32> {
350 if self.fse_stream_bits_consumed + n > self.fse_total_bits {
351 return Err(Error::unexpected_eof(self.fse_stream_bits_consumed + n));
352 }
353
354 self.fse_refill();
356
357 let mask = if n >= 32 { u32::MAX } else { (1u32 << n) - 1 };
358 let value = ((self.fse_container >> self.fse_bits_consumed) as u32) & mask;
359 self.fse_bits_consumed += n;
360 self.fse_stream_bits_consumed += n;
361
362 Ok(value)
363 }
364
365 pub fn read_bits(&mut self, n: usize) -> Result<u32> {
371 if n == 0 {
372 return Ok(0);
373 }
374 if n > 32 {
375 return Err(Error::corrupted("Cannot read more than 32 bits at once"));
376 }
377
378 if self.fse_mode {
379 self.read_bits_fse(n)
380 } else if self.reversed {
381 self.read_bits_reversed(n)
382 } else {
383 self.read_bits_forward(n)
384 }
385 }
386
387 fn read_bits_forward(&mut self, n: usize) -> Result<u32> {
389 if !self.has_bits(n) {
390 return Err(Error::unexpected_eof(
391 self.byte_pos * 8 + self.bit_pos as usize,
392 ));
393 }
394
395 let mut result = 0u32;
396 let mut bits_read = 0;
397
398 while bits_read < n {
399 let byte = self.data[self.byte_pos];
400 let available = 8 - self.bit_pos as usize;
401 let to_read = (n - bits_read).min(available);
402
403 let mask = ((1u32 << to_read) - 1) as u8;
405 let bits = (byte >> self.bit_pos) & mask;
406
407 result |= (bits as u32) << bits_read;
408 bits_read += to_read;
409
410 self.bit_pos += to_read as u8;
411 if self.bit_pos >= 8 {
412 self.bit_pos = 0;
413 self.byte_pos += 1;
414 }
415 }
416
417 Ok(result)
418 }
419
420 fn read_bits_reversed(&mut self, n: usize) -> Result<u32> {
426 if self.rev_total_bits < n {
427 return Err(Error::unexpected_eof(n));
428 }
429
430 let mut result = 0u32;
431 let mut bits_read = 0;
432
433 while bits_read < n {
434 if self.rev_bit_pos < 0 {
436 if self.rev_byte_idx == 0 {
437 return Err(Error::unexpected_eof(bits_read));
438 }
439 self.rev_byte_idx -= 1;
440 self.rev_bit_pos = 7;
441 }
442
443 let byte = self.data[self.rev_byte_idx];
444 let bits_to_read = (n - bits_read).min((self.rev_bit_pos + 1) as usize);
445
446 let shift = (self.rev_bit_pos + 1) as usize - bits_to_read;
449 let mask = ((1u32 << bits_to_read) - 1) as u8;
450 let extracted = (byte >> shift) & mask;
451
452 result = (result << bits_to_read) | (extracted as u32);
454 bits_read += bits_to_read;
455
456 self.rev_bit_pos -= bits_to_read as i8;
457 }
458
459 self.rev_total_bits -= n;
460 Ok(result)
461 }
462
463 fn has_bits(&self, n: usize) -> bool {
465 if self.fse_mode {
466 self.fse_stream_bits_consumed + n <= self.fse_total_bits
467 } else if self.reversed {
468 self.rev_total_bits >= n
469 } else {
470 let total_bits = self.data.len() * 8;
471 let consumed = self.byte_pos * 8 + self.bit_pos as usize;
472 consumed + n <= total_bits
473 }
474 }
475
476 pub fn is_empty(&self) -> bool {
478 if self.fse_mode {
479 self.fse_stream_bits_consumed >= self.fse_total_bits
480 } else if self.reversed {
481 self.rev_total_bits == 0
482 } else {
483 self.byte_pos >= self.data.len()
484 }
485 }
486
487 pub fn bits_remaining(&self) -> usize {
489 if self.fse_mode {
490 self.fse_total_bits
491 .saturating_sub(self.fse_stream_bits_consumed)
492 } else if self.reversed {
493 self.rev_total_bits
494 } else if self.byte_pos >= self.data.len() {
495 0
496 } else {
497 (self.data.len() - self.byte_pos) * 8 - self.bit_pos as usize
498 }
499 }
500
501 pub fn peek_bits(&self, n: usize) -> Result<u32> {
503 let mut clone = self.clone();
504 clone.read_bits(n)
505 }
506
507 pub fn peek_bits_padded(&self, n: usize) -> Result<u32> {
513 if !self.reversed {
514 return self.peek_bits(n);
516 }
517
518 let available = self.rev_total_bits;
519 if available >= n {
520 return self.peek_bits(n);
522 }
523
524 if available == 0 {
525 return Err(Error::unexpected_eof(0));
527 }
528
529 let mut clone = self.clone();
531 let bits = clone.read_bits(available)?;
532 Ok(bits << (n - available))
534 }
535}
536
537#[cfg(test)]
542mod tests {
543 use super::*;
544 use crate::fse::table::FseTable;
545
546 #[test]
547 fn test_bit_reader_empty() {
548 let data = [];
549 let reader = BitReader::new(&data);
550 assert!(reader.is_empty());
551 assert_eq!(reader.bits_remaining(), 0);
552 }
553
554 #[test]
555 fn test_bit_reader_single_byte() {
556 let data = [0b10110100]; let mut reader = BitReader::new(&data);
558
559 let low4 = reader.read_bits(4).unwrap();
561 let high4 = reader.read_bits(4).unwrap();
562
563 assert_eq!(low4, 0b0100); assert_eq!(high4, 0b1011); }
566
567 #[test]
568 fn test_bit_reader_multiple_bytes() {
569 let data = [0xAB, 0xCD]; let mut reader = BitReader::new(&data);
571
572 let first = reader.read_bits(8).unwrap();
574 let second = reader.read_bits(8).unwrap();
575
576 assert_eq!(first, 0xAB);
577 assert_eq!(second, 0xCD);
578 }
579
580 #[test]
581 fn test_bit_reader_cross_byte() {
582 let data = [0xFF, 0x00]; let mut reader = BitReader::new(&data);
584
585 let first = reader.read_bits(4).unwrap();
587 assert_eq!(first, 0x0F); let cross = reader.read_bits(8).unwrap();
591 assert_eq!(cross, 0x0F); }
593
594 #[test]
595 fn test_bit_reader_init_from_end() {
596 let data = [0x00, 0x80];
598 let mut reader = BitReader::new(&data);
599 reader.init_from_end().unwrap();
600
601 assert_eq!(reader.bits_remaining(), 15);
604 }
605
606 #[test]
607 fn test_bit_reader_init_from_end_lower_sentinel() {
608 let data = [0xFF, 0x04];
610 let mut reader = BitReader::new(&data);
611 reader.init_from_end().unwrap();
612
613 assert_eq!(reader.bits_remaining(), 10);
616 }
617
618 #[test]
619 fn test_bit_reader_eof() {
620 let data = [0xFF];
621 let mut reader = BitReader::new(&data);
622
623 reader.read_bits(8).unwrap();
625
626 let result = reader.read_bits(1);
628 assert!(result.is_err());
629 }
630
631 #[test]
632 fn test_bit_reader_peek() {
633 let data = [0b11110000];
634 let reader = BitReader::new(&data);
635
636 let peek1 = reader.peek_bits(4).unwrap();
637 let peek2 = reader.peek_bits(4).unwrap();
638
639 assert_eq!(peek1, peek2);
641 assert_eq!(peek1, 0b0000); assert_eq!(reader.bits_remaining(), 8);
643 }
644
645 #[test]
646 fn test_bit_reader_read_zero() {
647 let data = [0xFF];
648 let mut reader = BitReader::new(&data);
649
650 let zero = reader.read_bits(0).unwrap();
651 assert_eq!(zero, 0);
652 assert_eq!(reader.bits_remaining(), 8);
653 }
654
655 #[test]
656 fn test_fse_decoder_creation() {
657 let distribution = [4i16, 4];
658 let table = FseTable::build(&distribution, 3, 1).unwrap();
659 let decoder = FseDecoder::new(&table);
660
661 assert_eq!(decoder.state(), 0);
662 }
663
664 #[test]
665 fn test_fse_decoder_init_state() {
666 let distribution = [4i16, 4];
667 let table = FseTable::build(&distribution, 3, 1).unwrap();
668 let mut decoder = FseDecoder::new(&table);
669
670 let data = [0b00000101];
672 let mut bits = BitReader::new(&data);
673
674 decoder.init_state(&mut bits).unwrap();
675 assert_eq!(decoder.state(), 5);
676 }
677
678 #[test]
679 fn test_fse_decoder_decode_sequence() {
680 let distribution = [4i16, 4]; let table = FseTable::build(&distribution, 3, 1).unwrap();
683 let mut decoder = FseDecoder::new(&table);
684
685 decoder.set_state(0);
687 let sym0 = decoder.peek_symbol();
688
689 assert!(sym0 <= 1);
691 }
692
693 #[test]
694 fn test_fse_decoder_state_transitions() {
695 let distribution = [6i16, 2]; let table = FseTable::build(&distribution, 3, 1).unwrap();
698
699 for start_state in 0..8 {
701 let _decoder = FseDecoder::new(&table);
702 let entry = table.decode(start_state);
703
704 assert!(entry.symbol <= 1);
706
707 assert!(entry.num_bits <= 3);
709 }
710 }
711}