#![deny(unsafe_code)]
#[macro_use]
mod macros;
pub mod diagnostic;
pub mod driver;
pub mod runtime;
pub mod transport;
pub use driver::{
ConnectError, ConnectionError, Driver, FramedClient, HandshakeConfig, IncomingConnection,
IncomingConnections, MessageConnector, Negotiated, NoDispatcher, RetryPolicy, accept_framed,
connect_framed, connect_framed_with_policy, initiate_framed,
};
pub use transport::MessageTransport;
use std::marker::PhantomData;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::runtime::{OneshotSender, Receiver, Sender, oneshot};
use facet::Facet;
use std::convert::Infallible;
pub use roam_frame::{Frame, MsgDesc, OwnedMessage, Payload};
const CHANNEL_SIZE: usize = 1024;
const RX_STREAM_BUFFER_SIZE: usize = 1024;
pub type ChannelId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Role {
Initiator,
Acceptor,
}
pub struct ChannelIdAllocator {
next: AtomicU64,
}
impl ChannelIdAllocator {
pub fn new(role: Role) -> Self {
let start = match role {
Role::Initiator => 1, Role::Acceptor => 2, };
Self {
next: AtomicU64::new(start),
}
}
pub fn next(&self) -> ChannelId {
self.next.fetch_add(2, Ordering::Relaxed)
}
}
#[derive(Facet)]
#[facet(opaque)]
pub struct SenderSlot {
pub(crate) inner: Option<Sender<Vec<u8>>>,
}
impl SenderSlot {
pub fn new(tx: Sender<Vec<u8>>) -> Self {
Self { inner: Some(tx) }
}
pub fn empty() -> Self {
Self { inner: None }
}
pub fn take(&mut self) -> Option<Sender<Vec<u8>>> {
self.inner.take()
}
pub fn is_some(&self) -> bool {
self.inner.is_some()
}
pub fn is_none(&self) -> bool {
self.inner.is_none()
}
pub fn set(&mut self, tx: Sender<Vec<u8>>) {
self.inner = Some(tx);
}
}
#[derive(Facet)]
#[facet(opaque)]
pub struct DriverTxSlot {
pub(crate) inner: Option<Sender<DriverMessage>>,
}
impl DriverTxSlot {
pub fn new(tx: Sender<DriverMessage>) -> Self {
Self { inner: Some(tx) }
}
pub fn empty() -> Self {
Self { inner: None }
}
pub fn take(&mut self) -> Option<Sender<DriverMessage>> {
self.inner.take()
}
pub fn is_some(&self) -> bool {
self.inner.is_some()
}
pub fn is_none(&self) -> bool {
self.inner.is_none()
}
pub fn set(&mut self, tx: Sender<DriverMessage>) {
self.inner = Some(tx);
}
pub fn clone_inner(&self) -> Option<Sender<DriverMessage>> {
self.inner.clone()
}
}
#[derive(Facet)]
#[facet(proxy = u64)]
pub struct Tx<T: 'static> {
pub conn_id: roam_wire::ConnectionId,
pub channel_id: ChannelId,
pub sender: SenderSlot,
pub driver_tx: DriverTxSlot,
#[facet(opaque)]
_marker: PhantomData<T>,
}
#[allow(clippy::infallible_try_from)]
impl<T: 'static> TryFrom<&Tx<T>> for u64 {
type Error = Infallible;
fn try_from(tx: &Tx<T>) -> Result<Self, Self::Error> {
Ok(tx.channel_id)
}
}
#[allow(clippy::infallible_try_from)]
impl<T: 'static> TryFrom<u64> for Tx<T> {
type Error = Infallible;
fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
Ok(Tx {
conn_id: roam_wire::ConnectionId::ROOT,
channel_id,
sender: SenderSlot::empty(),
driver_tx: DriverTxSlot::empty(),
_marker: PhantomData,
})
}
}
impl<T: 'static> Tx<T> {
pub fn new(channel_id: ChannelId, tx: Sender<Vec<u8>>) -> Self {
Self {
conn_id: roam_wire::ConnectionId::ROOT,
channel_id,
sender: SenderSlot::new(tx),
driver_tx: DriverTxSlot::empty(),
_marker: PhantomData,
}
}
pub fn unbound(tx: Sender<Vec<u8>>) -> Self {
Self {
conn_id: roam_wire::ConnectionId::ROOT,
channel_id: 0,
sender: SenderSlot::new(tx),
driver_tx: DriverTxSlot::empty(),
_marker: PhantomData,
}
}
pub fn bound(
conn_id: roam_wire::ConnectionId,
channel_id: ChannelId,
tx: Sender<Vec<u8>>,
driver_tx: Sender<DriverMessage>,
) -> Self {
Self {
conn_id,
channel_id,
sender: SenderSlot::new(tx),
driver_tx: DriverTxSlot::new(driver_tx),
_marker: PhantomData,
}
}
pub fn channel_id(&self) -> ChannelId {
self.channel_id
}
pub async fn send(&self, value: &T) -> Result<(), TxError>
where
T: Facet<'static>,
{
let bytes = facet_postcard::to_vec(value).map_err(TxError::Serialize)?;
if let Some(tx) = self.sender.inner.as_ref() {
tx.send(bytes).await.map_err(|_| TxError::Closed)
}
else if let Some(task_tx) = self.driver_tx.inner.as_ref() {
task_tx
.send(DriverMessage::Data {
conn_id: self.conn_id,
channel_id: self.channel_id,
payload: bytes,
})
.await
.map_err(|_| TxError::Closed)
} else {
Err(TxError::Taken)
}
}
}
impl<T: 'static> Drop for Tx<T> {
fn drop(&mut self) {
if self.sender.inner.is_some() {
return;
}
if let Some(task_tx) = self.driver_tx.inner.take() {
let conn_id = self.conn_id;
let channel_id = self.channel_id;
if task_tx
.try_send(DriverMessage::Close {
conn_id,
channel_id,
})
.is_err()
{
crate::runtime::spawn(async move {
let _ = task_tx
.send(DriverMessage::Close {
conn_id,
channel_id,
})
.await;
});
}
}
}
}
#[derive(Debug)]
pub enum TxError {
Serialize(facet_postcard::SerializeError),
Closed,
Taken,
}
impl std::fmt::Display for TxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TxError::Serialize(e) => write!(f, "serialize error: {e}"),
TxError::Closed => write!(f, "stream closed"),
TxError::Taken => write!(f, "sender was taken"),
}
}
}
impl std::error::Error for TxError {}
#[derive(Facet)]
#[facet(opaque)]
pub struct ReceiverSlot {
pub(crate) inner: Option<Receiver<Vec<u8>>>,
}
impl ReceiverSlot {
pub fn new(rx: Receiver<Vec<u8>>) -> Self {
Self { inner: Some(rx) }
}
pub fn empty() -> Self {
Self { inner: None }
}
pub fn take(&mut self) -> Option<Receiver<Vec<u8>>> {
self.inner.take()
}
pub fn is_some(&self) -> bool {
self.inner.is_some()
}
pub fn is_none(&self) -> bool {
self.inner.is_none()
}
pub fn set(&mut self, rx: Receiver<Vec<u8>>) {
self.inner = Some(rx);
}
}
#[derive(Facet)]
#[facet(proxy = u64)]
pub struct Rx<T: 'static> {
pub channel_id: ChannelId,
pub receiver: ReceiverSlot,
#[facet(opaque)]
_marker: PhantomData<T>,
}
#[allow(clippy::infallible_try_from)]
impl<T: 'static> TryFrom<&Rx<T>> for u64 {
type Error = Infallible;
fn try_from(rx: &Rx<T>) -> Result<Self, Self::Error> {
Ok(rx.channel_id)
}
}
#[allow(clippy::infallible_try_from)]
impl<T: 'static> TryFrom<u64> for Rx<T> {
type Error = Infallible;
fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
Ok(Rx {
channel_id,
receiver: ReceiverSlot::empty(),
_marker: PhantomData,
})
}
}
impl<T: 'static> Rx<T> {
pub fn new(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
Self {
channel_id,
receiver: ReceiverSlot::new(rx),
_marker: PhantomData,
}
}
pub fn unbound(rx: Receiver<Vec<u8>>) -> Self {
Self {
channel_id: 0,
receiver: ReceiverSlot::new(rx),
_marker: PhantomData,
}
}
pub fn bound(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
Self {
channel_id,
receiver: ReceiverSlot::new(rx),
_marker: PhantomData,
}
}
pub fn channel_id(&self) -> ChannelId {
self.channel_id
}
pub async fn recv(&mut self) -> Result<Option<T>, RxError>
where
T: Facet<'static>,
{
let rx = self.receiver.inner.as_mut().ok_or(RxError::Taken)?;
match rx.recv().await {
Some(bytes) => {
let value = facet_postcard::from_slice(&bytes).map_err(RxError::Deserialize)?;
Ok(Some(value))
}
None => Ok(None),
}
}
}
#[derive(Debug)]
pub enum RxError {
Deserialize(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
Taken,
}
impl std::fmt::Display for RxError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RxError::Deserialize(e) => write!(f, "deserialize error: {e}"),
RxError::Taken => write!(f, "receiver was taken"),
}
}
}
impl std::error::Error for RxError {}
pub fn channel<T: 'static>() -> (Tx<T>, Rx<T>) {
let (sender, receiver) = crate::runtime::channel(CHANNEL_SIZE);
if let Some(ctx) = get_dispatch_context() {
let channel_id = ctx.channel_ids.next();
debug!(channel_id, "roam::channel() creating bound channel pair");
(
Tx::bound(ctx.conn_id, channel_id, sender, ctx.driver_tx.clone()),
Rx::bound(channel_id, receiver),
)
} else {
trace!("roam::channel() creating unbound channel pair (no dispatch context)");
(Tx::unbound(sender), Rx::unbound(receiver))
}
}
#[derive(Clone)]
struct DispatchContext {
conn_id: roam_wire::ConnectionId,
channel_ids: Arc<ChannelIdAllocator>,
driver_tx: Sender<DriverMessage>,
}
roam_task_local::task_local! {
static DISPATCH_CONTEXT: DispatchContext;
}
fn get_dispatch_context() -> Option<DispatchContext> {
DISPATCH_CONTEXT.try_with(|ctx| ctx.clone()).ok()
}
use std::collections::{HashMap, HashSet};
#[derive(Debug)]
pub struct ResponseData {
pub payload: Vec<u8>,
pub channels: Vec<u64>,
}
pub enum DriverMessage {
Call {
conn_id: roam_wire::ConnectionId,
request_id: u64,
method_id: u64,
metadata: Vec<(String, roam_wire::MetadataValue)>,
channels: Vec<u64>,
payload: Vec<u8>,
response_tx: OneshotSender<Result<ResponseData, TransportError>>,
},
Data {
conn_id: roam_wire::ConnectionId,
channel_id: ChannelId,
payload: Vec<u8>,
},
Close {
conn_id: roam_wire::ConnectionId,
channel_id: ChannelId,
},
Response {
conn_id: roam_wire::ConnectionId,
request_id: u64,
channels: Vec<u64>,
payload: Vec<u8>,
},
Connect {
request_id: u64,
metadata: roam_wire::Metadata,
response_tx: OneshotSender<Result<ConnectionHandle, crate::ConnectError>>,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
},
}
pub struct ChannelRegistry {
conn_id: roam_wire::ConnectionId,
incoming: HashMap<ChannelId, Sender<Vec<u8>>>,
closed: HashSet<ChannelId>,
incoming_credit: HashMap<ChannelId, u32>,
outgoing_credit: HashMap<ChannelId, u32>,
initial_credit: u32,
driver_tx: Sender<DriverMessage>,
response_channel_ids: Arc<ChannelIdAllocator>,
}
impl ChannelRegistry {
pub fn new_with_credit_and_role(
conn_id: roam_wire::ConnectionId,
initial_credit: u32,
driver_tx: Sender<DriverMessage>,
role: Role,
) -> Self {
Self {
conn_id,
incoming: HashMap::new(),
closed: HashSet::new(),
incoming_credit: HashMap::new(),
outgoing_credit: HashMap::new(),
initial_credit,
driver_tx,
response_channel_ids: Arc::new(ChannelIdAllocator::new(role)),
}
}
pub fn new_with_credit(initial_credit: u32, driver_tx: Sender<DriverMessage>) -> Self {
Self::new_with_credit_and_role(
roam_wire::ConnectionId::ROOT,
initial_credit,
driver_tx,
Role::Acceptor,
)
}
pub fn new(driver_tx: Sender<DriverMessage>) -> Self {
Self::new_with_credit(u32::MAX, driver_tx)
}
pub fn conn_id(&self) -> roam_wire::ConnectionId {
self.conn_id
}
pub(crate) fn dispatch_context(&self) -> DispatchContext {
DispatchContext {
conn_id: self.conn_id,
channel_ids: self.response_channel_ids.clone(),
driver_tx: self.driver_tx.clone(),
}
}
pub fn driver_tx(&self) -> Sender<DriverMessage> {
self.driver_tx.clone()
}
pub fn response_channel_ids(&self) -> Arc<ChannelIdAllocator> {
self.response_channel_ids.clone()
}
pub fn register_incoming(&mut self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
self.incoming.insert(channel_id, tx);
self.incoming_credit.insert(channel_id, self.initial_credit);
}
pub fn register_outgoing_credit(&mut self, channel_id: ChannelId) {
self.outgoing_credit.insert(channel_id, self.initial_credit);
}
pub fn prepare_route_data(
&mut self,
channel_id: ChannelId,
payload: Vec<u8>,
) -> Result<(Sender<Vec<u8>>, Vec<u8>), ChannelError> {
if self.closed.contains(&channel_id) {
return Err(ChannelError::DataAfterClose);
}
let payload_len = payload.len() as u32;
if let Some(credit) = self.incoming_credit.get_mut(&channel_id) {
if payload_len > *credit {
return Err(ChannelError::CreditOverrun);
}
*credit -= payload_len;
}
if let Some(tx) = self.incoming.get(&channel_id) {
Ok((tx.clone(), payload))
} else {
Err(ChannelError::Unknown)
}
}
pub async fn route_data(
&mut self,
channel_id: ChannelId,
payload: Vec<u8>,
) -> Result<(), ChannelError> {
let (tx, payload) = self.prepare_route_data(channel_id, payload)?;
let _ = tx.send(payload).await;
Ok(())
}
pub fn close(&mut self, channel_id: ChannelId) {
self.incoming.remove(&channel_id);
self.incoming_credit.remove(&channel_id);
self.outgoing_credit.remove(&channel_id);
self.closed.insert(channel_id);
}
pub fn reset(&mut self, channel_id: ChannelId) {
self.incoming.remove(&channel_id);
self.incoming_credit.remove(&channel_id);
self.outgoing_credit.remove(&channel_id);
self.closed.insert(channel_id);
}
pub fn receive_credit(&mut self, channel_id: ChannelId, bytes: u32) {
if let Some(credit) = self.outgoing_credit.get_mut(&channel_id) {
*credit = credit.saturating_add(bytes);
}
}
pub fn contains(&self, channel_id: ChannelId) -> bool {
self.incoming.contains_key(&channel_id) || self.outgoing_credit.contains_key(&channel_id)
}
pub fn contains_incoming(&self, channel_id: ChannelId) -> bool {
self.incoming.contains_key(&channel_id)
}
pub fn contains_outgoing(&self, channel_id: ChannelId) -> bool {
self.outgoing_credit.contains_key(&channel_id)
}
pub fn is_closed(&self, channel_id: ChannelId) -> bool {
self.closed.contains(&channel_id)
}
pub fn outgoing_count(&self) -> usize {
self.outgoing_credit.len()
}
pub fn outgoing_credit(&self, channel_id: ChannelId) -> Option<u32> {
self.outgoing_credit.get(&channel_id).copied()
}
pub fn incoming_credit(&self, channel_id: ChannelId) -> Option<u32> {
self.incoming_credit.get(&channel_id).copied()
}
pub fn bind_streams<T: Facet<'static>>(&mut self, args: &mut T) {
let poke = facet::Poke::new(args);
self.bind_streams_recursive(poke);
}
#[allow(unsafe_code)]
fn bind_streams_recursive(&mut self, mut poke: facet::Poke<'_, '_>) {
use facet::Def;
let shape = poke.shape();
trace!(
module_path = ?shape.module_path,
type_identifier = shape.type_identifier,
"bind_streams_recursive: visiting type"
);
if shape.module_path == Some("roam_session") {
if shape.type_identifier == "Rx" {
debug!("bind_streams_recursive: found Rx, binding");
self.bind_rx_stream(poke);
return;
} else if shape.type_identifier == "Tx" {
debug!("bind_streams_recursive: found Tx, binding");
self.bind_tx_stream(poke);
return;
}
}
match shape.def {
Def::Scalar => {}
_ if poke.is_struct() => {
let mut ps = poke.into_struct().expect("is_struct was true");
let field_count = ps.field_count();
trace!(field_count, "bind_streams_recursive: recursing into struct");
for i in 0..field_count {
if let Ok(field_poke) = ps.field(i) {
self.bind_streams_recursive(field_poke);
}
}
}
Def::Option(_) => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(inner_poke)) = pe.field(0)
{
self.bind_streams_recursive(inner_poke);
}
}
Def::List(list_def) => {
let len = {
let peek = poke.as_peek();
peek.into_list().map(|pl| pl.len()).unwrap_or(0)
};
if let Some(get_mut_fn) = list_def.vtable.get_mut {
let element_shape = list_def.t;
let data_ptr = poke.data_mut();
for i in 0..len {
let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
if let Some(ptr) = element_ptr {
let element_poke =
unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
self.bind_streams_recursive(element_poke);
}
}
}
}
_ if poke.is_enum() => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(variant_poke)) = pe.field(0)
{
self.bind_streams_recursive(variant_poke);
}
}
_ => {}
}
}
fn bind_rx_stream(&mut self, poke: facet::Poke<'_, '_>) {
if let Ok(mut ps) = poke.into_struct() {
let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(id_ref) = channel_id_field.get::<ChannelId>()
{
*id_ref
} else {
warn!("bind_rx_stream: could not get channel_id field");
return;
};
debug!(channel_id, "bind_rx_stream: registering incoming channel");
let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
if let Ok(mut receiver_field) = ps.field_by_name("receiver")
&& let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
{
slot.set(rx);
}
self.register_incoming(channel_id, tx);
debug!(channel_id, "bind_rx_stream: channel registered");
} else {
warn!("bind_rx_stream: could not convert poke to struct");
}
}
fn bind_tx_stream(&mut self, poke: facet::Poke<'_, '_>) {
if let Ok(mut ps) = poke.into_struct() {
if let Ok(mut conn_id_field) = ps.field_by_name("conn_id")
&& let Ok(id_ref) = conn_id_field.get_mut::<roam_wire::ConnectionId>()
{
*id_ref = self.conn_id;
}
if let Ok(mut driver_tx_field) = ps.field_by_name("driver_tx")
&& let Ok(slot) = driver_tx_field.get_mut::<DriverTxSlot>()
{
slot.set(self.driver_tx.clone());
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelError {
Unknown,
DataAfterClose,
CreditOverrun,
}
pub trait FlowControl: Send {
fn on_data_received(&mut self, channel_id: ChannelId, bytes: u32);
fn wait_for_send_credit(
&mut self,
channel_id: ChannelId,
bytes: u32,
) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
fn consume_send_credit(&mut self, channel_id: ChannelId, bytes: u32);
}
#[derive(Debug, Clone, Copy, Default)]
pub struct InfiniteCredit;
impl FlowControl for InfiniteCredit {
fn on_data_received(&mut self, _channel_id: ChannelId, _bytes: u32) {
}
async fn wait_for_send_credit(
&mut self,
_channel_id: ChannelId,
_bytes: u32,
) -> std::io::Result<()> {
Ok(())
}
fn consume_send_credit(&mut self, _channel_id: ChannelId, _bytes: u32) {
}
}
pub struct RequestIdGenerator {
next: AtomicU64,
}
impl RequestIdGenerator {
pub fn new() -> Self {
Self {
next: AtomicU64::new(1),
}
}
pub fn next(&self) -> u64 {
self.next.fetch_add(1, Ordering::Relaxed)
}
}
impl Default for RequestIdGenerator {
fn default() -> Self {
Self::new()
}
}
pub fn dispatch_call<A, R, E, F, Fut>(
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
handler: F,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
where
A: Facet<'static> + Send,
R: Facet<'static> + Send,
E: Facet<'static> + Send,
F: FnOnce(A) -> Fut + Send + 'static,
Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
{
let conn_id = cx.conn_id;
let request_id = cx.request_id.raw();
let channels = &cx.channels;
let mut args: A = match facet_postcard::from_slice(&payload) {
Ok(args) => args,
Err(_) => {
let task_tx = registry.driver_tx();
return Box::pin(async move {
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: Vec::new(),
payload: vec![1, 2],
})
.await;
});
}
};
debug!(channels = ?channels, "dispatch_call: patching channel IDs");
patch_channel_ids(&mut args, channels);
debug!("dispatch_call: binding streams SYNC");
registry.bind_streams(&mut args);
debug!("dispatch_call: streams bound SYNC - channels should now be registered");
let task_tx = registry.driver_tx();
let dispatch_ctx = registry.dispatch_context();
Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
debug!("dispatch_call: handler ASYNC starting");
let result = handler(args).await;
debug!("dispatch_call: handler ASYNC finished");
let (payload, response_channels) = match result {
Ok(ref ok_result) => {
let channels = collect_channel_ids(ok_result);
let mut out = vec![0u8];
match facet_postcard::to_vec(ok_result) {
Ok(bytes) => out.extend(bytes),
Err(_) => return,
}
(out, channels)
}
Err(user_error) => {
let mut out = vec![1u8, 0u8];
match facet_postcard::to_vec(&user_error) {
Ok(bytes) => out.extend(bytes),
Err(_) => return,
}
(out, Vec::new())
}
};
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: response_channels,
payload,
})
.await;
}))
}
pub fn dispatch_call_infallible<A, R, F, Fut>(
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
handler: F,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
where
A: Facet<'static> + Send,
R: Facet<'static> + Send,
F: FnOnce(A) -> Fut + Send + 'static,
Fut: std::future::Future<Output = R> + Send + 'static,
{
let conn_id = cx.conn_id;
let request_id = cx.request_id.raw();
let channels = &cx.channels;
let mut args: A = match facet_postcard::from_slice(&payload) {
Ok(args) => args,
Err(_) => {
let task_tx = registry.driver_tx();
return Box::pin(async move {
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: Vec::new(),
payload: vec![1, 2],
})
.await;
});
}
};
patch_channel_ids(&mut args, channels);
registry.bind_streams(&mut args);
let task_tx = registry.driver_tx();
let dispatch_ctx = registry.dispatch_context();
Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
let result = handler(args).await;
let response_channels = collect_channel_ids(&result);
if !response_channels.is_empty() {
debug!(
channels = ?response_channels,
"dispatch_call_infallible: collected response channels"
);
}
let mut payload = vec![0u8];
match facet_postcard::to_vec(&result) {
Ok(bytes) => payload.extend(bytes),
Err(_) => return,
}
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: response_channels,
payload,
})
.await;
}))
}
pub fn dispatch_unknown_method(
cx: &Context,
registry: &mut ChannelRegistry,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
let conn_id = cx.conn_id;
let request_id = cx.request_id.raw();
let task_tx = registry.driver_tx();
Box::pin(async move {
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: Vec::new(),
payload: vec![1, 1],
})
.await;
})
}
pub fn collect_channel_ids<T: Facet<'static>>(args: &T) -> Vec<u64> {
let mut ids = Vec::new();
let poke = facet::Peek::new(args);
collect_channel_ids_recursive(poke, &mut ids);
ids
}
fn collect_channel_ids_recursive(peek: facet::Peek<'_, '_>, ids: &mut Vec<u64>) {
let shape = peek.shape();
if shape.module_path == Some("roam_session")
&& (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
{
if let Ok(ps) = peek.into_struct()
&& let Ok(channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(&channel_id) = channel_id_field.get::<ChannelId>()
{
ids.push(channel_id);
}
return;
}
if let Ok(ps) = peek.into_struct() {
let field_count = ps.field_count();
for i in 0..field_count {
if let Ok(field_peek) = ps.field(i) {
collect_channel_ids_recursive(field_peek, ids);
}
}
return;
}
if let Ok(po) = peek.into_option() {
if let Some(inner) = po.value() {
collect_channel_ids_recursive(inner, ids);
}
return;
}
if let Ok(pe) = peek.into_enum() {
if let Ok(Some(variant_peek)) = pe.field(0) {
collect_channel_ids_recursive(variant_peek, ids);
}
return;
}
if let Ok(pl) = peek.into_list() {
for element in pl.iter() {
collect_channel_ids_recursive(element, ids);
}
}
}
pub fn patch_channel_ids<T: Facet<'static>>(args: &mut T, channels: &[u64]) {
debug!(channels = ?channels, "patch_channel_ids: patching channels from wire");
let mut idx = 0;
let poke = facet::Poke::new(args);
patch_channel_ids_recursive(poke, channels, &mut idx);
}
#[allow(unsafe_code)]
fn patch_channel_ids_recursive(mut poke: facet::Poke<'_, '_>, channels: &[u64], idx: &mut usize) {
use facet::Def;
let shape = poke.shape();
if shape.module_path == Some("roam_session")
&& (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
{
if let Ok(mut ps) = poke.into_struct()
&& let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(channel_id_ref) = channel_id_field.get_mut::<ChannelId>()
&& *idx < channels.len()
{
*channel_id_ref = channels[*idx];
*idx += 1;
}
return;
}
match shape.def {
Def::Scalar => {}
_ if poke.is_struct() => {
let mut ps = poke.into_struct().expect("is_struct was true");
let field_count = ps.field_count();
for i in 0..field_count {
if let Ok(field_poke) = ps.field(i) {
patch_channel_ids_recursive(field_poke, channels, idx);
}
}
}
Def::Option(_) => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(inner_poke)) = pe.field(0)
{
patch_channel_ids_recursive(inner_poke, channels, idx);
}
}
Def::List(list_def) => {
let len = {
let peek = poke.as_peek();
peek.into_list().map(|pl| pl.len()).unwrap_or(0)
};
if let Some(get_mut_fn) = list_def.vtable.get_mut {
let element_shape = list_def.t;
let data_ptr = poke.data_mut();
for i in 0..len {
let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
if let Some(ptr) = element_ptr {
let element_poke =
unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
patch_channel_ids_recursive(element_poke, channels, idx);
}
}
}
}
_ if poke.is_enum() => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(variant_poke)) = pe.field(0)
{
patch_channel_ids_recursive(variant_poke, channels, idx);
}
}
_ => {}
}
}
#[derive(Debug, Clone)]
pub struct Context {
pub conn_id: roam_wire::ConnectionId,
pub request_id: roam_wire::RequestId,
pub method_id: roam_wire::MethodId,
pub metadata: roam_wire::Metadata,
pub channels: Vec<u64>,
}
impl Context {
pub fn new(
conn_id: roam_wire::ConnectionId,
request_id: roam_wire::RequestId,
method_id: roam_wire::MethodId,
metadata: roam_wire::Metadata,
channels: Vec<u64>,
) -> Self {
Self {
conn_id,
request_id,
method_id,
metadata,
channels,
}
}
pub fn conn_id(&self) -> roam_wire::ConnectionId {
self.conn_id
}
pub fn request_id(&self) -> roam_wire::RequestId {
self.request_id
}
pub fn method_id(&self) -> roam_wire::MethodId {
self.method_id
}
pub fn metadata(&self) -> &roam_wire::Metadata {
&self.metadata
}
pub fn channels(&self) -> &[u64] {
&self.channels
}
}
pub trait ServiceDispatcher: Send + Sync {
fn method_ids(&self) -> Vec<u64>;
fn dispatch(
&self,
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>;
}
pub struct RoutedDispatcher<A, B> {
primary: A,
fallback: B,
primary_methods: Vec<u64>,
}
impl<A, B> RoutedDispatcher<A, B>
where
A: ServiceDispatcher,
{
pub fn new(primary: A, fallback: B) -> Self {
let primary_methods = primary.method_ids();
Self {
primary,
fallback,
primary_methods,
}
}
}
impl<A, B> ServiceDispatcher for RoutedDispatcher<A, B>
where
A: ServiceDispatcher,
B: ServiceDispatcher,
{
fn method_ids(&self) -> Vec<u64> {
let mut ids = self.primary_methods.clone();
ids.extend(self.fallback.method_ids());
ids
}
fn dispatch(
&self,
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
if self.primary_methods.contains(&cx.method_id().raw()) {
self.primary.dispatch(cx, payload, registry)
} else {
self.fallback.dispatch(cx, payload, registry)
}
}
}
pub struct ForwardingDispatcher {
upstream: ConnectionHandle,
}
impl ForwardingDispatcher {
pub fn new(upstream: ConnectionHandle) -> Self {
Self { upstream }
}
}
impl Clone for ForwardingDispatcher {
fn clone(&self) -> Self {
Self {
upstream: self.upstream.clone(),
}
}
}
impl ServiceDispatcher for ForwardingDispatcher {
fn method_ids(&self) -> Vec<u64> {
vec![]
}
fn dispatch(
&self,
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
let task_tx = registry.driver_tx();
let upstream = self.upstream.clone();
let conn_id = cx.conn_id;
let method_id = cx.method_id.raw();
let request_id = cx.request_id.raw();
let channels = cx.channels.clone();
if channels.is_empty() {
let downstream_channel_ids = registry.response_channel_ids();
Box::pin(async move {
let response = upstream
.call_raw_with_channels(method_id, vec![], payload, None)
.await;
let (response_payload, upstream_response_channels) = match response {
Ok(data) => (data.payload, data.channels),
Err(TransportError::Encode(_)) => {
(vec![1, 2], Vec::new()) }
Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
(vec![1, 3], Vec::new()) }
};
let mut downstream_channels = Vec::new();
if !upstream_response_channels.is_empty() {
debug!(
upstream_channels = ?upstream_response_channels,
"ForwardingDispatcher: setting up response channel forwarding"
);
for &upstream_id in &upstream_response_channels {
let downstream_id = downstream_channel_ids.next();
downstream_channels.push(downstream_id);
debug!(
upstream_id,
downstream_id, "ForwardingDispatcher: mapping channel IDs"
);
let (tx, mut rx) = crate::runtime::channel::<Vec<u8>>(64);
upstream.register_incoming(upstream_id, tx);
let task_tx_clone = task_tx.clone();
crate::runtime::spawn(async move {
debug!(
upstream_id,
downstream_id, "ForwardingDispatcher: forwarding task started"
);
while let Some(data) = rx.recv().await {
debug!(
upstream_id,
downstream_id,
data_len = data.len(),
"ForwardingDispatcher: forwarding data"
);
let _ = task_tx_clone
.send(DriverMessage::Data {
conn_id,
channel_id: downstream_id,
payload: data,
})
.await;
}
debug!(
upstream_id,
downstream_id,
"ForwardingDispatcher: forwarding task ended, sending Close"
);
let _ = task_tx_clone
.send(DriverMessage::Close {
conn_id,
channel_id: downstream_id,
})
.await;
});
}
}
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: downstream_channels,
payload: response_payload,
})
.await;
})
} else {
let mut upstream_channels = Vec::with_capacity(channels.len());
let mut ds_to_us_rxs = Vec::with_capacity(channels.len());
let mut us_to_ds_rxs = Vec::with_capacity(channels.len());
let mut channel_map = Vec::with_capacity(channels.len());
let upstream_task_tx = upstream.driver_tx();
for &downstream_id in &channels {
let upstream_id = upstream.alloc_channel_id();
upstream_channels.push(upstream_id);
channel_map.push((downstream_id, upstream_id));
let (ds_to_us_tx, ds_to_us_rx) = crate::runtime::channel(64);
registry.register_incoming(downstream_id, ds_to_us_tx);
ds_to_us_rxs.push(ds_to_us_rx);
let (us_to_ds_tx, us_to_ds_rx) = crate::runtime::channel(64);
upstream.register_incoming(upstream_id, us_to_ds_tx);
us_to_ds_rxs.push(us_to_ds_rx);
}
Box::pin(async move {
let response_future =
upstream.call_raw_with_channels(method_id, upstream_channels, payload, None);
let upstream_conn_id = upstream.conn_id();
for (i, mut rx) in ds_to_us_rxs.into_iter().enumerate() {
let upstream_id = channel_map[i].1;
let upstream_task_tx = upstream_task_tx.clone();
crate::runtime::spawn(async move {
while let Some(data) = rx.recv().await {
let _ = upstream_task_tx
.send(DriverMessage::Data {
conn_id: upstream_conn_id,
channel_id: upstream_id,
payload: data,
})
.await;
}
let _ = upstream_task_tx
.send(DriverMessage::Close {
conn_id: upstream_conn_id,
channel_id: upstream_id,
})
.await;
});
}
for (i, mut rx) in us_to_ds_rxs.into_iter().enumerate() {
let downstream_id = channel_map[i].0;
let task_tx = task_tx.clone();
crate::runtime::spawn(async move {
while let Some(data) = rx.recv().await {
let _ = task_tx
.send(DriverMessage::Data {
conn_id,
channel_id: downstream_id,
payload: data,
})
.await;
}
let _ = task_tx
.send(DriverMessage::Close {
conn_id,
channel_id: downstream_id,
})
.await;
});
}
let response = response_future.await;
let (response_payload, upstream_response_channels) = match response {
Ok(data) => (data.payload, data.channels),
Err(TransportError::Encode(_)) => {
(vec![1, 2], Vec::new()) }
Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
(vec![1, 3], Vec::new()) }
};
let downstream_response_channels: Vec<u64> = upstream_response_channels
.iter()
.filter_map(|&upstream_id| {
channel_map
.iter()
.find(|(_, us)| *us == upstream_id)
.map(|(ds, _)| *ds)
})
.collect();
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: downstream_response_channels,
payload: response_payload,
})
.await;
})
}
}
}
#[derive(Clone)]
pub struct LateBoundHandle {
inner: Arc<std::sync::OnceLock<ConnectionHandle>>,
}
impl LateBoundHandle {
pub fn new() -> Self {
Self {
inner: Arc::new(std::sync::OnceLock::new()),
}
}
pub fn set(&self, handle: ConnectionHandle) {
if self.inner.set(handle).is_err() {
panic!("LateBoundHandle::set called more than once");
}
}
pub fn get(&self) -> Option<&ConnectionHandle> {
self.inner.get()
}
}
impl Default for LateBoundHandle {
fn default() -> Self {
Self::new()
}
}
pub struct LateBoundForwarder {
upstream: LateBoundHandle,
}
impl LateBoundForwarder {
pub fn new(upstream: LateBoundHandle) -> Self {
Self { upstream }
}
}
impl Clone for LateBoundForwarder {
fn clone(&self) -> Self {
Self {
upstream: self.upstream.clone(),
}
}
}
impl ServiceDispatcher for LateBoundForwarder {
fn method_ids(&self) -> Vec<u64> {
vec![]
}
fn dispatch(
&self,
cx: &Context,
payload: Vec<u8>,
registry: &mut ChannelRegistry,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
let task_tx = registry.driver_tx();
let conn_id = cx.conn_id;
let request_id = cx.request_id.raw();
let Some(upstream) = self.upstream.get().cloned() else {
debug!(
method_id = cx.method_id.raw(),
"LateBoundForwarder: upstream not bound, returning Cancelled"
);
return Box::pin(async move {
let _ = task_tx
.send(DriverMessage::Response {
conn_id,
request_id,
channels: vec![],
payload: vec![1, 3], })
.await;
});
};
ForwardingDispatcher::new(upstream).dispatch(cx, payload, registry)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Facet)]
pub struct Never;
#[repr(u8)]
#[derive(Debug, Clone, PartialEq, Eq, Facet)]
pub enum RoamError<E> {
User(E) = 0,
UnknownMethod = 1,
InvalidPayload = 2,
Cancelled = 3,
}
impl<E> RoamError<E> {
pub fn map_user<F, E2>(self, f: F) -> RoamError<E2>
where
F: FnOnce(E) -> E2,
{
match self {
RoamError::User(e) => RoamError::User(f(e)),
RoamError::UnknownMethod => RoamError::UnknownMethod,
RoamError::InvalidPayload => RoamError::InvalidPayload,
RoamError::Cancelled => RoamError::Cancelled,
}
}
}
pub type CallResult<T, E> = ::core::result::Result<T, RoamError<E>>;
pub type BorrowedCallResult<T, E> = OwnedMessage<CallResult<T, E>>;
#[derive(Debug)]
pub enum CallError<E = Never> {
Roam(RoamError<E>),
Encode(facet_postcard::SerializeError),
Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
Protocol(DecodeError),
ConnectionClosed,
DriverGone,
}
impl<E> CallError<E> {
pub fn map_user<F, E2>(self, f: F) -> CallError<E2>
where
F: FnOnce(E) -> E2,
{
match self {
CallError::Roam(roam_err) => CallError::Roam(roam_err.map_user(f)),
CallError::Encode(e) => CallError::Encode(e),
CallError::Decode(e) => CallError::Decode(e),
CallError::Protocol(e) => CallError::Protocol(e),
CallError::ConnectionClosed => CallError::ConnectionClosed,
CallError::DriverGone => CallError::DriverGone,
}
}
}
impl<E: std::fmt::Debug> std::fmt::Display for CallError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CallError::Roam(e) => write!(f, "roam error: {e:?}"),
CallError::Encode(e) => write!(f, "encode error: {e}"),
CallError::Decode(e) => write!(f, "decode error: {e}"),
CallError::Protocol(e) => write!(f, "protocol error: {e}"),
CallError::ConnectionClosed => write!(f, "connection closed"),
CallError::DriverGone => write!(f, "driver task stopped"),
}
}
}
impl<E: std::fmt::Debug> std::error::Error for CallError<E> {}
#[derive(Debug)]
pub enum TransportError {
Encode(facet_postcard::SerializeError),
ConnectionClosed,
DriverGone,
}
impl<E> From<TransportError> for CallError<E> {
fn from(e: TransportError) -> Self {
match e {
TransportError::Encode(e) => CallError::Encode(e),
TransportError::ConnectionClosed => CallError::ConnectionClosed,
TransportError::DriverGone => CallError::DriverGone,
}
}
}
impl std::fmt::Display for TransportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportError::Encode(e) => write!(f, "encode error: {e}"),
TransportError::ConnectionClosed => write!(f, "connection closed"),
TransportError::DriverGone => write!(f, "driver task stopped"),
}
}
}
impl std::error::Error for TransportError {}
#[derive(Debug)]
pub enum DecodeError {
EmptyPayload,
TruncatedError,
UnknownRoamErrorDiscriminant(u8),
InvalidResultDiscriminant(u8),
Postcard(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
}
impl std::fmt::Display for DecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
DecodeError::EmptyPayload => write!(f, "empty response payload"),
DecodeError::TruncatedError => write!(f, "truncated error response"),
DecodeError::UnknownRoamErrorDiscriminant(d) => {
write!(f, "unknown RoamError discriminant: {d}")
}
DecodeError::InvalidResultDiscriminant(d) => {
write!(f, "invalid Result discriminant: {d}")
}
DecodeError::Postcard(e) => write!(f, "postcard: {e}"),
}
}
}
impl std::error::Error for DecodeError {}
impl<E> From<DecodeError> for CallError<E> {
fn from(e: DecodeError) -> Self {
match e {
DecodeError::Postcard(pe) => CallError::Decode(pe),
other => CallError::Protocol(other),
}
}
}
pub fn decode_response<T: Facet<'static>, E: Facet<'static>>(
payload: &[u8],
) -> Result<T, CallError<E>> {
if payload.is_empty() {
return Err(DecodeError::EmptyPayload.into());
}
match payload[0] {
0 => {
facet_postcard::from_slice(&payload[1..]).map_err(CallError::Decode)
}
1 => {
if payload.len() < 2 {
return Err(DecodeError::TruncatedError.into());
}
let roam_error = match payload[1] {
0 => {
let user_error: E =
facet_postcard::from_slice(&payload[2..]).map_err(CallError::Decode)?;
RoamError::User(user_error)
}
1 => RoamError::UnknownMethod,
2 => RoamError::InvalidPayload,
3 => RoamError::Cancelled,
d => return Err(DecodeError::UnknownRoamErrorDiscriminant(d).into()),
};
Err(CallError::Roam(roam_error))
}
d => Err(DecodeError::InvalidResultDiscriminant(d).into()),
}
}
#[allow(async_fn_in_trait)]
pub trait Caller: Clone + Send + Sync + 'static {
#[cfg(not(target_arch = "wasm32"))]
fn call<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send {
self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
}
#[cfg(target_arch = "wasm32")]
fn call<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> {
self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
}
#[cfg(not(target_arch = "wasm32"))]
fn call_with_metadata<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
metadata: roam_wire::Metadata,
) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send;
#[cfg(target_arch = "wasm32")]
fn call_with_metadata<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
metadata: roam_wire::Metadata,
) -> impl std::future::Future<Output = Result<ResponseData, TransportError>>;
fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]);
}
impl Caller for ConnectionHandle {
async fn call_with_metadata<T: Facet<'static> + Send>(
&self,
method_id: u64,
args: &mut T,
metadata: roam_wire::Metadata,
) -> Result<ResponseData, TransportError> {
ConnectionHandle::call_with_metadata(self, method_id, args, metadata).await
}
fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
ConnectionHandle::bind_response_streams(self, response, channels)
}
}
pub struct CallFuture<C, Args, Ok, Err>
where
C: Caller,
Args: Facet<'static>,
{
caller: C,
method_id: u64,
args: Args,
metadata: roam_wire::Metadata,
_phantom: PhantomData<fn() -> (Ok, Err)>,
}
impl<C, Args, Ok, Err> CallFuture<C, Args, Ok, Err>
where
C: Caller,
Args: Facet<'static>,
{
pub fn new(caller: C, method_id: u64, args: Args) -> Self {
Self {
caller,
method_id,
args,
metadata: roam_wire::Metadata::default(),
_phantom: PhantomData,
}
}
pub fn with_metadata(mut self, metadata: roam_wire::Metadata) -> Self {
self.metadata = metadata;
self
}
}
#[cfg(not(target_arch = "wasm32"))]
impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
where
C: Caller,
Args: Facet<'static> + Send + 'static,
Ok: Facet<'static> + Send + 'static,
Err: Facet<'static> + Send + 'static,
{
type Output = Result<Ok, CallError<Err>>;
type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
let CallFuture {
caller,
method_id,
mut args,
metadata,
_phantom,
} = self;
Box::pin(async move {
let response = caller
.call_with_metadata(method_id, &mut args, metadata)
.await
.map_err(CallError::from)?;
let mut result = decode_response::<Ok, Err>(&response.payload)?;
caller.bind_response_streams(&mut result, &response.channels);
Ok(result)
})
}
}
#[cfg(target_arch = "wasm32")]
impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
where
C: Caller,
Args: Facet<'static> + Send + 'static,
Ok: Facet<'static> + Send + 'static,
Err: Facet<'static> + Send + 'static,
{
type Output = Result<Ok, CallError<Err>>;
type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output>>>;
fn into_future(self) -> Self::IntoFuture {
let CallFuture {
caller,
method_id,
mut args,
metadata,
_phantom,
} = self;
Box::pin(async move {
let response = caller
.call_with_metadata(method_id, &mut args, metadata)
.await
.map_err(CallError::from)?;
let mut result = decode_response::<Ok, Err>(&response.payload)?;
caller.bind_response_streams(&mut result, &response.channels);
Ok(result)
})
}
}
struct HandleShared {
conn_id: roam_wire::ConnectionId,
driver_tx: Sender<DriverMessage>,
request_ids: RequestIdGenerator,
channel_ids: ChannelIdAllocator,
channel_registry: std::sync::Mutex<ChannelRegistry>,
diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
}
#[derive(Clone)]
pub struct ConnectionHandle {
shared: Arc<HandleShared>,
}
impl ConnectionHandle {
pub fn new(driver_tx: Sender<DriverMessage>, role: Role, initial_credit: u32) -> Self {
Self::new_with_diagnostics(
roam_wire::ConnectionId::ROOT,
driver_tx,
role,
initial_credit,
None,
)
}
pub fn new_with_diagnostics(
conn_id: roam_wire::ConnectionId,
driver_tx: Sender<DriverMessage>,
role: Role,
initial_credit: u32,
diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
) -> Self {
let channel_registry = ChannelRegistry::new_with_credit(initial_credit, driver_tx.clone());
Self {
shared: Arc::new(HandleShared {
conn_id,
driver_tx,
request_ids: RequestIdGenerator::new(),
channel_ids: ChannelIdAllocator::new(role),
channel_registry: std::sync::Mutex::new(channel_registry),
diagnostic_state,
}),
}
}
pub fn conn_id(&self) -> roam_wire::ConnectionId {
self.shared.conn_id
}
pub fn diagnostic_state(&self) -> Option<&Arc<crate::diagnostic::DiagnosticState>> {
self.shared.diagnostic_state.as_ref()
}
pub async fn call<T: Facet<'static>>(
&self,
method_id: u64,
args: &mut T,
) -> Result<ResponseData, TransportError> {
self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
.await
}
pub async fn call_with_metadata<T: Facet<'static>>(
&self,
method_id: u64,
args: &mut T,
metadata: roam_wire::Metadata,
) -> Result<ResponseData, TransportError> {
let mut drains = Vec::new();
debug!("ConnectionHandle::call: binding streams");
self.bind_streams(args, &mut drains);
let channels = collect_channel_ids(args);
debug!(
channels = ?channels,
drain_count = drains.len(),
"ConnectionHandle::call: collected channels after bind_streams"
);
let payload = facet_postcard::to_vec(args).map_err(TransportError::Encode)?;
let args_debug = if diagnostic::debug_enabled() {
Some(
facet_pretty::PrettyPrinter::new()
.with_colors(facet_pretty::ColorMode::Never)
.with_max_content_len(64)
.format(args),
)
} else {
None
};
if drains.is_empty() {
self.call_raw_with_channels_and_metadata(
method_id, channels, payload, args_debug, metadata,
)
.await
} else {
let request_id = self.shared.request_ids.next();
let (response_tx, response_rx) = oneshot();
if let Some(diag) = &self.shared.diagnostic_state {
let args = args_debug.map(|s| {
let mut map = std::collections::HashMap::new();
map.insert("args".to_string(), s);
map
});
diag.record_outgoing_request(request_id, method_id, args);
diag.associate_channels_with_request(&channels, request_id);
}
let msg = DriverMessage::Call {
conn_id: self.shared.conn_id,
request_id,
method_id,
metadata,
channels,
payload,
response_tx,
};
if self.shared.driver_tx.send(msg).await.is_err() {
return Err(TransportError::DriverGone);
}
let task_tx = self.shared.channel_registry.lock().unwrap().driver_tx();
let conn_id = self.shared.conn_id;
for (channel_id, mut rx) in drains {
let task_tx = task_tx.clone();
crate::runtime::spawn(async move {
loop {
match rx.recv().await {
Some(payload) => {
debug!(
"drain task: received {} bytes on channel {}",
payload.len(),
channel_id
);
let _ = task_tx
.send(DriverMessage::Data {
conn_id,
channel_id,
payload,
})
.await;
debug!(
"drain task: sent DriverMessage::Data for channel {}",
channel_id
);
}
None => {
debug!("drain task: channel {} closed", channel_id);
let _ = task_tx
.send(DriverMessage::Close {
conn_id,
channel_id,
})
.await;
debug!(
"drain task: sent DriverMessage::Close for channel {}",
channel_id
);
break;
}
}
}
});
}
let result = response_rx
.await
.map_err(|_| TransportError::DriverGone)?
.map_err(|_| TransportError::ConnectionClosed);
if let Some(diag) = &self.shared.diagnostic_state {
diag.complete_request(request_id);
}
result
}
}
fn bind_streams<T: Facet<'static>>(
&self,
args: &mut T,
drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
) {
let poke = facet::Poke::new(args);
self.bind_streams_recursive(poke, drains);
}
#[allow(unsafe_code)]
fn bind_streams_recursive(
&self,
mut poke: facet::Poke<'_, '_>,
drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
) {
use facet::Def;
let shape = poke.shape();
if shape.module_path == Some("roam_session") {
if shape.type_identifier == "Rx" {
self.bind_rx_stream(poke, drains);
return;
} else if shape.type_identifier == "Tx" {
self.bind_tx_stream(poke);
return;
}
}
match shape.def {
Def::Scalar => {}
_ if poke.is_struct() => {
let mut ps = poke.into_struct().expect("is_struct was true");
let field_count = ps.field_count();
for i in 0..field_count {
if let Ok(field_poke) = ps.field(i) {
self.bind_streams_recursive(field_poke, drains);
}
}
}
Def::Option(_) => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(inner_poke)) = pe.field(0)
{
self.bind_streams_recursive(inner_poke, drains);
}
}
Def::List(list_def) => {
let len = {
let peek = poke.as_peek();
peek.into_list().map(|pl| pl.len()).unwrap_or(0)
};
if let Some(get_mut_fn) = list_def.vtable.get_mut {
let element_shape = list_def.t;
let data_ptr = poke.data_mut();
for i in 0..len {
let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
if let Some(ptr) = element_ptr {
let element_poke =
unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
self.bind_streams_recursive(element_poke, drains);
}
}
}
}
_ if poke.is_enum() => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(variant_poke)) = pe.field(0)
{
self.bind_streams_recursive(variant_poke, drains);
}
}
_ => {}
}
}
fn bind_rx_stream(
&self,
poke: facet::Poke<'_, '_>,
drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
) {
let channel_id = self.alloc_channel_id();
debug!(
channel_id,
"OutgoingBinder::bind_rx_stream: allocated channel_id for Rx"
);
if let Ok(mut ps) = poke.into_struct() {
if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
{
debug!(
old_id = *id_ref,
new_id = channel_id,
"OutgoingBinder::bind_rx_stream: overwriting channel_id"
);
*id_ref = channel_id;
}
if let Ok(mut receiver_field) = ps.field_by_name("receiver")
&& let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
&& let Some(rx) = slot.take()
{
debug!(
channel_id,
"OutgoingBinder::bind_rx_stream: took receiver, adding to drains"
);
drains.push((channel_id, rx));
}
}
}
fn bind_tx_stream(&self, poke: facet::Poke<'_, '_>) {
let channel_id = self.alloc_channel_id();
debug!(
channel_id,
"OutgoingBinder::bind_tx_stream: allocated channel_id for Tx"
);
if let Ok(mut ps) = poke.into_struct() {
if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
{
debug!(
old_id = *id_ref,
new_id = channel_id,
"OutgoingBinder::bind_tx_stream: overwriting channel_id"
);
*id_ref = channel_id;
}
if let Ok(mut sender_field) = ps.field_by_name("sender")
&& let Ok(slot) = sender_field.get_mut::<SenderSlot>()
&& let Some(tx) = slot.take()
{
debug!(
channel_id,
"OutgoingBinder::bind_tx_stream: took sender, registering for incoming"
);
self.register_incoming(channel_id, tx);
}
}
}
pub async fn call_raw(
&self,
method_id: u64,
payload: Vec<u8>,
) -> Result<Vec<u8>, TransportError> {
self.call_raw_full(method_id, Vec::new(), Vec::new(), payload, None)
.await
.map(|r| r.payload)
}
async fn call_raw_with_channels(
&self,
method_id: u64,
channels: Vec<u64>,
payload: Vec<u8>,
args_debug: Option<String>,
) -> Result<ResponseData, TransportError> {
self.call_raw_full(method_id, Vec::new(), channels, payload, args_debug)
.await
}
async fn call_raw_with_channels_and_metadata(
&self,
method_id: u64,
channels: Vec<u64>,
payload: Vec<u8>,
args_debug: Option<String>,
metadata: roam_wire::Metadata,
) -> Result<ResponseData, TransportError> {
self.call_raw_full(method_id, metadata, channels, payload, args_debug)
.await
}
pub async fn call_raw_with_metadata(
&self,
method_id: u64,
payload: Vec<u8>,
metadata: Vec<(String, roam_wire::MetadataValue)>,
) -> Result<Vec<u8>, TransportError> {
self.call_raw_full(method_id, metadata, Vec::new(), payload, None)
.await
.map(|r| r.payload)
}
async fn call_raw_full(
&self,
method_id: u64,
metadata: Vec<(String, roam_wire::MetadataValue)>,
channels: Vec<u64>,
payload: Vec<u8>,
args_debug: Option<String>,
) -> Result<ResponseData, TransportError> {
let request_id = self.shared.request_ids.next();
let (response_tx, response_rx) = oneshot();
if let Some(diag) = &self.shared.diagnostic_state {
let args = args_debug.map(|s| {
let mut map = std::collections::HashMap::new();
map.insert("args".to_string(), s);
map
});
diag.record_outgoing_request(request_id, method_id, args);
diag.associate_channels_with_request(&channels, request_id);
}
let msg = DriverMessage::Call {
conn_id: self.shared.conn_id,
request_id,
method_id,
metadata,
channels,
payload,
response_tx,
};
self.shared
.driver_tx
.send(msg)
.await
.map_err(|_| TransportError::DriverGone)?;
let result = response_rx
.await
.map_err(|_| TransportError::DriverGone)?
.map_err(|_| TransportError::ConnectionClosed);
if let Some(diag) = &self.shared.diagnostic_state {
diag.complete_request(request_id);
}
result
}
pub async fn connect(
&self,
metadata: roam_wire::Metadata,
dispatcher: Option<Box<dyn ServiceDispatcher>>,
) -> Result<ConnectionHandle, crate::ConnectError> {
let request_id = self.shared.request_ids.next();
let (response_tx, response_rx) = oneshot();
let msg = DriverMessage::Connect {
request_id,
metadata,
response_tx,
dispatcher,
};
self.shared.driver_tx.send(msg).await.map_err(|_| {
crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone"))
})?;
response_rx
.await
.map_err(|_| crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone")))?
}
pub fn alloc_channel_id(&self) -> ChannelId {
self.shared.channel_ids.next()
}
pub fn alloc_request_id(&self) -> u64 {
self.shared.request_ids.next()
}
pub fn register_incoming(&self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
if let Some(diag) = &self.shared.diagnostic_state {
diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Rx, None);
}
self.shared
.channel_registry
.lock()
.unwrap()
.register_incoming(channel_id, tx);
}
pub fn register_outgoing_credit(&self, channel_id: ChannelId) {
if let Some(diag) = &self.shared.diagnostic_state {
diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Tx, None);
}
self.shared
.channel_registry
.lock()
.unwrap()
.register_outgoing_credit(channel_id);
}
pub async fn route_data(
&self,
channel_id: ChannelId,
payload: Vec<u8>,
) -> Result<(), ChannelError> {
let (tx, payload) = self
.shared
.channel_registry
.lock()
.unwrap()
.prepare_route_data(channel_id, payload)?;
let _ = tx.send(payload).await;
Ok(())
}
pub fn close_channel(&self, channel_id: ChannelId) {
if let Some(diag) = &self.shared.diagnostic_state {
diag.record_channel_close(channel_id);
}
self.shared
.channel_registry
.lock()
.unwrap()
.close(channel_id);
}
pub fn reset_channel(&self, channel_id: ChannelId) {
if let Some(diag) = &self.shared.diagnostic_state {
diag.record_channel_close(channel_id);
}
self.shared
.channel_registry
.lock()
.unwrap()
.reset(channel_id);
}
pub fn contains_channel(&self, channel_id: ChannelId) -> bool {
self.shared
.channel_registry
.lock()
.unwrap()
.contains(channel_id)
}
pub fn receive_credit(&self, channel_id: ChannelId, bytes: u32) {
self.shared
.channel_registry
.lock()
.unwrap()
.receive_credit(channel_id, bytes);
}
pub fn driver_tx(&self) -> Sender<DriverMessage> {
self.shared.channel_registry.lock().unwrap().driver_tx()
}
pub fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
patch_channel_ids(response, channels);
let poke = facet::Poke::new(response);
self.bind_response_streams_recursive(poke);
}
#[allow(unsafe_code)]
fn bind_response_streams_recursive(&self, mut poke: facet::Poke<'_, '_>) {
use facet::Def;
let shape = poke.shape();
if shape.module_path == Some("roam_session") && shape.type_identifier == "Rx" {
self.bind_rx_response_stream(poke);
return;
}
match shape.def {
Def::Scalar => {}
_ if poke.is_struct() => {
let mut ps = poke.into_struct().expect("is_struct was true");
let field_count = ps.field_count();
for i in 0..field_count {
if let Ok(field_poke) = ps.field(i) {
self.bind_response_streams_recursive(field_poke);
}
}
}
Def::Option(_) => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(inner_poke)) = pe.field(0)
{
self.bind_response_streams_recursive(inner_poke);
}
}
Def::List(list_def) => {
let len = {
let peek = poke.as_peek();
peek.into_list().map(|pl| pl.len()).unwrap_or(0)
};
if let Some(get_mut_fn) = list_def.vtable.get_mut {
let element_shape = list_def.t;
let data_ptr = poke.data_mut();
for i in 0..len {
let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
if let Some(ptr) = element_ptr {
let element_poke =
unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
self.bind_response_streams_recursive(element_poke);
}
}
}
}
_ if poke.is_enum() => {
if let Ok(mut pe) = poke.into_enum()
&& let Ok(Some(variant_poke)) = pe.field(0)
{
self.bind_response_streams_recursive(variant_poke);
}
}
_ => {}
}
}
fn bind_rx_response_stream(&self, poke: facet::Poke<'_, '_>) {
if let Ok(mut ps) = poke.into_struct() {
let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
&& let Ok(id_ref) = channel_id_field.get::<ChannelId>()
{
*id_ref
} else {
return;
};
let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
if let Ok(mut receiver_field) = ps.field_by_name("receiver")
&& let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
{
slot.set(rx);
}
self.register_incoming(channel_id, tx);
}
}
}
#[derive(Debug)]
pub enum ClientError<TransportError> {
Transport(TransportError),
Encode(facet_postcard::SerializeError),
Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
}
impl<TransportError> From<TransportError> for ClientError<TransportError> {
fn from(value: TransportError) -> Self {
Self::Transport(value)
}
}
#[derive(Debug)]
pub enum DispatchError {
Encode(facet_postcard::SerializeError),
}
#[cfg(not(target_arch = "wasm32"))]
use std::io;
#[cfg(not(target_arch = "wasm32"))]
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
#[cfg(not(target_arch = "wasm32"))]
use tokio::task::JoinHandle;
#[cfg(not(target_arch = "wasm32"))]
pub const DEFAULT_TUNNEL_CHUNK_SIZE: usize = 32 * 1024;
#[derive(Facet)]
pub struct Tunnel {
pub tx: Tx<Vec<u8>>,
pub rx: Rx<Vec<u8>>,
}
pub fn tunnel_pair() -> (Tunnel, Tunnel) {
let (tx1, rx1) = channel::<Vec<u8>>();
let (tx2, rx2) = channel::<Vec<u8>>();
(Tunnel { tx: tx1, rx: rx2 }, Tunnel { tx: tx2, rx: rx1 })
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn pump_read_to_tx<R: AsyncRead + Unpin>(
mut reader: R,
tx: Tx<Vec<u8>>,
chunk_size: usize,
) -> io::Result<()> {
let mut buf = vec![0u8; chunk_size];
loop {
let n = reader.read(&mut buf).await?;
if n == 0 {
break;
}
if tx.send(&buf[..n].to_vec()).await.is_err() {
break;
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub async fn pump_rx_to_write<W: AsyncWrite + Unpin>(
mut rx: Rx<Vec<u8>>,
mut writer: W,
) -> io::Result<()> {
loop {
match rx.recv().await {
Ok(Some(data)) => {
writer.write_all(&data).await?;
}
Ok(None) => {
writer.flush().await?;
break;
}
Err(e) => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("tunnel receive error: {e}"),
));
}
}
}
Ok(())
}
#[cfg(not(target_arch = "wasm32"))]
pub fn tunnel_stream<S>(
stream: S,
tunnel: Tunnel,
chunk_size: usize,
) -> (JoinHandle<io::Result<()>>, JoinHandle<io::Result<()>>)
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let (reader, writer) = tokio::io::split(stream);
let Tunnel { tx, rx } = tunnel;
let read_handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, chunk_size).await });
let write_handle = tokio::spawn(async move { pump_rx_to_write(rx, writer).await });
(read_handle, write_handle)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn channel_id_allocator_initiator_uses_odd_ids() {
let alloc = ChannelIdAllocator::new(Role::Initiator);
assert_eq!(alloc.next(), 1);
assert_eq!(alloc.next(), 3);
assert_eq!(alloc.next(), 5);
assert_eq!(alloc.next(), 7);
}
#[test]
fn channel_id_allocator_acceptor_uses_even_ids() {
let alloc = ChannelIdAllocator::new(Role::Acceptor);
assert_eq!(alloc.next(), 2);
assert_eq!(alloc.next(), 4);
assert_eq!(alloc.next(), 6);
assert_eq!(alloc.next(), 8);
}
#[tokio::test]
async fn tx_serializes_and_rx_deserializes() {
let (tx, mut rx) = channel::<i32>();
let mut taken_rx = rx.receiver.take().expect("receiver should be present");
tx.send(&100).await.unwrap();
tx.send(&200).await.unwrap();
let bytes1 = taken_rx.recv().await.unwrap();
let val1: i32 = facet_postcard::from_slice(&bytes1).unwrap();
assert_eq!(val1, 100);
let bytes2 = taken_rx.recv().await.unwrap();
let val2: i32 = facet_postcard::from_slice(&bytes2).unwrap();
assert_eq!(val2, 200);
}
fn test_registry() -> ChannelRegistry {
let (task_tx, _task_rx) = crate::runtime::channel(10);
ChannelRegistry::new(task_tx)
}
#[tokio::test]
async fn data_after_close_is_rejected() {
let mut registry = test_registry();
let (tx, _rx) = crate::runtime::channel(10);
registry.register_incoming(42, tx);
registry.close(42);
let result = registry.route_data(42, b"data".to_vec()).await;
assert_eq!(result, Err(ChannelError::DataAfterClose));
}
#[tokio::test]
async fn channel_registry_routes_data_to_registered_stream() {
let mut registry = test_registry();
let (tx, mut rx) = crate::runtime::channel(10);
registry.register_incoming(42, tx);
assert!(registry.route_data(42, b"hello".to_vec()).await.is_ok());
assert_eq!(rx.recv().await, Some(b"hello".to_vec()));
assert!(registry.route_data(999, b"nope".to_vec()).await.is_err());
}
#[tokio::test]
async fn channel_registry_close_terminates_stream() {
let mut registry = test_registry();
let (tx, mut rx) = crate::runtime::channel(10);
registry.register_incoming(42, tx);
registry.route_data(42, b"data1".to_vec()).await.unwrap();
registry.close(42);
assert_eq!(rx.recv().await, Some(b"data1".to_vec()));
assert_eq!(rx.recv().await, None);
assert!(!registry.contains(42));
}
#[test]
fn tx_rx_shape_metadata() {
use facet::Facet;
let tx_shape = <Tx<i32> as Facet>::SHAPE;
let rx_shape = <Rx<i32> as Facet>::SHAPE;
assert_eq!(tx_shape.module_path, Some("roam_session"));
assert_eq!(tx_shape.type_identifier, "Tx");
assert_eq!(rx_shape.module_path, Some("roam_session"));
assert_eq!(rx_shape.type_identifier, "Rx");
assert_eq!(tx_shape.type_params.len(), 1);
assert_eq!(rx_shape.type_params.len(), 1);
}
#[tokio::test]
async fn tunnel_pair_connects_bidirectionally() {
let (local, remote) = tunnel_pair();
local.tx.send(&b"hello".to_vec()).await.unwrap();
let mut remote_rx = remote.rx;
let received = remote_rx.recv().await.unwrap().unwrap();
assert_eq!(received, b"hello".to_vec());
remote.tx.send(&b"world".to_vec()).await.unwrap();
let mut local_rx = local.rx;
let received = local_rx.recv().await.unwrap().unwrap();
assert_eq!(received, b"world".to_vec());
}
#[tokio::test]
async fn pump_read_to_tx_sends_chunks() {
use std::io::Cursor;
let data = b"hello world this is a test message";
let reader = Cursor::new(data.to_vec());
let (tx, mut rx) = channel::<Vec<u8>>();
let handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, 10).await });
let mut received = Vec::new();
while let Ok(Some(chunk)) = rx.recv().await {
received.extend(chunk);
}
assert_eq!(received, data.to_vec());
handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn pump_rx_to_write_writes_chunks() {
use std::io::Cursor;
let (tx, rx) = channel::<Vec<u8>>();
let writer = Cursor::new(Vec::new());
let handle = tokio::spawn(async move {
let mut writer = writer;
pump_rx_to_write(rx, &mut writer).await?;
Ok::<_, io::Error>(writer)
});
tx.send(&b"hello ".to_vec()).await.unwrap();
tx.send(&b"world".to_vec()).await.unwrap();
drop(tx);
let writer = handle.await.unwrap().unwrap();
assert_eq!(writer.into_inner(), b"hello world".to_vec());
}
#[tokio::test]
async fn tunnel_stream_bidirectional() {
let (client, server) = tokio::io::duplex(1024);
let (local, remote) = tunnel_pair();
let (client_read_handle, client_write_handle) =
tunnel_stream(client, local, DEFAULT_TUNNEL_CHUNK_SIZE);
tokio::spawn(async move {
remote.tx.send(&b"from tunnel".to_vec()).await.unwrap();
});
let mut server = server;
let mut buf = vec![0u8; 1024];
let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
.await
.unwrap();
assert!(n > 0);
tokio::io::AsyncWriteExt::write_all(&mut server, b"to tunnel")
.await
.unwrap();
drop(server);
client_read_handle.await.unwrap().unwrap();
client_write_handle.await.unwrap().unwrap();
}
#[tokio::test]
async fn tunnel_handles_empty_data() {
let (tx, mut rx) = channel::<Vec<u8>>();
tx.send(&Vec::new()).await.unwrap();
let received = rx.recv().await.unwrap().unwrap();
assert!(received.is_empty());
}
#[tokio::test]
async fn tunnel_close_propagates() {
let (local, remote) = tunnel_pair();
drop(local.tx);
let mut rx = remote.rx;
let result = rx.recv().await;
assert!(matches!(result, Ok(None)));
}
#[test]
fn collect_channel_ids_simple_tx() {
let tx: Tx<i32> = Tx::try_from(42u64).unwrap();
let ids = collect_channel_ids(&tx);
assert_eq!(ids, vec![42]);
}
#[test]
fn collect_channel_ids_simple_rx() {
let rx: Rx<i32> = Rx::try_from(99u64).unwrap();
let ids = collect_channel_ids(&rx);
assert_eq!(ids, vec![99]);
}
#[test]
fn collect_channel_ids_tuple() {
let rx: Rx<String> = Rx::try_from(10u64).unwrap();
let tx: Tx<String> = Tx::try_from(20u64).unwrap();
let args = (rx, tx);
let ids = collect_channel_ids(&args);
assert_eq!(ids, vec![10, 20]);
}
#[test]
fn collect_channel_ids_nested_in_struct() {
#[derive(facet::Facet)]
struct StreamArgs {
input: Rx<i32>,
output: Tx<i32>,
count: u32,
}
let args = StreamArgs {
input: Rx::try_from(100u64).unwrap(),
output: Tx::try_from(200u64).unwrap(),
count: 5,
};
let ids = collect_channel_ids(&args);
assert_eq!(ids, vec![100, 200]);
}
#[test]
fn collect_channel_ids_option_some() {
let tx: Tx<i32> = Tx::try_from(55u64).unwrap();
let args: Option<Tx<i32>> = Some(tx);
let ids = collect_channel_ids(&args);
assert_eq!(ids, vec![55]);
}
#[test]
fn collect_channel_ids_option_none() {
let args: Option<Tx<i32>> = None;
let ids = collect_channel_ids(&args);
assert!(ids.is_empty());
}
#[test]
fn collect_channel_ids_vec() {
let tx1: Tx<i32> = Tx::try_from(1u64).unwrap();
let tx2: Tx<i32> = Tx::try_from(2u64).unwrap();
let tx3: Tx<i32> = Tx::try_from(3u64).unwrap();
let args: Vec<Tx<i32>> = vec![tx1, tx2, tx3];
let ids = collect_channel_ids(&args);
assert_eq!(ids, vec![1, 2, 3]);
}
#[test]
fn collect_channel_ids_deeply_nested() {
#[derive(facet::Facet)]
struct Outer {
inner: Inner,
}
#[derive(facet::Facet)]
struct Inner {
stream: Tx<u8>,
}
let args = Outer {
inner: Inner {
stream: Tx::try_from(777u64).unwrap(),
},
};
let ids = collect_channel_ids(&args);
assert_eq!(ids, vec![777]);
}
}