commonware_codec/
varint.rs

1//! Variable-length integer encoding and decoding
2//!
3//! # Overview
4//!
5//! This module implements Google's Protocol Buffers variable-length integer encoding.
6//! Each byte uses:
7//! - 7 bits for the value
8//! - 1 "continuation" bit to indicate if more bytes follow
9//!
10//! `u8` and `i8` are omitted since those types do not benefit from varint encoding.
11//!
12//! `usize` and `isize` are omitted to prevent behavior from depending on the target architecture.
13//!
14//! # Usage Example
15//!
16//! ```rust
17//! use commonware_codec::{Encode, DecodeExt, varint::{UInt, SInt}};
18//!
19//! // Unsigned example
20//! let one = UInt(42u128).encode();
21//! assert_eq!(one.len(), 1); // 42 fits in a single byte
22//! let decoded: u128 = UInt::decode(one).unwrap().into();
23//! assert_eq!(decoded, 42);
24//!
25//! // Signed example (ZigZag)
26//! let neg = SInt(-3i32).encode();
27//! assert_eq!(neg.len(), 1);
28//! let decoded: i32 = SInt::decode(neg).unwrap().into();
29//! assert_eq!(decoded, -3);
30//! ```
31
32use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37// ---------- Constants ----------
38
39/// The number of bits in a byte.
40const BITS_PER_BYTE: usize = 8;
41
42/// The number of data-bearing bits in a byte.
43/// That is, the number of bits in a byte excluding the continuation bit.
44const DATA_BITS_PER_BYTE: usize = 7;
45
46/// The mask for the data-bearing bits in a byte.
47const DATA_BITS_MASK: u8 = 0x7F;
48
49/// The mask for the continuation bit in a byte.
50const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52// ---------- Traits ----------
53
54#[doc(hidden)]
55mod sealed {
56    use super::*;
57    use core::ops::{BitOrAssign, Shl, ShrAssign};
58
59    /// A trait for unsigned integer primitives that can be varint encoded.
60    pub trait UPrim:
61        Copy
62        + From<u8>
63        + Sized
64        + FixedSize
65        + ShrAssign<usize>
66        + Shl<usize, Output = Self>
67        + BitOrAssign<Self>
68        + PartialOrd
69        + Debug
70    {
71        /// Returns the number of leading zeros in the integer.
72        fn leading_zeros(self) -> u32;
73
74        /// Returns the least significant byte of the integer.
75        fn as_u8(self) -> u8;
76    }
77
78    // Implements the `UPrim` trait for all unsigned integer types.
79    macro_rules! impl_uint {
80        ($type:ty) => {
81            impl UPrim for $type {
82                #[inline(always)]
83                fn leading_zeros(self) -> u32 {
84                    self.leading_zeros()
85                }
86
87                #[inline(always)]
88                fn as_u8(self) -> u8 {
89                    self as u8
90                }
91            }
92        };
93    }
94    impl_uint!(u16);
95    impl_uint!(u32);
96    impl_uint!(u64);
97    impl_uint!(u128);
98
99    /// A trait for signed integer primitives that can be converted to and from unsigned integer
100    /// primitives of the equivalent size.
101    ///
102    /// When converted to unsigned integers, the encoding is done using ZigZag encoding, which moves the
103    /// sign bit to the least significant bit (shifting all other bits to the left by one). This allows
104    /// for more efficient encoding of numbers that are close to zero, even if they are negative.
105    pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106        /// The unsigned equivalent type of the signed integer.
107        /// This type must be the same size as the signed integer type.
108        type UnsignedEquivalent: UPrim;
109
110        /// Compile-time assertion to ensure that the size of the signed integer is equal to the size of
111        /// the unsigned integer.
112        #[doc(hidden)]
113        const _COMMIT_OP_ASSERT: () =
114            assert!(size_of::<Self>() == size_of::<Self::UnsignedEquivalent>());
115
116        /// Converts the signed integer to an unsigned integer using ZigZag encoding.
117        fn as_zigzag(&self) -> Self::UnsignedEquivalent;
118
119        /// Converts a (ZigZag'ed) unsigned integer back to a signed integer.
120        fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
121    }
122
123    // Implements the `SPrim` trait for all signed integer types.
124    macro_rules! impl_sint {
125        ($type:ty, $utype:ty) => {
126            impl SPrim for $type {
127                type UnsignedEquivalent = $utype;
128
129                #[inline]
130                fn as_zigzag(&self) -> $utype {
131                    let shr = size_of::<$utype>() * 8 - 1;
132                    ((self << 1) ^ (self >> shr)) as $utype
133                }
134                #[inline]
135                fn un_zigzag(value: $utype) -> Self {
136                    ((value >> 1) as $type) ^ (-((value & 1) as $type))
137                }
138            }
139        };
140    }
141    impl_sint!(i16, u16);
142    impl_sint!(i32, u32);
143    impl_sint!(i64, u64);
144    impl_sint!(i128, u128);
145}
146
147// ---------- Structs ----------
148
149/// An ergonomic wrapper to allow for encoding and decoding of primitive unsigned integers as
150/// varints rather than the default fixed-width integers.
151#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub struct UInt<U: UPrim>(pub U);
153
154// Implements `Into<U>` for `UInt<U>` for all unsigned integer types.
155// This allows for easy conversion from `UInt<U>` to `U` using `.into()`.
156macro_rules! impl_varuint_into {
157    ($($type:ty),+) => {
158        $(
159            impl From<UInt<$type>> for $type {
160                fn from(val: UInt<$type>) -> Self {
161                    val.0
162                }
163            }
164        )+
165    };
166}
167impl_varuint_into!(u16, u32, u64, u128);
168
169impl<U: UPrim> Write for UInt<U> {
170    fn write(&self, buf: &mut impl BufMut) {
171        write(self.0, buf);
172    }
173}
174
175impl<U: UPrim> Read for UInt<U> {
176    type Cfg = ();
177    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
178        read(buf).map(UInt)
179    }
180}
181
182impl<U: UPrim> EncodeSize for UInt<U> {
183    fn encode_size(&self) -> usize {
184        size(self.0)
185    }
186}
187
188/// An ergonomic wrapper to allow for encoding and decoding of primitive signed integers as
189/// varints rather than the default fixed-width integers.
190#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
191pub struct SInt<S: SPrim>(pub S);
192
193// Implements `Into<U>` for `SInt<U>` for all signed integer types.
194// This allows for easy conversion from `SInt<S>` to `S` using `.into()`.
195macro_rules! impl_varsint_into {
196    ($($type:ty),+) => {
197        $(
198            impl From<SInt<$type>> for $type {
199                fn from(val: SInt<$type>) -> Self {
200                    val.0
201                }
202            }
203        )+
204    };
205}
206impl_varsint_into!(i16, i32, i64, i128);
207
208impl<S: SPrim> Write for SInt<S> {
209    fn write(&self, buf: &mut impl BufMut) {
210        write_signed::<S>(self.0, buf);
211    }
212}
213
214impl<S: SPrim> Read for SInt<S> {
215    type Cfg = ();
216    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
217        read_signed::<S>(buf).map(SInt)
218    }
219}
220
221impl<S: SPrim> EncodeSize for SInt<S> {
222    fn encode_size(&self) -> usize {
223        size_signed::<S>(self.0)
224    }
225}
226
227// ---------- Helper Functions ----------
228
229/// Encodes an unsigned integer as a varint
230fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
231    let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
232    if value < continuation_threshold {
233        // Fast path for small values (common case for lengths).
234        // `as_u8()` does not truncate the value or leave a continuation bit.
235        buf.put_u8(value.as_u8());
236        return;
237    }
238
239    let mut val = value;
240    while val >= continuation_threshold {
241        buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
242        val >>= 7;
243    }
244    buf.put_u8(val.as_u8());
245}
246
247/// Decodes a unsigned integer from a varint.
248///
249/// Returns an error if:
250/// - The varint is invalid (too long or malformed)
251/// - The buffer ends while reading
252fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
253    let max_bits = T::SIZE * BITS_PER_BYTE;
254    let mut result: T = T::from(0);
255    let mut bits_read = 0;
256
257    // Loop over all the bytes.
258    loop {
259        // Read the next byte.
260        let byte = u8::read(buf)?;
261
262        // If this is not the first byte, but the byte is completely zero, we have an invalid
263        // varint. This is because this byte has no data bits and no continuation, so there was no
264        // point in continuing to this byte in the first place. While the output could still result
265        // in a valid value, we ensure that every value has exactly one unique, valid encoding.
266        if byte == 0 && bits_read > 0 {
267            return Err(Error::InvalidVarint(T::SIZE));
268        }
269
270        // If this must be the last byte, check for overflow (i.e. set bits beyond the size of T).
271        // Because the continuation bit is the most-significant bit, this check also happens to
272        // check for an invalid continuation bit.
273        //
274        // If we have reached what must be the last byte, this check prevents continuing to read
275        // from the buffer by ensuring that the conditional (`if byte & CONTINUATION_BIT_MASK == 0`)
276        // always evaluates to true.
277        let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
278        if remaining_bits <= DATA_BITS_PER_BYTE {
279            let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
280            if relevant_bits > remaining_bits {
281                return Err(Error::InvalidVarint(T::SIZE));
282            }
283        }
284
285        // Write the 7 bits of data to the result.
286        result |= T::from(byte & DATA_BITS_MASK) << bits_read;
287
288        // If the continuation bit is not set, return.
289        if byte & CONTINUATION_BIT_MASK == 0 {
290            return Ok(result);
291        }
292
293        bits_read += DATA_BITS_PER_BYTE;
294    }
295}
296
297/// Calculates the number of bytes needed to encode an unsigned integer as a varint.
298fn size<T: UPrim>(value: T) -> usize {
299    let total_bits = size_of::<T>() * 8;
300    let leading_zeros = value.leading_zeros() as usize;
301    let data_bits = total_bits - leading_zeros;
302    usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
303}
304
305/// Encodes a signed integer as a varint using ZigZag encoding.
306fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
307    write(value.as_zigzag(), buf);
308}
309
310/// Decodes a signed integer from varint ZigZag encoding.
311fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
312    Ok(S::un_zigzag(read(buf)?))
313}
314
315/// Calculates the number of bytes needed to encode a signed integer as a varint.
316fn size_signed<S: SPrim>(value: S) -> usize {
317    size(value.as_zigzag())
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{error::Error, DecodeExt, Encode};
324    #[cfg(not(feature = "std"))]
325    use alloc::vec::Vec;
326    use bytes::Bytes;
327
328    #[test]
329    fn test_end_of_buffer() {
330        let mut buf: Bytes = Bytes::from_static(&[]);
331        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
332
333        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
334        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
335
336        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
337        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
338    }
339
340    #[test]
341    fn test_overflow() {
342        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
343        assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
344
345        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
346        assert!(matches!(
347            read::<u32>(&mut buf),
348            Err(Error::InvalidVarint(u32::SIZE))
349        ));
350
351        let mut buf =
352            Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
353        assert!(matches!(
354            read::<u64>(&mut buf),
355            Err(Error::InvalidVarint(u64::SIZE))
356        ));
357    }
358
359    #[test]
360    fn test_overcontinuation() {
361        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
362        let result = read::<u32>(&mut buf);
363        assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
364    }
365
366    #[test]
367    fn test_zeroed_byte() {
368        let mut buf = Bytes::from_static(&[0xFF, 0x00]);
369        let result = read::<u64>(&mut buf);
370        assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
371    }
372
373    /// Core round-trip check, generic over any UPrim.
374    fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
375        const CASES: &[u128] = &[
376            0,
377            1,
378            127,
379            128,
380            129,
381            0xFF,
382            0x100,
383            0x3FFF,
384            0x4000,
385            0x1_FFFF,
386            0xFF_FFFF,
387            0x1_FF_FF_FF_FF,
388            0xFF_FF_FF_FF_FF_FF,
389            0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
390            u16::MAX as u128,
391            u32::MAX as u128,
392            u64::MAX as u128,
393            u128::MAX,
394        ];
395
396        for &raw in CASES {
397            // skip values that don't fit into T
398            let Ok(value) = raw.try_into() else { continue };
399            let value: T = value;
400
401            // size matches encoding length
402            let mut buf = Vec::new();
403            write(value, &mut buf);
404            assert_eq!(buf.len(), size(value));
405
406            // decode matches original value
407            let mut slice = &buf[..];
408            let decoded: T = read(&mut slice).unwrap();
409            assert_eq!(decoded, value);
410            assert!(slice.is_empty());
411
412            // UInt wrapper
413            let encoded = UInt(value).encode();
414            assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
415        }
416    }
417
418    #[test]
419    fn test_varuint() {
420        varuint_round_trip::<u16>();
421        varuint_round_trip::<u32>();
422        varuint_round_trip::<u64>();
423        varuint_round_trip::<u128>();
424    }
425
426    fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
427        const CASES: &[i128] = &[
428            0,
429            1,
430            -1,
431            2,
432            -2,
433            127,
434            -127,
435            128,
436            -128,
437            129,
438            -129,
439            0x7FFFFFFF,
440            -0x7FFFFFFF,
441            0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
442            -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
443            i16::MIN as i128,
444            i16::MAX as i128,
445            i32::MIN as i128,
446            i32::MAX as i128,
447            i64::MIN as i128,
448            i64::MAX as i128,
449        ];
450
451        for &raw in CASES {
452            // skip values that don't fit into T
453            let Ok(value) = raw.try_into() else { continue };
454            let value: T = value;
455
456            // size matches encoding length
457            let mut buf = Vec::new();
458            write_signed(value, &mut buf);
459            assert_eq!(buf.len(), size_signed(value));
460
461            // decode matches original value
462            let mut slice = &buf[..];
463            let decoded: T = read_signed(&mut slice).unwrap();
464            assert_eq!(decoded, value);
465            assert!(slice.is_empty());
466
467            // SInt wrapper
468            let encoded = SInt(value).encode();
469            assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
470        }
471    }
472
473    #[test]
474    fn test_varsint() {
475        varsint_round_trip::<i16>();
476        varsint_round_trip::<i32>();
477        varsint_round_trip::<i64>();
478        varsint_round_trip::<i128>();
479    }
480
481    #[test]
482    fn test_varuint_into() {
483        let v32: u32 = 0x1_FFFF;
484        let out32: u32 = UInt(v32).into();
485        assert_eq!(v32, out32);
486
487        let v64: u64 = 0x1_FF_FF_FF_FF;
488        let out64: u64 = UInt(v64).into();
489        assert_eq!(v64, out64);
490    }
491
492    #[test]
493    fn test_varsint_into() {
494        let s32: i32 = -123_456;
495        let out32: i32 = SInt(s32).into();
496        assert_eq!(s32, out32);
497
498        let s64: i64 = 987_654_321;
499        let out64: i64 = SInt(s64).into();
500        assert_eq!(s64, out64);
501    }
502
503    #[test]
504    fn test_conformity() {
505        assert_eq!(0usize.encode(), &[0x00][..]);
506        assert_eq!(1usize.encode(), &[0x01][..]);
507        assert_eq!(127usize.encode(), &[0x7F][..]);
508        assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
509        assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
510        assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
511        assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
512        assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
513        assert_eq!(
514            (u32::MAX as usize).encode(),
515            &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
516        );
517    }
518
519    #[test]
520    fn test_all_u16_values() {
521        // Exhaustively test all u16 values to ensure size matches encoding
522        for i in 0..=u16::MAX {
523            let value = i;
524            let calculated_size = size(value);
525
526            let mut buf = Vec::new();
527            write(value, &mut buf);
528
529            assert_eq!(
530                buf.len(),
531                calculated_size,
532                "Size mismatch for u16 value {value}",
533            );
534
535            // Also verify UInt wrapper
536            let uint = UInt(value);
537            assert_eq!(
538                uint.encode_size(),
539                buf.len(),
540                "UInt encode_size mismatch for value {value}",
541            );
542        }
543    }
544
545    #[test]
546    fn test_all_i16_values() {
547        // Exhaustively test all i16 values to ensure size matches encoding
548        for i in i16::MIN..=i16::MAX {
549            let value = i;
550            let calculated_size = size_signed(value);
551
552            let mut buf = Vec::new();
553            write_signed(value, &mut buf);
554
555            assert_eq!(
556                buf.len(),
557                calculated_size,
558                "Size mismatch for i16 value {value}",
559            );
560
561            // Also verify SInt wrapper
562            let sint = SInt(value);
563            assert_eq!(
564                sint.encode_size(),
565                buf.len(),
566                "SInt encode_size mismatch for value {value}",
567            );
568
569            // Verify we can decode it back correctly
570            let mut slice = &buf[..];
571            let decoded: i16 = read_signed(&mut slice).unwrap();
572            assert_eq!(decoded, value, "Decode mismatch for value {value}");
573            assert!(
574                slice.is_empty(),
575                "Buffer not fully consumed for value {value}",
576            );
577        }
578    }
579
580    #[test]
581    fn test_exact_bit_boundaries() {
582        // Test values with exactly N bits set
583        fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
584            for bits in 1..=128 {
585                // Create a value with exactly 'bits' bits
586                // e.g., bits=3 -> 0b111 = 7
587                let val = if bits == 128 {
588                    u128::MAX
589                } else {
590                    (1u128 << bits) - 1
591                };
592                let Ok(value) = T::try_from(val) else {
593                    continue;
594                };
595
596                // Compute expected size
597                let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
598                let calculated_size = size(value);
599                assert_eq!(
600                    calculated_size, expected_size,
601                    "Size calculation wrong for {val} with {bits} bits",
602                );
603
604                // Compare encoded size
605                let mut buf = Vec::new();
606                write(value, &mut buf);
607                assert_eq!(
608                    buf.len(),
609                    expected_size,
610                    "Encoded size wrong for {val} with {bits} bits",
611                );
612            }
613        }
614
615        test_exact_bits::<u16>();
616        test_exact_bits::<u32>();
617        test_exact_bits::<u64>();
618        test_exact_bits::<u128>();
619    }
620
621    #[test]
622    fn test_single_bit_boundaries() {
623        // Test values with only a single bit set at different positions
624        fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
625            for bit_pos in 0..128 {
626                // Create a value with only a single bit set at the given position
627                let val = 1u128 << bit_pos;
628                let Ok(value) = T::try_from(val) else {
629                    continue;
630                };
631
632                // Compute expected size
633                let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
634                let calculated_size = size(value);
635                assert_eq!(
636                    calculated_size, expected_size,
637                    "Size wrong for 1<<{bit_pos} = {val}",
638                );
639
640                // Compare encoded size
641                let mut buf = Vec::new();
642                write(value, &mut buf);
643                assert_eq!(
644                    buf.len(),
645                    expected_size,
646                    "Encoded size wrong for 1<<{bit_pos} = {val}",
647                );
648            }
649        }
650
651        test_single_bits::<u16>();
652        test_single_bits::<u32>();
653        test_single_bits::<u64>();
654        test_single_bits::<u128>();
655    }
656}