1use crate::block::{Sequence, LITERAL_LENGTH_BASELINE};
25use crate::fse::{
26 cached_ll_table, cached_ml_table, cached_of_table, FseBitWriter, InterleavedTansEncoder,
27 TansEncoder,
28};
29use crate::CustomFseTables;
30use haagenti_core::Result;
31
32const ML_ENCODE_TABLE: [(u8, u32, u32); 53] = [
49 (0, 3, 3),
51 (0, 4, 4),
52 (0, 5, 5),
53 (0, 6, 6),
54 (0, 7, 7),
55 (0, 8, 8),
56 (0, 9, 9),
57 (0, 10, 10),
58 (0, 11, 11),
59 (0, 12, 12),
60 (0, 13, 13),
61 (0, 14, 14),
62 (0, 15, 15),
63 (0, 16, 16),
64 (0, 17, 17),
65 (0, 18, 18),
66 (0, 19, 19),
67 (0, 20, 20),
68 (0, 21, 21),
69 (0, 22, 22),
70 (0, 23, 23),
71 (0, 24, 24),
72 (0, 25, 25),
73 (0, 26, 26),
74 (0, 27, 27),
75 (0, 28, 28),
76 (0, 29, 29),
77 (0, 30, 30),
78 (0, 31, 31),
79 (0, 32, 32),
80 (0, 33, 33),
81 (0, 34, 34),
82 (1, 35, 36),
84 (1, 37, 38),
85 (1, 39, 40),
86 (1, 41, 42),
87 (2, 43, 46),
89 (2, 47, 50),
90 (3, 51, 58),
92 (3, 59, 66),
93 (4, 67, 82),
95 (4, 83, 98),
96 (5, 99, 130),
98 (7, 131, 258),
100 (8, 259, 514),
102 (9, 515, 1026),
104 (10, 1027, 2050),
106 (11, 2051, 4098),
108 (12, 4099, 8194),
110 (13, 8195, 16386),
112 (14, 16387, 32770),
114 (15, 32771, 65538),
116 (16, 65539, 131074),
118];
119
120#[derive(Debug, Clone, Copy)]
122pub struct EncodedSequence {
123 pub ll_code: u8,
125 pub ll_extra: u32,
127 pub ll_bits: u8,
129 pub of_code: u8,
131 pub of_extra: u32,
133 pub of_bits: u8,
135 pub ml_code: u8,
137 pub ml_extra: u32,
139 pub ml_bits: u8,
141}
142
143impl EncodedSequence {
144 #[inline]
146 pub fn from_sequence(seq: &Sequence) -> Self {
147 let (ll_code, ll_extra, ll_bits) = encode_literal_length(seq.literal_length);
148 let (of_code, of_extra, of_bits) = encode_offset(seq.offset);
149 let (ml_code, ml_extra, ml_bits) = encode_match_length(seq.match_length);
150
151 Self {
152 ll_code,
153 ll_extra,
154 ll_bits,
155 of_code,
156 of_extra,
157 of_bits,
158 ml_code,
159 ml_extra,
160 ml_bits,
161 }
162 }
163}
164
165#[inline(always)]
169fn encode_literal_length(value: u32) -> (u8, u32, u8) {
170 if value < 16 {
172 return (value as u8, 0, 0);
173 }
174
175 if value < 18 {
177 return (16, value - 16, 1);
178 }
179
180 if value < 20 {
182 return (17, value - 18, 1);
183 }
184
185 if value < 24 {
187 return (18, value - 20, 2);
188 }
189
190 let log_estimate = if value < 64 {
193 4
194 } else if value < 256 {
195 6
196 } else if value < 1024 {
197 8
198 } else {
199 10
200 };
201
202 for (code, &(bits, baseline)) in LITERAL_LENGTH_BASELINE
204 .iter()
205 .enumerate()
206 .skip(log_estimate)
207 {
208 let max_value = if bits == 0 {
209 baseline
210 } else {
211 baseline + ((1u32 << bits) - 1)
212 };
213
214 if value >= baseline && value <= max_value {
215 return (code as u8, value - baseline, bits);
216 }
217 }
218
219 let last_idx = LITERAL_LENGTH_BASELINE.len() - 1;
221 let (bits, baseline) = LITERAL_LENGTH_BASELINE[last_idx];
222 (last_idx as u8, value.saturating_sub(baseline), bits)
223}
224
225#[inline(always)]
233fn encode_match_length(value: u32) -> (u8, u32, u8) {
234 if (3..=34).contains(&value) {
237 return ((value - 3) as u8, 0, 0);
238 }
239
240 if value < 3 {
242 return (0, 0, 0); }
244
245 if value <= 42 {
247 let code = 32 + ((value - 35) / 2) as u8;
248 let baseline = 35 + ((code - 32) as u32 * 2);
249 return (code, value - baseline, 1);
250 }
251
252 if value <= 50 {
254 let code = if value < 47 { 36 } else { 37 };
255 let baseline = if code == 36 { 43 } else { 47 };
256 return (code, value - baseline, 2);
257 }
258
259 if value <= 66 {
261 let code = if value < 59 { 38 } else { 39 };
262 let baseline = if code == 38 { 51 } else { 59 };
263 return (code, value - baseline, 3);
264 }
265
266 for (code, &(bits, baseline, max_value)) in ML_ENCODE_TABLE.iter().enumerate().skip(40) {
269 if value >= baseline && value <= max_value {
270 return (code as u8, value - baseline, bits);
271 }
272 }
273
274 let last_idx = ML_ENCODE_TABLE.len() - 1;
276 let (bits, baseline, _) = ML_ENCODE_TABLE[last_idx];
277 (last_idx as u8, value.saturating_sub(baseline), bits)
278}
279
280fn encode_offset(offset_value: u32) -> (u8, u32, u8) {
294 if offset_value == 0 {
295 return (0, 0, 0);
296 }
297
298 let offset_code = 31 - offset_value.leading_zeros();
300 let baseline = 1u32 << offset_code;
301 let extra = offset_value - baseline;
302 let num_bits = offset_code as u8;
303
304 (offset_code as u8, extra, num_bits)
305}
306
307pub fn analyze_for_rle(sequences: &[Sequence]) -> RleSuitability {
314 if sequences.is_empty() {
315 return RleSuitability::all_rle(0, 0, 0);
316 }
317
318 let mut encoded = Vec::with_capacity(sequences.len());
320
321 let first = EncodedSequence::from_sequence(&sequences[0]);
323 let (ll_code, of_code, ml_code) = (first.ll_code, first.of_code, first.ml_code);
324 encoded.push(first);
325
326 let mut ll_uniform = true;
328 let mut of_uniform = true;
329 let mut ml_uniform = true;
330
331 for seq in sequences.iter().skip(1) {
332 let enc = EncodedSequence::from_sequence(seq);
333
334 ll_uniform = ll_uniform && enc.ll_code == ll_code;
336 of_uniform = of_uniform && enc.of_code == of_code;
337 ml_uniform = ml_uniform && enc.ml_code == ml_code;
338
339 encoded.push(enc);
340 }
341
342 RleSuitability {
343 ll_uniform,
344 ll_code,
345 of_uniform,
346 of_code,
347 ml_uniform,
348 ml_code,
349 encoded,
350 }
351}
352
353#[derive(Debug)]
355pub struct RleSuitability {
356 pub ll_uniform: bool,
358 pub ll_code: u8,
360 pub of_uniform: bool,
362 pub of_code: u8,
364 pub ml_uniform: bool,
366 pub ml_code: u8,
368 pub encoded: Vec<EncodedSequence>,
370}
371
372impl RleSuitability {
373 fn all_rle(ll: u8, of: u8, ml: u8) -> Self {
374 Self {
375 ll_uniform: true,
376 ll_code: ll,
377 of_uniform: true,
378 of_code: of,
379 ml_uniform: true,
380 ml_code: ml,
381 encoded: Vec::new(),
382 }
383 }
384
385 pub fn all_uniform(&self) -> bool {
387 self.ll_uniform && self.of_uniform && self.ml_uniform
388 }
389}
390
391pub fn encode_sequences_rle(
396 sequences: &[Sequence],
397 suitability: &RleSuitability,
398 output: &mut Vec<u8>,
399) -> Result<()> {
400 if sequences.is_empty() {
401 output.push(0);
402 return Ok(());
403 }
404
405 let count = sequences.len();
406
407 if count < 128 {
409 output.push(count as u8);
410 } else if count < 0x7F00 {
411 output.push(((count >> 8) + 128) as u8);
412 output.push((count & 0xFF) as u8);
413 } else {
414 output.push(255);
415 let adjusted = count - 0x7F00;
416 output.push((adjusted & 0xFF) as u8);
417 output.push(((adjusted >> 8) & 0xFF) as u8);
418 }
419
420 output.push(0x15);
429
430 output.push(suitability.ll_code);
432 output.push(suitability.of_code);
433 output.push(suitability.ml_code);
434
435 let bitstream = build_rle_bitstream(&suitability.encoded);
437 output.extend_from_slice(&bitstream);
438
439 Ok(())
440}
441
442pub fn encode_sequences_fse(sequences: &[Sequence], output: &mut Vec<u8>) -> Result<()> {
455 if sequences.is_empty() {
456 output.push(0);
457 return Ok(());
458 }
459
460 let encoded: Vec<EncodedSequence> = sequences
462 .iter()
463 .map(EncodedSequence::from_sequence)
464 .collect();
465
466 encode_sequences_fse_with_encoded(&encoded, output)
467}
468
469pub fn encode_sequences_fse_with_encoded(
479 encoded: &[EncodedSequence],
480 output: &mut Vec<u8>,
481) -> Result<()> {
482 if encoded.is_empty() {
483 output.push(0);
484 return Ok(());
485 }
486
487 let count = encoded.len();
488
489 if count < 128 {
491 output.push(count as u8);
492 } else if count < 0x7F00 {
493 output.push(((count >> 8) + 128) as u8);
494 output.push((count & 0xFF) as u8);
495 } else {
496 output.push(255);
497 let adjusted = count - 0x7F00;
498 output.push((adjusted & 0xFF) as u8);
499 output.push(((adjusted >> 8) & 0xFF) as u8);
500 }
501
502 output.push(0x00);
505
506 let mut tans = InterleavedTansEncoder::new_predefined();
509
510 let bitstream = build_fse_bitstream(encoded, &mut tans);
512 output.extend_from_slice(&bitstream);
513
514 Ok(())
515}
516
517pub fn encode_sequences_with_custom_tables(
537 encoded: &[EncodedSequence],
538 custom_tables: &CustomFseTables,
539 output: &mut Vec<u8>,
540) -> Result<()> {
541 if encoded.is_empty() {
542 output.push(0);
543 return Ok(());
544 }
545
546 let count = encoded.len();
547
548 if count < 128 {
550 output.push(count as u8);
551 } else if count < 0x7F00 {
552 output.push(((count >> 8) + 128) as u8);
553 output.push((count & 0xFF) as u8);
554 } else {
555 output.push(255);
556 let adjusted = count - 0x7F00;
557 output.push((adjusted & 0xFF) as u8);
558 output.push(((adjusted >> 8) & 0xFF) as u8);
559 }
560
561 let mode_byte = 0x00; output.push(mode_byte);
572
573 let ll_table = custom_tables
575 .ll_table
576 .as_ref()
577 .map(|t| t.as_ref())
578 .unwrap_or_else(|| cached_ll_table());
579 let of_table = custom_tables
580 .of_table
581 .as_ref()
582 .map(|t| t.as_ref())
583 .unwrap_or_else(|| cached_of_table());
584 let ml_table = custom_tables
585 .ml_table
586 .as_ref()
587 .map(|t| t.as_ref())
588 .unwrap_or_else(|| cached_ml_table());
589
590 let ll_encoder = TansEncoder::from_decode_table(ll_table);
591 let of_encoder = TansEncoder::from_decode_table(of_table);
592 let ml_encoder = TansEncoder::from_decode_table(ml_table);
593
594 let mut tans = InterleavedTansEncoder::from_encoders(ll_encoder, of_encoder, ml_encoder);
595
596 let bitstream = build_fse_bitstream(encoded, &mut tans);
598 output.extend_from_slice(&bitstream);
599
600 Ok(())
601}
602
603#[allow(unused_variables)]
627fn build_fse_bitstream(encoded: &[EncodedSequence], tans: &mut InterleavedTansEncoder) -> Vec<u8> {
628 #[cfg(test)]
629 let debug = std::env::var("DEBUG_FSE").is_ok();
630 if encoded.is_empty() {
631 return vec![0x01]; }
633
634 let mut bits = FseBitWriter::new();
635
636 let (ll_log, of_log, ml_log) = tans.accuracy_logs();
638
639 let last_idx = encoded.len() - 1;
650 let last_seq = &encoded[last_idx];
651
652 tans.init_states(last_seq.ll_code, last_seq.of_code, last_seq.ml_code);
654
655 #[cfg(test)]
656 if std::env::var("DEBUG_FSE_DETAIL").is_ok() {
657 let (ll_s, of_s, ml_s) = tans.get_states();
658 eprintln!(
659 "FSE init with codes ({}, {}, {}): states = ({}, {}, {})",
660 last_seq.ll_code, last_seq.of_code, last_seq.ml_code, ll_s, of_s, ml_s
661 );
662 }
663
664 let mut fse_bits_per_seq: Vec<[(u32, u8); 3]> = vec![[(0, 0); 3]; last_idx];
667
668 for i in (0..last_idx).rev() {
669 let seq = &encoded[i];
670 let fse_bits = tans.encode_sequence(seq.ll_code, seq.of_code, seq.ml_code);
671
672 #[cfg(test)]
673 if std::env::var("DEBUG_FSE_DETAIL").is_ok() {
674 let (ll_s, of_s, ml_s) = tans.get_states();
675 eprintln!("FSE encode seq[{}] codes ({}, {}, {}): bits=LL({},{}) ML({},{}) OF({},{}), new states=({}, {}, {})",
676 i, seq.ll_code, seq.of_code, seq.ml_code,
677 fse_bits[0].0, fse_bits[0].1,
678 fse_bits[2].0, fse_bits[2].1,
679 fse_bits[1].0, fse_bits[1].1,
680 ll_s, of_s, ml_s);
681 }
682
683 fse_bits_per_seq[i] = fse_bits;
684 }
685
686 for i in 0..last_idx {
691 let seq = &encoded[i];
692
693 if seq.ll_bits > 0 {
695 bits.write_bits(seq.ll_extra, seq.ll_bits);
696 }
697 if seq.ml_bits > 0 {
698 bits.write_bits(seq.ml_extra, seq.ml_bits);
699 }
700 if seq.of_bits > 0 {
701 bits.write_bits(seq.of_extra, seq.of_bits);
702 }
703
704 let [ll_fse, of_fse, ml_fse] = fse_bits_per_seq[i];
706 bits.write_bits(ll_fse.0, ll_fse.1);
707 bits.write_bits(ml_fse.0, ml_fse.1);
708 bits.write_bits(of_fse.0, of_fse.1);
709 }
710
711 if last_seq.ll_bits > 0 {
713 bits.write_bits(last_seq.ll_extra, last_seq.ll_bits);
714 }
715 if last_seq.ml_bits > 0 {
716 bits.write_bits(last_seq.ml_extra, last_seq.ml_bits);
717 }
718 if last_seq.of_bits > 0 {
719 bits.write_bits(last_seq.of_extra, last_seq.of_bits);
720 }
721
722 let (ll_state, of_state, ml_state) = tans.get_states();
724
725 #[cfg(test)]
726 if std::env::var("DEBUG_FSE").is_ok() {
727 eprintln!("FSE encode: {} sequences", encoded.len());
728 eprintln!(
729 " Last seq: ll_code={}, of_code={}, ml_code={}",
730 last_seq.ll_code, last_seq.of_code, last_seq.ml_code
731 );
732 eprintln!(
733 " Last seq extras: ll={}({} bits), ml={}({} bits), of={}({} bits)",
734 last_seq.ll_extra,
735 last_seq.ll_bits,
736 last_seq.ml_extra,
737 last_seq.ml_bits,
738 last_seq.of_extra,
739 last_seq.of_bits
740 );
741 eprintln!(
742 " Final states: ll={}, of={}, ml={}",
743 ll_state, of_state, ml_state
744 );
745 }
746
747 bits.write_bits(ml_state, ml_log);
750 bits.write_bits(of_state, of_log);
751 bits.write_bits(ll_state, ll_log);
752
753 bits.finish()
754}
755
756fn build_rle_bitstream(encoded: &[EncodedSequence]) -> Vec<u8> {
762 if encoded.is_empty() {
763 return vec![0x01]; }
765
766 let mut bits = FseBitWriter::new();
767
768 for seq in encoded.iter().rev() {
771 if seq.ll_bits > 0 {
772 bits.write_bits(seq.ll_extra, seq.ll_bits);
773 }
774 if seq.ml_bits > 0 {
775 bits.write_bits(seq.ml_extra, seq.ml_bits);
776 }
777 if seq.of_bits > 0 {
778 bits.write_bits(seq.of_extra, seq.of_bits);
779 }
780 }
781
782 bits.finish()
783}
784
785#[cfg(test)]
790mod tests {
791 use super::*;
792
793 #[test]
794 fn test_encode_literal_length_small() {
795 for i in 0..16 {
797 let (code, extra, bits) = encode_literal_length(i);
798 assert_eq!(code, i as u8);
799 assert_eq!(extra, 0);
800 assert_eq!(bits, 0);
801 }
802 }
803
804 #[test]
805 fn test_encode_literal_length_with_extra_bits() {
806 let (code, extra, bits) = encode_literal_length(16);
808 assert_eq!(code, 16);
809 assert_eq!(extra, 0);
810 assert_eq!(bits, 1);
811
812 let (code, extra, bits) = encode_literal_length(17);
813 assert_eq!(code, 16);
814 assert_eq!(extra, 1);
815 assert_eq!(bits, 1);
816 }
817
818 #[test]
819 fn test_encode_match_length() {
820 let (code, extra, bits) = encode_match_length(3);
822 assert_eq!(code, 0);
823 assert_eq!(extra, 0);
824 assert_eq!(bits, 0);
825
826 let (code, extra, bits) = encode_match_length(4);
828 assert_eq!(code, 1);
829 assert_eq!(extra, 0);
830 assert_eq!(bits, 0);
831 }
832
833 #[test]
834 fn test_encode_offset() {
835 let (code, extra, bits) = encode_offset(1);
843 assert_eq!(code, 0);
844 assert_eq!(extra, 0);
845 assert_eq!(bits, 0);
846
847 let (code, extra, bits) = encode_offset(2);
849 assert_eq!(code, 1);
850 assert_eq!(extra, 0);
851 assert_eq!(bits, 1);
852
853 let (code, extra, bits) = encode_offset(3);
855 assert_eq!(code, 1);
856 assert_eq!(extra, 1);
857 assert_eq!(bits, 1);
858
859 let (code, extra, bits) = encode_offset(4);
861 assert_eq!(code, 2);
862 assert_eq!(extra, 0);
863 assert_eq!(bits, 2);
864
865 let (code, extra, bits) = encode_offset(7);
867 assert_eq!(code, 2);
868 assert_eq!(extra, 3);
869 assert_eq!(bits, 2);
870
871 let (code, extra, bits) = encode_offset(8);
873 assert_eq!(code, 3);
874 assert_eq!(extra, 0);
875 assert_eq!(bits, 3);
876
877 let (code, extra, bits) = encode_offset(19);
879 assert_eq!(code, 4);
880 assert_eq!(extra, 3);
881 assert_eq!(bits, 4);
882
883 let (code, extra, bits) = encode_offset(100);
885 assert_eq!(code, 6);
886 assert_eq!(extra, 36);
887 assert_eq!(bits, 6);
888 }
889
890 #[test]
891 fn test_analyze_for_rle_uniform() {
892 let sequences = vec![
893 Sequence::new(0, 4, 3), Sequence::new(0, 4, 3),
895 Sequence::new(0, 4, 3),
896 ];
897
898 let suitability = analyze_for_rle(&sequences);
899 assert!(suitability.all_uniform());
900 }
901
902 #[test]
903 fn test_analyze_for_rle_non_uniform() {
904 let sequences = vec![
905 Sequence::new(0, 4, 3),
906 Sequence::new(10, 100, 20), ];
908
909 let suitability = analyze_for_rle(&sequences);
910 assert!(!suitability.all_uniform());
911 }
912
913 #[test]
914 fn test_encode_sequences_rle_empty() {
915 let sequences: Vec<Sequence> = vec![];
916 let suitability = analyze_for_rle(&sequences);
917
918 let mut output = Vec::new();
919 encode_sequences_rle(&sequences, &suitability, &mut output).unwrap();
920
921 assert_eq!(output, vec![0]); }
923
924 #[test]
925 fn test_encode_sequences_rle_single() {
926 let sequences = vec![Sequence::new(0, 4, 3)];
927 let suitability = analyze_for_rle(&sequences);
928
929 let mut output = Vec::new();
930 encode_sequences_rle(&sequences, &suitability, &mut output).unwrap();
931
932 assert!(output.len() >= 5);
934 assert_eq!(output[0], 1); assert_eq!(output[1], 0x15); }
937
938 #[test]
939 fn test_encoded_sequence_creation() {
940 let seq = Sequence::new(5, 8, 10);
941 let encoded = EncodedSequence::from_sequence(&seq);
942
943 assert_eq!(encoded.ll_code, 5); assert_eq!(encoded.ml_code, 7); }
946}