Skip to main content

ant_quic/
varint.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{convert::TryInto, fmt};
9
10use bytes::{Buf, BufMut};
11use thiserror::Error;
12
13use crate::coding::{self, Codec, UnexpectedEnd};
14
15#[cfg(feature = "arbitrary")]
16use arbitrary::Arbitrary;
17
18/// An integer less than 2^62
19///
20/// Values of this type are suitable for encoding as QUIC variable-length integer.
21// It would be neat if we could express to Rust that the top two bits are available for use as enum
22// discriminants
23#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
24pub struct VarInt(pub(crate) u64);
25
26impl VarInt {
27    /// The largest representable value
28    pub const MAX: Self = Self((1 << 62) - 1);
29    /// The largest encoded value length
30    pub const MAX_SIZE: usize = 8;
31
32    /// Create a VarInt from a value that is guaranteed to be in range
33    ///
34    /// This should only be used when the value is known at compile time or
35    /// has been validated to be less than 2^62.
36    #[inline]
37    pub(crate) fn from_u64_bounded(x: u64) -> Self {
38        debug_assert!(x < 2u64.pow(62), "VarInt value {} exceeds maximum", x);
39        // Safety: caller guarantees the bound.
40        unsafe { Self::from_u64_unchecked(x) }
41    }
42
43    /// Construct a `VarInt` infallibly
44    pub const fn from_u32(x: u32) -> Self {
45        Self(x as u64)
46    }
47
48    /// Succeeds iff `x` < 2^62
49    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
50        if x < 2u64.pow(62) {
51            Ok(Self(x))
52        } else {
53            Err(VarIntBoundsExceeded)
54        }
55    }
56
57    /// Create a VarInt without ensuring it's in range
58    ///
59    /// # Safety
60    ///
61    /// `x` must be less than 2^62.
62    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
63        Self(x)
64    }
65
66    /// Extract the integer value
67    pub const fn into_inner(self) -> u64 {
68        self.0
69    }
70
71    /// Compute the number of bytes needed to encode this value
72    pub(crate) const fn size(self) -> usize {
73        let x = self.0;
74        if x < 2u64.pow(6) {
75            1
76        } else if x < 2u64.pow(14) {
77            2
78        } else if x < 2u64.pow(30) {
79            4
80        } else if x < 2u64.pow(62) {
81            8
82        } else {
83            Self::MAX_SIZE
84        }
85    }
86
87    pub(crate) fn encode_checked<B: BufMut>(x: u64, w: &mut B) -> Result<(), VarIntBoundsExceeded> {
88        if x < 2u64.pow(6) {
89            w.put_u8(x as u8);
90            Ok(())
91        } else if x < 2u64.pow(14) {
92            w.put_u16((0b01 << 14) | x as u16);
93            Ok(())
94        } else if x < 2u64.pow(30) {
95            w.put_u32((0b10 << 30) | x as u32);
96            Ok(())
97        } else if x < 2u64.pow(62) {
98            w.put_u64((0b11 << 62) | x);
99            Ok(())
100        } else {
101            Err(VarIntBoundsExceeded)
102        }
103    }
104}
105
106impl From<VarInt> for u64 {
107    fn from(x: VarInt) -> Self {
108        x.0
109    }
110}
111
112impl From<u8> for VarInt {
113    fn from(x: u8) -> Self {
114        Self(x.into())
115    }
116}
117
118impl From<u16> for VarInt {
119    fn from(x: u16) -> Self {
120        Self(x.into())
121    }
122}
123
124impl From<u32> for VarInt {
125    fn from(x: u32) -> Self {
126        Self(x.into())
127    }
128}
129
130impl std::convert::TryFrom<u64> for VarInt {
131    type Error = VarIntBoundsExceeded;
132    /// Succeeds iff `x` < 2^62
133    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
134        Self::from_u64(x)
135    }
136}
137
138impl std::convert::TryFrom<u128> for VarInt {
139    type Error = VarIntBoundsExceeded;
140    /// Succeeds iff `x` < 2^62
141    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
142        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
143    }
144}
145
146impl std::convert::TryFrom<usize> for VarInt {
147    type Error = VarIntBoundsExceeded;
148    /// Succeeds iff `x` < 2^62
149    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
150        Self::try_from(x as u64)
151    }
152}
153
154impl fmt::Debug for VarInt {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        self.0.fmt(f)
157    }
158}
159
160impl fmt::Display for VarInt {
161    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162        self.0.fmt(f)
163    }
164}
165
166#[cfg(feature = "arbitrary")]
167impl<'arbitrary> Arbitrary<'arbitrary> for VarInt {
168    fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
169        Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
170    }
171}
172
173/// Error returned when constructing a `VarInt` from a value >= 2^62
174#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
175#[error("value too large for varint encoding")]
176pub struct VarIntBoundsExceeded;
177
178impl Codec for VarInt {
179    fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
180        if !r.has_remaining() {
181            return Err(UnexpectedEnd);
182        }
183        let mut buf = [0; 8];
184        buf[0] = r.get_u8();
185        let tag = buf[0] >> 6;
186        buf[0] &= 0b0011_1111;
187        let x = match tag {
188            0b00 => u64::from(buf[0]),
189            0b01 => {
190                if r.remaining() < 1 {
191                    return Err(UnexpectedEnd);
192                }
193                r.copy_to_slice(&mut buf[1..2]);
194                // Safe: buf[..2] is exactly 2 bytes
195                u64::from(u16::from_be_bytes([buf[0], buf[1]]))
196            }
197            0b10 => {
198                if r.remaining() < 3 {
199                    return Err(UnexpectedEnd);
200                }
201                r.copy_to_slice(&mut buf[1..4]);
202                // Safe: buf[..4] is exactly 4 bytes
203                u64::from(u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]))
204            }
205            0b11 => {
206                if r.remaining() < 7 {
207                    return Err(UnexpectedEnd);
208                }
209                r.copy_to_slice(&mut buf[1..8]);
210                u64::from_be_bytes(buf)
211            }
212            _ => unreachable!(),
213        };
214        Ok(Self(x))
215    }
216
217    fn encode<B: BufMut>(&self, w: &mut B) {
218        if let Err(_) = Self::encode_checked(self.0, w) {
219            tracing::error!("VarInt overflow: {} exceeds maximum", self.0);
220            debug_assert!(false, "VarInt overflow: {}", self.0);
221        }
222    }
223}