ant_quic/
varint.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{convert::TryInto, fmt};
9
10use bytes::{Buf, BufMut};
11use thiserror::Error;
12
13use crate::coding::{self, Codec, UnexpectedEnd};
14
15#[cfg(feature = "arbitrary")]
16use arbitrary::Arbitrary;
17
18/// An integer less than 2^62
19///
20/// Values of this type are suitable for encoding as QUIC variable-length integer.
21// It would be neat if we could express to Rust that the top two bits are available for use as enum
22// discriminants
23#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
24pub struct VarInt(pub(crate) u64);
25
26impl VarInt {
27    /// The largest representable value
28    pub const MAX: Self = Self((1 << 62) - 1);
29    /// The largest encoded value length
30    pub const MAX_SIZE: usize = 8;
31
32    /// Create a VarInt from a value that is guaranteed to be in range
33    ///
34    /// This should only be used when the value is known at compile time or
35    /// has been validated to be less than 2^62.
36    #[inline]
37    pub(crate) fn from_u64_bounded(x: u64) -> Self {
38        debug_assert!(x < 2u64.pow(62), "VarInt value {} exceeds maximum", x);
39        if x < 2u64.pow(62) {
40            Self(x)
41        } else {
42            // In production, clamp to MAX instead of panicking
43            tracing::error!("VarInt overflow: {} exceeds maximum, clamping to MAX", x);
44            Self::MAX
45        }
46    }
47
48    /// Construct a `VarInt` infallibly
49    pub const fn from_u32(x: u32) -> Self {
50        Self(x as u64)
51    }
52
53    /// Succeeds iff `x` < 2^62
54    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
55        if x < 2u64.pow(62) {
56            Ok(Self(x))
57        } else {
58            Err(VarIntBoundsExceeded)
59        }
60    }
61
62    /// Create a VarInt without ensuring it's in range
63    ///
64    /// # Safety
65    ///
66    /// `x` must be less than 2^62.
67    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
68        Self(x)
69    }
70
71    /// Extract the integer value
72    pub const fn into_inner(self) -> u64 {
73        self.0
74    }
75
76    /// Compute the number of bytes needed to encode this value
77    pub(crate) const fn size(self) -> usize {
78        let x = self.0;
79        if x < 2u64.pow(6) {
80            1
81        } else if x < 2u64.pow(14) {
82            2
83        } else if x < 2u64.pow(30) {
84            4
85        } else if x < 2u64.pow(62) {
86            8
87        } else {
88            panic!("malformed VarInt");
89        }
90    }
91}
92
93impl From<VarInt> for u64 {
94    fn from(x: VarInt) -> Self {
95        x.0
96    }
97}
98
99impl From<u8> for VarInt {
100    fn from(x: u8) -> Self {
101        Self(x.into())
102    }
103}
104
105impl From<u16> for VarInt {
106    fn from(x: u16) -> Self {
107        Self(x.into())
108    }
109}
110
111impl From<u32> for VarInt {
112    fn from(x: u32) -> Self {
113        Self(x.into())
114    }
115}
116
117impl std::convert::TryFrom<u64> for VarInt {
118    type Error = VarIntBoundsExceeded;
119    /// Succeeds iff `x` < 2^62
120    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
121        Self::from_u64(x)
122    }
123}
124
125impl std::convert::TryFrom<u128> for VarInt {
126    type Error = VarIntBoundsExceeded;
127    /// Succeeds iff `x` < 2^62
128    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
129        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
130    }
131}
132
133impl std::convert::TryFrom<usize> for VarInt {
134    type Error = VarIntBoundsExceeded;
135    /// Succeeds iff `x` < 2^62
136    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
137        Self::try_from(x as u64)
138    }
139}
140
141impl fmt::Debug for VarInt {
142    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
143        self.0.fmt(f)
144    }
145}
146
147impl fmt::Display for VarInt {
148    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
149        self.0.fmt(f)
150    }
151}
152
153#[cfg(feature = "arbitrary")]
154impl<'arbitrary> Arbitrary<'arbitrary> for VarInt {
155    fn arbitrary(u: &mut arbitrary::Unstructured<'arbitrary>) -> arbitrary::Result<Self> {
156        Ok(Self(u.int_in_range(0..=Self::MAX.0)?))
157    }
158}
159
160/// Error returned when constructing a `VarInt` from a value >= 2^62
161#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
162#[error("value too large for varint encoding")]
163pub struct VarIntBoundsExceeded;
164
165impl Codec for VarInt {
166    fn decode<B: Buf>(r: &mut B) -> coding::Result<Self> {
167        if !r.has_remaining() {
168            return Err(UnexpectedEnd);
169        }
170        let mut buf = [0; 8];
171        buf[0] = r.get_u8();
172        let tag = buf[0] >> 6;
173        buf[0] &= 0b0011_1111;
174        let x = match tag {
175            0b00 => u64::from(buf[0]),
176            0b01 => {
177                if r.remaining() < 1 {
178                    return Err(UnexpectedEnd);
179                }
180                r.copy_to_slice(&mut buf[1..2]);
181                // Safe: buf[..2] is exactly 2 bytes
182                u64::from(u16::from_be_bytes([buf[0], buf[1]]))
183            }
184            0b10 => {
185                if r.remaining() < 3 {
186                    return Err(UnexpectedEnd);
187                }
188                r.copy_to_slice(&mut buf[1..4]);
189                // Safe: buf[..4] is exactly 4 bytes
190                u64::from(u32::from_be_bytes([buf[0], buf[1], buf[2], buf[3]]))
191            }
192            0b11 => {
193                if r.remaining() < 7 {
194                    return Err(UnexpectedEnd);
195                }
196                r.copy_to_slice(&mut buf[1..8]);
197                u64::from_be_bytes(buf)
198            }
199            _ => unreachable!(),
200        };
201        Ok(Self(x))
202    }
203
204    fn encode<B: BufMut>(&self, w: &mut B) {
205        let x = self.0;
206        if x < 2u64.pow(6) {
207            w.put_u8(x as u8);
208        } else if x < 2u64.pow(14) {
209            w.put_u16((0b01 << 14) | x as u16);
210        } else if x < 2u64.pow(30) {
211            w.put_u32((0b10 << 30) | x as u32);
212        } else if x < 2u64.pow(62) {
213            w.put_u64((0b11 << 62) | x);
214        } else {
215            unreachable!("malformed VarInt")
216        }
217    }
218}