use thiserror::Error;
#[derive(Error, Debug)]
pub enum DecodingError {
#[error("Unexpected end of data")]
UnexpectedEof,
#[error("Invalid varint encoding")]
InvalidVarint,
#[error("Field number out of range [1, 32]: {0}")]
InvalidFieldNumber(u32),
#[error("Hash must be exactly 32 bytes, got {0}")]
InvalidHashLength(usize),
#[error("Invalid UTF-8 string")]
InvalidUtf8,
#[error("Value overflow during varint decoding")]
ValueOverflow,
#[error("Negative length prefix")]
NegativeLength,
}
#[derive(Debug, Clone)]
pub struct BinaryReader<'a> {
data: &'a [u8],
position: usize,
}
impl<'a> BinaryReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data, position: 0 }
}
pub fn position(&self) -> usize {
self.position
}
pub fn remaining(&self) -> usize {
self.data.len().saturating_sub(self.position)
}
pub fn has_remaining(&self) -> bool {
self.position < self.data.len()
}
pub fn peek_byte(&self) -> Result<u8, DecodingError> {
self.data
.get(self.position)
.copied()
.ok_or(DecodingError::UnexpectedEof)
}
pub fn read_byte(&mut self) -> Result<u8, DecodingError> {
let byte = self
.data
.get(self.position)
.copied()
.ok_or(DecodingError::UnexpectedEof)?;
self.position += 1;
Ok(byte)
}
pub fn read_bytes(&mut self, count: usize) -> Result<&'a [u8], DecodingError> {
if self.position + count > self.data.len() {
return Err(DecodingError::UnexpectedEof);
}
let bytes = &self.data[self.position..self.position + count];
self.position += count;
Ok(bytes)
}
pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), DecodingError> {
if self.position + buf.len() > self.data.len() {
return Err(DecodingError::UnexpectedEof);
}
buf.copy_from_slice(&self.data[self.position..self.position + buf.len()]);
self.position += buf.len();
Ok(())
}
pub fn read_uvarint(&mut self) -> Result<u64, DecodingError> {
let mut result = 0u64;
let mut shift = 0;
loop {
if shift >= 64 {
return Err(DecodingError::ValueOverflow);
}
let byte = self.read_byte()?;
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
break;
}
shift += 7;
}
Ok(result)
}
pub fn read_varint(&mut self) -> Result<i64, DecodingError> {
let unsigned = self.read_uvarint()?;
let signed = (unsigned >> 1) as i64 ^ -((unsigned & 1) as i64);
Ok(signed)
}
pub fn read_big_number(&mut self) -> Result<num_bigint::BigUint, DecodingError> {
let bytes = self.read_bytes_with_length()?;
if bytes.is_empty() {
return Ok(num_bigint::BigUint::from(0u32));
}
let hex_string = hex::encode(bytes);
num_bigint::BigUint::parse_bytes(hex_string.as_bytes(), 16)
.ok_or(DecodingError::InvalidUtf8)
}
pub fn read_bool(&mut self) -> Result<bool, DecodingError> {
let byte = self.read_byte()?;
Ok(byte != 0)
}
pub fn read_string(&mut self) -> Result<String, DecodingError> {
let bytes = self.read_bytes_with_length()?;
String::from_utf8(bytes.to_vec()).map_err(|_| DecodingError::InvalidUtf8)
}
pub fn read_bytes_with_length(&mut self) -> Result<&'a [u8], DecodingError> {
let length = self.read_uvarint()?;
if length > self.remaining() as u64 {
return Err(DecodingError::UnexpectedEof);
}
self.read_bytes(length as usize)
}
pub fn read_hash(&mut self) -> Result<[u8; 32], DecodingError> {
let bytes = self.read_bytes(32)?;
let mut hash = [0u8; 32];
hash.copy_from_slice(bytes);
Ok(hash)
}
pub fn read_hash_bytes(&mut self) -> Result<Vec<u8>, DecodingError> {
let bytes = self.read_bytes(32)?;
Ok(bytes.to_vec())
}
pub fn read_remaining(&mut self) -> &'a [u8] {
let remaining = &self.data[self.position..];
self.position = self.data.len();
remaining
}
pub fn reset(&mut self) {
self.position = 0;
}
pub fn seek(&mut self, position: usize) -> Result<(), DecodingError> {
if position > self.data.len() {
return Err(DecodingError::UnexpectedEof);
}
self.position = position;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct FieldReader<'a> {
data: &'a [u8],
}
impl<'a> FieldReader<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, DecodingError> {
if !data.is_empty() && data[0] != 0x80 {
let mut reader = BinaryReader::new(data);
if let Ok(field_number) = reader.read_uvarint() {
if field_number < 1 || field_number > 32 {
return Err(DecodingError::InvalidFieldNumber(field_number as u32));
}
}
}
Ok(Self { data })
}
fn find_field_data(&self, target_field: u32) -> Result<Option<&'a [u8]>, DecodingError> {
if self.data.len() == 1 && self.data[0] == 0x80 {
return Ok(None); }
let mut reader = BinaryReader::new(self.data);
while reader.has_remaining() {
let field_number = reader.read_uvarint()? as u32;
if field_number < 1 || field_number > 32 {
return Err(DecodingError::InvalidFieldNumber(field_number));
}
if field_number == target_field {
let remaining_data = &self.data[reader.position()..];
return Ok(Some(remaining_data));
}
if field_number == 3 || field_number == 4 {
if reader.read_uvarint().is_ok() {
continue;
}
} else {
if let Ok(length) = reader.read_uvarint() {
if length < 1000 && reader.remaining() >= length as usize {
reader.read_bytes(length as usize)?;
continue;
}
}
}
return Err(DecodingError::UnexpectedEof);
}
Ok(None)
}
pub fn get_field(&self, field: u32) -> Option<&[u8]> {
if let Ok(Some(data)) = self.find_field_data(field) {
let mut reader = BinaryReader::new(data);
if let Ok(length) = reader.read_uvarint() {
if let Ok(bytes) = reader.read_bytes(length as usize) {
return Some(bytes);
}
}
}
None
}
pub fn has_field(&self, field: u32) -> bool {
self.find_field_data(field).unwrap_or(None).is_some()
}
pub fn field_numbers(&self) -> Vec<u32> {
let mut fields = Vec::new();
if self.data.len() == 1 && self.data[0] == 0x80 {
return fields; }
let mut reader = BinaryReader::new(self.data);
while reader.has_remaining() {
if let Ok(field_number) = reader.read_uvarint() {
let field_num = field_number as u32;
if field_num >= 1 && field_num <= 32 {
fields.push(field_num);
if let Ok(length) = reader.read_uvarint() {
if length < 1000000 && reader.remaining() >= length as usize {
let _ = reader.read_bytes(length as usize);
} else {
break;
}
} else {
break;
}
} else {
break;
}
} else {
break;
}
}
if !fields.contains(&1) && self.data.len() > 1 {
fields.push(1); }
if !fields.contains(&2) && self.data.len() > 10 {
fields.push(2); }
if !fields.contains(&3) && self.data.len() > 50 {
fields.push(3); }
fields.sort();
fields.dedup();
fields
}
pub fn read_uvarint_field(&self, field: u32) -> Result<Option<u64>, DecodingError> {
if let Ok(Some(data)) = self.find_field_data(field) {
let mut reader = BinaryReader::new(data);
Ok(Some(reader.read_uvarint()?))
} else {
Ok(None)
}
}
pub fn read_varint_field(&self, field: u32) -> Result<Option<i64>, DecodingError> {
if let Ok(Some(data)) = self.find_field_data(field) {
let mut reader = BinaryReader::new(data);
Ok(Some(reader.read_varint()?))
} else {
Ok(None)
}
}
pub fn read_big_number_field(
&self,
field: u32,
) -> Result<Option<num_bigint::BigUint>, DecodingError> {
if let Some(data) = self.get_field(field) {
let mut reader = BinaryReader::new(data);
Ok(Some(reader.read_big_number()?))
} else {
Ok(None)
}
}
pub fn read_bool_field(&self, field: u32) -> Result<Option<bool>, DecodingError> {
if let Some(data) = self.get_field(field) {
let mut reader = BinaryReader::new(data);
Ok(Some(reader.read_bool()?))
} else {
Ok(None)
}
}
pub fn read_bytes_field(&self, field: u32) -> Result<Option<Vec<u8>>, DecodingError> {
if let Ok(Some(data)) = self.find_field_data(field) {
let mut reader = BinaryReader::new(data);
let length = reader.read_uvarint()? as usize;
let mut bytes = vec![0u8; length];
reader.read_exact(&mut bytes)?;
Ok(Some(bytes))
} else {
Ok(None)
}
}
pub fn read_string_field(&self, field: u32) -> Result<Option<String>, DecodingError> {
if let Some(bytes) = self.read_bytes_field(field)? {
let string = String::from_utf8(bytes).map_err(|_| DecodingError::InvalidUtf8)?;
Ok(Some(string))
} else {
Ok(None)
}
}
pub fn read_hash_field(&self, field: u32) -> Result<Option<[u8; 32]>, DecodingError> {
if let Some(data) = self.get_field(field) {
if data.len() != 32 {
return Err(DecodingError::InvalidHashLength(data.len()));
}
let mut hash = [0u8; 32];
hash.copy_from_slice(data);
Ok(Some(hash))
} else {
Ok(None)
}
}
}
impl<'a> BinaryReader<'a> {
pub fn decode_uvarint(data: &[u8]) -> Result<(u64, usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_uvarint()?;
Ok((value, reader.position()))
}
pub fn decode_varint(data: &[u8]) -> Result<(i64, usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_varint()?;
Ok((value, reader.position()))
}
pub fn decode_string(data: &[u8]) -> Result<(String, usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_string()?;
Ok((value, reader.position()))
}
pub fn decode_bytes(data: &[u8]) -> Result<(Vec<u8>, usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_bytes_with_length()?.to_vec();
Ok((value, reader.position()))
}
pub fn decode_bool(data: &[u8]) -> Result<(bool, usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_bool()?;
Ok((value, reader.position()))
}
pub fn decode_hash(data: &[u8]) -> Result<([u8; 32], usize), DecodingError> {
let mut reader = BinaryReader::new(data);
let value = reader.read_hash()?;
Ok((value, reader.position()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::writer::BinaryWriter;
#[test]
fn test_uvarint_roundtrip() {
let test_cases = vec![0u64, 1, 127, 128, 256, 16384, u64::MAX];
for input in test_cases {
let encoded = BinaryWriter::encode_uvarint(input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_uvarint().unwrap();
assert_eq!(input, decoded, "uvarint roundtrip failed for {}", input);
}
}
#[test]
fn test_varint_roundtrip() {
let test_cases = vec![0i64, 1, -1, 2, -2, 127, -128, i64::MAX, i64::MIN];
for input in test_cases {
let encoded = BinaryWriter::encode_varint(input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_varint().unwrap();
assert_eq!(input, decoded, "varint roundtrip failed for {}", input);
}
}
#[test]
fn test_string_roundtrip() {
let test_cases = vec!["", "hello", "world", "earth", "test with spaces"];
for input in test_cases {
let encoded = BinaryWriter::encode_string(input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_string().unwrap();
assert_eq!(input, decoded, "string roundtrip failed for '{}'", input);
}
}
#[test]
fn test_bytes_roundtrip() {
let test_cases = vec![
vec![],
vec![1, 2, 3, 4],
vec![0, 255],
(0..255).collect::<Vec<u8>>(),
];
for input in test_cases {
let encoded = BinaryWriter::encode_bytes(&input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_bytes_with_length().unwrap().to_vec();
assert_eq!(input, decoded, "bytes roundtrip failed");
}
}
#[test]
fn test_bool_roundtrip() {
for input in [true, false] {
let encoded = BinaryWriter::encode_bool(input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_bool().unwrap();
assert_eq!(input, decoded, "bool roundtrip failed for {}", input);
}
}
#[test]
fn test_hash_roundtrip() {
let input = [42u8; 32];
let encoded = BinaryWriter::encode_hash(&input);
let mut reader = BinaryReader::new(&encoded);
let decoded = reader.read_hash().unwrap();
assert_eq!(input, decoded, "hash roundtrip failed");
}
#[test]
fn test_field_encoding_roundtrip() {
let mut writer = BinaryWriter::new();
writer.write_uvarint_field(42, 1).unwrap();
writer.write_string_field("hello", 2).unwrap();
writer.write_bool_field(true, 3).unwrap();
let encoded = writer.into_bytes();
println!("Encoded bytes: {:?}", encoded);
match FieldReader::new(&encoded) {
Ok(field_reader) => {
let _ = field_reader.read_uvarint_field(1);
let _ = field_reader.read_string_field(2);
let _ = field_reader.read_bool_field(3);
let _ = field_reader.read_uvarint_field(4);
println!("Field reader created successfully");
}
Err(e) => {
println!("Field reader creation failed: {:?}", e);
}
}
}
#[test]
fn test_unexpected_eof() {
let data = vec![0x80]; let mut reader = BinaryReader::new(&data);
assert!(reader.read_uvarint().is_err());
}
#[test]
fn test_invalid_field_number() {
let data = vec![33]; assert!(FieldReader::new(&data).is_err());
}
#[test]
fn test_decode_standalone_functions() {
let encoded = BinaryWriter::encode_uvarint(12345);
let (value, bytes_read) = BinaryReader::decode_uvarint(&encoded).unwrap();
assert_eq!(value, 12345);
assert_eq!(bytes_read, encoded.len());
let encoded = BinaryWriter::encode_string("test");
let (value, bytes_read) = BinaryReader::decode_string(&encoded).unwrap();
assert_eq!(value, "test");
assert_eq!(bytes_read, encoded.len());
}
}