use core::marker::PhantomData;
use serde::{Deserialize, Serialize};
use crate::io::{AsyncBytesRead, AsyncBytesWrite, LocalLock};
use crate::stream::codec::{Decoder, Encoder};
use crate::stream::framing::{FrameReader, FrameWriter};
use crate::stream::routing::{
MuxedReplyToken, MuxedSlots, ReplyRouter, RouterSlotHandle, Sequential,
};
use crate::transport::{ClientTransport, ServerTransport};
pub trait RouterStorage {
type Router: ReplyRouter;
fn router(&self) -> &Self::Router;
}
impl RouterStorage for Sequential {
type Router = Sequential;
#[inline]
fn router(&self) -> &Self::Router {
self
}
}
impl RouterStorage for Box<Sequential> {
type Router = Sequential;
#[inline]
fn router(&self) -> &Self::Router {
self
}
}
impl<const N: usize, const BUF: usize> RouterStorage for MuxedSlots<N, BUF> {
type Router = MuxedSlots<N, BUF>;
#[inline]
fn router(&self) -> &Self::Router {
self
}
}
impl<const N: usize, const BUF: usize> RouterStorage for Box<MuxedSlots<N, BUF>> {
type Router = MuxedSlots<N, BUF>;
#[inline]
fn router(&self) -> &Self::Router {
self
}
}
pub struct StreamTransport<R, W, Framer, Codec, Router, Req, Resp, S = Router>
where
S: RouterStorage<Router = Router>,
{
reader: LocalLock<R>,
writer: LocalLock<W>,
framer: Framer,
codec: Codec,
router: S,
_phantom: PhantomData<(Router, Req, Resp)>,
}
impl<R, W, Framer, Codec, Router, Req, Resp> StreamTransport<R, W, Framer, Codec, Router, Req, Resp>
where
Framer: Default,
Codec: Default,
Router: Default + RouterStorage<Router = Router>,
{
pub fn new(reader: R, writer: W) -> Self {
Self::with_layers(
reader,
writer,
Framer::default(),
Codec::default(),
Router::default(),
)
}
}
impl<R, W, Framer, Codec, Router, Req, Resp> StreamTransport<R, W, Framer, Codec, Router, Req, Resp>
where
Router: RouterStorage<Router = Router>,
{
pub fn with_layers(reader: R, writer: W, framer: Framer, codec: Codec, router: Router) -> Self {
Self {
reader: LocalLock::new(reader),
writer: LocalLock::new(writer),
framer,
codec,
router,
_phantom: PhantomData,
}
}
}
impl<R, W, Framer, Codec, Router, Req, Resp>
StreamTransport<R, W, Framer, Codec, Router, Req, Resp, Box<Router>>
where
Router: ReplyRouter,
Box<Router>: RouterStorage<Router = Router>,
{
pub fn with_boxed_router(
reader: R,
writer: W,
framer: Framer,
codec: Codec,
router: Box<Router>,
) -> Self {
Self {
reader: LocalLock::new(reader),
writer: LocalLock::new(writer),
framer,
codec,
router,
_phantom: PhantomData,
}
}
}
#[derive(Debug)]
pub enum StreamTransportError<F, C> {
Framing(F),
Codec(C),
}
impl<F: core::fmt::Display, C: core::fmt::Display> core::fmt::Display
for StreamTransportError<F, C>
{
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
StreamTransportError::Framing(e) => write!(f, "{e}"),
StreamTransportError::Codec(e) => write!(f, "codec error: {e}"),
}
}
}
impl<T, F: core::fmt::Debug, C: core::fmt::Debug> crate::TransportResult<T>
for StreamTransportError<F, C>
{
type Output = Result<T, crate::CallError<StreamTransportError<F, C>>>;
fn into_output(result: Result<T, Self>) -> Self::Output {
result.map_err(crate::CallError::Transport)
}
}
type FramingRE<Framer, R> = <Framer as FrameReader>::Error<<R as AsyncBytesRead>::Error>;
type FramingWE<Framer, W> = <Framer as FrameWriter>::Error<<W as AsyncBytesWrite>::Error>;
pub struct StreamReplyToken<Router> {
_phantom: PhantomData<Router>,
}
impl<Router> StreamReplyToken<Router> {
fn new() -> Self {
Self {
_phantom: PhantomData,
}
}
}
impl<R, W, Framer, Codec, Req, Resp, S> ClientTransport<Req, Resp>
for StreamTransport<R, W, Framer, Codec, Sequential, Resp, Req, S>
where
R: AsyncBytesRead,
W: AsyncBytesWrite,
Framer: FrameWriter + FrameReader,
<Framer as FrameWriter>::Error<W::Error>: Into<FramingWE<Framer, W>>,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: Serialize,
Resp: for<'de> Deserialize<'de>,
FramingWE<Framer, W>: core::fmt::Debug,
FramingRE<Framer, R>: core::fmt::Debug + From<FramingWE<Framer, W>>,
<Codec as Encoder>::Error: core::fmt::Debug,
S: RouterStorage<Router = Sequential>,
{
type Error = StreamTransportError<FramingRE<Framer, R>, <Codec as Encoder>::Error>;
async fn call(&self, req: Req) -> Result<Resp, Self::Error> {
let bytes = self
.codec
.encode_to_vec(&req)
.map_err(StreamTransportError::Codec)?;
{
let mut writer = self.writer.lock().await;
self.framer
.write_frame(&mut *writer, &bytes)
.await
.map_err(|e| {
StreamTransportError::<FramingRE<Framer, R>, _>::Framing(
<FramingRE<Framer, R> as From<FramingWE<Framer, W>>>::from(e),
)
})?;
}
let buf = {
let mut reader = self.reader.lock().await;
self.framer
.read_frame(&mut *reader)
.await
.map_err(StreamTransportError::Framing)?
};
self.codec.decode(&buf).map_err(StreamTransportError::Codec)
}
}
impl<R, W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp, S> ClientTransport<Req, Resp>
for StreamTransport<R, W, Framer, Codec, MuxedSlots<N, BUF>, Resp, Req, S>
where
R: AsyncBytesRead,
W: AsyncBytesWrite,
Framer: FrameWriter + FrameReader,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: Serialize,
Resp: for<'de> Deserialize<'de>,
FramingWE<Framer, W>: core::fmt::Debug,
FramingRE<Framer, R>: core::fmt::Debug + From<FramingWE<Framer, W>>,
<Codec as Encoder>::Error: core::fmt::Debug,
S: RouterStorage<Router = MuxedSlots<N, BUF>>,
{
type Error = StreamTransportError<FramingRE<Framer, R>, <Codec as Encoder>::Error>;
async fn call(&self, req: Req) -> Result<Resp, Self::Error> {
let slot = self
.router
.router()
.acquire()
.await
.map_err(|e| match e {})?;
let payload = self
.codec
.encode_to_vec(&req)
.map_err(StreamTransportError::Codec)?;
let mut frame = Vec::with_capacity(1 + payload.len());
frame.push(slot.slot_id());
frame.extend_from_slice(&payload);
{
let mut writer = self.writer.lock().await;
self.framer
.write_frame(&mut *writer, &frame)
.await
.map_err(|e| {
StreamTransportError::<FramingRE<Framer, R>, _>::Framing(
<FramingRE<Framer, R> as From<FramingWE<Framer, W>>>::from(e),
)
})?;
}
loop {
if let Some(data) = self.router.router().try_recv_slot(slot.slot_id()) {
return self.codec.decode(data).map_err(StreamTransportError::Codec);
}
let frame = {
let mut reader = self.reader.lock().await;
if let Some(data) = self.router.router().try_recv_slot(slot.slot_id()) {
return self.codec.decode(data).map_err(StreamTransportError::Codec);
}
self.framer
.read_frame(&mut *reader)
.await
.map_err(StreamTransportError::Framing)?
};
let reply_slot_id = MuxedSlots::<N, BUF>::parse_header(&frame[..1]);
let reply_payload = &frame[1..];
if reply_slot_id == slot.slot_id() {
return self
.codec
.decode(reply_payload)
.map_err(StreamTransportError::Codec);
}
self.router.router().deliver(reply_slot_id, reply_payload);
}
}
}
impl<R, W, Framer, Codec, Req, Resp, S> ServerTransport<Req, Resp>
for StreamTransport<R, W, Framer, Codec, Sequential, Req, Resp, S>
where
R: AsyncBytesRead,
W: AsyncBytesWrite,
Framer: FrameWriter + FrameReader,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: for<'de> Deserialize<'de>,
Resp: Serialize,
FramingWE<Framer, W>: core::fmt::Debug,
FramingRE<Framer, R>: core::fmt::Debug + From<FramingWE<Framer, W>>,
<Codec as Encoder>::Error: core::fmt::Debug,
S: RouterStorage<Router = Sequential>,
{
type Error = StreamTransportError<FramingRE<Framer, R>, <Codec as Encoder>::Error>;
type ReplyToken = StreamReplyToken<Sequential>;
async fn recv(&mut self) -> Result<(Req, Self::ReplyToken), Self::Error> {
let buf = {
let mut reader = self.reader.lock().await;
self.framer
.read_frame(&mut *reader)
.await
.map_err(StreamTransportError::Framing)?
};
let req = self
.codec
.decode(&buf)
.map_err(StreamTransportError::Codec)?;
Ok((req, StreamReplyToken::new()))
}
async fn reply(&self, _token: Self::ReplyToken, resp: Resp) -> Result<(), Self::Error> {
let bytes = self
.codec
.encode_to_vec(&resp)
.map_err(StreamTransportError::Codec)?;
let mut writer = self.writer.lock().await;
self.framer
.write_frame(&mut *writer, &bytes)
.await
.map_err(|e| {
StreamTransportError::<FramingRE<Framer, R>, _>::Framing(
<FramingRE<Framer, R> as From<FramingWE<Framer, W>>>::from(e),
)
})
}
}
impl<R, W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp, S> ServerTransport<Req, Resp>
for StreamTransport<R, W, Framer, Codec, MuxedSlots<N, BUF>, Req, Resp, S>
where
R: AsyncBytesRead,
W: AsyncBytesWrite,
Framer: FrameWriter + FrameReader,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: for<'de> Deserialize<'de>,
Resp: Serialize,
FramingWE<Framer, W>: core::fmt::Debug,
FramingRE<Framer, R>: core::fmt::Debug + From<FramingWE<Framer, W>>,
<Codec as Encoder>::Error: core::fmt::Debug,
S: RouterStorage<Router = MuxedSlots<N, BUF>>,
{
type Error = StreamTransportError<FramingRE<Framer, R>, <Codec as Encoder>::Error>;
type ReplyToken = MuxedReplyToken;
async fn recv(&mut self) -> Result<(Req, Self::ReplyToken), Self::Error> {
let frame = {
let mut reader = self.reader.lock().await;
self.framer
.read_frame(&mut *reader)
.await
.map_err(StreamTransportError::Framing)?
};
let slot_id = MuxedSlots::<N, BUF>::parse_header(&frame[..1]);
let payload = &frame[1..];
let req = self
.codec
.decode(payload)
.map_err(StreamTransportError::Codec)?;
Ok((req, MuxedReplyToken::new(slot_id)))
}
async fn reply(&self, token: Self::ReplyToken, resp: Resp) -> Result<(), Self::Error> {
let payload = self
.codec
.encode_to_vec(&resp)
.map_err(StreamTransportError::Codec)?;
let mut frame = Vec::with_capacity(1 + payload.len());
frame.push(token.slot_id());
frame.extend_from_slice(&payload);
let mut writer = self.writer.lock().await;
self.framer
.write_frame(&mut *writer, &frame)
.await
.map_err(|e| {
StreamTransportError::<FramingRE<Framer, R>, _>::Framing(
<FramingRE<Framer, R> as From<FramingWE<Framer, W>>>::from(e),
)
})
}
}
#[cfg(all(test, feature = "postcard"))]
mod tests {
use super::*;
use crate::io::cursor::{cursor_read, cursor_write};
use crate::io::mem_pipe::duplex;
use crate::stream::routing::MuxedSlots;
use crate::stream::{LengthPrefixed, PostcardCodec, Sequential};
fn block_on<F: core::future::Future>(fut: F) -> F::Output {
futures_lite::future::block_on(fut)
}
#[test]
fn sequential_round_trip_via_cursors() {
type ServerT = StreamTransport<
crate::io::cursor::CursorRead,
crate::io::cursor::CursorWrite,
LengthPrefixed,
PostcardCodec,
Sequential,
u32, u32, >;
type ClientT = StreamTransport<
crate::io::cursor::CursorRead,
crate::io::cursor::CursorWrite,
LengthPrefixed,
PostcardCodec,
Sequential,
u32, u32, >;
let client: ClientT = ClientT::new(cursor_read(Vec::new()), cursor_write());
let mut w = block_on(client.writer.lock());
let bytes = PostcardCodec.encode_to_vec(&42u32).unwrap();
block_on(LengthPrefixed.write_frame(&mut *w, &bytes)).unwrap();
let c2s = w.0.clone();
drop(w);
let mut server: ServerT = ServerT::new(cursor_read(c2s), cursor_write());
let (req, token) = block_on(server.recv()).unwrap();
assert_eq!(req, 42u32);
block_on(server.reply(token, 100u32)).unwrap();
let s2c = block_on(server.writer.lock()).0.clone();
*block_on(client.reader.lock()) = cursor_read(s2c);
let buf = {
let mut r = block_on(client.reader.lock());
block_on(LengthPrefixed.read_frame(&mut *r)).unwrap()
};
let reply: u32 = PostcardCodec.decode(&buf).unwrap();
assert_eq!(reply, 100);
}
#[test]
fn muxed_server_round_trip() {
type T = StreamTransport<
crate::io::cursor::CursorRead,
crate::io::cursor::CursorWrite,
LengthPrefixed,
PostcardCodec,
MuxedSlots<4, 128>,
u32,
u32,
>;
let slot_id: u8 = 3;
let payload = PostcardCodec.encode_to_vec(&42u32).unwrap();
let mut frame = Vec::with_capacity(1 + payload.len());
frame.push(slot_id);
frame.extend_from_slice(&payload);
let mut c2s_writer = cursor_write();
block_on(LengthPrefixed.write_frame(&mut c2s_writer, &frame)).unwrap();
let c2s_bytes = c2s_writer.0;
let mut server: T = T::new(cursor_read(c2s_bytes), cursor_write());
let (req, token) = block_on(server.recv()).unwrap();
assert_eq!(req, 42u32);
assert_eq!(token.slot_id(), 3);
block_on(server.reply(token, 100u32)).unwrap();
let s2c_bytes = block_on(server.writer.lock()).0.clone();
let mut s2c_reader = cursor_read(s2c_bytes);
let reply_frame = block_on(LengthPrefixed.read_frame(&mut s2c_reader)).unwrap();
assert_eq!(reply_frame[0], 3u8);
let reply: u32 = PostcardCodec.decode(&reply_frame[1..]).unwrap();
assert_eq!(reply, 100u32);
}
#[test]
fn muxed_client_call_end_to_end_over_mem_pipe() {
use crate::io::mem_pipe::{PipeReader, PipeWriter};
type ClientT = StreamTransport<
PipeReader,
PipeWriter,
LengthPrefixed,
PostcardCodec,
MuxedSlots<4, 128>,
u32,
u32,
>;
type ServerT = StreamTransport<
PipeReader,
PipeWriter,
LengthPrefixed,
PostcardCodec,
MuxedSlots<4, 128>,
u32,
u32,
>;
let ((r_a, w_a), (r_b, w_b)) = duplex();
let client: ClientT = ClientT::new(r_a, w_a);
let mut server: ServerT = ServerT::new(r_b, w_b);
let result = block_on(async {
let server_fut = async {
let (req, token) = server.recv().await.unwrap();
server.reply(token, req * 2).await.unwrap();
};
let client_fut = async { client.call(21u32).await.unwrap() };
let (resp, _) = futures_lite::future::zip(client_fut, server_fut).await;
resp
});
assert_eq!(result, 42);
}
}