Skip to main content

moq_transport/coding/
varint.rs

1// Based on quinn-proto
2// https://github.com/quinn-rs/quinn/blob/main/quinn-proto/src/varint.rs
3// Licensed via Apache 2.0 and MIT
4
5use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use thiserror::Error;
9
10use super::{Decode, DecodeError, Encode, EncodeError};
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
13#[error("value out of range")]
14pub struct BoundsExceeded;
15
16/// An integer less than 2^62
17///
18/// Values of this type are suitable for encoding as QUIC variable-length integer.
19/// It would be neat if we could express to Rust that the top two bits are available for use as enum
20/// discriminants
21#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
22pub struct VarInt(u64);
23
24impl VarInt {
25    /// The largest possible value.
26    pub const MAX: Self = Self((1 << 62) - 1);
27
28    /// The smallest possible value.
29    pub const ZERO: Self = Self(0);
30
31    /// Construct a `VarInt` infallibly using the largest available type.
32    /// Larger values need to use `try_from` instead.
33    pub const fn from_u32(x: u32) -> Self {
34        Self(x as u64)
35    }
36
37    /// Extract the integer value
38    pub const fn into_inner(self) -> u64 {
39        self.0
40    }
41}
42
43impl From<VarInt> for u64 {
44    fn from(x: VarInt) -> Self {
45        x.0
46    }
47}
48
49impl From<VarInt> for usize {
50    fn from(x: VarInt) -> Self {
51        x.0 as usize
52    }
53}
54
55impl From<VarInt> for u128 {
56    fn from(x: VarInt) -> Self {
57        x.0 as u128
58    }
59}
60
61impl From<u8> for VarInt {
62    fn from(x: u8) -> Self {
63        Self(x.into())
64    }
65}
66
67impl From<u16> for VarInt {
68    fn from(x: u16) -> Self {
69        Self(x.into())
70    }
71}
72
73impl From<u32> for VarInt {
74    fn from(x: u32) -> Self {
75        Self(x.into())
76    }
77}
78
79impl TryFrom<u64> for VarInt {
80    type Error = BoundsExceeded;
81
82    /// Succeeds iff `x` < 2^62
83    fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
84        let x = Self(x);
85        if x <= Self::MAX {
86            Ok(x)
87        } else {
88            Err(BoundsExceeded)
89        }
90    }
91}
92
93impl TryFrom<u128> for VarInt {
94    type Error = BoundsExceeded;
95
96    /// Succeeds iff `x` < 2^62
97    fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
98        if x <= Self::MAX.into() {
99            Ok(Self(x as u64))
100        } else {
101            Err(BoundsExceeded)
102        }
103    }
104}
105
106impl TryFrom<usize> for VarInt {
107    type Error = BoundsExceeded;
108
109    /// Succeeds iff `x` < 2^62
110    fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
111        Self::try_from(x as u64)
112    }
113}
114
115impl TryFrom<VarInt> for u32 {
116    type Error = BoundsExceeded;
117
118    /// Succeeds iff `x` < 2^32
119    fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
120        if x.0 <= u32::MAX.into() {
121            Ok(x.0 as u32)
122        } else {
123            Err(BoundsExceeded)
124        }
125    }
126}
127
128impl TryFrom<VarInt> for u16 {
129    type Error = BoundsExceeded;
130
131    /// Succeeds iff `x` < 2^16
132    fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
133        if x.0 <= u16::MAX.into() {
134            Ok(x.0 as u16)
135        } else {
136            Err(BoundsExceeded)
137        }
138    }
139}
140
141impl TryFrom<VarInt> for u8 {
142    type Error = BoundsExceeded;
143
144    /// Succeeds iff `x` < 2^8
145    fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
146        if x.0 <= u8::MAX.into() {
147            Ok(x.0 as u8)
148        } else {
149            Err(BoundsExceeded)
150        }
151    }
152}
153
154impl fmt::Debug for VarInt {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        self.0.fmt(f)
157    }
158}
159
160impl fmt::Display for VarInt {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        self.0.fmt(f)
163    }
164}
165
166impl Decode for VarInt {
167    /// Decode a varint from the given reader.
168    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
169        Self::decode_remaining(r, 1)?;
170
171        let b = r.get_u8();
172        let tag = b >> 6;
173
174        let mut buf = [0u8; 8];
175        buf[0] = b & 0b0011_1111;
176
177        let x = match tag {
178            0b00 => u64::from(buf[0]),
179            0b01 => {
180                Self::decode_remaining(r, 1)?;
181                r.copy_to_slice(buf[1..2].as_mut());
182                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
183            }
184            0b10 => {
185                Self::decode_remaining(r, 3)?;
186                r.copy_to_slice(buf[1..4].as_mut());
187                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
188            }
189            0b11 => {
190                Self::decode_remaining(r, 7)?;
191                r.copy_to_slice(buf[1..8].as_mut());
192                u64::from_be_bytes(buf)
193            }
194            _ => unreachable!(),
195        };
196
197        Ok(Self(x))
198    }
199}
200
201impl Encode for VarInt {
202    /// Encode a varint to the given writer.
203    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
204        let x = self.0;
205        if x < 2u64.pow(6) {
206            Self::encode_remaining(w, 1)?;
207            w.put_u8(x as u8)
208        } else if x < 2u64.pow(14) {
209            Self::encode_remaining(w, 2)?;
210            w.put_u16((0b01 << 14) | x as u16)
211        } else if x < 2u64.pow(30) {
212            Self::encode_remaining(w, 4)?;
213            w.put_u32((0b10 << 30) | x as u32)
214        } else if x < 2u64.pow(62) {
215            Self::encode_remaining(w, 8)?;
216            w.put_u64((0b11 << 62) | x)
217        } else {
218            return Err(BoundsExceeded.into());
219        }
220
221        Ok(())
222    }
223}
224
225// It is doubtful the MOQ specs would ever ask us to encode/decode a u64 to the wire directly without
226// VarInt encoding. These encode/decode methods offer some nice syntactic sugar.
227impl Encode for u64 {
228    /// Encode a varint to the given writer.
229    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
230        VarInt::try_from(*self)?.encode(w)
231    }
232}
233
234impl Decode for u64 {
235    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
236        VarInt::decode(r).map(|v| v.into_inner())
237    }
238}
239
240// The MOQ specs would never ask us to encode/decode a usize to the wire directly without VarInt
241// encoding, since it's actual size is depended on 32bit vs 64bit compilations.  These encode/decode
242// methods offer some nice syntactic sugar.
243impl Encode for usize {
244    /// Encode a varint to the given writer.
245    fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
246        let var = VarInt::try_from(*self)?;
247        var.encode(w)
248    }
249}
250
251impl Decode for usize {
252    fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
253        let var = VarInt::decode(r)?;
254        // Note: If 32-bit system, then VarInt may not fit into usize
255        #[allow(clippy::unnecessary_fallible_conversions)]
256        usize::try_from(var).map_err(|_| DecodeError::BoundsExceeded(BoundsExceeded))
257    }
258}
259
260#[cfg(test)]
261mod tests {
262    use super::*;
263    use bytes::BytesMut;
264
265    #[test]
266    fn encode_decode_usize() {
267        let mut buf = BytesMut::new();
268
269        let i: usize = 123;
270        i.encode(&mut buf).unwrap();
271        assert_eq!(buf.to_vec(), vec![0x40, 0x7b]); // first 2 bits are 01
272        let decoded = usize::decode(&mut buf).unwrap();
273        assert_eq!(decoded, i);
274    }
275
276    #[test]
277    fn encode_usize_overflow() {
278        let i: u64 = 4611686018427387904;
279        // This test is only applicable on 64-bit systems
280        if i < usize::MAX as u64 {
281            let i = i as usize;
282            let mut buf = BytesMut::new();
283            let encoded = i.encode(&mut buf);
284            assert!(matches!(
285                encoded.unwrap_err(),
286                EncodeError::BoundsExceeded(_)
287            ));
288        }
289    }
290
291    #[test]
292    fn encode_decode_u64() {
293        let mut buf = BytesMut::new();
294
295        let i: u64 = 123;
296        i.encode(&mut buf).unwrap();
297        assert_eq!(buf.to_vec(), vec![0x40, 0x7b]); // first 2 bits are 01
298        let decoded = u64::decode(&mut buf).unwrap();
299        assert_eq!(decoded, i);
300    }
301
302    #[test]
303    fn encode_u64_overflow() {
304        let mut buf = BytesMut::new();
305
306        let i: u64 = 4611686018427387904;
307        let encoded = i.encode(&mut buf);
308        assert!(matches!(
309            encoded.unwrap_err(),
310            EncodeError::BoundsExceeded(_)
311        ));
312    }
313
314    #[test]
315    fn encode_decode_varint() {
316        let mut buf = BytesMut::new();
317
318        // 0 -> 1 byte
319        let i = 0;
320        let vi = VarInt(i);
321        vi.encode(&mut buf).unwrap();
322        assert_eq!(buf.to_vec(), vec![0b0000_0000]); // first 2 bits are 00
323        let decoded = VarInt::decode(&mut buf).unwrap();
324        assert_eq!(decoded, vi);
325        assert_eq!(u64::from(decoded), i);
326
327        // 63 -> 1 byte
328        let i = 63;
329        let vi = VarInt(i);
330        vi.encode(&mut buf).unwrap();
331        assert_eq!(buf.to_vec(), vec![0b0011_1111]); // first 2 bits are 00
332        let decoded = VarInt::decode(&mut buf).unwrap();
333        assert_eq!(decoded, vi);
334        assert_eq!(u64::from(decoded), i);
335
336        // 64 -> 2 bytes
337        let i = 64;
338        let vi = VarInt(i);
339        vi.encode(&mut buf).unwrap();
340        assert_eq!(buf.to_vec(), vec![0b0100_0000, 0b0100_0000]); // first 2 bits are 01
341        let decoded = VarInt::decode(&mut buf).unwrap();
342        assert_eq!(decoded, vi);
343        assert_eq!(u64::from(decoded), i);
344
345        // 16383 -> 2 bytes
346        let i = 16383;
347        let vi = VarInt(i);
348        vi.encode(&mut buf).unwrap();
349        assert_eq!(buf.to_vec(), vec![0b0111_1111, 0xff]); // first 2 bits are 01
350        let decoded = VarInt::decode(&mut buf).unwrap();
351        assert_eq!(decoded, vi);
352        assert_eq!(u64::from(decoded), i);
353
354        // 16384 -> 4 bytes
355        let i = 16384;
356        let vi = VarInt(i);
357        vi.encode(&mut buf).unwrap();
358        assert_eq!(buf.to_vec(), vec![0b1000_0000, 0x00, 0x40, 0x00]); // first 2 bits are 10
359        let decoded = VarInt::decode(&mut buf).unwrap();
360        assert_eq!(decoded, vi);
361        assert_eq!(u64::from(decoded), i);
362
363        // 1073741823 -> 4 bytes
364        let i = 1073741823;
365        let vi = VarInt(i);
366        vi.encode(&mut buf).unwrap();
367        assert_eq!(buf.to_vec(), vec![0b1011_1111, 0xff, 0xff, 0xff]); // first 2 bits are 10
368        let decoded = VarInt::decode(&mut buf).unwrap();
369        assert_eq!(decoded, vi);
370        assert_eq!(u64::from(decoded), i);
371
372        // 1073741824 -> 8 bytes
373        let i = 1073741824;
374        let vi = VarInt(i);
375        vi.encode(&mut buf).unwrap();
376        assert_eq!(
377            buf.to_vec(),
378            // first 2 bits are 11
379            vec![0b1100_0000, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00]
380        );
381        let decoded = VarInt::decode(&mut buf).unwrap();
382        assert_eq!(decoded, vi);
383        assert_eq!(u64::from(decoded), i);
384
385        // 4611686018427387903 -> 8 bytes
386        let i = 4611686018427387903;
387        let vi = VarInt(i);
388        vi.encode(&mut buf).unwrap();
389        assert_eq!(
390            buf.to_vec(),
391            // first 2 bits are 11
392            vec![0b1111_1111, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
393        );
394        let decoded = VarInt::decode(&mut buf).unwrap();
395        assert_eq!(decoded, vi);
396        assert_eq!(u64::from(decoded), i);
397    }
398
399    #[test]
400    fn overflow() {
401        let mut buf = BytesMut::new();
402
403        let i = 4611686018427387904;
404        let vi = VarInt(i);
405        let decoded = vi.encode(&mut buf);
406        assert!(matches!(
407            decoded.unwrap_err(),
408            EncodeError::BoundsExceeded(_)
409        ));
410    }
411}