use crate::error::DecodeError;
use crate::limits::MAX_VARINT_BYTES;
use crate::model::Id;
#[derive(Debug, Clone)]
pub struct Reader<'a> {
data: &'a [u8],
pos: usize,
}
impl<'a> Reader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data, pos: 0 }
}
pub fn position(&self) -> usize {
self.pos
}
pub fn remaining(&self) -> &'a [u8] {
&self.data[self.pos..]
}
pub fn remaining_len(&self) -> usize {
self.data.len() - self.pos
}
pub fn is_empty(&self) -> bool {
self.pos >= self.data.len()
}
#[inline]
pub fn read_byte(&mut self, context: &'static str) -> Result<u8, DecodeError> {
if self.pos >= self.data.len() {
return Err(DecodeError::UnexpectedEof { context });
}
let byte = self.data[self.pos];
self.pos += 1;
Ok(byte)
}
#[inline]
pub fn read_bytes(&mut self, n: usize, context: &'static str) -> Result<&'a [u8], DecodeError> {
if self.pos + n > self.data.len() {
return Err(DecodeError::UnexpectedEof { context });
}
let bytes = &self.data[self.pos..self.pos + n];
self.pos += n;
Ok(bytes)
}
#[inline]
pub fn read_id(&mut self, context: &'static str) -> Result<Id, DecodeError> {
let bytes = self.read_bytes(16, context)?;
Ok(bytes.try_into().unwrap())
}
#[inline]
pub fn read_varint(&mut self, context: &'static str) -> Result<u64, DecodeError> {
let mut result: u64 = 0;
let mut shift = 0;
for i in 0..MAX_VARINT_BYTES {
let byte = self.read_byte(context)?;
let value = (byte & 0x7F) as u64;
if shift >= 64 || (shift == 63 && value > 1) {
return Err(DecodeError::VarintOverflow);
}
result |= value << shift;
if byte & 0x80 == 0 {
return Ok(result);
}
shift += 7;
if i == MAX_VARINT_BYTES - 1 {
return Err(DecodeError::VarintTooLong);
}
}
Err(DecodeError::VarintTooLong)
}
pub fn read_signed_varint(&mut self, context: &'static str) -> Result<i64, DecodeError> {
let unsigned = self.read_varint(context)?;
Ok(zigzag_decode(unsigned))
}
#[inline]
pub fn read_string(
&mut self,
max_len: usize,
field: &'static str,
) -> Result<String, DecodeError> {
let len = self.read_varint(field)? as usize;
if len > max_len {
return Err(DecodeError::LengthExceedsLimit {
field,
len,
max: max_len,
});
}
let bytes = self.read_bytes(len, field)?;
std::str::from_utf8(bytes)
.map(|s| s.to_string())
.map_err(|_| DecodeError::InvalidUtf8 { field })
}
#[inline]
pub fn read_str(
&mut self,
max_len: usize,
field: &'static str,
) -> Result<&'a str, DecodeError> {
let len = self.read_varint(field)? as usize;
if len > max_len {
return Err(DecodeError::LengthExceedsLimit {
field,
len,
max: max_len,
});
}
let bytes = self.read_bytes(len, field)?;
std::str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8 { field })
}
pub fn read_bytes_prefixed(
&mut self,
max_len: usize,
field: &'static str,
) -> Result<Vec<u8>, DecodeError> {
let len = self.read_varint(field)? as usize;
if len > max_len {
return Err(DecodeError::LengthExceedsLimit {
field,
len,
max: max_len,
});
}
let bytes = self.read_bytes(len, field)?;
Ok(bytes.to_vec())
}
#[inline]
pub fn read_f64(&mut self, context: &'static str) -> Result<f64, DecodeError> {
let bytes = self.read_bytes(8, context)?;
let value = f64::from_le_bytes(bytes.try_into().unwrap());
if value.is_nan() {
return Err(DecodeError::FloatIsNan);
}
Ok(value)
}
#[inline]
pub fn read_f64_unchecked(&mut self, context: &'static str) -> Result<f64, DecodeError> {
let bytes = self.read_bytes(8, context)?;
Ok(f64::from_le_bytes(bytes.try_into().unwrap()))
}
pub fn read_id_vec(
&mut self,
max_len: usize,
field: &'static str,
) -> Result<Vec<Id>, DecodeError> {
let count = self.read_varint(field)? as usize;
if count > max_len {
return Err(DecodeError::LengthExceedsLimit {
field,
len: count,
max: max_len,
});
}
let mut ids = Vec::with_capacity(count);
for _ in 0..count {
ids.push(self.read_id(field)?);
}
Ok(ids)
}
}
#[derive(Debug, Clone, Default)]
pub struct Writer {
buf: Vec<u8>,
}
impl Writer {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
buf: Vec::with_capacity(capacity),
}
}
pub fn into_bytes(self) -> Vec<u8> {
self.buf
}
pub fn as_bytes(&self) -> &[u8] {
&self.buf
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
#[inline]
pub fn write_byte(&mut self, byte: u8) {
self.buf.push(byte);
}
#[inline]
pub fn write_bytes(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
}
#[inline]
pub fn write_id(&mut self, id: &Id) {
self.buf.extend_from_slice(id);
}
#[inline]
pub fn write_varint(&mut self, mut value: u64) {
let mut buf = [0u8; 10]; let mut len = 0;
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
buf[len] = byte;
len += 1;
if value == 0 {
break;
}
}
self.buf.extend_from_slice(&buf[..len]);
}
pub fn write_signed_varint(&mut self, value: i64) {
self.write_varint(zigzag_encode(value));
}
pub fn write_string(&mut self, s: &str) {
self.write_varint(s.len() as u64);
self.buf.extend_from_slice(s.as_bytes());
}
pub fn write_bytes_prefixed(&mut self, bytes: &[u8]) {
self.write_varint(bytes.len() as u64);
self.buf.extend_from_slice(bytes);
}
pub fn write_f64(&mut self, value: f64) {
self.buf.extend_from_slice(&value.to_le_bytes());
}
pub fn write_id_vec(&mut self, ids: &[Id]) {
self.write_varint(ids.len() as u64);
for id in ids {
self.write_id(id);
}
}
}
#[inline]
pub fn zigzag_encode(n: i64) -> u64 {
((n << 1) ^ (n >> 63)) as u64
}
#[inline]
pub fn zigzag_decode(n: u64) -> i64 {
((n >> 1) as i64) ^ (-((n & 1) as i64))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_zigzag_roundtrip() {
for v in [0i64, 1, -1, 127, -128, i64::MAX, i64::MIN] {
assert_eq!(zigzag_decode(zigzag_encode(v)), v);
}
}
#[test]
fn test_zigzag_values() {
assert_eq!(zigzag_encode(0), 0);
assert_eq!(zigzag_encode(-1), 1);
assert_eq!(zigzag_encode(1), 2);
assert_eq!(zigzag_encode(-2), 3);
assert_eq!(zigzag_encode(2), 4);
}
#[test]
fn test_varint_roundtrip() {
let test_values = [0u64, 1, 127, 128, 255, 256, 16383, 16384, u64::MAX];
for v in test_values {
let mut writer = Writer::new();
writer.write_varint(v);
let mut reader = Reader::new(writer.as_bytes());
let decoded = reader.read_varint("test").unwrap();
assert_eq!(v, decoded, "failed for {}", v);
}
}
#[test]
fn test_signed_varint_roundtrip() {
let test_values = [0i64, 1, -1, 127, -128, i64::MAX, i64::MIN];
for v in test_values {
let mut writer = Writer::new();
writer.write_signed_varint(v);
let mut reader = Reader::new(writer.as_bytes());
let decoded = reader.read_signed_varint("test").unwrap();
assert_eq!(v, decoded, "failed for {}", v);
}
}
#[test]
fn test_string_roundtrip() {
let test_strings = ["", "hello", "hello world", "unicode: \u{1F600}"];
for s in test_strings {
let mut writer = Writer::new();
writer.write_string(s);
let mut reader = Reader::new(writer.as_bytes());
let decoded = reader.read_string(1000, "test").unwrap();
assert_eq!(s, decoded);
}
}
#[test]
fn test_id_roundtrip() {
let id = [1u8, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16];
let mut writer = Writer::new();
writer.write_id(&id);
let mut reader = Reader::new(writer.as_bytes());
let decoded = reader.read_id("test").unwrap();
assert_eq!(id, decoded);
}
#[test]
fn test_f64_roundtrip() {
let test_values = [0.0, 1.0, -1.0, f64::INFINITY, f64::NEG_INFINITY, 3.14159];
for v in test_values {
let mut writer = Writer::new();
writer.write_f64(v);
let mut reader = Reader::new(writer.as_bytes());
let decoded = reader.read_f64("test").unwrap();
assert_eq!(v, decoded, "failed for {}", v);
}
}
#[test]
fn test_f64_nan_rejected() {
let mut writer = Writer::new();
writer.write_f64(f64::NAN);
let mut reader = Reader::new(writer.as_bytes());
let result = reader.read_f64("test");
assert!(matches!(result, Err(DecodeError::FloatIsNan)));
}
#[test]
fn test_varint_too_long() {
let data = [0x80u8; 11];
let mut reader = Reader::new(&data);
let result = reader.read_varint("test");
assert!(matches!(result, Err(DecodeError::VarintTooLong)));
}
#[test]
fn test_string_too_long() {
let mut writer = Writer::new();
writer.write_varint(1000); writer.write_bytes(&[0u8; 1000]);
let mut reader = Reader::new(writer.as_bytes());
let result = reader.read_string(100, "test"); assert!(matches!(
result,
Err(DecodeError::LengthExceedsLimit { max: 100, .. })
));
}
#[test]
fn test_unexpected_eof() {
let data = [0u8; 5];
let mut reader = Reader::new(&data);
let result = reader.read_bytes(10, "test");
assert!(matches!(result, Err(DecodeError::UnexpectedEof { .. })));
}
}