1use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
55
56#[derive(Debug, Clone)]
74pub struct Decoder<U: UPrim> {
75 result: U,
76 bits_read: usize,
77}
78
79impl<U: UPrim> Default for Decoder<U> {
80 fn default() -> Self {
81 Self::new()
82 }
83}
84
85impl<U: UPrim> Decoder<U> {
86 #[inline]
88 pub fn new() -> Self {
89 Self {
90 result: U::from(0),
91 bits_read: 0,
92 }
93 }
94
95 #[inline]
102 pub fn feed(&mut self, byte: u8) -> Result<Option<U>, Error> {
103 let max_bits = U::SIZE * BITS_PER_BYTE;
104
105 if byte == 0 && self.bits_read > 0 {
110 return Err(Error::InvalidVarint(U::SIZE));
111 }
112
113 let remaining_bits = max_bits.checked_sub(self.bits_read).unwrap();
121 if remaining_bits <= DATA_BITS_PER_BYTE {
122 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
123 if relevant_bits > remaining_bits {
124 return Err(Error::InvalidVarint(U::SIZE));
125 }
126 }
127
128 self.result |= U::from(byte & DATA_BITS_MASK) << self.bits_read;
130
131 if byte & CONTINUATION_BIT_MASK == 0 {
133 return Ok(Some(self.result));
134 }
135
136 self.bits_read += DATA_BITS_PER_BYTE;
137 Ok(None)
138 }
139}
140
141#[doc(hidden)]
144mod sealed {
145 use super::*;
146 use core::ops::{BitOrAssign, Shl, ShrAssign};
147
148 pub trait UPrim:
150 Copy
151 + From<u8>
152 + Sized
153 + FixedSize
154 + ShrAssign<usize>
155 + Shl<usize, Output = Self>
156 + BitOrAssign<Self>
157 + PartialOrd
158 + Debug
159 {
160 fn leading_zeros(self) -> u32;
162
163 fn as_u8(self) -> u8;
165 }
166
167 macro_rules! impl_uint {
169 ($type:ty) => {
170 impl UPrim for $type {
171 #[inline(always)]
172 fn leading_zeros(self) -> u32 {
173 self.leading_zeros()
174 }
175
176 #[inline(always)]
177 fn as_u8(self) -> u8 {
178 self as u8
179 }
180 }
181 };
182 }
183 impl_uint!(u16);
184 impl_uint!(u32);
185 impl_uint!(u64);
186 impl_uint!(u128);
187
188 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
195 type UnsignedEquivalent: UPrim;
198
199 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
201
202 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
204 }
205
206 #[inline(always)]
208 const fn assert_equal_size<T: Sized, U: Sized>() {
209 assert!(
210 size_of::<T>() == size_of::<U>(),
211 "Unsigned integer must be the same size as the signed integer"
212 );
213 }
214
215 macro_rules! impl_sint {
217 ($type:ty, $utype:ty) => {
218 impl SPrim for $type {
219 type UnsignedEquivalent = $utype;
220
221 #[inline]
222 fn as_zigzag(&self) -> $utype {
223 const {
225 assert_equal_size::<$type, $utype>();
226 }
227
228 let shr = size_of::<$utype>() * 8 - 1;
229 ((self << 1) ^ (self >> shr)) as $utype
230 }
231 #[inline]
232 fn un_zigzag(value: $utype) -> Self {
233 const {
235 assert_equal_size::<$type, $utype>();
236 }
237
238 ((value >> 1) as $type) ^ (-((value & 1) as $type))
239 }
240 }
241 };
242 }
243 impl_sint!(i16, u16);
244 impl_sint!(i32, u32);
245 impl_sint!(i64, u64);
246 impl_sint!(i128, u128);
247}
248
249#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
254pub struct UInt<U: UPrim>(pub U);
255
256macro_rules! impl_varuint_into {
259 ($($type:ty),+) => {
260 $(
261 impl From<UInt<$type>> for $type {
262 fn from(val: UInt<$type>) -> Self {
263 val.0
264 }
265 }
266 )+
267 };
268}
269impl_varuint_into!(u16, u32, u64, u128);
270
271impl<U: UPrim> Write for UInt<U> {
272 fn write(&self, buf: &mut impl BufMut) {
273 write(self.0, buf);
274 }
275}
276
277impl<U: UPrim> Read for UInt<U> {
278 type Cfg = ();
279 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
280 read(buf).map(UInt)
281 }
282}
283
284impl<U: UPrim> EncodeSize for UInt<U> {
285 fn encode_size(&self) -> usize {
286 size(self.0)
287 }
288}
289
290#[cfg(feature = "arbitrary")]
291impl<U: UPrim> arbitrary::Arbitrary<'_> for UInt<U>
292where
293 U: for<'a> arbitrary::Arbitrary<'a>,
294{
295 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
296 let value = U::arbitrary(u)?;
297 Ok(Self(value))
298 }
299}
300
301#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
304pub struct SInt<S: SPrim>(pub S);
305
306macro_rules! impl_varsint_into {
309 ($($type:ty),+) => {
310 $(
311 impl From<SInt<$type>> for $type {
312 fn from(val: SInt<$type>) -> Self {
313 val.0
314 }
315 }
316 )+
317 };
318}
319impl_varsint_into!(i16, i32, i64, i128);
320
321impl<S: SPrim> Write for SInt<S> {
322 fn write(&self, buf: &mut impl BufMut) {
323 write_signed::<S>(self.0, buf);
324 }
325}
326
327impl<S: SPrim> Read for SInt<S> {
328 type Cfg = ();
329 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
330 read_signed::<S>(buf).map(SInt)
331 }
332}
333
334impl<S: SPrim> EncodeSize for SInt<S> {
335 fn encode_size(&self) -> usize {
336 size_signed::<S>(self.0)
337 }
338}
339
340#[cfg(feature = "arbitrary")]
341impl<S: SPrim> arbitrary::Arbitrary<'_> for SInt<S>
342where
343 S: for<'a> arbitrary::Arbitrary<'a>,
344{
345 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
346 let value = S::arbitrary(u)?;
347 Ok(Self(value))
348 }
349}
350
351fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
355 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
356 if value < continuation_threshold {
357 buf.put_u8(value.as_u8());
360 return;
361 }
362
363 let mut val = value;
364 while val >= continuation_threshold {
365 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
366 val >>= 7;
367 }
368 buf.put_u8(val.as_u8());
369}
370
371fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
377 let mut decoder = Decoder::<T>::new();
378 loop {
379 let byte = u8::read(buf)?;
381 if let Some(value) = decoder.feed(byte)? {
382 return Ok(value);
383 }
384 }
385}
386
387fn size<T: UPrim>(value: T) -> usize {
389 let total_bits = size_of::<T>() * 8;
390 let leading_zeros = value.leading_zeros() as usize;
391 let data_bits = total_bits - leading_zeros;
392 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
393}
394
395fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
397 write(value.as_zigzag(), buf);
398}
399
400fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
402 Ok(S::un_zigzag(read(buf)?))
403}
404
405fn size_signed<S: SPrim>(value: S) -> usize {
407 size(value.as_zigzag())
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use crate::{error::Error, DecodeExt, Encode};
414 #[cfg(not(feature = "std"))]
415 use alloc::vec::Vec;
416 use bytes::Bytes;
417
418 #[test]
419 fn test_end_of_buffer() {
420 let mut buf: Bytes = Bytes::from_static(&[]);
421 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
422
423 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
424 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
425
426 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
427 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
428 }
429
430 #[test]
431 fn test_overflow() {
432 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
433 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
434
435 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
436 assert!(matches!(
437 read::<u32>(&mut buf),
438 Err(Error::InvalidVarint(u32::SIZE))
439 ));
440
441 let mut buf =
442 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
443 assert!(matches!(
444 read::<u64>(&mut buf),
445 Err(Error::InvalidVarint(u64::SIZE))
446 ));
447 }
448
449 #[test]
450 fn test_overcontinuation() {
451 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
452 let result = read::<u32>(&mut buf);
453 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
454 }
455
456 #[test]
457 fn test_zeroed_byte() {
458 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
459 let result = read::<u64>(&mut buf);
460 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
461 }
462
463 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
465 const CASES: &[u128] = &[
466 0,
467 1,
468 127,
469 128,
470 129,
471 0xFF,
472 0x100,
473 0x3FFF,
474 0x4000,
475 0x1_FFFF,
476 0xFF_FFFF,
477 0x1_FF_FF_FF_FF,
478 0xFF_FF_FF_FF_FF_FF,
479 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
480 u16::MAX as u128,
481 u32::MAX as u128,
482 u64::MAX as u128,
483 u128::MAX,
484 ];
485
486 for &raw in CASES {
487 let Ok(value) = raw.try_into() else { continue };
489 let value: T = value;
490
491 let mut buf = Vec::new();
493 write(value, &mut buf);
494 assert_eq!(buf.len(), size(value));
495
496 let mut slice = &buf[..];
498 let decoded: T = read(&mut slice).unwrap();
499 assert_eq!(decoded, value);
500 assert!(slice.is_empty());
501
502 let encoded = UInt(value).encode();
504 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
505 }
506 }
507
508 #[test]
509 fn test_varuint() {
510 varuint_round_trip::<u16>();
511 varuint_round_trip::<u32>();
512 varuint_round_trip::<u64>();
513 varuint_round_trip::<u128>();
514 }
515
516 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
517 const CASES: &[i128] = &[
518 0,
519 1,
520 -1,
521 2,
522 -2,
523 127,
524 -127,
525 128,
526 -128,
527 129,
528 -129,
529 0x7FFFFFFF,
530 -0x7FFFFFFF,
531 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
532 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
533 i16::MIN as i128,
534 i16::MAX as i128,
535 i32::MIN as i128,
536 i32::MAX as i128,
537 i64::MIN as i128,
538 i64::MAX as i128,
539 ];
540
541 for &raw in CASES {
542 let Ok(value) = raw.try_into() else { continue };
544 let value: T = value;
545
546 let mut buf = Vec::new();
548 write_signed(value, &mut buf);
549 assert_eq!(buf.len(), size_signed(value));
550
551 let mut slice = &buf[..];
553 let decoded: T = read_signed(&mut slice).unwrap();
554 assert_eq!(decoded, value);
555 assert!(slice.is_empty());
556
557 let encoded = SInt(value).encode();
559 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
560 }
561 }
562
563 #[test]
564 fn test_varsint() {
565 varsint_round_trip::<i16>();
566 varsint_round_trip::<i32>();
567 varsint_round_trip::<i64>();
568 varsint_round_trip::<i128>();
569 }
570
571 #[test]
572 fn test_varuint_into() {
573 let v32: u32 = 0x1_FFFF;
574 let out32: u32 = UInt(v32).into();
575 assert_eq!(v32, out32);
576
577 let v64: u64 = 0x1_FF_FF_FF_FF;
578 let out64: u64 = UInt(v64).into();
579 assert_eq!(v64, out64);
580 }
581
582 #[test]
583 fn test_varsint_into() {
584 let s32: i32 = -123_456;
585 let out32: i32 = SInt(s32).into();
586 assert_eq!(s32, out32);
587
588 let s64: i64 = 987_654_321;
589 let out64: i64 = SInt(s64).into();
590 assert_eq!(s64, out64);
591 }
592
593 #[test]
594 fn test_conformity() {
595 assert_eq!(0usize.encode(), &[0x00][..]);
596 assert_eq!(1usize.encode(), &[0x01][..]);
597 assert_eq!(127usize.encode(), &[0x7F][..]);
598 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
599 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
600 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
601 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
602 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
603 assert_eq!(
604 (u32::MAX as usize).encode(),
605 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
606 );
607 }
608
609 #[test]
610 fn test_all_u16_values() {
611 for i in 0..=u16::MAX {
613 let value = i;
614 let calculated_size = size(value);
615
616 let mut buf = Vec::new();
617 write(value, &mut buf);
618
619 assert_eq!(
620 buf.len(),
621 calculated_size,
622 "Size mismatch for u16 value {value}",
623 );
624
625 let uint = UInt(value);
627 assert_eq!(
628 uint.encode_size(),
629 buf.len(),
630 "UInt encode_size mismatch for value {value}",
631 );
632 }
633 }
634
635 #[test]
636 fn test_all_i16_values() {
637 for i in i16::MIN..=i16::MAX {
639 let value = i;
640 let calculated_size = size_signed(value);
641
642 let mut buf = Vec::new();
643 write_signed(value, &mut buf);
644
645 assert_eq!(
646 buf.len(),
647 calculated_size,
648 "Size mismatch for i16 value {value}",
649 );
650
651 let sint = SInt(value);
653 assert_eq!(
654 sint.encode_size(),
655 buf.len(),
656 "SInt encode_size mismatch for value {value}",
657 );
658
659 let mut slice = &buf[..];
661 let decoded: i16 = read_signed(&mut slice).unwrap();
662 assert_eq!(decoded, value, "Decode mismatch for value {value}");
663 assert!(
664 slice.is_empty(),
665 "Buffer not fully consumed for value {value}",
666 );
667 }
668 }
669
670 #[test]
671 fn test_exact_bit_boundaries() {
672 fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
674 for bits in 1..=128 {
675 let val = if bits == 128 {
678 u128::MAX
679 } else {
680 (1u128 << bits) - 1
681 };
682 let Ok(value) = T::try_from(val) else {
683 continue;
684 };
685
686 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
688 let calculated_size = size(value);
689 assert_eq!(
690 calculated_size, expected_size,
691 "Size calculation wrong for {val} with {bits} bits",
692 );
693
694 let mut buf = Vec::new();
696 write(value, &mut buf);
697 assert_eq!(
698 buf.len(),
699 expected_size,
700 "Encoded size wrong for {val} with {bits} bits",
701 );
702 }
703 }
704
705 test_exact_bits::<u16>();
706 test_exact_bits::<u32>();
707 test_exact_bits::<u64>();
708 test_exact_bits::<u128>();
709 }
710
711 #[test]
712 fn test_single_bit_boundaries() {
713 fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
715 for bit_pos in 0..128 {
716 let val = 1u128 << bit_pos;
718 let Ok(value) = T::try_from(val) else {
719 continue;
720 };
721
722 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
724 let calculated_size = size(value);
725 assert_eq!(
726 calculated_size, expected_size,
727 "Size wrong for 1<<{bit_pos} = {val}",
728 );
729
730 let mut buf = Vec::new();
732 write(value, &mut buf);
733 assert_eq!(
734 buf.len(),
735 expected_size,
736 "Encoded size wrong for 1<<{bit_pos} = {val}",
737 );
738 }
739 }
740
741 test_single_bits::<u16>();
742 test_single_bits::<u32>();
743 test_single_bits::<u64>();
744 test_single_bits::<u128>();
745 }
746
747 #[cfg(feature = "arbitrary")]
748 mod conformance {
749 use super::*;
750 use crate::conformance::CodecConformance;
751
752 commonware_conformance::conformance_tests! {
753 CodecConformance<UInt<u16>>,
754 CodecConformance<UInt<u32>>,
755 CodecConformance<UInt<u64>>,
756 CodecConformance<UInt<u128>>,
757 CodecConformance<SInt<i16>>,
758 CodecConformance<SInt<i32>>,
759 CodecConformance<SInt<i64>>,
760 CodecConformance<SInt<i128>>,
761 }
762 }
763}