use crate::backend::native::NativeBackendError;
use crate::backend::native::NativeResult;
pub const MAX_VARINT_BYTES: usize = 10;
pub const MAX_SINGLE_BYTE_VALUE: u64 = 0x7F;
const DATA_MASK: u8 = 0x7F;
const CONTINUATION_BIT: u8 = 0x80;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VarintError {
EmptyInput,
IncompleteData {
expected_bytes: usize,
actual_bytes: usize,
},
TooLong { actual_length: usize },
Malformed { position: usize },
}
impl std::fmt::Display for VarintError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::EmptyInput => write!(f, "Cannot decode varint from empty input"),
Self::IncompleteData {
expected_bytes,
actual_bytes,
} => write!(
f,
"Incomplete varint: expected {} bytes, found {}",
expected_bytes, actual_bytes
),
Self::TooLong { actual_length } => write!(
f,
"Varint too long: maximum {} bytes, found {}",
MAX_VARINT_BYTES, actual_length
),
Self::Malformed { position } => write!(
f,
"Malformed varint: continuation bit set at final byte {}",
position
),
}
}
}
impl std::error::Error for VarintError {}
pub fn encode_varint(value: u64) -> Vec<u8> {
if value == 0 {
return vec![0];
}
let mut buffer = Vec::with_capacity(MAX_VARINT_BYTES);
let mut remaining = value;
while remaining > 0 {
let mut byte = (remaining & DATA_MASK as u64) as u8;
remaining >>= 7;
if remaining > 0 {
byte |= CONTINUATION_BIT;
}
buffer.push(byte);
}
buffer
}
pub fn decode_varint(bytes: &[u8]) -> NativeResult<(u64, usize)> {
if bytes.is_empty() {
return Err(NativeBackendError::InvalidHeader {
field: "varint".to_string(),
reason: VarintError::EmptyInput.to_string(),
});
}
let mut result: u64 = 0;
let mut shift: usize = 0;
let mut bytes_read = 0;
for (i, &byte) in bytes.iter().enumerate() {
bytes_read = i + 1;
if i >= MAX_VARINT_BYTES {
return Err(NativeBackendError::InvalidHeader {
field: "varint".to_string(),
reason: VarintError::TooLong {
actual_length: i + 1,
}
.to_string(),
});
}
let data = (byte & DATA_MASK) as u64;
result |= data << shift;
if byte & CONTINUATION_BIT == 0 {
return Ok((result, bytes_read));
}
shift += 7;
if shift >= 64 {
return Err(NativeBackendError::InvalidHeader {
field: "varint".to_string(),
reason: format!("Varint overflow: shift {} exceeds u64 capacity", shift),
});
}
}
Err(NativeBackendError::InvalidHeader {
field: "varint".to_string(),
reason: VarintError::IncompleteData {
expected_bytes: bytes_read + 1,
actual_bytes: bytes_read,
}
.to_string(),
})
}
pub fn varint_size(value: u64) -> usize {
if value == 0 {
return 1;
}
let mut size = 0;
let mut remaining = value;
while remaining > 0 {
size += 1;
remaining >>= 7;
}
size
}
pub fn encode_varint_u32(value: u32) -> Vec<u8> {
encode_varint(value as u64)
}
pub fn decode_varint_u32(bytes: &[u8]) -> NativeResult<(u32, usize)> {
let (value, bytes_read) = decode_varint(bytes)?;
if value > u32::MAX as u64 {
return Err(NativeBackendError::InvalidHeader {
field: "varint_u32".to_string(),
reason: format!("Value {} exceeds u32::MAX", value),
});
}
Ok((value as u32, bytes_read))
}
pub fn encode_varint_u16(value: u16) -> Vec<u8> {
encode_varint(value as u64)
}
pub fn decode_varint_u16(bytes: &[u8]) -> NativeResult<(u16, usize)> {
let (value, bytes_read) = decode_varint(bytes)?;
if value > u16::MAX as u64 {
return Err(NativeBackendError::InvalidHeader {
field: "varint_u16".to_string(),
reason: format!("Value {} exceeds u16::MAX", value),
});
}
Ok((value as u16, bytes_read))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_varint_zero() {
assert_eq!(encode_varint(0), vec![0x00]);
}
#[test]
fn test_encode_varint_single_byte() {
assert_eq!(encode_varint(1), vec![0x01]);
assert_eq!(encode_varint(42), vec![0x2A]);
assert_eq!(encode_varint(127), vec![0x7F]);
}
#[test]
fn test_encode_varint_two_bytes() {
assert_eq!(encode_varint(128), vec![0x80, 0x01]);
assert_eq!(encode_varint(300), vec![0xAC, 0x02]);
assert_eq!(encode_varint(16383), vec![0xFF, 0x7F]);
}
#[test]
fn test_encode_varint_three_bytes() {
assert_eq!(encode_varint(16384), vec![0x80, 0x80, 0x01]);
assert_eq!(encode_varint(2097151), vec![0xFF, 0xFF, 0x7F]);
}
#[test]
fn test_encode_varint_max_u64() {
let encoded = encode_varint(u64::MAX);
assert_eq!(encoded.len(), 10);
for (i, &byte) in encoded.iter().enumerate() {
if i < 9 {
assert!(
byte & CONTINUATION_BIT != 0,
"Byte {} should have MSB set",
i
);
} else {
assert!(
byte & CONTINUATION_BIT == 0,
"Byte {} should not have MSB set",
i
);
}
}
}
#[test]
fn test_encode_varint_u16() {
let encoded = encode_varint_u16(u16::MAX);
assert_eq!(encoded, vec![0xFF, 0xFF, 0x03]);
assert_eq!(encode_varint_u16(100), vec![100]);
}
#[test]
fn test_encode_varint_u32() {
let encoded = encode_varint_u32(u32::MAX);
assert_eq!(encoded.len(), 5);
assert_eq!(encoded, vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
}
#[test]
fn test_decode_varint_zero() {
assert_eq!(decode_varint(&[0x00]).unwrap(), (0, 1));
}
#[test]
fn test_decode_varint_single_byte() {
assert_eq!(decode_varint(&[0x01]).unwrap(), (1, 1));
assert_eq!(decode_varint(&[0x2A]).unwrap(), (42, 1));
assert_eq!(decode_varint(&[0x7F]).unwrap(), (127, 1));
}
#[test]
fn test_decode_varint_two_bytes() {
assert_eq!(decode_varint(&[0x80, 0x01]).unwrap(), (128, 2));
assert_eq!(decode_varint(&[0xAC, 0x02]).unwrap(), (300, 2));
assert_eq!(decode_varint(&[0xFF, 0x7F]).unwrap(), (16383, 2));
}
#[test]
fn test_decode_varint_three_bytes() {
assert_eq!(decode_varint(&[0x80, 0x80, 0x01]).unwrap(), (16384, 3));
assert_eq!(decode_varint(&[0xFF, 0xFF, 0x7F]).unwrap(), (2097151, 3));
}
#[test]
fn test_decode_varint_with_extra_data() {
let data = vec![0x80, 0x01, 0xFF, 0xFF, 0xFF];
assert_eq!(decode_varint(&data).unwrap(), (128, 2));
}
#[test]
fn test_decode_varint_u16() {
assert_eq!(decode_varint_u16(&[0x7F]).unwrap(), (127, 1));
assert_eq!(decode_varint_u16(&[0xFF, 0x01]).unwrap(), (255, 2));
assert_eq!(
decode_varint_u16(&[0xFF, 0xFF, 0x03]).unwrap(),
(u16::MAX, 3)
);
}
#[test]
fn test_decode_varint_u16_incomplete() {
let result = decode_varint_u16(&[0x80]);
assert!(result.is_err());
}
#[test]
fn test_decode_varint_u32() {
assert_eq!(decode_varint_u32(&[0x7F]).unwrap(), (127, 1));
assert_eq!(decode_varint_u32(&[0xFF, 0x01]).unwrap(), (255, 2));
assert_eq!(
decode_varint_u32(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]).unwrap(),
(u32::MAX, 5)
);
}
#[test]
fn test_decode_varint_u32_incomplete() {
let result = decode_varint_u32(&[0x80]);
assert!(result.is_err());
}
#[test]
fn test_decode_varint_u32_overflow() {
let large_value = encode_varint(u32::MAX as u64 + 1);
assert!(decode_varint_u32(&large_value).is_err());
}
#[test]
fn test_decode_varint_u16_overflow() {
let large_value = encode_varint(u16::MAX as u64 + 1);
assert!(decode_varint_u16(&large_value).is_err());
}
#[test]
fn test_decode_varint_empty_input() {
let result = decode_varint(&[]);
assert!(result.is_err());
}
#[test]
fn test_decode_varint_incomplete() {
let result = decode_varint(&[0x80]);
assert!(result.is_err());
let result = decode_varint(&[0x80, 0x80]);
assert!(result.is_err());
}
#[test]
fn test_decode_varint_too_long() {
let too_long = vec![0x80u8; 11];
let result = decode_varint(&too_long);
assert!(result.is_err());
}
#[test]
fn test_round_trip_small_values() {
let test_values = vec![0, 1, 42, 127, 128, 255, 256, 1000];
for value in test_values {
let encoded = encode_varint(value);
let (decoded, bytes_read) = decode_varint(&encoded).unwrap();
assert_eq!(decoded, value, "Round-trip failed for {}", value);
assert_eq!(
bytes_read,
encoded.len(),
"Byte count mismatch for {}",
value
);
}
}
#[test]
fn test_round_trip_edge_values() {
let test_values = vec![
0,
1,
127,
128,
16383,
16384,
2097151,
2097152,
u16::MAX as u64,
u32::MAX as u64,
u64::MAX,
];
for value in test_values {
let encoded = encode_varint(value);
let (decoded, bytes_read) = decode_varint(&encoded).unwrap();
assert_eq!(decoded, value, "Round-trip failed for {}", value);
assert_eq!(bytes_read, encoded.len());
}
}
#[test]
fn test_round_trip_u16() {
for value in [0, 1, 100, u16::MAX / 2, u16::MAX - 1, u16::MAX] {
let encoded = encode_varint_u16(value);
let (decoded, bytes_read) = decode_varint_u16(&encoded).unwrap();
assert_eq!(decoded, value);
assert_eq!(bytes_read, encoded.len());
}
}
#[test]
fn test_round_trip_u32() {
for value in [0, 1, 1000, u16::MAX as u32, u32::MAX / 2, u32::MAX] {
let encoded = encode_varint_u32(value);
let (decoded, bytes_read) = decode_varint_u32(&encoded).unwrap();
assert_eq!(decoded, value);
assert_eq!(bytes_read, encoded.len());
}
}
#[test]
fn test_varint_size_calculation() {
assert_eq!(varint_size(0), 1);
assert_eq!(varint_size(127), 1);
assert_eq!(varint_size(128), 2);
assert_eq!(varint_size(16383), 2);
assert_eq!(varint_size(16384), 3);
assert_eq!(varint_size(2097151), 3);
assert_eq!(varint_size(2097152), 4);
for value in [0, 1, 127, 128, 16383, 16384, u32::MAX as u64, u64::MAX] {
assert_eq!(
varint_size(value),
encode_varint(value).len(),
"varint_size mismatch for {}",
value
);
}
}
#[test]
fn test_varint_7_bit_chunks() {
let value = 0b111_1111_1111_1111_1111_1111;
let encoded = encode_varint(value);
assert_eq!(encoded.len(), 4);
assert_eq!(encoded[0] & !DATA_MASK, CONTINUATION_BIT); assert_eq!(encoded[0] & DATA_MASK, 0b1111111);
assert_eq!(encoded[3] & CONTINUATION_BIT, 0); }
#[test]
fn test_varint_efficiency_for_small_values() {
assert_eq!(encode_varint(0).len(), 1);
assert_eq!(encode_varint(10).len(), 1);
assert_eq!(encode_varint(42).len(), 1);
assert_eq!(encode_varint(100).len(), 1);
assert_eq!(encode_varint(127).len(), 1);
assert_eq!(encode_varint(128).len(), 2);
assert_eq!(encode_varint(255).len(), 2);
assert_eq!(encode_varint(256).len(), 2);
assert_eq!(encode_varint(1000).len(), 2);
}
}