commonware_codec/
varint.rs

1//! Variable-length integer encoding and decoding
2//!
3//! # Overview
4//!
5//! This module implements Google's Protocol Buffers variable-length integer encoding.
6//! Each byte uses:
7//! - 7 bits for the value
8//! - 1 "continuation" bit to indicate if more bytes follow
9//!
10//! `u8` and `i8` are omitted since those types do not benefit from varint encoding.
11//!
12//! `usize` and `isize` are omitted to prevent behavior from depending on the target architecture.
13//!
14//! # Usage Example
15//!
16//! ```rust
17//! use commonware_codec::{Encode, DecodeExt, varint::{UInt, SInt}};
18//!
19//! // Unsigned example
20//! let one = UInt(42u128).encode();
21//! assert_eq!(one.len(), 1); // 42 fits in a single byte
22//! let decoded: u128 = UInt::decode(one).unwrap().into();
23//! assert_eq!(decoded, 42);
24//!
25//! // Signed example (ZigZag)
26//! let neg = SInt(-3i32).encode();
27//! assert_eq!(neg.len(), 1);
28//! let decoded: i32 = SInt::decode(neg).unwrap().into();
29//! assert_eq!(decoded, -3);
30//! ```
31
32use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37// ---------- Constants ----------
38
39/// The number of bits in a byte.
40const BITS_PER_BYTE: usize = 8;
41
42/// The number of data-bearing bits in a byte.
43/// That is, the number of bits in a byte excluding the continuation bit.
44const DATA_BITS_PER_BYTE: usize = 7;
45
46/// The mask for the data-bearing bits in a byte.
47const DATA_BITS_MASK: u8 = 0x7F;
48
49/// The mask for the continuation bit in a byte.
50const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52// ---------- Traits ----------
53
54#[doc(hidden)]
55mod sealed {
56    use super::*;
57    use core::ops::{BitOrAssign, Shl, ShrAssign};
58
59    /// A trait for unsigned integer primitives that can be varint encoded.
60    pub trait UPrim:
61        Copy
62        + From<u8>
63        + Sized
64        + FixedSize
65        + ShrAssign<usize>
66        + Shl<usize, Output = Self>
67        + BitOrAssign<Self>
68        + PartialOrd
69        + Debug
70    {
71        /// Returns the number of leading zeros in the integer.
72        fn leading_zeros(self) -> u32;
73
74        /// Returns the least significant byte of the integer.
75        fn as_u8(self) -> u8;
76    }
77
78    // Implements the `UPrim` trait for all unsigned integer types.
79    macro_rules! impl_uint {
80        ($type:ty) => {
81            impl UPrim for $type {
82                #[inline(always)]
83                fn leading_zeros(self) -> u32 {
84                    self.leading_zeros()
85                }
86
87                #[inline(always)]
88                fn as_u8(self) -> u8 {
89                    self as u8
90                }
91            }
92        };
93    }
94    impl_uint!(u16);
95    impl_uint!(u32);
96    impl_uint!(u64);
97    impl_uint!(u128);
98
99    /// A trait for signed integer primitives that can be converted to and from unsigned integer
100    /// primitives of the equivalent size.
101    ///
102    /// When converted to unsigned integers, the encoding is done using ZigZag encoding, which moves the
103    /// sign bit to the least significant bit (shifting all other bits to the left by one). This allows
104    /// for more efficient encoding of numbers that are close to zero, even if they are negative.
105    pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106        /// The unsigned equivalent type of the signed integer.
107        /// This type must be the same size as the signed integer type.
108        type UnsignedEquivalent: UPrim;
109
110        /// Converts the signed integer to an unsigned integer using ZigZag encoding.
111        fn as_zigzag(&self) -> Self::UnsignedEquivalent;
112
113        /// Converts a (ZigZag'ed) unsigned integer back to a signed integer.
114        fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
115    }
116
117    /// Asserts that two types are the same size at compile time.
118    #[inline(always)]
119    const fn assert_equal_size<T: Sized, U: Sized>() {
120        assert!(
121            size_of::<T>() == size_of::<U>(),
122            "Unsigned integer must be the same size as the signed integer"
123        );
124    }
125
126    // Implements the `SPrim` trait for all signed integer types.
127    macro_rules! impl_sint {
128        ($type:ty, $utype:ty) => {
129            impl SPrim for $type {
130                type UnsignedEquivalent = $utype;
131
132                #[inline]
133                fn as_zigzag(&self) -> $utype {
134                    // TODO: Re-evaluate assertion placement after `generic_const_exprs` is stable.
135                    const {
136                        assert_equal_size::<$type, $utype>();
137                    }
138
139                    let shr = size_of::<$utype>() * 8 - 1;
140                    ((self << 1) ^ (self >> shr)) as $utype
141                }
142                #[inline]
143                fn un_zigzag(value: $utype) -> Self {
144                    // TODO: Re-evaluate assertion placement after `generic_const_exprs` is stable.
145                    const {
146                        assert_equal_size::<$type, $utype>();
147                    }
148
149                    ((value >> 1) as $type) ^ (-((value & 1) as $type))
150                }
151            }
152        };
153    }
154    impl_sint!(i16, u16);
155    impl_sint!(i32, u32);
156    impl_sint!(i64, u64);
157    impl_sint!(i128, u128);
158}
159
160// ---------- Structs ----------
161
162/// An ergonomic wrapper to allow for encoding and decoding of primitive unsigned integers as
163/// varints rather than the default fixed-width integers.
164#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
165pub struct UInt<U: UPrim>(pub U);
166
167// Implements `Into<U>` for `UInt<U>` for all unsigned integer types.
168// This allows for easy conversion from `UInt<U>` to `U` using `.into()`.
169macro_rules! impl_varuint_into {
170    ($($type:ty),+) => {
171        $(
172            impl From<UInt<$type>> for $type {
173                fn from(val: UInt<$type>) -> Self {
174                    val.0
175                }
176            }
177        )+
178    };
179}
180impl_varuint_into!(u16, u32, u64, u128);
181
182impl<U: UPrim> Write for UInt<U> {
183    fn write(&self, buf: &mut impl BufMut) {
184        write(self.0, buf);
185    }
186}
187
188impl<U: UPrim> Read for UInt<U> {
189    type Cfg = ();
190    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
191        read(buf).map(UInt)
192    }
193}
194
195impl<U: UPrim> EncodeSize for UInt<U> {
196    fn encode_size(&self) -> usize {
197        size(self.0)
198    }
199}
200
201/// An ergonomic wrapper to allow for encoding and decoding of primitive signed integers as
202/// varints rather than the default fixed-width integers.
203#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
204pub struct SInt<S: SPrim>(pub S);
205
206// Implements `Into<U>` for `SInt<U>` for all signed integer types.
207// This allows for easy conversion from `SInt<S>` to `S` using `.into()`.
208macro_rules! impl_varsint_into {
209    ($($type:ty),+) => {
210        $(
211            impl From<SInt<$type>> for $type {
212                fn from(val: SInt<$type>) -> Self {
213                    val.0
214                }
215            }
216        )+
217    };
218}
219impl_varsint_into!(i16, i32, i64, i128);
220
221impl<S: SPrim> Write for SInt<S> {
222    fn write(&self, buf: &mut impl BufMut) {
223        write_signed::<S>(self.0, buf);
224    }
225}
226
227impl<S: SPrim> Read for SInt<S> {
228    type Cfg = ();
229    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
230        read_signed::<S>(buf).map(SInt)
231    }
232}
233
234impl<S: SPrim> EncodeSize for SInt<S> {
235    fn encode_size(&self) -> usize {
236        size_signed::<S>(self.0)
237    }
238}
239
240// ---------- Helper Functions ----------
241
242/// Encodes an unsigned integer as a varint
243fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
244    let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
245    if value < continuation_threshold {
246        // Fast path for small values (common case for lengths).
247        // `as_u8()` does not truncate the value or leave a continuation bit.
248        buf.put_u8(value.as_u8());
249        return;
250    }
251
252    let mut val = value;
253    while val >= continuation_threshold {
254        buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
255        val >>= 7;
256    }
257    buf.put_u8(val.as_u8());
258}
259
260/// Decodes a unsigned integer from a varint.
261///
262/// Returns an error if:
263/// - The varint is invalid (too long or malformed)
264/// - The buffer ends while reading
265fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
266    let max_bits = T::SIZE * BITS_PER_BYTE;
267    let mut result: T = T::from(0);
268    let mut bits_read = 0;
269
270    // Loop over all the bytes.
271    loop {
272        // Read the next byte.
273        let byte = u8::read(buf)?;
274
275        // If this is not the first byte, but the byte is completely zero, we have an invalid
276        // varint. This is because this byte has no data bits and no continuation, so there was no
277        // point in continuing to this byte in the first place. While the output could still result
278        // in a valid value, we ensure that every value has exactly one unique, valid encoding.
279        if byte == 0 && bits_read > 0 {
280            return Err(Error::InvalidVarint(T::SIZE));
281        }
282
283        // If this must be the last byte, check for overflow (i.e. set bits beyond the size of T).
284        // Because the continuation bit is the most-significant bit, this check also happens to
285        // check for an invalid continuation bit.
286        //
287        // If we have reached what must be the last byte, this check prevents continuing to read
288        // from the buffer by ensuring that the conditional (`if byte & CONTINUATION_BIT_MASK == 0`)
289        // always evaluates to true.
290        let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
291        if remaining_bits <= DATA_BITS_PER_BYTE {
292            let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
293            if relevant_bits > remaining_bits {
294                return Err(Error::InvalidVarint(T::SIZE));
295            }
296        }
297
298        // Write the 7 bits of data to the result.
299        result |= T::from(byte & DATA_BITS_MASK) << bits_read;
300
301        // If the continuation bit is not set, return.
302        if byte & CONTINUATION_BIT_MASK == 0 {
303            return Ok(result);
304        }
305
306        bits_read += DATA_BITS_PER_BYTE;
307    }
308}
309
310/// Calculates the number of bytes needed to encode an unsigned integer as a varint.
311fn size<T: UPrim>(value: T) -> usize {
312    let total_bits = size_of::<T>() * 8;
313    let leading_zeros = value.leading_zeros() as usize;
314    let data_bits = total_bits - leading_zeros;
315    usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
316}
317
318/// Encodes a signed integer as a varint using ZigZag encoding.
319fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
320    write(value.as_zigzag(), buf);
321}
322
323/// Decodes a signed integer from varint ZigZag encoding.
324fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
325    Ok(S::un_zigzag(read(buf)?))
326}
327
328/// Calculates the number of bytes needed to encode a signed integer as a varint.
329fn size_signed<S: SPrim>(value: S) -> usize {
330    size(value.as_zigzag())
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336    use crate::{error::Error, DecodeExt, Encode};
337    #[cfg(not(feature = "std"))]
338    use alloc::vec::Vec;
339    use bytes::Bytes;
340
341    #[test]
342    fn test_end_of_buffer() {
343        let mut buf: Bytes = Bytes::from_static(&[]);
344        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
345
346        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
347        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
348
349        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
350        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
351    }
352
353    #[test]
354    fn test_overflow() {
355        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
356        assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
357
358        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
359        assert!(matches!(
360            read::<u32>(&mut buf),
361            Err(Error::InvalidVarint(u32::SIZE))
362        ));
363
364        let mut buf =
365            Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
366        assert!(matches!(
367            read::<u64>(&mut buf),
368            Err(Error::InvalidVarint(u64::SIZE))
369        ));
370    }
371
372    #[test]
373    fn test_overcontinuation() {
374        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
375        let result = read::<u32>(&mut buf);
376        assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
377    }
378
379    #[test]
380    fn test_zeroed_byte() {
381        let mut buf = Bytes::from_static(&[0xFF, 0x00]);
382        let result = read::<u64>(&mut buf);
383        assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
384    }
385
386    /// Core round-trip check, generic over any UPrim.
387    fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
388        const CASES: &[u128] = &[
389            0,
390            1,
391            127,
392            128,
393            129,
394            0xFF,
395            0x100,
396            0x3FFF,
397            0x4000,
398            0x1_FFFF,
399            0xFF_FFFF,
400            0x1_FF_FF_FF_FF,
401            0xFF_FF_FF_FF_FF_FF,
402            0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
403            u16::MAX as u128,
404            u32::MAX as u128,
405            u64::MAX as u128,
406            u128::MAX,
407        ];
408
409        for &raw in CASES {
410            // skip values that don't fit into T
411            let Ok(value) = raw.try_into() else { continue };
412            let value: T = value;
413
414            // size matches encoding length
415            let mut buf = Vec::new();
416            write(value, &mut buf);
417            assert_eq!(buf.len(), size(value));
418
419            // decode matches original value
420            let mut slice = &buf[..];
421            let decoded: T = read(&mut slice).unwrap();
422            assert_eq!(decoded, value);
423            assert!(slice.is_empty());
424
425            // UInt wrapper
426            let encoded = UInt(value).encode();
427            assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
428        }
429    }
430
431    #[test]
432    fn test_varuint() {
433        varuint_round_trip::<u16>();
434        varuint_round_trip::<u32>();
435        varuint_round_trip::<u64>();
436        varuint_round_trip::<u128>();
437    }
438
439    fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
440        const CASES: &[i128] = &[
441            0,
442            1,
443            -1,
444            2,
445            -2,
446            127,
447            -127,
448            128,
449            -128,
450            129,
451            -129,
452            0x7FFFFFFF,
453            -0x7FFFFFFF,
454            0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
455            -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
456            i16::MIN as i128,
457            i16::MAX as i128,
458            i32::MIN as i128,
459            i32::MAX as i128,
460            i64::MIN as i128,
461            i64::MAX as i128,
462        ];
463
464        for &raw in CASES {
465            // skip values that don't fit into T
466            let Ok(value) = raw.try_into() else { continue };
467            let value: T = value;
468
469            // size matches encoding length
470            let mut buf = Vec::new();
471            write_signed(value, &mut buf);
472            assert_eq!(buf.len(), size_signed(value));
473
474            // decode matches original value
475            let mut slice = &buf[..];
476            let decoded: T = read_signed(&mut slice).unwrap();
477            assert_eq!(decoded, value);
478            assert!(slice.is_empty());
479
480            // SInt wrapper
481            let encoded = SInt(value).encode();
482            assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
483        }
484    }
485
486    #[test]
487    fn test_varsint() {
488        varsint_round_trip::<i16>();
489        varsint_round_trip::<i32>();
490        varsint_round_trip::<i64>();
491        varsint_round_trip::<i128>();
492    }
493
494    #[test]
495    fn test_varuint_into() {
496        let v32: u32 = 0x1_FFFF;
497        let out32: u32 = UInt(v32).into();
498        assert_eq!(v32, out32);
499
500        let v64: u64 = 0x1_FF_FF_FF_FF;
501        let out64: u64 = UInt(v64).into();
502        assert_eq!(v64, out64);
503    }
504
505    #[test]
506    fn test_varsint_into() {
507        let s32: i32 = -123_456;
508        let out32: i32 = SInt(s32).into();
509        assert_eq!(s32, out32);
510
511        let s64: i64 = 987_654_321;
512        let out64: i64 = SInt(s64).into();
513        assert_eq!(s64, out64);
514    }
515
516    #[test]
517    fn test_conformity() {
518        assert_eq!(0usize.encode(), &[0x00][..]);
519        assert_eq!(1usize.encode(), &[0x01][..]);
520        assert_eq!(127usize.encode(), &[0x7F][..]);
521        assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
522        assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
523        assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
524        assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
525        assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
526        assert_eq!(
527            (u32::MAX as usize).encode(),
528            &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
529        );
530    }
531
532    #[test]
533    fn test_all_u16_values() {
534        // Exhaustively test all u16 values to ensure size matches encoding
535        for i in 0..=u16::MAX {
536            let value = i;
537            let calculated_size = size(value);
538
539            let mut buf = Vec::new();
540            write(value, &mut buf);
541
542            assert_eq!(
543                buf.len(),
544                calculated_size,
545                "Size mismatch for u16 value {value}",
546            );
547
548            // Also verify UInt wrapper
549            let uint = UInt(value);
550            assert_eq!(
551                uint.encode_size(),
552                buf.len(),
553                "UInt encode_size mismatch for value {value}",
554            );
555        }
556    }
557
558    #[test]
559    fn test_all_i16_values() {
560        // Exhaustively test all i16 values to ensure size matches encoding
561        for i in i16::MIN..=i16::MAX {
562            let value = i;
563            let calculated_size = size_signed(value);
564
565            let mut buf = Vec::new();
566            write_signed(value, &mut buf);
567
568            assert_eq!(
569                buf.len(),
570                calculated_size,
571                "Size mismatch for i16 value {value}",
572            );
573
574            // Also verify SInt wrapper
575            let sint = SInt(value);
576            assert_eq!(
577                sint.encode_size(),
578                buf.len(),
579                "SInt encode_size mismatch for value {value}",
580            );
581
582            // Verify we can decode it back correctly
583            let mut slice = &buf[..];
584            let decoded: i16 = read_signed(&mut slice).unwrap();
585            assert_eq!(decoded, value, "Decode mismatch for value {value}");
586            assert!(
587                slice.is_empty(),
588                "Buffer not fully consumed for value {value}",
589            );
590        }
591    }
592
593    #[test]
594    fn test_exact_bit_boundaries() {
595        // Test values with exactly N bits set
596        fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
597            for bits in 1..=128 {
598                // Create a value with exactly 'bits' bits
599                // e.g., bits=3 -> 0b111 = 7
600                let val = if bits == 128 {
601                    u128::MAX
602                } else {
603                    (1u128 << bits) - 1
604                };
605                let Ok(value) = T::try_from(val) else {
606                    continue;
607                };
608
609                // Compute expected size
610                let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
611                let calculated_size = size(value);
612                assert_eq!(
613                    calculated_size, expected_size,
614                    "Size calculation wrong for {val} with {bits} bits",
615                );
616
617                // Compare encoded size
618                let mut buf = Vec::new();
619                write(value, &mut buf);
620                assert_eq!(
621                    buf.len(),
622                    expected_size,
623                    "Encoded size wrong for {val} with {bits} bits",
624                );
625            }
626        }
627
628        test_exact_bits::<u16>();
629        test_exact_bits::<u32>();
630        test_exact_bits::<u64>();
631        test_exact_bits::<u128>();
632    }
633
634    #[test]
635    fn test_single_bit_boundaries() {
636        // Test values with only a single bit set at different positions
637        fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
638            for bit_pos in 0..128 {
639                // Create a value with only a single bit set at the given position
640                let val = 1u128 << bit_pos;
641                let Ok(value) = T::try_from(val) else {
642                    continue;
643                };
644
645                // Compute expected size
646                let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
647                let calculated_size = size(value);
648                assert_eq!(
649                    calculated_size, expected_size,
650                    "Size wrong for 1<<{bit_pos} = {val}",
651                );
652
653                // Compare encoded size
654                let mut buf = Vec::new();
655                write(value, &mut buf);
656                assert_eq!(
657                    buf.len(),
658                    expected_size,
659                    "Encoded size wrong for 1<<{bit_pos} = {val}",
660                );
661            }
662        }
663
664        test_single_bits::<u16>();
665        test_single_bits::<u32>();
666        test_single_bits::<u64>();
667        test_single_bits::<u128>();
668    }
669}