iroh_quinn_proto/
varint.rs

1use std::{convert::TryInto, fmt};
2
3use bytes::{Buf, BufMut};
4use thiserror::Error;
5
6use crate::coding::{self, Codec, UnexpectedEnd};
7
8#[cfg(feature = "arbitrary")]
9use arbitrary::Arbitrary;
10
11/// An integer less than 2^62
12///
13/// Values of this type are suitable for encoding as QUIC variable-length integer.
14// It would be neat if we could express to Rust that the top two bits are available for use as enum
15// discriminants
16#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
17pub struct VarInt(pub(crate) u64);
18
19impl VarInt {
20    /// The largest representable value
21    pub const MAX: Self = Self((1 << 62) - 1);
22    /// The largest encoded value length
23    pub const MAX_SIZE: usize = 8;
24
25    /// Construct a `VarInt` infallibly
26    pub const fn from_u32(x: u32) -> Self {
27        Self(x as u64)
28    }
29
30    /// Succeeds iff `x` < 2^62
31    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
32        if x < 2u64.pow(62) {
33            Ok(Self(x))
34        } else {
35            Err(VarIntBoundsExceeded)
36        }
37    }
38
39    /// Create a VarInt without ensuring it's in range
40    ///
41    /// # Safety
42    ///
43    /// `x` must be less than 2^62.
44    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
45        Self(x)
46    }
47
48    /// Extract the integer value
49    pub const fn into_inner(self) -> u64 {
50        self.0
51    }
52
53    /// Saturating integer addition. Computes self + rhs, saturating at the numeric bounds instead
54    /// of overflowing.
55    pub fn saturating_add(self, rhs: impl Into<Self>) -> Self {
56        let rhs = rhs.into();
57        let inner = self.0.saturating_add(rhs.0).min(Self::MAX.0);
58        Self(inner)
59    }
60
61    /// Compute the number of bytes needed to encode this value
62    pub(crate) const fn size(self) -> usize {
63        let x = self.0;
64        if x < 2u64.pow(6) {
65            1
66        } else if x < 2u64.pow(14) {
67            2
68        } else if x < 2u64.pow(30) {
69            4
70        } else if x < 2u64.pow(62) {
71            8
72        } else {
73            panic!("malformed VarInt");
74        }
75    }
76}
77
78impl From<VarInt> for u64 {
79    fn from(x: VarInt) -> Self {
80        x.0
81    }
82}
83
84impl From<u8> for VarInt {
85    fn from(x: u8) -> Self {
86        Self(x.into())
87    }
88}
89
90impl From<u16> for VarInt {
91    fn from(x: u16) -> Self {
92        Self(x.into())
93    }
94}
95
96impl From<u32> for VarInt {
97    fn from(x: u32) -> Self {
98        Self(x.into())
99    }
100}
101
102impl std::convert::TryFrom<u64> for VarInt {
103    type Error = VarIntBoundsExceeded;
104    /// Succeeds iff `x` < 2^62
105    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
106        Self::from_u64(x)
107    }
108}
109
110impl std::convert::TryFrom<u128> for VarInt {
111    type Error = VarIntBoundsExceeded;
112    /// Succeeds iff `x` < 2^62
113    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
114        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
115    }
116}
117
118impl std::convert::TryFrom<usize> for VarInt {
119    type Error = VarIntBoundsExceeded;
120    /// Succeeds iff `x` < 2^62
121    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
122        Self::try_from(x as u64)
123    }
124}
125
126impl fmt::Debug for VarInt {
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        self.0.fmt(f)
129    }
130}
131
132impl fmt::Display for VarInt {
133    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
134        self.0.fmt(f)
135    }
136}
137
138#[cfg(feature = "arbitrary")]
139impl<'arbitrary> Arbitrary<'arbitrary> for VarInt {
140    fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
141        Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
142    }
143}
144
145/// Error returned when constructing a `VarInt` from a value >= 2^62
146#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
147#[error("value too large for varint encoding")]
148pub struct VarIntBoundsExceeded;
149
150impl Codec for VarInt {
151    fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
152        if !r.has_remaining() {
153            return Err(UnexpectedEnd);
154        }
155        let mut buf = [0; 8];
156        buf[0] = r.get_u8();
157        let tag = buf[0] >> 6;
158        buf[0] &= 0b0011_1111;
159        let x = match tag {
160            0b00 => u64::from(buf[0]),
161            0b01 => {
162                if r.remaining() < 1 {
163                    return Err(UnexpectedEnd);
164                }
165                r.copy_to_slice(&mut buf[1..2]);
166                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
167            }
168            0b10 => {
169                if r.remaining() < 3 {
170                    return Err(UnexpectedEnd);
171                }
172                r.copy_to_slice(&mut buf[1..4]);
173                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
174            }
175            0b11 => {
176                if r.remaining() < 7 {
177                    return Err(UnexpectedEnd);
178                }
179                r.copy_to_slice(&mut buf[1..8]);
180                u64::from_be_bytes(buf)
181            }
182            _ => unreachable!(),
183        };
184        Ok(Self(x))
185    }
186
187    fn encode<B: BufMut>(&self, w: &mut B) {
188        let x = self.0;
189        if x < 2u64.pow(6) {
190            w.put_u8(x as u8);
191        } else if x < 2u64.pow(14) {
192            w.put_u16(0b01 << 14 | x as u16);
193        } else if x < 2u64.pow(30) {
194            w.put_u32(0b10 << 30 | x as u32);
195        } else if x < 2u64.pow(62) {
196            w.put_u64(0b11 << 62 | x);
197        } else {
198            unreachable!("malformed VarInt")
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_saturating_add() {
209        // add within range behaves normally
210        let large: VarInt = u32::MAX.into();
211        let next = u64::from(u32::MAX) + 1;
212        assert_eq!(large.saturating_add(1u8), VarInt::from_u64(next).unwrap());
213
214        // outside range saturates
215        assert_eq!(VarInt::MAX.saturating_add(1u8), VarInt::MAX)
216    }
217}