1use alloc::vec::Vec;
35use core::{
36 borrow::Borrow,
37 fmt::{Debug, Display},
38 marker::PhantomData,
39 num::NonZeroUsize,
40 ops::Deref,
41};
42
43use num_traits::AsPrimitive;
44
45use super::{
46 model::{DecoderModel, EncoderModel},
47 Code, Decode, Encode, IntoDecoder,
48};
49use crate::{
50 backends::{AsReadWords, BoundedReadWords, Cursor, IntoReadWords, ReadWords, WriteWords},
51 generic_static_asserts, BitArray, CoderError, DefaultEncoderError, DefaultEncoderFrontendError,
52 NonZeroBitArray, Pos, PosSeek, Queue, Seek, UnwrapInfallible,
53};
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct RangeCoderState<Word, State: BitArray> {
59 lower: State,
60
61 range: State::NonZero,
65
66 phantom: PhantomData<Word>,
69}
70
71impl<Word: BitArray, State: BitArray> RangeCoderState<Word, State> {
72 #[allow(clippy::result_unit_err)]
73 pub fn new(lower: State, range: State) -> Result<Self, ()> {
74 if range >> (State::BITS - Word::BITS) == State::zero() {
75 Err(())
76 } else {
77 Ok(Self {
78 lower,
79 range: range.into_nonzero().expect("We checked above."),
80 phantom: PhantomData,
81 })
82 }
83 }
84
85 pub fn lower(&self) -> State {
87 self.lower
88 }
89
90 pub fn range(&self) -> State::NonZero {
92 self.range
93 }
94}
95
96impl<Word: BitArray, State: BitArray> Default for RangeCoderState<Word, State> {
97 fn default() -> Self {
98 Self {
99 lower: State::zero(),
100 range: State::max_value().into_nonzero().expect("max_value() != 0"),
101 phantom: PhantomData,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct RangeEncoder<Word, State, Backend = Vec<Word>>
108where
109 Word: BitArray,
110 State: BitArray,
111 Backend: WriteWords<Word>,
112{
113 bulk: Backend,
114 state: RangeCoderState<Word, State>,
115 situation: EncoderSituation<Word>,
116}
117
118#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
124pub enum EncoderSituation<Word> {
125 #[default]
129 Normal,
130
131 Inverted(NonZeroUsize, Word),
140}
141
142pub type DefaultRangeEncoder<Backend = Vec<u32>> = RangeEncoder<u32, u64, Backend>;
144
145pub type SmallRangeEncoder<Backend = Vec<u16>> = RangeEncoder<u16, u32, Backend>;
155
156impl<Word, State, Backend> Code for RangeEncoder<Word, State, Backend>
157where
158 Word: BitArray + Into<State>,
159 State: BitArray + AsPrimitive<Word>,
160 Backend: WriteWords<Word>,
161{
162 type State = RangeCoderState<Word, State>;
163 type Word = Word;
164
165 fn state(&self) -> Self::State {
166 self.state
167 }
168}
169
170impl<Word, State, Backend> PosSeek for RangeEncoder<Word, State, Backend>
171where
172 Word: BitArray,
173 State: BitArray,
174 Backend: WriteWords<Word> + PosSeek,
175 Self: Code,
176{
177 type Position = (Backend::Position, <Self as Code>::State);
178}
179
180impl<Word, State, Backend> Pos for RangeEncoder<Word, State, Backend>
181where
182 Word: BitArray + Into<State>,
183 State: BitArray + AsPrimitive<Word>,
184 Backend: WriteWords<Word> + Pos<Position = usize>,
185{
186 fn pos(&self) -> Self::Position {
187 let num_inverted = if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
188 num_inverted.get()
189 } else {
190 0
191 };
192 (self.bulk.pos() + num_inverted, self.state())
193 }
194}
195
196impl<Word, State, Backend> Default for RangeEncoder<Word, State, Backend>
197where
198 Word: BitArray + Into<State>,
199 State: BitArray + AsPrimitive<Word>,
200 Backend: WriteWords<Word> + Default,
201{
202 fn default() -> Self {
205 Self::with_backend(Backend::default())
206 }
207}
208
209impl<Word, State> RangeEncoder<Word, State>
210where
211 Word: BitArray + Into<State>,
212 State: BitArray + AsPrimitive<Word>,
213{
214 pub fn new() -> Self {
216 generic_static_asserts!(
217 (Word: BitArray, State:BitArray);
218 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
219 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
220 );
221
222 Self {
223 bulk: Vec::new(),
224 state: RangeCoderState::default(),
225 situation: EncoderSituation::Normal,
226 }
227 }
228}
229
230impl<Word, State> From<RangeEncoder<Word, State>> for Vec<Word>
231where
232 Word: BitArray + Into<State>,
233 State: BitArray + AsPrimitive<Word>,
234{
235 fn from(val: RangeEncoder<Word, State>) -> Self {
236 val.into_compressed().unwrap_infallible()
237 }
238}
239
240impl<Word, State, Backend> RangeEncoder<Word, State, Backend>
241where
242 Word: BitArray + Into<State>,
243 State: BitArray + AsPrimitive<Word>,
244 Backend: WriteWords<Word>,
245{
246 pub fn with_backend(backend: Backend) -> Self {
263 generic_static_asserts!(
264 (Word: BitArray, State:BitArray);
265 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
266 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
267 );
268
269 Self {
270 bulk: backend,
271 state: RangeCoderState::default(),
272 situation: EncoderSituation::Normal,
273 }
274 }
275
276 pub fn is_empty<'a>(&'a self) -> bool
278 where
279 Backend: AsReadWords<'a, Word, Queue>,
280 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
281 {
282 self.state.range.get() == State::max_value() && self.bulk.as_read_words().is_exhausted()
283 }
284
285 pub fn maybe_full(&self) -> bool {
288 self.bulk.maybe_full()
289 }
290
291 #[allow(clippy::result_unit_err)]
296 pub fn into_decoder(self) -> Result<RangeDecoder<Word, State, Backend::IntoReadWords>, ()>
297 where
298 Backend: IntoReadWords<Word, Queue>,
299 {
300 RangeDecoder::from_compressed(self.into_compressed().map_err(|_| ())?).map_err(|_| ())
302 }
303
304 pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
305 self.seal()?;
306 Ok(self.bulk)
307 }
308
309 fn seal(&mut self) -> Result<(), Backend::WriteError> {
317 if self.state.range.get() == State::max_value() {
318 return Ok(());
323 }
324
325 let point = self
326 .state
327 .lower
328 .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
329
330 if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
331 {
332 let (first_word, consecutive_words) = if point < self.state.lower {
333 (first_inverted_lower_word + Word::one(), Word::zero())
335 } else {
336 (first_inverted_lower_word, Word::max_value())
338 };
339
340 self.bulk.write(first_word)?;
341 for _ in 1..num_inverted.get() {
342 self.bulk.write(consecutive_words)?;
343 }
344 }
345
346 let point_word = (point >> (State::BITS - Word::BITS)).as_();
347 self.bulk.write(point_word)?;
348
349 let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
350 >> (State::BITS - Word::BITS))
351 .as_();
352 if upper_word == point_word {
353 self.bulk.write(Word::zero())?;
354 }
355
356 Ok(())
357 }
358
359 fn num_seal_words(&self) -> usize {
360 if self.state.range.get() == State::max_value() {
361 return 0;
362 }
363
364 let point = self
365 .state
366 .lower
367 .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
368 let point_word = (point >> (State::BITS - Word::BITS)).as_();
369 let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
370 >> (State::BITS - Word::BITS))
371 .as_();
372 let mut count = if upper_word == point_word { 2 } else { 1 };
373
374 if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
375 count += num_inverted.get();
376 }
377 count
378 }
379
380 pub fn num_words<'a>(&'a self) -> usize
396 where
397 Backend: AsReadWords<'a, Word, Queue>,
398 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
399 {
400 self.bulk.as_read_words().remaining() + self.num_seal_words()
401 }
402
403 pub fn num_bits<'a>(&'a self) -> usize
411 where
412 Backend: AsReadWords<'a, Word, Queue>,
413 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
414 {
415 Word::BITS * self.num_words()
416 }
417
418 pub fn bulk(&self) -> &Backend {
419 &self.bulk
420 }
421
422 pub fn from_raw_parts(
427 bulk: Backend,
428 state: RangeCoderState<Word, State>,
429 situation: EncoderSituation<Word>,
430 ) -> Self {
431 generic_static_asserts!(
432 (Word: BitArray, State:BitArray);
433 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
434 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
435 );
436
437 Self {
440 bulk,
441 state,
442 situation,
443 }
444 }
445
446 pub fn into_raw_parts(
450 self,
451 ) -> (
452 Backend,
453 RangeCoderState<Word, State>,
454 EncoderSituation<Word>,
455 ) {
456 (self.bulk, self.state, self.situation)
457 }
458}
459
460impl<Word, State> RangeEncoder<Word, State>
461where
462 Word: BitArray + Into<State>,
463 State: BitArray + AsPrimitive<Word>,
464{
465 pub fn clear(&mut self) {
468 self.bulk.clear();
469 self.state = RangeCoderState::default();
470 }
471
472 pub fn get_compressed(&mut self) -> EncoderGuard<'_, Word, State> {
483 EncoderGuard::new(self)
484 }
485
486 pub fn decoder(
501 &mut self,
502 ) -> RangeDecoder<Word, State, Cursor<Word, EncoderGuard<'_, Word, State>>> {
503 RangeDecoder::from_compressed(self.get_compressed()).unwrap_infallible()
504 }
505
506 fn unseal(&mut self) {
507 for _ in 0..self.num_seal_words() {
508 let word = self.bulk.pop();
509 debug_assert!(word.is_some());
510 }
511 }
512}
513
514impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
515 for RangeEncoder<Word, State, Backend>
516where
517 Word: BitArray + Into<State>,
518 State: BitArray + AsPrimitive<Word>,
519 Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
520{
521 type IntoDecoder = RangeDecoder<Word, State, Backend::IntoReadWords>;
522
523 fn into_decoder(self) -> Self::IntoDecoder {
524 self.into()
525 }
526}
527
528impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
529 for RangeEncoder<Word, State, Backend>
530where
531 Word: BitArray + Into<State>,
532 State: BitArray + AsPrimitive<Word>,
533 Backend: WriteWords<Word>,
534{
535 type FrontendError = DefaultEncoderFrontendError;
536 type BackendError = Backend::WriteError;
537
538 fn encode_symbol<D>(
539 &mut self,
540 symbol: impl Borrow<D::Symbol>,
541 model: D,
542 ) -> Result<(), DefaultEncoderError<Self::BackendError>>
543 where
544 D: EncoderModel<PRECISION>,
545 D::Probability: Into<Self::Word>,
546 Self::Word: AsPrimitive<D::Probability>,
547 {
548 generic_static_asserts!(
549 (Word: BitArray, State:BitArray; const PRECISION: usize);
550 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
551 NON_ZERO_PRECISION: PRECISION > 0;
552 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
553 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
554 );
555
556 let (left_sided_cumulative, probability) = model
560 .left_cumulative_and_probability(symbol)
561 .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
562
563 let scale = self.state.range.get() >> PRECISION;
564 self.state.range = (scale * probability.get().into().into())
566 .into_nonzero()
567 .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
568 let new_lower = self
569 .state
570 .lower
571 .wrapping_add(&(scale * left_sided_cumulative.into().into()));
572
573 if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
574 {
575 if new_lower.wrapping_add(&self.state.range.get()) > new_lower {
577 let (first_word, consecutive_words) = if new_lower < self.state.lower {
580 (first_inverted_lower_word + Word::one(), Word::zero())
581 } else {
582 (first_inverted_lower_word, Word::max_value())
583 };
584
585 self.bulk.write(first_word)?;
586 for _ in 1..num_inverted.get() {
587 self.bulk.write(consecutive_words)?;
588 }
589
590 self.situation = EncoderSituation::Normal;
591 }
592 }
593
594 self.state.lower = new_lower;
595
596 if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
597 self.state.range = unsafe {
606 (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
611 };
612
613 let lower_word = (self.state.lower >> (State::BITS - Word::BITS)).as_();
614 self.state.lower = self.state.lower << Word::BITS;
615
616 if let EncoderSituation::Inverted(num_inverted, _) = &mut self.situation {
617 *num_inverted = NonZeroUsize::new(num_inverted.get().wrapping_add(1))
619 .expect("Cannot encode more symbols than what's addressable with usize.");
620 } else if self.state.lower.wrapping_add(&self.state.range.get()) > self.state.lower {
621 self.bulk.write(lower_word)?;
623 } else {
624 self.situation =
626 EncoderSituation::Inverted(NonZeroUsize::new(1).expect("1 != 0"), lower_word);
627 }
628 }
629
630 Ok(())
631 }
632
633 fn maybe_full(&self) -> bool {
634 RangeEncoder::maybe_full(self)
635 }
636}
637
638#[derive(Debug, Clone)]
639pub struct RangeDecoder<Word, State, Backend>
640where
641 Word: BitArray,
642 State: BitArray,
643 Backend: ReadWords<Word, Queue>,
644{
645 bulk: Backend,
646
647 state: RangeCoderState<Word, State>,
648
649 point: State,
651}
652
653pub type DefaultRangeDecoder<Backend = Cursor<u32, Vec<u32>>> = RangeDecoder<u32, u64, Backend>;
655
656pub type SmallRangeDecoder<Backend> = RangeDecoder<u16, u32, Backend>;
674
675impl<Word, State, Backend> RangeDecoder<Word, State, Backend>
676where
677 Word: BitArray + Into<State>,
678 State: BitArray + AsPrimitive<Word>,
679 Backend: ReadWords<Word, Queue>,
680{
681 pub fn from_compressed<Buf>(compressed: Buf) -> Result<Self, Backend::ReadError>
682 where
683 Buf: IntoReadWords<Word, Queue, IntoReadWords = Backend>,
684 {
685 generic_static_asserts!(
686 (Word: BitArray, State:BitArray);
687 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
688 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
689 );
690
691 let mut bulk = compressed.into_read_words();
692 let point = Self::read_point(&mut bulk)?;
693
694 Ok(RangeDecoder {
695 bulk,
696 state: RangeCoderState::default(),
697 point,
698 })
699 }
700
701 pub fn with_backend(backend: Backend) -> Result<Self, Backend::ReadError> {
702 generic_static_asserts!(
703 (Word: BitArray, State:BitArray);
704 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
705 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
706 );
707
708 let mut bulk = backend;
709 let point = Self::read_point(&mut bulk)?;
710
711 Ok(RangeDecoder {
712 bulk,
713 state: RangeCoderState::default(),
714 point,
715 })
716 }
717
718 pub fn for_compressed<'a, Buf>(compressed: &'a Buf) -> Result<Self, Backend::ReadError>
719 where
720 Buf: AsReadWords<'a, Word, Queue, AsReadWords = Backend>,
721 {
722 generic_static_asserts!(
723 (Word: BitArray, State:BitArray);
724 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
725 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
726 );
727
728 let mut bulk = compressed.as_read_words();
729 let point = Self::read_point(&mut bulk)?;
730
731 Ok(RangeDecoder {
732 bulk,
733 state: RangeCoderState::default(),
734 point,
735 })
736 }
737
738 pub fn from_raw_parts(
747 bulk: Backend,
748 state: RangeCoderState<Word, State>,
749 point: State,
750 ) -> Result<Self, Backend> {
751 generic_static_asserts!(
752 (Word: BitArray, State:BitArray);
753 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
754 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
755 );
756
757 if point.wrapping_sub(&state.lower) >= state.range.get() {
760 Err(bulk)
761 } else {
762 Ok(Self { bulk, state, point })
763 }
764 }
765
766 pub fn into_raw_parts(self) -> (Backend, RangeCoderState<Word, State>, State) {
770 (self.bulk, self.state, self.point)
771 }
772
773 fn read_point<B: ReadWords<Word, Queue>>(bulk: &mut B) -> Result<State, B::ReadError> {
774 let mut num_read = 0;
775 let mut point = State::zero();
776 while let Some(word) = bulk.read()? {
777 point = (point << Word::BITS) | word.into();
778 num_read += 1;
779 if num_read == State::BITS / Word::BITS {
780 break;
781 }
782 }
783
784 #[allow(clippy::collapsible_if)]
785 if num_read < State::BITS / Word::BITS {
786 if num_read != 0 {
787 point = point << (State::BITS - num_read * Word::BITS);
788 }
789 }
792
793 Ok(point)
794 }
795
796 pub fn maybe_exhausted(&self) -> bool {
799 let max_difference =
802 ((State::one() << (State::BITS - Word::BITS)) << 1).wrapping_sub(&State::one());
803
804 self.bulk.maybe_exhausted()
807 && (self.state.range.get() == State::max_value()
808 || self.point.wrapping_sub(&self.state.lower) < max_difference)
809 }
810}
811
812impl<Word, State, Backend> Code for RangeDecoder<Word, State, Backend>
813where
814 Word: BitArray + Into<State>,
815 State: BitArray + AsPrimitive<Word>,
816 Backend: ReadWords<Word, Queue>,
817{
818 type State = RangeCoderState<Word, State>;
819 type Word = Word;
820
821 fn state(&self) -> Self::State {
822 self.state
823 }
824}
825
826impl<Word, State, Backend> PosSeek for RangeDecoder<Word, State, Backend>
827where
828 Word: BitArray,
829 State: BitArray,
830 Backend: ReadWords<Word, Queue>,
831 Backend: PosSeek,
832 Self: Code,
833{
834 type Position = (Backend::Position, <Self as Code>::State);
835}
836
837impl<Word, State, Backend> Seek for RangeDecoder<Word, State, Backend>
838where
839 Word: BitArray + Into<State>,
840 State: BitArray + AsPrimitive<Word>,
841 Backend: ReadWords<Word, Queue> + Seek,
842{
843 fn seek(&mut self, pos_and_state: Self::Position) -> Result<(), ()> {
844 let (pos, state) = pos_and_state;
845
846 self.bulk.seek(pos)?;
847 self.point = Self::read_point(&mut self.bulk).map_err(|_| ())?;
848 self.state = state;
849
850 Ok(())
853 }
854}
855
856impl<Word, State, Backend> From<RangeEncoder<Word, State, Backend>>
857 for RangeDecoder<Word, State, Backend::IntoReadWords>
858where
859 Word: BitArray + Into<State>,
860 State: BitArray + AsPrimitive<Word>,
861 Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
862{
863 fn from(encoder: RangeEncoder<Word, State, Backend>) -> Self {
864 encoder.into_decoder().unwrap()
868 }
869}
870
871impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
885 for RangeDecoder<Word, State, Backend>
886where
887 Word: BitArray + Into<State>,
888 State: BitArray + AsPrimitive<Word>,
889 Backend: ReadWords<Word, Queue>,
890{
891 type FrontendError = DecoderFrontendError;
892
893 type BackendError = Backend::ReadError;
894
895 fn decode_symbol<D>(
896 &mut self,
897 model: D,
898 ) -> Result<D::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
899 where
900 D: DecoderModel<PRECISION>,
901 D::Probability: Into<Self::Word>,
902 Self::Word: AsPrimitive<D::Probability>,
903 {
904 generic_static_asserts!(
905 (Word: BitArray, State:BitArray; const PRECISION: usize);
906 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
907 NON_ZERO_PRECISION: PRECISION > 0;
908 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
909 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
910 );
911
912 let scale = self.state.range.get() >> PRECISION;
917 let quantile = self.point.wrapping_sub(&self.state.lower) / scale;
918 if quantile >= State::one() << PRECISION {
919 return Err(CoderError::Frontend(DecoderFrontendError::InvalidData));
920 }
921
922 let (symbol, left_sided_cumulative, probability) =
923 model.quantile_function(quantile.as_().as_());
924
925 self.state.lower = self
927 .state
928 .lower
929 .wrapping_add(&(scale * left_sided_cumulative.into().into()));
930 self.state.range = (scale * probability.get().into().into())
931 .into_nonzero()
932 .expect("TODO");
933
934 if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
942 self.state.lower = self.state.lower << Word::BITS;
944 self.state.range = unsafe {
945 (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
950 };
951
952 self.point = self.point << Word::BITS;
954 if let Some(word) = self.bulk.read()? {
955 self.point = self.point | word.into();
956 }
957
958 }
960
961 Ok(symbol)
962 }
963
964 fn maybe_exhausted(&self) -> bool {
965 RangeDecoder::maybe_exhausted(self)
966 }
967}
968
969pub struct EncoderGuard<'a, Word, State>
974where
975 Word: BitArray + Into<State>,
976 State: BitArray + AsPrimitive<Word>,
977{
978 inner: &'a mut RangeEncoder<Word, State>,
979}
980
981impl<Word, State> Debug for EncoderGuard<'_, Word, State>
982where
983 Word: BitArray + Into<State>,
984 State: BitArray + AsPrimitive<Word>,
985{
986 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
987 Debug::fmt(&**self, f)
988 }
989}
990
991impl<'a, Word, State> EncoderGuard<'a, Word, State>
992where
993 Word: BitArray + Into<State>,
994 State: BitArray + AsPrimitive<Word>,
995{
996 fn new(encoder: &'a mut RangeEncoder<Word, State>) -> Self {
997 if !encoder.is_empty() {
999 encoder.seal().unwrap_infallible();
1000 }
1001 Self { inner: encoder }
1002 }
1003}
1004
1005impl<Word, State> Drop for EncoderGuard<'_, Word, State>
1006where
1007 Word: BitArray + Into<State>,
1008 State: BitArray + AsPrimitive<Word>,
1009{
1010 fn drop(&mut self) {
1011 self.inner.unseal();
1012 }
1013}
1014
1015impl<Word, State> Deref for EncoderGuard<'_, Word, State>
1016where
1017 Word: BitArray + Into<State>,
1018 State: BitArray + AsPrimitive<Word>,
1019{
1020 type Target = [Word];
1021
1022 fn deref(&self) -> &Self::Target {
1023 &self.inner.bulk
1024 }
1025}
1026
1027impl<Word, State> AsRef<[Word]> for EncoderGuard<'_, Word, State>
1028where
1029 Word: BitArray + Into<State>,
1030 State: BitArray + AsPrimitive<Word>,
1031{
1032 fn as_ref(&self) -> &[Word] {
1033 self
1034 }
1035}
1036
1037#[cfg(test)]
1038mod tests {
1039 extern crate std;
1040 use std::dbg;
1041
1042 use super::super::model::{
1043 ContiguousCategoricalEntropyModel, IterableEntropyModel, LeakyQuantizer,
1044 };
1045 use super::*;
1046
1047 use probability::distribution::{Gaussian, Inverse};
1048 use rand_xoshiro::{
1049 rand_core::{RngCore, SeedableRng},
1050 Xoshiro256StarStar,
1051 };
1052
1053 #[test]
1054 fn compress_none() {
1055 let encoder = DefaultRangeEncoder::new();
1056 assert!(encoder.is_empty());
1057 let compressed = encoder.into_compressed().unwrap();
1058 assert!(compressed.is_empty());
1059
1060 let decoder = DefaultRangeDecoder::from_compressed(compressed).unwrap();
1061 assert!(decoder.maybe_exhausted());
1062 }
1063
1064 #[test]
1065 fn compress_one() {
1066 generic_compress_few(core::iter::once(5), 1)
1067 }
1068
1069 #[test]
1070 fn compress_two() {
1071 generic_compress_few([2, 8].iter().cloned(), 1)
1072 }
1073
1074 #[test]
1075 fn compress_ten() {
1076 generic_compress_few(0..10, 2)
1077 }
1078
1079 #[test]
1080 fn compress_twenty() {
1081 generic_compress_few(-10..10, 4)
1082 }
1083
1084 fn generic_compress_few<I>(symbols: I, expected_size: usize)
1085 where
1086 I: IntoIterator<Item = i32>,
1087 I::IntoIter: Clone,
1088 {
1089 let symbols = symbols.into_iter();
1090
1091 let mut encoder = DefaultRangeEncoder::new();
1092 let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-127..=127);
1093 let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
1094
1095 encoder.encode_iid_symbols(symbols.clone(), model).unwrap();
1096 let compressed = encoder.into_compressed().unwrap();
1097 assert_eq!(compressed.len(), expected_size);
1098
1099 let mut decoder = DefaultRangeDecoder::from_compressed(&compressed).unwrap();
1100 for symbol in symbols {
1101 assert_eq!(decoder.decode_symbol(model).unwrap(), symbol);
1102 }
1103 assert!(decoder.maybe_exhausted());
1104 }
1105
1106 #[test]
1107 fn compress_many_u32_u64_32() {
1108 generic_compress_many::<u32, u64, u32, 32>();
1109 }
1110
1111 #[test]
1112 fn compress_many_u32_u64_24() {
1113 generic_compress_many::<u32, u64, u32, 24>();
1114 }
1115
1116 #[test]
1117 fn compress_many_u32_u64_16() {
1118 generic_compress_many::<u32, u64, u16, 16>();
1119 }
1120
1121 #[test]
1122 fn compress_many_u32_u64_8() {
1123 generic_compress_many::<u32, u64, u8, 8>();
1124 }
1125
1126 #[test]
1127 fn compress_many_u16_u64_16() {
1128 generic_compress_many::<u16, u64, u16, 16>();
1129 }
1130
1131 #[test]
1132 fn compress_many_u16_u64_12() {
1133 generic_compress_many::<u16, u64, u16, 12>();
1134 }
1135
1136 #[test]
1137 fn compress_many_u16_u64_8() {
1138 generic_compress_many::<u16, u64, u8, 8>();
1139 }
1140
1141 #[test]
1142 fn compress_many_u8_u64_8() {
1143 generic_compress_many::<u8, u64, u8, 8>();
1144 }
1145
1146 #[test]
1147 fn compress_many_u16_u32_16() {
1148 generic_compress_many::<u16, u32, u16, 16>();
1149 }
1150
1151 #[test]
1152 fn compress_many_u16_u32_12() {
1153 generic_compress_many::<u16, u32, u16, 12>();
1154 }
1155
1156 #[test]
1157 fn compress_many_u16_u32_8() {
1158 generic_compress_many::<u16, u32, u8, 8>();
1159 }
1160
1161 #[test]
1162 fn compress_many_u8_u32_8() {
1163 generic_compress_many::<u8, u32, u8, 8>();
1164 }
1165
1166 #[test]
1167 fn compress_many_u8_u16_8() {
1168 generic_compress_many::<u8, u16, u8, 8>();
1169 }
1170
1171 fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
1172 where
1173 State: BitArray + AsPrimitive<Word>,
1174 Word: BitArray + Into<State> + AsPrimitive<Probability>,
1175 Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
1176 u32: AsPrimitive<Probability>,
1177 usize: AsPrimitive<Probability>,
1178 f64: AsPrimitive<Probability>,
1179 i32: AsPrimitive<Probability>,
1180 {
1181 #[cfg(not(miri))]
1182 const AMT: usize = 1000;
1183
1184 #[cfg(miri)]
1185 const AMT: usize = 100;
1186
1187 let mut symbols_gaussian = Vec::with_capacity(AMT);
1188 let mut means = Vec::with_capacity(AMT);
1189 let mut stds = Vec::with_capacity(AMT);
1190
1191 let mut rng = Xoshiro256StarStar::seed_from_u64(1234);
1192 for _ in 0..AMT {
1193 let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
1194 let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
1195 let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
1196 let dist = Gaussian::new(mean, std_dev);
1197 let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
1198
1199 symbols_gaussian.push(symbol);
1200 means.push(mean);
1201 stds.push(std_dev);
1202 }
1203
1204 let hist = [
1205 1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
1206 896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
1207 347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
1208 ];
1209 let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
1210 let categorical =
1211 ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities_fast::<f64>(
1212 &categorical_probabilities,None
1213 )
1214 .unwrap();
1215 let mut symbols_categorical = Vec::with_capacity(AMT);
1216 let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
1217 for _ in 0..AMT {
1218 let quantile = rng.next_u32().as_() & max_probability;
1219 let symbol = categorical.quantile_function(quantile).0;
1220 symbols_categorical.push(symbol);
1221 }
1222
1223 let mut encoder = RangeEncoder::<Word, State>::new();
1224
1225 encoder
1226 .encode_iid_symbols(&symbols_categorical, &categorical)
1227 .unwrap();
1228 dbg!(
1229 encoder.num_bits(),
1230 AMT as f64 * categorical.entropy_base2::<f64>()
1231 );
1232
1233 let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
1234 encoder
1235 .encode_symbols(symbols_gaussian.iter().zip(&means).zip(&stds).map(
1236 |((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
1237 ))
1238 .unwrap();
1239 dbg!(encoder.num_bits());
1240
1241 let mut decoder = encoder.into_decoder().unwrap();
1242
1243 let reconstructed_categorical = decoder
1244 .decode_iid_symbols(AMT, &categorical)
1245 .collect::<Result<Vec<_>, _>>()
1246 .unwrap();
1247 let reconstructed_gaussian = decoder
1248 .decode_symbols(
1249 means
1250 .iter()
1251 .zip(&stds)
1252 .map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
1253 )
1254 .collect::<Result<Vec<_>, _>>()
1255 .unwrap();
1256
1257 assert!(decoder.maybe_exhausted());
1258
1259 assert_eq!(symbols_categorical, reconstructed_categorical);
1260 assert_eq!(symbols_gaussian, reconstructed_gaussian);
1261 }
1262
1263 #[test]
1264 fn seek() {
1265 #[cfg(not(miri))]
1266 let (num_chunks, symbols_per_chunk) = (100, 100);
1267
1268 #[cfg(miri)]
1269 let (num_chunks, symbols_per_chunk) = (10, 10);
1270
1271 let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-100..=100);
1272 let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
1273
1274 let mut encoder = DefaultRangeEncoder::new();
1275
1276 let mut rng = Xoshiro256StarStar::seed_from_u64(123);
1277 let mut symbols = Vec::with_capacity(num_chunks);
1278 let mut jump_table = Vec::with_capacity(num_chunks);
1279
1280 for _ in 0..num_chunks {
1281 jump_table.push(encoder.pos());
1282 let chunk = (0..symbols_per_chunk)
1283 .map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
1284 .collect::<Vec<_>>();
1285 encoder.encode_iid_symbols(&chunk, &model).unwrap();
1286 symbols.push(chunk);
1287 }
1288 let final_pos_and_state = encoder.pos();
1289
1290 let mut decoder = encoder.decoder();
1291
1292 for (chunk, _) in symbols.iter().zip(&jump_table) {
1296 let decoded = decoder
1297 .decode_iid_symbols(symbols_per_chunk, &model)
1298 .collect::<Result<Vec<_>, _>>()
1299 .unwrap();
1300 assert_eq!(&decoded, chunk);
1301 }
1302 assert!(decoder.maybe_exhausted());
1303
1304 for i in 0..100 {
1306 let chunk_index = if i == 3 {
1307 0
1309 } else {
1310 rng.next_u32() as usize % num_chunks
1311 };
1312
1313 let pos_and_state = jump_table[chunk_index];
1314 decoder.seek(pos_and_state).unwrap();
1315 let decoded = decoder
1316 .decode_iid_symbols(symbols_per_chunk, &model)
1317 .collect::<Result<Vec<_>, _>>()
1318 .unwrap();
1319 assert_eq!(&decoded, &symbols[chunk_index])
1320 }
1321
1322 decoder.seek(jump_table[0]).unwrap();
1324 assert!(!decoder.maybe_exhausted());
1325 decoder.seek(final_pos_and_state).unwrap();
1326 assert!(decoder.maybe_exhausted());
1327 }
1328}
1329
1330#[derive(Debug)]
1331#[non_exhaustive]
1332pub enum DecoderFrontendError {
1333 InvalidData,
1358}
1359
1360impl Display for DecoderFrontendError {
1361 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1362 match self {
1363 Self::InvalidData => write!(
1364 f,
1365 "Tried to decode from compressed data that is invalid for the employed entropy model."
1366 ),
1367 }
1368 }
1369}
1370
1371#[cfg(feature = "std")]
1372impl std::error::Error for DecoderFrontendError {}