use crate::webtransport::error::{WtError, WtErrorKind, WtResult};
pub const MAX_VALUE: u64 = 4_611_686_018_427_387_903;
const MAX_1_BYTE: u64 = 63;
const MAX_2_BYTES: u64 = 16_383;
const MAX_4_BYTES: u64 = 1_073_741_823;
#[must_use]
pub const fn encoded_len(value: u64) -> usize {
if value <= MAX_1_BYTE {
1
} else if value <= MAX_2_BYTES {
2
} else if value <= MAX_4_BYTES {
4
} else {
8
}
}
#[track_caller]
pub fn encode(value: u64, buf: &mut [u8]) -> WtResult<usize> {
if value > MAX_VALUE {
return Err(WtError::with_reason(
WtErrorKind::InvalidInput,
format!("varint value {value} exceeds maximum {MAX_VALUE}"),
));
}
let len = encoded_len(value);
if buf.len() < len {
return Err(WtError::with_reason(
WtErrorKind::BufferTooShort,
format!("need {len} bytes, got {}", buf.len()),
));
}
match len {
1 => {
buf[0] = value as u8;
}
2 => {
buf[0] = 0x40 | ((value >> 8) as u8);
buf[1] = value as u8;
}
4 => {
buf[0] = 0x80 | ((value >> 24) as u8);
buf[1] = (value >> 16) as u8;
buf[2] = (value >> 8) as u8;
buf[3] = value as u8;
}
8 => {
buf[0] = 0xc0 | ((value >> 56) as u8);
buf[1] = (value >> 48) as u8;
buf[2] = (value >> 40) as u8;
buf[3] = (value >> 32) as u8;
buf[4] = (value >> 24) as u8;
buf[5] = (value >> 16) as u8;
buf[6] = (value >> 8) as u8;
buf[7] = value as u8;
}
_ => unreachable!(),
}
Ok(len)
}
#[track_caller]
pub fn encode_to_vec(value: u64) -> WtResult<Vec<u8>> {
let mut buf = vec![0u8; encoded_len(value)];
encode(value, &mut buf)?;
Ok(buf)
}
#[track_caller]
pub fn decode(buf: &[u8]) -> WtResult<(u64, usize)> {
if buf.is_empty() {
return Err(WtError::new(WtErrorKind::Incomplete));
}
let first = buf[0];
let prefix = first >> 6;
let (value, len) = match prefix {
0 => {
(u64::from(first & 0x3f), 1)
}
1 => {
if buf.len() < 2 {
return Err(WtError::new(WtErrorKind::Incomplete));
}
let value = (u64::from(first & 0x3f) << 8) | u64::from(buf[1]);
(value, 2)
}
2 => {
if buf.len() < 4 {
return Err(WtError::new(WtErrorKind::Incomplete));
}
let value = (u64::from(first & 0x3f) << 24)
| (u64::from(buf[1]) << 16)
| (u64::from(buf[2]) << 8)
| u64::from(buf[3]);
(value, 4)
}
3 => {
if buf.len() < 8 {
return Err(WtError::new(WtErrorKind::Incomplete));
}
let value = (u64::from(first & 0x3f) << 56)
| (u64::from(buf[1]) << 48)
| (u64::from(buf[2]) << 40)
| (u64::from(buf[3]) << 32)
| (u64::from(buf[4]) << 24)
| (u64::from(buf[5]) << 16)
| (u64::from(buf[6]) << 8)
| u64::from(buf[7]);
(value, 8)
}
_ => unreachable!(),
};
if encoded_len(value) != len {
return Err(WtError::with_reason(
WtErrorKind::InvalidInput,
format!(
"non-minimal varint encoding: value {value} encoded in {len} bytes, minimum is {}",
encoded_len(value)
),
));
}
Ok((value, len))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encoded_len() {
assert_eq!(encoded_len(0), 1);
assert_eq!(encoded_len(63), 1);
assert_eq!(encoded_len(64), 2);
assert_eq!(encoded_len(16383), 2);
assert_eq!(encoded_len(16384), 4);
assert_eq!(encoded_len(1073741823), 4);
assert_eq!(encoded_len(1073741824), 8);
assert_eq!(encoded_len(MAX_VALUE), 8);
}
#[test]
fn test_encode_decode_1_byte() {
let mut buf = [0u8; 8];
for value in [0, 1, 37, 63] {
let len = encode(value, &mut buf).unwrap();
assert_eq!(len, 1);
let (decoded, consumed) = decode(&buf[..len]).unwrap();
assert_eq!(decoded, value);
assert_eq!(consumed, 1);
}
}
#[test]
fn test_encode_decode_2_bytes() {
let mut buf = [0u8; 8];
for value in [64, 100, 494, 16383] {
let len = encode(value, &mut buf).unwrap();
assert_eq!(len, 2);
let (decoded, consumed) = decode(&buf[..len]).unwrap();
assert_eq!(decoded, value);
assert_eq!(consumed, 2);
}
}
#[test]
fn test_encode_decode_4_bytes() {
let mut buf = [0u8; 8];
for value in [16384, 65535, 494878333, 1073741823] {
let len = encode(value, &mut buf).unwrap();
assert_eq!(len, 4);
let (decoded, consumed) = decode(&buf[..len]).unwrap();
assert_eq!(decoded, value);
assert_eq!(consumed, 4);
}
}
#[test]
fn test_encode_decode_8_bytes() {
let mut buf = [0u8; 8];
for value in [1073741824, 151288809941952652, MAX_VALUE] {
let len = encode(value, &mut buf).unwrap();
assert_eq!(len, 8);
let (decoded, consumed) = decode(&buf[..len]).unwrap();
assert_eq!(decoded, value);
assert_eq!(consumed, 8);
}
}
#[test]
fn test_encode_overflow() {
let mut buf = [0u8; 8];
let result = encode(MAX_VALUE + 1, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_encode_buffer_too_short() {
let mut buf = [0u8; 1];
let result = encode(16384, &mut buf);
assert!(result.is_err());
}
#[test]
fn test_decode_incomplete() {
assert!(decode(&[]).is_err());
assert!(decode(&[0x40]).is_err());
assert!(decode(&[0x80, 0x00, 0x00]).is_err());
assert!(decode(&[0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]).is_err());
}
#[test]
fn test_decode_non_minimal_encoding() {
assert!(decode(&[0x40, 0x0a]).is_err());
assert!(decode(&[0x40, 0x00]).is_err());
assert!(decode(&[0x40, 0x3f]).is_err());
assert!(decode(&[0x80, 0x00, 0x00, 0x40]).is_err());
assert!(decode(&[0x80, 0x00, 0x3f, 0xff]).is_err());
assert!(decode(&[0xc0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x40, 0x00]).is_err());
}
#[test]
fn test_rfc_examples() {
let mut buf = [0u8; 8];
let len = encode(37, &mut buf).unwrap();
assert_eq!(&buf[..len], &[0x25]);
let len = encode(15293, &mut buf).unwrap();
assert_eq!(&buf[..len], &[0x7b, 0xbd]);
let len = encode(494878333, &mut buf).unwrap();
assert_eq!(&buf[..len], &[0x9d, 0x7f, 0x3e, 0x7d]);
let len = encode(151288809941952652, &mut buf).unwrap();
assert_eq!(
&buf[..len],
&[0xc2, 0x19, 0x7c, 0x5e, 0xff, 0x14, 0xe8, 0x8c]
);
}
#[test]
fn test_encode_to_vec() {
let buf = encode_to_vec(37).unwrap();
assert_eq!(buf, vec![0x25]);
let buf = encode_to_vec(15293).unwrap();
assert_eq!(buf, vec![0x7b, 0xbd]);
}
}