1use alloc::vec::Vec;
35use core::{
36 borrow::Borrow,
37 fmt::{Debug, Display},
38 marker::PhantomData,
39 num::NonZeroUsize,
40 ops::Deref,
41};
42
43use num_traits::AsPrimitive;
44
45use super::{
46 model::{DecoderModel, EncoderModel},
47 Code, Decode, Encode, IntoDecoder,
48};
49use crate::{
50 backends::{AsReadWords, BoundedReadWords, Cursor, IntoReadWords, ReadWords, WriteWords},
51 generic_static_asserts, BitArray, CoderError, DefaultEncoderError, DefaultEncoderFrontendError,
52 NonZeroBitArray, Pos, PosSeek, Queue, Seek, UnwrapInfallible,
53};
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct RangeCoderState<Word, State: BitArray> {
59 lower: State,
60
61 range: State::NonZero,
65
66 phantom: PhantomData<Word>,
69}
70
71impl<Word: BitArray, State: BitArray> RangeCoderState<Word, State> {
72 #[allow(clippy::result_unit_err)]
73 pub fn new(lower: State, range: State) -> Result<Self, ()> {
74 if range >> (State::BITS - Word::BITS) == State::zero() {
75 Err(())
76 } else {
77 Ok(Self {
78 lower,
79 range: range.into_nonzero().expect("We checked above."),
80 phantom: PhantomData,
81 })
82 }
83 }
84
85 pub fn lower(&self) -> State {
87 self.lower
88 }
89
90 pub fn range(&self) -> State::NonZero {
92 self.range
93 }
94}
95
96impl<Word: BitArray, State: BitArray> Default for RangeCoderState<Word, State> {
97 fn default() -> Self {
98 Self {
99 lower: State::zero(),
100 range: State::max_value().into_nonzero().expect("max_value() != 0"),
101 phantom: PhantomData,
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
107pub struct RangeEncoder<Word, State, Backend = Vec<Word>>
108where
109 Word: BitArray,
110 State: BitArray,
111 Backend: WriteWords<Word>,
112{
113 bulk: Backend,
114 state: RangeCoderState<Word, State>,
115 situation: EncoderSituation<Word>,
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum EncoderSituation<Word> {
125 Normal,
129
130 Inverted(NonZeroUsize, Word),
139}
140
141impl<Word> Default for EncoderSituation<Word> {
142 fn default() -> Self {
143 Self::Normal
144 }
145}
146
147pub type DefaultRangeEncoder<Backend = Vec<u32>> = RangeEncoder<u32, u64, Backend>;
149
150pub type SmallRangeEncoder<Backend = Vec<u16>> = RangeEncoder<u16, u32, Backend>;
161
162impl<Word, State, Backend> Code for RangeEncoder<Word, State, Backend>
163where
164 Word: BitArray + Into<State>,
165 State: BitArray + AsPrimitive<Word>,
166 Backend: WriteWords<Word>,
167{
168 type State = RangeCoderState<Word, State>;
169 type Word = Word;
170
171 fn state(&self) -> Self::State {
172 self.state
173 }
174}
175
176impl<Word, State, Backend> PosSeek for RangeEncoder<Word, State, Backend>
177where
178 Word: BitArray,
179 State: BitArray,
180 Backend: WriteWords<Word> + PosSeek,
181 Self: Code,
182{
183 type Position = (Backend::Position, <Self as Code>::State);
184}
185
186impl<Word, State, Backend> Pos for RangeEncoder<Word, State, Backend>
187where
188 Word: BitArray + Into<State>,
189 State: BitArray + AsPrimitive<Word>,
190 Backend: WriteWords<Word> + Pos<Position = usize>,
191{
192 fn pos(&self) -> Self::Position {
193 let num_inverted = if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
194 num_inverted.get()
195 } else {
196 0
197 };
198 (self.bulk.pos() + num_inverted, self.state())
199 }
200}
201
202impl<Word, State, Backend> Default for RangeEncoder<Word, State, Backend>
203where
204 Word: BitArray + Into<State>,
205 State: BitArray + AsPrimitive<Word>,
206 Backend: WriteWords<Word> + Default,
207{
208 fn default() -> Self {
211 Self::with_backend(Backend::default())
212 }
213}
214
215impl<Word, State> RangeEncoder<Word, State>
216where
217 Word: BitArray + Into<State>,
218 State: BitArray + AsPrimitive<Word>,
219{
220 pub fn new() -> Self {
222 generic_static_asserts!(
223 (Word: BitArray, State:BitArray);
224 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
225 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
226 );
227
228 Self {
229 bulk: Vec::new(),
230 state: RangeCoderState::default(),
231 situation: EncoderSituation::Normal,
232 }
233 }
234}
235
236impl<Word, State> From<RangeEncoder<Word, State>> for Vec<Word>
237where
238 Word: BitArray + Into<State>,
239 State: BitArray + AsPrimitive<Word>,
240{
241 fn from(val: RangeEncoder<Word, State>) -> Self {
242 val.into_compressed().unwrap_infallible()
243 }
244}
245
246impl<Word, State, Backend> RangeEncoder<Word, State, Backend>
247where
248 Word: BitArray + Into<State>,
249 State: BitArray + AsPrimitive<Word>,
250 Backend: WriteWords<Word>,
251{
252 pub fn with_backend(backend: Backend) -> Self {
269 generic_static_asserts!(
270 (Word: BitArray, State:BitArray);
271 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
272 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
273 );
274
275 Self {
276 bulk: backend,
277 state: RangeCoderState::default(),
278 situation: EncoderSituation::Normal,
279 }
280 }
281
282 pub fn is_empty<'a>(&'a self) -> bool
284 where
285 Backend: AsReadWords<'a, Word, Queue>,
286 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
287 {
288 self.state.range.get() == State::max_value() && self.bulk.as_read_words().is_exhausted()
289 }
290
291 pub fn maybe_full(&self) -> bool {
294 self.bulk.maybe_full()
295 }
296
297 #[allow(clippy::result_unit_err)]
302 pub fn into_decoder(self) -> Result<RangeDecoder<Word, State, Backend::IntoReadWords>, ()>
303 where
304 Backend: IntoReadWords<Word, Queue>,
305 {
306 RangeDecoder::from_compressed(self.into_compressed().map_err(|_| ())?).map_err(|_| ())
308 }
309
310 pub fn into_compressed(mut self) -> Result<Backend, Backend::WriteError> {
311 self.seal()?;
312 Ok(self.bulk)
313 }
314
315 fn seal(&mut self) -> Result<(), Backend::WriteError> {
323 if self.state.range.get() == State::max_value() {
324 return Ok(());
329 }
330
331 let point = self
332 .state
333 .lower
334 .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
335
336 if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
337 {
338 let (first_word, consecutive_words) = if point < self.state.lower {
339 (first_inverted_lower_word + Word::one(), Word::zero())
341 } else {
342 (first_inverted_lower_word, Word::max_value())
344 };
345
346 self.bulk.write(first_word)?;
347 for _ in 1..num_inverted.get() {
348 self.bulk.write(consecutive_words)?;
349 }
350 }
351
352 let point_word = (point >> (State::BITS - Word::BITS)).as_();
353 self.bulk.write(point_word)?;
354
355 let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
356 >> (State::BITS - Word::BITS))
357 .as_();
358 if upper_word == point_word {
359 self.bulk.write(Word::zero())?;
360 }
361
362 Ok(())
363 }
364
365 fn num_seal_words(&self) -> usize {
366 if self.state.range.get() == State::max_value() {
367 return 0;
368 }
369
370 let point = self
371 .state
372 .lower
373 .wrapping_add(&((State::one() << (State::BITS - Word::BITS)) - State::one()));
374 let point_word = (point >> (State::BITS - Word::BITS)).as_();
375 let upper_word = (self.state.lower.wrapping_add(&self.state.range.get())
376 >> (State::BITS - Word::BITS))
377 .as_();
378 let mut count = if upper_word == point_word { 2 } else { 1 };
379
380 if let EncoderSituation::Inverted(num_inverted, _) = self.situation {
381 count += num_inverted.get();
382 }
383 count
384 }
385
386 pub fn num_words<'a>(&'a self) -> usize
402 where
403 Backend: AsReadWords<'a, Word, Queue>,
404 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
405 {
406 self.bulk.as_read_words().remaining() + self.num_seal_words()
407 }
408
409 pub fn num_bits<'a>(&'a self) -> usize
417 where
418 Backend: AsReadWords<'a, Word, Queue>,
419 Backend::AsReadWords: BoundedReadWords<Word, Queue>,
420 {
421 Word::BITS * self.num_words()
422 }
423
424 pub fn bulk(&self) -> &Backend {
425 &self.bulk
426 }
427
428 pub fn from_raw_parts(
433 bulk: Backend,
434 state: RangeCoderState<Word, State>,
435 situation: EncoderSituation<Word>,
436 ) -> Self {
437 generic_static_asserts!(
438 (Word: BitArray, State:BitArray);
439 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
440 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
441 );
442
443 Self {
446 bulk,
447 state,
448 situation,
449 }
450 }
451
452 pub fn into_raw_parts(
456 self,
457 ) -> (
458 Backend,
459 RangeCoderState<Word, State>,
460 EncoderSituation<Word>,
461 ) {
462 (self.bulk, self.state, self.situation)
463 }
464}
465
466impl<Word, State> RangeEncoder<Word, State>
467where
468 Word: BitArray + Into<State>,
469 State: BitArray + AsPrimitive<Word>,
470{
471 pub fn clear(&mut self) {
474 self.bulk.clear();
475 self.state = RangeCoderState::default();
476 }
477
478 pub fn get_compressed(&mut self) -> EncoderGuard<'_, Word, State> {
489 EncoderGuard::new(self)
490 }
491
492 pub fn decoder(
507 &mut self,
508 ) -> RangeDecoder<Word, State, Cursor<Word, EncoderGuard<'_, Word, State>>> {
509 RangeDecoder::from_compressed(self.get_compressed()).unwrap_infallible()
510 }
511
512 fn unseal(&mut self) {
513 for _ in 0..self.num_seal_words() {
514 let word = self.bulk.pop();
515 debug_assert!(word.is_some());
516 }
517 }
518}
519
520impl<Word, State, Backend, const PRECISION: usize> IntoDecoder<PRECISION>
521 for RangeEncoder<Word, State, Backend>
522where
523 Word: BitArray + Into<State>,
524 State: BitArray + AsPrimitive<Word>,
525 Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
526{
527 type IntoDecoder = RangeDecoder<Word, State, Backend::IntoReadWords>;
528
529 fn into_decoder(self) -> Self::IntoDecoder {
530 self.into()
531 }
532}
533
534impl<Word, State, Backend, const PRECISION: usize> Encode<PRECISION>
535 for RangeEncoder<Word, State, Backend>
536where
537 Word: BitArray + Into<State>,
538 State: BitArray + AsPrimitive<Word>,
539 Backend: WriteWords<Word>,
540{
541 type FrontendError = DefaultEncoderFrontendError;
542 type BackendError = Backend::WriteError;
543
544 fn encode_symbol<D>(
545 &mut self,
546 symbol: impl Borrow<D::Symbol>,
547 model: D,
548 ) -> Result<(), DefaultEncoderError<Self::BackendError>>
549 where
550 D: EncoderModel<PRECISION>,
551 D::Probability: Into<Self::Word>,
552 Self::Word: AsPrimitive<D::Probability>,
553 {
554 generic_static_asserts!(
555 (Word: BitArray, State:BitArray; const PRECISION: usize);
556 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
557 NON_ZERO_PRECISION: PRECISION > 0;
558 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
559 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
560 );
561
562 let (left_sided_cumulative, probability) = model
566 .left_cumulative_and_probability(symbol)
567 .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
568
569 let scale = self.state.range.get() >> PRECISION;
570 self.state.range = (scale * probability.get().into().into())
572 .into_nonzero()
573 .ok_or_else(|| DefaultEncoderFrontendError::ImpossibleSymbol.into_coder_error())?;
574 let new_lower = self
575 .state
576 .lower
577 .wrapping_add(&(scale * left_sided_cumulative.into().into()));
578
579 if let EncoderSituation::Inverted(num_inverted, first_inverted_lower_word) = self.situation
580 {
581 if new_lower.wrapping_add(&self.state.range.get()) > new_lower {
583 let (first_word, consecutive_words) = if new_lower < self.state.lower {
586 (first_inverted_lower_word + Word::one(), Word::zero())
587 } else {
588 (first_inverted_lower_word, Word::max_value())
589 };
590
591 self.bulk.write(first_word)?;
592 for _ in 1..num_inverted.get() {
593 self.bulk.write(consecutive_words)?;
594 }
595
596 self.situation = EncoderSituation::Normal;
597 }
598 }
599
600 self.state.lower = new_lower;
601
602 if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
603 self.state.range = unsafe {
612 (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
617 };
618
619 let lower_word = (self.state.lower >> (State::BITS - Word::BITS)).as_();
620 self.state.lower = self.state.lower << Word::BITS;
621
622 if let EncoderSituation::Inverted(num_inverted, _) = &mut self.situation {
623 *num_inverted = NonZeroUsize::new(num_inverted.get().wrapping_add(1))
625 .expect("Cannot encode more symbols than what's addressable with usize.");
626 } else if self.state.lower.wrapping_add(&self.state.range.get()) > self.state.lower {
627 self.bulk.write(lower_word)?;
629 } else {
630 self.situation =
632 EncoderSituation::Inverted(NonZeroUsize::new(1).expect("1 != 0"), lower_word);
633 }
634 }
635
636 Ok(())
637 }
638
639 fn maybe_full(&self) -> bool {
640 RangeEncoder::maybe_full(self)
641 }
642}
643
644#[derive(Debug, Clone)]
645pub struct RangeDecoder<Word, State, Backend>
646where
647 Word: BitArray,
648 State: BitArray,
649 Backend: ReadWords<Word, Queue>,
650{
651 bulk: Backend,
652
653 state: RangeCoderState<Word, State>,
654
655 point: State,
657}
658
659pub type DefaultRangeDecoder<Backend = Cursor<u32, Vec<u32>>> = RangeDecoder<u32, u64, Backend>;
661
662pub type SmallRangeDecoder<Backend> = RangeDecoder<u16, u32, Backend>;
680
681impl<Word, State, Backend> RangeDecoder<Word, State, Backend>
682where
683 Word: BitArray + Into<State>,
684 State: BitArray + AsPrimitive<Word>,
685 Backend: ReadWords<Word, Queue>,
686{
687 pub fn from_compressed<Buf>(compressed: Buf) -> Result<Self, Backend::ReadError>
688 where
689 Buf: IntoReadWords<Word, Queue, IntoReadWords = Backend>,
690 {
691 generic_static_asserts!(
692 (Word: BitArray, State:BitArray);
693 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
694 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
695 );
696
697 let mut bulk = compressed.into_read_words();
698 let point = Self::read_point(&mut bulk)?;
699
700 Ok(RangeDecoder {
701 bulk,
702 state: RangeCoderState::default(),
703 point,
704 })
705 }
706
707 pub fn with_backend(backend: Backend) -> Result<Self, Backend::ReadError> {
708 generic_static_asserts!(
709 (Word: BitArray, State:BitArray);
710 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
711 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
712 );
713
714 let mut bulk = backend;
715 let point = Self::read_point(&mut bulk)?;
716
717 Ok(RangeDecoder {
718 bulk,
719 state: RangeCoderState::default(),
720 point,
721 })
722 }
723
724 pub fn for_compressed<'a, Buf>(compressed: &'a Buf) -> Result<Self, Backend::ReadError>
725 where
726 Buf: AsReadWords<'a, Word, Queue, AsReadWords = Backend>,
727 {
728 generic_static_asserts!(
729 (Word: BitArray, State:BitArray);
730 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
731 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
732 );
733
734 let mut bulk = compressed.as_read_words();
735 let point = Self::read_point(&mut bulk)?;
736
737 Ok(RangeDecoder {
738 bulk,
739 state: RangeCoderState::default(),
740 point,
741 })
742 }
743
744 pub fn from_raw_parts(
753 bulk: Backend,
754 state: RangeCoderState<Word, State>,
755 point: State,
756 ) -> Result<Self, Backend> {
757 generic_static_asserts!(
758 (Word: BitArray, State:BitArray);
759 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
760 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
761 );
762
763 if point.wrapping_sub(&state.lower) >= state.range.get() {
766 Err(bulk)
767 } else {
768 Ok(Self { bulk, state, point })
769 }
770 }
771
772 pub fn into_raw_parts(self) -> (Backend, RangeCoderState<Word, State>, State) {
776 (self.bulk, self.state, self.point)
777 }
778
779 fn read_point<B: ReadWords<Word, Queue>>(bulk: &mut B) -> Result<State, B::ReadError> {
780 let mut num_read = 0;
781 let mut point = State::zero();
782 while let Some(word) = bulk.read()? {
783 point = point << Word::BITS | word.into();
784 num_read += 1;
785 if num_read == State::BITS / Word::BITS {
786 break;
787 }
788 }
789
790 #[allow(clippy::collapsible_if)]
791 if num_read < State::BITS / Word::BITS {
792 if num_read != 0 {
793 point = point << (State::BITS - num_read * Word::BITS);
794 }
795 }
798
799 Ok(point)
800 }
801
802 pub fn maybe_exhausted(&self) -> bool {
805 let max_difference =
808 ((State::one() << (State::BITS - Word::BITS)) << 1).wrapping_sub(&State::one());
809
810 self.bulk.maybe_exhausted()
813 && (self.state.range.get() == State::max_value()
814 || self.point.wrapping_sub(&self.state.lower) < max_difference)
815 }
816}
817
818impl<Word, State, Backend> Code for RangeDecoder<Word, State, Backend>
819where
820 Word: BitArray + Into<State>,
821 State: BitArray + AsPrimitive<Word>,
822 Backend: ReadWords<Word, Queue>,
823{
824 type State = RangeCoderState<Word, State>;
825 type Word = Word;
826
827 fn state(&self) -> Self::State {
828 self.state
829 }
830}
831
832impl<Word, State, Backend> PosSeek for RangeDecoder<Word, State, Backend>
833where
834 Word: BitArray,
835 State: BitArray,
836 Backend: ReadWords<Word, Queue>,
837 Backend: PosSeek,
838 Self: Code,
839{
840 type Position = (Backend::Position, <Self as Code>::State);
841}
842
843impl<Word, State, Backend> Seek for RangeDecoder<Word, State, Backend>
844where
845 Word: BitArray + Into<State>,
846 State: BitArray + AsPrimitive<Word>,
847 Backend: ReadWords<Word, Queue> + Seek,
848{
849 fn seek(&mut self, pos_and_state: Self::Position) -> Result<(), ()> {
850 let (pos, state) = pos_and_state;
851
852 self.bulk.seek(pos)?;
853 self.point = Self::read_point(&mut self.bulk).map_err(|_| ())?;
854 self.state = state;
855
856 Ok(())
859 }
860}
861
862impl<Word, State, Backend> From<RangeEncoder<Word, State, Backend>>
863 for RangeDecoder<Word, State, Backend::IntoReadWords>
864where
865 Word: BitArray + Into<State>,
866 State: BitArray + AsPrimitive<Word>,
867 Backend: WriteWords<Word> + IntoReadWords<Word, Queue>,
868{
869 fn from(encoder: RangeEncoder<Word, State, Backend>) -> Self {
870 encoder.into_decoder().unwrap()
874 }
875}
876
877impl<Word, State, Backend, const PRECISION: usize> Decode<PRECISION>
891 for RangeDecoder<Word, State, Backend>
892where
893 Word: BitArray + Into<State>,
894 State: BitArray + AsPrimitive<Word>,
895 Backend: ReadWords<Word, Queue>,
896{
897 type FrontendError = DecoderFrontendError;
898
899 type BackendError = Backend::ReadError;
900
901 fn decode_symbol<D>(
917 &mut self,
918 model: D,
919 ) -> Result<D::Symbol, CoderError<Self::FrontendError, Self::BackendError>>
920 where
921 D: DecoderModel<PRECISION>,
922 D::Probability: Into<Self::Word>,
923 Self::Word: AsPrimitive<D::Probability>,
924 {
925 generic_static_asserts!(
926 (Word: BitArray, State:BitArray; const PRECISION: usize);
927 PROBABILITY_SUPPORTS_PRECISION: State::BITS >= Word::BITS + PRECISION;
928 NON_ZERO_PRECISION: PRECISION > 0;
929 STATE_SUPPORTS_AT_LEAST_TWO_WORDS: State::BITS >= 2 * Word::BITS;
930 STATE_SIZE_IS_MULTIPLE_OF_WORD_SIZE: State::BITS % Word::BITS == 0;
931 );
932
933 let scale = self.state.range.get() >> PRECISION;
938 let quantile = self.point.wrapping_sub(&self.state.lower) / scale;
939 if quantile >= State::one() << PRECISION {
940 return Err(CoderError::Frontend(DecoderFrontendError::InvalidData));
941 }
942
943 let (symbol, left_sided_cumulative, probability) =
944 model.quantile_function(quantile.as_().as_());
945
946 self.state.lower = self
948 .state
949 .lower
950 .wrapping_add(&(scale * left_sided_cumulative.into().into()));
951 self.state.range = (scale * probability.get().into().into())
952 .into_nonzero()
953 .expect("TODO");
954
955 if self.state.range.get() < State::one() << (State::BITS - Word::BITS) {
963 self.state.lower = self.state.lower << Word::BITS;
965 self.state.range = unsafe {
966 (self.state.range.get() << Word::BITS).into_nonzero_unchecked()
971 };
972
973 self.point = self.point << Word::BITS;
975 if let Some(word) = self.bulk.read()? {
976 self.point = self.point | word.into();
977 }
978
979 }
981
982 Ok(symbol)
983 }
984
985 fn maybe_exhausted(&self) -> bool {
986 RangeDecoder::maybe_exhausted(self)
987 }
988}
989
990pub struct EncoderGuard<'a, Word, State>
995where
996 Word: BitArray + Into<State>,
997 State: BitArray + AsPrimitive<Word>,
998{
999 inner: &'a mut RangeEncoder<Word, State>,
1000}
1001
1002impl<Word, State> Debug for EncoderGuard<'_, Word, State>
1003where
1004 Word: BitArray + Into<State>,
1005 State: BitArray + AsPrimitive<Word>,
1006{
1007 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1008 Debug::fmt(&**self, f)
1009 }
1010}
1011
1012impl<'a, Word, State> EncoderGuard<'a, Word, State>
1013where
1014 Word: BitArray + Into<State>,
1015 State: BitArray + AsPrimitive<Word>,
1016{
1017 fn new(encoder: &'a mut RangeEncoder<Word, State>) -> Self {
1018 if !encoder.is_empty() {
1020 encoder.seal().unwrap_infallible();
1021 }
1022 Self { inner: encoder }
1023 }
1024}
1025
1026impl<'a, Word, State> Drop for EncoderGuard<'a, Word, State>
1027where
1028 Word: BitArray + Into<State>,
1029 State: BitArray + AsPrimitive<Word>,
1030{
1031 fn drop(&mut self) {
1032 self.inner.unseal();
1033 }
1034}
1035
1036impl<'a, Word, State> Deref for EncoderGuard<'a, Word, State>
1037where
1038 Word: BitArray + Into<State>,
1039 State: BitArray + AsPrimitive<Word>,
1040{
1041 type Target = [Word];
1042
1043 fn deref(&self) -> &Self::Target {
1044 &self.inner.bulk
1045 }
1046}
1047
1048impl<'a, Word, State> AsRef<[Word]> for EncoderGuard<'a, Word, State>
1049where
1050 Word: BitArray + Into<State>,
1051 State: BitArray + AsPrimitive<Word>,
1052{
1053 fn as_ref(&self) -> &[Word] {
1054 self
1055 }
1056}
1057
1058#[cfg(test)]
1059mod tests {
1060 extern crate std;
1061 use std::dbg;
1062
1063 use super::super::model::{
1064 ContiguousCategoricalEntropyModel, IterableEntropyModel, LeakyQuantizer,
1065 };
1066 use super::*;
1067
1068 use probability::distribution::{Gaussian, Inverse};
1069 use rand_xoshiro::{
1070 rand_core::{RngCore, SeedableRng},
1071 Xoshiro256StarStar,
1072 };
1073
1074 #[test]
1075 fn compress_none() {
1076 let encoder = DefaultRangeEncoder::new();
1077 assert!(encoder.is_empty());
1078 let compressed = encoder.into_compressed().unwrap();
1079 assert!(compressed.is_empty());
1080
1081 let decoder = DefaultRangeDecoder::from_compressed(compressed).unwrap();
1082 assert!(decoder.maybe_exhausted());
1083 }
1084
1085 #[test]
1086 fn compress_one() {
1087 generic_compress_few(core::iter::once(5), 1)
1088 }
1089
1090 #[test]
1091 fn compress_two() {
1092 generic_compress_few([2, 8].iter().cloned(), 1)
1093 }
1094
1095 #[test]
1096 fn compress_ten() {
1097 generic_compress_few(0..10, 2)
1098 }
1099
1100 #[test]
1101 fn compress_twenty() {
1102 generic_compress_few(-10..10, 4)
1103 }
1104
1105 fn generic_compress_few<I>(symbols: I, expected_size: usize)
1106 where
1107 I: IntoIterator<Item = i32>,
1108 I::IntoIter: Clone,
1109 {
1110 let symbols = symbols.into_iter();
1111
1112 let mut encoder = DefaultRangeEncoder::new();
1113 let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-127..=127);
1114 let model = quantizer.quantize(Gaussian::new(3.2, 5.1));
1115
1116 encoder.encode_iid_symbols(symbols.clone(), model).unwrap();
1117 let compressed = encoder.into_compressed().unwrap();
1118 assert_eq!(compressed.len(), expected_size);
1119
1120 let mut decoder = DefaultRangeDecoder::from_compressed(&compressed).unwrap();
1121 for symbol in symbols {
1122 assert_eq!(decoder.decode_symbol(model).unwrap(), symbol);
1123 }
1124 assert!(decoder.maybe_exhausted());
1125 }
1126
1127 #[test]
1128 fn compress_many_u32_u64_32() {
1129 generic_compress_many::<u32, u64, u32, 32>();
1130 }
1131
1132 #[test]
1133 fn compress_many_u32_u64_24() {
1134 generic_compress_many::<u32, u64, u32, 24>();
1135 }
1136
1137 #[test]
1138 fn compress_many_u32_u64_16() {
1139 generic_compress_many::<u32, u64, u16, 16>();
1140 }
1141
1142 #[test]
1143 fn compress_many_u32_u64_8() {
1144 generic_compress_many::<u32, u64, u8, 8>();
1145 }
1146
1147 #[test]
1148 fn compress_many_u16_u64_16() {
1149 generic_compress_many::<u16, u64, u16, 16>();
1150 }
1151
1152 #[test]
1153 fn compress_many_u16_u64_12() {
1154 generic_compress_many::<u16, u64, u16, 12>();
1155 }
1156
1157 #[test]
1158 fn compress_many_u16_u64_8() {
1159 generic_compress_many::<u16, u64, u8, 8>();
1160 }
1161
1162 #[test]
1163 fn compress_many_u8_u64_8() {
1164 generic_compress_many::<u8, u64, u8, 8>();
1165 }
1166
1167 #[test]
1168 fn compress_many_u16_u32_16() {
1169 generic_compress_many::<u16, u32, u16, 16>();
1170 }
1171
1172 #[test]
1173 fn compress_many_u16_u32_12() {
1174 generic_compress_many::<u16, u32, u16, 12>();
1175 }
1176
1177 #[test]
1178 fn compress_many_u16_u32_8() {
1179 generic_compress_many::<u16, u32, u8, 8>();
1180 }
1181
1182 #[test]
1183 fn compress_many_u8_u32_8() {
1184 generic_compress_many::<u8, u32, u8, 8>();
1185 }
1186
1187 #[test]
1188 fn compress_many_u8_u16_8() {
1189 generic_compress_many::<u8, u16, u8, 8>();
1190 }
1191
1192 fn generic_compress_many<Word, State, Probability, const PRECISION: usize>()
1193 where
1194 State: BitArray + AsPrimitive<Word>,
1195 Word: BitArray + Into<State> + AsPrimitive<Probability>,
1196 Probability: BitArray + Into<Word> + AsPrimitive<usize> + Into<f64>,
1197 u32: AsPrimitive<Probability>,
1198 usize: AsPrimitive<Probability>,
1199 f64: AsPrimitive<Probability>,
1200 i32: AsPrimitive<Probability>,
1201 {
1202 #[cfg(not(miri))]
1203 const AMT: usize = 1000;
1204
1205 #[cfg(miri)]
1206 const AMT: usize = 100;
1207
1208 let mut symbols_gaussian = Vec::with_capacity(AMT);
1209 let mut means = Vec::with_capacity(AMT);
1210 let mut stds = Vec::with_capacity(AMT);
1211
1212 let mut rng = Xoshiro256StarStar::seed_from_u64(1234);
1213 for _ in 0..AMT {
1214 let mean = (200.0 / u32::MAX as f64) * rng.next_u32() as f64 - 100.0;
1215 let std_dev = (10.0 / u32::MAX as f64) * rng.next_u32() as f64 + 0.001;
1216 let quantile = (rng.next_u32() as f64 + 0.5) / (1u64 << 32) as f64;
1217 let dist = Gaussian::new(mean, std_dev);
1218 let symbol = (dist.inverse(quantile).round() as i32).clamp(-127, 127);
1219
1220 symbols_gaussian.push(symbol);
1221 means.push(mean);
1222 stds.push(std_dev);
1223 }
1224
1225 let hist = [
1226 1u32, 186545, 237403, 295700, 361445, 433686, 509456, 586943, 663946, 737772, 1657269,
1227 896675, 922197, 930672, 916665, 0, 0, 0, 0, 0, 723031, 650522, 572300, 494702, 418703,
1228 347600, 1, 283500, 226158, 178194, 136301, 103158, 76823, 55540, 39258, 27988, 54269,
1229 ];
1230 let categorical_probabilities = hist.iter().map(|&x| x as f64).collect::<Vec<_>>();
1231 let categorical =
1232 ContiguousCategoricalEntropyModel::<Probability, _, PRECISION>::from_floating_point_probabilities_fast::<f64>(
1233 &categorical_probabilities,None
1234 )
1235 .unwrap();
1236 let mut symbols_categorical = Vec::with_capacity(AMT);
1237 let max_probability = Probability::max_value() >> (Probability::BITS - PRECISION);
1238 for _ in 0..AMT {
1239 let quantile = rng.next_u32().as_() & max_probability;
1240 let symbol = categorical.quantile_function(quantile).0;
1241 symbols_categorical.push(symbol);
1242 }
1243
1244 let mut encoder = RangeEncoder::<Word, State>::new();
1245
1246 encoder
1247 .encode_iid_symbols(&symbols_categorical, &categorical)
1248 .unwrap();
1249 dbg!(
1250 encoder.num_bits(),
1251 AMT as f64 * categorical.entropy_base2::<f64>()
1252 );
1253
1254 let quantizer = LeakyQuantizer::<_, _, Probability, PRECISION>::new(-127..=127);
1255 encoder
1256 .encode_symbols(symbols_gaussian.iter().zip(&means).zip(&stds).map(
1257 |((&symbol, &mean), &core)| (symbol, quantizer.quantize(Gaussian::new(mean, core))),
1258 ))
1259 .unwrap();
1260 dbg!(encoder.num_bits());
1261
1262 let mut decoder = encoder.into_decoder().unwrap();
1263
1264 let reconstructed_categorical = decoder
1265 .decode_iid_symbols(AMT, &categorical)
1266 .collect::<Result<Vec<_>, _>>()
1267 .unwrap();
1268 let reconstructed_gaussian = decoder
1269 .decode_symbols(
1270 means
1271 .iter()
1272 .zip(&stds)
1273 .map(|(&mean, &core)| quantizer.quantize(Gaussian::new(mean, core))),
1274 )
1275 .collect::<Result<Vec<_>, _>>()
1276 .unwrap();
1277
1278 assert!(decoder.maybe_exhausted());
1279
1280 assert_eq!(symbols_categorical, reconstructed_categorical);
1281 assert_eq!(symbols_gaussian, reconstructed_gaussian);
1282 }
1283
1284 #[test]
1285 fn seek() {
1286 #[cfg(not(miri))]
1287 let (num_chunks, symbols_per_chunk) = (100, 100);
1288
1289 #[cfg(miri)]
1290 let (num_chunks, symbols_per_chunk) = (10, 10);
1291
1292 let quantizer = LeakyQuantizer::<_, _, u32, 24>::new(-100..=100);
1293 let model = quantizer.quantize(Gaussian::new(0.0, 10.0));
1294
1295 let mut encoder = DefaultRangeEncoder::new();
1296
1297 let mut rng = Xoshiro256StarStar::seed_from_u64(123);
1298 let mut symbols = Vec::with_capacity(num_chunks);
1299 let mut jump_table = Vec::with_capacity(num_chunks);
1300
1301 for _ in 0..num_chunks {
1302 jump_table.push(encoder.pos());
1303 let chunk = (0..symbols_per_chunk)
1304 .map(|_| model.quantile_function(rng.next_u32() % (1 << 24)).0)
1305 .collect::<Vec<_>>();
1306 encoder.encode_iid_symbols(&chunk, &model).unwrap();
1307 symbols.push(chunk);
1308 }
1309 let final_pos_and_state = encoder.pos();
1310
1311 let mut decoder = encoder.decoder();
1312
1313 for (chunk, _) in symbols.iter().zip(&jump_table) {
1317 let decoded = decoder
1318 .decode_iid_symbols(symbols_per_chunk, &model)
1319 .collect::<Result<Vec<_>, _>>()
1320 .unwrap();
1321 assert_eq!(&decoded, chunk);
1322 }
1323 assert!(decoder.maybe_exhausted());
1324
1325 for i in 0..100 {
1327 let chunk_index = if i == 3 {
1328 0
1330 } else {
1331 rng.next_u32() as usize % num_chunks
1332 };
1333
1334 let pos_and_state = jump_table[chunk_index];
1335 decoder.seek(pos_and_state).unwrap();
1336 let decoded = decoder
1337 .decode_iid_symbols(symbols_per_chunk, &model)
1338 .collect::<Result<Vec<_>, _>>()
1339 .unwrap();
1340 assert_eq!(&decoded, &symbols[chunk_index])
1341 }
1342
1343 decoder.seek(jump_table[0]).unwrap();
1345 assert!(!decoder.maybe_exhausted());
1346 decoder.seek(final_pos_and_state).unwrap();
1347 assert!(decoder.maybe_exhausted());
1348 }
1349}
1350
1351#[derive(Debug)]
1352#[non_exhaustive]
1353pub enum DecoderFrontendError {
1354 InvalidData,
1379}
1380
1381impl Display for DecoderFrontendError {
1382 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
1383 match self {
1384 Self::InvalidData => write!(
1385 f,
1386 "Tried to decode from compressed data that is invalid for the employed entropy model."
1387 ),
1388 }
1389 }
1390}
1391
1392#[cfg(feature = "std")]
1393impl std::error::Error for DecoderFrontendError {}