#![allow(clippy::unwrap_used)]
use thiserror::Error;
#[derive(Error, Debug)]
pub enum EncodingError {
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Field number out of range [1, 32]: {0}")]
InvalidFieldNumber(u32),
#[error("Value exceeds maximum safe integer")]
ValueTooLarge,
#[error("Hash must be exactly 32 bytes, got {0}")]
InvalidHashLength(usize),
#[error("Cannot marshal negative bigint")]
NegativeBigInt,
#[error("Invalid UTF-8 string")]
InvalidUtf8,
}
#[derive(Debug, Clone)]
pub struct BinaryWriter {
buffer: Vec<u8>,
}
impl BinaryWriter {
pub fn new() -> Self {
Self { buffer: Vec::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
}
}
pub fn into_bytes(self) -> Vec<u8> {
self.buffer
}
pub fn bytes(&self) -> &[u8] {
&self.buffer
}
pub fn clear(&mut self) {
self.buffer.clear();
}
pub fn write_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodingError> {
self.buffer.extend_from_slice(bytes);
Ok(())
}
pub fn write_field(&mut self, field: u32, value: &[u8]) -> Result<(), EncodingError> {
if field < 1 || field > 32 {
return Err(EncodingError::InvalidFieldNumber(field));
}
self.write_uvarint(field as u64)?;
self.write_bytes(value)?;
Ok(())
}
pub fn write_uvarint(&mut self, mut value: u64) -> Result<(), EncodingError> {
while value >= 0x80 {
self.buffer.push((value as u8) | 0x80);
value >>= 7;
}
self.buffer.push(value as u8);
Ok(())
}
pub fn write_uvarint_field(&mut self, value: u64, field: u32) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_uvarint(value)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_varint(&mut self, value: i64) -> Result<(), EncodingError> {
let unsigned = ((value as u64) << 1) ^ ((value >> 63) as u64);
self.write_uvarint(unsigned)
}
pub fn write_varint_field(&mut self, value: i64, field: u32) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_varint(value)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_big_number(&mut self, value: &num_bigint::BigUint) -> Result<(), EncodingError> {
let hex_string = value.to_str_radix(16);
let padded_hex = if hex_string.len() % 2 == 1 {
format!("0{}", hex_string)
} else {
hex_string
};
let bytes: Result<Vec<u8>, _> = (0..padded_hex.len())
.step_by(2)
.map(|i| u8::from_str_radix(&padded_hex[i..i + 2], 16))
.collect();
let bytes = bytes.map_err(|_| EncodingError::InvalidUtf8)?;
self.write_bytes_with_length(&bytes)?;
Ok(())
}
pub fn write_big_number_field(
&mut self,
value: &num_bigint::BigUint,
field: u32,
) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_big_number(value)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_bool(&mut self, value: bool) -> Result<(), EncodingError> {
self.buffer.push(if value { 1 } else { 0 });
Ok(())
}
pub fn write_bool_field(&mut self, value: bool, field: u32) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_bool(value)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_string(&mut self, value: &str) -> Result<(), EncodingError> {
let bytes = value.as_bytes();
self.write_bytes_with_length(bytes)?;
Ok(())
}
pub fn write_string_field(&mut self, value: &str, field: u32) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_string(value)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_bytes_with_length(&mut self, bytes: &[u8]) -> Result<(), EncodingError> {
self.write_uvarint(bytes.len() as u64)?;
self.write_bytes(bytes)?;
Ok(())
}
pub fn write_bytes_field(&mut self, bytes: &[u8], field: u32) -> Result<(), EncodingError> {
let mut temp_writer = BinaryWriter::new();
temp_writer.write_bytes_with_length(bytes)?;
self.write_field(field, temp_writer.bytes())?;
Ok(())
}
pub fn write_hash(&mut self, hash: &[u8; 32]) -> Result<(), EncodingError> {
self.write_bytes(hash)?;
Ok(())
}
pub fn write_hash_field(&mut self, hash: &[u8; 32], field: u32) -> Result<(), EncodingError> {
self.write_field(field, hash)?;
Ok(())
}
pub fn write_hash_bytes(&mut self, hash: &[u8]) -> Result<(), EncodingError> {
if hash.len() != 32 {
return Err(EncodingError::InvalidHashLength(hash.len()));
}
self.write_bytes(hash)?;
Ok(())
}
pub fn write_hash_bytes_field(&mut self, hash: &[u8], field: u32) -> Result<(), EncodingError> {
if hash.len() != 32 {
return Err(EncodingError::InvalidHashLength(hash.len()));
}
self.write_field(field, hash)?;
Ok(())
}
pub fn write_optional<T, F>(
&mut self,
value: Option<&T>,
_field: u32,
writer_fn: F,
) -> Result<(), EncodingError>
where
T: Clone,
F: FnOnce(&mut Self, &T) -> Result<(), EncodingError>,
{
if let Some(val) = value {
writer_fn(self, val)?;
}
Ok(())
}
pub fn write_array<T, F>(
&mut self,
items: &[T],
_field: u32,
writer_fn: F,
) -> Result<(), EncodingError>
where
F: Fn(&mut Self, &T) -> Result<(), EncodingError>,
{
for item in items {
writer_fn(self, item)?;
}
Ok(())
}
}
impl Default for BinaryWriter {
fn default() -> Self {
Self::new()
}
}
impl BinaryWriter {
pub fn with_field_number(data: &[u8], field: Option<u32>) -> Result<Vec<u8>, EncodingError> {
match field {
Some(field_num) => {
let mut writer = BinaryWriter::new();
writer.write_field(field_num, data)?;
Ok(writer.into_bytes())
}
None => Ok(data.to_vec()),
}
}
pub fn encode_uvarint(value: u64) -> Vec<u8> {
let mut writer = BinaryWriter::new();
writer.write_uvarint(value).unwrap(); writer.into_bytes()
}
pub fn encode_varint(value: i64) -> Vec<u8> {
let mut writer = BinaryWriter::new();
writer.write_varint(value).unwrap(); writer.into_bytes()
}
pub fn encode_string(value: &str) -> Vec<u8> {
let mut writer = BinaryWriter::new();
writer.write_string(value).unwrap(); writer.into_bytes()
}
pub fn encode_bytes(bytes: &[u8]) -> Vec<u8> {
let mut writer = BinaryWriter::new();
writer.write_bytes_with_length(bytes).unwrap(); writer.into_bytes()
}
pub fn encode_bool(value: bool) -> Vec<u8> {
vec![if value { 1 } else { 0 }]
}
pub fn encode_hash(hash: &[u8; 32]) -> Vec<u8> {
hash.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uvarint_encoding() {
let test_cases = vec![
(0u64, vec![0]),
(1u64, vec![1]),
(127u64, vec![127]),
(128u64, vec![128, 1]),
(256u64, vec![128, 2]),
(16384u64, vec![128, 128, 1]),
];
for (input, expected) in test_cases {
let result = BinaryWriter::encode_uvarint(input);
assert_eq!(result, expected, "uvarint({}) failed", input);
}
}
#[test]
fn test_varint_encoding() {
let test_cases = vec![
(0i64, vec![0]),
(-1i64, vec![1]),
(1i64, vec![2]),
(-2i64, vec![3]),
(2i64, vec![4]),
];
for (input, expected) in test_cases {
let result = BinaryWriter::encode_varint(input);
assert_eq!(result, expected, "varint({}) failed", input);
}
}
#[test]
fn test_string_encoding() {
let result = BinaryWriter::encode_string("hello");
let expected = vec![5, b'h', b'e', b'l', b'l', b'o'];
assert_eq!(result, expected);
}
#[test]
fn test_bytes_encoding() {
let input = &[1, 2, 3, 4];
let result = BinaryWriter::encode_bytes(input);
let expected = vec![4, 1, 2, 3, 4];
assert_eq!(result, expected);
}
#[test]
fn test_bool_encoding() {
assert_eq!(BinaryWriter::encode_bool(true), vec![1]);
assert_eq!(BinaryWriter::encode_bool(false), vec![0]);
}
#[test]
fn test_field_encoding() {
let mut writer = BinaryWriter::new();
writer.write_field(1, &[42]).unwrap();
let expected = vec![1, 42];
assert_eq!(writer.bytes(), &expected);
}
#[test]
fn test_hash_validation() {
let mut writer = BinaryWriter::new();
let valid_hash = [0u8; 32];
assert!(writer.write_hash(&valid_hash).is_ok());
let invalid_hash = [0u8; 31];
assert!(writer.write_hash_bytes(&invalid_hash).is_err());
}
}