1#![allow(clippy::integer_arithmetic)]
4use {
5 serde::{
6 de::{Error as _, SeqAccess, Visitor},
7 ser::SerializeTuple,
8 Deserializer, Serializer,
9 },
10 std::{fmt, marker::PhantomData},
11};
12
13pub trait VarInt: Sized {
14 fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
15 where
16 A: SeqAccess<'de>;
17
18 fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
19 where
20 S: Serializer;
21}
22
23struct VarIntVisitor<T> {
24 phantom: PhantomData<T>,
25}
26
27impl<'de, T> Visitor<'de> for VarIntVisitor<T>
28where
29 T: VarInt,
30{
31 type Value = T;
32
33 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
34 formatter.write_str("a VarInt")
35 }
36
37 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
38 where
39 A: SeqAccess<'de>,
40 {
41 T::visit_seq(seq)
42 }
43}
44
45pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
46where
47 T: Copy + VarInt,
48 S: Serializer,
49{
50 (*value).serialize(serializer)
51}
52
53pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
54where
55 D: Deserializer<'de>,
56 T: VarInt,
57{
58 deserializer.deserialize_tuple(
59 (std::mem::size_of::<T>() * 8 + 6) / 7,
60 VarIntVisitor {
61 phantom: PhantomData::default(),
62 },
63 )
64}
65
66macro_rules! impl_var_int {
67 ($type:ty) => {
68 impl VarInt for $type {
69 fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
70 where
71 A: SeqAccess<'de>,
72 {
73 let mut out = 0;
74 let mut shift = 0u32;
75 while shift < <$type>::BITS {
76 let byte = match seq.next_element::<u8>()? {
77 None => return Err(A::Error::custom("Invalid Sequence")),
78 Some(byte) => byte,
79 };
80 out |= ((byte & 0x7F) as Self) << shift;
81 if byte & 0x80 == 0 {
82 if (out >> shift) as u8 != byte {
85 return Err(A::Error::custom("Last Byte Truncated"));
86 }
87 if byte == 0u8 && (shift != 0 || out != 0) {
90 return Err(A::Error::custom("Invalid Trailing Zeros"));
91 }
92 return Ok(out);
93 }
94 shift += 7;
95 }
96 Err(A::Error::custom("Left Shift Overflows"))
97 }
98
99 fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
100 where
101 S: Serializer,
102 {
103 let bits = <$type>::BITS - self.leading_zeros();
104 let num_bytes = ((bits + 6) / 7).max(1) as usize;
105 let mut seq = serializer.serialize_tuple(num_bytes)?;
106 while self >= 0x80 {
107 let byte = ((self & 0x7F) | 0x80) as u8;
108 seq.serialize_element(&byte)?;
109 self >>= 7;
110 }
111 seq.serialize_element(&(self as u8))?;
112 seq.end()
113 }
114 }
115 };
116}
117
118impl_var_int!(u32);
119impl_var_int!(u64);
120
121#[cfg(test)]
122mod tests {
123 use rand::Rng;
124
125 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
126 struct Dummy {
127 #[serde(with = "super")]
128 a: u32,
129 b: u64,
130 #[serde(with = "super")]
131 c: u64,
132 d: u32,
133 }
134
135 #[test]
136 fn test_serde_varint() {
137 assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
138 assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
139 let dummy = Dummy {
140 a: 698,
141 b: 370,
142 c: 146,
143 d: 796,
144 };
145 let bytes = bincode::serialize(&dummy).unwrap();
146 assert_eq!(bytes.len(), 16);
147 let other: Dummy = bincode::deserialize(&bytes).unwrap();
148 assert_eq!(other, dummy);
149 }
150
151 #[test]
152 fn test_serde_varint_zero() {
153 let dummy = Dummy {
154 a: 0,
155 b: 0,
156 c: 0,
157 d: 0,
158 };
159 let bytes = bincode::serialize(&dummy).unwrap();
160 assert_eq!(bytes.len(), 14);
161 let other: Dummy = bincode::deserialize(&bytes).unwrap();
162 assert_eq!(other, dummy);
163 }
164
165 #[test]
166 fn test_serde_varint_max() {
167 let dummy = Dummy {
168 a: u32::MAX,
169 b: u64::MAX,
170 c: u64::MAX,
171 d: u32::MAX,
172 };
173 let bytes = bincode::serialize(&dummy).unwrap();
174 assert_eq!(bytes.len(), 27);
175 let other: Dummy = bincode::deserialize(&bytes).unwrap();
176 assert_eq!(other, dummy);
177 }
178
179 #[test]
180 fn test_serde_varint_rand() {
181 let mut rng = rand::thread_rng();
182 for _ in 0..100_000 {
183 let dummy = Dummy {
184 a: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
185 b: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
186 c: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
187 d: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
188 };
189 let bytes = bincode::serialize(&dummy).unwrap();
190 let other: Dummy = bincode::deserialize(&bytes).unwrap();
191 assert_eq!(other, dummy);
192 }
193 }
194
195 #[test]
196 fn test_serde_varint_trailing_zeros() {
197 let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
198 let out = bincode::deserialize::<Dummy>(&buffer);
199 assert!(out.is_err());
200 assert_eq!(
201 format!("{out:?}"),
202 r#"Err(Custom("Invalid Trailing Zeros"))"#
203 );
204 let buffer = [0x80, 0x0];
205 let out = bincode::deserialize::<Dummy>(&buffer);
206 assert!(out.is_err());
207 assert_eq!(
208 format!("{out:?}"),
209 r#"Err(Custom("Invalid Trailing Zeros"))"#
210 );
211 }
212
213 #[test]
214 fn test_serde_varint_last_byte_truncated() {
215 let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
216 let out = bincode::deserialize::<Dummy>(&buffer);
217 assert!(out.is_err());
218 assert_eq!(format!("{out:?}"), r#"Err(Custom("Last Byte Truncated"))"#);
219 }
220
221 #[test]
222 fn test_serde_varint_shift_overflow() {
223 let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
224 let out = bincode::deserialize::<Dummy>(&buffer);
225 assert!(out.is_err());
226 assert_eq!(format!("{out:?}"), r#"Err(Custom("Left Shift Overflows"))"#);
227 }
228
229 #[test]
230 fn test_serde_varint_short_buffer() {
231 let buffer = [0x84, 0xdf, 0x96, 0xfa];
232 let out = bincode::deserialize::<Dummy>(&buffer);
233 assert!(out.is_err());
234 assert_eq!(format!("{out:?}"), r#"Err(Io(Kind(UnexpectedEof)))"#);
235 }
236
237 #[test]
238 fn test_serde_varint_fuzz() {
239 let mut rng = rand::thread_rng();
240 let mut buffer = [0u8; 36];
241 let mut num_errors = 0;
242 for _ in 0..200_000 {
243 rng.fill(&mut buffer[..]);
244 match bincode::deserialize::<Dummy>(&buffer) {
245 Err(_) => {
246 num_errors += 1;
247 }
248 Ok(dummy) => {
249 let bytes = bincode::serialize(&dummy).unwrap();
250 assert_eq!(bytes, &buffer[..bytes.len()]);
251 }
252 }
253 }
254 assert!(num_errors > 2_000);
255 }
256}