use crate::bytes::{BufMut, BytesMut};
use crate::net::atp::protocol::outcome::{AtpOutcome, ProtocolError};
use crate::types::outcome::Outcome;
use std::io;
pub const VARINT_MAX: u64 = (1u64 << 62) - 1;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VarInt(pub u64);
impl VarInt {
pub fn new(value: u64) -> AtpOutcome<Self> {
if value > VARINT_MAX {
return AtpOutcome::protocol_error(ProtocolError::InvalidVarInt);
}
Outcome::Ok(VarInt(value))
}
#[inline]
pub fn value(self) -> u64 {
self.0
}
#[inline]
pub fn from_u64_unchecked(value: u64) -> Self {
match VarInt::new(value) {
Outcome::Ok(varint) => varint,
_ => panic!("varint value exceeds maximum allowed value"),
}
}
#[inline]
pub fn encoded_len(self) -> usize {
let value = self.0;
if value < 64 {
1
} else if value < 16_384 {
2
} else if value < 1_073_741_824 {
4
} else {
8
}
}
pub fn encode(self, buf: &mut BytesMut) -> AtpOutcome<()> {
let value = self.0;
if value < 64 {
buf.put_u8(value as u8);
} else if value < 16_384 {
buf.put_u16((0x4000 | value) as u16);
} else if value < 1_073_741_824 {
buf.put_u32((0x80000000 | value) as u32);
} else if value <= VARINT_MAX {
buf.put_u64(0xC000000000000000 | value);
} else {
return AtpOutcome::protocol_error(ProtocolError::InvalidVarInt);
}
Outcome::Ok(())
}
pub fn decode(buf: &mut BytesMut) -> AtpOutcome<Option<Self>> {
if buf.is_empty() {
return Outcome::Ok(None);
}
let first_byte = buf[0];
let (length, prefix) = match first_byte >> 6 {
0b00 => (1, 0x3F), 0b01 => (2, 0x3F), 0b10 => (4, 0x3F), 0b11 => (8, 0x3F), _ => unreachable!(),
};
if buf.len() < length {
return Outcome::Ok(None); }
let value = match length {
1 => {
let _ = buf.split_to(1); (first_byte & prefix) as u64
}
2 => {
let bytes = buf.split_to(2);
let val = u16::from_be_bytes([bytes[0], bytes[1]]) & 0x3FFF;
val as u64
}
4 => {
let bytes = buf.split_to(4);
let val = u32::from_be_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) & 0x3FFFFFFF;
val as u64
}
8 => {
let bytes = buf.split_to(8);
u64::from_be_bytes([
bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]) & 0x3FFFFFFFFFFFFFFF
}
_ => unreachable!(),
};
if value > VARINT_MAX {
return AtpOutcome::protocol_error(ProtocolError::InvalidVarInt);
}
Outcome::Ok(Some(VarInt(value)))
}
pub fn peek_len(buf: &BytesMut) -> Option<usize> {
if buf.is_empty() {
return None;
}
let first_byte = buf[0];
Some(match first_byte >> 6 {
0b00 => 1,
0b01 => 2,
0b10 => 4,
0b11 => 8,
_ => unreachable!(),
})
}
}
impl TryFrom<u64> for VarInt {
type Error = VarIntError;
fn try_from(value: u64) -> Result<Self, Self::Error> {
match VarInt::new(value) {
Outcome::Ok(varint) => Ok(varint),
_ => Err(VarIntError::ValueTooLarge(value)),
}
}
}
impl From<VarInt> for u64 {
fn from(varint: VarInt) -> Self {
varint.0
}
}
#[derive(Debug, Clone, PartialEq, Eq, thiserror::Error)]
pub enum VarIntError {
#[error("varint value too large: {value} > {max}", value = .0, max = VARINT_MAX)]
ValueTooLarge(u64),
#[error("unexpected end of input while reading varint")]
UnexpectedEof,
#[error("invalid varint encoding")]
InvalidEncoding,
}
impl From<VarIntError> for io::Error {
fn from(err: VarIntError) -> Self {
match err {
VarIntError::ValueTooLarge(_) => io::Error::new(io::ErrorKind::InvalidData, err),
VarIntError::UnexpectedEof => io::Error::new(io::ErrorKind::UnexpectedEof, err),
VarIntError::InvalidEncoding => io::Error::new(io::ErrorKind::InvalidData, err),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_varint_encoding_lengths() {
let cases = [
(0, 1),
(63, 1),
(64, 2),
(16383, 2),
(16384, 4),
(1073741823, 4),
(1073741824, 8),
(VARINT_MAX, 8),
];
for (value, expected_len) in cases {
let varint = VarInt::new(value).expect("test value should be valid");
assert_eq!(varint.encoded_len(), expected_len, "value={value}");
}
}
#[test]
fn test_varint_roundtrip() {
let values = [
0, 1, 63, 64, 127, 128, 16383, 16384, 65535, 65536, 1073741823, 1073741824, VARINT_MAX,
];
for &value in &values {
let mut buf = BytesMut::new();
let varint = VarInt::new(value).expect("test value should be valid");
varint.encode(&mut buf).expect("encoding should succeed");
let decoded = VarInt::decode(&mut buf)
.expect("decoding should succeed")
.expect("should have decoded a varint");
assert_eq!(decoded.value(), value, "value={value}");
}
}
#[test]
fn test_varint_partial_decode() {
let varint = VarInt::new(16384).expect("test value should be valid"); let mut buf = BytesMut::new();
varint.encode(&mut buf).expect("encoding should succeed");
let mut partial = buf.split_to(2); assert!(
VarInt::decode(&mut partial)
.expect("partial decode should not error")
.is_none()
);
partial.put_slice(&buf);
let decoded = VarInt::decode(&mut partial)
.expect("decoding should succeed")
.expect("should have decoded a varint");
assert_eq!(decoded.value(), 16384);
}
#[test]
fn test_varint_too_large() {
assert!(VarInt::new(VARINT_MAX + 1).is_err());
assert!(VarInt::new(u64::MAX).is_err());
}
#[test]
fn test_varint_peek_len() {
let test_cases = [
(VarInt::new(0).expect("test value should be valid"), 1),
(VarInt::new(64).expect("test value should be valid"), 2),
(VarInt::new(16384).expect("test value should be valid"), 4),
(
VarInt::new(1073741824).expect("test value should be valid"),
8,
),
];
for (varint, expected_len) in test_cases {
let mut buf = BytesMut::new();
varint.encode(&mut buf).expect("encoding should succeed");
assert_eq!(
VarInt::peek_len(&buf).expect("buffer should have varint"),
expected_len
);
}
}
}