commonware_codec/
varint.rs

1//! Variable-length integer encoding and decoding
2//!
3//! # Overview
4//!
5//! This module implements Google's Protocol Buffers variable-length integer encoding.
6//! Each byte uses:
7//! - 7 bits for the value
8//! - 1 "continuation" bit to indicate if more bytes follow
9//!
10//! `u8` and `i8` are omitted since those types do not benefit from varint encoding.
11//!
12//! `usize` and `isize` are omitted to prevent behavior from depending on the target architecture.
13//!
14//! # Usage Example
15//!
16//! ```rust
17//! use commonware_codec::{Encode, DecodeExt, varint::{UInt, SInt}};
18//!
19//! // Unsigned example
20//! let one = UInt(42u128).encode();
21//! assert_eq!(one.len(), 1); // 42 fits in a single byte
22//! let decoded: u128 = UInt::decode(one).unwrap().into();
23//! assert_eq!(decoded, 42);
24//!
25//! // Signed example (ZigZag)
26//! let neg = SInt(-3i32).encode();
27//! assert_eq!(neg.len(), 1);
28//! let decoded: i32 = SInt::decode(neg).unwrap().into();
29//! assert_eq!(decoded, -3);
30//! ```
31
32use crate::{EncodeSize, Error, FixedSize, Read, Write};
33use bytes::{Buf, BufMut};
34use sealed::{SPrim, UPrim};
35use std::fmt::Debug;
36
37// ---------- Constants ----------
38
39/// The number of bits in a byte.
40const BITS_PER_BYTE: usize = 8;
41
42/// The number of data-bearing bits in a byte.
43/// That is, the number of bits in a byte excluding the continuation bit.
44const DATA_BITS_PER_BYTE: usize = 7;
45
46/// The mask for the data-bearing bits in a byte.
47const DATA_BITS_MASK: u8 = 0x7F;
48
49/// The mask for the continuation bit in a byte.
50const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52// ---------- Traits ----------
53mod sealed {
54    use super::*;
55    use std::ops::{BitOrAssign, Shl, ShrAssign};
56
57    /// A trait for unsigned integer primitives that can be varint encoded.
58    pub trait UPrim:
59        Copy
60        + From<u8>
61        + Sized
62        + FixedSize
63        + ShrAssign<usize>
64        + Shl<usize, Output = Self>
65        + BitOrAssign<Self>
66        + PartialOrd
67        + Debug
68    {
69        /// Returns the number of leading zeros in the integer.
70        fn leading_zeros(self) -> u32;
71
72        /// Returns the least significant byte of the integer.
73        fn as_u8(self) -> u8;
74    }
75
76    // Implements the `UPrim` trait for all unsigned integer types.
77    macro_rules! impl_uint {
78        ($type:ty) => {
79            impl UPrim for $type {
80                #[inline(always)]
81                fn leading_zeros(self) -> u32 {
82                    self.leading_zeros()
83                }
84
85                #[inline(always)]
86                fn as_u8(self) -> u8 {
87                    self as u8
88                }
89            }
90        };
91    }
92    impl_uint!(u16);
93    impl_uint!(u32);
94    impl_uint!(u64);
95    impl_uint!(u128);
96
97    /// A trait for signed integer primitives that can be converted to and from unsigned integer
98    /// primitives of the equivalent size.
99    ///
100    /// When converted to unsigned integers, the encoding is done using ZigZag encoding, which moves the
101    /// sign bit to the least significant bit (shifting all other bits to the left by one). This allows
102    /// for more efficient encoding of numbers that are close to zero, even if they are negative.
103    pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
104        /// The unsigned equivalent type of the signed integer.
105        /// This type must be the same size as the signed integer type.
106        type UnsignedEquivalent: UPrim;
107
108        /// Compile-time assertion to ensure that the size of the signed integer is equal to the size of
109        /// the unsigned integer.
110        #[doc(hidden)]
111        const _COMMIT_OP_ASSERT: () =
112            assert!(std::mem::size_of::<Self>() == std::mem::size_of::<Self::UnsignedEquivalent>());
113
114        /// Converts the signed integer to an unsigned integer using ZigZag encoding.
115        fn as_zigzag(&self) -> Self::UnsignedEquivalent;
116
117        /// Converts a (ZigZag'ed) unsigned integer back to a signed integer.
118        fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
119    }
120
121    // Implements the `SPrim` trait for all signed integer types.
122    macro_rules! impl_sint {
123        ($type:ty, $utype:ty) => {
124            impl SPrim for $type {
125                type UnsignedEquivalent = $utype;
126
127                #[inline]
128                fn as_zigzag(&self) -> $utype {
129                    let shr = std::mem::size_of::<$utype>() * 8 - 1;
130                    ((self << 1) ^ (self >> shr)) as $utype
131                }
132                #[inline]
133                fn un_zigzag(value: $utype) -> Self {
134                    ((value >> 1) as $type) ^ (-((value & 1) as $type))
135                }
136            }
137        };
138    }
139    impl_sint!(i16, u16);
140    impl_sint!(i32, u32);
141    impl_sint!(i64, u64);
142    impl_sint!(i128, u128);
143}
144
145// ---------- Structs ----------
146
147/// An ergonomic wrapper to allow for encoding and decoding of primitive unsigned integers as
148/// varints rather than the default fixed-width integers.
149#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
150pub struct UInt<U: UPrim>(pub U);
151
152// Implements `Into<U>` for `UInt<U>` for all unsigned integer types.
153// This allows for easy conversion from `UInt<U>` to `U` using `.into()`.
154macro_rules! impl_varuint_into {
155    ($($type:ty),+) => {
156        $(
157            impl From<UInt<$type>> for $type {
158                fn from(val: UInt<$type>) -> Self {
159                    val.0
160                }
161            }
162        )+
163    };
164}
165impl_varuint_into!(u16, u32, u64, u128);
166
167impl<U: UPrim> Write for UInt<U> {
168    fn write(&self, buf: &mut impl BufMut) {
169        write(self.0, buf);
170    }
171}
172
173impl<U: UPrim> Read for UInt<U> {
174    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
175        read(buf).map(UInt)
176    }
177}
178
179impl<U: UPrim> EncodeSize for UInt<U> {
180    fn encode_size(&self) -> usize {
181        size(self.0)
182    }
183}
184
185/// An ergonomic wrapper to allow for encoding and decoding of primitive signed integers as
186/// varints rather than the default fixed-width integers.
187#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
188pub struct SInt<S: SPrim>(pub S);
189
190// Implements `Into<U>` for `SInt<U>` for all signed integer types.
191// This allows for easy conversion from `SInt<S>` to `S` using `.into()`.
192macro_rules! impl_varsint_into {
193    ($($type:ty),+) => {
194        $(
195            impl From<SInt<$type>> for $type {
196                fn from(val: SInt<$type>) -> Self {
197                    val.0
198                }
199            }
200        )+
201    };
202}
203impl_varsint_into!(i16, i32, i64, i128);
204
205impl<S: SPrim> Write for SInt<S> {
206    fn write(&self, buf: &mut impl BufMut) {
207        write_signed::<S>(self.0, buf);
208    }
209}
210
211impl<S: SPrim> Read for SInt<S> {
212    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
213        read_signed::<S>(buf).map(SInt)
214    }
215}
216
217impl<S: SPrim> EncodeSize for SInt<S> {
218    fn encode_size(&self) -> usize {
219        size_signed::<S>(self.0)
220    }
221}
222
223// ---------- Helper Functions ----------
224
225/// Encodes an unsigned integer as a varint
226fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
227    let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
228    if value < continuation_threshold {
229        // Fast path for small values (common case for lengths).
230        // `as_u8()` does not truncate the value or leave a continuation bit.
231        buf.put_u8(value.as_u8());
232        return;
233    }
234
235    let mut val = value;
236    while val >= continuation_threshold {
237        buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
238        val >>= 7;
239    }
240    buf.put_u8(val.as_u8());
241}
242
243/// Decodes a unsigned integer from a varint.
244///
245/// Returns an error if:
246/// - The varint is invalid (too long or malformed)
247/// - The buffer ends while reading
248fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
249    let max_bits = T::SIZE * BITS_PER_BYTE;
250    let mut result: T = T::from(0);
251    let mut bits_read = 0;
252
253    // Loop over all the bytes.
254    loop {
255        // Read the next byte.
256        if !buf.has_remaining() {
257            return Err(Error::EndOfBuffer);
258        }
259        let byte = buf.get_u8();
260
261        // If this is not the first byte, but the byte is completely zero, we have an invalid
262        // varint. This is because this byte has no data bits and no continuation, so there was no
263        // point in continuing to this byte in the first place. While the output could still result
264        // in a valid value, we ensure that every value has exactly one unique, valid encoding.
265        if byte == 0 && bits_read > 0 {
266            return Err(Error::InvalidVarint(T::SIZE));
267        }
268
269        // If this must be the last byte, check for overflow (i.e. set bits beyond the size of T).
270        // Because the continuation bit is the most-significant bit, this check also happens to
271        // check for an invalid continuation bit.
272        //
273        // If we have reached what must be the last byte, this check prevents continuing to read
274        // from the buffer by ensuring that the conditional (`if byte & CONTINUATION_BIT_MASK == 0`)
275        // always evaluates to true.
276        let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
277        if remaining_bits <= DATA_BITS_PER_BYTE {
278            let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
279            if relevant_bits > remaining_bits {
280                return Err(Error::InvalidVarint(T::SIZE));
281            }
282        }
283
284        // Write the 7 bits of data to the result.
285        result |= T::from(byte & DATA_BITS_MASK) << bits_read;
286
287        // If the continuation bit is not set, return.
288        if byte & CONTINUATION_BIT_MASK == 0 {
289            return Ok(result);
290        }
291
292        bits_read += DATA_BITS_PER_BYTE;
293    }
294}
295
296/// Calculates the number of bytes needed to encode an unsigned integer as a varint.
297fn size<T: UPrim>(value: T) -> usize {
298    let total_bits = std::mem::size_of::<T>() * 8;
299    let leading_zeros = value.leading_zeros() as usize;
300    let data_bits = total_bits - leading_zeros;
301    usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
302}
303
304/// Encodes a signed integer as a varint using ZigZag encoding.
305fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
306    write(value.as_zigzag(), buf);
307}
308
309/// Decodes a signed integer from varint ZigZag encoding.
310fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
311    Ok(S::un_zigzag(read(buf)?))
312}
313
314/// Calculates the number of bytes needed to encode a signed integer as a varint.
315fn size_signed<S: SPrim>(value: S) -> usize {
316    size(value.as_zigzag())
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use crate::{error::Error, DecodeExt, Encode};
323    use bytes::Bytes;
324
325    #[test]
326    fn test_end_of_buffer() {
327        let mut buf: Bytes = Bytes::from_static(&[]);
328        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
329
330        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
331        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
332
333        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
334        assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
335    }
336
337    #[test]
338    fn test_overflow() {
339        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
340        assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
341
342        let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
343        assert!(matches!(
344            read::<u32>(&mut buf),
345            Err(Error::InvalidVarint(u32::SIZE))
346        ));
347
348        let mut buf =
349            Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
350        assert!(matches!(
351            read::<u64>(&mut buf),
352            Err(Error::InvalidVarint(u64::SIZE))
353        ));
354    }
355
356    #[test]
357    fn test_overcontinuation() {
358        let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
359        let result = read::<u32>(&mut buf);
360        assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
361    }
362
363    #[test]
364    fn test_zeroed_byte() {
365        let mut buf = Bytes::from_static(&[0xFF, 0x00]);
366        let result = read::<u64>(&mut buf);
367        assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
368    }
369
370    /// Core round-trip check, generic over any UPrim.
371    fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
372        const CASES: &[u128] = &[
373            0,
374            1,
375            127,
376            128,
377            129,
378            0xFF,
379            0x100,
380            0x3FFF,
381            0x4000,
382            0x1_FFFF,
383            0xFF_FFFF,
384            0x1_FF_FF_FF_FF,
385            0xFF_FF_FF_FF_FF_FF,
386            0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
387            u16::MAX as u128,
388            u32::MAX as u128,
389            u64::MAX as u128,
390            u128::MAX,
391        ];
392
393        for &raw in CASES {
394            // skip values that don't fit into T
395            let Ok(value) = raw.try_into() else { continue };
396            let value: T = value;
397
398            // size matches encoding length
399            let mut buf = Vec::new();
400            write(value, &mut buf);
401            assert_eq!(buf.len(), size(value));
402
403            // decode matches original value
404            let mut slice = &buf[..];
405            let decoded: T = read(&mut slice).unwrap();
406            assert_eq!(decoded, value);
407            assert!(slice.is_empty());
408
409            // UInt wrapper
410            let encoded = UInt(value).encode();
411            assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
412        }
413    }
414
415    #[test]
416    fn test_varuint() {
417        varuint_round_trip::<u16>();
418        varuint_round_trip::<u32>();
419        varuint_round_trip::<u64>();
420        varuint_round_trip::<u128>();
421    }
422
423    fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
424        const CASES: &[i128] = &[
425            0,
426            1,
427            -1,
428            2,
429            -2,
430            127,
431            -127,
432            128,
433            -128,
434            129,
435            -129,
436            0x7FFFFFFF,
437            -0x7FFFFFFF,
438            0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
439            -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
440            i16::MIN as i128,
441            i16::MAX as i128,
442            i32::MIN as i128,
443            i32::MAX as i128,
444            i64::MIN as i128,
445            i64::MAX as i128,
446        ];
447
448        for &raw in CASES {
449            // skip values that don't fit into T
450            let Ok(value) = raw.try_into() else { continue };
451            let value: T = value;
452
453            // size matches encoding length
454            let mut buf = Vec::new();
455            write_signed(value, &mut buf);
456            assert_eq!(buf.len(), size_signed(value));
457
458            // decode matches original value
459            let mut slice = &buf[..];
460            let decoded: T = read_signed(&mut slice).unwrap();
461            assert_eq!(decoded, value);
462            assert!(slice.is_empty());
463
464            // SInt wrapper
465            let encoded = SInt(value).encode();
466            assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
467        }
468    }
469
470    #[test]
471    fn test_varsint() {
472        varsint_round_trip::<i16>();
473        varsint_round_trip::<i32>();
474        varsint_round_trip::<i64>();
475        varsint_round_trip::<i128>();
476    }
477
478    #[test]
479    fn test_varuint_into() {
480        let v32: u32 = 0x1_FFFF;
481        let out32: u32 = UInt(v32).into();
482        assert_eq!(v32, out32);
483
484        let v64: u64 = 0x1_FF_FF_FF_FF;
485        let out64: u64 = UInt(v64).into();
486        assert_eq!(v64, out64);
487    }
488
489    #[test]
490    fn test_varsint_into() {
491        let s32: i32 = -123_456;
492        let out32: i32 = SInt(s32).into();
493        assert_eq!(s32, out32);
494
495        let s64: i64 = 987_654_321;
496        let out64: i64 = SInt(s64).into();
497        assert_eq!(s64, out64);
498    }
499}