use alloc::vec;
use alloc::vec::Vec;
use crate::error::{Result, SerialError};
use crate::traits::{Deserialize, Serialize};
use crate::varint;
#[non_exhaustive]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Config {
pub max_alloc: usize,
}
impl Default for Config {
fn default() -> Self {
Self::new()
}
}
impl Config {
#[must_use]
pub const fn new() -> Self {
Self { max_alloc: 1 << 30 }
}
#[must_use]
pub const fn with_max_alloc(mut self, max_alloc: usize) -> Self {
self.max_alloc = max_alloc;
self
}
pub(crate) fn validate(self) -> Result<Self> {
if self.max_alloc == 0 {
return Err(SerialError::InvalidLength {
declared: 0,
remaining: 0,
});
}
Ok(self)
}
}
pub trait Encode {
fn write_byte(&mut self, byte: u8) -> Result<()>;
fn write_bytes(&mut self, bytes: &[u8]) -> Result<()>;
#[inline]
fn reserve(&mut self, additional: usize) {
let _ = additional;
}
#[inline]
fn write_varint_u64(&mut self, value: u64) -> Result<()> {
let mut buf = [0u8; varint::MAX_VARINT_LEN_U64];
let n = varint::write_u64(value, &mut buf);
self.write_bytes(&buf[..n])
}
#[inline]
fn write_varint_u128(&mut self, value: u128) -> Result<()> {
let mut buf = [0u8; varint::MAX_VARINT_LEN_U128];
let n = varint::write_u128(value, &mut buf);
self.write_bytes(&buf[..n])
}
}
pub trait Decode {
fn read_byte(&mut self) -> Result<u8>;
fn read_into(&mut self, out: &mut [u8]) -> Result<()>;
fn max_alloc(&self) -> usize;
#[inline]
fn read_varint_u64(&mut self) -> Result<u64> {
let mut result: u64 = 0;
let mut shift: u32 = 0;
for consumed in 1..=varint::MAX_VARINT_LEN_U64 {
let byte = self.read_byte()?;
if consumed == varint::MAX_VARINT_LEN_U64 && (byte & 0xfe) != 0 {
return Err(SerialError::VarintOverflow);
}
result |= u64::from(byte & 0x7f) << shift;
if byte & 0x80 == 0 {
return Ok(result);
}
shift += 7;
}
Err(SerialError::VarintOverflow)
}
#[inline]
fn read_varint_u128(&mut self) -> Result<u128> {
let mut result: u128 = 0;
let mut shift: u32 = 0;
for consumed in 1..=varint::MAX_VARINT_LEN_U128 {
let byte = self.read_byte()?;
if consumed == varint::MAX_VARINT_LEN_U128 && (byte & 0xfc) != 0 {
return Err(SerialError::VarintOverflow);
}
result |= u128::from(byte & 0x7f) << shift;
if byte & 0x80 == 0 {
return Ok(result);
}
shift += 7;
}
Err(SerialError::VarintOverflow)
}
#[inline]
fn read_length_prefixed(&mut self) -> Result<Vec<u8>> {
let declared = self.read_varint_u64()?;
let max = self.max_alloc() as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: 0,
});
}
let len = declared as usize;
let mut buf = vec![0u8; len];
self.read_into(&mut buf)?;
Ok(buf)
}
}
#[derive(Debug, Default)]
pub struct Encoder {
out: Vec<u8>,
}
impl Encoder {
#[must_use]
pub fn new() -> Self {
Self { out: Vec::new() }
}
#[must_use]
pub fn into_buffer(buffer: Vec<u8>) -> Self {
Self { out: buffer }
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8] {
&self.out
}
#[inline]
#[must_use]
pub fn into_inner(self) -> Vec<u8> {
self.out
}
#[must_use]
pub fn take(&mut self) -> Vec<u8> {
core::mem::take(&mut self.out)
}
#[inline]
pub fn write<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
value.serialize(self)
}
}
impl Encode for Encoder {
#[inline]
fn write_byte(&mut self, byte: u8) -> Result<()> {
self.out.push(byte);
Ok(())
}
#[inline]
fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.out.extend_from_slice(bytes);
Ok(())
}
#[inline]
fn reserve(&mut self, additional: usize) {
self.out.reserve(additional);
}
}
#[derive(Debug)]
pub struct Decoder<'a> {
input: &'a [u8],
pos: usize,
config: Config,
}
impl<'a> Decoder<'a> {
#[inline]
#[must_use]
pub fn new(bytes: &'a [u8]) -> Self {
Self {
input: bytes,
pos: 0,
config: Config::default(),
}
}
pub fn with_config(bytes: &'a [u8], config: Config) -> Result<Self> {
Ok(Self {
input: bytes,
pos: 0,
config: config.validate()?,
})
}
#[inline]
#[must_use]
pub fn position(&self) -> usize {
self.pos
}
#[inline]
#[must_use]
pub fn remaining(&self) -> usize {
self.input.len().saturating_sub(self.pos)
}
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.remaining() == 0
}
#[inline]
pub fn read<T: Deserialize>(&mut self) -> Result<T> {
T::deserialize(self)
}
}
impl Decode for Decoder<'_> {
#[inline]
fn read_byte(&mut self) -> Result<u8> {
match self.input.get(self.pos) {
Some(&b) => {
self.pos += 1;
Ok(b)
}
None => Err(SerialError::UnexpectedEof {
needed: 1,
remaining: 0,
}),
}
}
#[inline]
fn read_into(&mut self, out: &mut [u8]) -> Result<()> {
let n = out.len();
let remaining = self.remaining();
if n > remaining {
return Err(SerialError::UnexpectedEof {
needed: n,
remaining,
});
}
let start = self.pos;
let end = start + n;
out.copy_from_slice(&self.input[start..end]);
self.pos = end;
Ok(())
}
#[inline]
fn max_alloc(&self) -> usize {
self.config.max_alloc
}
#[inline]
fn read_length_prefixed(&mut self) -> Result<Vec<u8>> {
let declared = self.read_varint_u64()?;
let max = self.config.max_alloc as u64;
if declared > max {
return Err(SerialError::InvalidLength {
declared,
remaining: self.remaining(),
});
}
let len = declared as usize;
let remaining = self.remaining();
if len > remaining {
return Err(SerialError::InvalidLength {
declared,
remaining,
});
}
let start = self.pos;
let end = start + len;
let slice = &self.input[start..end];
self.pos = end;
Ok(slice.to_vec())
}
}
#[inline]
pub fn encode<T: Serialize + ?Sized>(value: &T) -> Result<Vec<u8>> {
let mut enc = Encoder::new();
value.serialize(&mut enc)?;
Ok(enc.into_inner())
}
#[inline]
pub fn decode<T: Deserialize>(bytes: &[u8]) -> Result<T> {
let mut dec = Decoder::new(bytes);
let value = T::deserialize(&mut dec)?;
let remaining = dec.remaining();
if remaining != 0 {
return Err(SerialError::TrailingBytes { remaining });
}
Ok(value)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_default_has_one_gib_cap() {
let cfg = Config::default();
assert_eq!(cfg.max_alloc, 1 << 30);
}
#[test]
fn decoder_with_zero_cap_is_rejected() {
let cfg = Config::new().with_max_alloc(0);
let err = Decoder::with_config(&[], cfg).expect_err("zero cap is invalid");
assert!(matches!(err, SerialError::InvalidLength { .. }));
}
#[test]
fn encoder_into_buffer_reuses_caller_vec() {
let mut buf = Vec::with_capacity(64);
buf.push(0xff);
let mut enc = Encoder::into_buffer(buf);
enc.write(&7_u64).unwrap();
let out = enc.into_inner();
assert_eq!(out[0], 0xff);
assert!(out.len() > 1);
}
#[test]
fn encoder_take_returns_buffer_and_resets() {
let mut enc = Encoder::new();
enc.write(&1_u64).unwrap();
let first = enc.take();
assert!(!first.is_empty());
assert!(enc.as_bytes().is_empty());
enc.write(&2_u64).unwrap();
let second = enc.take();
assert_eq!(second, [0x02]);
}
#[test]
fn decode_rejects_trailing_bytes() {
let mut bytes = encode(&7_u8).unwrap();
bytes.push(0xff);
let err = decode::<u8>(&bytes).expect_err("trailing bytes should fail");
assert!(matches!(err, SerialError::TrailingBytes { remaining: 1 }));
}
#[test]
fn decoder_read_past_end_returns_unexpected_eof() {
let mut dec = Decoder::new(&[0x01]);
let _: u8 = dec.read().unwrap();
let err = dec.read::<u8>().expect_err("past end should fail");
assert!(matches!(err, SerialError::UnexpectedEof { .. }));
}
#[test]
fn decoder_length_prefix_above_cap_is_rejected() {
let cfg = Config::new().with_max_alloc(4);
let bytes = [0x05, b'h', b'e', b'l', b'l', b'o'];
let mut dec = Decoder::with_config(&bytes, cfg).expect("non-zero cap");
let err = dec
.read_length_prefixed()
.expect_err("length > cap should fail");
assert!(matches!(
err,
SerialError::InvalidLength { declared: 5, .. }
));
}
#[test]
fn decoder_length_prefix_overflowing_remaining_is_rejected() {
let bytes = [0x10, b'a', b'b'];
let mut dec = Decoder::new(&bytes);
let err = dec
.read_length_prefixed()
.expect_err("length > remaining should fail");
assert!(matches!(err, SerialError::InvalidLength { .. }));
}
#[test]
fn decoder_position_advances_with_reads() {
let bytes = [0x01, 0x02, 0x03];
let mut dec = Decoder::new(&bytes);
assert_eq!(dec.position(), 0);
let _ = dec.read_byte().unwrap();
assert_eq!(dec.position(), 1);
let mut buf = [0u8; 2];
dec.read_into(&mut buf).unwrap();
assert_eq!(dec.position(), 3);
assert!(dec.is_empty());
}
#[test]
fn read_into_short_read_is_rejected() {
let mut dec = Decoder::new(&[0x01, 0x02]);
let mut buf = [0u8; 4];
let err = dec.read_into(&mut buf).expect_err("short read");
assert!(matches!(err, SerialError::UnexpectedEof { .. }));
}
}