use core::marker::PhantomData;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use atomic_waker::AtomicWaker;
use serde::{Deserialize, Serialize};
use crate::io::{AsyncBytesRead, AsyncBytesWrite, LocalLock};
use crate::stream::codec::{Decoder, Encoder};
use crate::stream::framing::{FrameReader, FrameWriter};
use crate::stream::routing::{MuxedReplyToken, MuxedSlots, ReplyRouter, RouterSlotHandle};
use crate::stream::transport::StreamTransportError;
use crate::transport::{ClientTransport, ServerTransport};
pub const KIND_REQUEST: u8 = 0;
pub const KIND_RESPONSE: u8 = 1;
pub const DUPLEX_HEADER_LEN: usize = 4;
pub fn encode_duplex_frame(kind: u8, api_id: u16, slot_id: u8, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(DUPLEX_HEADER_LEN + payload.len());
out.push(kind);
out.extend_from_slice(&api_id.to_le_bytes());
out.push(slot_id);
out.extend_from_slice(payload);
out
}
#[derive(Debug, Clone, Copy)]
pub struct DuplexHeader {
pub kind: u8,
pub api_id: u16,
pub slot_id: u8,
}
#[derive(Debug)]
pub enum DuplexFrameError {
TooShort,
UnknownKind(u8),
}
impl core::fmt::Display for DuplexFrameError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DuplexFrameError::TooShort => write!(f, "duplex frame too short for header"),
DuplexFrameError::UnknownKind(k) => write!(f, "unknown duplex frame kind: {k}"),
}
}
}
pub fn parse_duplex_frame(frame: &[u8]) -> Result<(DuplexHeader, &[u8]), DuplexFrameError> {
if frame.len() < DUPLEX_HEADER_LEN {
return Err(DuplexFrameError::TooShort);
}
let kind = frame[0];
if kind != KIND_REQUEST && kind != KIND_RESPONSE {
return Err(DuplexFrameError::UnknownKind(kind));
}
let api_id = u16::from_le_bytes([frame[1], frame[2]]);
let slot_id = frame[3];
Ok((
DuplexHeader {
kind,
api_id,
slot_id,
},
&frame[DUPLEX_HEADER_LEN..],
))
}
struct InboxInner {
queue: std::sync::Mutex<VecDeque<(Vec<u8>, u8)>>,
waker: AtomicWaker,
closed: std::sync::Mutex<bool>,
}
#[derive(Clone)]
pub struct ServerInbox {
inner: Arc<InboxInner>,
}
impl ServerInbox {
fn new() -> Self {
Self {
inner: Arc::new(InboxInner {
queue: std::sync::Mutex::new(VecDeque::new()),
waker: AtomicWaker::new(),
closed: std::sync::Mutex::new(false),
}),
}
}
fn push(&self, payload: Vec<u8>, slot_id: u8) {
self.inner
.queue
.lock()
.unwrap()
.push_back((payload, slot_id));
self.inner.waker.wake();
}
fn close(&self) {
*self.inner.closed.lock().unwrap() = true;
self.inner.waker.wake();
}
async fn recv(&self) -> Option<(Vec<u8>, u8)> {
core::future::poll_fn(|cx| {
if let Some(v) = self.inner.queue.lock().unwrap().pop_front() {
return core::task::Poll::Ready(Some(v));
}
self.inner.waker.register(cx.waker());
if let Some(v) = self.inner.queue.lock().unwrap().pop_front() {
return core::task::Poll::Ready(Some(v));
}
if *self.inner.closed.lock().unwrap() {
return core::task::Poll::Ready(None);
}
core::task::Poll::Pending
})
.await
}
}
pub struct DuplexShared<W, Framer, Codec, const N: usize, const BUF: usize> {
writer: LocalLock<W>,
slots: Box<MuxedSlots<N, BUF>>,
inboxes: std::sync::Mutex<HashMap<u16, ServerInbox>>,
framer: Framer,
codec: Codec,
}
impl<W, Framer, Codec, const N: usize, const BUF: usize> DuplexShared<W, Framer, Codec, N, BUF> {
fn register_inbox(&self, api_id: u16) -> ServerInbox {
let mut map = self.inboxes.lock().unwrap();
if let Some(existing) = map.get(&api_id) {
return existing.clone();
}
let inbox = ServerInbox::new();
map.insert(api_id, inbox.clone());
inbox
}
fn close_inboxes(&self) {
let map = self.inboxes.lock().unwrap();
for ib in map.values() {
ib.close();
}
}
}
pub struct DuplexStreamTransport<R, W, Framer, Codec, const N: usize, const BUF: usize> {
reader: R,
shared: Arc<DuplexShared<W, Framer, Codec, N, BUF>>,
}
impl<R, W, Framer, Codec, const N: usize, const BUF: usize>
DuplexStreamTransport<R, W, Framer, Codec, N, BUF>
where
Framer: Default,
Codec: Default,
{
pub fn new(reader: R, writer: W) -> Self {
Self::with_layers(reader, writer, Framer::default(), Codec::default())
}
}
impl<R, W, Framer, Codec, const N: usize, const BUF: usize>
DuplexStreamTransport<R, W, Framer, Codec, N, BUF>
{
pub fn with_layers(reader: R, writer: W, framer: Framer, codec: Codec) -> Self {
Self {
reader,
shared: Arc::new(DuplexShared {
writer: LocalLock::new(writer),
slots: MuxedSlots::new_boxed(),
inboxes: std::sync::Mutex::new(HashMap::new()),
framer,
codec,
}),
}
}
pub fn server_half<Req, Resp>(
&self,
api_id: u16,
) -> DuplexServerHalf<W, Framer, Codec, N, BUF, Req, Resp> {
let inbox = self.shared.register_inbox(api_id);
DuplexServerHalf {
shared: self.shared.clone(),
api_id,
inbox,
_phantom: PhantomData,
}
}
pub fn client_half<Req, Resp>(
&self,
api_id: u16,
) -> DuplexClientHalf<W, Framer, Codec, N, BUF, Req, Resp> {
DuplexClientHalf {
shared: self.shared.clone(),
api_id,
_phantom: PhantomData,
}
}
#[allow(clippy::type_complexity)]
pub fn split(
self,
) -> (
DuplexPump<R, W, Framer, Codec, N, BUF>,
DuplexHandle<W, Framer, Codec, N, BUF>,
) {
let shared = self.shared.clone();
(
DuplexPump {
reader: self.reader,
shared: self.shared,
},
DuplexHandle { shared },
)
}
}
pub struct DuplexHandle<W, Framer, Codec, const N: usize, const BUF: usize> {
shared: Arc<DuplexShared<W, Framer, Codec, N, BUF>>,
}
impl<W, Framer, Codec, const N: usize, const BUF: usize> Clone
for DuplexHandle<W, Framer, Codec, N, BUF>
{
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
}
}
}
impl<W, Framer, Codec, const N: usize, const BUF: usize> DuplexHandle<W, Framer, Codec, N, BUF> {
pub fn server_half<Req, Resp>(
&self,
api_id: u16,
) -> DuplexServerHalf<W, Framer, Codec, N, BUF, Req, Resp> {
let inbox = self.shared.register_inbox(api_id);
DuplexServerHalf {
shared: self.shared.clone(),
api_id,
inbox,
_phantom: PhantomData,
}
}
pub fn client_half<Req, Resp>(
&self,
api_id: u16,
) -> DuplexClientHalf<W, Framer, Codec, N, BUF, Req, Resp> {
DuplexClientHalf {
shared: self.shared.clone(),
api_id,
_phantom: PhantomData,
}
}
}
pub struct DuplexPump<R, W, Framer, Codec, const N: usize, const BUF: usize> {
reader: R,
shared: Arc<DuplexShared<W, Framer, Codec, N, BUF>>,
}
#[derive(Debug)]
pub enum DuplexPumpError<F> {
Framing(F),
BadFrame(DuplexFrameError),
}
impl<F: core::fmt::Display> core::fmt::Display for DuplexPumpError<F> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DuplexPumpError::Framing(e) => write!(f, "{e}"),
DuplexPumpError::BadFrame(e) => write!(f, "{e}"),
}
}
}
impl<R, W, Framer, Codec, const N: usize, const BUF: usize> DuplexPump<R, W, Framer, Codec, N, BUF>
where
R: AsyncBytesRead,
Framer: FrameReader,
{
pub async fn run(
mut self,
) -> Result<(), DuplexPumpError<<Framer as FrameReader>::Error<R::Error>>> {
let res = self.run_inner().await;
self.shared.close_inboxes();
res
}
async fn run_inner(
&mut self,
) -> Result<(), DuplexPumpError<<Framer as FrameReader>::Error<R::Error>>> {
loop {
let frame = self
.shared
.framer
.read_frame(&mut self.reader)
.await
.map_err(DuplexPumpError::Framing)?;
let (hdr, payload) = parse_duplex_frame(&frame).map_err(DuplexPumpError::BadFrame)?;
match hdr.kind {
KIND_RESPONSE => {
self.shared.slots.deliver(hdr.slot_id, payload);
}
KIND_REQUEST => {
let owned = payload.to_vec();
let inbox = {
let map = self.shared.inboxes.lock().unwrap();
map.get(&hdr.api_id).cloned()
};
match inbox {
Some(inbox) => inbox.push(owned, hdr.slot_id),
None => {
#[cfg(debug_assertions)]
eprintln!(
"duplex pump: dropping request for unknown api_id 0x{:04x}",
hdr.api_id
);
}
}
}
_ => unreachable!("parse_duplex_frame validates kind"),
}
}
}
}
pub struct DuplexClientHalf<W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp> {
shared: Arc<DuplexShared<W, Framer, Codec, N, BUF>>,
api_id: u16,
_phantom: PhantomData<(Req, Resp)>,
}
impl<W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp> ClientTransport<Req, Resp>
for DuplexClientHalf<W, Framer, Codec, N, BUF, Req, Resp>
where
W: AsyncBytesWrite,
Framer: FrameWriter,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: Serialize,
Resp: for<'de> Deserialize<'de>,
<Framer as FrameWriter>::Error<W::Error>: core::fmt::Debug,
<Codec as Encoder>::Error: core::fmt::Debug,
{
type Error =
StreamTransportError<<Framer as FrameWriter>::Error<W::Error>, <Codec as Encoder>::Error>;
async fn call(&self, req: Req) -> Result<Resp, Self::Error> {
let slot = self.shared.slots.acquire().await.map_err(|e| match e {})?;
let payload = self
.shared
.codec
.encode_to_vec(&req)
.map_err(StreamTransportError::Codec)?;
let framed = encode_duplex_frame(KIND_REQUEST, self.api_id, slot.slot_id(), &payload);
{
let mut w = self.shared.writer.lock().await;
self.shared
.framer
.write_frame(&mut *w, &framed)
.await
.map_err(StreamTransportError::Framing)?;
}
let bytes = slot.recv_reply().await;
self.shared
.codec
.decode(bytes)
.map_err(StreamTransportError::Codec)
}
}
pub struct DuplexServerHalf<W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp> {
shared: Arc<DuplexShared<W, Framer, Codec, N, BUF>>,
api_id: u16,
inbox: ServerInbox,
_phantom: PhantomData<(Req, Resp)>,
}
#[derive(Debug)]
pub struct DuplexStreamClosed;
impl core::fmt::Display for DuplexStreamClosed {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str("duplex stream closed")
}
}
impl<W, Framer, Codec, const N: usize, const BUF: usize, Req, Resp> ServerTransport<Req, Resp>
for DuplexServerHalf<W, Framer, Codec, N, BUF, Req, Resp>
where
W: AsyncBytesWrite,
Framer: FrameWriter,
Codec: Encoder + Decoder<Error = <Codec as Encoder>::Error>,
Req: for<'de> Deserialize<'de>,
Resp: Serialize,
<Framer as FrameWriter>::Error<W::Error>: core::fmt::Debug,
<Codec as Encoder>::Error: core::fmt::Debug,
{
type Error =
DuplexServerError<<Framer as FrameWriter>::Error<W::Error>, <Codec as Encoder>::Error>;
type ReplyToken = MuxedReplyToken;
async fn recv(&mut self) -> Result<(Req, Self::ReplyToken), Self::Error> {
let (payload, slot_id) = self.inbox.recv().await.ok_or(DuplexServerError::Closed)?;
let req = self
.shared
.codec
.decode(&payload)
.map_err(DuplexServerError::Codec)?;
Ok((req, MuxedReplyToken::new(slot_id)))
}
async fn reply(&self, token: Self::ReplyToken, resp: Resp) -> Result<(), Self::Error> {
let payload = self
.shared
.codec
.encode_to_vec(&resp)
.map_err(DuplexServerError::Codec)?;
let framed = encode_duplex_frame(KIND_RESPONSE, self.api_id, token.slot_id(), &payload);
let mut w = self.shared.writer.lock().await;
self.shared
.framer
.write_frame(&mut *w, &framed)
.await
.map_err(DuplexServerError::Framing)
}
}
#[derive(Debug)]
pub enum DuplexServerError<F, C> {
Framing(F),
Codec(C),
Closed,
}
impl<F: core::fmt::Display, C: core::fmt::Display> core::fmt::Display for DuplexServerError<F, C> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
DuplexServerError::Framing(e) => write!(f, "{e}"),
DuplexServerError::Codec(e) => write!(f, "codec error: {e}"),
DuplexServerError::Closed => write!(f, "duplex stream closed"),
}
}
}
#[cfg(all(test, feature = "postcard"))]
mod tests {
use super::*;
use crate::io::mem_pipe::duplex;
use crate::stream::{LengthPrefixed, PostcardCodec};
fn block_on<F: core::future::Future>(fut: F) -> F::Output {
futures_lite::future::block_on(fut)
}
type DxT<R, W> = DuplexStreamTransport<R, W, LengthPrefixed, PostcardCodec, 4, 256>;
const PING_API: u16 = 0x0001;
const ECHO_API: u16 = 0x0002;
#[test]
fn duplex_header_round_trip() {
let f = encode_duplex_frame(KIND_REQUEST, 0xBEEF, 7, b"hi");
let (hdr, payload) = parse_duplex_frame(&f).unwrap();
assert_eq!(hdr.kind, KIND_REQUEST);
assert_eq!(hdr.api_id, 0xBEEF);
assert_eq!(hdr.slot_id, 7);
assert_eq!(payload, b"hi");
}
#[test]
fn duplex_end_to_end_two_apis_both_directions() {
let ((r_a, w_a), (r_b, w_b)) = duplex();
let dx_a: DxT<_, _> = DxT::new(r_a, w_a);
let dx_b: DxT<_, _> = DxT::new(r_b, w_b);
let ping_server_a = dx_a.server_half::<u32, u32>(PING_API);
let echo_client_a = dx_a.client_half::<String, String>(ECHO_API);
let echo_server_b = dx_b.server_half::<String, String>(ECHO_API);
let ping_client_b = dx_b.client_half::<u32, u32>(PING_API);
let (pump_a, _h_a) = dx_a.split();
let (pump_b, _h_b) = dx_b.split();
block_on(async {
let mut ping_srv = ping_server_a;
let server_a = async move {
let (req, token) = ping_srv.recv().await.unwrap();
ping_srv.reply(token, req + 1).await.unwrap();
};
let mut echo_srv = echo_server_b;
let server_b = async move {
let (req, token) = echo_srv.recv().await.unwrap();
echo_srv.reply(token, format!("echo: {req}")).await.unwrap();
};
let client_a = async { echo_client_a.call("hello".to_string()).await.unwrap() };
let client_b = async { ping_client_b.call(41u32).await.unwrap() };
let pump_a_fut = pump_a.run();
let pump_b_fut = pump_b.run();
let work = async {
let ((echo_resp, ping_resp), _) = futures_lite::future::zip(
futures_lite::future::zip(client_a, client_b),
futures_lite::future::zip(server_a, server_b),
)
.await;
assert_eq!(echo_resp, "echo: hello");
assert_eq!(ping_resp, 42);
};
futures_lite::future::or(work, async {
let _ = futures_lite::future::zip(pump_a_fut, pump_b_fut).await;
})
.await;
});
}
#[test]
fn duplex_construction_on_restricted_stack() {
type BigDx<R, W> = DuplexStreamTransport<R, W, LengthPrefixed, PostcardCodec, 32, 131_072>;
std::thread::Builder::new()
.stack_size(1 << 20) .spawn(|| {
let ((r_a, w_a), (_r_b, _w_b)) = duplex();
let _dx: BigDx<_, _> = BigDx::new(r_a, w_a);
})
.expect("spawn restricted-stack thread")
.join()
.expect("DuplexStreamTransport construction overflowed a 1 MiB stack");
}
}