use crate::group::{DataSink, DataSource};
use crate::util::BasicHeader;
use crate::{MessageId, Versioned};
use serde::de::DeserializeOwned;
use serde::Serialize;
use std::convert::TryInto;
use std::io::{self, Cursor, Read, Write};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum CborDataError {
#[error("IO Error")]
Io(Option<io::Error>),
#[error("Serialize/Deserialize Error")]
Serializer,
#[error("Premature EOF")]
Eof,
}
impl From<serde_cbor::Error> for CborDataError {
fn from(e: serde_cbor::Error) -> Self {
use serde_cbor::error::Category;
match e.classify() {
Category::Io => CborDataError::Io(None),
Category::Syntax => CborDataError::Serializer,
Category::Data => CborDataError::Serializer,
Category::Eof => CborDataError::Eof,
}
}
}
impl From<io::Error> for CborDataError {
fn from(e: io::Error) -> Self {
CborDataError::Io(Some(e))
}
}
pub struct CborData<RW> {
inner: RW,
}
impl<RW> CborData<RW> {
pub fn new(reader: RW) -> Self {
CborData { inner: reader }
}
pub fn into_inner(self) -> RW {
self.inner
}
}
impl<R> DataSource for CborData<R>
where
R: Read,
{
type Error = CborDataError;
type Header = BasicHeader;
fn read_header(&mut self) -> Result<BasicHeader, CborDataError> {
Ok(BasicHeader::deserialize_from(&mut self.inner)?)
}
fn read_message<T>(&mut self, header: &BasicHeader) -> Result<T, CborDataError>
where
T: DeserializeOwned,
{
let reader = &mut self.inner;
let mut subreader = reader.take(header.msg_len.into());
let msg: T = serde_cbor::from_reader(&mut subreader)?;
Ok(msg)
}
fn unknown_message(&self, _msg_id: u16) -> CborDataError {
CborDataError::Serializer
}
fn unknown_version<T>(&self, _ver: u16) -> CborDataError {
CborDataError::Serializer
}
fn unexpected_message<T>(&self, _msg_id: u16) -> CborDataError {
CborDataError::Serializer
}
}
impl<W> DataSink for CborData<W>
where
W: Write,
{
type Error = CborDataError;
fn write_message<T>(&mut self, msg: &T) -> Result<(), CborDataError>
where
T: Serialize + Versioned,
T::Base: MessageId,
{
let msg_buf = Vec::<u8>::new();
let mut cursor = Cursor::new(msg_buf);
serde_cbor::to_writer(&mut cursor, msg)?;
let msg_buf = cursor.into_inner();
let msg_len: u32 = msg_buf.len().try_into().expect("usize to u32");
let header = BasicHeader::for_msg(msg, msg_len);
header.serialize_into(&mut self.inner)?;
self.inner.write_all(&msg_buf)?;
Ok(())
}
}