commonware_codec/
varint.rs

1//! Variable-length integer encoding and decoding
2//!
3//! This module implements Google's Protocol Buffers variable-length integer encoding.
4//! Each byte uses 7 bits for the value and 1 bit to indicate if more bytes follow.
5
6use crate::error::Error;
7use bytes::{Buf, BufMut};
8
9fn must_u64<T: TryInto<u64>>(value: T) -> u64 {
10    value
11        .try_into()
12        .unwrap_or_else(|_| panic!("Failed to convert to u64"))
13}
14
15fn must_i64<T: TryInto<i64>>(value: T) -> i64 {
16    value
17        .try_into()
18        .unwrap_or_else(|_| panic!("Failed to convert to i64"))
19}
20
21/// Encodes a unsigned 64-bit integer as a varint
22pub fn write<T: TryInto<u64>>(value: T, buf: &mut impl BufMut) {
23    let value = must_u64(value);
24
25    if value < 0x80 {
26        // Fast path for small values (common case for lengths)
27        buf.put_u8(value as u8);
28        return;
29    }
30
31    let mut val = value;
32    while val >= 0x80 {
33        buf.put_u8((val as u8) | 0x80);
34        val >>= 7;
35    }
36    buf.put_u8(val as u8);
37}
38
39/// Decodes a unsigned 64-bit integer from a varint
40pub fn read<T: TryFrom<u64>>(buf: &mut impl Buf) -> Result<T, Error> {
41    let mut result = 0u64;
42    let mut shift = 0;
43
44    // Loop over all the bytes.
45    loop {
46        // Read the next byte.
47        if !buf.has_remaining() {
48            return Err(Error::EndOfBuffer);
49        }
50        let byte = buf.get_u8();
51
52        // If we have read more than 9 bytes, the next byte must be 0 or 1.
53        if shift >= (9 * 7) && byte > 1 {
54            return Err(Error::InvalidVarint);
55        }
56
57        // Write the 7 bits of data to the result.
58        result |= ((byte & 0x7F) as u64) << shift;
59
60        // If the continuation bit is not set, return.
61        if byte & 0x80 == 0 {
62            return result.try_into().map_err(|_| Error::InvalidVarint);
63        }
64
65        // Each byte has 7 bits of data.
66        shift += 7;
67    }
68}
69
70/// Calculates the number of bytes needed to encode a value as a varint
71pub fn size<T: TryInto<u64>>(value: T) -> usize {
72    let value = must_u64(value);
73    match value {
74        0..=0x7F => 1,
75        0x80..=0x3FFF => 2,
76        0x4000..=0x1FFFFF => 3,
77        0x200000..=0xFFFFFFF => 4,
78        0x10000000..=0x7FFFFFFFF => 5,
79        0x800000000..=0x3FFFFFFFFFF => 6,
80        0x40000000000..=0x1FFFFFFFFFFFF => 7,
81        0x2000000000000..=0xFFFFFFFFFFFFFF => 8,
82        0x100000000000000..=0x7FFFFFFFFFFFFFFF => 9,
83        _ => 10,
84    }
85}
86
87/// Converts a signed integer to an unsigned integer using ZigZag encoding
88fn to_u64(value: i64) -> u64 {
89    ((value << 1) ^ (value >> 63)) as u64
90}
91
92/// Converts an unsigned integer to a signed integer using ZigZag encoding
93fn to_i64(value: u64) -> i64 {
94    ((value >> 1) as i64) ^ (-((value & 1) as i64))
95}
96
97/// Encodes a signed 64-bit integer as a varint using ZigZag encoding
98pub fn write_i64<T: TryInto<i64>>(value: T, buf: &mut impl BufMut) {
99    let value = must_i64(value);
100    write(to_u64(value), buf);
101}
102
103/// Decodes a signed 64-bit integer from a varint using ZigZag encoding
104pub fn read_i64<T: TryFrom<i64>>(buf: &mut impl Buf) -> Result<T, Error> {
105    let zigzag = read(buf)?;
106    to_i64(zigzag).try_into().map_err(|_| Error::InvalidVarint)
107}
108
109/// Calculates the number of bytes needed to encode a signed integer as a varint
110pub fn size_i64(value: i64) -> usize {
111    size(to_u64(value))
112}
113
114#[cfg(test)]
115mod tests {
116    use super::*;
117    use crate::error::Error;
118    use bytes::Bytes;
119
120    #[test]
121    fn test_varint_encoding() {
122        let test_cases = [
123            0u64,
124            1,
125            127,
126            128,
127            129,
128            0xFF,
129            0x100,
130            0x3FFF,
131            0x4000,
132            0x1FFFFF,
133            0xFFFFFF,
134            0x1FFFFFFF,
135            0xFFFFFFFF,
136            0x1FFFFFFFFFF,
137            0xFFFFFFFFFFFFFF,
138            u64::MAX,
139        ];
140
141        for &value in &test_cases {
142            let mut buf = Vec::new();
143            write(value, &mut buf);
144
145            assert_eq!(buf.len(), size(value));
146
147            let mut read_buf = &buf[..];
148            let decoded: u64 = read(&mut read_buf).unwrap();
149
150            assert_eq!(decoded, value);
151            assert_eq!(read_buf.len(), 0);
152        }
153    }
154
155    #[test]
156    fn test_zigzag_encoding() {
157        let test_cases = [
158            0i64,
159            1,
160            -1,
161            2,
162            -2,
163            127,
164            -127,
165            128,
166            -128,
167            129,
168            -129,
169            0x7FFFFFFF,
170            -0x7FFFFFFF,
171            i64::MIN,
172            i64::MAX,
173        ];
174
175        for &value in &test_cases {
176            let mut buf = Vec::new();
177            write_i64(value, &mut buf);
178
179            assert_eq!(buf.len(), size_i64(value));
180
181            let mut read_buf = &buf[..];
182            let decoded = read_i64::<i64>(&mut read_buf).unwrap();
183
184            assert_eq!(decoded, value);
185            assert_eq!(read_buf.len(), 0,);
186        }
187    }
188
189    #[test]
190    fn test_varint_insufficient_buffer() {
191        let mut buf = Bytes::from_static(&[0x80]);
192        assert!(matches!(read::<u64>(&mut buf), Err(Error::EndOfBuffer)));
193    }
194
195    #[test]
196    fn test_varint_invalid() {
197        let mut buf =
198            Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
199        assert!(matches!(read::<u64>(&mut buf), Err(Error::InvalidVarint)));
200    }
201}