ant_quic/
varint.rs

1use std::{convert::TryInto, fmt};
2
3use bytes::{Buf, BufMut};
4use thiserror::Error;
5
6use crate::coding::{self, Codec, UnexpectedEnd};
7
8#[cfg(feature = "arbitrary")]
9use arbitrary::Arbitrary;
10
11/// An integer less than 2^62
12///
13/// Values of this type are suitable for encoding as QUIC variable-length integer.
14// It would be neat if we could express to Rust that the top two bits are available for use as enum
15// discriminants
16#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
17pub struct VarInt(pub(crate) u64);
18
19impl VarInt {
20    /// The largest representable value
21    pub const MAX: Self = Self((1 << 62) - 1);
22    /// The largest encoded value length
23    pub const MAX_SIZE: usize = 8;
24
25    /// Create a VarInt from a value that is guaranteed to be in range
26    ///
27    /// This should only be used when the value is known at compile time or
28    /// has been validated to be less than 2^62.
29    #[inline]
30    pub(crate) fn from_u64_bounded(x: u64) -> Self {
31        debug_assert!(x < 2u64.pow(62), "VarInt value {} exceeds maximum", x);
32        if x < 2u64.pow(62) {
33            Self(x)
34        } else {
35            // In production, clamp to MAX instead of panicking
36            tracing::error!("VarInt overflow: {} exceeds maximum, clamping to MAX", x);
37            Self::MAX
38        }
39    }
40
41    /// Construct a `VarInt` infallibly
42    pub const fn from_u32(x: u32) -> Self {
43        Self(x as u64)
44    }
45
46    /// Succeeds iff `x` < 2^62
47    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
48        if x < 2u64.pow(62) {
49            Ok(Self(x))
50        } else {
51            Err(VarIntBoundsExceeded)
52        }
53    }
54
55    /// Create a VarInt without ensuring it's in range
56    ///
57    /// # Safety
58    ///
59    /// `x` must be less than 2^62.
60    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
61        Self(x)
62    }
63
64    /// Extract the integer value
65    pub const fn into_inner(self) -> u64 {
66        self.0
67    }
68
69    /// Compute the number of bytes needed to encode this value
70    pub(crate) const fn size(self) -> usize {
71        let x = self.0;
72        if x < 2u64.pow(6) {
73            1
74        } else if x < 2u64.pow(14) {
75            2
76        } else if x < 2u64.pow(30) {
77            4
78        } else if x < 2u64.pow(62) {
79            8
80        } else {
81            panic!("malformed VarInt");
82        }
83    }
84}
85
86impl From<VarInt> for u64 {
87    fn from(x: VarInt) -> Self {
88        x.0
89    }
90}
91
92impl From<u8> for VarInt {
93    fn from(x: u8) -> Self {
94        Self(x.into())
95    }
96}
97
98impl From<u16> for VarInt {
99    fn from(x: u16) -> Self {
100        Self(x.into())
101    }
102}
103
104impl From<u32> for VarInt {
105    fn from(x: u32) -> Self {
106        Self(x.into())
107    }
108}
109
110impl std::convert::TryFrom<u64> for VarInt {
111    type Error = VarIntBoundsExceeded;
112    /// Succeeds iff `x` < 2^62
113    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
114        Self::from_u64(x)
115    }
116}
117
118impl std::convert::TryFrom<u128> for VarInt {
119    type Error = VarIntBoundsExceeded;
120    /// Succeeds iff `x` < 2^62
121    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
122        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
123    }
124}
125
126impl std::convert::TryFrom<usize> for VarInt {
127    type Error = VarIntBoundsExceeded;
128    /// Succeeds iff `x` < 2^62
129    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
130        Self::try_from(x as u64)
131    }
132}
133
134impl fmt::Debug for VarInt {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        self.0.fmt(f)
137    }
138}
139
140impl fmt::Display for VarInt {
141    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142        self.0.fmt(f)
143    }
144}
145
146#[cfg(feature = "arbitrary")]
147impl<'arbitrary> Arbitrary<'arbitrary> for VarInt {
148    fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
149        Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
150    }
151}
152
153/// Error returned when constructing a `VarInt` from a value >= 2^62
154#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
155#[error("value too large for varint encoding")]
156pub struct VarIntBoundsExceeded;
157
158impl Codec for VarInt {
159    fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
160        if !r.has_remaining() {
161            return Err(UnexpectedEnd);
162        }
163        let mut buf = [0; 8];
164        buf[0] = r.get_u8();
165        let tag = buf[0] >> 6;
166        buf[0] &= 0b0011_1111;
167        let x = match tag {
168            0b00 => u64::from(buf[0]),
169            0b01 => {
170                if r.remaining() < 1 {
171                    return Err(UnexpectedEnd);
172                }
173                r.copy_to_slice(&mut buf[1..2]);
174                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
175            }
176            0b10 => {
177                if r.remaining() < 3 {
178                    return Err(UnexpectedEnd);
179                }
180                r.copy_to_slice(&mut buf[1..4]);
181                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
182            }
183            0b11 => {
184                if r.remaining() < 7 {
185                    return Err(UnexpectedEnd);
186                }
187                r.copy_to_slice(&mut buf[1..8]);
188                u64::from_be_bytes(buf)
189            }
190            _ => unreachable!(),
191        };
192        Ok(Self(x))
193    }
194
195    fn encode<B: BufMut>(&self, w: &mut B) {
196        let x = self.0;
197        if x < 2u64.pow(6) {
198            w.put_u8(x as u8);
199        } else if x < 2u64.pow(14) {
200            w.put_u16((0b01 << 14) | x as u16);
201        } else if x < 2u64.pow(30) {
202            w.put_u32((0b10 << 30) | x as u32);
203        } else if x < 2u64.pow(62) {
204            w.put_u64((0b11 << 62) | x);
205        } else {
206            unreachable!("malformed VarInt")
207        }
208    }
209}