use crate::types::Token;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DecodeError {
#[error("truncated varint stream at byte {0}")]
Truncated(usize),
#[error("decoded token value {0} exceeds u32::MAX")]
Overflow(u64),
}
pub fn encode_tokens(tokens: &[Token]) -> Vec<u8> {
let mut out = Vec::with_capacity(tokens.len() * 2);
for &tok in tokens {
encode_one(tok, &mut out);
}
out
}
pub fn decode_tokens(data: &[u8]) -> Result<Vec<Token>, DecodeError> {
let mut tokens = Vec::new();
let mut i = 0;
while i < data.len() {
let (tok, consumed) = decode_one(&data[i..], i)?;
tokens.push(tok);
i += consumed;
}
Ok(tokens)
}
const MAX_4BYTE: u32 = (1 << 29) - 1;
fn encode_one(tok: Token, out: &mut Vec<u8>) {
if tok < 128 {
out.push(tok as u8);
} else if tok < 16_384 {
let high = 0b1000_0000u8 | ((tok >> 8) as u8 & 0x3F);
let low = (tok & 0xFF) as u8;
out.push(high);
out.push(low);
} else if tok < 2_097_152 {
let high = 0b1100_0000u8 | ((tok >> 16) as u8 & 0x1F);
let mid = ((tok >> 8) & 0xFF) as u8;
let low = (tok & 0xFF) as u8;
out.push(high);
out.push(mid);
out.push(low);
} else if tok <= MAX_4BYTE {
let high = 0b1110_0000u8 | ((tok >> 24) as u8 & 0x1F);
let b2 = ((tok >> 16) & 0xFF) as u8;
let b3 = ((tok >> 8) & 0xFF) as u8;
let b4 = (tok & 0xFF) as u8;
out.push(high);
out.push(b2);
out.push(b3);
out.push(b4);
} else {
out.push(0xFF);
out.extend_from_slice(&tok.to_be_bytes());
}
}
fn decode_one(data: &[u8], base_offset: usize) -> Result<(Token, usize), DecodeError> {
if data.is_empty() {
return Err(DecodeError::Truncated(base_offset));
}
let first = data[0];
if first & 0b1000_0000 == 0 {
Ok((u32::from(first), 1))
} else if first & 0b0100_0000 == 0 {
if data.len() < 2 {
return Err(DecodeError::Truncated(base_offset));
}
let val = ((u32::from(first & 0x3F)) << 8) | u32::from(data[1]);
Ok((val, 2))
} else if first & 0b0010_0000 == 0 {
if data.len() < 3 {
return Err(DecodeError::Truncated(base_offset));
}
let val = ((u32::from(first & 0x1F)) << 16)
| (u32::from(data[1]) << 8)
| u32::from(data[2]);
Ok((val, 3))
} else if first == 0xFF {
if data.len() < 5 {
return Err(DecodeError::Truncated(base_offset));
}
let val = u32::from_be_bytes([data[1], data[2], data[3], data[4]]);
Ok((val, 5))
} else {
if data.len() < 4 {
return Err(DecodeError::Truncated(base_offset));
}
let val = ((u32::from(first & 0x1F)) << 24)
| (u32::from(data[1]) << 16)
| (u32::from(data[2]) << 8)
| u32::from(data[3]);
Ok((val, 4))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip(tokens: &[Token]) {
let encoded = encode_tokens(tokens);
let decoded = decode_tokens(&encoded).expect("decode failed");
assert_eq!(tokens, decoded.as_slice(), "roundtrip mismatch for {tokens:?}");
}
#[test]
fn boundary_values() {
roundtrip(&[0]);
roundtrip(&[1]);
roundtrip(&[127]);
roundtrip(&[128]);
roundtrip(&[16_383]);
roundtrip(&[16_384]);
roundtrip(&[2_097_151]);
roundtrip(&[2_097_152]);
roundtrip(&[u32::MAX]);
}
#[test]
fn empty_sequence() {
roundtrip(&[]);
}
#[test]
fn mixed_sequence() {
let tokens: Vec<Token> = vec![0, 64, 127, 128, 1000, 16_383, 16_384, 100_000, u32::MAX];
roundtrip(&tokens);
}
#[test]
fn long_sequence() {
let tokens: Vec<Token> = (0..10_000).map(|i| i % 150_000).collect();
roundtrip(&tokens);
}
#[test]
fn encoding_sizes() {
assert_eq!(encode_tokens(&[0]).len(), 1);
assert_eq!(encode_tokens(&[127]).len(), 1);
assert_eq!(encode_tokens(&[128]).len(), 2);
assert_eq!(encode_tokens(&[16_383]).len(), 2);
assert_eq!(encode_tokens(&[16_384]).len(), 3);
assert_eq!(encode_tokens(&[2_097_151]).len(), 3);
assert_eq!(encode_tokens(&[2_097_152]).len(), 4);
assert_eq!(encode_tokens(&[536_870_911]).len(), 4);
assert_eq!(encode_tokens(&[536_870_912]).len(), 5);
assert_eq!(encode_tokens(&[u32::MAX]).len(), 5);
}
#[test]
fn truncated_stream_errors() {
let partial = vec![0b1000_0000u8]; assert!(matches!(decode_tokens(&partial), Err(DecodeError::Truncated(_))));
}
#[test]
fn high_freq_tokens_are_small() {
let top_tokens: Vec<Token> = (0u32..128).collect();
let encoded = encode_tokens(&top_tokens);
assert_eq!(encoded.len(), 128, "each token 0–127 should cost 1 byte");
}
}