use core::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct VarInt(u64);
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarIntError {
OutOfRange {
value: u64,
},
}
impl fmt::Display for VarIntError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::OutOfRange { value } => {
write!(
f,
"varint value {value} exceeds the maximum allowed by RFC 9000 Section 16 (2^62 - 1)"
)
}
}
}
}
impl std::error::Error for VarIntError {}
impl VarInt {
pub const MAX: Self = Self((1u64 << 62) - 1);
pub const ZERO: Self = Self(0);
pub const fn new(value: u64) -> Result<Self, VarIntError> {
if value > Self::MAX.get() {
Err(VarIntError::OutOfRange { value })
} else {
Ok(Self(value))
}
}
#[track_caller]
pub const fn from_static(value: u64) -> Self {
assert!(value <= Self::MAX.get(), "VarInt value must be <= 2^62 - 1");
Self(value)
}
pub const fn get(self) -> u64 {
self.0
}
pub const fn encoded_len(self) -> usize {
if self.0 < 64 {
1
} else if self.0 < 16_384 {
2
} else if self.0 < 1_073_741_824 {
4
} else {
8
}
}
pub(crate) const fn from_validated_parts_internal(value: u64) -> Self {
debug_assert!(
value <= Self::MAX.get(),
"VarInt invariant violated: value exceeds 2^62 - 1"
);
Self(value)
}
}
impl fmt::Display for VarInt {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl From<u8> for VarInt {
fn from(value: u8) -> Self {
Self(u64::from(value))
}
}
impl From<u16> for VarInt {
fn from(value: u16) -> Self {
Self(u64::from(value))
}
}
impl From<u32> for VarInt {
fn from(value: u32) -> Self {
Self(u64::from(value))
}
}
impl TryFrom<u64> for VarInt {
type Error = VarIntError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
Self::new(value)
}
}
impl TryFrom<usize> for VarInt {
type Error = VarIntError;
fn try_from(value: usize) -> Result<Self, Self::Error> {
Self::new(value as u64)
}
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecodeError {
BufferTooShort,
}
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EncodeError {
BufferTooShort,
}
pub fn encode_into_vec(buf: &mut Vec<u8>, value: VarInt) {
let len = value.encoded_len();
let start = buf.len();
buf.resize(start + len, 0);
encode(&mut buf[start..], value).expect("buffer is correctly sized");
}
pub fn encode(buf: &mut [u8], value: VarInt) -> Result<usize, EncodeError> {
let len = value.encoded_len();
if buf.len() < len {
return Err(EncodeError::BufferTooShort);
}
let raw = value.get();
match len {
1 => {
buf[0] = raw as u8;
}
2 => {
let v = (raw as u16) | 0x4000;
buf[..2].copy_from_slice(&v.to_be_bytes());
}
4 => {
let v = (raw as u32) | 0x8000_0000;
buf[..4].copy_from_slice(&v.to_be_bytes());
}
8 => {
let v = raw | 0xc000_0000_0000_0000;
buf[..8].copy_from_slice(&v.to_be_bytes());
}
_ => unreachable!(),
}
Ok(len)
}
pub fn decode(buf: &[u8]) -> Result<(VarInt, usize), DecodeError> {
if buf.is_empty() {
return Err(DecodeError::BufferTooShort);
}
let prefix = buf[0] >> 6;
let len = 1 << prefix;
if buf.len() < len {
return Err(DecodeError::BufferTooShort);
}
let value = match len {
1 => u64::from(buf[0] & 0x3f),
2 => {
let mut bytes = [0u8; 2];
bytes.copy_from_slice(&buf[..2]);
u64::from(u16::from_be_bytes(bytes) & 0x3fff)
}
4 => {
let mut bytes = [0u8; 4];
bytes.copy_from_slice(&buf[..4]);
u64::from(u32::from_be_bytes(bytes) & 0x3fff_ffff)
}
8 => {
let mut bytes = [0u8; 8];
bytes.copy_from_slice(&buf[..8]);
u64::from_be_bytes(bytes) & 0x3fff_ffff_ffff_ffff
}
_ => unreachable!(),
};
Ok((VarInt::from_validated_parts_internal(value), len))
}
#[inline]
pub fn peek_len(buf: &[u8]) -> Option<usize> {
buf.first().map(|&b| 1 << (b >> 6))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_rejects_out_of_range() {
assert_eq!(
VarInt::new(VarInt::MAX.get() + 1),
Err(VarIntError::OutOfRange {
value: VarInt::MAX.get() + 1,
})
);
assert_eq!(
VarInt::new(u64::MAX),
Err(VarIntError::OutOfRange { value: u64::MAX })
);
}
#[test]
fn test_from_static_const_context() {
const ZERO: VarInt = VarInt::from_static(0);
const ONE_BYTE_BOUNDARY: VarInt = VarInt::from_static(63);
const TWO_BYTE_BOUNDARY: VarInt = VarInt::from_static(16_383);
const FOUR_BYTE_BOUNDARY: VarInt = VarInt::from_static(1_073_741_823);
const MAX: VarInt = VarInt::from_static((1u64 << 62) - 1);
assert_eq!(ZERO, VarInt::ZERO);
assert_eq!(ONE_BYTE_BOUNDARY.encoded_len(), 1);
assert_eq!(TWO_BYTE_BOUNDARY.encoded_len(), 2);
assert_eq!(FOUR_BYTE_BOUNDARY.encoded_len(), 4);
assert_eq!(MAX, VarInt::MAX);
}
#[test]
fn test_error_display_contains_rfc_reference() {
let err = VarIntError::OutOfRange {
value: VarInt::MAX.get() + 1,
};
let s = format!("{err}");
assert!(s.contains(&format!("{}", VarInt::MAX.get() + 1)));
assert!(s.contains("RFC 9000"));
assert!(s.contains("exceeds"));
}
#[test]
fn test_try_from_u64_max_is_error() {
assert!(VarInt::try_from(u64::MAX).is_err());
assert!(VarInt::try_from(VarInt::MAX.get()).is_ok());
}
#[cfg(target_pointer_width = "64")]
#[test]
fn test_try_from_usize_rejects_out_of_range_on_64bit() {
let too_large = (VarInt::MAX.get() + 1) as usize;
assert!(VarInt::try_from(too_large).is_err());
}
#[test]
fn test_from_validated_parts_internal_boundaries() {
assert_eq!(VarInt::from_validated_parts_internal(0).get(), 0);
assert_eq!(
VarInt::from_validated_parts_internal(VarInt::MAX.get()),
VarInt::MAX,
);
}
#[test]
fn test_decode_buffer_too_short() {
assert_eq!(decode(&[]), Err(DecodeError::BufferTooShort));
assert_eq!(decode(&[0x40]), Err(DecodeError::BufferTooShort));
}
}