commonware_codec/
varint.rs

1//! Variable-length integer encoding and decoding
2//!
3//! # Overview
4//!
5//! This module implements Google's Protocol Buffers variable-length integer encoding.
6//! Each byte uses:
7//! - 7 bits for the value
8//! - 1 "continuation" bit to indicate if more bytes follow
9//!
10//! `u8` and `i8` are omitted since those types do not benefit from varint encoding.
11//!
12//! `usize` and `isize` are omitted to prevent behavior from depending on the target architecture.
13//!
14//! # Usage Example
15//!
16//! ```rust
17//! use commonware_codec::{Encode, DecodeExt, varint::{UInt, SInt}};
18//!
19//! // Unsigned example
20//! let one = UInt(42u128).encode();
21//! assert_eq!(one.len(), 1); // 42 fits in a single byte
22//! let decoded: u128 = UInt::decode(one).unwrap().into();
23//! assert_eq!(decoded, 42);
24//!
25//! // Signed example (ZigZag)
26//! let neg = SInt(-3i32).encode();
27//! assert_eq!(neg.len(), 1);
28//! let decoded: i32 = SInt::decode(neg).unwrap().into();
29//! assert_eq!(decoded, -3);
30//! ```
31
32use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use sealed::{SPrim, UPrim};
35use std::fmt::Debug;
36
37// ---------- Constants ----------
38
39/// The number of bits in a byte.
40const BITS_PER_BYTE: usize = 8;
41
42/// The number of data-bearing bits in a byte.
43/// That is, the number of bits in a byte excluding the continuation bit.
44const DATA_BITS_PER_BYTE: usize = 7;
45
46/// The mask for the data-bearing bits in a byte.
47const DATA_BITS_MASK: u8 = 0x7F;
48
49/// The mask for the continuation bit in a byte.
50const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52// ---------- Traits ----------
53
54#[doc(hidden)]
55mod sealed {
56    use super::*;
57    use std::ops::{BitOrAssign, Shl, ShrAssign};
58
59    /// A trait for unsigned integer primitives that can be varint encoded.
60    pub trait UPrim:
61        Copy
62        + From<u8>
63        + Sized
64        + FixedSize
65        + ShrAssign<usize>
66        + Shl<usize, Output = Self>
67        + BitOrAssign<Self>
68        + PartialOrd
69        + Debug
70    {
71        /// Returns the number of leading zeros in the integer.
72        fn leading_zeros(self) -> u32;
73
74        /// Returns the least significant byte of the integer.
75        fn as_u8(self) -> u8;
76    }
77
78    // Implements the `UPrim` trait for all unsigned integer types.
79    macro_rules! impl_uint {
80        ($type:ty) => {
81            impl UPrim for $type {
82                #[inline(always)]
83                fn leading_zeros(self) -> u32 {
84                    self.leading_zeros()
85                }
86
87                #[inline(always)]
88                fn as_u8(self) -> u8 {
89                    self as u8
90                }
91            }
92        };
93    }
94    impl_uint!(u16);
95    impl_uint!(u32);
96    impl_uint!(u64);
97    impl_uint!(u128);
98
99    /// A trait for signed integer primitives that can be converted to and from unsigned integer
100    /// primitives of the equivalent size.
101    ///
102    /// When converted to unsigned integers, the encoding is done using ZigZag encoding, which moves the
103    /// sign bit to the least significant bit (shifting all other bits to the left by one). This allows
104    /// for more efficient encoding of numbers that are close to zero, even if they are negative.
105    pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106        /// The unsigned equivalent type of the signed integer.
107        /// This type must be the same size as the signed integer type.
108        type UnsignedEquivalent: UPrim;
109
110        /// Compile-time assertion to ensure that the size of the signed integer is equal to the size of
111        /// the unsigned integer.
112        #[doc(hidden)]
113        const _COMMIT_OP_ASSERT: () =
114            assert!(std::mem::size_of::<Self>() == std::mem::size_of::<Self::UnsignedEquivalent>());
115
116        /// Converts the signed integer to an unsigned integer using ZigZag encoding.
117        fn as_zigzag(&self) -> Self::UnsignedEquivalent;
118
119        /// Converts a (ZigZag'ed) unsigned integer back to a signed integer.
120        fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
121    }
122
123    // Implements the `SPrim` trait for all signed integer types.
124    macro_rules! impl_sint {
125        ($type:ty, $utype:ty) => {
126            impl SPrim for $type {
127                type UnsignedEquivalent = $utype;
128
129                #[inline]
130                fn as_zigzag(&self) -> $utype {
131                    let shr = std::mem::size_of::<$utype>() * 8 - 1;
132                    ((self << 1) ^ (self >> shr)) as $utype
133                }
134                #[inline]
135                fn un_zigzag(value: $utype) -> Self {
136                    ((value >> 1) as $type) ^ (-((value & 1) as $type))
137                }
138            }
139        };
140    }
141    impl_sint!(i16, u16);
142    impl_sint!(i32, u32);
143    impl_sint!(i64, u64);
144    impl_sint!(i128, u128);
145}
146
147// ---------- Structs ----------
148
149/// An ergonomic wrapper to allow for encoding and decoding of primitive unsigned integers as
150/// varints rather than the default fixed-width integers.
151#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub struct UInt<U: UPrim>(pub U);
153
154// Implements `Into<U>` for `UInt<U>` for all unsigned integer types.
155// This allows for easy conversion from `UInt<U>` to `U` using `.into()`.
156macro_rules! impl_varuint_into {
157    ($($type:ty),+) => {
158        $(
159            impl From<UInt<$type>> for $type {
160                fn from(val: UInt<$type>) -> Self {
161                    val.0
162                }
163            }
164        )+
165    };
166}
167impl_varuint_into!(u16, u32, u64, u128);
168
169impl<U: UPrim> Write for UInt<U> {
170    fn write(&self, buf: &mut impl BufMut) {
171        write(self.0, buf);
172    }
173}
174
175impl<U: UPrim> Read for UInt<U> {
176    type Cfg = ();
177    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
178        read(buf).map(UInt)
179    }
180}
181
182impl<U: UPrim> EncodeSize for UInt<U> {
183    fn encode_size(&self) -> usize {
184        size(self.0)
185    }
186}
187
188/// An ergonomic wrapper to allow for encoding and decoding of primitive signed integers as
189/// varints rather than the default fixed-width integers.
190#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
191pub struct SInt<S: SPrim>(pub S);
192
193// Implements `Into<U>` for `SInt<U>` for all signed integer types.
194// This allows for easy conversion from `SInt<S>` to `S` using `.into()`.
195macro_rules! impl_varsint_into {
196    ($($type:ty),+) => {
197        $(
198            impl From<SInt<$type>> for $type {
199                fn from(val: SInt<$type>) -> Self {
200                    val.0
201                }
202            }
203        )+
204    };
205}
206impl_varsint_into!(i16, i32, i64, i128);
207
208impl<S: SPrim> Write for SInt<S> {
209    fn write(&self, buf: &mut impl BufMut) {
210        write_signed::<S>(self.0, buf);
211    }
212}
213
214impl<S: SPrim> Read for SInt<S> {
215    type Cfg = ();
216    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
217        read_signed::<S>(buf).map(SInt)
218    }
219}
220
221impl<S: SPrim> EncodeSize for SInt<S> {
222    fn encode_size(&self) -> usize {
223        size_signed::<S>(self.0)
224    }
225}
226
227// ---------- Helper Functions ----------
228
229/// Encodes an unsigned integer as a varint
230fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
231    let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
232    if value < continuation_threshold {
233        // Fast path for small values (common case for lengths).
234        // `as_u8()` does not truncate the value or leave a continuation bit.
235        buf.put_u8(value.as_u8());
236        return;
237    }
238
239    let mut val = value;
240    while val >= continuation_threshold {
241        buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
242        val >>= 7;
243    }
244    buf.put_u8(val.as_u8());
245}
246
247/// Decodes a unsigned integer from a varint.
248///
249/// Returns an error if:
250/// - The varint is invalid (too long or malformed)
251/// - The buffer ends while reading
252fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
253    let max_bits = T::SIZE * BITS_PER_BYTE;
254    let mut result: T = T::from(0);
255    let mut bits_read = 0;
256
257    // Loop over all the bytes.
258    loop {
259        // Read the next byte.
260        let byte = u8::read(buf)?;
261
262        // If this is not the first byte, but the byte is completely zero, we have an invalid
263        // varint. This is because this byte has no data bits and no continuation, so there was no
264        // point in continuing to this byte in the first place. While the output could still result
265        // in a valid value, we ensure that every value has exactly one unique, valid encoding.
266        if byte == 0 && bits_read > 0 {
267            return Err(Error::InvalidVarint(T::SIZE));
268        }
269
270        // If this must be the last byte, check for overflow (i.e. set bits beyond the size of T).
271        // Because the continuation bit is the most-significant bit, this check also happens to
272        // check for an invalid continuation bit.
273        //
274        // If we have reached what must be the last byte, this check prevents continuing to read
275        // from the buffer by ensuring that the conditional (`if byte & CONTINUATION_BIT_MASK == 0`)
276        // always evaluates to true.
277        let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
278        if remaining_bits <= DATA_BITS_PER_BYTE {
279            let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
280            if relevant_bits > remaining_bits {
281                return Err(Error::InvalidVarint(T::SIZE));
282            }
283        }
284
285        // Write the 7 bits of data to the result.
286        result |= T::from(byte & DATA_BITS_MASK) << bits_read;
287
288        // If the continuation bit is not set, return.
289        if byte & CONTINUATION_BIT_MASK == 0 {
290            return Ok(result);
291        }
292
293        bits_read += DATA_BITS_PER_BYTE;
294    }
295}
296
297/// Calculates the number of bytes needed to encode an unsigned integer as a varint.
298fn size<T: UPrim>(value: T) -> usize {
299    let total_bits = std::mem::size_of::<T>() * 8;
300    let leading_zeros = value.leading_zeros() as usize;
301    let data_bits = total_bits - leading_zeros;
302    usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
303}
304
305/// Encodes a signed integer as a varint using ZigZag encoding.
306fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
307    write(value.as_zigzag(), buf);
308}
309
310/// Decodes a signed integer from varint ZigZag encoding.
311fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
312    Ok(S::un_zigzag(read(buf)?))
313}
314
315/// Calculates the number of bytes needed to encode a signed integer as a varint.
316fn size_signed<S: SPrim>(value: S) -> usize {
317    size(value.as_zigzag())
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{error::Error, DecodeExt, Encode};
324    use bytes::Bytes;
325
326    #[test]
327    fn test_end_of_buffer() {
328        let mut buf: Bytes = Bytes::from_static(&[]);
329        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
330
331        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
332        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
333
334        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
335        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
336    }
337
338    #[test]
339    fn test_overflow() {
340        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
341        assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
342
343        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
344        assert!(matches!(
345            read::<u32>(&mut buf),
346            Err(Error::InvalidVarint(u32::SIZE))
347        ));
348
349        let mut buf =
350            Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
351        assert!(matches!(
352            read::<u64>(&mut buf),
353            Err(Error::InvalidVarint(u64::SIZE))
354        ));
355    }
356
357    #[test]
358    fn test_overcontinuation() {
359        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
360        let result = read::<u32>(&mut buf);
361        assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
362    }
363
364    #[test]
365    fn test_zeroed_byte() {
366        let mut buf = Bytes::from_static(&[0xFF, 0x00]);
367        let result = read::<u64>(&mut buf);
368        assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
369    }
370
371    /// Core round-trip check, generic over any UPrim.
372    fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
373        const CASES: &[u128] = &[
374            0,
375            1,
376            127,
377            128,
378            129,
379            0xFF,
380            0x100,
381            0x3FFF,
382            0x4000,
383            0x1_FFFF,
384            0xFF_FFFF,
385            0x1_FF_FF_FF_FF,
386            0xFF_FF_FF_FF_FF_FF,
387            0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
388            u16::MAX as u128,
389            u32::MAX as u128,
390            u64::MAX as u128,
391            u128::MAX,
392        ];
393
394        for &raw in CASES {
395            // skip values that don't fit into T
396            let Ok(value) = raw.try_into() else { continue };
397            let value: T = value;
398
399            // size matches encoding length
400            let mut buf = Vec::new();
401            write(value, &mut buf);
402            assert_eq!(buf.len(), size(value));
403
404            // decode matches original value
405            let mut slice = &buf[..];
406            let decoded: T = read(&mut slice).unwrap();
407            assert_eq!(decoded, value);
408            assert!(slice.is_empty());
409
410            // UInt wrapper
411            let encoded = UInt(value).encode();
412            assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
413        }
414    }
415
416    #[test]
417    fn test_varuint() {
418        varuint_round_trip::<u16>();
419        varuint_round_trip::<u32>();
420        varuint_round_trip::<u64>();
421        varuint_round_trip::<u128>();
422    }
423
424    fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
425        const CASES: &[i128] = &[
426            0,
427            1,
428            -1,
429            2,
430            -2,
431            127,
432            -127,
433            128,
434            -128,
435            129,
436            -129,
437            0x7FFFFFFF,
438            -0x7FFFFFFF,
439            0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
440            -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
441            i16::MIN as i128,
442            i16::MAX as i128,
443            i32::MIN as i128,
444            i32::MAX as i128,
445            i64::MIN as i128,
446            i64::MAX as i128,
447        ];
448
449        for &raw in CASES {
450            // skip values that don't fit into T
451            let Ok(value) = raw.try_into() else { continue };
452            let value: T = value;
453
454            // size matches encoding length
455            let mut buf = Vec::new();
456            write_signed(value, &mut buf);
457            assert_eq!(buf.len(), size_signed(value));
458
459            // decode matches original value
460            let mut slice = &buf[..];
461            let decoded: T = read_signed(&mut slice).unwrap();
462            assert_eq!(decoded, value);
463            assert!(slice.is_empty());
464
465            // SInt wrapper
466            let encoded = SInt(value).encode();
467            assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
468        }
469    }
470
471    #[test]
472    fn test_varsint() {
473        varsint_round_trip::<i16>();
474        varsint_round_trip::<i32>();
475        varsint_round_trip::<i64>();
476        varsint_round_trip::<i128>();
477    }
478
479    #[test]
480    fn test_varuint_into() {
481        let v32: u32 = 0x1_FFFF;
482        let out32: u32 = UInt(v32).into();
483        assert_eq!(v32, out32);
484
485        let v64: u64 = 0x1_FF_FF_FF_FF;
486        let out64: u64 = UInt(v64).into();
487        assert_eq!(v64, out64);
488    }
489
490    #[test]
491    fn test_varsint_into() {
492        let s32: i32 = -123_456;
493        let out32: i32 = SInt(s32).into();
494        assert_eq!(s32, out32);
495
496        let s64: i64 = 987_654_321;
497        let out64: i64 = SInt(s64).into();
498        assert_eq!(s64, out64);
499    }
500
501    #[test]
502    fn test_conformity() {
503        assert_eq!(0usize.encode(), &[0x00][..]);
504        assert_eq!(1usize.encode(), &[0x01][..]);
505        assert_eq!(127usize.encode(), &[0x7F][..]);
506        assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
507        assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
508        assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
509        assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
510        assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
511        assert_eq!(
512            (u32::MAX as usize).encode(),
513            &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
514        );
515    }
516
517    #[test]
518    fn test_all_u16_values() {
519        // Exhaustively test all u16 values to ensure size matches encoding
520        for i in 0..=u16::MAX {
521            let value = i;
522            let calculated_size = size(value);
523
524            let mut buf = Vec::new();
525            write(value, &mut buf);
526
527            assert_eq!(
528                buf.len(),
529                calculated_size,
530                "Size mismatch for u16 value {value}",
531            );
532
533            // Also verify UInt wrapper
534            let uint = UInt(value);
535            assert_eq!(
536                uint.encode_size(),
537                buf.len(),
538                "UInt encode_size mismatch for value {value}",
539            );
540        }
541    }
542
543    #[test]
544    fn test_all_i16_values() {
545        // Exhaustively test all i16 values to ensure size matches encoding
546        for i in i16::MIN..=i16::MAX {
547            let value = i;
548            let calculated_size = size_signed(value);
549
550            let mut buf = Vec::new();
551            write_signed(value, &mut buf);
552
553            assert_eq!(
554                buf.len(),
555                calculated_size,
556                "Size mismatch for i16 value {value}",
557            );
558
559            // Also verify SInt wrapper
560            let sint = SInt(value);
561            assert_eq!(
562                sint.encode_size(),
563                buf.len(),
564                "SInt encode_size mismatch for value {value}",
565            );
566
567            // Verify we can decode it back correctly
568            let mut slice = &buf[..];
569            let decoded: i16 = read_signed(&mut slice).unwrap();
570            assert_eq!(decoded, value, "Decode mismatch for value {value}");
571            assert!(
572                slice.is_empty(),
573                "Buffer not fully consumed for value {value}",
574            );
575        }
576    }
577
578    #[test]
579    fn test_exact_bit_boundaries() {
580        // Test values with exactly N bits set
581        fn test_exact_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
582            for bits in 1..=128 {
583                // Create a value with exactly 'bits' bits
584                // e.g., bits=3 -> 0b111 = 7
585                let val = if bits == 128 {
586                    u128::MAX
587                } else {
588                    (1u128 << bits) - 1
589                };
590                let Ok(value) = T::try_from(val) else {
591                    continue;
592                };
593
594                // Compute expected size
595                let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
596                let calculated_size = size(value);
597                assert_eq!(
598                    calculated_size, expected_size,
599                    "Size calculation wrong for {val} with {bits} bits",
600                );
601
602                // Compare encoded size
603                let mut buf = Vec::new();
604                write(value, &mut buf);
605                assert_eq!(
606                    buf.len(),
607                    expected_size,
608                    "Encoded size wrong for {val} with {bits} bits",
609                );
610            }
611        }
612
613        test_exact_bits::<u16>();
614        test_exact_bits::<u32>();
615        test_exact_bits::<u64>();
616        test_exact_bits::<u128>();
617    }
618
619    #[test]
620    fn test_single_bit_boundaries() {
621        // Test values with only a single bit set at different positions
622        fn test_single_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
623            for bit_pos in 0..128 {
624                // Create a value with only a single bit set at the given position
625                let val = 1u128 << bit_pos;
626                let Ok(value) = T::try_from(val) else {
627                    continue;
628                };
629
630                // Compute expected size
631                let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
632                let calculated_size = size(value);
633                assert_eq!(
634                    calculated_size, expected_size,
635                    "Size wrong for 1<<{bit_pos} = {val}",
636                );
637
638                // Compare encoded size
639                let mut buf = Vec::new();
640                write(value, &mut buf);
641                assert_eq!(
642                    buf.len(),
643                    expected_size,
644                    "Encoded size wrong for 1<<{bit_pos} = {val}",
645                );
646            }
647        }
648
649        test_single_bits::<u16>();
650        test_single_bits::<u32>();
651        test_single_bits::<u64>();
652        test_single_bits::<u128>();
653    }
654}