1use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
55
56pub const MAX_U16_VARINT_SIZE: usize =
58 (size_of::<u16>() * BITS_PER_BYTE).div_ceil(DATA_BITS_PER_BYTE);
59
60pub const MAX_U32_VARINT_SIZE: usize =
62 (size_of::<u32>() * BITS_PER_BYTE).div_ceil(DATA_BITS_PER_BYTE);
63
64pub const MAX_U64_VARINT_SIZE: usize =
66 (size_of::<u64>() * BITS_PER_BYTE).div_ceil(DATA_BITS_PER_BYTE);
67
68pub const MAX_U128_VARINT_SIZE: usize =
70 (size_of::<u128>() * BITS_PER_BYTE).div_ceil(DATA_BITS_PER_BYTE);
71
72#[derive(Debug, Clone)]
90pub struct Decoder<U: UPrim> {
91 result: U,
92 bits_read: usize,
93}
94
95impl<U: UPrim> Default for Decoder<U> {
96 fn default() -> Self {
97 Self::new()
98 }
99}
100
101impl<U: UPrim> Decoder<U> {
102 #[inline]
104 pub fn new() -> Self {
105 Self {
106 result: U::from(0),
107 bits_read: 0,
108 }
109 }
110
111 #[inline]
118 pub fn feed(&mut self, byte: u8) -> Result<Option<U>, Error> {
119 let max_bits = U::SIZE * BITS_PER_BYTE;
120
121 if byte == 0 && self.bits_read > 0 {
126 return Err(Error::InvalidVarint(U::SIZE));
127 }
128
129 let remaining_bits = max_bits.checked_sub(self.bits_read).unwrap();
137 if remaining_bits <= DATA_BITS_PER_BYTE {
138 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
139 if relevant_bits > remaining_bits {
140 return Err(Error::InvalidVarint(U::SIZE));
141 }
142 }
143
144 self.result |= U::from(byte & DATA_BITS_MASK) << self.bits_read;
146
147 if byte & CONTINUATION_BIT_MASK == 0 {
149 return Ok(Some(self.result));
150 }
151
152 self.bits_read += DATA_BITS_PER_BYTE;
153 Ok(None)
154 }
155}
156
157#[doc(hidden)]
160mod sealed {
161 use super::*;
162 use core::ops::{BitOrAssign, Shl, ShrAssign};
163
164 pub trait UPrim:
166 Copy
167 + From<u8>
168 + Sized
169 + FixedSize
170 + ShrAssign<usize>
171 + Shl<usize, Output = Self>
172 + BitOrAssign<Self>
173 + PartialOrd
174 + Debug
175 {
176 fn leading_zeros(self) -> u32;
178
179 fn as_u8(self) -> u8;
181 }
182
183 macro_rules! impl_uint {
185 ($type:ty) => {
186 impl UPrim for $type {
187 #[inline(always)]
188 fn leading_zeros(self) -> u32 {
189 self.leading_zeros()
190 }
191
192 #[inline(always)]
193 fn as_u8(self) -> u8 {
194 self as u8
195 }
196 }
197 };
198 }
199 impl_uint!(u16);
200 impl_uint!(u32);
201 impl_uint!(u64);
202 impl_uint!(u128);
203
204 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
211 type UnsignedEquivalent: UPrim;
214
215 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
217
218 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
220 }
221
222 #[inline(always)]
224 const fn assert_equal_size<T: Sized, U: Sized>() {
225 assert!(
226 size_of::<T>() == size_of::<U>(),
227 "Unsigned integer must be the same size as the signed integer"
228 );
229 }
230
231 macro_rules! impl_sint {
233 ($type:ty, $utype:ty) => {
234 impl SPrim for $type {
235 type UnsignedEquivalent = $utype;
236
237 #[inline]
238 fn as_zigzag(&self) -> $utype {
239 const {
241 assert_equal_size::<$type, $utype>();
242 }
243
244 let shr = size_of::<$utype>() * 8 - 1;
245 ((self << 1) ^ (self >> shr)) as $utype
246 }
247 #[inline]
248 fn un_zigzag(value: $utype) -> Self {
249 const {
251 assert_equal_size::<$type, $utype>();
252 }
253
254 ((value >> 1) as $type) ^ (-((value & 1) as $type))
255 }
256 }
257 };
258 }
259 impl_sint!(i16, u16);
260 impl_sint!(i32, u32);
261 impl_sint!(i64, u64);
262 impl_sint!(i128, u128);
263}
264
265#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
270pub struct UInt<U: UPrim>(pub U);
271
272macro_rules! impl_varuint_into {
275 ($($type:ty),+) => {
276 $(
277 impl From<UInt<$type>> for $type {
278 fn from(val: UInt<$type>) -> Self {
279 val.0
280 }
281 }
282 )+
283 };
284}
285impl_varuint_into!(u16, u32, u64, u128);
286
287impl<U: UPrim> Write for UInt<U> {
288 fn write(&self, buf: &mut impl BufMut) {
289 write(self.0, buf);
290 }
291}
292
293impl<U: UPrim> Read for UInt<U> {
294 type Cfg = ();
295 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
296 read(buf).map(UInt)
297 }
298}
299
300impl<U: UPrim> EncodeSize for UInt<U> {
301 fn encode_size(&self) -> usize {
302 size(self.0)
303 }
304}
305
306#[cfg(feature = "arbitrary")]
307impl<U: UPrim> arbitrary::Arbitrary<'_> for UInt<U>
308where
309 U: for<'a> arbitrary::Arbitrary<'a>,
310{
311 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
312 let value = U::arbitrary(u)?;
313 Ok(Self(value))
314 }
315}
316
317#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
320pub struct SInt<S: SPrim>(pub S);
321
322macro_rules! impl_varsint_into {
325 ($($type:ty),+) => {
326 $(
327 impl From<SInt<$type>> for $type {
328 fn from(val: SInt<$type>) -> Self {
329 val.0
330 }
331 }
332 )+
333 };
334}
335impl_varsint_into!(i16, i32, i64, i128);
336
337impl<S: SPrim> Write for SInt<S> {
338 fn write(&self, buf: &mut impl BufMut) {
339 write_signed::<S>(self.0, buf);
340 }
341}
342
343impl<S: SPrim> Read for SInt<S> {
344 type Cfg = ();
345 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
346 read_signed::<S>(buf).map(SInt)
347 }
348}
349
350impl<S: SPrim> EncodeSize for SInt<S> {
351 fn encode_size(&self) -> usize {
352 size_signed::<S>(self.0)
353 }
354}
355
356#[cfg(feature = "arbitrary")]
357impl<S: SPrim> arbitrary::Arbitrary<'_> for SInt<S>
358where
359 S: for<'a> arbitrary::Arbitrary<'a>,
360{
361 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
362 let value = S::arbitrary(u)?;
363 Ok(Self(value))
364 }
365}
366
367fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
371 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
372 if value < continuation_threshold {
373 buf.put_u8(value.as_u8());
376 return;
377 }
378
379 let mut val = value;
380 while val >= continuation_threshold {
381 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
382 val >>= 7;
383 }
384 buf.put_u8(val.as_u8());
385}
386
387fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
393 let mut decoder = Decoder::<T>::new();
394 loop {
395 let byte = u8::read(buf)?;
397 if let Some(value) = decoder.feed(byte)? {
398 return Ok(value);
399 }
400 }
401}
402
403fn size<T: UPrim>(value: T) -> usize {
405 let total_bits = size_of::<T>() * 8;
406 let leading_zeros = value.leading_zeros() as usize;
407 let data_bits = total_bits - leading_zeros;
408 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
409}
410
411fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
413 write(value.as_zigzag(), buf);
414}
415
416fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
418 Ok(S::un_zigzag(read(buf)?))
419}
420
421fn size_signed<S: SPrim>(value: S) -> usize {
423 size(value.as_zigzag())
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use crate::{error::Error, DecodeExt, Encode};
430 use bytes::Bytes;
431
432 #[test]
433 fn test_end_of_buffer() {
434 let mut buf: Bytes = Bytes::from_static(&[]);
435 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
436
437 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
438 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
439
440 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
441 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
442 }
443
444 #[test]
445 fn test_overflow() {
446 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
447 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
448
449 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
450 assert!(matches!(
451 read::<u32>(&mut buf),
452 Err(Error::InvalidVarint(u32::SIZE))
453 ));
454
455 let mut buf =
456 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
457 assert!(matches!(
458 read::<u64>(&mut buf),
459 Err(Error::InvalidVarint(u64::SIZE))
460 ));
461 }
462
463 #[test]
464 fn test_overcontinuation() {
465 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
466 let result = read::<u32>(&mut buf);
467 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
468 }
469
470 #[test]
471 fn test_zeroed_byte() {
472 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
473 let result = read::<u64>(&mut buf);
474 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
475 }
476
477 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
479 const CASES: &[u128] = &[
480 0,
481 1,
482 127,
483 128,
484 129,
485 0xFF,
486 0x100,
487 0x3FFF,
488 0x4000,
489 0x1_FFFF,
490 0xFF_FFFF,
491 0x1_FF_FF_FF_FF,
492 0xFF_FF_FF_FF_FF_FF,
493 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
494 u16::MAX as u128,
495 u32::MAX as u128,
496 u64::MAX as u128,
497 u128::MAX,
498 ];
499
500 for &raw in CASES {
501 let Ok(value) = raw.try_into() else { continue };
503 let value: T = value;
504
505 let mut buf = Vec::new();
507 write(value, &mut buf);
508 assert_eq!(buf.len(), size(value));
509
510 let mut slice = &buf[..];
512 let decoded: T = read(&mut slice).unwrap();
513 assert_eq!(decoded, value);
514 assert!(slice.is_empty());
515
516 let encoded = UInt(value).encode();
518 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
519 }
520 }
521
522 #[test]
523 fn test_varuint() {
524 varuint_round_trip::<u16>();
525 varuint_round_trip::<u32>();
526 varuint_round_trip::<u64>();
527 varuint_round_trip::<u128>();
528 }
529
530 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
531 const CASES: &[i128] = &[
532 0,
533 1,
534 -1,
535 2,
536 -2,
537 127,
538 -127,
539 128,
540 -128,
541 129,
542 -129,
543 0x7FFFFFFF,
544 -0x7FFFFFFF,
545 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
546 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
547 i16::MIN as i128,
548 i16::MAX as i128,
549 i32::MIN as i128,
550 i32::MAX as i128,
551 i64::MIN as i128,
552 i64::MAX as i128,
553 ];
554
555 for &raw in CASES {
556 let Ok(value) = raw.try_into() else { continue };
558 let value: T = value;
559
560 let mut buf = Vec::new();
562 write_signed(value, &mut buf);
563 assert_eq!(buf.len(), size_signed(value));
564
565 let mut slice = &buf[..];
567 let decoded: T = read_signed(&mut slice).unwrap();
568 assert_eq!(decoded, value);
569 assert!(slice.is_empty());
570
571 let encoded = SInt(value).encode();
573 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
574 }
575 }
576
577 #[test]
578 fn test_varsint() {
579 varsint_round_trip::<i16>();
580 varsint_round_trip::<i32>();
581 varsint_round_trip::<i64>();
582 varsint_round_trip::<i128>();
583 }
584
585 #[test]
586 fn test_varuint_into() {
587 let v32: u32 = 0x1_FFFF;
588 let out32: u32 = UInt(v32).into();
589 assert_eq!(v32, out32);
590
591 let v64: u64 = 0x1_FF_FF_FF_FF;
592 let out64: u64 = UInt(v64).into();
593 assert_eq!(v64, out64);
594 }
595
596 #[test]
597 fn test_varsint_into() {
598 let s32: i32 = -123_456;
599 let out32: i32 = SInt(s32).into();
600 assert_eq!(s32, out32);
601
602 let s64: i64 = 987_654_321;
603 let out64: i64 = SInt(s64).into();
604 assert_eq!(s64, out64);
605 }
606
607 #[test]
608 fn test_conformity() {
609 assert_eq!(0usize.encode(), &[0x00][..]);
610 assert_eq!(1usize.encode(), &[0x01][..]);
611 assert_eq!(127usize.encode(), &[0x7F][..]);
612 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
613 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
614 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
615 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
616 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
617 assert_eq!(
618 (u32::MAX as usize).encode(),
619 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
620 );
621 }
622
623 #[test]
624 fn test_all_u16_values() {
625 for i in 0..=u16::MAX {
627 let value = i;
628 let calculated_size = size(value);
629
630 let mut buf = Vec::new();
631 write(value, &mut buf);
632
633 assert_eq!(
634 buf.len(),
635 calculated_size,
636 "Size mismatch for u16 value {value}",
637 );
638
639 let uint = UInt(value);
641 assert_eq!(
642 uint.encode_size(),
643 buf.len(),
644 "UInt encode_size mismatch for value {value}",
645 );
646 }
647 }
648
649 #[test]
650 fn test_all_i16_values() {
651 for i in i16::MIN..=i16::MAX {
653 let value = i;
654 let calculated_size = size_signed(value);
655
656 let mut buf = Vec::new();
657 write_signed(value, &mut buf);
658
659 assert_eq!(
660 buf.len(),
661 calculated_size,
662 "Size mismatch for i16 value {value}",
663 );
664
665 let sint = SInt(value);
667 assert_eq!(
668 sint.encode_size(),
669 buf.len(),
670 "SInt encode_size mismatch for value {value}",
671 );
672
673 let mut slice = &buf[..];
675 let decoded: i16 = read_signed(&mut slice).unwrap();
676 assert_eq!(decoded, value, "Decode mismatch for value {value}");
677 assert!(
678 slice.is_empty(),
679 "Buffer not fully consumed for value {value}",
680 );
681 }
682 }
683
684 #[test]
685 fn test_exact_bit_boundaries() {
686 fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
688 for bits in 1..=128 {
689 let val = if bits == 128 {
692 u128::MAX
693 } else {
694 (1u128 << bits) - 1
695 };
696 let Ok(value) = T::try_from(val) else {
697 continue;
698 };
699
700 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
702 let calculated_size = size(value);
703 assert_eq!(
704 calculated_size, expected_size,
705 "Size calculation wrong for {val} with {bits} bits",
706 );
707
708 let mut buf = Vec::new();
710 write(value, &mut buf);
711 assert_eq!(
712 buf.len(),
713 expected_size,
714 "Encoded size wrong for {val} with {bits} bits",
715 );
716 }
717 }
718
719 test_exact_bits::<u16>();
720 test_exact_bits::<u32>();
721 test_exact_bits::<u64>();
722 test_exact_bits::<u128>();
723 }
724
725 #[test]
726 fn test_single_bit_boundaries() {
727 fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
729 for bit_pos in 0..128 {
730 let val = 1u128 << bit_pos;
732 let Ok(value) = T::try_from(val) else {
733 continue;
734 };
735
736 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
738 let calculated_size = size(value);
739 assert_eq!(
740 calculated_size, expected_size,
741 "Size wrong for 1<<{bit_pos} = {val}",
742 );
743
744 let mut buf = Vec::new();
746 write(value, &mut buf);
747 assert_eq!(
748 buf.len(),
749 expected_size,
750 "Encoded size wrong for 1<<{bit_pos} = {val}",
751 );
752 }
753 }
754
755 test_single_bits::<u16>();
756 test_single_bits::<u32>();
757 test_single_bits::<u64>();
758 test_single_bits::<u128>();
759 }
760
761 #[test]
762 fn test_max_varint_size_constants() {
763 let mut buf = Vec::new();
764
765 write(u16::MAX, &mut buf);
766 assert_eq!(buf.len(), MAX_U16_VARINT_SIZE);
767 buf.clear();
768
769 write(u32::MAX, &mut buf);
770 assert_eq!(buf.len(), MAX_U32_VARINT_SIZE);
771 buf.clear();
772
773 write(u64::MAX, &mut buf);
774 assert_eq!(buf.len(), MAX_U64_VARINT_SIZE);
775 buf.clear();
776
777 write(u128::MAX, &mut buf);
778 assert_eq!(buf.len(), MAX_U128_VARINT_SIZE);
779 }
780
781 #[cfg(feature = "arbitrary")]
782 mod conformance {
783 use super::*;
784 use crate::conformance::CodecConformance;
785
786 commonware_conformance::conformance_tests! {
787 CodecConformance<UInt<u16>>,
788 CodecConformance<UInt<u32>>,
789 CodecConformance<UInt<u64>>,
790 CodecConformance<UInt<u128>>,
791 CodecConformance<SInt<i16>>,
792 CodecConformance<SInt<i32>>,
793 CodecConformance<SInt<i64>>,
794 CodecConformance<SInt<i128>>,
795 }
796 }
797}