use std::marker::PhantomData;
use vox_types::{Conduit, ConduitRx, ConduitTx, Link, LinkTx, MaybeSend, MsgFamily, SelfRef};
use crate::MessagePlan;
pub struct BareConduit<F: MsgFamily, L: Link> {
link: L,
writer_schema: Option<Vec<u8>>,
_phantom: PhantomData<fn(F) -> F>,
}
impl<F: MsgFamily, L: Link> BareConduit<F, L> {
pub fn new(link: L) -> Self {
Self {
link,
writer_schema: None,
_phantom: PhantomData,
}
}
pub fn with_message_plan(link: L, message_plan: MessagePlan) -> Self {
Self {
link,
writer_schema: Some(message_plan.writer_schema),
_phantom: PhantomData,
}
}
}
impl<F: MsgFamily, L: Link> Conduit for BareConduit<F, L>
where
L::Tx: MaybeSend + 'static,
L::Rx: MaybeSend + 'static,
{
type Msg = F;
type Tx = BareConduitTx<F, L::Tx>;
type Rx = BareConduitRx<F, L::Rx>;
fn split(self) -> (Self::Tx, Self::Rx) {
let (tx, rx) = self.link.split();
(
BareConduitTx {
link_tx: tx,
_phantom: PhantomData,
},
BareConduitRx {
link_rx: rx,
pending_fds: vox_types::FrameFds::default(),
writer_schema: self.writer_schema,
program: None,
_phantom: PhantomData,
},
)
}
}
pub struct BareConduitTx<F: MsgFamily, LTx: LinkTx> {
link_tx: LTx,
_phantom: PhantomData<fn(F)>,
}
pub struct PreparedFrame {
pub bytes: Vec<u8>,
pub fds: vox_types::FrameFds,
}
impl<F: MsgFamily, LTx: LinkTx + MaybeSend + 'static> ConduitTx for BareConduitTx<F, LTx> {
type Msg = F;
type Prepared = PreparedFrame;
type Error = BareConduitError;
fn prepare_send(&self, item: F::Msg<'_>) -> Result<Self::Prepared, Self::Error> {
let (encoded, fds) =
vox_types::collect_fds(|| vox_phon::to_vec(&item).map_err(BareConduitError::Encode));
Ok(PreparedFrame {
bytes: encoded?,
fds,
})
}
async fn send_prepared(&self, prepared: Self::Prepared) -> Result<(), Self::Error> {
let PreparedFrame { bytes, fds } = prepared;
if vox_types::frame_fds_len(&fds) > 0 && !self.link_tx.supports_fd_passing() {
return Err(BareConduitError::Io(std::io::Error::other(
"message carries file descriptors but the transport \
cannot pass them",
)));
}
self.link_tx
.send_with_fds(bytes, fds)
.await
.map_err(BareConduitError::Io)
}
async fn close(self) -> std::io::Result<()> {
self.link_tx.close().await
}
}
pub struct BareConduitRx<F: MsgFamily, LRx> {
link_rx: LRx,
pending_fds: vox_types::FrameFds,
writer_schema: Option<Vec<u8>>,
program: Option<vox_phon::DecodeProgram>,
_phantom: PhantomData<fn() -> F>,
}
impl<F: MsgFamily, LRx> BareConduitRx<F, LRx> {
fn ensure_program(&mut self) -> Result<&vox_phon::DecodeProgram, BareConduitError> {
if self.program.is_none() {
let writer_bytes = match &self.writer_schema {
Some(b) => std::borrow::Cow::Borrowed(b.as_slice()),
None => std::borrow::Cow::Owned(
vox_phon::schema_bytes::<F::Msg<'static>>()
.map_err(BareConduitError::Decode)?,
),
};
let writer =
vox_phon::parse_schema_bytes(&writer_bytes).map_err(BareConduitError::Decode)?;
let program = vox_phon::build_decode_program::<F::Msg<'static>>(&writer)
.map_err(BareConduitError::Decode)?;
self.program = Some(program);
}
Ok(self.program.as_ref().expect("program built above"))
}
}
impl<F: MsgFamily, LRx> ConduitRx for BareConduitRx<F, LRx>
where
LRx: vox_types::LinkRx + MaybeSend + 'static,
{
type Msg = F;
type Error = BareConduitError;
#[vox_rt::instrument]
async fn recv(&mut self) -> Result<Option<SelfRef<F::Msg<'static>>>, Self::Error> {
let backing = match self.link_rx.recv().await.map_err(|error| {
BareConduitError::Io(std::io::Error::other(format!("link recv failed: {error}")))
})? {
Some(b) => b,
None => return Ok(None),
};
self.pending_fds = self.link_rx.take_frame_fds();
let program = self.ensure_program()?;
SelfRef::try_new(backing, |bytes| {
vox_phon::decode_with_program::<F::Msg<'static>>(program, bytes)
.map_err(BareConduitError::Decode)
})
.map(Some)
}
fn take_frame_fds(&mut self) -> vox_types::FrameFds {
std::mem::take(&mut self.pending_fds)
}
}
#[derive(Debug)]
pub enum BareConduitError {
Encode(vox_phon::Error),
Decode(vox_phon::Error),
Io(std::io::Error),
LinkDead,
}
impl std::fmt::Display for BareConduitError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Encode(e) => write!(f, "encode error: {e}"),
Self::Decode(e) => write!(f, "decode error: {e}"),
Self::Io(e) => write!(f, "io error: {e}"),
Self::LinkDead => write!(f, "link dead"),
}
}
}
impl std::error::Error for BareConduitError {}
#[cfg(test)]
mod tests {
use vox_types::*;
use super::*;
use crate::memory_link_pair;
#[test]
fn connection_reject_with_nonempty_metadata_round_trips() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.unwrap();
rt.block_on(async { connection_reject_with_nonempty_metadata_inner().await });
}
async fn connection_reject_with_nonempty_metadata_inner() {
let (a, b) = memory_link_pair(64);
let a_conduit = BareConduit::<MessageFamily, _>::new(a);
let b_conduit = BareConduit::<MessageFamily, _>::new(b);
let (a_tx, _a_rx) = a_conduit.split();
let (_b_tx, mut b_rx) = b_conduit.split();
let msg = Message {
lane_id: LaneId(1),
payload: MessagePayload::LaneReject(LaneReject {
metadata: metadata()
.str("error", "missing required vox-service metadata")
.build(),
}),
};
let prepared = a_tx.prepare_send(msg).unwrap();
a_tx.send_prepared(prepared).await.unwrap();
let received = b_rx.recv().await.unwrap().unwrap();
let msg = received.get();
if let MessagePayload::LaneReject(reject) = &msg.payload {
assert_eq!(reject.metadata.meta_len(), 1, "expected 1 metadata entry");
assert_eq!(
reject.metadata.meta_str("error"),
Some("missing required vox-service metadata"),
);
} else {
panic!("expected LaneReject, got {:?}", msg.payload);
}
}
}