const MAX_U64_VARINT_LEN: usize = 10;
const MAX_U32_VARINT_LEN: usize = 5;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VarintError {
UnexpectedEof,
Overflow,
}
impl std::fmt::Display for VarintError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VarintError::UnexpectedEof => write!(f, "buffer ended before varint was complete"),
VarintError::Overflow => write!(f, "varint overflows target integer type"),
}
}
}
impl std::error::Error for VarintError {}
pub fn encode_u64(mut v: u64) -> Vec<u8> {
let mut buf = Vec::with_capacity(MAX_U64_VARINT_LEN);
loop {
let byte = (v & 0x7f) as u8;
v >>= 7;
if v == 0 {
buf.push(byte);
break;
} else {
buf.push(byte | 0x80);
}
}
buf
}
pub fn encode_u32(v: u32) -> Vec<u8> {
encode_u64(v as u64)
}
pub fn decode_u64(buf: &[u8]) -> Result<(u64, usize), VarintError> {
let mut result: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in buf.iter().enumerate() {
if i == MAX_U64_VARINT_LEN {
return Err(VarintError::Overflow);
}
let low7 = (byte & 0x7f) as u64;
if i == 9 && low7 > 1 {
return Err(VarintError::Overflow);
}
result |= low7 << shift;
shift += 7;
if byte & 0x80 == 0 {
return Ok((result, i + 1));
}
}
Err(VarintError::UnexpectedEof)
}
pub fn decode_u32(buf: &[u8]) -> Result<(u32, usize), VarintError> {
let mut result: u32 = 0;
let mut shift = 0u32;
for (i, &byte) in buf.iter().enumerate() {
if i == MAX_U32_VARINT_LEN {
return Err(VarintError::Overflow);
}
let low7 = (byte & 0x7f) as u32;
if i == 4 && low7 > 0x0f {
return Err(VarintError::Overflow);
}
result |= low7 << shift;
shift += 7;
if byte & 0x80 == 0 {
return Ok((result, i + 1));
}
}
Err(VarintError::UnexpectedEof)
}
pub fn length_varint_u64(v: u64) -> usize {
if v == 0 {
return 1;
}
let bits_needed = 64 - v.leading_zeros() as usize;
bits_needed.div_ceil(7)
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::{self, Write};
fn write_varint_u64(w: &mut impl Write, v: u64) -> io::Result<()> {
let encoded = encode_u64(v);
w.write_all(&encoded)
}
fn write_varint_u32(w: &mut impl Write, v: u32) -> io::Result<()> {
let encoded = encode_u32(v);
w.write_all(&encoded)
}
#[test]
fn test_encode_u64_zero() {
assert_eq!(encode_u64(0), vec![0x00]);
}
#[test]
fn test_encode_u64_one() {
assert_eq!(encode_u64(1), vec![0x01]);
}
#[test]
fn test_encode_u64_127() {
assert_eq!(encode_u64(127), vec![0x7f]);
}
#[test]
fn test_encode_u64_128() {
assert_eq!(encode_u64(128), vec![0x80, 0x01]);
}
#[test]
fn test_encode_u64_16383() {
assert_eq!(encode_u64(16383), vec![0xff, 0x7f]);
}
#[test]
fn test_encode_u64_16384() {
assert_eq!(encode_u64(16384), vec![0x80, 0x80, 0x01]);
}
#[test]
fn test_encode_u64_max() {
assert_eq!(
encode_u64(u64::MAX),
vec![0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x01]
);
}
fn roundtrip_u64(v: u64) {
let encoded = encode_u64(v);
let (decoded, consumed) = decode_u64(&encoded).unwrap();
assert_eq!(decoded, v, "round-trip failed for {v}");
assert_eq!(consumed, encoded.len(), "wrong consumed for {v}");
}
#[test]
fn test_decode_u64_roundtrip_representative_set() {
for &v in &[0u64, 1, 127, 128, 16383, 16384, u64::MAX] {
roundtrip_u64(v);
}
}
#[test]
fn test_decode_u64_partial_buffer() {
let mut buf = encode_u64(300);
buf.push(0xab); let (val, consumed) = decode_u64(&buf).unwrap();
assert_eq!(val, 300);
assert_eq!(consumed, 2);
}
#[test]
fn test_decode_u64_unexpected_eof() {
let buf = vec![0x80];
assert_eq!(decode_u64(&buf), Err(VarintError::UnexpectedEof));
}
#[test]
fn test_decode_u64_overflow() {
let buf = vec![0xff; 11];
assert_eq!(decode_u64(&buf), Err(VarintError::Overflow));
}
fn roundtrip_u32(v: u32) {
let encoded = encode_u32(v);
let (decoded, consumed) = decode_u32(&encoded).unwrap();
assert_eq!(decoded, v, "round-trip failed for {v}");
assert_eq!(consumed, encoded.len(), "wrong consumed for {v}");
}
#[test]
fn test_decode_u32_roundtrip() {
for &v in &[0u32, 1, 127, 128, u32::MAX] {
roundtrip_u32(v);
}
}
#[test]
fn test_decode_u32_overflow() {
let buf = encode_u64(u64::MAX);
assert_eq!(decode_u32(&buf), Err(VarintError::Overflow));
}
#[test]
fn test_length_varint_matches_encode_len() {
for &v in &[0u64, 1, 127, 128, 16383, 16384, u64::MAX] {
let expected = encode_u64(v).len();
let got = length_varint_u64(v);
assert_eq!(
got, expected,
"length_varint_u64({v}) = {got}, expected {expected}"
);
}
}
#[test]
fn test_write_varint_u64() {
let mut buf = Vec::new();
write_varint_u64(&mut buf, 128).unwrap();
assert_eq!(buf, vec![0x80, 0x01]);
}
#[test]
fn test_write_varint_u32() {
let mut buf = Vec::new();
write_varint_u32(&mut buf, 0).unwrap();
assert_eq!(buf, vec![0x00]);
}
}