use std::io::{Read, Write};
use crate::codec::{Config, Decode, Encode};
use crate::error::{Result, SerialError};
use crate::traits::{Deserialize, Serialize};
#[derive(Debug)]
pub struct IoEncoder<W: Write> {
writer: W,
}
impl<W: Write> IoEncoder<W> {
#[must_use]
pub fn new(writer: W) -> Self {
Self { writer }
}
#[must_use]
pub fn writer(&self) -> &W {
&self.writer
}
#[must_use]
pub fn writer_mut(&mut self) -> &mut W {
&mut self.writer
}
#[must_use]
pub fn into_inner(self) -> W {
self.writer
}
#[inline]
pub fn write<T: Serialize + ?Sized>(&mut self, value: &T) -> Result<()> {
value.serialize(self)
}
}
impl<W: Write> Encode for IoEncoder<W> {
#[inline]
fn write_byte(&mut self, byte: u8) -> Result<()> {
self.writer.write_all(&[byte]).map_err(map_io_error)
}
#[inline]
fn write_bytes(&mut self, bytes: &[u8]) -> Result<()> {
self.writer.write_all(bytes).map_err(map_io_error)
}
}
#[derive(Debug)]
pub struct IoDecoder<R: Read> {
reader: R,
config: Config,
}
impl<R: Read> IoDecoder<R> {
#[must_use]
pub fn new(reader: R) -> Self {
Self {
reader,
config: Config::default(),
}
}
pub fn with_config(reader: R, config: Config) -> Result<Self> {
Ok(Self {
reader,
config: config.validate()?,
})
}
#[must_use]
pub fn reader(&self) -> &R {
&self.reader
}
#[must_use]
pub fn into_inner(self) -> R {
self.reader
}
#[inline]
pub fn read<T: Deserialize>(&mut self) -> Result<T> {
T::deserialize(self)
}
}
impl<R: Read> Decode for IoDecoder<R> {
fn read_byte(&mut self) -> Result<u8> {
let mut buf = [0u8; 1];
self.read_into(&mut buf)?;
Ok(buf[0])
}
fn read_into(&mut self, out: &mut [u8]) -> Result<()> {
self.reader.read_exact(out).map_err(|e| {
if e.kind() == std::io::ErrorKind::UnexpectedEof {
SerialError::UnexpectedEof {
needed: out.len(),
remaining: 0,
}
} else {
map_io_error(e)
}
})
}
fn max_alloc(&self) -> usize {
self.config.max_alloc
}
}
#[inline]
pub fn encode_into<T, W>(value: &T, writer: &mut W) -> Result<()>
where
T: Serialize + ?Sized,
W: Write,
{
let mut enc = IoEncoder::new(writer);
enc.write(value)
}
pub fn decode_from<T, R>(reader: &mut R) -> Result<T>
where
T: Deserialize,
R: Read,
{
let mut buf = alloc::vec::Vec::new();
let _ = reader.read_to_end(&mut buf).map_err(map_io_error)?;
crate::decode(&buf)
}
#[inline]
fn map_io_error(err: std::io::Error) -> SerialError {
use alloc::string::ToString;
SerialError::Io {
kind: err.kind(),
message: err.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode;
use alloc::vec::Vec;
use std::io::Cursor;
#[test]
fn io_encoder_decoder_round_trip() {
let mut buf: Vec<u8> = Vec::new();
{
let mut enc = IoEncoder::new(&mut buf);
enc.write(&42_u64).unwrap();
enc.write(&"hello").unwrap();
enc.write(&true).unwrap();
}
let mut dec = IoDecoder::new(Cursor::new(buf));
let n: u64 = dec.read().unwrap();
let s: String = dec.read().unwrap();
let b: bool = dec.read().unwrap();
assert_eq!((n, s.as_str(), b), (42, "hello", true));
}
#[test]
fn encode_into_writes_same_bytes_as_encode() {
let value = (1u32, String::from("hi"), -2i32);
let from_fn = encode(&value).unwrap();
let mut from_io: Vec<u8> = Vec::new();
encode_into(&value, &mut from_io).unwrap();
assert_eq!(from_fn, from_io);
}
#[test]
fn decode_from_reads_same_value_as_decode() {
let bytes = encode(&(7u64, true)).unwrap();
let value: (u64, bool) = decode_from(&mut Cursor::new(bytes)).unwrap();
assert_eq!(value, (7, true));
}
#[test]
fn io_decoder_with_zero_cap_is_rejected() {
let cfg = Config::new().with_max_alloc(0);
let bytes: Vec<u8> = Vec::new();
let err = IoDecoder::with_config(Cursor::new(bytes), cfg).expect_err("zero cap");
assert!(matches!(err, SerialError::InvalidLength { .. }));
}
#[test]
fn io_decoder_short_read_surfaces_unexpected_eof() {
let bytes = alloc::vec![0x80];
let mut dec = IoDecoder::new(Cursor::new(bytes));
let err = dec.read::<u64>().expect_err("truncated");
assert!(matches!(err, SerialError::UnexpectedEof { .. }));
}
}