vlqencoding/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8//! VLQ (Variable-length quantity) encoding.
9
10use std::io;
11use std::io::Read;
12use std::io::Write;
13use std::mem::size_of;
14
15pub trait VLQEncode<T> {
16    /// Encode an integer to a VLQ byte array and write it directly to a stream.
17    ///
18    /// # Examples
19    ///
20    /// ```
21    /// use vlqencoding::VLQEncode;
22    /// let mut v = vec![];
23    ///
24    /// let x = 120u8;
25    /// v.write_vlq(x).expect("writing an encoded u8 to a vec should work");
26    /// assert_eq!(v, vec![120]);
27    ///
28    /// let x = 22742734291u64;
29    /// v.write_vlq(x).expect("writing an encoded u64 to a vec should work");
30    ///
31    /// assert_eq!(v, vec![120, 211, 171, 202, 220, 84]);
32    /// ```
33    ///
34    /// Signed integers are encoded via zig-zag:
35    ///
36    /// ```
37    /// use vlqencoding::VLQEncode;
38    /// let mut v = vec![];
39    ///
40    /// let x = -3i8;
41    /// v.write_vlq(x).expect("writing an encoded i8 to a vec should work");
42    /// assert_eq!(v, vec![5]);
43    ///
44    /// let x = 1000i16;
45    /// v.write_vlq(x).expect("writing an encoded i16 to a vec should work");
46    /// assert_eq!(v, vec![5, 208, 15]);
47    /// ```
48    fn write_vlq(&mut self, value: T) -> io::Result<()>;
49}
50
51pub trait VLQDecode<T> {
52    /// Read a VLQ byte array from stream and decode it to an integer.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// use vlqencoding::VLQDecode;
58    /// use std::io::{Cursor,Seek,SeekFrom,ErrorKind};
59    ///
60    /// let mut c = Cursor::new(vec![120u8, 211, 171, 202, 220, 84]);
61    ///
62    /// let x: Result<u8, _> = c.read_vlq();
63    /// assert_eq!(x.unwrap(), 120u8);
64    ///
65    /// let x: Result<u16, _> = c.read_vlq();
66    /// assert_eq!(x.unwrap_err().kind(), ErrorKind::InvalidData);
67    ///
68    /// c.seek(SeekFrom::Start(1)).expect("seek should work");
69    /// let x: Result<u64, _> = c.read_vlq();
70    /// assert_eq!(x.unwrap(), 22742734291u64);
71    /// ```
72    ///
73    /// Signed integers are decoded via zig-zag:
74    ///
75    /// ```
76    /// use vlqencoding::VLQDecode;
77    /// use std::io::{Cursor,Seek,SeekFrom,ErrorKind};
78    ///
79    /// let mut c = Cursor::new(vec![5u8, 208, 15]);
80    ///
81    /// let x: Result<i8, _> = c.read_vlq();
82    /// assert_eq!(x.unwrap(), -3i8);
83    ///
84    /// let x: Result<i8, _> = c.read_vlq();
85    /// assert_eq!(x.unwrap_err().kind(), ErrorKind::InvalidData);
86    ///
87    /// c.seek(SeekFrom::Start(1)).expect("seek should work");
88    /// let x: Result<i32, _> = c.read_vlq();
89    /// assert_eq!(x.unwrap(), 1000i32);
90    /// ```
91    fn read_vlq(&mut self) -> io::Result<T>;
92}
93
94pub trait VLQDecodeAt<T> {
95    /// Read a VLQ byte array from the given offset and decode it to an integer.
96    ///
97    /// Returns `Ok((decoded_integer, bytes_read))` on success.
98    ///
99    /// This is similar to `VLQDecode::read_vlq`. It's for immutable `AsRef<[u8]>` instead of
100    /// a mutable `io::Read` object.
101    ///
102    /// # Examples
103    ///
104    /// ```
105    /// use vlqencoding::VLQDecodeAt;
106    /// use std::io::ErrorKind;
107    ///
108    /// let c = &[120u8, 211, 171, 202, 220, 84, 255];
109    ///
110    /// let x: Result<(u8, _), _> = c.read_vlq_at(0);
111    /// assert_eq!(x.unwrap(), (120u8, 1));
112    ///
113    /// let x: Result<(u64, _), _> = c.read_vlq_at(1);
114    /// assert_eq!(x.unwrap(), (22742734291u64, 5));
115    ///
116    /// let x: Result<(u64, _), _> = c.read_vlq_at(6);
117    /// assert_eq!(x.unwrap_err().kind(), ::std::io::ErrorKind::InvalidData);
118    ///
119    /// let x: Result<(u64, _), _> = c.read_vlq_at(7);
120    /// assert_eq!(x.unwrap_err().kind(), ::std::io::ErrorKind::InvalidData);
121    /// ```
122    fn read_vlq_at(&self, offset: usize) -> io::Result<(T, usize)>;
123}
124
125macro_rules! impl_unsigned_primitive {
126    ($T: ident) => {
127        impl<W: Write + ?Sized> VLQEncode<$T> for W {
128            fn write_vlq(&mut self, value: $T) -> io::Result<()> {
129                let mut buf = [0u8];
130                let mut value = value;
131                loop {
132                    let mut byte = (value & 127) as u8;
133                    let next = value >> 7;
134                    if next != 0 {
135                        byte |= 128;
136                    }
137                    buf[0] = byte;
138                    self.write_all(&buf)?;
139                    value = next;
140                    if value == 0 {
141                        break;
142                    }
143                }
144                Ok(())
145            }
146        }
147
148        impl<R: Read + ?Sized> VLQDecode<$T> for R {
149            fn read_vlq(&mut self) -> io::Result<$T> {
150                let mut buf = [0u8];
151                let mut value = 0 as $T;
152                let mut base = 1 as $T;
153                let base_multiplier = (1 << 7) as $T;
154                loop {
155                    self.read_exact(&mut buf)?;
156                    let byte = buf[0];
157                    value = ($T::from(byte & 127))
158                        .checked_mul(base)
159                        .and_then(|v| v.checked_add(value))
160                        .ok_or(io::ErrorKind::InvalidData)?;
161                    if byte & 128 == 0 {
162                        break;
163                    }
164                    base = base
165                        .checked_mul(base_multiplier)
166                        .ok_or(io::ErrorKind::InvalidData)?;
167                }
168                Ok(value)
169            }
170        }
171
172        impl<R: AsRef<[u8]>> VLQDecodeAt<$T> for R {
173            fn read_vlq_at(&self, offset: usize) -> io::Result<($T, usize)> {
174                let buf = self.as_ref();
175                let mut size = 0;
176                let mut value = 0 as $T;
177                let mut base = 1 as $T;
178                let base_multiplier = (1 << 7) as $T;
179                loop {
180                    if let Some(byte) = buf.get(offset + size) {
181                        size += 1;
182                        value = ($T::from(byte & 127))
183                            .checked_mul(base)
184                            .and_then(|v| v.checked_add(value))
185                            .ok_or(io::ErrorKind::InvalidData)?;
186                        if byte & 128 == 0 {
187                            break;
188                        }
189                        base = base
190                            .checked_mul(base_multiplier)
191                            .ok_or(io::ErrorKind::InvalidData)?;
192                    } else {
193                        return Err(io::ErrorKind::InvalidData.into());
194                    }
195                }
196                Ok((value, size))
197            }
198        }
199    };
200}
201
202impl_unsigned_primitive!(usize);
203impl_unsigned_primitive!(u64);
204impl_unsigned_primitive!(u32);
205impl_unsigned_primitive!(u16);
206impl_unsigned_primitive!(u8);
207
208macro_rules! impl_signed_primitive {
209    ($T: ty, $U: ty) => {
210        impl<W: Write + ?Sized> VLQEncode<$T> for W {
211            fn write_vlq(&mut self, v: $T) -> io::Result<()> {
212                self.write_vlq(((v << 1) ^ (v >> (size_of::<$U>() * 8 - 1))) as $U)
213            }
214        }
215
216        impl<R: Read + ?Sized> VLQDecode<$T> for R {
217            fn read_vlq(&mut self) -> io::Result<$T> {
218                (self.read_vlq() as Result<$U, _>).map(|n| ((n >> 1) as $T) ^ -((n & 1) as $T))
219            }
220        }
221
222        impl<R: AsRef<[u8]>> VLQDecodeAt<$T> for R {
223            fn read_vlq_at(&self, offset: usize) -> io::Result<($T, usize)> {
224                (self.read_vlq_at(offset) as Result<($U, _), _>)
225                    .map(|(n, s)| (((n >> 1) as $T) ^ -((n & 1) as $T), s))
226            }
227        }
228    };
229}
230
231impl_signed_primitive!(isize, usize);
232impl_signed_primitive!(i64, u64);
233impl_signed_primitive!(i32, u32);
234impl_signed_primitive!(i16, u16);
235impl_signed_primitive!(i8, u8);
236
237#[cfg(test)]
238mod tests {
239    use std::io;
240    use std::io::Cursor;
241    use std::io::Seek;
242    use std::io::SeekFrom;
243
244    use quickcheck::quickcheck;
245
246    use super::*;
247
248    macro_rules! check_round_trip {
249        ($N: expr) => {{
250            let mut v = vec![];
251            let mut x = $N;
252            v.write_vlq(x).expect("write");
253
254            // `z` and `y` below are helpful for the compiler to figure out the return type of
255            // `read_vlq_at`, and `read_vlq`.
256            #[allow(unused_assignments)]
257            let mut z = x;
258            let t = v.read_vlq_at(0).unwrap();
259            z = t.0;
260
261            let mut c = Cursor::new(v);
262            let y = x;
263            x = c.read_vlq().unwrap();
264            x == y && y == z && t.1 == c.position() as usize
265        }};
266    }
267
268    #[test]
269    fn test_round_trip_manual() {
270        for i in (0..64)
271            .flat_map(|b| vec![1u64 << b, (1 << b) + 1, (1 << b) - 1].into_iter())
272            .chain(vec![0xb3a73ce2ff2, 0xab54a98ceb1f0ad2].into_iter())
273            .flat_map(|i| vec![i, !i].into_iter())
274        {
275            assert!(check_round_trip!(i as i8));
276            assert!(check_round_trip!(i as i16));
277            assert!(check_round_trip!(i as i32));
278            assert!(check_round_trip!(i as i64));
279            assert!(check_round_trip!(i as isize));
280            assert!(check_round_trip!(i as u8));
281            assert!(check_round_trip!(i as u16));
282            assert!(check_round_trip!(i as u32));
283            assert!(check_round_trip!(i as u64));
284            assert!(check_round_trip!(i as usize));
285        }
286    }
287
288    #[test]
289    fn test_read_errors() {
290        let mut c = Cursor::new(vec![]);
291        assert_eq!(
292            (c.read_vlq() as io::Result<u64>).unwrap_err().kind(),
293            io::ErrorKind::UnexpectedEof
294        );
295
296        let mut c = Cursor::new(vec![255, 129]);
297        assert_eq!(
298            (c.read_vlq() as io::Result<u64>).unwrap_err().kind(),
299            io::ErrorKind::UnexpectedEof
300        );
301
302        c.seek(SeekFrom::Start(0)).unwrap();
303        assert_eq!(
304            (c.read_vlq() as io::Result<u8>).unwrap_err().kind(),
305            io::ErrorKind::InvalidData
306        );
307    }
308
309    #[test]
310    fn test_zig_zag() {
311        let mut c = Cursor::new(vec![]);
312        for (i, u) in [
313            (0, 0),
314            (-1, 1),
315            (1, 2),
316            (-2, 3),
317            (-127, 253),
318            (127, 254),
319            (-128i8, 255u8),
320        ] {
321            c.seek(SeekFrom::Start(0)).expect("seek");
322            c.write_vlq(i).expect("write");
323            c.seek(SeekFrom::Start(0)).expect("seek");
324            let x: u8 = c.read_vlq().unwrap();
325            assert_eq!(x, u);
326        }
327    }
328
329    quickcheck! {
330        fn test_round_trip_u64_quickcheck(x: u64) -> bool {
331            check_round_trip!(x)
332        }
333
334        fn test_round_trip_i64_quickcheck(x: i64) -> bool {
335            check_round_trip!(x)
336        }
337
338        fn test_round_trip_u8_quickcheck(x: u8) -> bool {
339            check_round_trip!(x)
340        }
341
342        fn test_round_trip_i8_quickcheck(x: i8) -> bool {
343            check_round_trip!(x)
344        }
345    }
346}