use std::default::Default;
use std::error::Error as ErrorTrait;
use std::fmt::{Display, Formatter, Result as FmtResult};
use std::io::{Error as IoError, Read, Result as IoResult};
use std::marker::PhantomData;
use bytes::{Buf, BufMut, BytesMut};
use serde::{Deserialize, Serialize};
use serde_cbor::de::{Deserializer, IoRead};
use serde_cbor::error::Error as CborError;
use serde_cbor::ser::{IoWrite, Serializer};
use tokio_util::codec::{Decoder as IoDecoder, Encoder as IoEncoder};
#[derive(Debug)]
#[non_exhaustive]
pub enum Error {
Io(IoError),
Cbor(CborError),
}
impl From<IoError> for Error {
fn from(error: IoError) -> Self {
Error::Io(error)
}
}
impl From<CborError> for Error {
fn from(error: CborError) -> Self {
Error::Cbor(error)
}
}
impl Display for Error {
fn fmt(&self, fmt: &mut Formatter) -> FmtResult {
match self {
Error::Io(e) => e.fmt(fmt),
Error::Cbor(e) => e.fmt(fmt),
}
}
}
impl ErrorTrait for Error {
fn cause(&self) -> Option<&dyn ErrorTrait> {
match self {
Error::Io(e) => Some(e),
Error::Cbor(e) => Some(e),
}
}
}
struct Counted<'a, R: 'a> {
r: &'a mut R,
pos: &'a mut usize,
}
impl<'a, R: Read> Read for Counted<'a, R> {
fn read(&mut self, buf: &mut [u8]) -> IoResult<usize> {
match self.r.read(buf) {
Ok(size) => {
*self.pos += size;
Ok(size)
}
e => e,
}
}
}
#[derive(Clone, Debug)]
pub struct Decoder<Item> {
_data: PhantomData<fn() -> Item>,
}
impl<'de, Item: Deserialize<'de>> Decoder<Item> {
pub fn new() -> Self {
Self { _data: PhantomData }
}
}
impl<'de, Item: Deserialize<'de>> Default for Decoder<Item> {
fn default() -> Self {
Self::new()
}
}
impl<'de, Item: Deserialize<'de>> IoDecoder for Decoder<Item> {
type Item = Item;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Item>, Error> {
let mut pos = 0;
let result = {
let mut slice: &[u8] = src;
let reader = Counted {
r: &mut slice,
pos: &mut pos,
};
let reader = IoRead::new(reader);
let mut deserializer = Deserializer::new(reader);
Item::deserialize(&mut deserializer)
};
match result {
Ok(item) => {
src.advance(pos);
Ok(Some(item))
}
Err(ref error) if error.is_eof() => Ok(None),
Err(e) => Err(e.into()),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum SdMode {
Always,
Once,
Never,
}
#[derive(Clone, Debug)]
pub struct Encoder<Item> {
_data: PhantomData<fn(Item)>,
sd: SdMode,
packed: bool,
}
impl<Item: Serialize> Encoder<Item> {
pub fn new() -> Self {
Self {
_data: PhantomData,
sd: SdMode::Never,
packed: false,
}
}
pub fn sd(self, sd: SdMode) -> Self {
Self { sd, ..self }
}
pub fn packed(self, packed: bool) -> Self {
Self { packed, ..self }
}
}
impl<Item: Serialize> Default for Encoder<Item> {
fn default() -> Self {
Self::new()
}
}
impl<Item: Serialize> IoEncoder<Item> for Encoder<Item> {
type Error = Error;
fn encode(&mut self, item: Item, dst: &mut BytesMut) -> Result<(), Error> {
let mut serializer = if self.packed {
Serializer::new(IoWrite::new(dst.writer())).packed_format()
} else {
Serializer::new(IoWrite::new(dst.writer()))
};
if self.sd != SdMode::Never {
serializer.self_describe()?;
}
if self.sd == SdMode::Once {
self.sd = SdMode::Never;
}
item.serialize(&mut serializer).map_err(Into::into)
}
}
#[derive(Clone, Debug)]
pub struct Codec<Dec, Enc> {
dec: Decoder<Dec>,
enc: Encoder<Enc>,
}
impl<'de, Dec: Deserialize<'de>, Enc: Serialize> Codec<Dec, Enc> {
pub fn new() -> Self {
Self {
dec: Decoder::new(),
enc: Encoder::new(),
}
}
pub fn sd(self, sd: SdMode) -> Self {
Self {
dec: self.dec,
enc: Encoder { sd, ..self.enc },
}
}
pub fn packed(self, packed: bool) -> Self {
Self {
dec: self.dec,
enc: Encoder { packed, ..self.enc },
}
}
}
impl<'de, Dec: Deserialize<'de>, Enc: Serialize> Default for Codec<Dec, Enc> {
fn default() -> Self {
Self::new()
}
}
impl<'de, Dec: Deserialize<'de>, Enc: Serialize> IoDecoder for Codec<Dec, Enc> {
type Item = Dec;
type Error = Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Dec>, Error> {
self.dec.decode(src)
}
}
impl<'de, Dec: Deserialize<'de>, Enc: Serialize> IoEncoder<Enc> for Codec<Dec, Enc> {
type Error = Error;
fn encode(&mut self, item: Enc, dst: &mut BytesMut) -> Result<(), Error> {
self.enc.encode(item, dst)
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use super::*;
type TestData = HashMap<String, usize>;
fn test_data() -> TestData {
let mut data = HashMap::new();
data.insert("hello".to_owned(), 42usize);
data.insert("world".to_owned(), 0usize);
data
}
fn decode<Dec: IoDecoder<Item = TestData, Error = Error>>(dec: Dec) {
let mut decoder = dec;
let data = test_data();
let encoded = serde_cbor::to_vec(&data).unwrap();
let mut all = BytesMut::with_capacity(128);
all.extend(&encoded);
all.extend(&encoded);
all.extend(&encoded[..1]);
let decoded = decoder.decode(&mut all).unwrap().unwrap();
assert_eq!(data, decoded);
let decoded = decoder.decode(&mut all).unwrap().unwrap();
assert_eq!(data, decoded);
assert_eq!(1, all.len());
assert!(decoder.decode(&mut all).unwrap().is_none());
assert_eq!(1, all.len());
all.extend(&encoded[1..]);
let decoded = decoder.decode(&mut all).unwrap().unwrap();
assert_eq!(data, decoded);
assert!(all.is_empty());
all.extend(&[0, 1, 2, 3, 4]);
decoder.decode(&mut all).unwrap_err();
assert_eq!(5, all.len());
}
#[test]
fn decode_only() {
let decoder = Decoder::new();
decode(decoder);
}
#[test]
fn decode_codec() {
let decoder: Codec<_, ()> = Codec::new();
decode(decoder);
}
fn encode<Enc: IoEncoder<TestData, Error = Error>>(enc: Enc) {
let mut encoder = enc;
let data = test_data();
let mut buffer = BytesMut::with_capacity(0);
encoder.encode(data.clone(), &mut buffer).unwrap();
let pos1 = buffer.len();
let decoded = serde_cbor::from_slice::<TestData>(&buffer).unwrap();
assert_eq!(data, decoded);
encoder.encode(data.clone(), &mut buffer).unwrap();
let pos2 = buffer.len();
assert!(pos2 > pos1);
assert!(pos1 * 2 > pos2);
let decoded = serde_cbor::from_slice::<TestData>(&buffer[pos1..]).unwrap();
assert_eq!(data, decoded);
encoder.encode(data, &mut buffer).unwrap();
let pos3 = buffer.len();
assert_eq!(pos2 - pos1, pos3 - pos2);
}
#[test]
fn encode_only() {
let encoder = Encoder::new().sd(SdMode::Once);
encode(encoder);
}
#[test]
fn encode_packed() {
let encoder = Encoder::new().packed(true).sd(SdMode::Once);
encode(encoder);
}
#[test]
fn encode_codec() {
let encoder: Codec<(), _> = Codec::new().sd(SdMode::Once);
encode(encoder);
}
#[test]
fn is_send() {
let codec: Codec<(), ()> = Codec::new();
std::thread::spawn(move || {
let _c = codec;
});
}
#[test]
fn is_sync() {
let codec: Arc<Codec<(), ()>> = Arc::new(Codec::new());
std::thread::spawn(move || {
let _c = codec;
});
}
}