h3/proto/
varint.rs

1use std::{convert::TryInto, fmt, ops::Div};
2
3use bytes::{Buf, BufMut};
4
5pub use super::coding::UnexpectedEnd;
6
7/// An integer less than 2^62
8///
9/// Values of this type are suitable for encoding as QUIC variable-length integer.
10// It would be neat if we could express to Rust that the top two bits are available for use as enum
11// discriminants
12#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
13pub struct VarInt(pub(crate) u64);
14
15impl Div<u64> for VarInt {
16    type Output = Self;
17
18    fn div(self, rhs: u64) -> Self::Output {
19        Self(self.0 / rhs)
20    }
21}
22
23impl VarInt {
24    /// The largest representable value
25    pub const MAX: VarInt = VarInt((1 << 62) - 1);
26    /// The largest encoded value length
27    pub const MAX_SIZE: usize = 8;
28
29    /// Construct a `VarInt` infallibly
30    pub const fn from_u32(x: u32) -> Self {
31        VarInt(x as u64)
32    }
33
34    /// Succeeds iff `x` < 2^62
35    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
36        if x < 2u64.pow(62) {
37            Ok(VarInt(x))
38        } else {
39            Err(VarIntBoundsExceeded(x))
40        }
41    }
42
43    /// Create a VarInt without ensuring it's in range
44    ///
45    /// # Safety
46    ///
47    /// `x` must be less than 2^62.
48    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
49        VarInt(x)
50    }
51
52    /// Extract the integer value
53    pub const fn into_inner(self) -> u64 {
54        self.0
55    }
56
57    /// Compute the number of bytes needed to encode this value
58    pub fn size(self) -> usize {
59        let x = self.0;
60        if x < 2u64.pow(6) {
61            1
62        } else if x < 2u64.pow(14) {
63            2
64        } else if x < 2u64.pow(30) {
65            4
66        } else if x < 2u64.pow(62) {
67            8
68        } else {
69            unreachable!("malformed VarInt");
70        }
71    }
72
73    /// Length of an encoded value from its first byte
74    pub fn encoded_size(first: u8) -> usize {
75        2usize.pow((first >> 6) as u32)
76    }
77
78    pub fn decode<B: Buf>(r: &mut B) -> Result<Self, UnexpectedEnd> {
79        if !r.has_remaining() {
80            return Err(UnexpectedEnd(0));
81        }
82        let mut buf = [0; 8];
83        buf[0] = r.get_u8();
84        let tag = buf[0] >> 6;
85        buf[0] &= 0b0011_1111;
86        let x = match tag {
87            0b00 => u64::from(buf[0]),
88            0b01 => {
89                if r.remaining() < 1 {
90                    return Err(UnexpectedEnd(1));
91                }
92                r.copy_to_slice(&mut buf[1..2]);
93                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
94            }
95            0b10 => {
96                if r.remaining() < 3 {
97                    return Err(UnexpectedEnd(2));
98                }
99                r.copy_to_slice(&mut buf[1..4]);
100                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
101            }
102            0b11 => {
103                if r.remaining() < 7 {
104                    return Err(UnexpectedEnd(3));
105                }
106                r.copy_to_slice(&mut buf[1..8]);
107                u64::from_be_bytes(buf)
108            }
109            _ => unreachable!(),
110        };
111        Ok(VarInt(x))
112    }
113
114    pub fn encode<B: BufMut>(&self, w: &mut B) {
115        let x = self.0;
116        if x < 2u64.pow(6) {
117            w.put_u8(x as u8);
118        } else if x < 2u64.pow(14) {
119            w.put_u16(0b01 << 14 | x as u16);
120        } else if x < 2u64.pow(30) {
121            w.put_u32(0b10 << 30 | x as u32);
122        } else if x < 2u64.pow(62) {
123            w.put_u64(0b11 << 62 | x);
124        } else {
125            unreachable!("malformed VarInt")
126        }
127    }
128}
129
130impl From<VarInt> for u64 {
131    fn from(x: VarInt) -> u64 {
132        x.0
133    }
134}
135
136impl From<u8> for VarInt {
137    fn from(x: u8) -> Self {
138        VarInt(x.into())
139    }
140}
141
142impl From<u16> for VarInt {
143    fn from(x: u16) -> Self {
144        VarInt(x.into())
145    }
146}
147
148impl From<u32> for VarInt {
149    fn from(x: u32) -> Self {
150        VarInt(x.into())
151    }
152}
153
154impl std::convert::TryFrom<u64> for VarInt {
155    type Error = VarIntBoundsExceeded;
156    /// Succeeds iff `x` < 2^62
157    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
158        VarInt::from_u64(x)
159    }
160}
161
162impl std::convert::TryFrom<usize> for VarInt {
163    type Error = VarIntBoundsExceeded;
164    /// Succeeds iff `x` < 2^62
165    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
166        VarInt::try_from(x as u64)
167    }
168}
169
170impl fmt::Debug for VarInt {
171    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172        self.0.fmt(f)
173    }
174}
175
176impl fmt::Display for VarInt {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        self.0.fmt(f)
179    }
180}
181
182pub trait BufExt {
183    fn get_var(&mut self) -> Result<u64, UnexpectedEnd>;
184}
185
186impl<T: Buf> BufExt for T {
187    fn get_var(&mut self) -> Result<u64, UnexpectedEnd> {
188        Ok(VarInt::decode(self)?.into_inner())
189    }
190}
191
192pub trait BufMutExt {
193    fn write_var(&mut self, x: u64);
194}
195
196impl<T: BufMut> BufMutExt for T {
197    fn write_var(&mut self, x: u64) {
198        VarInt::from_u64(x).unwrap().encode(self);
199    }
200}
201/// Error returned when constructing a `VarInt` from a value >= 2^62
202#[derive(Debug, Copy, Clone, Eq, PartialEq)]
203pub struct VarIntBoundsExceeded(pub(crate) u64);