use std::io::{self, Read, Write};
const MAX_VARINT_LEN: usize = 10;
const U32_OVERFLOW_MASK: u32 = 0xFE00_0000;
const U64_OVERFLOW_MASK: u64 = 0xFE00_0000_0000_0000;
#[inline]
pub fn encode_u64(mut num: u64, buf: &mut [u8; MAX_VARINT_LEN]) -> usize {
let mut i = MAX_VARINT_LEN;
loop {
i -= 1;
buf[i] = (num as u8 & 0x7F) | 0x80;
num >>= 7;
if num == 0 {
break;
}
}
buf[MAX_VARINT_LEN - 1] &= 0x7F; MAX_VARINT_LEN - i
}
#[inline]
pub fn encode_u32(num: u32, buf: &mut [u8; MAX_VARINT_LEN]) -> usize {
encode_u64(num as u64, buf)
}
pub fn write_usize<W: Write>(w: &mut W, num: usize) -> io::Result<()> {
let mut buf = [0u8; MAX_VARINT_LEN];
let len = encode_u64(num as u64, &mut buf);
w.write_all(&buf[MAX_VARINT_LEN - len..])
}
pub fn write_u64<W: Write>(w: &mut W, num: u64) -> io::Result<()> {
let mut buf = [0u8; MAX_VARINT_LEN];
let len = encode_u64(num, &mut buf);
w.write_all(&buf[MAX_VARINT_LEN - len..])
}
pub fn write_u32<W: Write>(w: &mut W, num: u32) -> io::Result<()> {
write_u64(w, num as u64)
}
pub fn read_u64(data: &[u8]) -> Result<(u64, usize), VarIntError> {
let mut val: u64 = 0;
for (i, &byte) in data.iter().enumerate() {
if val & U64_OVERFLOW_MASK != 0 {
return Err(VarIntError::Overflow);
}
val = (val << 7) | u64::from(byte & 0x7F);
if byte & 0x80 == 0 {
return Ok((val, i + 1));
}
}
Err(VarIntError::Underflow)
}
pub fn read_u32(data: &[u8]) -> Result<(u32, usize), VarIntError> {
let mut val: u32 = 0;
for (i, &byte) in data.iter().enumerate() {
if val & U32_OVERFLOW_MASK != 0 {
return Err(VarIntError::Overflow);
}
val = (val << 7) | u32::from(byte & 0x7F);
if byte & 0x80 == 0 {
return Ok((val, i + 1));
}
}
Err(VarIntError::Underflow)
}
pub fn read_usize(data: &[u8]) -> Result<(usize, usize), VarIntError> {
let (val, len) = read_u64(data)?;
let val = usize::try_from(val).map_err(|_| VarIntError::Overflow)?;
Ok((val, len))
}
pub fn stream_read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
let mut val: u64 = 0;
let mut buf = [0u8; 1];
loop {
r.read_exact(&mut buf)?;
let byte = buf[0];
if val & U64_OVERFLOW_MASK != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"varint overflow",
));
}
val = (val << 7) | u64::from(byte & 0x7F);
if byte & 0x80 == 0 {
return Ok(val);
}
}
}
pub fn stream_read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
let mut val: u32 = 0;
let mut buf = [0u8; 1];
loop {
r.read_exact(&mut buf)?;
let byte = buf[0];
if val & U32_OVERFLOW_MASK != 0 {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"varint overflow",
));
}
val = (val << 7) | u32::from(byte & 0x7F);
if byte & 0x80 == 0 {
return Ok(val);
}
}
}
pub fn stream_read_usize<R: Read>(r: &mut R) -> io::Result<usize> {
let val = stream_read_u64(r)?;
usize::try_from(val).map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "varint overflow"))
}
#[inline]
pub fn sizeof_u64(num: u64) -> usize {
let bits = 64 - num.leading_zeros();
(bits.max(1).div_ceil(7) as usize).min(10)
}
#[inline]
pub fn sizeof_u32(num: u32) -> usize {
let bits = 32 - num.leading_zeros();
(bits.max(1).div_ceil(7) as usize).min(5)
}
#[inline]
pub fn sizeof_usize(num: usize) -> usize {
sizeof_u64(num as u64)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum VarIntError {
Underflow,
Overflow,
}
impl std::fmt::Display for VarIntError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VarIntError::Underflow => write!(f, "varint underflow (truncated input)"),
VarIntError::Overflow => write!(f, "varint overflow"),
}
}
}
impl std::error::Error for VarIntError {}
impl From<VarIntError> for io::Error {
fn from(e: VarIntError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, e)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn roundtrip_u64() {
let cases: &[u64] = &[
0,
1,
127,
128,
255,
256,
16383,
16384,
u32::MAX as u64,
u64::MAX,
];
let mut buf = [0u8; MAX_VARINT_LEN];
for &val in cases {
let len = encode_u64(val, &mut buf);
let (decoded, consumed) = read_u64(&buf[MAX_VARINT_LEN - len..]).unwrap();
assert_eq!(decoded, val, "roundtrip failed for {val}");
assert_eq!(consumed, len, "length mismatch for {val}");
assert_eq!(sizeof_u64(val), len, "sizeof mismatch for {val}");
}
}
#[test]
fn roundtrip_u32() {
let cases: &[u32] = &[0, 1, 127, 128, 16383, 16384, u32::MAX];
let mut buf = [0u8; MAX_VARINT_LEN];
for &val in cases {
let len = encode_u32(val, &mut buf);
let (decoded, consumed) = read_u32(&buf[MAX_VARINT_LEN - len..]).unwrap();
assert_eq!(decoded, val);
assert_eq!(consumed, len);
assert_eq!(sizeof_u32(val), len);
}
}
#[test]
fn encoding_is_big_endian() {
let mut buf = [0u8; MAX_VARINT_LEN];
let len = encode_u64(300, &mut buf);
assert_eq!(len, 2);
assert_eq!(&buf[MAX_VARINT_LEN - 2..], &[0x82, 0x2C]);
}
#[test]
fn single_byte_values() {
let mut buf = [0u8; MAX_VARINT_LEN];
for val in 0..=127u64 {
let len = encode_u64(val, &mut buf);
assert_eq!(len, 1);
assert_eq!(buf[MAX_VARINT_LEN - 1], val as u8);
}
}
#[test]
fn overflow_detection_u32() {
let mut buf = [0u8; MAX_VARINT_LEN];
let len = encode_u64(u64::MAX, &mut buf);
let result = read_u32(&buf[MAX_VARINT_LEN - len..]);
assert_eq!(result, Err(VarIntError::Overflow));
}
#[test]
fn underflow_detection() {
let data = [0x80, 0x80, 0x80];
assert_eq!(read_u64(&data), Err(VarIntError::Underflow));
}
#[test]
fn streaming_roundtrip() {
let mut buf = [0u8; MAX_VARINT_LEN];
let len = encode_u64(123456789, &mut buf);
let bytes = &buf[MAX_VARINT_LEN - len..MAX_VARINT_LEN];
let mut cursor = std::io::Cursor::new(bytes);
let val = stream_read_u64(&mut cursor).unwrap();
assert_eq!(val, 123456789);
}
#[test]
fn write_read_roundtrip() {
let mut out = Vec::new();
write_u64(&mut out, 999999).unwrap();
let (val, len) = read_u64(&out).unwrap();
assert_eq!(val, 999999);
assert_eq!(len, out.len());
}
}