use std::{
io::{Error, ErrorKind},
pin::Pin,
task::{Context, Poll, ready},
};
use bytes::{Bytes, BytesMut};
use futures_core::stream::TryStream;
use futures_io::{AsyncRead, AsyncWrite};
use crate::{Deserialize, IoDuplex, Serialize, Sink, Stream};
pub trait Codec<Io> {
type Framed: IoDuplex;
fn new_framed(&self, io: Io) -> Self::Framed;
}
pub trait Serializer {
type Error;
fn serialize<T: Serialize>(&mut self, item: &T) -> Result<Bytes, Self::Error>;
}
pub trait Deserializer {
type Error;
fn deserialize<T: Deserialize>(&mut self, buf: &BytesMut) -> Result<T, Self::Error>;
}
#[cfg(feature = "bincode")]
mod bincode_impl {
use super::*;
use bincode::{deserialize, serialize};
#[derive(Default, Clone)]
pub struct Bincode;
impl Serializer for Bincode {
type Error = bincode::Error;
fn serialize<T: Serialize>(&mut self, item: &T) -> Result<Bytes, Self::Error> {
Ok(Bytes::from(serialize(item)?))
}
}
impl Deserializer for Bincode {
type Error = bincode::Error;
fn deserialize<T: Deserialize>(&mut self, buf: &BytesMut) -> Result<T, Self::Error> {
deserialize(buf)
}
}
use tokio_util::{
codec::{Framed as TokioFramed, LengthDelimitedCodec},
compat::{Compat, FuturesAsyncReadCompatExt as _},
};
impl<Io> Codec<Io> for Bincode
where
Io: AsyncRead + AsyncWrite + Unpin,
{
type Framed = Framed<TokioFramed<Compat<Io>, LengthDelimitedCodec>, Self>;
fn new_framed(&self, io: Io) -> Self::Framed {
Framed::new(
LengthDelimitedCodec::builder().new_framed(io.compat()),
self.clone(),
)
}
}
}
#[cfg(feature = "bincode")]
pub use bincode_impl::Bincode;
pub struct Framed<T, C> {
inner: T,
codec: C,
}
impl<T, C> Framed<T, C> {
pub fn new(inner: T, codec: C) -> Self {
Self { inner, codec }
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}
pub fn codec(&self) -> &C {
&self.codec
}
pub fn codec_mut(&mut self) -> &mut C {
&mut self.codec
}
pub fn into_parts(self) -> (T, C) {
(self.inner, self.codec)
}
}
impl<T, C> Sink for Framed<T, C>
where
T: futures_sink::Sink<Bytes, Error = Error> + Unpin,
C: Serializer + Unpin,
<C as Serializer>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_ready(cx)
}
fn start_send<I: Serialize>(
mut self: std::pin::Pin<&mut Self>,
item: I,
) -> Result<(), Self::Error> {
let buf = self
.codec
.serialize(&item)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
Pin::new(&mut self.inner).start_send(buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.inner).poll_close(cx)
}
}
impl<T, C> Stream for Framed<T, C>
where
T: TryStream<Ok = BytesMut, Error = Error> + Unpin,
C: Deserializer + Unpin,
<C as Deserializer>::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
type Error = Error;
fn poll_next<Item: Deserialize>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, Error>>> {
let Some(buf) = ready!(Pin::new(&mut self.inner).try_poll_next(cx)) else {
return Poll::Ready(None);
};
let item = self
.codec
.deserialize(&buf?)
.map_err(|e| Error::new(ErrorKind::InvalidData, e));
Poll::Ready(Some(item))
}
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use tokio::io::duplex;
use tokio_util::compat::TokioAsyncReadCompatExt;
use crate::{SinkExt, StreamExt};
use super::*;
#[derive(Serialize, Deserialize)]
struct Ping;
#[derive(Serialize, Deserialize)]
struct Pong;
#[test]
fn test_framed() {
let (a, b) = duplex(1024);
let mut a = Bincode.new_framed(a.compat());
let mut b = Bincode.new_framed(b.compat());
let a = async {
a.send(Ping).await.unwrap();
a.next::<Pong>().await.unwrap().unwrap();
};
let b = async {
b.next::<Ping>().await.unwrap().unwrap();
b.send(Pong).await.unwrap();
};
futures::executor::block_on(async {
futures::join!(a, b);
});
}
}