cbe_program/
serde_varint.rs

1//! Integers that serialize to variable size.
2
3#![allow(clippy::integer_arithmetic)]
4use {
5    serde::{
6        de::{Error as _, SeqAccess, Visitor},
7        ser::SerializeTuple,
8        Deserializer, Serializer,
9    },
10    std::{fmt, marker::PhantomData},
11};
12
13pub trait VarInt: Sized {
14    fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
15    where
16        A: SeqAccess<'de>;
17
18    fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
19    where
20        S: Serializer;
21}
22
23struct VarIntVisitor<T> {
24    phantom: PhantomData<T>,
25}
26
27impl<'de, T> Visitor<'de> for VarIntVisitor<T>
28where
29    T: VarInt,
30{
31    type Value = T;
32
33    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34        formatter.write_str("a VarInt")
35    }
36
37    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
38    where
39        A: SeqAccess<'de>,
40    {
41        T::visit_seq(seq)
42    }
43}
44
45pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
46where
47    T: Copy + VarInt,
48    S: Serializer,
49{
50    (*value).serialize(serializer)
51}
52
53pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
54where
55    D: Deserializer<'de>,
56    T: VarInt,
57{
58    deserializer.deserialize_tuple(
59        (std::mem::size_of::<T>() * 8 + 6) / 7,
60        VarIntVisitor {
61            phantom: PhantomData::default(),
62        },
63    )
64}
65
66macro_rules! impl_var_int {
67    ($type:ty) => {
68        impl VarInt for $type {
69            fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
70            where
71                A: SeqAccess<'de>,
72            {
73                let mut out = 0;
74                let mut shift = 0u32;
75                while shift < <$type>::BITS {
76                    let byte = match seq.next_element::<u8>()? {
77                        None => return Err(A::Error::custom("Invalid Sequence")),
78                        Some(byte) => byte,
79                    };
80                    out |= ((byte & 0x7F) as Self) << shift;
81                    if byte & 0x80 == 0 {
82                        // Last byte should not have been truncated when it was
83                        // shifted to the left above.
84                        if (out >> shift) as u8 != byte {
85                            return Err(A::Error::custom("Last Byte Truncated"));
86                        }
87                        // Last byte can be zero only if there was only one
88                        // byte and the output is also zero.
89                        if byte == 0u8 && (shift != 0 || out != 0) {
90                            return Err(A::Error::custom("Invalid Trailing Zeros"));
91                        }
92                        return Ok(out);
93                    }
94                    shift += 7;
95                }
96                Err(A::Error::custom("Left Shift Overflows"))
97            }
98
99            fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
100            where
101                S: Serializer,
102            {
103                let bits = <$type>::BITS - self.leading_zeros();
104                let num_bytes = ((bits + 6) / 7).max(1) as usize;
105                let mut seq = serializer.serialize_tuple(num_bytes)?;
106                while self >= 0x80 {
107                    let byte = ((self & 0x7F) | 0x80) as u8;
108                    seq.serialize_element(&byte)?;
109                    self >>= 7;
110                }
111                seq.serialize_element(&(self as u8))?;
112                seq.end()
113            }
114        }
115    };
116}
117
118impl_var_int!(u32);
119impl_var_int!(u64);
120
121#[cfg(test)]
122mod tests {
123    use rand::Rng;
124
125    #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
126    struct Dummy {
127        #[serde(with = "super")]
128        a: u32,
129        b: u64,
130        #[serde(with = "super")]
131        c: u64,
132        d: u32,
133    }
134
135    #[test]
136    fn test_serde_varint() {
137        assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
138        assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
139        let dummy = Dummy {
140            a: 698,
141            b: 370,
142            c: 146,
143            d: 796,
144        };
145        let bytes = bincode::serialize(&dummy).unwrap();
146        assert_eq!(bytes.len(), 16);
147        let other: Dummy = bincode::deserialize(&bytes).unwrap();
148        assert_eq!(other, dummy);
149    }
150
151    #[test]
152    fn test_serde_varint_zero() {
153        let dummy = Dummy {
154            a: 0,
155            b: 0,
156            c: 0,
157            d: 0,
158        };
159        let bytes = bincode::serialize(&dummy).unwrap();
160        assert_eq!(bytes.len(), 14);
161        let other: Dummy = bincode::deserialize(&bytes).unwrap();
162        assert_eq!(other, dummy);
163    }
164
165    #[test]
166    fn test_serde_varint_max() {
167        let dummy = Dummy {
168            a: u32::MAX,
169            b: u64::MAX,
170            c: u64::MAX,
171            d: u32::MAX,
172        };
173        let bytes = bincode::serialize(&dummy).unwrap();
174        assert_eq!(bytes.len(), 27);
175        let other: Dummy = bincode::deserialize(&bytes).unwrap();
176        assert_eq!(other, dummy);
177    }
178
179    #[test]
180    fn test_serde_varint_rand() {
181        let mut rng = rand::thread_rng();
182        for _ in 0..100_000 {
183            let dummy = Dummy {
184                a: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
185                b: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
186                c: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
187                d: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
188            };
189            let bytes = bincode::serialize(&dummy).unwrap();
190            let other: Dummy = bincode::deserialize(&bytes).unwrap();
191            assert_eq!(other, dummy);
192        }
193    }
194
195    #[test]
196    fn test_serde_varint_trailing_zeros() {
197        let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
198        let out = bincode::deserialize::<Dummy>(&buffer);
199        assert!(out.is_err());
200        assert_eq!(
201            format!("{out:?}"),
202            r#"Err(Custom("Invalid Trailing Zeros"))"#
203        );
204        let buffer = [0x80, 0x0];
205        let out = bincode::deserialize::<Dummy>(&buffer);
206        assert!(out.is_err());
207        assert_eq!(
208            format!("{out:?}"),
209            r#"Err(Custom("Invalid Trailing Zeros"))"#
210        );
211    }
212
213    #[test]
214    fn test_serde_varint_last_byte_truncated() {
215        let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
216        let out = bincode::deserialize::<Dummy>(&buffer);
217        assert!(out.is_err());
218        assert_eq!(format!("{out:?}"), r#"Err(Custom("Last Byte Truncated"))"#);
219    }
220
221    #[test]
222    fn test_serde_varint_shift_overflow() {
223        let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
224        let out = bincode::deserialize::<Dummy>(&buffer);
225        assert!(out.is_err());
226        assert_eq!(format!("{out:?}"), r#"Err(Custom("Left Shift Overflows"))"#);
227    }
228
229    #[test]
230    fn test_serde_varint_short_buffer() {
231        let buffer = [0x84, 0xdf, 0x96, 0xfa];
232        let out = bincode::deserialize::<Dummy>(&buffer);
233        assert!(out.is_err());
234        assert_eq!(format!("{out:?}"), r#"Err(Io(Kind(UnexpectedEof)))"#);
235    }
236
237    #[test]
238    fn test_serde_varint_fuzz() {
239        let mut rng = rand::thread_rng();
240        let mut buffer = [0u8; 36];
241        let mut num_errors = 0;
242        for _ in 0..200_000 {
243            rng.fill(&mut buffer[..]);
244            match bincode::deserialize::<Dummy>(&buffer) {
245                Err(_) => {
246                    num_errors += 1;
247                }
248                Ok(dummy) => {
249                    let bytes = bincode::serialize(&dummy).unwrap();
250                    assert_eq!(bytes, &buffer[..bytes.len()]);
251                }
252            }
253        }
254        assert!(num_errors > 2_000);
255    }
256}