miraland_program/
serde_varint.rs

1//! Integers that serialize to variable size.
2
3#![allow(clippy::arithmetic_side_effects)]
4use {
5    serde::{
6        de::{Error as _, SeqAccess, Visitor},
7        ser::SerializeTuple,
8        Deserializer, Serializer,
9    },
10    std::{fmt, marker::PhantomData},
11};
12
13pub trait VarInt: Sized {
14    fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
15    where
16        A: SeqAccess<'de>;
17
18    fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
19    where
20        S: Serializer;
21}
22
23struct VarIntVisitor<T> {
24    phantom: PhantomData<T>,
25}
26
27impl<'de, T> Visitor<'de> for VarIntVisitor<T>
28where
29    T: VarInt,
30{
31    type Value = T;
32
33    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34        formatter.write_str("a VarInt")
35    }
36
37    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
38    where
39        A: SeqAccess<'de>,
40    {
41        T::visit_seq(seq)
42    }
43}
44
45pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
46where
47    T: Copy + VarInt,
48    S: Serializer,
49{
50    (*value).serialize(serializer)
51}
52
53pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
54where
55    D: Deserializer<'de>,
56    T: VarInt,
57{
58    deserializer.deserialize_tuple(
59        (std::mem::size_of::<T>() * 8 + 6) / 7,
60        VarIntVisitor {
61            phantom: PhantomData,
62        },
63    )
64}
65
66macro_rules! impl_var_int {
67    ($type:ty) => {
68        impl VarInt for $type {
69            fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
70            where
71                A: SeqAccess<'de>,
72            {
73                let mut out = 0;
74                let mut shift = 0u32;
75                while shift < <$type>::BITS {
76                    let Some(byte) = seq.next_element::<u8>()? else {
77                        return Err(A::Error::custom("Invalid Sequence"));
78                    };
79                    out |= ((byte & 0x7F) as Self) << shift;
80                    if byte & 0x80 == 0 {
81                        // Last byte should not have been truncated when it was
82                        // shifted to the left above.
83                        if (out >> shift) as u8 != byte {
84                            return Err(A::Error::custom("Last Byte Truncated"));
85                        }
86                        // Last byte can be zero only if there was only one
87                        // byte and the output is also zero.
88                        if byte == 0u8 && (shift != 0 || out != 0) {
89                            return Err(A::Error::custom("Invalid Trailing Zeros"));
90                        }
91                        return Ok(out);
92                    }
93                    shift += 7;
94                }
95                Err(A::Error::custom("Left Shift Overflows"))
96            }
97
98            fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
99            where
100                S: Serializer,
101            {
102                let bits = <$type>::BITS - self.leading_zeros();
103                let num_bytes = ((bits + 6) / 7).max(1) as usize;
104                let mut seq = serializer.serialize_tuple(num_bytes)?;
105                while self >= 0x80 {
106                    let byte = ((self & 0x7F) | 0x80) as u8;
107                    seq.serialize_element(&byte)?;
108                    self >>= 7;
109                }
110                seq.serialize_element(&(self as u8))?;
111                seq.end()
112            }
113        }
114    };
115}
116
117impl_var_int!(u16);
118impl_var_int!(u32);
119impl_var_int!(u64);
120
121#[cfg(test)]
122mod tests {
123    use {crate::short_vec::ShortU16, rand::Rng};
124
125    #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
126    struct Dummy {
127        #[serde(with = "super")]
128        a: u32,
129        b: u64,
130        #[serde(with = "super")]
131        c: u64,
132        d: u32,
133    }
134
135    #[test]
136    fn test_serde_varint() {
137        assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
138        assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
139        let dummy = Dummy {
140            a: 698,
141            b: 370,
142            c: 146,
143            d: 796,
144        };
145        let bytes = bincode::serialize(&dummy).unwrap();
146        assert_eq!(bytes.len(), 16);
147        let other: Dummy = bincode::deserialize(&bytes).unwrap();
148        assert_eq!(other, dummy);
149    }
150
151    #[test]
152    fn test_serde_varint_zero() {
153        let dummy = Dummy {
154            a: 0,
155            b: 0,
156            c: 0,
157            d: 0,
158        };
159        let bytes = bincode::serialize(&dummy).unwrap();
160        assert_eq!(bytes.len(), 14);
161        let other: Dummy = bincode::deserialize(&bytes).unwrap();
162        assert_eq!(other, dummy);
163    }
164
165    #[test]
166    fn test_serde_varint_max() {
167        let dummy = Dummy {
168            a: u32::MAX,
169            b: u64::MAX,
170            c: u64::MAX,
171            d: u32::MAX,
172        };
173        let bytes = bincode::serialize(&dummy).unwrap();
174        assert_eq!(bytes.len(), 27);
175        let other: Dummy = bincode::deserialize(&bytes).unwrap();
176        assert_eq!(other, dummy);
177    }
178
179    #[test]
180    fn test_serde_varint_rand() {
181        let mut rng = rand::thread_rng();
182        for _ in 0..100_000 {
183            let dummy = Dummy {
184                a: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
185                b: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
186                c: rng.gen::<u64>() >> rng.gen_range(0..u64::BITS),
187                d: rng.gen::<u32>() >> rng.gen_range(0..u32::BITS),
188            };
189            let bytes = bincode::serialize(&dummy).unwrap();
190            let other: Dummy = bincode::deserialize(&bytes).unwrap();
191            assert_eq!(other, dummy);
192        }
193    }
194
195    #[test]
196    fn test_serde_varint_trailing_zeros() {
197        let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
198        let out = bincode::deserialize::<Dummy>(&buffer);
199        assert!(out.is_err());
200        assert_eq!(
201            format!("{out:?}"),
202            r#"Err(Custom("Invalid Trailing Zeros"))"#
203        );
204        let buffer = [0x80, 0x0];
205        let out = bincode::deserialize::<Dummy>(&buffer);
206        assert!(out.is_err());
207        assert_eq!(
208            format!("{out:?}"),
209            r#"Err(Custom("Invalid Trailing Zeros"))"#
210        );
211    }
212
213    #[test]
214    fn test_serde_varint_last_byte_truncated() {
215        let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
216        let out = bincode::deserialize::<Dummy>(&buffer);
217        assert!(out.is_err());
218        assert_eq!(format!("{out:?}"), r#"Err(Custom("Last Byte Truncated"))"#);
219    }
220
221    #[test]
222    fn test_serde_varint_shift_overflow() {
223        let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
224        let out = bincode::deserialize::<Dummy>(&buffer);
225        assert!(out.is_err());
226        assert_eq!(format!("{out:?}"), r#"Err(Custom("Left Shift Overflows"))"#);
227    }
228
229    #[test]
230    fn test_serde_varint_short_buffer() {
231        let buffer = [0x84, 0xdf, 0x96, 0xfa];
232        let out = bincode::deserialize::<Dummy>(&buffer);
233        assert!(out.is_err());
234        assert_eq!(format!("{out:?}"), r#"Err(Io(Kind(UnexpectedEof)))"#);
235    }
236
237    #[test]
238    fn test_serde_varint_fuzz() {
239        let mut rng = rand::thread_rng();
240        let mut buffer = [0u8; 36];
241        let mut num_errors = 0;
242        for _ in 0..200_000 {
243            rng.fill(&mut buffer[..]);
244            match bincode::deserialize::<Dummy>(&buffer) {
245                Err(_) => {
246                    num_errors += 1;
247                }
248                Ok(dummy) => {
249                    let bytes = bincode::serialize(&dummy).unwrap();
250                    assert_eq!(bytes, &buffer[..bytes.len()]);
251                }
252            }
253        }
254        assert!(
255            (3_000..23_000).contains(&num_errors),
256            "num errors: {num_errors}"
257        );
258    }
259
260    #[test]
261    fn test_serde_varint_cross_fuzz() {
262        #[derive(Serialize, Deserialize)]
263        struct U16(#[serde(with = "super")] u16);
264        let mut rng = rand::thread_rng();
265        let mut buffer = [0u8; 16];
266        let mut num_errors = 0;
267        for _ in 0..200_000 {
268            rng.fill(&mut buffer[..]);
269            match bincode::deserialize::<U16>(&buffer) {
270                Err(_) => {
271                    assert!(bincode::deserialize::<ShortU16>(&buffer).is_err());
272                    num_errors += 1;
273                }
274                Ok(k) => {
275                    let bytes = bincode::serialize(&k).unwrap();
276                    assert_eq!(bytes, &buffer[..bytes.len()]);
277                    assert_eq!(bytes, bincode::serialize(&ShortU16(k.0)).unwrap());
278                    assert_eq!(bincode::deserialize::<ShortU16>(&buffer).unwrap().0, k.0);
279                }
280            }
281        }
282        assert!(
283            (30_000..70_000).contains(&num_errors),
284            "num errors: {num_errors}"
285        );
286    }
287}