use std::convert::{TryFrom, TryInto};
use std::fmt;
use thiserror::Error;
use super::{Decode, DecodeError, Encode, EncodeError};
#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
#[error("value out of range")]
pub struct BoundsExceeded;
#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct VarInt(u64);
impl VarInt {
pub const MAX: Self = Self((1 << 62) - 1);
pub const ZERO: Self = Self(0);
pub const fn from_u32(x: u32) -> Self {
Self(x as u64)
}
pub const fn from_u64(x: u64) -> Option<Self> {
if x <= Self::MAX.0 { Some(Self(x)) } else { None }
}
pub const fn from_u128(x: u128) -> Option<Self> {
if x <= Self::MAX.0 as u128 {
Some(Self(x as u64))
} else {
None
}
}
pub const fn into_inner(self) -> u64 {
self.0
}
}
impl From<VarInt> for u64 {
fn from(x: VarInt) -> Self {
x.0
}
}
impl From<VarInt> for usize {
fn from(x: VarInt) -> Self {
x.0 as usize
}
}
impl From<VarInt> for u128 {
fn from(x: VarInt) -> Self {
x.0 as u128
}
}
impl From<u8> for VarInt {
fn from(x: u8) -> Self {
Self(x.into())
}
}
impl From<u16> for VarInt {
fn from(x: u16) -> Self {
Self(x.into())
}
}
impl From<u32> for VarInt {
fn from(x: u32) -> Self {
Self(x.into())
}
}
impl TryFrom<u64> for VarInt {
type Error = BoundsExceeded;
fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
let x = Self(x);
if x <= Self::MAX { Ok(x) } else { Err(BoundsExceeded) }
}
}
impl TryFrom<u128> for VarInt {
type Error = BoundsExceeded;
fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
if x <= Self::MAX.into() {
Ok(Self(x as u64))
} else {
Err(BoundsExceeded)
}
}
}
impl TryFrom<usize> for VarInt {
type Error = BoundsExceeded;
fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
Self::try_from(x as u64)
}
}
impl TryFrom<VarInt> for u32 {
type Error = BoundsExceeded;
fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
if x.0 <= u32::MAX.into() {
Ok(x.0 as u32)
} else {
Err(BoundsExceeded)
}
}
}
impl TryFrom<VarInt> for u16 {
type Error = BoundsExceeded;
fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
if x.0 <= u16::MAX.into() {
Ok(x.0 as u16)
} else {
Err(BoundsExceeded)
}
}
}
impl TryFrom<VarInt> for u8 {
type Error = BoundsExceeded;
fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
if x.0 <= u8::MAX.into() {
Ok(x.0 as u8)
} else {
Err(BoundsExceeded)
}
}
}
impl fmt::Display for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl VarInt {
fn decode_quic<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
if !r.has_remaining() {
return Err(DecodeError::Short);
}
let b = r.get_u8();
let tag = b >> 6;
let mut buf = [0u8; 8];
buf[0] = b & 0b0011_1111;
let x = match tag {
0b00 => u64::from(buf[0]),
0b01 => {
if !r.has_remaining() {
return Err(DecodeError::Short);
}
r.copy_to_slice(buf[1..2].as_mut());
u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
}
0b10 => {
if r.remaining() < 3 {
return Err(DecodeError::Short);
}
r.copy_to_slice(buf[1..4].as_mut());
u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
}
0b11 => {
if r.remaining() < 7 {
return Err(DecodeError::Short);
}
r.copy_to_slice(buf[1..8].as_mut());
u64::from_be_bytes(buf)
}
_ => unreachable!(),
};
Ok(Self(x))
}
fn encode_quic<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
let remaining = w.remaining_mut();
if self.0 < (1u64 << 6) {
if remaining < 1 {
return Err(EncodeError::Short);
}
w.put_u8(self.0 as u8);
} else if self.0 < (1u64 << 14) {
if remaining < 2 {
return Err(EncodeError::Short);
}
w.put_u16((0b01 << 14) | self.0 as u16);
} else if self.0 < (1u64 << 30) {
if remaining < 4 {
return Err(EncodeError::Short);
}
w.put_u32((0b10 << 30) | self.0 as u32);
} else if self.0 < (1u64 << 62) {
if remaining < 8 {
return Err(EncodeError::Short);
}
w.put_u64((0b11 << 62) | self.0);
} else {
return Err(BoundsExceeded.into());
}
Ok(())
}
fn decode_leading_ones<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
if !r.has_remaining() {
return Err(DecodeError::Short);
}
let b = r.get_u8();
let ones = b.leading_ones() as usize;
match ones {
0 => {
Ok(Self(u64::from(b)))
}
1 => {
if !r.has_remaining() {
return Err(DecodeError::Short);
}
let hi = u64::from(b & 0x3F);
let lo = u64::from(r.get_u8());
Ok(Self((hi << 8) | lo))
}
2 => {
if r.remaining() < 2 {
return Err(DecodeError::Short);
}
let hi = u64::from(b & 0x1F);
let mut buf = [0u8; 2];
r.copy_to_slice(&mut buf);
Ok(Self((hi << 16) | u64::from(u16::from_be_bytes(buf))))
}
3 => {
if r.remaining() < 3 {
return Err(DecodeError::Short);
}
let hi = u64::from(b & 0x0F);
let mut buf = [0u8; 3];
r.copy_to_slice(&mut buf);
Ok(Self(
(hi << 24) | u64::from(buf[0]) << 16 | u64::from(buf[1]) << 8 | u64::from(buf[2]),
))
}
4 => {
if r.remaining() < 4 {
return Err(DecodeError::Short);
}
let hi = u64::from(b & 0x07);
let mut buf = [0u8; 4];
r.copy_to_slice(&mut buf);
Ok(Self((hi << 32) | u64::from(u32::from_be_bytes(buf))))
}
5 => {
if r.remaining() < 5 {
return Err(DecodeError::Short);
}
let hi = u64::from(b & 0x03);
let mut buf = [0u8; 5];
r.copy_to_slice(&mut buf);
let lo = u64::from(buf[0]) << 32
| u64::from(buf[1]) << 24
| u64::from(buf[2]) << 16
| u64::from(buf[3]) << 8
| u64::from(buf[4]);
Ok(Self((hi << 40) | lo))
}
6 => {
Err(DecodeError::InvalidValue)?
}
7 => {
if r.remaining() < 7 {
return Err(DecodeError::Short);
}
let mut buf = [0u8; 8];
buf[0] = 0;
r.copy_to_slice(&mut buf[1..]);
Ok(Self(u64::from_be_bytes(buf)))
}
8 => {
if r.remaining() < 8 {
return Err(DecodeError::Short);
}
let mut buf = [0u8; 8];
r.copy_to_slice(&mut buf);
Ok(Self(u64::from_be_bytes(buf)))
}
_ => unreachable!(),
}
}
fn encode_leading_ones<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
let x = self.0;
let remaining = w.remaining_mut();
if x < (1 << 7) {
if remaining < 1 {
return Err(EncodeError::Short);
}
w.put_u8(x as u8);
} else if x < (1 << 14) {
if remaining < 2 {
return Err(EncodeError::Short);
}
w.put_u8(0x80 | (x >> 8) as u8);
w.put_u8(x as u8);
} else if x < (1 << 21) {
if remaining < 3 {
return Err(EncodeError::Short);
}
w.put_u8(0xC0 | (x >> 16) as u8);
w.put_u16(x as u16);
} else if x < (1 << 28) {
if remaining < 4 {
return Err(EncodeError::Short);
}
w.put_u8(0xE0 | (x >> 24) as u8);
w.put_u8((x >> 16) as u8);
w.put_u16(x as u16);
} else if x < (1 << 35) {
if remaining < 5 {
return Err(EncodeError::Short);
}
w.put_u8(0xF0 | (x >> 32) as u8);
w.put_u32(x as u32);
} else if x < (1 << 42) {
if remaining < 6 {
return Err(EncodeError::Short);
}
w.put_u8(0xF8 | (x >> 40) as u8);
w.put_u8((x >> 32) as u8);
w.put_u32(x as u32);
} else if x < (1 << 56) {
if remaining < 8 {
return Err(EncodeError::Short);
}
w.put_u8(0xFE);
w.put_u8((x >> 48) as u8);
w.put_u16((x >> 32) as u16);
w.put_u32(x as u32);
} else {
if remaining < 9 {
return Err(EncodeError::Short);
}
w.put_u8(0xFF);
w.put_u64(x);
}
Ok(())
}
}
use crate::{Version, ietf, lite};
impl Encode<lite::Version> for VarInt {
fn encode<W: bytes::BufMut>(&self, w: &mut W, _: lite::Version) -> Result<(), EncodeError> {
self.encode_quic(w)
}
}
impl Decode<lite::Version> for VarInt {
fn decode<R: bytes::Buf>(r: &mut R, _: lite::Version) -> Result<Self, DecodeError> {
Self::decode_quic(r)
}
}
impl Encode<ietf::Version> for VarInt {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: ietf::Version) -> Result<(), EncodeError> {
match version {
ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => self.encode_quic(w),
ietf::Version::Draft17 => self.encode_leading_ones(w),
}
}
}
impl Decode<ietf::Version> for VarInt {
fn decode<R: bytes::Buf>(r: &mut R, version: ietf::Version) -> Result<Self, DecodeError> {
match version {
ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => Self::decode_quic(r),
ietf::Version::Draft17 => Self::decode_leading_ones(r),
}
}
}
impl Encode<Version> for VarInt {
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
match version {
Version::Lite(v) => self.encode(w, v),
Version::Ietf(v) => self.encode(w, v),
}
}
}
impl Decode<Version> for VarInt {
fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
match version {
Version::Lite(v) => Self::decode(r, v),
Version::Ietf(v) => Self::decode(r, v),
}
}
}
impl<V: Copy> Encode<V> for u64
where
VarInt: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
VarInt::try_from(*self)?.encode(w, version)
}
}
impl<V: Copy> Decode<V> for u64
where
VarInt: Decode<V>,
{
fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
VarInt::decode(r, version).map(|v| v.into_inner())
}
}
impl<V: Copy> Encode<V> for usize
where
VarInt: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
VarInt::try_from(*self)?.encode(w, version)
}
}
impl<V: Copy> Decode<V> for usize
where
VarInt: Decode<V>,
{
fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
VarInt::decode(r, version).map(|v| v.into_inner() as usize)
}
}
impl<V: Copy> Encode<V> for u32
where
VarInt: Encode<V>,
{
fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
VarInt::from(*self).encode(w, version)
}
}
impl<V: Copy> Decode<V> for u32
where
VarInt: Decode<V>,
{
fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
let v = VarInt::decode(r, version)?;
let v = v.try_into().map_err(|_| DecodeError::BoundsExceeded)?;
Ok(v)
}
}
#[cfg(test)]
mod tests {
use super::{DecodeError, VarInt};
use bytes::Bytes;
#[test]
fn leading_ones_spec_examples() {
let cases: &[(&[u8], u64)] = &[
(&[0x25], 37),
(&[0x80, 0x25], 37),
(&[0xbb, 0xbd], 15_293),
(&[0xfa, 0xa1, 0xa0, 0xe4, 0x03, 0xd8], 2_893_212_287_960),
(
&[0xfe, 0xfa, 0x31, 0x8f, 0xa8, 0xe3, 0xca, 0x11],
70_423_237_261_249_041,
),
(
&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
18_446_744_073_709_551_615,
),
];
for (bytes, expected) in cases {
let mut buf = Bytes::from(bytes.to_vec());
let decoded = VarInt::decode_leading_ones(&mut buf).expect("decode should succeed");
assert_eq!(
decoded.into_inner(),
*expected,
"decode mismatch for bytes {bytes:02x?}"
);
assert_eq!(buf.len(), 0, "all bytes should be consumed for {bytes:02x?}");
if let Some(varint) = VarInt::from_u64(*expected)
&& (bytes.len() == 1 || *expected != 37)
{
let mut encoded = Vec::new();
varint.encode_leading_ones(&mut encoded).expect("encode should succeed");
assert_eq!(&encoded, bytes, "encode mismatch for value {expected}");
}
}
}
#[test]
fn leading_ones_invalid_0xfc() {
let mut buf = Bytes::from_static(&[0xFC]);
assert!(
matches!(VarInt::decode_leading_ones(&mut buf), Err(DecodeError::InvalidValue)),
"0xFC should be rejected as invalid"
);
}
#[test]
fn leading_ones_boundaries_round_trip() {
let cases = [
((1u64 << 7) - 1, 1usize),
(1u64 << 7, 2usize),
((1u64 << 14) - 1, 2usize),
(1u64 << 14, 3usize),
((1u64 << 56) - 1, 8usize),
(1u64 << 56, 9usize),
];
for (value, expected_len) in cases {
let varint = VarInt::from_u64(value).expect("value should be representable as VarInt");
let mut encoded = Vec::new();
varint
.encode_leading_ones(&mut encoded)
.expect("leading-ones encode should succeed");
assert_eq!(
encoded.len(),
expected_len,
"unexpected encoded length for value {value}"
);
let mut bytes = Bytes::from(encoded);
let decoded = VarInt::decode_leading_ones(&mut bytes).expect("leading-ones decode should succeed");
assert_eq!(decoded.into_inner(), value, "round-trip mismatch for value {value}");
}
}
}