use {
bytes::{Bytes, BytesMut},
futures::{prelude::*, ready},
pin_project::pin_project,
std::{
marker::PhantomData,
pin::Pin,
task::{Context, Poll},
},
};
pub trait Serializer<T> {
type Error;
fn serialize(self: Pin<&mut Self>, item: &T) -> Result<Bytes, Self::Error>;
}
pub trait Deserializer<T> {
type Error;
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<T, Self::Error>;
}
#[pin_project]
pub struct FramedRead<T, U, S> {
#[pin]
inner: T,
#[pin]
deserializer: S,
item: PhantomData<U>,
}
#[pin_project]
pub struct FramedWrite<T, U, S> {
#[pin]
inner: T,
#[pin]
serializer: S,
item: PhantomData<U>,
}
impl<T, U, S> FramedRead<T, U, S> {
pub fn new(inner: T, deserializer: S) -> FramedRead<T, U, S> {
FramedRead {
inner,
deserializer,
item: PhantomData,
}
}
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T, U, S> Stream for FramedRead<T, U, S>
where
T: TryStream<Ok = BytesMut>,
T::Error: From<S::Error>,
BytesMut: From<T::Ok>,
S: Deserializer<U>,
{
type Item = Result<U, T::Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match ready!(self.as_mut().project().inner.try_poll_next(cx)) {
Some(bytes) => Poll::Ready(Some(Ok(self
.as_mut()
.project()
.deserializer
.deserialize(&bytes?)?))),
None => Poll::Ready(None),
}
}
}
impl<T, U, S, SinkItem> Sink<SinkItem> for FramedRead<T, U, S>
where
T: Sink<SinkItem>,
{
type Error = T::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project().inner.start_send(item)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx)
}
}
impl<T, U, S> FramedWrite<T, U, S> {
pub fn new(inner: T, serializer: S) -> Self {
FramedWrite {
inner,
serializer,
item: PhantomData,
}
}
pub fn get_ref(&self) -> &T {
&self.inner
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T, U, S> Stream for FramedWrite<T, U, S>
where
T: Stream,
{
type Item = T::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.project().inner.poll_next(cx)
}
}
impl<T, U, S> Sink<U> for FramedWrite<T, U, S>
where
T: Sink<Bytes>,
S: Serializer<U>,
S::Error: Into<T::Error>,
{
type Error = T::Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: U) -> Result<(), Self::Error> {
let res = self.as_mut().project().serializer.serialize(&item);
let bytes = res.map_err(Into::into)?;
self.as_mut().project().inner.start_send(bytes)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
ready!(self.as_mut().poll_flush(cx))?;
self.project().inner.poll_close(cx)
}
}
#[cfg(any(feature = "json", feature = "bincode", feature = "messagepack"))]
pub mod formats {
#[cfg(feature = "bincode")]
pub use self::bincode::Bincode;
#[cfg(feature = "json")]
pub use self::json::Json;
#[cfg(feature = "messagepack")]
pub use self::messagepack::MessagePack;
use {
super::{Deserializer, Serializer},
bytes::{buf::BufExt, Bytes, BytesMut},
derivative::Derivative,
serde::{Deserialize, Serialize},
std::{io, marker::PhantomData, pin::Pin},
};
#[cfg(feature = "bincode")]
mod bincode {
use super::*;
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct Bincode<T> {
ghost: PhantomData<T>,
}
impl<T> Deserializer<T> for Bincode<T>
where
T: for<'de> Deserialize<'de>,
{
type Error = io::Error;
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<T, Self::Error> {
Ok(serde_bincode::deserialize(src)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?)
}
}
impl<T> Serializer<T> for Bincode<T>
where
T: Serialize,
{
type Error = io::Error;
fn serialize(self: Pin<&mut Self>, item: &T) -> Result<Bytes, Self::Error> {
Ok(serde_bincode::serialize(item)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
.into())
}
}
}
#[cfg(feature = "json")]
mod json {
use super::*;
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct Json<T> {
ghost: PhantomData<T>,
}
impl<T> Deserializer<T> for Json<T>
where
for<'a> T: Deserialize<'a>,
{
type Error = serde_json::Error;
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<T, Self::Error> {
serde_json::from_reader(std::io::Cursor::new(src).reader())
}
}
impl<T> Serializer<T> for Json<T>
where
T: Serialize,
{
type Error = serde_json::Error;
fn serialize(self: Pin<&mut Self>, item: &T) -> Result<Bytes, Self::Error> {
serde_json::to_vec(item).map(Into::into)
}
}
}
#[cfg(feature = "messagepack")]
mod messagepack {
use super::*;
use std::io;
#[derive(Derivative)]
#[derivative(Default(bound = ""))]
pub struct MessagePack<T> {
ghost: PhantomData<T>,
}
impl<T> Deserializer<T> for MessagePack<T>
where
for<'a> T: Deserialize<'a>,
{
type Error = io::Error;
fn deserialize(self: Pin<&mut Self>, src: &BytesMut) -> Result<T, Self::Error> {
Ok(
serde_messagepack::from_read(std::io::Cursor::new(src).reader())
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?,
)
}
}
impl<T> Serializer<T> for MessagePack<T>
where
T: Serialize,
{
type Error = io::Error;
fn serialize(self: Pin<&mut Self>, item: &T) -> Result<Bytes, Self::Error> {
Ok(serde_messagepack::to_vec(item)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?
.into())
}
}
}
}