const LSB_7: u8 = u8::MAX / 2;
const MSB: u8 = !LSB_7;
#[inline]
const fn msb_is_1(val: u8) -> bool {
val & MSB == MSB
}
#[derive(Debug, Eq, PartialEq, Copy, Clone, Hash)]
pub(crate) struct VarintEncoding {
buf: [u8; Self::MAX_LEN],
len: u8,
}
impl VarintEncoding {
pub const MAX_LEN: usize = 3;
#[allow(dead_code)]
pub const fn as_slice(&self) -> &[u8] {
self.buf.split_at(self.len as usize).0
}
}
const fn bitlength(value: u16) -> u8 {
if let Some(x) = value.checked_ilog2() {
(x + 1) as u8
} else {
0
}
}
#[cfg(test)]
mod bitlength_test {
use super::bitlength;
#[test]
fn test_bitlength() {
assert_eq!(bitlength(0b000), 0);
assert_eq!(bitlength(0b001), 1);
assert_eq!(bitlength(0b010), 2);
assert_eq!(bitlength(0b011), 2);
assert_eq!(bitlength(0b100), 3);
}
}
#[allow(dead_code)]
pub(crate) const fn encode_varint(value: u16) -> VarintEncoding {
let mut out_buf = [0; VarintEncoding::MAX_LEN];
let in_bit_length: u16 = bitlength(value) as u16;
if in_bit_length == 0 {
return VarintEncoding {
buf: out_buf,
len: 1,
};
};
let mut out = 0u32;
let mut in_bit_pos = 0u16;
let mut carry_counter = 0;
macro_rules! copy_chunk {
() => {
let retrieved_bits = value & ((LSB_7 as u16) << in_bit_pos);
out |= (retrieved_bits as u32) << carry_counter;
in_bit_pos += 7;
};
}
copy_chunk!();
while in_bit_pos < in_bit_length {
carry_counter += 1;
out |= 1 << (carry_counter * 8 - 1); copy_chunk!();
}
let num_output_bytes = carry_counter + 1;
let mut idx = 0;
let output_bytes = out.to_le_bytes();
while idx < num_output_bytes {
out_buf[idx] = output_bytes[idx];
idx += 1;
}
VarintEncoding {
buf: out_buf,
len: num_output_bytes as u8,
}
}
pub(crate) const fn decode_varint(encoded: &[u8]) -> Result<(u16, &[u8]), DecodeError> {
if encoded.is_empty() {
return Err(DecodeError::MissingBytes);
}
let mut decoded: u16 = 0;
let mut current_encoded_idx = 0;
let mut current_decoded_bit = 0;
while current_decoded_bit < u16::BITS {
if current_encoded_idx >= encoded.len() {
return Err(DecodeError::MissingBytes);
}
let current_byte = encoded[current_encoded_idx];
let current_byte_lower7 = current_byte & LSB_7;
if current_decoded_bit
> u16::BITS - (bitlength(current_byte_lower7 as u16) as u32)
{
return Err(DecodeError::WouldOverflow);
}
let shifted = (current_byte_lower7 as u16) << current_decoded_bit;
decoded |= shifted;
current_decoded_bit += 7;
current_encoded_idx += 1;
if !msb_is_1(current_byte) {
let bytes_in_varint = current_encoded_idx;
return Ok((decoded, encoded.split_at(bytes_in_varint).1));
}
}
Err(DecodeError::WouldOverflow)
}
#[derive(thiserror::Error, Debug, Eq, PartialEq)]
pub enum DecodeError {
#[error("expected more bytes than what were provided")]
MissingBytes,
#[error(
"the decoded number is too large to fit into the type without overflowing"
)]
WouldOverflow,
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use itertools::Itertools;
use super::*;
fn test_roundtrip(decoded: u16) {
let extra_bytes = [1, 2, 3, 4].as_slice();
let empty = [].as_slice();
let encoded = encode_varint(decoded);
let round_tripped = decode_varint(encoded.as_slice());
assert_eq!(Ok((decoded, empty)), round_tripped);
let mut encoded_with_extra = encoded.as_slice().to_vec();
encoded_with_extra.extend_from_slice(extra_bytes);
let round_tripped_with_extra = decode_varint(&encoded_with_extra);
assert_eq!(Ok((decoded, extra_bytes)), round_tripped_with_extra)
}
#[test]
fn test_known_examples() {
let examples1 = [
(0x01, [0x01].as_slice()),
(0x02, &[0x02]),
(0x7f, &[0x7f]), (0x80, &[0x80, 0x01]), (0x81, &[0x81, 0x01]), (0xff, &[0xff, 0x01]), (0x012c, &[0xac, 0x02]), ];
let examples2 = [
(0xed, [0xed, 0x01].as_slice()), (0xec, &[0xec, 0x01]), (0x1200, &[0x80, 0x24]), (0xe7, &[0xe7, 0x01]), ];
let empty: &[u8] = &[];
let extra_bytes = [1, 2, 3, 4];
for (decoded, encoded) in examples1.into_iter().chain(examples2) {
assert_eq!(encoded, encode_varint(decoded).as_slice());
assert_eq!(Ok((decoded, empty)), decode_varint(encoded));
let mut extended_encoded = encoded.to_vec();
extended_encoded.extend_from_slice(&extra_bytes);
assert_eq!(
Ok((decoded, extra_bytes.as_slice())),
decode_varint(&extended_encoded)
);
test_roundtrip(decoded);
}
}
#[test]
fn test_all_u16_roundtrip() {
for i in 0..=u16::MAX {
test_roundtrip(i);
}
}
#[test]
fn test_decode_boundary_conditions() {
let examples = [
([].as_slice(), Err(DecodeError::MissingBytes)),
(&[0], Ok((0, [].as_slice()))),
(&[0, 0], Ok((0, &[0]))),
(&[0, 1], Ok((0, &[1]))),
(&[0, u8::MAX], Ok((0, &[u8::MAX]))),
(&[u8::MAX], Err(DecodeError::MissingBytes)),
(&[1 << 7], Err(DecodeError::MissingBytes)),
(&[LSB_7], Ok((LSB_7 as u16, &[]))),
(&[LSB_7, 0], Ok((LSB_7 as u16, &[0]))),
(&[LSB_7, u8::MAX], Ok((LSB_7 as u16, &[u8::MAX]))),
];
for (encoded, result) in examples {
assert_eq!(decode_varint(encoded), result, "decoded from {encoded:?}");
}
}
fn generate_all_valid_varints() -> HashMap<Vec<u8>, u16> {
let mut result = HashMap::with_capacity(u16::MAX as usize);
for i in 0..=u16::MAX {
let encoded = encode_varint(i);
result.insert(encoded.as_slice().to_vec(), i);
}
result
}
const SPECIAL_BYTES: &[u8] =
&[0, 1, u8::MAX, u8::MAX - 1, LSB_7, LSB_7 + 1, LSB_7 - 1];
fn generate_all_byte_patterns() -> impl Iterator<Item = Vec<u8>> {
let mut iterators = Vec::new();
const LEN: usize = 3;
for i in 1..=LEN {
iterators.push((0..=u8::MAX).combinations_with_replacement(i));
}
iterators
.into_iter()
.flatten()
.chain([vec![]])
.chain(
(0..=u8::MAX)
.combinations_with_replacement(LEN)
.cartesian_product(SPECIAL_BYTES)
.map(|(mut bytes, special)| {
bytes.push(*special);
bytes
}),
)
}
#[test]
fn test_all_varints() {
let valid_varints = generate_all_valid_varints();
for bytes in generate_all_byte_patterns() {
match decode_varint(&bytes) {
Ok((decoded, tail_bytes)) => {
let head_bytes = &bytes[0..(bytes.len() - tail_bytes.len())];
assert_eq!(
head_bytes,
encode_varint(decoded).as_slice(),
"expected head bytes {head_bytes:?} of bytes {bytes:?} to match encoding of decoded value {decoded}"
);
if let Some(&expected_decoded) = valid_varints.get(head_bytes) {
assert_eq!(
decoded, expected_decoded,
"expected successful decode of {bytes:?} to match expected decode value"
);
} else {
panic!("expected all successful decodes to have a matching entry in the valid varint set. Offending bytes: {bytes:?}");
}
}
Err(_err) => assert!(!valid_varints.contains_key(&bytes)),
}
}
}
}