roam_session/lib.rs
1#![deny(unsafe_code)]
2
3//! Session/state machine and RPC-level utilities.
4//!
5//! Canonical definitions live in `docs/content/spec/_index.md`,
6//! `docs/content/rust-spec/_index.md`, and `docs/content/shm-spec/_index.md`.
7
8#[macro_use]
9mod macros;
10
11pub mod diagnostic;
12pub mod driver;
13pub mod runtime;
14pub mod transport;
15
16pub use driver::{
17 ConnectError, ConnectionError, Driver, FramedClient, HandshakeConfig, IncomingConnection,
18 IncomingConnections, MessageConnector, Negotiated, NoDispatcher, RetryPolicy, accept_framed,
19 connect_framed, connect_framed_with_policy, initiate_framed,
20};
21pub use transport::MessageTransport;
22
23use std::marker::PhantomData;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicU64, Ordering};
26
27use crate::runtime::{OneshotSender, Receiver, Sender, oneshot};
28use facet::Facet;
29use std::convert::Infallible;
30
31pub use roam_frame::{Frame, MsgDesc, OwnedMessage, Payload};
32
33const CHANNEL_SIZE: usize = 1024;
34const RX_STREAM_BUFFER_SIZE: usize = 1024;
35
36// ============================================================================
37// Streaming types
38// ============================================================================
39
40/// Stream ID type.
41pub type ChannelId = u64;
42
43/// Connection role - determines stream ID parity.
44///
45/// The initiator is whoever opened the connection (e.g. connected to a TCP socket,
46/// or opened an SHM channel). The acceptor is whoever accepted/received the connection.
47///
48/// r[impl channeling.id.parity]
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum Role {
51 /// Initiator uses odd stream IDs (1, 3, 5, ...).
52 Initiator,
53 /// Acceptor uses even stream IDs (2, 4, 6, ...).
54 Acceptor,
55}
56
57/// Allocates unique stream IDs with correct parity.
58///
59/// r[impl channeling.id.uniqueness] - IDs are unique within a connection.
60/// r[impl channeling.id.parity] - Initiator uses odd, Acceptor uses even.
61pub struct ChannelIdAllocator {
62 next: AtomicU64,
63}
64
65impl ChannelIdAllocator {
66 /// Create a new allocator for the given role.
67 pub fn new(role: Role) -> Self {
68 let start = match role {
69 Role::Initiator => 1, // odd: 1, 3, 5, ...
70 Role::Acceptor => 2, // even: 2, 4, 6, ...
71 };
72 Self {
73 next: AtomicU64::new(start),
74 }
75 }
76
77 /// Allocate the next stream ID.
78 pub fn next(&self) -> ChannelId {
79 self.next.fetch_add(2, Ordering::Relaxed)
80 }
81}
82
83// ============================================================================
84// SenderSlot - Wrapper for Option<Sender> that implements Facet
85// ============================================================================
86
87/// A wrapper around `Option<Sender<Vec<u8>>>` that implements Facet.
88///
89/// This allows `Poke::get_mut::<SenderSlot>()` to work, enabling `.take()`
90/// via reflection. Used by `ConnectionHandle::call` to extract senders from
91/// `Tx<T>` arguments and register them with the stream registry.
92#[derive(Facet)]
93#[facet(opaque)]
94pub struct SenderSlot {
95 /// The optional sender. Public within crate for `Tx::send()` access.
96 pub(crate) inner: Option<Sender<Vec<u8>>>,
97}
98
99impl SenderSlot {
100 /// Create a slot containing a sender.
101 pub fn new(tx: Sender<Vec<u8>>) -> Self {
102 Self { inner: Some(tx) }
103 }
104
105 /// Create an empty slot.
106 pub fn empty() -> Self {
107 Self { inner: None }
108 }
109
110 /// Take the sender out of the slot, leaving it empty.
111 pub fn take(&mut self) -> Option<Sender<Vec<u8>>> {
112 self.inner.take()
113 }
114
115 /// Check if the slot contains a sender.
116 pub fn is_some(&self) -> bool {
117 self.inner.is_some()
118 }
119
120 /// Check if the slot is empty.
121 pub fn is_none(&self) -> bool {
122 self.inner.is_none()
123 }
124
125 /// Set the sender in this slot.
126 ///
127 /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Tx<T>`
128 /// with an actual channel sender.
129 pub fn set(&mut self, tx: Sender<Vec<u8>>) {
130 self.inner = Some(tx);
131 }
132}
133
134// ============================================================================
135// DriverTxSlot - Wrapper for Option<Sender<DriverMessage>> that implements Facet
136// ============================================================================
137
138/// A wrapper around `Option<Sender<DriverMessage>>` that implements Facet.
139///
140/// This allows `Poke::get_mut::<DriverTxSlot>()` to work, enabling reflection-based
141/// hydration of `Tx<T>` handles on the server side. Sends Data/Close messages
142/// directly to the connection driver.
143#[derive(Facet)]
144#[facet(opaque)]
145pub struct DriverTxSlot {
146 /// The optional sender. Public within crate for `Tx::send()` access.
147 pub(crate) inner: Option<Sender<DriverMessage>>,
148}
149
150impl DriverTxSlot {
151 /// Create a slot containing a task sender.
152 pub fn new(tx: Sender<DriverMessage>) -> Self {
153 Self { inner: Some(tx) }
154 }
155
156 /// Create an empty slot.
157 pub fn empty() -> Self {
158 Self { inner: None }
159 }
160
161 /// Take the sender out of the slot, leaving it empty.
162 pub fn take(&mut self) -> Option<Sender<DriverMessage>> {
163 self.inner.take()
164 }
165
166 /// Check if the slot contains a sender.
167 pub fn is_some(&self) -> bool {
168 self.inner.is_some()
169 }
170
171 /// Check if the slot is empty.
172 pub fn is_none(&self) -> bool {
173 self.inner.is_none()
174 }
175
176 /// Set the task sender in this slot.
177 ///
178 /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Tx<T>`
179 /// with the connection's task message channel.
180 pub fn set(&mut self, tx: Sender<DriverMessage>) {
181 self.inner = Some(tx);
182 }
183
184 /// Clone the sender if present.
185 pub fn clone_inner(&self) -> Option<Sender<DriverMessage>> {
186 self.inner.clone()
187 }
188}
189
190/// Tx stream handle - caller sends data to callee.
191///
192/// r[impl channeling.caller-pov] - From caller's perspective, Tx means "I send".
193/// r[impl channeling.type] - Serializes as u64 stream ID on wire.
194/// r[impl channeling.holder-semantics] - The holder sends on this stream.
195/// r[impl channeling.channels-outlive-response] - Tx streams may outlive Response.
196/// r[impl channeling.lifecycle.immediate-data] - Can send Data before Response.
197/// r[impl channeling.lifecycle.speculative] - Early Data may be wasted on error.
198///
199/// # Facet Implementation
200///
201/// Uses `#[facet(proxy = u64)]` so that:
202/// - `channel_id` is pokeable (Connection can walk args and set stream IDs)
203/// - Serializes as just a `u64` on the wire
204/// - `T` is exposed as a type parameter for codegen introspection
205///
206/// # Two modes of operation
207///
208/// - **Client side**: `sender` holds a channel to an intermediate drain task.
209/// `ConnectionHandle::call` takes the receiver and drains it to wire.
210/// - **Server side**: `task_tx` holds a direct channel to the connection driver.
211/// `ChannelRegistry::bind_streams` sets this, and `send()` writes `DriverMessage::Data`.
212#[derive(Facet)]
213#[facet(proxy = u64)]
214pub struct Tx<T: 'static> {
215 /// The connection ID this stream belongs to.
216 pub conn_id: roam_wire::ConnectionId,
217 /// The unique stream ID for this stream.
218 /// Public so Connection can poke it when binding streams.
219 pub channel_id: ChannelId,
220 /// Channel sender for outgoing data (client-side mode).
221 /// Used when Tx is created via `roam::channel()`.
222 pub sender: SenderSlot,
223 /// Direct driver message sender (server-side mode).
224 /// Used when Tx is hydrated by `ChannelRegistry::bind_streams`.
225 pub driver_tx: DriverTxSlot,
226 /// Phantom data for the element type.
227 #[facet(opaque)]
228 _marker: PhantomData<T>,
229}
230
231/// Serialization: `&Tx<T>` -> u64 (extracts channel_id)
232///
233/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
234#[allow(clippy::infallible_try_from)]
235impl<T: 'static> TryFrom<&Tx<T>> for u64 {
236 type Error = Infallible;
237 fn try_from(tx: &Tx<T>) -> Result<Self, Self::Error> {
238 Ok(tx.channel_id)
239 }
240}
241
242/// Deserialization: u64 -> `Tx<T>` (creates a "hollow" Tx)
243///
244/// Both sender slots are empty - the real sender gets set up by Connection
245/// after deserialization when it binds the stream.
246///
247/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
248#[allow(clippy::infallible_try_from)]
249impl<T: 'static> TryFrom<u64> for Tx<T> {
250 type Error = Infallible;
251 fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
252 // Create a hollow Tx - no actual sender, Connection will bind later
253 // conn_id will be set when binding
254 Ok(Tx {
255 conn_id: roam_wire::ConnectionId::ROOT,
256 channel_id,
257 sender: SenderSlot::empty(),
258 driver_tx: DriverTxSlot::empty(),
259 _marker: PhantomData,
260 })
261 }
262}
263
264impl<T: 'static> Tx<T> {
265 /// Create a new Tx stream with the given ID and sender channel (client-side mode).
266 pub fn new(channel_id: ChannelId, tx: Sender<Vec<u8>>) -> Self {
267 Self {
268 conn_id: roam_wire::ConnectionId::ROOT,
269 channel_id,
270 sender: SenderSlot::new(tx),
271 driver_tx: DriverTxSlot::empty(),
272 _marker: PhantomData,
273 }
274 }
275
276 /// Create an unbound Tx with a sender but channel_id 0.
277 ///
278 /// Used by `roam::channel()` to create a pair before binding.
279 /// Connection will poke the channel_id and conn_id when binding.
280 pub fn unbound(tx: Sender<Vec<u8>>) -> Self {
281 Self {
282 conn_id: roam_wire::ConnectionId::ROOT,
283 channel_id: 0,
284 sender: SenderSlot::new(tx),
285 driver_tx: DriverTxSlot::empty(),
286 _marker: PhantomData,
287 }
288 }
289
290 /// Create a bound Tx with conn_id, channel_id and driver_tx already set.
291 ///
292 /// Used by `roam::channel()` when called during dispatch to create
293 /// response channels that can send Data directly over the wire.
294 pub fn bound(
295 conn_id: roam_wire::ConnectionId,
296 channel_id: ChannelId,
297 tx: Sender<Vec<u8>>,
298 driver_tx: Sender<DriverMessage>,
299 ) -> Self {
300 Self {
301 conn_id,
302 channel_id,
303 sender: SenderSlot::new(tx),
304 driver_tx: DriverTxSlot::new(driver_tx),
305 _marker: PhantomData,
306 }
307 }
308
309 /// Get the stream ID.
310 pub fn channel_id(&self) -> ChannelId {
311 self.channel_id
312 }
313
314 /// Send a value on this stream.
315 ///
316 /// r[impl channeling.data] - Data messages carry serialized values.
317 ///
318 /// Works in two modes:
319 /// - Client-side (or passthrough): sends raw bytes to intermediate channel (drained by connection)
320 /// - Server-side: sends `DriverMessage::Data` directly to connection driver
321 ///
322 /// IMPORTANT: We prefer sender over driver_tx because when a channel created during
323 /// dispatch is passed to a callback, the rx gets a NEW channel_id allocated by the
324 /// caller's bind_streams. The drain task uses that new channel_id, while self.channel_id
325 /// still has the old dispatch-context channel_id. By using sender, data flows through
326 /// the drain task which uses the correct channel_id.
327 pub async fn send(&self, value: &T) -> Result<(), TxError>
328 where
329 T: Facet<'static>,
330 {
331 let bytes = facet_postcard::to_vec(value).map_err(TxError::Serialize)?;
332
333 // Prefer sender - data flows through drain task which has correct channel_id
334 if let Some(tx) = self.sender.inner.as_ref() {
335 tx.send(bytes).await.map_err(|_| TxError::Closed)
336 }
337 // Fallback to direct driver_tx (sender was taken or never set)
338 else if let Some(task_tx) = self.driver_tx.inner.as_ref() {
339 task_tx
340 .send(DriverMessage::Data {
341 conn_id: self.conn_id,
342 channel_id: self.channel_id,
343 payload: bytes,
344 })
345 .await
346 .map_err(|_| TxError::Closed)
347 } else {
348 Err(TxError::Taken)
349 }
350 }
351}
352
353/// When a Tx is dropped, send a Close message.
354///
355/// r[impl channeling.close] - Close terminates the stream.
356///
357/// The Close path depends on how data was sent:
358/// - If sender is present: data went through drain task, drain task sends Close when channel closes
359/// - If only driver_tx is present: data went directly to driver, we send Close via driver_tx
360impl<T: 'static> Drop for Tx<T> {
361 fn drop(&mut self) {
362 // If sender is still present, the drain task will handle Close when
363 // the internal channel closes. Don't send Close via driver_tx because
364 // it would use the wrong channel_id (dispatch-context id vs caller-allocated id).
365 if self.sender.inner.is_some() {
366 // Just drop the sender - drain task handles Close
367 return;
368 }
369
370 // Sender was taken or never set - send Close via driver_tx if available
371 if let Some(task_tx) = self.driver_tx.inner.take() {
372 let conn_id = self.conn_id;
373 let channel_id = self.channel_id;
374 // Use try_send for synchronous Close delivery.
375 // This ensures Close is queued before Response in dispatch_call.
376 //
377 // WARNING: If try_send fails (channel full), we spawn as fallback.
378 // This creates a potential ordering issue where Close could arrive
379 // after Response. To mitigate: task_tx channels should be sized
380 // generously (256+) to make this unlikely. A proper fix would use
381 // unbounded channels for task messages.
382 if task_tx
383 .try_send(DriverMessage::Close {
384 conn_id,
385 channel_id,
386 })
387 .is_err()
388 {
389 // Channel full or closed - spawn as fallback (see warning above)
390 crate::runtime::spawn(async move {
391 let _ = task_tx
392 .send(DriverMessage::Close {
393 conn_id,
394 channel_id,
395 })
396 .await;
397 });
398 }
399 }
400 }
401}
402
403/// Error when sending on a Tx stream.
404#[derive(Debug)]
405pub enum TxError {
406 /// Failed to serialize the value.
407 Serialize(facet_postcard::SerializeError),
408 /// The stream channel is closed.
409 Closed,
410 /// The sender was already taken (e.g., by ConnectionHandle::call).
411 Taken,
412}
413
414impl std::fmt::Display for TxError {
415 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416 match self {
417 TxError::Serialize(e) => write!(f, "serialize error: {e}"),
418 TxError::Closed => write!(f, "stream closed"),
419 TxError::Taken => write!(f, "sender was taken"),
420 }
421 }
422}
423
424impl std::error::Error for TxError {}
425
426// ============================================================================
427// ReceiverSlot - Wrapper for Option<Receiver> that implements Facet
428// ============================================================================
429
430/// A wrapper around `Option<Receiver<Vec<u8>>>` that implements Facet.
431///
432/// This allows `Poke::get_mut::<ReceiverSlot>()` to work, enabling `.take()`
433/// via reflection. Used by `ConnectionHandle::call` to extract receivers from
434/// `Rx<T>` arguments and register them with the stream registry.
435#[derive(Facet)]
436#[facet(opaque)]
437pub struct ReceiverSlot {
438 /// The optional receiver. Public within crate for `Rx::recv()` access.
439 pub(crate) inner: Option<Receiver<Vec<u8>>>,
440}
441
442impl ReceiverSlot {
443 /// Create a slot containing a receiver.
444 pub fn new(rx: Receiver<Vec<u8>>) -> Self {
445 Self { inner: Some(rx) }
446 }
447
448 /// Create an empty slot.
449 pub fn empty() -> Self {
450 Self { inner: None }
451 }
452
453 /// Take the receiver out of the slot, leaving it empty.
454 pub fn take(&mut self) -> Option<Receiver<Vec<u8>>> {
455 self.inner.take()
456 }
457
458 /// Check if the slot contains a receiver.
459 pub fn is_some(&self) -> bool {
460 self.inner.is_some()
461 }
462
463 /// Check if the slot is empty.
464 pub fn is_none(&self) -> bool {
465 self.inner.is_none()
466 }
467
468 /// Set the receiver in this slot.
469 ///
470 /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Rx<T>`
471 /// with an actual channel receiver.
472 pub fn set(&mut self, rx: Receiver<Vec<u8>>) {
473 self.inner = Some(rx);
474 }
475}
476
477/// Rx stream handle - caller receives data from callee.
478///
479/// r[impl channeling.caller-pov] - From caller's perspective, Rx means "I receive".
480/// r[impl channeling.type] - Serializes as u64 stream ID on wire.
481/// r[impl channeling.holder-semantics] - The holder receives from this stream.
482///
483/// # Facet Implementation
484///
485/// Uses `#[facet(proxy = u64)]` so that:
486/// - `channel_id` is pokeable (Connection can walk args and set stream IDs)
487/// - Serializes as just a `u64` on the wire
488/// - `T` is exposed as a type parameter for codegen introspection
489///
490/// The `receiver` field uses `ReceiverSlot` wrapper so that `ConnectionHandle::call`
491/// can use `Poke::get_mut::<ReceiverSlot>()` to `.take()` the receiver and register
492/// it with the stream registry.
493#[derive(Facet)]
494#[facet(proxy = u64)]
495pub struct Rx<T: 'static> {
496 /// The unique stream ID for this stream.
497 /// Public so Connection can poke it when binding streams.
498 pub channel_id: ChannelId,
499 /// Channel receiver for incoming data.
500 /// Uses ReceiverSlot so it's pokeable (can .take() via Poke).
501 pub receiver: ReceiverSlot,
502 /// Phantom data for the element type.
503 #[facet(opaque)]
504 _marker: PhantomData<T>,
505}
506
507/// Serialization: `&Rx<T>` -> u64 (extracts channel_id)
508///
509/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
510#[allow(clippy::infallible_try_from)]
511impl<T: 'static> TryFrom<&Rx<T>> for u64 {
512 type Error = Infallible;
513 fn try_from(rx: &Rx<T>) -> Result<Self, Self::Error> {
514 Ok(rx.channel_id)
515 }
516}
517
518/// Deserialization: u64 -> `Rx<T>` (creates a "hollow" Rx)
519///
520/// The receiver is a placeholder - the real receiver gets set up by Connection
521/// after deserialization when it binds the stream.
522///
523/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
524#[allow(clippy::infallible_try_from)]
525impl<T: 'static> TryFrom<u64> for Rx<T> {
526 type Error = Infallible;
527 fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
528 // Create a hollow Rx - no actual receiver, Connection will bind later
529 Ok(Rx {
530 channel_id,
531 receiver: ReceiverSlot::empty(),
532 _marker: PhantomData,
533 })
534 }
535}
536
537impl<T: 'static> Rx<T> {
538 /// Create a new Rx stream with the given ID and receiver channel.
539 pub fn new(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
540 Self {
541 channel_id,
542 receiver: ReceiverSlot::new(rx),
543 _marker: PhantomData,
544 }
545 }
546
547 /// Create an unbound Rx with a receiver but channel_id 0.
548 ///
549 /// Used by `roam::channel()` to create a pair before binding.
550 /// Connection will poke the channel_id when binding.
551 pub fn unbound(rx: Receiver<Vec<u8>>) -> Self {
552 Self {
553 channel_id: 0,
554 receiver: ReceiverSlot::new(rx),
555 _marker: PhantomData,
556 }
557 }
558
559 /// Create a bound Rx with channel_id already set.
560 ///
561 /// Used by `roam::channel()` when called during dispatch to create
562 /// response channels. The channel_id will be serialized and sent to
563 /// the client, who will bind a receiver for incoming Data.
564 pub fn bound(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
565 Self {
566 channel_id,
567 receiver: ReceiverSlot::new(rx),
568 _marker: PhantomData,
569 }
570 }
571
572 /// Get the stream ID.
573 pub fn channel_id(&self) -> ChannelId {
574 self.channel_id
575 }
576
577 /// Receive the next value from this stream.
578 ///
579 /// Returns `Ok(Some(value))` for each received value,
580 /// `Ok(None)` when the stream is closed,
581 /// or `Err` if deserialization fails.
582 ///
583 /// r[impl channeling.data] - Deserialize Data message payloads.
584 /// r[impl channeling.data.invalid] - Caller must send Goodbye on deserialize error.
585 pub async fn recv(&mut self) -> Result<Option<T>, RxError>
586 where
587 T: Facet<'static>,
588 {
589 let rx = self.receiver.inner.as_mut().ok_or(RxError::Taken)?;
590 match rx.recv().await {
591 Some(bytes) => {
592 let value = facet_postcard::from_slice(&bytes).map_err(RxError::Deserialize)?;
593 Ok(Some(value))
594 }
595 None => Ok(None),
596 }
597 }
598}
599
600/// Error when receiving from a Rx stream.
601#[derive(Debug)]
602pub enum RxError {
603 /// Failed to deserialize the value.
604 Deserialize(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
605 /// The receiver was already taken (e.g., by ConnectionHandle::call).
606 Taken,
607}
608
609impl std::fmt::Display for RxError {
610 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611 match self {
612 RxError::Deserialize(e) => write!(f, "deserialize error: {e}"),
613 RxError::Taken => write!(f, "receiver was taken"),
614 }
615 }
616}
617
618impl std::error::Error for RxError {}
619
620// ============================================================================
621// Channel creation
622// ============================================================================
623
624/// Create an unbound channel pair for streaming RPC.
625///
626/// Returns `(Tx<T>, Rx<T>)` with `channel_id: 0`. The `ConnectionHandle::call`
627/// method will walk the args, find `Rx<T>` or `Tx<T>` fields, assign stream IDs,
628/// and take the internal channel handles to register with the stream registry.
629///
630/// # Channel semantics (like regular mpsc)
631///
632/// - If caller wants to **send** data: pass `rx`, keep `tx`
633/// - If caller wants to **receive** data: pass `tx`, keep `rx`
634///
635/// # Example
636///
637/// ```ignore
638/// // sum(numbers: Rx<i32>) -> i64
639/// let (tx, rx) = roam::channel::<i32>();
640/// let fut = client.sum(rx); // pass rx, keep tx
641/// tx.send(1).await;
642/// tx.send(2).await;
643/// drop(tx);
644/// let sum = fut.await?;
645/// ```
646pub fn channel<T: 'static>() -> (Tx<T>, Rx<T>) {
647 let (sender, receiver) = crate::runtime::channel(CHANNEL_SIZE);
648
649 // Check if we're in a dispatch context - if so, create bound channels
650 if let Some(ctx) = get_dispatch_context() {
651 let channel_id = ctx.channel_ids.next();
652 debug!(channel_id, "roam::channel() creating bound channel pair");
653 (
654 Tx::bound(ctx.conn_id, channel_id, sender, ctx.driver_tx.clone()),
655 Rx::bound(channel_id, receiver),
656 )
657 } else {
658 trace!("roam::channel() creating unbound channel pair (no dispatch context)");
659 (Tx::unbound(sender), Rx::unbound(receiver))
660 }
661}
662
663// ============================================================================
664// Dispatch Context (task-local for response channel binding)
665// ============================================================================
666
667/// Context for binding response channels during dispatch.
668///
669/// When a service handler creates a channel with `roam::channel()` and returns
670/// the Rx, the Tx needs to be bound to send Data over the wire. This context
671/// provides the channel ID allocator and driver_tx needed for binding.
672#[derive(Clone)]
673struct DispatchContext {
674 conn_id: roam_wire::ConnectionId,
675 channel_ids: Arc<ChannelIdAllocator>,
676 driver_tx: Sender<DriverMessage>,
677}
678
679roam_task_local::task_local! {
680 /// Task-local dispatch context. Using task_local instead of thread_local
681 /// is critical: thread_local can leak across different async tasks that
682 /// happen to run on the same worker thread, causing channel binding bugs.
683 static DISPATCH_CONTEXT: DispatchContext;
684}
685
686/// Get the current dispatch context, if any.
687fn get_dispatch_context() -> Option<DispatchContext> {
688 DISPATCH_CONTEXT.try_with(|ctx| ctx.clone()).ok()
689}
690
691// ============================================================================
692// Stream Registry
693// ============================================================================
694
695use std::collections::{HashMap, HashSet};
696
697/// Response data returned from a call, including any response stream channels.
698#[derive(Debug)]
699pub struct ResponseData {
700 /// The response payload bytes.
701 pub payload: Vec<u8>,
702 /// Channel IDs for streams in the response (`Rx<T>` returned by the method).
703 /// Client must register receivers for these channels.
704 pub channels: Vec<u64>,
705}
706
707/// All messages to the connection driver go through a single channel.
708///
709/// This unified channel ensures FIFO ordering: a Call followed by Data
710/// will always be processed in that order, preventing race conditions
711/// where Data could arrive before the Request is sent.
712pub enum DriverMessage {
713 /// Send a Request and expect a Response (client-side call).
714 Call {
715 conn_id: roam_wire::ConnectionId,
716 request_id: u64,
717 method_id: u64,
718 metadata: Vec<(String, roam_wire::MetadataValue)>,
719 /// Channel IDs used by this call (Tx/Rx), in declaration order.
720 channels: Vec<u64>,
721 payload: Vec<u8>,
722 response_tx: OneshotSender<Result<ResponseData, TransportError>>,
723 },
724 /// Send a Data message on a stream.
725 Data {
726 conn_id: roam_wire::ConnectionId,
727 channel_id: ChannelId,
728 payload: Vec<u8>,
729 },
730 /// Send a Close message to end a stream.
731 Close {
732 conn_id: roam_wire::ConnectionId,
733 channel_id: ChannelId,
734 },
735 /// Send a Response message (server-side call completed).
736 Response {
737 conn_id: roam_wire::ConnectionId,
738 request_id: u64,
739 /// Channel IDs for streams in the response (Tx/Rx returned by the method).
740 channels: Vec<u64>,
741 payload: Vec<u8>,
742 },
743 /// Request to open a new virtual connection.
744 Connect {
745 request_id: u64,
746 metadata: roam_wire::Metadata,
747 response_tx: OneshotSender<Result<ConnectionHandle, crate::ConnectError>>,
748 /// Dispatcher for handling incoming requests on the virtual connection.
749 /// If None, the connection can only make calls, not receive them.
750 dispatcher: Option<Box<dyn ServiceDispatcher>>,
751 },
752}
753
754/// Registry of active streams for a connection.
755///
756/// Handles incoming streams (Data from wire → `Rx<T>` / `Tx<T>` handles).
757/// For outgoing streams (server `Tx<T>` args), spawned tasks drain receivers
758/// and send Data/Close messages via `driver_tx`.
759///
760/// r[impl channeling.unknown] - Unknown stream IDs cause Goodbye.
761pub struct ChannelRegistry {
762 /// Connection ID this registry belongs to.
763 conn_id: roam_wire::ConnectionId,
764
765 /// Streams where we receive Data messages (backing `Rx<T>` or `Tx<T>` handles on our side).
766 /// Key: channel_id, Value: sender to route Data payloads to the handle.
767 incoming: HashMap<ChannelId, Sender<Vec<u8>>>,
768
769 /// Stream IDs that have been closed.
770 /// Used to detect data-after-close violations.
771 ///
772 /// r[impl channeling.data-after-close] - Track closed streams.
773 closed: HashSet<ChannelId>,
774
775 // ========================================================================
776 // Flow Control
777 // ========================================================================
778 /// r[impl flow.channel.credit-based] - Credit tracking for incoming streams.
779 /// r[impl flow.channel.all-transports] - Flow control applies to all transports.
780 /// This is the credit we've granted to the peer - bytes they can still send us.
781 /// Decremented when we receive Data, incremented when we send Credit.
782 incoming_credit: HashMap<ChannelId, u32>,
783
784 /// r[impl flow.channel.credit-based] - Credit tracking for outgoing streams.
785 /// r[impl flow.channel.all-transports] - Flow control applies to all transports.
786 /// This is the credit peer granted us - bytes we can still send them.
787 /// Decremented when we send Data, incremented when we receive Credit.
788 outgoing_credit: HashMap<ChannelId, u32>,
789
790 /// Initial credit to grant new streams.
791 /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
792 initial_credit: u32,
793
794 /// Unified channel for all messages to the driver.
795 /// The driver owns the receiving end and sends these on the wire.
796 /// Using a single channel ensures FIFO ordering.
797 driver_tx: Sender<DriverMessage>,
798
799 /// Channel ID allocator for response channels created during dispatch.
800 /// These are channels returned by service methods (e.g., `subscribe() -> Rx<Event>`).
801 response_channel_ids: Arc<ChannelIdAllocator>,
802}
803
804impl ChannelRegistry {
805 /// Create a new registry with the given conn_id, initial credit, driver channel, and role.
806 ///
807 /// The `driver_tx` is used to send all messages (Call/Data/Close/Response)
808 /// to the driver for transmission on the wire.
809 ///
810 /// The `role` determines channel ID parity for response channels:
811 /// - Acceptor (server) uses even IDs
812 /// - Initiator (client) uses odd IDs
813 ///
814 /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
815 pub fn new_with_credit_and_role(
816 conn_id: roam_wire::ConnectionId,
817 initial_credit: u32,
818 driver_tx: Sender<DriverMessage>,
819 role: Role,
820 ) -> Self {
821 Self {
822 conn_id,
823 incoming: HashMap::new(),
824 closed: HashSet::new(),
825 incoming_credit: HashMap::new(),
826 outgoing_credit: HashMap::new(),
827 initial_credit,
828 driver_tx,
829 response_channel_ids: Arc::new(ChannelIdAllocator::new(role)),
830 }
831 }
832
833 /// Create a new registry with the given initial credit and driver channel.
834 /// Uses ROOT conn_id and Acceptor role for backward compatibility (server-side usage).
835 ///
836 /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
837 pub fn new_with_credit(initial_credit: u32, driver_tx: Sender<DriverMessage>) -> Self {
838 Self::new_with_credit_and_role(
839 roam_wire::ConnectionId::ROOT,
840 initial_credit,
841 driver_tx,
842 Role::Acceptor,
843 )
844 }
845
846 /// Create a new registry with default infinite credit.
847 ///
848 /// r[impl flow.channel.infinite-credit] - Implementations MAY use very large credit.
849 /// r[impl flow.channel.zero-credit] - With infinite credit, zero-credit never occurs.
850 /// This disables backpressure but simplifies implementation.
851 pub fn new(driver_tx: Sender<DriverMessage>) -> Self {
852 Self::new_with_credit(u32::MAX, driver_tx)
853 }
854
855 /// Get the connection ID for this registry.
856 pub fn conn_id(&self) -> roam_wire::ConnectionId {
857 self.conn_id
858 }
859
860 /// Get the dispatch context for response channel binding.
861 ///
862 /// Used by `dispatch_call` and `dispatch_call_infallible` to set up
863 /// thread-local context so `roam::channel()` can create bound channels.
864 pub(crate) fn dispatch_context(&self) -> DispatchContext {
865 DispatchContext {
866 conn_id: self.conn_id,
867 channel_ids: self.response_channel_ids.clone(),
868 driver_tx: self.driver_tx.clone(),
869 }
870 }
871
872 /// Get a clone of the driver message sender.
873 ///
874 /// Used by codegen to spawn tasks that send Data/Close/Response messages.
875 pub fn driver_tx(&self) -> Sender<DriverMessage> {
876 self.driver_tx.clone()
877 }
878
879 /// Get the response channel ID allocator.
880 /// Used by ForwardingDispatcher to allocate downstream channel IDs for response channels.
881 pub fn response_channel_ids(&self) -> Arc<ChannelIdAllocator> {
882 self.response_channel_ids.clone()
883 }
884
885 /// Register an incoming stream.
886 ///
887 /// The connection layer will route Data messages for this channel_id to the sender.
888 /// Used for both `Rx<T>` (caller receives from callee) and `Tx<T>` (callee sends to caller).
889 ///
890 /// r[impl flow.channel.initial-credit] - Stream starts with initial credit.
891 pub fn register_incoming(&mut self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
892 self.incoming.insert(channel_id, tx);
893 // Grant initial credit - peer can send us this many bytes
894 self.incoming_credit.insert(channel_id, self.initial_credit);
895 }
896
897 /// Register credit tracking for an outgoing stream.
898 ///
899 /// The actual receiver is NOT stored here - the driver owns it directly.
900 /// This only sets up credit tracking for the stream.
901 ///
902 /// r[impl flow.channel.initial-credit] - Stream starts with initial credit.
903 pub fn register_outgoing_credit(&mut self, channel_id: ChannelId) {
904 // Assume peer grants us initial credit - we can send them this many bytes
905 self.outgoing_credit.insert(channel_id, self.initial_credit);
906 }
907
908 /// Route a Data message payload to the appropriate incoming stream.
909 ///
910 /// Returns Ok(()) if routed successfully, Err(ChannelError) otherwise.
911 ///
912 /// r[impl channeling.data] - Data messages routed by channel_id.
913 /// r[impl channeling.data-after-close] - Reject data on closed streams.
914 /// r[impl flow.channel.credit-overrun] - Reject if data exceeds remaining credit.
915 /// r[impl flow.channel.credit-consume] - Deduct bytes from remaining credit.
916 /// r[impl flow.channel.byte-accounting] - Credit measured in payload bytes.
917 ///
918 /// Returns a sender and payload if routing is allowed, or an error.
919 /// The actual send must be done by the caller to avoid holding locks across await.
920 pub fn prepare_route_data(
921 &mut self,
922 channel_id: ChannelId,
923 payload: Vec<u8>,
924 ) -> Result<(Sender<Vec<u8>>, Vec<u8>), ChannelError> {
925 // Check for data-after-close
926 if self.closed.contains(&channel_id) {
927 return Err(ChannelError::DataAfterClose);
928 }
929
930 // Check credit before routing
931 // r[impl flow.channel.credit-overrun] - Reject if exceeds credit
932 let payload_len = payload.len() as u32;
933 if let Some(credit) = self.incoming_credit.get_mut(&channel_id) {
934 if payload_len > *credit {
935 return Err(ChannelError::CreditOverrun);
936 }
937 // r[impl flow.channel.credit-consume] - Deduct from credit
938 *credit -= payload_len;
939 }
940 // Note: if no credit entry exists, the stream may not be registered yet
941 // (e.g., Rx stream created by callee). In that case, skip credit check.
942
943 if let Some(tx) = self.incoming.get(&channel_id) {
944 Ok((tx.clone(), payload))
945 } else {
946 Err(ChannelError::Unknown)
947 }
948 }
949
950 /// Route a Data message payload to the appropriate incoming stream.
951 ///
952 /// Returns Ok(()) if routed successfully, Err(ChannelError) otherwise.
953 ///
954 /// r[impl channeling.data] - Data messages routed by channel_id.
955 /// r[impl channeling.data-after-close] - Reject data on closed streams.
956 /// r[impl flow.channel.credit-overrun] - Reject if data exceeds remaining credit.
957 /// r[impl flow.channel.credit-consume] - Deduct bytes from remaining credit.
958 /// r[impl flow.channel.byte-accounting] - Credit measured in payload bytes.
959 pub async fn route_data(
960 &mut self,
961 channel_id: ChannelId,
962 payload: Vec<u8>,
963 ) -> Result<(), ChannelError> {
964 let (tx, payload) = self.prepare_route_data(channel_id, payload)?;
965 // If send fails, the Rx<T> was dropped - that's okay, just drop the data
966 let _ = tx.send(payload).await;
967 Ok(())
968 }
969
970 /// Close an incoming stream (remove from registry).
971 ///
972 /// Dropping the sender will cause the `Rx<T>`'s recv() to return None.
973 ///
974 /// r[impl channeling.close] - Close terminates the stream.
975 /// r[impl flow.channel.close-exempt] - Close doesn't consume credit.
976 pub fn close(&mut self, channel_id: ChannelId) {
977 self.incoming.remove(&channel_id);
978 self.incoming_credit.remove(&channel_id);
979 self.outgoing_credit.remove(&channel_id);
980 self.closed.insert(channel_id);
981 }
982
983 /// Reset a stream (remove from registry, discard credit).
984 ///
985 /// r[impl channeling.reset] - Reset terminates the stream abruptly.
986 /// r[impl channeling.reset.credit] - Outstanding credit is lost on reset.
987 pub fn reset(&mut self, channel_id: ChannelId) {
988 self.incoming.remove(&channel_id);
989 self.incoming_credit.remove(&channel_id);
990 self.outgoing_credit.remove(&channel_id);
991 self.closed.insert(channel_id);
992 }
993
994 /// Receive a Credit message - add credit for an outgoing stream.
995 ///
996 /// r[impl flow.channel.credit-grant] - Credit message adds to available credit.
997 /// r[impl flow.channel.credit-additive] - Credit accumulates additively.
998 pub fn receive_credit(&mut self, channel_id: ChannelId, bytes: u32) {
999 if let Some(credit) = self.outgoing_credit.get_mut(&channel_id) {
1000 // r[impl flow.channel.credit-additive] - Add to existing credit
1001 *credit = credit.saturating_add(bytes);
1002 }
1003 // If no entry, stream may be closed or unknown - ignore
1004 }
1005
1006 /// Check if a stream ID is registered (either incoming or outgoing credit).
1007 pub fn contains(&self, channel_id: ChannelId) -> bool {
1008 self.incoming.contains_key(&channel_id) || self.outgoing_credit.contains_key(&channel_id)
1009 }
1010
1011 /// Check if a stream ID is registered as incoming.
1012 pub fn contains_incoming(&self, channel_id: ChannelId) -> bool {
1013 self.incoming.contains_key(&channel_id)
1014 }
1015
1016 /// Check if a stream ID has outgoing credit registered.
1017 pub fn contains_outgoing(&self, channel_id: ChannelId) -> bool {
1018 self.outgoing_credit.contains_key(&channel_id)
1019 }
1020
1021 /// Check if a stream has been closed.
1022 pub fn is_closed(&self, channel_id: ChannelId) -> bool {
1023 self.closed.contains(&channel_id)
1024 }
1025
1026 /// Get the number of active outgoing streams (by credit tracking).
1027 pub fn outgoing_count(&self) -> usize {
1028 self.outgoing_credit.len()
1029 }
1030
1031 /// Get remaining credit for an outgoing stream.
1032 ///
1033 /// Returns None if stream is not registered.
1034 pub fn outgoing_credit(&self, channel_id: ChannelId) -> Option<u32> {
1035 self.outgoing_credit.get(&channel_id).copied()
1036 }
1037
1038 /// Get remaining credit we've granted for an incoming stream.
1039 ///
1040 /// Returns None if stream is not registered.
1041 pub fn incoming_credit(&self, channel_id: ChannelId) -> Option<u32> {
1042 self.incoming_credit.get(&channel_id).copied()
1043 }
1044
1045 /// Bind streams in deserialized args for server-side dispatch.
1046 ///
1047 /// Walks the args using Poke reflection to find any `Rx<T>` or `Tx<T>` fields.
1048 /// For each stream found:
1049 /// - For `Rx<T>`: creates a channel, sets the receiver slot, registers for incoming data
1050 /// - For `Tx<T>`: sets the task_tx so send() writes directly to the wire
1051 ///
1052 /// # Example
1053 ///
1054 /// ```ignore
1055 /// let mut args = facet_postcard::from_slice::<(Rx<i32>, Tx<String>)>(&payload)?;
1056 /// registry.bind_streams(&mut args);
1057 /// let (input, output) = args;
1058 /// // ... call handler with input, output ...
1059 /// // When handler returns and Tx is dropped, Close is sent automatically
1060 /// ```
1061 pub fn bind_streams<T: Facet<'static>>(&mut self, args: &mut T) {
1062 let poke = facet::Poke::new(args);
1063 self.bind_streams_recursive(poke);
1064 }
1065
1066 /// Recursively walk a Poke value looking for Rx/Tx streams to bind.
1067 #[allow(unsafe_code)]
1068 fn bind_streams_recursive(&mut self, mut poke: facet::Poke<'_, '_>) {
1069 use facet::Def;
1070
1071 let shape = poke.shape();
1072
1073 trace!(
1074 module_path = ?shape.module_path,
1075 type_identifier = shape.type_identifier,
1076 "bind_streams_recursive: visiting type"
1077 );
1078
1079 // Check if this is an Rx or Tx type
1080 if shape.module_path == Some("roam_session") {
1081 if shape.type_identifier == "Rx" {
1082 debug!("bind_streams_recursive: found Rx, binding");
1083 self.bind_rx_stream(poke);
1084 return;
1085 } else if shape.type_identifier == "Tx" {
1086 debug!("bind_streams_recursive: found Tx, binding");
1087 self.bind_tx_stream(poke);
1088 return;
1089 }
1090 }
1091
1092 // Dispatch based on the shape's definition
1093 match shape.def {
1094 Def::Scalar => {}
1095
1096 // Recurse into struct/tuple fields
1097 _ if poke.is_struct() => {
1098 let mut ps = poke.into_struct().expect("is_struct was true");
1099 let field_count = ps.field_count();
1100 trace!(field_count, "bind_streams_recursive: recursing into struct");
1101 for i in 0..field_count {
1102 if let Ok(field_poke) = ps.field(i) {
1103 self.bind_streams_recursive(field_poke);
1104 }
1105 }
1106 }
1107
1108 // Recurse into Option<T>
1109 Def::Option(_) => {
1110 // Option is represented as an enum, use into_enum to access its value
1111 if let Ok(mut pe) = poke.into_enum()
1112 && let Ok(Some(inner_poke)) = pe.field(0)
1113 {
1114 self.bind_streams_recursive(inner_poke);
1115 }
1116 }
1117
1118 // Recurse into list elements (e.g., Vec<Tx<T>>)
1119 Def::List(list_def) => {
1120 let len = {
1121 let peek = poke.as_peek();
1122 peek.into_list().map(|pl| pl.len()).unwrap_or(0)
1123 };
1124 // Get mutable access to elements via VTable (no PokeList exists)
1125 if let Some(get_mut_fn) = list_def.vtable.get_mut {
1126 let element_shape = list_def.t;
1127 let data_ptr = poke.data_mut();
1128 for i in 0..len {
1129 // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
1130 let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
1131 if let Some(ptr) = element_ptr {
1132 // SAFETY: ptr points to a valid element with the correct shape
1133 let element_poke =
1134 unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
1135 self.bind_streams_recursive(element_poke);
1136 }
1137 }
1138 }
1139 }
1140
1141 // Other enum variants
1142 _ if poke.is_enum() => {
1143 if let Ok(mut pe) = poke.into_enum()
1144 && let Ok(Some(variant_poke)) = pe.field(0)
1145 {
1146 self.bind_streams_recursive(variant_poke);
1147 }
1148 }
1149
1150 _ => {}
1151 }
1152 }
1153
1154 /// Bind an Rx<T> stream for server-side dispatch.
1155 ///
1156 /// Server receives data from client on this stream.
1157 /// Creates a channel, sets the receiver slot, registers the sender for routing.
1158 fn bind_rx_stream(&mut self, poke: facet::Poke<'_, '_>) {
1159 if let Ok(mut ps) = poke.into_struct() {
1160 // Get the channel_id that was deserialized from the wire
1161 let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
1162 && let Ok(id_ref) = channel_id_field.get::<ChannelId>()
1163 {
1164 *id_ref
1165 } else {
1166 warn!("bind_rx_stream: could not get channel_id field");
1167 return;
1168 };
1169
1170 debug!(channel_id, "bind_rx_stream: registering incoming channel");
1171
1172 // Create channel and set receiver slot
1173 let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
1174
1175 if let Ok(mut receiver_field) = ps.field_by_name("receiver")
1176 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
1177 {
1178 slot.set(rx);
1179 }
1180
1181 // Register for incoming data routing
1182 self.register_incoming(channel_id, tx);
1183 debug!(channel_id, "bind_rx_stream: channel registered");
1184 } else {
1185 warn!("bind_rx_stream: could not convert poke to struct");
1186 }
1187 }
1188
1189 /// Bind a Tx<T> stream for server-side dispatch.
1190 ///
1191 /// Server sends data to client on this stream.
1192 /// Sets the conn_id and driver_tx so Tx::send() writes DriverMessage::Data to the wire.
1193 /// When the Tx is dropped, it sends DriverMessage::Close automatically.
1194 fn bind_tx_stream(&mut self, poke: facet::Poke<'_, '_>) {
1195 if let Ok(mut ps) = poke.into_struct() {
1196 // Set conn_id so Data/Close messages go to the correct virtual connection
1197 // r[impl core.conn.independence]
1198 if let Ok(mut conn_id_field) = ps.field_by_name("conn_id")
1199 && let Ok(id_ref) = conn_id_field.get_mut::<roam_wire::ConnectionId>()
1200 {
1201 *id_ref = self.conn_id;
1202 }
1203
1204 // Set driver_tx so Tx::send() can write directly to the wire
1205 if let Ok(mut driver_tx_field) = ps.field_by_name("driver_tx")
1206 && let Ok(slot) = driver_tx_field.get_mut::<DriverTxSlot>()
1207 {
1208 slot.set(self.driver_tx.clone());
1209 }
1210 }
1211 }
1212}
1213
1214/// Error when routing stream data.
1215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1216pub enum ChannelError {
1217 /// Stream ID not found in registry.
1218 Unknown,
1219 /// Data received after stream was closed.
1220 DataAfterClose,
1221 /// r[impl flow.channel.credit-overrun] - Data exceeded remaining credit.
1222 CreditOverrun,
1223}
1224
1225// ============================================================================
1226// Flow Control
1227// ============================================================================
1228
1229/// Abstraction for stream flow control mechanism.
1230///
1231/// Different transports implement credit-based flow control differently:
1232/// - **Stream transports** (TCP, WebSocket): explicit `Message::Credit` on the wire
1233/// - **SHM**: shared atomic counters in the channel table (`ChannelEntry::granted_total`)
1234///
1235/// This trait abstracts the mechanism while `ChannelRegistry` remains the source
1236/// of truth for stream lifecycle (routing, ordering, existence).
1237///
1238/// r[impl flow.channel.credit-based]
1239/// r[impl flow.channel.all-transports]
1240pub trait FlowControl: Send {
1241 /// Called when we receive data on a channel (receiver side).
1242 ///
1243 /// The implementation may grant credit back to the sender:
1244 /// - Stream: queue a `Message::Credit` to send
1245 /// - SHM: increment `ChannelEntry::granted_total` atomically
1246 ///
1247 /// r[impl flow.channel.credit-grant]
1248 fn on_data_received(&mut self, channel_id: ChannelId, bytes: u32);
1249
1250 /// Wait until we have enough credit to send `bytes` on a channel (sender side).
1251 ///
1252 /// - Stream: check `ChannelRegistry::outgoing_credit`, wait on notify if insufficient
1253 /// - SHM: poll/futex wait on `granted_total - sent_total >= bytes`
1254 ///
1255 /// Returns `Ok(())` when credit is available, `Err` if the channel is closed/invalid.
1256 ///
1257 /// r[impl flow.channel.zero-credit]
1258 fn wait_for_send_credit(
1259 &mut self,
1260 channel_id: ChannelId,
1261 bytes: u32,
1262 ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
1263
1264 /// Consume credit after sending data (sender side).
1265 ///
1266 /// Called after successfully sending `bytes` on a channel.
1267 /// - Stream: decrement `ChannelRegistry::outgoing_credit`
1268 /// - SHM: increment local `sent_total`
1269 ///
1270 /// r[impl flow.channel.credit-consume]
1271 fn consume_send_credit(&mut self, channel_id: ChannelId, bytes: u32);
1272}
1273
1274/// No-op flow control for infinite credit mode.
1275///
1276/// r[impl flow.channel.infinite-credit]
1277///
1278/// Used when flow control is disabled or not yet implemented.
1279/// All operations succeed immediately without tracking.
1280#[derive(Debug, Clone, Copy, Default)]
1281pub struct InfiniteCredit;
1282
1283impl FlowControl for InfiniteCredit {
1284 fn on_data_received(&mut self, _channel_id: ChannelId, _bytes: u32) {
1285 // No credit tracking needed
1286 }
1287
1288 async fn wait_for_send_credit(
1289 &mut self,
1290 _channel_id: ChannelId,
1291 _bytes: u32,
1292 ) -> std::io::Result<()> {
1293 // Infinite credit - always available
1294 Ok(())
1295 }
1296
1297 fn consume_send_credit(&mut self, _channel_id: ChannelId, _bytes: u32) {
1298 // No credit tracking needed
1299 }
1300}
1301
1302// ============================================================================
1303// Request ID generation
1304// ============================================================================
1305
1306/// Generates unique request IDs for a connection.
1307///
1308/// r[impl call.request-id.uniqueness] - monotonically increasing counter starting at 1
1309pub struct RequestIdGenerator {
1310 next: AtomicU64,
1311}
1312
1313impl RequestIdGenerator {
1314 /// Create a new generator starting at 1.
1315 pub fn new() -> Self {
1316 Self {
1317 next: AtomicU64::new(1),
1318 }
1319 }
1320
1321 /// Generate the next unique request ID.
1322 pub fn next(&self) -> u64 {
1323 self.next.fetch_add(1, Ordering::Relaxed)
1324 }
1325}
1326
1327impl Default for RequestIdGenerator {
1328 fn default() -> Self {
1329 Self::new()
1330 }
1331}
1332
1333// ============================================================================
1334// Dispatch Helper
1335// ============================================================================
1336
1337/// Helper for dispatching RPC methods with minimal generated code.
1338///
1339/// This function handles the common dispatch pattern:
1340/// 1. Deserialize args from payload
1341/// 2. Bind any Tx/Rx streams via registry
1342/// 3. Call the handler closure
1343/// 4. Encode the result and send Response
1344///
1345/// The generated code just needs to provide a closure that calls the handler method.
1346///
1347/// # Type Parameters
1348///
1349/// - `A`: Args tuple type (must implement Facet for deserialization)
1350/// - `R`: Result ok type (must implement Facet for serialization)
1351/// - `E`: User error type (must implement Facet for serialization)
1352/// - `F`: Handler closure type
1353/// - `Fut`: Future returned by handler
1354///
1355/// # Example
1356///
1357/// ```ignore
1358/// fn dispatch_echo(&self, payload: Vec<u8>, request_id: u64, registry: &mut ChannelRegistry)
1359/// -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
1360/// {
1361/// let handler = self.handler.clone();
1362/// dispatch_call(payload, request_id, registry, move |args: (String,)| async move {
1363/// handler.echo(args.0).await
1364/// })
1365/// }
1366/// ```
1367///
1368/// The handler returns `Result<R, E>` - user errors are automatically wrapped
1369/// in `RoamError::User(e)` for wire serialization.
1370///
1371/// The `channels` parameter contains channel IDs from the Request message framing.
1372/// These are patched into the deserialized args before binding streams.
1373pub fn dispatch_call<A, R, E, F, Fut>(
1374 cx: &Context,
1375 payload: Vec<u8>,
1376 registry: &mut ChannelRegistry,
1377 handler: F,
1378) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
1379where
1380 A: Facet<'static> + Send,
1381 R: Facet<'static> + Send,
1382 E: Facet<'static> + Send,
1383 F: FnOnce(A) -> Fut + Send + 'static,
1384 Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
1385{
1386 let conn_id = cx.conn_id;
1387 let request_id = cx.request_id.raw();
1388 let channels = &cx.channels;
1389
1390 // Deserialize args
1391 let mut args: A = match facet_postcard::from_slice(&payload) {
1392 Ok(args) => args,
1393 Err(_) => {
1394 let task_tx = registry.driver_tx();
1395 return Box::pin(async move {
1396 // InvalidPayload error: Result::Err(1) + RoamError::InvalidPayload(2)
1397 let _ = task_tx
1398 .send(DriverMessage::Response {
1399 conn_id,
1400 request_id,
1401 channels: Vec::new(),
1402 payload: vec![1, 2],
1403 })
1404 .await;
1405 });
1406 }
1407 };
1408
1409 // Patch channel IDs from Request framing into deserialized args
1410 debug!(channels = ?channels, "dispatch_call: patching channel IDs");
1411 patch_channel_ids(&mut args, channels);
1412
1413 // Bind streams via reflection - THIS MUST HAPPEN SYNCHRONOUSLY
1414 debug!("dispatch_call: binding streams SYNC");
1415 registry.bind_streams(&mut args);
1416 debug!("dispatch_call: streams bound SYNC - channels should now be registered");
1417
1418 let task_tx = registry.driver_tx();
1419 let dispatch_ctx = registry.dispatch_context();
1420
1421 // Use task_local scope so roam::channel() creates bound channels.
1422 // This is critical: unlike thread_local, task_local won't leak to other
1423 // tasks that happen to run on the same worker thread.
1424 Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
1425 debug!("dispatch_call: handler ASYNC starting");
1426 let result = handler(args).await;
1427 debug!("dispatch_call: handler ASYNC finished");
1428 let (payload, response_channels) = match result {
1429 Ok(ref ok_result) => {
1430 // Collect channel IDs from the result (e.g., Rx<T> in return type)
1431 let channels = collect_channel_ids(ok_result);
1432 // Result::Ok(0) + serialized value
1433 let mut out = vec![0u8];
1434 match facet_postcard::to_vec(ok_result) {
1435 Ok(bytes) => out.extend(bytes),
1436 Err(_) => return,
1437 }
1438 (out, channels)
1439 }
1440 Err(user_error) => {
1441 // Result::Err(1) + RoamError::User(0) + serialized user error
1442 let mut out = vec![1u8, 0u8];
1443 match facet_postcard::to_vec(&user_error) {
1444 Ok(bytes) => out.extend(bytes),
1445 Err(_) => return,
1446 }
1447 (out, Vec::new())
1448 }
1449 };
1450
1451 // Send Response with channel IDs for any Rx<T> in the result.
1452 // ForwardingDispatcher uses these to set up Data forwarding.
1453 let _ = task_tx
1454 .send(DriverMessage::Response {
1455 conn_id,
1456 request_id,
1457 channels: response_channels,
1458 payload,
1459 })
1460 .await;
1461 }))
1462}
1463
1464/// Dispatch helper for infallible methods (those that return `T` instead of `Result<T, E>`).
1465///
1466/// Same as `dispatch_call` but for handlers that cannot fail at the application level.
1467pub fn dispatch_call_infallible<A, R, F, Fut>(
1468 cx: &Context,
1469 payload: Vec<u8>,
1470 registry: &mut ChannelRegistry,
1471 handler: F,
1472) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
1473where
1474 A: Facet<'static> + Send,
1475 R: Facet<'static> + Send,
1476 F: FnOnce(A) -> Fut + Send + 'static,
1477 Fut: std::future::Future<Output = R> + Send + 'static,
1478{
1479 let conn_id = cx.conn_id;
1480 let request_id = cx.request_id.raw();
1481 let channels = &cx.channels;
1482
1483 // Deserialize args
1484 let mut args: A = match facet_postcard::from_slice(&payload) {
1485 Ok(args) => args,
1486 Err(_) => {
1487 let task_tx = registry.driver_tx();
1488 return Box::pin(async move {
1489 // InvalidPayload error: Result::Err(1) + RoamError::InvalidPayload(2)
1490 let _ = task_tx
1491 .send(DriverMessage::Response {
1492 conn_id,
1493 request_id,
1494 channels: Vec::new(),
1495 payload: vec![1, 2],
1496 })
1497 .await;
1498 });
1499 }
1500 };
1501
1502 // Patch channel IDs from Request framing into deserialized args
1503 patch_channel_ids(&mut args, channels);
1504
1505 // Bind streams via reflection
1506 registry.bind_streams(&mut args);
1507
1508 let task_tx = registry.driver_tx();
1509 let dispatch_ctx = registry.dispatch_context();
1510
1511 // Use task_local scope so roam::channel() creates bound channels.
1512 Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
1513 let result = handler(args).await;
1514
1515 // Collect channel IDs from the result (e.g., Rx<T> in return type)
1516 let response_channels = collect_channel_ids(&result);
1517 if !response_channels.is_empty() {
1518 debug!(
1519 channels = ?response_channels,
1520 "dispatch_call_infallible: collected response channels"
1521 );
1522 }
1523
1524 // Result::Ok(0) + serialized value
1525 let mut payload = vec![0u8];
1526 match facet_postcard::to_vec(&result) {
1527 Ok(bytes) => payload.extend(bytes),
1528 Err(_) => return,
1529 }
1530
1531 // Send Response with channel IDs for any Rx<T> in the result.
1532 // ForwardingDispatcher uses these to set up Data forwarding.
1533 let _ = task_tx
1534 .send(DriverMessage::Response {
1535 conn_id,
1536 request_id,
1537 channels: response_channels,
1538 payload,
1539 })
1540 .await;
1541 }))
1542}
1543
1544/// Send an "unknown method" error response.
1545///
1546/// Used by dispatchers when the method_id doesn't match any known method.
1547pub fn dispatch_unknown_method(
1548 cx: &Context,
1549 registry: &mut ChannelRegistry,
1550) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1551 let conn_id = cx.conn_id;
1552 let request_id = cx.request_id.raw();
1553 let task_tx = registry.driver_tx();
1554 Box::pin(async move {
1555 // UnknownMethod error
1556 let _ = task_tx
1557 .send(DriverMessage::Response {
1558 conn_id,
1559 request_id,
1560 channels: Vec::new(),
1561 payload: vec![1, 1],
1562 })
1563 .await;
1564 })
1565}
1566
1567/// Collect channel IDs from args by walking with Peek.
1568///
1569/// Returns channel IDs in declaration order (depth-first traversal).
1570/// Used by the client to populate the `channels` vec in Request messages.
1571///
1572/// r[impl call.request.channels] - Collects channel IDs in declaration order for the Request.
1573pub fn collect_channel_ids<T: Facet<'static>>(args: &T) -> Vec<u64> {
1574 let mut ids = Vec::new();
1575 let poke = facet::Peek::new(args);
1576 collect_channel_ids_recursive(poke, &mut ids);
1577 ids
1578}
1579
1580fn collect_channel_ids_recursive(peek: facet::Peek<'_, '_>, ids: &mut Vec<u64>) {
1581 let shape = peek.shape();
1582
1583 // Check if this is an Rx or Tx type
1584 if shape.module_path == Some("roam_session")
1585 && (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
1586 {
1587 // Read the channel_id field
1588 if let Ok(ps) = peek.into_struct()
1589 && let Ok(channel_id_field) = ps.field_by_name("channel_id")
1590 && let Ok(&channel_id) = channel_id_field.get::<ChannelId>()
1591 {
1592 ids.push(channel_id);
1593 }
1594 return;
1595 }
1596
1597 // Recurse into struct/tuple fields
1598 if let Ok(ps) = peek.into_struct() {
1599 let field_count = ps.field_count();
1600 for i in 0..field_count {
1601 if let Ok(field_peek) = ps.field(i) {
1602 collect_channel_ids_recursive(field_peek, ids);
1603 }
1604 }
1605 return;
1606 }
1607
1608 // Recurse into Option<T> (specialized handling)
1609 if let Ok(po) = peek.into_option() {
1610 if let Some(inner) = po.value() {
1611 collect_channel_ids_recursive(inner, ids);
1612 }
1613 return;
1614 }
1615
1616 // Recurse into enum variants (for other enums with data)
1617 if let Ok(pe) = peek.into_enum() {
1618 // Try to get the first field of the active variant (e.g., Some(T) has one field)
1619 if let Ok(Some(variant_peek)) = pe.field(0) {
1620 collect_channel_ids_recursive(variant_peek, ids);
1621 }
1622 return;
1623 }
1624
1625 // Recurse into sequences (e.g., Vec<Tx<T>>)
1626 if let Ok(pl) = peek.into_list() {
1627 for element in pl.iter() {
1628 collect_channel_ids_recursive(element, ids);
1629 }
1630 }
1631}
1632
1633/// Patch channel IDs into deserialized args by walking with Poke.
1634///
1635/// Overwrites channel_id fields in Rx/Tx in declaration order.
1636/// Used by the server to apply the authoritative `channels` vec from Request.
1637pub fn patch_channel_ids<T: Facet<'static>>(args: &mut T, channels: &[u64]) {
1638 debug!(channels = ?channels, "patch_channel_ids: patching channels from wire");
1639 let mut idx = 0;
1640 let poke = facet::Poke::new(args);
1641 patch_channel_ids_recursive(poke, channels, &mut idx);
1642}
1643
1644#[allow(unsafe_code)]
1645fn patch_channel_ids_recursive(mut poke: facet::Poke<'_, '_>, channels: &[u64], idx: &mut usize) {
1646 use facet::Def;
1647
1648 let shape = poke.shape();
1649
1650 // Check if this is an Rx or Tx type
1651 if shape.module_path == Some("roam_session")
1652 && (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
1653 {
1654 // Overwrite the channel_id field
1655 if let Ok(mut ps) = poke.into_struct()
1656 && let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
1657 && let Ok(channel_id_ref) = channel_id_field.get_mut::<ChannelId>()
1658 && *idx < channels.len()
1659 {
1660 *channel_id_ref = channels[*idx];
1661 *idx += 1;
1662 }
1663 return;
1664 }
1665
1666 // Dispatch based on the shape's definition
1667 match shape.def {
1668 Def::Scalar => {}
1669
1670 // Recurse into struct/tuple fields
1671 _ if poke.is_struct() => {
1672 let mut ps = poke.into_struct().expect("is_struct was true");
1673 let field_count = ps.field_count();
1674 for i in 0..field_count {
1675 if let Ok(field_poke) = ps.field(i) {
1676 patch_channel_ids_recursive(field_poke, channels, idx);
1677 }
1678 }
1679 }
1680
1681 // Recurse into Option<T>
1682 Def::Option(_) => {
1683 // Option is represented as an enum, use into_enum to access its value
1684 if let Ok(mut pe) = poke.into_enum()
1685 && let Ok(Some(inner_poke)) = pe.field(0)
1686 {
1687 patch_channel_ids_recursive(inner_poke, channels, idx);
1688 }
1689 }
1690
1691 // Recurse into list elements (e.g., Vec<Tx<T>>)
1692 Def::List(list_def) => {
1693 let len = {
1694 let peek = poke.as_peek();
1695 peek.into_list().map(|pl| pl.len()).unwrap_or(0)
1696 };
1697 // Get mutable access to elements via VTable (no PokeList exists)
1698 if let Some(get_mut_fn) = list_def.vtable.get_mut {
1699 let element_shape = list_def.t;
1700 let data_ptr = poke.data_mut();
1701 for i in 0..len {
1702 // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
1703 let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
1704 if let Some(ptr) = element_ptr {
1705 // SAFETY: ptr points to a valid element with the correct shape
1706 let element_poke =
1707 unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
1708 patch_channel_ids_recursive(element_poke, channels, idx);
1709 }
1710 }
1711 }
1712 }
1713
1714 // Other enum variants
1715 _ if poke.is_enum() => {
1716 if let Ok(mut pe) = poke.into_enum()
1717 && let Ok(Some(variant_poke)) = pe.field(0)
1718 {
1719 patch_channel_ids_recursive(variant_poke, channels, idx);
1720 }
1721 }
1722
1723 _ => {}
1724 }
1725}
1726
1727// ============================================================================
1728// Service Dispatcher
1729// ============================================================================
1730
1731/// Context passed to service method implementations.
1732///
1733/// Contains information about the request that may be useful to the handler:
1734/// - `conn_id`: Which virtual connection the request came from
1735/// - `metadata`: Key-value pairs sent with the request
1736///
1737/// This enables services to identify callers and access per-request metadata.
1738#[derive(Debug, Clone)]
1739pub struct Context {
1740 /// The connection ID this request arrived on.
1741 ///
1742 /// For virtual connections, this identifies which specific connection
1743 /// the request came from, enabling bidirectional communication.
1744 pub conn_id: roam_wire::ConnectionId,
1745
1746 /// The request ID for this call.
1747 ///
1748 /// Unique within the connection; used for response routing and cancellation.
1749 pub request_id: roam_wire::RequestId,
1750
1751 /// The method ID being called.
1752 pub method_id: roam_wire::MethodId,
1753
1754 /// Metadata sent with the request.
1755 ///
1756 /// This is the `metadata` field from the wire `Request` message.
1757 pub metadata: roam_wire::Metadata,
1758
1759 /// Channel IDs from the request, in argument declaration order.
1760 ///
1761 /// Used for stream binding. Proxies can use this to remap channel IDs.
1762 pub channels: Vec<u64>,
1763}
1764
1765impl Context {
1766 /// Create a new context.
1767 pub fn new(
1768 conn_id: roam_wire::ConnectionId,
1769 request_id: roam_wire::RequestId,
1770 method_id: roam_wire::MethodId,
1771 metadata: roam_wire::Metadata,
1772 channels: Vec<u64>,
1773 ) -> Self {
1774 Self {
1775 conn_id,
1776 request_id,
1777 method_id,
1778 metadata,
1779 channels,
1780 }
1781 }
1782
1783 /// Get the connection ID.
1784 pub fn conn_id(&self) -> roam_wire::ConnectionId {
1785 self.conn_id
1786 }
1787
1788 /// Get the request ID.
1789 pub fn request_id(&self) -> roam_wire::RequestId {
1790 self.request_id
1791 }
1792
1793 /// Get the method ID.
1794 pub fn method_id(&self) -> roam_wire::MethodId {
1795 self.method_id
1796 }
1797
1798 /// Get the request metadata.
1799 pub fn metadata(&self) -> &roam_wire::Metadata {
1800 &self.metadata
1801 }
1802
1803 /// Get the channel IDs.
1804 pub fn channels(&self) -> &[u64] {
1805 &self.channels
1806 }
1807}
1808
1809/// Trait for dispatching requests to a service.
1810///
1811/// The dispatcher handles both simple and channeling methods uniformly.
1812/// Stream binding is done via reflection (Poke) on the deserialized args.
1813pub trait ServiceDispatcher: Send + Sync {
1814 /// Returns the method IDs this dispatcher handles.
1815 ///
1816 /// Used by [`RoutedDispatcher`] to determine which methods to route
1817 /// to which dispatcher.
1818 fn method_ids(&self) -> Vec<u64>;
1819
1820 /// Dispatch a request and send the response via the task channel.
1821 ///
1822 /// The dispatcher is responsible for:
1823 /// - Looking up the method by `cx.method_id()`
1824 /// - Deserializing arguments from payload
1825 /// - Patching channel IDs from `cx.channels()` into deserialized args via `patch_channel_ids()`
1826 /// - Binding any Tx/Rx streams via the registry
1827 /// - Calling the service method
1828 /// - Sending Data/Close messages for any Tx streams
1829 /// - Sending the Response message via DriverMessage::Response
1830 ///
1831 /// By using a single channel for Data/Close/Response, correct ordering is guaranteed:
1832 /// all stream Data and Close messages are sent before the Response.
1833 ///
1834 /// The `cx.channels()` contains channel IDs from the Request message framing,
1835 /// in declaration order. For a ForwardingDispatcher, this enables transparent proxying
1836 /// without parsing the payload.
1837 ///
1838 /// Returns a boxed future with `'static` lifetime so it can be spawned.
1839 /// Implementations should clone their service into the future to achieve this.
1840 ///
1841 /// r[impl channeling.allocation.caller] - Stream IDs are from Request.channels (caller allocated).
1842 fn dispatch(
1843 &self,
1844 cx: &Context,
1845 payload: Vec<u8>,
1846 registry: &mut ChannelRegistry,
1847 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>;
1848}
1849
1850/// A dispatcher that routes to one of two dispatchers based on method ID.
1851///
1852/// Methods handled by `primary` (via [`ServiceDispatcher::method_ids`]) are
1853/// routed to it; all other methods are routed to `fallback`.
1854pub struct RoutedDispatcher<A, B> {
1855 primary: A,
1856 fallback: B,
1857 primary_methods: Vec<u64>,
1858}
1859
1860impl<A, B> RoutedDispatcher<A, B>
1861where
1862 A: ServiceDispatcher,
1863{
1864 /// Create a new routed dispatcher.
1865 ///
1866 /// Methods declared by `primary.method_ids()` are routed to `primary`,
1867 /// all others to `fallback`.
1868 pub fn new(primary: A, fallback: B) -> Self {
1869 let primary_methods = primary.method_ids();
1870 Self {
1871 primary,
1872 fallback,
1873 primary_methods,
1874 }
1875 }
1876}
1877
1878impl<A, B> ServiceDispatcher for RoutedDispatcher<A, B>
1879where
1880 A: ServiceDispatcher,
1881 B: ServiceDispatcher,
1882{
1883 fn method_ids(&self) -> Vec<u64> {
1884 let mut ids = self.primary_methods.clone();
1885 ids.extend(self.fallback.method_ids());
1886 ids
1887 }
1888
1889 fn dispatch(
1890 &self,
1891 cx: &Context,
1892 payload: Vec<u8>,
1893 registry: &mut ChannelRegistry,
1894 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1895 if self.primary_methods.contains(&cx.method_id().raw()) {
1896 self.primary.dispatch(cx, payload, registry)
1897 } else {
1898 self.fallback.dispatch(cx, payload, registry)
1899 }
1900 }
1901}
1902
1903// ============================================================================
1904// ForwardingDispatcher - Transparent RPC Proxy
1905// ============================================================================
1906
1907/// A dispatcher that forwards all requests to an upstream connection.
1908///
1909/// This enables transparent proxying without knowing the service schema.
1910/// Channel IDs are remapped automatically: the proxy allocates new channel IDs
1911/// for the upstream connection and maintains bidirectional forwarding.
1912///
1913/// # Example
1914///
1915/// ```ignore
1916/// use roam_session::{ForwardingDispatcher, ConnectionHandle};
1917///
1918/// // Upstream connection to the actual service
1919/// let upstream: ConnectionHandle = /* ... */;
1920///
1921/// // Create a forwarding dispatcher
1922/// let proxy = ForwardingDispatcher::new(upstream);
1923///
1924/// // Use with accept() - all calls will be forwarded to upstream
1925/// let (handle, driver) = accept(stream, config, proxy).await?;
1926/// ```
1927pub struct ForwardingDispatcher {
1928 upstream: ConnectionHandle,
1929}
1930
1931impl ForwardingDispatcher {
1932 /// Create a new forwarding dispatcher that proxies to the upstream connection.
1933 pub fn new(upstream: ConnectionHandle) -> Self {
1934 Self { upstream }
1935 }
1936}
1937
1938impl Clone for ForwardingDispatcher {
1939 fn clone(&self) -> Self {
1940 Self {
1941 upstream: self.upstream.clone(),
1942 }
1943 }
1944}
1945
1946impl ServiceDispatcher for ForwardingDispatcher {
1947 /// Returns empty - this dispatcher accepts all method IDs.
1948 fn method_ids(&self) -> Vec<u64> {
1949 vec![]
1950 }
1951
1952 fn dispatch(
1953 &self,
1954 cx: &Context,
1955 payload: Vec<u8>,
1956 registry: &mut ChannelRegistry,
1957 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1958 let task_tx = registry.driver_tx();
1959 let upstream = self.upstream.clone();
1960 let conn_id = cx.conn_id;
1961 let method_id = cx.method_id.raw();
1962 let request_id = cx.request_id.raw();
1963 let channels = cx.channels.clone();
1964
1965 if channels.is_empty() {
1966 // Unary call - but response may contain Rx<T> channels
1967 // We need to set up forwarding for any response channels.
1968 //
1969 // IMPORTANT: Upstream and downstream use different channel ID spaces.
1970 // The upstream channel IDs must be remapped to downstream channel IDs.
1971 let downstream_channel_ids = registry.response_channel_ids();
1972
1973 Box::pin(async move {
1974 let response = upstream
1975 .call_raw_with_channels(method_id, vec![], payload, None)
1976 .await;
1977
1978 let (response_payload, upstream_response_channels) = match response {
1979 Ok(data) => (data.payload, data.channels),
1980 Err(TransportError::Encode(_)) => {
1981 // Should not happen for raw call
1982 (vec![1, 2], Vec::new()) // Err(InvalidPayload)
1983 }
1984 Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
1985 // Connection to upstream failed - return Cancelled
1986 (vec![1, 3], Vec::new()) // Err(Cancelled)
1987 }
1988 };
1989
1990 // If response has channels (e.g., method returns Rx<T>),
1991 // set up forwarding for Data from upstream to downstream.
1992 // We allocate new downstream channel IDs and remap when forwarding.
1993 let mut downstream_channels = Vec::new();
1994 if !upstream_response_channels.is_empty() {
1995 debug!(
1996 upstream_channels = ?upstream_response_channels,
1997 "ForwardingDispatcher: setting up response channel forwarding"
1998 );
1999 for &upstream_id in &upstream_response_channels {
2000 // Allocate a downstream channel ID
2001 let downstream_id = downstream_channel_ids.next();
2002 downstream_channels.push(downstream_id);
2003
2004 debug!(
2005 upstream_id,
2006 downstream_id, "ForwardingDispatcher: mapping channel IDs"
2007 );
2008
2009 // Set up forwarding: upstream → downstream
2010 let (tx, mut rx) = crate::runtime::channel::<Vec<u8>>(64);
2011 upstream.register_incoming(upstream_id, tx);
2012
2013 let task_tx_clone = task_tx.clone();
2014 crate::runtime::spawn(async move {
2015 debug!(
2016 upstream_id,
2017 downstream_id, "ForwardingDispatcher: forwarding task started"
2018 );
2019 while let Some(data) = rx.recv().await {
2020 debug!(
2021 upstream_id,
2022 downstream_id,
2023 data_len = data.len(),
2024 "ForwardingDispatcher: forwarding data"
2025 );
2026 let _ = task_tx_clone
2027 .send(DriverMessage::Data {
2028 conn_id,
2029 channel_id: downstream_id,
2030 payload: data,
2031 })
2032 .await;
2033 }
2034 debug!(
2035 upstream_id,
2036 downstream_id,
2037 "ForwardingDispatcher: forwarding task ended, sending Close"
2038 );
2039 // Channel closed
2040 let _ = task_tx_clone
2041 .send(DriverMessage::Close {
2042 conn_id,
2043 channel_id: downstream_id,
2044 })
2045 .await;
2046 });
2047 }
2048 }
2049
2050 let _ = task_tx
2051 .send(DriverMessage::Response {
2052 conn_id,
2053 request_id,
2054 channels: downstream_channels,
2055 payload: response_payload,
2056 })
2057 .await;
2058 })
2059 } else {
2060 // Streaming call - set up bidirectional channel forwarding
2061 //
2062 // IMPORTANT: We must send the upstream Request BEFORE any Data is
2063 // forwarded, otherwise the backend will reject Data for unknown channels.
2064 //
2065 // Strategy:
2066 // 1. Register incoming handlers synchronously (buffers Data in mpsc channels)
2067 // 2. In the async block: send Request first, then spawn forwarding tasks
2068 // (spawning AFTER Request is sent is safe - ordering is established)
2069
2070 // Allocate upstream channel IDs and set up buffering channels
2071 let mut upstream_channels = Vec::with_capacity(channels.len());
2072 let mut ds_to_us_rxs = Vec::with_capacity(channels.len());
2073 let mut us_to_ds_rxs = Vec::with_capacity(channels.len());
2074 let mut channel_map = Vec::with_capacity(channels.len());
2075
2076 let upstream_task_tx = upstream.driver_tx();
2077
2078 for &downstream_id in &channels {
2079 let upstream_id = upstream.alloc_channel_id();
2080 upstream_channels.push(upstream_id);
2081 channel_map.push((downstream_id, upstream_id));
2082
2083 // Buffer for downstream → upstream (client sends Data)
2084 let (ds_to_us_tx, ds_to_us_rx) = crate::runtime::channel(64);
2085 registry.register_incoming(downstream_id, ds_to_us_tx);
2086 ds_to_us_rxs.push(ds_to_us_rx);
2087
2088 // Buffer for upstream → downstream (server sends Data)
2089 let (us_to_ds_tx, us_to_ds_rx) = crate::runtime::channel(64);
2090 upstream.register_incoming(upstream_id, us_to_ds_tx);
2091 us_to_ds_rxs.push(us_to_ds_rx);
2092 }
2093
2094 // Everything below runs in the async block
2095 Box::pin(async move {
2096 // Send the upstream Request - this queues the Request command
2097 // which will be sent before any Data we forward
2098 let response_future =
2099 upstream.call_raw_with_channels(method_id, upstream_channels, payload, None);
2100
2101 // Now spawn forwarding tasks - safe because Request is queued first
2102 // and command_tx/task_tx are processed in order by the driver
2103 let upstream_conn_id = upstream.conn_id();
2104 for (i, mut rx) in ds_to_us_rxs.into_iter().enumerate() {
2105 let upstream_id = channel_map[i].1;
2106 let upstream_task_tx = upstream_task_tx.clone();
2107 crate::runtime::spawn(async move {
2108 while let Some(data) = rx.recv().await {
2109 let _ = upstream_task_tx
2110 .send(DriverMessage::Data {
2111 conn_id: upstream_conn_id,
2112 channel_id: upstream_id,
2113 payload: data,
2114 })
2115 .await;
2116 }
2117 // Channel closed
2118 let _ = upstream_task_tx
2119 .send(DriverMessage::Close {
2120 conn_id: upstream_conn_id,
2121 channel_id: upstream_id,
2122 })
2123 .await;
2124 });
2125 }
2126
2127 for (i, mut rx) in us_to_ds_rxs.into_iter().enumerate() {
2128 let downstream_id = channel_map[i].0;
2129 let task_tx = task_tx.clone();
2130 crate::runtime::spawn(async move {
2131 while let Some(data) = rx.recv().await {
2132 let _ = task_tx
2133 .send(DriverMessage::Data {
2134 conn_id,
2135 channel_id: downstream_id,
2136 payload: data,
2137 })
2138 .await;
2139 }
2140 // Channel closed
2141 let _ = task_tx
2142 .send(DriverMessage::Close {
2143 conn_id,
2144 channel_id: downstream_id,
2145 })
2146 .await;
2147 });
2148 }
2149
2150 // Wait for upstream response
2151 let response = response_future.await;
2152
2153 let (response_payload, upstream_response_channels) = match response {
2154 Ok(data) => (data.payload, data.channels),
2155 Err(TransportError::Encode(_)) => {
2156 (vec![1, 2], Vec::new()) // Err(InvalidPayload)
2157 }
2158 Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
2159 (vec![1, 3], Vec::new()) // Err(Cancelled)
2160 }
2161 };
2162
2163 // Map upstream response channels back to downstream channel IDs.
2164 // The downstream client allocated the original IDs and expects them
2165 // in the Response, not the upstream IDs we allocated for forwarding.
2166 let downstream_response_channels: Vec<u64> = upstream_response_channels
2167 .iter()
2168 .filter_map(|&upstream_id| {
2169 channel_map
2170 .iter()
2171 .find(|(_, us)| *us == upstream_id)
2172 .map(|(ds, _)| *ds)
2173 })
2174 .collect();
2175
2176 let _ = task_tx
2177 .send(DriverMessage::Response {
2178 conn_id,
2179 request_id,
2180 channels: downstream_response_channels,
2181 payload: response_payload,
2182 })
2183 .await;
2184 })
2185 }
2186 }
2187}
2188
2189// ============================================================================
2190// LateBoundForwarder - Forwarding with Deferred Handle Binding
2191// ============================================================================
2192
2193/// A handle that can be set once after creation.
2194///
2195/// This solves the chicken-and-egg problem in bidirectional proxying where:
2196/// 1. You need to pass a dispatcher to `connect()` for reverse-direction calls
2197/// 2. But the dispatcher needs a handle that's only available after `accept_framed()`
2198///
2199/// # Example
2200///
2201/// ```ignore
2202/// // Create the late-bound handle (empty initially)
2203/// let late_bound = LateBoundHandle::new();
2204///
2205/// // Pass a forwarder using this handle to connect()
2206/// let virtual_conn = handle.connect(
2207/// metadata,
2208/// Some(Box::new(LateBoundForwarder::new(late_bound.clone()))),
2209/// ).await?;
2210///
2211/// // Accept the other connection to get its handle
2212/// let (browser_handle, driver) = accept_framed(transport, config, dispatcher).await?;
2213///
2214/// // NOW bind the handle - any incoming calls will be forwarded
2215/// late_bound.set(browser_handle);
2216/// ```
2217#[derive(Clone)]
2218pub struct LateBoundHandle {
2219 inner: Arc<std::sync::OnceLock<ConnectionHandle>>,
2220}
2221
2222impl LateBoundHandle {
2223 /// Create a new unbound handle.
2224 pub fn new() -> Self {
2225 Self {
2226 inner: Arc::new(std::sync::OnceLock::new()),
2227 }
2228 }
2229
2230 /// Bind the handle to a connection. Can only be called once.
2231 ///
2232 /// # Panics
2233 ///
2234 /// Panics if called more than once.
2235 pub fn set(&self, handle: ConnectionHandle) {
2236 if self.inner.set(handle).is_err() {
2237 panic!("LateBoundHandle::set called more than once");
2238 }
2239 }
2240
2241 /// Try to get the bound handle, if set.
2242 pub fn get(&self) -> Option<&ConnectionHandle> {
2243 self.inner.get()
2244 }
2245}
2246
2247impl Default for LateBoundHandle {
2248 fn default() -> Self {
2249 Self::new()
2250 }
2251}
2252
2253/// A dispatcher that forwards all requests to a late-bound upstream connection.
2254///
2255/// Like [`ForwardingDispatcher`], but the upstream handle is provided after creation
2256/// via [`LateBoundHandle::set`]. This enables bidirectional proxying scenarios.
2257///
2258/// If a request arrives before the handle is bound, it returns `Cancelled`.
2259///
2260/// # Example
2261///
2262/// ```ignore
2263/// // Create late-bound handle and forwarder
2264/// let late_bound = LateBoundHandle::new();
2265/// let forwarder = LateBoundForwarder::new(late_bound.clone());
2266///
2267/// // Use forwarder with connect() for reverse-direction calls
2268/// let virtual_conn = handle.connect(metadata, Some(Box::new(forwarder))).await?;
2269///
2270/// // Later, bind the actual handle
2271/// let (browser_handle, driver) = accept_framed(...).await?;
2272/// late_bound.set(browser_handle);
2273/// ```
2274pub struct LateBoundForwarder {
2275 upstream: LateBoundHandle,
2276}
2277
2278impl LateBoundForwarder {
2279 /// Create a new late-bound forwarding dispatcher.
2280 pub fn new(upstream: LateBoundHandle) -> Self {
2281 Self { upstream }
2282 }
2283}
2284
2285impl Clone for LateBoundForwarder {
2286 fn clone(&self) -> Self {
2287 Self {
2288 upstream: self.upstream.clone(),
2289 }
2290 }
2291}
2292
2293impl ServiceDispatcher for LateBoundForwarder {
2294 fn method_ids(&self) -> Vec<u64> {
2295 vec![]
2296 }
2297
2298 fn dispatch(
2299 &self,
2300 cx: &Context,
2301 payload: Vec<u8>,
2302 registry: &mut ChannelRegistry,
2303 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
2304 let task_tx = registry.driver_tx();
2305 let conn_id = cx.conn_id;
2306 let request_id = cx.request_id.raw();
2307
2308 // Try to get the upstream handle
2309 let Some(upstream) = self.upstream.get().cloned() else {
2310 // Handle not bound yet - return Cancelled
2311 debug!(
2312 method_id = cx.method_id.raw(),
2313 "LateBoundForwarder: upstream not bound, returning Cancelled"
2314 );
2315 return Box::pin(async move {
2316 let _ = task_tx
2317 .send(DriverMessage::Response {
2318 conn_id,
2319 request_id,
2320 channels: vec![],
2321 payload: vec![1, 3], // Err(Cancelled)
2322 })
2323 .await;
2324 });
2325 };
2326
2327 // Delegate to ForwardingDispatcher now that we have the handle
2328 ForwardingDispatcher::new(upstream).dispatch(cx, payload, registry)
2329 }
2330}
2331
2332// TODO: Remove this shim once facet implements `Facet` for `core::convert::Infallible`
2333// and for the never type `!` (facet-rs/facet#1668), then use `Infallible`.
2334#[derive(Debug, Clone, PartialEq, Eq, Facet)]
2335pub struct Never;
2336
2337/// Call error type encoded in RPC responses.
2338///
2339/// r[impl core.error.roam-error] - Wraps call results to distinguish app vs protocol errors
2340/// r[impl call.response.encoding] - Response is `Result<T, RoamError<E>>`
2341/// r[impl call.error.roam-error] - Protocol errors use RoamError variants
2342/// r[impl call.error.protocol] - Discriminants 1-3 are protocol-level errors
2343///
2344/// Spec: `docs/content/spec/_index.md` "RoamError".
2345#[repr(u8)]
2346#[derive(Debug, Clone, PartialEq, Eq, Facet)]
2347pub enum RoamError<E> {
2348 /// r[impl core.error.call-vs-connection] - User errors affect only this call
2349 /// r[impl call.error.user] - User(E) carries the application's error type
2350 User(E) = 0,
2351 /// r[impl call.error.unknown-method] - Method ID not recognized
2352 UnknownMethod = 1,
2353 /// r[impl call.error.invalid-payload] - Request payload deserialization failed
2354 InvalidPayload = 2,
2355 Cancelled = 3,
2356}
2357
2358impl<E> RoamError<E> {
2359 /// Map the user error type to a different type.
2360 pub fn map_user<F, E2>(self, f: F) -> RoamError<E2>
2361 where
2362 F: FnOnce(E) -> E2,
2363 {
2364 match self {
2365 RoamError::User(e) => RoamError::User(f(e)),
2366 RoamError::UnknownMethod => RoamError::UnknownMethod,
2367 RoamError::InvalidPayload => RoamError::InvalidPayload,
2368 RoamError::Cancelled => RoamError::Cancelled,
2369 }
2370 }
2371}
2372
2373pub type CallResult<T, E> = ::core::result::Result<T, RoamError<E>>;
2374pub type BorrowedCallResult<T, E> = OwnedMessage<CallResult<T, E>>;
2375
2376// ============================================================================
2377// Connection Handle (Client-side API)
2378// ============================================================================
2379
2380/// Error from making an outgoing call.
2381///
2382/// This flattens the nested `Result<Result<T, RoamError<E>>, CallError>` pattern
2383/// into a single `Result<T, CallError<E>>` for better ergonomics.
2384///
2385/// The type parameter `E` represents the user's error type from fallible methods.
2386/// For infallible methods, use `CallError<Never>`.
2387#[derive(Debug)]
2388pub enum CallError<E = Never> {
2389 /// The remote returned a roam-level error (user error or protocol error).
2390 Roam(RoamError<E>),
2391 /// Failed to encode request payload.
2392 Encode(facet_postcard::SerializeError),
2393 /// Failed to decode response payload.
2394 Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
2395 /// Protocol-level decode error (malformed response structure).
2396 Protocol(DecodeError),
2397 /// Connection was closed before response.
2398 ConnectionClosed,
2399 /// Driver task is gone.
2400 DriverGone,
2401}
2402
2403impl<E> CallError<E> {
2404 /// Map the user error type to a different type.
2405 pub fn map_user<F, E2>(self, f: F) -> CallError<E2>
2406 where
2407 F: FnOnce(E) -> E2,
2408 {
2409 match self {
2410 CallError::Roam(roam_err) => CallError::Roam(roam_err.map_user(f)),
2411 CallError::Encode(e) => CallError::Encode(e),
2412 CallError::Decode(e) => CallError::Decode(e),
2413 CallError::Protocol(e) => CallError::Protocol(e),
2414 CallError::ConnectionClosed => CallError::ConnectionClosed,
2415 CallError::DriverGone => CallError::DriverGone,
2416 }
2417 }
2418}
2419
2420impl<E: std::fmt::Debug> std::fmt::Display for CallError<E> {
2421 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2422 match self {
2423 CallError::Roam(e) => write!(f, "roam error: {e:?}"),
2424 CallError::Encode(e) => write!(f, "encode error: {e}"),
2425 CallError::Decode(e) => write!(f, "decode error: {e}"),
2426 CallError::Protocol(e) => write!(f, "protocol error: {e}"),
2427 CallError::ConnectionClosed => write!(f, "connection closed"),
2428 CallError::DriverGone => write!(f, "driver task stopped"),
2429 }
2430 }
2431}
2432
2433impl<E: std::fmt::Debug> std::error::Error for CallError<E> {}
2434
2435/// Transport-level call error (no user error type).
2436///
2437/// Used by the `Caller` trait which operates at the transport level
2438/// before response decoding.
2439#[derive(Debug)]
2440pub enum TransportError {
2441 /// Failed to encode request payload.
2442 Encode(facet_postcard::SerializeError),
2443 /// Connection was closed before response.
2444 ConnectionClosed,
2445 /// Driver task is gone.
2446 DriverGone,
2447}
2448
2449impl<E> From<TransportError> for CallError<E> {
2450 fn from(e: TransportError) -> Self {
2451 match e {
2452 TransportError::Encode(e) => CallError::Encode(e),
2453 TransportError::ConnectionClosed => CallError::ConnectionClosed,
2454 TransportError::DriverGone => CallError::DriverGone,
2455 }
2456 }
2457}
2458
2459impl std::fmt::Display for TransportError {
2460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2461 match self {
2462 TransportError::Encode(e) => write!(f, "encode error: {e}"),
2463 TransportError::ConnectionClosed => write!(f, "connection closed"),
2464 TransportError::DriverGone => write!(f, "driver task stopped"),
2465 }
2466 }
2467}
2468
2469impl std::error::Error for TransportError {}
2470
2471/// Error decoding a response payload.
2472#[derive(Debug)]
2473pub enum DecodeError {
2474 /// Empty response payload.
2475 EmptyPayload,
2476 /// Truncated error response.
2477 TruncatedError,
2478 /// Unknown RoamError discriminant.
2479 UnknownRoamErrorDiscriminant(u8),
2480 /// Invalid Result discriminant.
2481 InvalidResultDiscriminant(u8),
2482 /// Postcard deserialization error.
2483 Postcard(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
2484}
2485
2486impl std::fmt::Display for DecodeError {
2487 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2488 match self {
2489 DecodeError::EmptyPayload => write!(f, "empty response payload"),
2490 DecodeError::TruncatedError => write!(f, "truncated error response"),
2491 DecodeError::UnknownRoamErrorDiscriminant(d) => {
2492 write!(f, "unknown RoamError discriminant: {d}")
2493 }
2494 DecodeError::InvalidResultDiscriminant(d) => {
2495 write!(f, "invalid Result discriminant: {d}")
2496 }
2497 DecodeError::Postcard(e) => write!(f, "postcard: {e}"),
2498 }
2499 }
2500}
2501
2502impl std::error::Error for DecodeError {}
2503
2504impl<E> From<DecodeError> for CallError<E> {
2505 fn from(e: DecodeError) -> Self {
2506 match e {
2507 DecodeError::Postcard(pe) => CallError::Decode(pe),
2508 other => CallError::Protocol(other),
2509 }
2510 }
2511}
2512
2513/// Decode a response payload into the expected type.
2514///
2515/// This is the core response decoding logic used by generated clients.
2516/// It handles the wire format: `[0] + value_bytes` for Ok, `[1, discriminant] + error_bytes` for Err.
2517///
2518/// Returns `Result<T, CallError<E>>` with the decoded value or error.
2519pub fn decode_response<T: Facet<'static>, E: Facet<'static>>(
2520 payload: &[u8],
2521) -> Result<T, CallError<E>> {
2522 if payload.is_empty() {
2523 return Err(DecodeError::EmptyPayload.into());
2524 }
2525
2526 match payload[0] {
2527 0 => {
2528 // Ok variant: deserialize the value
2529 facet_postcard::from_slice(&payload[1..]).map_err(CallError::Decode)
2530 }
2531 1 => {
2532 // Err variant: deserialize RoamError<E>
2533 if payload.len() < 2 {
2534 return Err(DecodeError::TruncatedError.into());
2535 }
2536 let roam_error = match payload[1] {
2537 0 => {
2538 // User error
2539 let user_error: E =
2540 facet_postcard::from_slice(&payload[2..]).map_err(CallError::Decode)?;
2541 RoamError::User(user_error)
2542 }
2543 1 => RoamError::UnknownMethod,
2544 2 => RoamError::InvalidPayload,
2545 3 => RoamError::Cancelled,
2546 d => return Err(DecodeError::UnknownRoamErrorDiscriminant(d).into()),
2547 };
2548 Err(CallError::Roam(roam_error))
2549 }
2550 d => Err(DecodeError::InvalidResultDiscriminant(d).into()),
2551 }
2552}
2553
2554/// Trait for making RPC calls.
2555///
2556/// This abstracts over different connection types (e.g., `ConnectionHandle`,
2557/// `ReconnectingClient`) so generated clients can work with any of them.
2558///
2559/// All callers return `TransportError` for transport-level failures.
2560/// Generated clients convert this to `CallError<E>` which also includes
2561/// response-level errors like `RoamError::User(E)`.
2562#[allow(async_fn_in_trait)]
2563pub trait Caller: Clone + Send + Sync + 'static {
2564 /// Make an RPC call with the given method ID and arguments.
2565 ///
2566 /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2567 /// assigned channel IDs before serialization.
2568 ///
2569 /// Returns ResponseData containing the payload and any response channel IDs.
2570 #[cfg(not(target_arch = "wasm32"))]
2571 fn call<T: Facet<'static> + Send>(
2572 &self,
2573 method_id: u64,
2574 args: &mut T,
2575 ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send {
2576 self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2577 }
2578
2579 /// Make an RPC call with the given method ID and arguments.
2580 ///
2581 /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2582 /// assigned channel IDs before serialization.
2583 ///
2584 /// Returns ResponseData containing the payload and any response channel IDs.
2585 #[cfg(target_arch = "wasm32")]
2586 fn call<T: Facet<'static> + Send>(
2587 &self,
2588 method_id: u64,
2589 args: &mut T,
2590 ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> {
2591 self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2592 }
2593
2594 /// Make an RPC call with the given method ID, arguments, and metadata.
2595 ///
2596 /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2597 /// assigned channel IDs before serialization.
2598 ///
2599 /// Returns ResponseData containing the payload and any response channel IDs.
2600 #[cfg(not(target_arch = "wasm32"))]
2601 fn call_with_metadata<T: Facet<'static> + Send>(
2602 &self,
2603 method_id: u64,
2604 args: &mut T,
2605 metadata: roam_wire::Metadata,
2606 ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send;
2607
2608 /// Make an RPC call with the given method ID, arguments, and metadata.
2609 ///
2610 /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2611 /// assigned channel IDs before serialization.
2612 ///
2613 /// Returns ResponseData containing the payload and any response channel IDs.
2614 #[cfg(target_arch = "wasm32")]
2615 fn call_with_metadata<T: Facet<'static> + Send>(
2616 &self,
2617 method_id: u64,
2618 args: &mut T,
2619 metadata: roam_wire::Metadata,
2620 ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>>;
2621
2622 /// Bind receivers for `Rx<T>` streams in the response.
2623 ///
2624 /// After deserializing a response, any `Rx<T>` values in it are "hollow" -
2625 /// they have channel IDs but no actual receiver. This method walks the
2626 /// response and binds receivers for each Rx using the channel IDs from
2627 /// the Response message.
2628 fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]);
2629}
2630
2631impl Caller for ConnectionHandle {
2632 async fn call_with_metadata<T: Facet<'static> + Send>(
2633 &self,
2634 method_id: u64,
2635 args: &mut T,
2636 metadata: roam_wire::Metadata,
2637 ) -> Result<ResponseData, TransportError> {
2638 ConnectionHandle::call_with_metadata(self, method_id, args, metadata).await
2639 }
2640
2641 fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
2642 ConnectionHandle::bind_response_streams(self, response, channels)
2643 }
2644}
2645
2646// ============================================================================
2647// CallFuture - Builder pattern for RPC calls with optional metadata
2648// ============================================================================
2649
2650/// A future representing an RPC call that can be configured with metadata.
2651///
2652/// This provides a builder pattern for RPC calls:
2653/// - `client.method(args).await` - Simple call with default (empty) metadata
2654/// - `client.method(args).with_metadata(meta).await` - Call with custom metadata
2655///
2656/// The future is lazy - the RPC call is not made until `.await` is called.
2657///
2658/// # Example
2659///
2660/// ```ignore
2661/// // Simple call
2662/// let result = client.subscribe(route).await?;
2663///
2664/// // With metadata
2665/// let result = client.subscribe(route)
2666/// .with_metadata(vec![("trace-id".into(), MetadataValue::String("abc".into()))])
2667/// .await?;
2668/// ```
2669pub struct CallFuture<C, Args, Ok, Err>
2670where
2671 C: Caller,
2672 Args: Facet<'static>,
2673{
2674 caller: C,
2675 method_id: u64,
2676 args: Args,
2677 metadata: roam_wire::Metadata,
2678 _phantom: PhantomData<fn() -> (Ok, Err)>,
2679}
2680
2681impl<C, Args, Ok, Err> CallFuture<C, Args, Ok, Err>
2682where
2683 C: Caller,
2684 Args: Facet<'static>,
2685{
2686 /// Create a new CallFuture.
2687 pub fn new(caller: C, method_id: u64, args: Args) -> Self {
2688 Self {
2689 caller,
2690 method_id,
2691 args,
2692 metadata: roam_wire::Metadata::default(),
2693 _phantom: PhantomData,
2694 }
2695 }
2696
2697 /// Set metadata for this call.
2698 ///
2699 /// Metadata is a list of key-value pairs that will be sent with the request.
2700 /// The server can access this via `Context::metadata()`.
2701 pub fn with_metadata(mut self, metadata: roam_wire::Metadata) -> Self {
2702 self.metadata = metadata;
2703 self
2704 }
2705}
2706
2707// On native, the future must be Send so it can be spawned on tokio.
2708// On WASM, futures don't need Send since everything is single-threaded.
2709#[cfg(not(target_arch = "wasm32"))]
2710impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
2711where
2712 C: Caller,
2713 Args: Facet<'static> + Send + 'static,
2714 Ok: Facet<'static> + Send + 'static,
2715 Err: Facet<'static> + Send + 'static,
2716{
2717 type Output = Result<Ok, CallError<Err>>;
2718 type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
2719
2720 fn into_future(self) -> Self::IntoFuture {
2721 let CallFuture {
2722 caller,
2723 method_id,
2724 mut args,
2725 metadata,
2726 _phantom,
2727 } = self;
2728
2729 Box::pin(async move {
2730 let response = caller
2731 .call_with_metadata(method_id, &mut args, metadata)
2732 .await
2733 .map_err(CallError::from)?;
2734 let mut result = decode_response::<Ok, Err>(&response.payload)?;
2735 caller.bind_response_streams(&mut result, &response.channels);
2736 Ok(result)
2737 })
2738 }
2739}
2740
2741#[cfg(target_arch = "wasm32")]
2742impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
2743where
2744 C: Caller,
2745 Args: Facet<'static> + Send + 'static,
2746 Ok: Facet<'static> + Send + 'static,
2747 Err: Facet<'static> + Send + 'static,
2748{
2749 type Output = Result<Ok, CallError<Err>>;
2750 type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output>>>;
2751
2752 fn into_future(self) -> Self::IntoFuture {
2753 let CallFuture {
2754 caller,
2755 method_id,
2756 mut args,
2757 metadata,
2758 _phantom,
2759 } = self;
2760
2761 Box::pin(async move {
2762 let response = caller
2763 .call_with_metadata(method_id, &mut args, metadata)
2764 .await
2765 .map_err(CallError::from)?;
2766 let mut result = decode_response::<Ok, Err>(&response.payload)?;
2767 caller.bind_response_streams(&mut result, &response.channels);
2768 Ok(result)
2769 })
2770 }
2771}
2772
2773/// Shared state between ConnectionHandle and Driver.
2774struct HandleShared {
2775 /// Connection ID for this handle (0 = root connection).
2776 conn_id: roam_wire::ConnectionId,
2777 /// Unified channel to send all messages to the driver.
2778 driver_tx: Sender<DriverMessage>,
2779 /// Request ID generator.
2780 request_ids: RequestIdGenerator,
2781 /// Stream ID allocator.
2782 channel_ids: ChannelIdAllocator,
2783 /// Stream registry for routing incoming data.
2784 /// Protected by a mutex since handles may create streams concurrently.
2785 channel_registry: std::sync::Mutex<ChannelRegistry>,
2786 /// Optional diagnostic state for SIGUSR1 dumps.
2787 diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
2788}
2789
2790/// Handle for making outgoing RPC calls.
2791///
2792/// This is the client-side API. It can be cloned and used from multiple tasks.
2793/// The actual I/O is driven by the `Driver` future which must be spawned.
2794///
2795/// # Example
2796///
2797/// ```ignore
2798/// let (handle, driver) = establish_connection(transport, dispatcher).await?;
2799/// tokio::spawn(driver);
2800///
2801/// // Use handle to make calls
2802/// let response = handle.call_raw(method_id, payload).await?;
2803/// ```
2804#[derive(Clone)]
2805pub struct ConnectionHandle {
2806 shared: Arc<HandleShared>,
2807}
2808
2809impl ConnectionHandle {
2810 /// Create a new handle for the root connection (conn_id = 0).
2811 ///
2812 /// All messages (Call/Data/Close/Response) go through a single unified channel
2813 /// to ensure FIFO ordering.
2814 pub fn new(driver_tx: Sender<DriverMessage>, role: Role, initial_credit: u32) -> Self {
2815 Self::new_with_diagnostics(
2816 roam_wire::ConnectionId::ROOT,
2817 driver_tx,
2818 role,
2819 initial_credit,
2820 None,
2821 )
2822 }
2823
2824 /// Create a new handle with a specific connection ID and optional diagnostic state.
2825 ///
2826 /// If `diagnostic_state` is provided, all RPC calls and channels will be tracked
2827 /// for debugging purposes.
2828 pub fn new_with_diagnostics(
2829 conn_id: roam_wire::ConnectionId,
2830 driver_tx: Sender<DriverMessage>,
2831 role: Role,
2832 initial_credit: u32,
2833 diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
2834 ) -> Self {
2835 let channel_registry = ChannelRegistry::new_with_credit(initial_credit, driver_tx.clone());
2836 Self {
2837 shared: Arc::new(HandleShared {
2838 conn_id,
2839 driver_tx,
2840 request_ids: RequestIdGenerator::new(),
2841 channel_ids: ChannelIdAllocator::new(role),
2842 channel_registry: std::sync::Mutex::new(channel_registry),
2843 diagnostic_state,
2844 }),
2845 }
2846 }
2847
2848 /// Get the connection ID for this handle.
2849 pub fn conn_id(&self) -> roam_wire::ConnectionId {
2850 self.shared.conn_id
2851 }
2852
2853 /// Get the diagnostic state, if any.
2854 pub fn diagnostic_state(&self) -> Option<&Arc<crate::diagnostic::DiagnosticState>> {
2855 self.shared.diagnostic_state.as_ref()
2856 }
2857
2858 /// Make a typed RPC call with automatic serialization and stream binding.
2859 ///
2860 /// Walks the args using Poke reflection to find any `Rx<T>` or `Tx<T>` fields,
2861 /// binds stream IDs, and sets up the stream infrastructure before serialization.
2862 ///
2863 /// # Arguments
2864 ///
2865 /// * `method_id` - The method ID to call
2866 /// * `args` - Arguments to serialize (typically a tuple of all method args).
2867 /// Must be mutable so stream IDs can be assigned.
2868 ///
2869 /// # Stream Binding
2870 ///
2871 /// For `Rx<T>` in args (caller passes receiver, keeps sender to push data):
2872 /// - Allocates a stream ID
2873 /// - Takes the receiver and spawns a task to drain it, sending Data messages
2874 /// - The caller keeps the `Tx<T>` from `roam::channel()` to send values
2875 ///
2876 /// For `Tx<T>` in args (caller passes sender, keeps receiver to pull data):
2877 /// - Allocates a stream ID
2878 /// - Takes the sender and registers for incoming Data routing
2879 /// - The caller keeps the `Rx<T>` from `roam::channel()` to receive values
2880 ///
2881 /// # Example
2882 ///
2883 /// ```ignore
2884 /// // For a streaming method sum(numbers: Rx<i32>) -> i64
2885 /// let (tx, rx) = roam::channel::<i32>();
2886 /// let response = handle.call(method_id::SUM, &mut (rx,)).await?;
2887 /// // tx.send(&42).await to push values
2888 /// ```
2889 /// Make an RPC call with default (empty) metadata.
2890 pub async fn call<T: Facet<'static>>(
2891 &self,
2892 method_id: u64,
2893 args: &mut T,
2894 ) -> Result<ResponseData, TransportError> {
2895 self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2896 .await
2897 }
2898
2899 /// Make an RPC call with custom metadata.
2900 pub async fn call_with_metadata<T: Facet<'static>>(
2901 &self,
2902 method_id: u64,
2903 args: &mut T,
2904 metadata: roam_wire::Metadata,
2905 ) -> Result<ResponseData, TransportError> {
2906 // Walk args and bind any streams (allocates channel IDs)
2907 // This collects receivers that need to be drained but does NOT spawn
2908 let mut drains = Vec::new();
2909 debug!("ConnectionHandle::call: binding streams");
2910 self.bind_streams(args, &mut drains);
2911
2912 // Collect channel IDs for the Request message
2913 let channels = collect_channel_ids(args);
2914 debug!(
2915 channels = ?channels,
2916 drain_count = drains.len(),
2917 "ConnectionHandle::call: collected channels after bind_streams"
2918 );
2919
2920 let payload = facet_postcard::to_vec(args).map_err(TransportError::Encode)?;
2921
2922 // Generate args debug info for diagnostics when enabled
2923 let args_debug = if diagnostic::debug_enabled() {
2924 Some(
2925 facet_pretty::PrettyPrinter::new()
2926 .with_colors(facet_pretty::ColorMode::Never)
2927 .with_max_content_len(64)
2928 .format(args),
2929 )
2930 } else {
2931 None
2932 };
2933
2934 if drains.is_empty() {
2935 // No Rx streams - simple call
2936 self.call_raw_with_channels_and_metadata(
2937 method_id, channels, payload, args_debug, metadata,
2938 )
2939 .await
2940 } else {
2941 // Has Rx streams - spawn tasks to drain them
2942 // IMPORTANT: We must send Request BEFORE spawning drain tasks to ensure ordering.
2943 // We need to actually send the DriverMessage::Call to the driver's queue
2944 // before spawning drains, not just create the future.
2945 let request_id = self.shared.request_ids.next();
2946 let (response_tx, response_rx) = oneshot();
2947
2948 // Track outgoing request for diagnostics
2949 if let Some(diag) = &self.shared.diagnostic_state {
2950 let args = args_debug.map(|s| {
2951 let mut map = std::collections::HashMap::new();
2952 map.insert("args".to_string(), s);
2953 map
2954 });
2955 diag.record_outgoing_request(request_id, method_id, args);
2956 // Associate channels with this request
2957 diag.associate_channels_with_request(&channels, request_id);
2958 }
2959
2960 let msg = DriverMessage::Call {
2961 conn_id: self.shared.conn_id,
2962 request_id,
2963 method_id,
2964 metadata,
2965 channels,
2966 payload,
2967 response_tx,
2968 };
2969
2970 // Send the Call message NOW, before spawning drain tasks
2971 if self.shared.driver_tx.send(msg).await.is_err() {
2972 return Err(TransportError::DriverGone);
2973 }
2974
2975 let task_tx = self.shared.channel_registry.lock().unwrap().driver_tx();
2976 let conn_id = self.shared.conn_id;
2977
2978 // Spawn a task for each drain to forward data to driver
2979 for (channel_id, mut rx) in drains {
2980 let task_tx = task_tx.clone();
2981 crate::runtime::spawn(async move {
2982 loop {
2983 match rx.recv().await {
2984 Some(payload) => {
2985 debug!(
2986 "drain task: received {} bytes on channel {}",
2987 payload.len(),
2988 channel_id
2989 );
2990 // Send data to driver
2991 let _ = task_tx
2992 .send(DriverMessage::Data {
2993 conn_id,
2994 channel_id,
2995 payload,
2996 })
2997 .await;
2998 debug!(
2999 "drain task: sent DriverMessage::Data for channel {}",
3000 channel_id
3001 );
3002 }
3003 None => {
3004 debug!("drain task: channel {} closed", channel_id);
3005 // Channel closed, send Close and exit
3006 let _ = task_tx
3007 .send(DriverMessage::Close {
3008 conn_id,
3009 channel_id,
3010 })
3011 .await;
3012 debug!(
3013 "drain task: sent DriverMessage::Close for channel {}",
3014 channel_id
3015 );
3016 break;
3017 }
3018 }
3019 }
3020 });
3021 }
3022
3023 // Just await the response - drain tasks run independently
3024 let result = response_rx
3025 .await
3026 .map_err(|_| TransportError::DriverGone)?
3027 .map_err(|_| TransportError::ConnectionClosed);
3028
3029 // Mark request as complete
3030 if let Some(diag) = &self.shared.diagnostic_state {
3031 diag.complete_request(request_id);
3032 }
3033
3034 result
3035 }
3036 }
3037
3038 /// Walk args and bind any Rx<T> or Tx<T> streams.
3039 /// Collects (channel_id, receiver) pairs for Rx streams that need draining.
3040 fn bind_streams<T: Facet<'static>>(
3041 &self,
3042 args: &mut T,
3043 drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3044 ) {
3045 let poke = facet::Poke::new(args);
3046 self.bind_streams_recursive(poke, drains);
3047 }
3048
3049 /// Recursively walk a Poke value looking for Rx/Tx streams to bind.
3050 #[allow(unsafe_code)]
3051 fn bind_streams_recursive(
3052 &self,
3053 mut poke: facet::Poke<'_, '_>,
3054 drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3055 ) {
3056 use facet::Def;
3057
3058 let shape = poke.shape();
3059
3060 // Check if this is an Rx or Tx type
3061 if shape.module_path == Some("roam_session") {
3062 if shape.type_identifier == "Rx" {
3063 self.bind_rx_stream(poke, drains);
3064 return;
3065 } else if shape.type_identifier == "Tx" {
3066 self.bind_tx_stream(poke);
3067 return;
3068 }
3069 }
3070
3071 // Dispatch based on the shape's definition
3072 match shape.def {
3073 Def::Scalar => {}
3074
3075 // Recurse into struct/tuple fields
3076 _ if poke.is_struct() => {
3077 let mut ps = poke.into_struct().expect("is_struct was true");
3078 let field_count = ps.field_count();
3079 for i in 0..field_count {
3080 if let Ok(field_poke) = ps.field(i) {
3081 self.bind_streams_recursive(field_poke, drains);
3082 }
3083 }
3084 }
3085
3086 // Recurse into Option<T>
3087 Def::Option(_) => {
3088 // Option is represented as an enum, use into_enum to access its value
3089 if let Ok(mut pe) = poke.into_enum()
3090 && let Ok(Some(inner_poke)) = pe.field(0)
3091 {
3092 self.bind_streams_recursive(inner_poke, drains);
3093 }
3094 }
3095
3096 // Recurse into list elements (e.g., Vec<Tx<T>>)
3097 Def::List(list_def) => {
3098 let len = {
3099 let peek = poke.as_peek();
3100 peek.into_list().map(|pl| pl.len()).unwrap_or(0)
3101 };
3102 // Get mutable access to elements via VTable (no PokeList exists)
3103 if let Some(get_mut_fn) = list_def.vtable.get_mut {
3104 let element_shape = list_def.t;
3105 let data_ptr = poke.data_mut();
3106 for i in 0..len {
3107 // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
3108 let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
3109 if let Some(ptr) = element_ptr {
3110 // SAFETY: ptr points to a valid element with the correct shape
3111 let element_poke =
3112 unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
3113 self.bind_streams_recursive(element_poke, drains);
3114 }
3115 }
3116 }
3117 }
3118
3119 // Other enum variants
3120 _ if poke.is_enum() => {
3121 if let Ok(mut pe) = poke.into_enum()
3122 && let Ok(Some(variant_poke)) = pe.field(0)
3123 {
3124 self.bind_streams_recursive(variant_poke, drains);
3125 }
3126 }
3127
3128 _ => {}
3129 }
3130 }
3131
3132 /// Bind an Rx<T> stream - caller passes receiver, keeps sender.
3133 /// Collects the receiver for draining (no spawning).
3134 fn bind_rx_stream(
3135 &self,
3136 poke: facet::Poke<'_, '_>,
3137 drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3138 ) {
3139 let channel_id = self.alloc_channel_id();
3140 debug!(
3141 channel_id,
3142 "OutgoingBinder::bind_rx_stream: allocated channel_id for Rx"
3143 );
3144
3145 if let Ok(mut ps) = poke.into_struct() {
3146 // Set channel_id field by getting mutable access to the u64
3147 if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
3148 && let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
3149 {
3150 debug!(
3151 old_id = *id_ref,
3152 new_id = channel_id,
3153 "OutgoingBinder::bind_rx_stream: overwriting channel_id"
3154 );
3155 *id_ref = channel_id;
3156 }
3157
3158 // Take the receiver from ReceiverSlot - collect for draining later
3159 if let Ok(mut receiver_field) = ps.field_by_name("receiver")
3160 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
3161 && let Some(rx) = slot.take()
3162 {
3163 debug!(
3164 channel_id,
3165 "OutgoingBinder::bind_rx_stream: took receiver, adding to drains"
3166 );
3167 drains.push((channel_id, rx));
3168 }
3169 }
3170 }
3171
3172 /// Bind a Tx<T> stream - caller passes sender, keeps receiver.
3173 /// We take the sender and register for incoming Data routing.
3174 fn bind_tx_stream(&self, poke: facet::Poke<'_, '_>) {
3175 let channel_id = self.alloc_channel_id();
3176 debug!(
3177 channel_id,
3178 "OutgoingBinder::bind_tx_stream: allocated channel_id for Tx"
3179 );
3180
3181 if let Ok(mut ps) = poke.into_struct() {
3182 // Set channel_id field by getting mutable access to the u64
3183 if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
3184 && let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
3185 {
3186 debug!(
3187 old_id = *id_ref,
3188 new_id = channel_id,
3189 "OutgoingBinder::bind_tx_stream: overwriting channel_id"
3190 );
3191 *id_ref = channel_id;
3192 }
3193
3194 // Take the sender from SenderSlot
3195 if let Ok(mut sender_field) = ps.field_by_name("sender")
3196 && let Ok(slot) = sender_field.get_mut::<SenderSlot>()
3197 && let Some(tx) = slot.take()
3198 {
3199 debug!(
3200 channel_id,
3201 "OutgoingBinder::bind_tx_stream: took sender, registering for incoming"
3202 );
3203 // Register for incoming Data routing
3204 self.register_incoming(channel_id, tx);
3205 }
3206 }
3207 }
3208
3209 /// Make a raw RPC call with pre-serialized payload.
3210 ///
3211 /// Returns the raw response payload bytes.
3212 /// Note: For streaming calls, use `call()` which handles channel binding.
3213 pub async fn call_raw(
3214 &self,
3215 method_id: u64,
3216 payload: Vec<u8>,
3217 ) -> Result<Vec<u8>, TransportError> {
3218 self.call_raw_full(method_id, Vec::new(), Vec::new(), payload, None)
3219 .await
3220 .map(|r| r.payload)
3221 }
3222
3223 /// Make a raw RPC call with pre-serialized payload and channel IDs.
3224 ///
3225 /// Used internally by `call()` after binding streams.
3226 /// Returns ResponseData so caller can handle response channels.
3227 async fn call_raw_with_channels(
3228 &self,
3229 method_id: u64,
3230 channels: Vec<u64>,
3231 payload: Vec<u8>,
3232 args_debug: Option<String>,
3233 ) -> Result<ResponseData, TransportError> {
3234 self.call_raw_full(method_id, Vec::new(), channels, payload, args_debug)
3235 .await
3236 }
3237
3238 async fn call_raw_with_channels_and_metadata(
3239 &self,
3240 method_id: u64,
3241 channels: Vec<u64>,
3242 payload: Vec<u8>,
3243 args_debug: Option<String>,
3244 metadata: roam_wire::Metadata,
3245 ) -> Result<ResponseData, TransportError> {
3246 self.call_raw_full(method_id, metadata, channels, payload, args_debug)
3247 .await
3248 }
3249
3250 /// Make a raw RPC call with pre-serialized payload and metadata.
3251 ///
3252 /// Returns the raw response payload bytes.
3253 pub async fn call_raw_with_metadata(
3254 &self,
3255 method_id: u64,
3256 payload: Vec<u8>,
3257 metadata: Vec<(String, roam_wire::MetadataValue)>,
3258 ) -> Result<Vec<u8>, TransportError> {
3259 self.call_raw_full(method_id, metadata, Vec::new(), payload, None)
3260 .await
3261 .map(|r| r.payload)
3262 }
3263
3264 /// Make a raw RPC call with all options.
3265 ///
3266 /// Returns ResponseData containing the payload and any response channel IDs.
3267 async fn call_raw_full(
3268 &self,
3269 method_id: u64,
3270 metadata: Vec<(String, roam_wire::MetadataValue)>,
3271 channels: Vec<u64>,
3272 payload: Vec<u8>,
3273 args_debug: Option<String>,
3274 ) -> Result<ResponseData, TransportError> {
3275 let request_id = self.shared.request_ids.next();
3276 let (response_tx, response_rx) = oneshot();
3277
3278 // Track outgoing request for diagnostics
3279 if let Some(diag) = &self.shared.diagnostic_state {
3280 let args = args_debug.map(|s| {
3281 let mut map = std::collections::HashMap::new();
3282 map.insert("args".to_string(), s);
3283 map
3284 });
3285 diag.record_outgoing_request(request_id, method_id, args);
3286 // Associate channels with this request
3287 diag.associate_channels_with_request(&channels, request_id);
3288 }
3289
3290 let msg = DriverMessage::Call {
3291 conn_id: self.shared.conn_id,
3292 request_id,
3293 method_id,
3294 metadata,
3295 channels,
3296 payload,
3297 response_tx,
3298 };
3299
3300 self.shared
3301 .driver_tx
3302 .send(msg)
3303 .await
3304 .map_err(|_| TransportError::DriverGone)?;
3305
3306 let result = response_rx
3307 .await
3308 .map_err(|_| TransportError::DriverGone)?
3309 .map_err(|_| TransportError::ConnectionClosed);
3310
3311 // Mark request as complete
3312 if let Some(diag) = &self.shared.diagnostic_state {
3313 diag.complete_request(request_id);
3314 }
3315
3316 result
3317 }
3318
3319 /// Open a new virtual connection on the link.
3320 ///
3321 /// Sends a `Connect` message to the remote peer and waits for an
3322 /// `Accept` or `Reject` response. Returns a new `ConnectionHandle`
3323 /// for the virtual connection if accepted.
3324 ///
3325 /// r[impl core.conn.open]
3326 ///
3327 /// # Arguments
3328 ///
3329 /// * `metadata` - Optional metadata to send with the Connect request
3330 /// (e.g., authentication tokens, routing hints).
3331 /// * `dispatcher` - Optional dispatcher for handling incoming requests on the
3332 /// virtual connection. If None, the connection can only make calls, not receive them.
3333 ///
3334 /// # Example
3335 ///
3336 /// ```ignore
3337 /// // Open a new virtual connection that can receive calls
3338 /// let dispatcher = Box::new(MyDispatcher::new());
3339 /// let virtual_conn = handle.connect(vec![], Some(dispatcher)).await?;
3340 ///
3341 /// // Use the new connection for calls
3342 /// let response = virtual_conn.call_raw(method_id, payload).await?;
3343 /// ```
3344 pub async fn connect(
3345 &self,
3346 metadata: roam_wire::Metadata,
3347 dispatcher: Option<Box<dyn ServiceDispatcher>>,
3348 ) -> Result<ConnectionHandle, crate::ConnectError> {
3349 let request_id = self.shared.request_ids.next();
3350 let (response_tx, response_rx) = oneshot();
3351
3352 let msg = DriverMessage::Connect {
3353 request_id,
3354 metadata,
3355 response_tx,
3356 dispatcher,
3357 };
3358
3359 self.shared.driver_tx.send(msg).await.map_err(|_| {
3360 crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone"))
3361 })?;
3362
3363 response_rx
3364 .await
3365 .map_err(|_| crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone")))?
3366 }
3367
3368 /// Allocate a stream ID for an outgoing stream.
3369 ///
3370 /// Used internally when binding streams during call().
3371 pub fn alloc_channel_id(&self) -> ChannelId {
3372 self.shared.channel_ids.next()
3373 }
3374
3375 /// Allocate a unique request ID for an outgoing call.
3376 ///
3377 /// Used when manually constructing DriverMessage::Call.
3378 pub fn alloc_request_id(&self) -> u64 {
3379 self.shared.request_ids.next()
3380 }
3381
3382 /// Register an incoming stream (we receive data from peer).
3383 ///
3384 /// Used when schema has `Tx<T>` (callee sends to caller) - we receive that data.
3385 pub fn register_incoming(&self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
3386 // Track channel for diagnostics (request_id not available here)
3387 if let Some(diag) = &self.shared.diagnostic_state {
3388 diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Rx, None);
3389 }
3390 self.shared
3391 .channel_registry
3392 .lock()
3393 .unwrap()
3394 .register_incoming(channel_id, tx);
3395 }
3396
3397 /// Register credit tracking for an outgoing stream.
3398 ///
3399 /// The actual receiver is owned by the driver, not the registry.
3400 pub fn register_outgoing_credit(&self, channel_id: ChannelId) {
3401 // Track channel for diagnostics (request_id not available here)
3402 if let Some(diag) = &self.shared.diagnostic_state {
3403 diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Tx, None);
3404 }
3405 self.shared
3406 .channel_registry
3407 .lock()
3408 .unwrap()
3409 .register_outgoing_credit(channel_id);
3410 }
3411
3412 /// Route incoming stream data to the appropriate Rx.
3413 pub async fn route_data(
3414 &self,
3415 channel_id: ChannelId,
3416 payload: Vec<u8>,
3417 ) -> Result<(), ChannelError> {
3418 // Get the sender while holding the lock, then release before await
3419 let (tx, payload) = self
3420 .shared
3421 .channel_registry
3422 .lock()
3423 .unwrap()
3424 .prepare_route_data(channel_id, payload)?;
3425 // Send without holding the lock
3426 let _ = tx.send(payload).await;
3427 Ok(())
3428 }
3429
3430 /// Close an incoming stream.
3431 pub fn close_channel(&self, channel_id: ChannelId) {
3432 // Track channel close for diagnostics
3433 if let Some(diag) = &self.shared.diagnostic_state {
3434 diag.record_channel_close(channel_id);
3435 }
3436 self.shared
3437 .channel_registry
3438 .lock()
3439 .unwrap()
3440 .close(channel_id);
3441 }
3442
3443 /// Reset a stream.
3444 pub fn reset_channel(&self, channel_id: ChannelId) {
3445 // Track channel close for diagnostics
3446 if let Some(diag) = &self.shared.diagnostic_state {
3447 diag.record_channel_close(channel_id);
3448 }
3449 self.shared
3450 .channel_registry
3451 .lock()
3452 .unwrap()
3453 .reset(channel_id);
3454 }
3455
3456 /// Check if a stream exists.
3457 pub fn contains_channel(&self, channel_id: ChannelId) -> bool {
3458 self.shared
3459 .channel_registry
3460 .lock()
3461 .unwrap()
3462 .contains(channel_id)
3463 }
3464
3465 /// Receive credit for an outgoing stream.
3466 pub fn receive_credit(&self, channel_id: ChannelId, bytes: u32) {
3467 self.shared
3468 .channel_registry
3469 .lock()
3470 .unwrap()
3471 .receive_credit(channel_id, bytes);
3472 }
3473
3474 /// Get a clone of the driver message sender.
3475 ///
3476 /// Used for forwarding/proxy scenarios where messages need to be sent
3477 /// on this connection's wire.
3478 pub fn driver_tx(&self) -> Sender<DriverMessage> {
3479 self.shared.channel_registry.lock().unwrap().driver_tx()
3480 }
3481
3482 /// Bind receivers for `Rx<T>` streams in a deserialized response.
3483 ///
3484 /// After deserializing a response, any `Rx<T>` values are "hollow" - they have
3485 /// channel IDs but no actual receiver. This method walks the response using
3486 /// reflection and binds receivers for each `Rx<T>` so data can be received.
3487 ///
3488 /// # How it works
3489 ///
3490 /// For each `Rx<T>` found in the response:
3491 /// 1. Read the channel_id that was set during deserialization
3492 /// 2. Create a new channel (tx, rx)
3493 /// 3. Set the receiver slot on the Rx
3494 /// 4. Register the sender with the channel registry for incoming data routing
3495 ///
3496 /// This mirrors server-side `ChannelRegistry::bind_streams` but for responses.
3497 ///
3498 /// IMPORTANT: The `channels` parameter contains the authoritative channel IDs
3499 /// from the Response framing. For forwarded connections (via ForwardingDispatcher),
3500 /// these IDs may differ from the IDs serialized in the payload. We patch them first.
3501 pub fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
3502 // Patch channel IDs from Response.channels into the deserialized response.
3503 // This is critical for ForwardingDispatcher where the payload contains upstream
3504 // channel IDs but channels[] contains the remapped downstream IDs.
3505 patch_channel_ids(response, channels);
3506
3507 let poke = facet::Poke::new(response);
3508 self.bind_response_streams_recursive(poke);
3509 }
3510
3511 /// Recursively walk a Poke value looking for Rx streams to bind in responses.
3512 #[allow(unsafe_code)]
3513 fn bind_response_streams_recursive(&self, mut poke: facet::Poke<'_, '_>) {
3514 use facet::Def;
3515
3516 let shape = poke.shape();
3517
3518 // Check if this is an Rx type - only Rx needs binding in responses
3519 // (Tx in responses would be outgoing, but that's uncommon for return types)
3520 if shape.module_path == Some("roam_session") && shape.type_identifier == "Rx" {
3521 self.bind_rx_response_stream(poke);
3522 return;
3523 }
3524
3525 // Dispatch based on the shape's definition
3526 match shape.def {
3527 Def::Scalar => {}
3528
3529 // Recurse into struct/tuple fields
3530 _ if poke.is_struct() => {
3531 let mut ps = poke.into_struct().expect("is_struct was true");
3532 let field_count = ps.field_count();
3533 for i in 0..field_count {
3534 if let Ok(field_poke) = ps.field(i) {
3535 self.bind_response_streams_recursive(field_poke);
3536 }
3537 }
3538 }
3539
3540 // Recurse into Option<T>
3541 Def::Option(_) => {
3542 // Option is represented as an enum, use into_enum to access its value
3543 if let Ok(mut pe) = poke.into_enum()
3544 && let Ok(Some(inner_poke)) = pe.field(0)
3545 {
3546 self.bind_response_streams_recursive(inner_poke);
3547 }
3548 }
3549
3550 // Recurse into list elements (e.g., Vec<Rx<T>>)
3551 Def::List(list_def) => {
3552 let len = {
3553 let peek = poke.as_peek();
3554 peek.into_list().map(|pl| pl.len()).unwrap_or(0)
3555 };
3556 // Get mutable access to elements via VTable (no PokeList exists)
3557 if let Some(get_mut_fn) = list_def.vtable.get_mut {
3558 let element_shape = list_def.t;
3559 let data_ptr = poke.data_mut();
3560 for i in 0..len {
3561 // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
3562 let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
3563 if let Some(ptr) = element_ptr {
3564 // SAFETY: ptr points to a valid element with the correct shape
3565 let element_poke =
3566 unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
3567 self.bind_response_streams_recursive(element_poke);
3568 }
3569 }
3570 }
3571 }
3572
3573 // Other enum variants
3574 _ if poke.is_enum() => {
3575 if let Ok(mut pe) = poke.into_enum()
3576 && let Ok(Some(variant_poke)) = pe.field(0)
3577 {
3578 self.bind_response_streams_recursive(variant_poke);
3579 }
3580 }
3581
3582 _ => {}
3583 }
3584 }
3585
3586 /// Bind a single Rx<T> stream from a response.
3587 ///
3588 /// Creates a channel, sets the receiver slot, and registers for incoming data.
3589 fn bind_rx_response_stream(&self, poke: facet::Poke<'_, '_>) {
3590 if let Ok(mut ps) = poke.into_struct() {
3591 // Get the channel_id that was deserialized from the wire
3592 let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
3593 && let Ok(id_ref) = channel_id_field.get::<ChannelId>()
3594 {
3595 *id_ref
3596 } else {
3597 return;
3598 };
3599
3600 // Create channel and set receiver slot
3601 let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
3602
3603 if let Ok(mut receiver_field) = ps.field_by_name("receiver")
3604 && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
3605 {
3606 slot.set(rx);
3607 }
3608
3609 // Register for incoming data routing
3610 self.register_incoming(channel_id, tx);
3611 }
3612 }
3613}
3614
3615#[derive(Debug)]
3616pub enum ClientError<TransportError> {
3617 Transport(TransportError),
3618 Encode(facet_postcard::SerializeError),
3619 Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
3620}
3621
3622impl<TransportError> From<TransportError> for ClientError<TransportError> {
3623 fn from(value: TransportError) -> Self {
3624 Self::Transport(value)
3625 }
3626}
3627
3628#[derive(Debug)]
3629pub enum DispatchError {
3630 Encode(facet_postcard::SerializeError),
3631}
3632
3633// ============================================================================
3634// Tunnel Adapters for AsyncRead/AsyncWrite Streams (native only)
3635// ============================================================================
3636
3637#[cfg(not(target_arch = "wasm32"))]
3638use std::io;
3639#[cfg(not(target_arch = "wasm32"))]
3640use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3641#[cfg(not(target_arch = "wasm32"))]
3642use tokio::task::JoinHandle;
3643
3644/// Default chunk size for tunnel pumps (32KB).
3645///
3646/// Balances throughput with memory usage and slot consumption.
3647/// Larger values improve throughput but use more memory per read.
3648/// Smaller values improve latency but increase syscall overhead.
3649#[cfg(not(target_arch = "wasm32"))]
3650pub const DEFAULT_TUNNEL_CHUNK_SIZE: usize = 32 * 1024;
3651
3652/// A bidirectional byte tunnel over roam channels.
3653///
3654/// From the perspective of whoever holds the tunnel:
3655/// - `tx`: Send bytes TO the remote end
3656/// - `rx`: Receive bytes FROM the remote end
3657///
3658/// Tunnels are typically used to bridge async byte streams (TCP, Unix sockets, etc.)
3659/// with roam's streaming channels. One side creates a tunnel pair with [`tunnel_pair()`],
3660/// passes one half to the remote via an RPC call, and uses the other half locally.
3661///
3662/// # Example
3663///
3664/// ```ignore
3665/// // Host side: create tunnel and pump to/from a socket
3666/// let (local, remote) = roam_session::tunnel_pair();
3667/// let (read_handle, write_handle) = roam_session::tunnel_stream(socket, local, 32 * 1024);
3668///
3669/// // Pass `remote` to cell via RPC
3670/// cell.handle_connection(remote).await?;
3671/// ```
3672#[derive(Facet)]
3673pub struct Tunnel {
3674 /// Channel for sending bytes to the remote end.
3675 pub tx: Tx<Vec<u8>>,
3676 /// Channel for receiving bytes from the remote end.
3677 pub rx: Rx<Vec<u8>>,
3678}
3679
3680/// Create a pair of connected tunnels.
3681///
3682/// Returns `(local, remote)` where:
3683/// - Data sent on `local.tx` arrives at `remote.rx`
3684/// - Data sent on `remote.tx` arrives at `local.rx`
3685///
3686/// This is useful for creating a bidirectional channel that can be split
3687/// across an RPC boundary. One side keeps `local` and passes `remote` to
3688/// the other side via an RPC call.
3689///
3690/// # Example
3691///
3692/// ```ignore
3693/// let (local, remote) = tunnel_pair();
3694///
3695/// // Spawn tasks to pump data from local stream
3696/// tunnel_stream(tcp_stream, local, DEFAULT_TUNNEL_CHUNK_SIZE);
3697///
3698/// // Send remote to the other side via RPC
3699/// service.handle_tunnel(remote).await?;
3700/// ```
3701pub fn tunnel_pair() -> (Tunnel, Tunnel) {
3702 let (tx1, rx1) = channel::<Vec<u8>>();
3703 let (tx2, rx2) = channel::<Vec<u8>>();
3704 (Tunnel { tx: tx1, rx: rx2 }, Tunnel { tx: tx2, rx: rx1 })
3705}
3706
3707/// Pump bytes from an `AsyncRead` into a `Tx<Vec<u8>>`.
3708///
3709/// Reads chunks up to `chunk_size` bytes and sends them on the channel.
3710/// Returns when the reader reaches EOF or the channel closes.
3711///
3712/// # Arguments
3713///
3714/// * `reader` - Any type implementing `AsyncRead + Unpin`
3715/// * `tx` - The transmit channel to send bytes to
3716/// * `chunk_size` - Maximum bytes to read per chunk
3717///
3718/// # Returns
3719///
3720/// * `Ok(())` - Reader reached EOF, channel closed gracefully
3721/// * `Err(io::Error)` - Read error occurred
3722///
3723/// # Example
3724///
3725/// ```ignore
3726/// let (tx, rx) = roam::channel::<Vec<u8>>();
3727/// let result = pump_read_to_tx(reader, tx, 32 * 1024).await;
3728/// ```
3729#[cfg(not(target_arch = "wasm32"))]
3730pub async fn pump_read_to_tx<R: AsyncRead + Unpin>(
3731 mut reader: R,
3732 tx: Tx<Vec<u8>>,
3733 chunk_size: usize,
3734) -> io::Result<()> {
3735 let mut buf = vec![0u8; chunk_size];
3736 loop {
3737 let n = reader.read(&mut buf).await?;
3738 if n == 0 {
3739 // EOF - drop tx to close the channel
3740 break;
3741 }
3742 // Send the bytes we read
3743 if tx.send(&buf[..n].to_vec()).await.is_err() {
3744 // Channel closed by receiver - treat as graceful shutdown
3745 break;
3746 }
3747 }
3748 Ok(())
3749}
3750
3751/// Pump bytes from an `Rx<Vec<u8>>` into an `AsyncWrite`.
3752///
3753/// Receives chunks and writes them to the writer.
3754/// Returns when the channel closes or a write error occurs.
3755///
3756/// # Arguments
3757///
3758/// * `rx` - The receive channel to get bytes from
3759/// * `writer` - Any type implementing `AsyncWrite + Unpin`
3760///
3761/// # Returns
3762///
3763/// * `Ok(())` - Channel closed gracefully
3764/// * `Err(io::Error)` - Write error or deserialization error occurred
3765///
3766/// # Example
3767///
3768/// ```ignore
3769/// let (tx, rx) = roam::channel::<Vec<u8>>();
3770/// let result = pump_rx_to_write(rx, writer).await;
3771/// ```
3772#[cfg(not(target_arch = "wasm32"))]
3773pub async fn pump_rx_to_write<W: AsyncWrite + Unpin>(
3774 mut rx: Rx<Vec<u8>>,
3775 mut writer: W,
3776) -> io::Result<()> {
3777 loop {
3778 match rx.recv().await {
3779 Ok(Some(data)) => {
3780 writer.write_all(&data).await?;
3781 }
3782 Ok(None) => {
3783 // Channel closed - flush and exit
3784 writer.flush().await?;
3785 break;
3786 }
3787 Err(e) => {
3788 return Err(io::Error::new(
3789 io::ErrorKind::InvalidData,
3790 format!("tunnel receive error: {e}"),
3791 ));
3792 }
3793 }
3794 }
3795 Ok(())
3796}
3797
3798/// Tunnel a bidirectional stream through a roam Tunnel.
3799///
3800/// Spawns two tasks to pump data in both directions:
3801/// - One task reads from `stream` and sends to `tunnel.tx`
3802/// - One task receives from `tunnel.rx` and writes to `stream`
3803///
3804/// Returns handles to join on completion. Both tasks run until their
3805/// respective direction completes (EOF/close) or an error occurs.
3806///
3807/// # Arguments
3808///
3809/// * `stream` - Any type implementing `AsyncRead + AsyncWrite + Unpin + Send + 'static`
3810/// * `tunnel` - The tunnel to pump data through
3811/// * `chunk_size` - Maximum bytes to read per chunk (see [`DEFAULT_TUNNEL_CHUNK_SIZE`])
3812///
3813/// # Returns
3814///
3815/// A tuple of `(read_handle, write_handle)`:
3816/// - `read_handle` - Completes when the stream reaches EOF or tx closes
3817/// - `write_handle` - Completes when rx closes or stream write fails
3818///
3819/// # Example
3820///
3821/// ```ignore
3822/// let (local, remote) = tunnel_pair();
3823/// let (read_handle, write_handle) = tunnel_stream(tcp_stream, local, DEFAULT_TUNNEL_CHUNK_SIZE);
3824///
3825/// // Wait for both directions to complete
3826/// let _ = read_handle.await;
3827/// let _ = write_handle.await;
3828/// ```
3829#[cfg(not(target_arch = "wasm32"))]
3830pub fn tunnel_stream<S>(
3831 stream: S,
3832 tunnel: Tunnel,
3833 chunk_size: usize,
3834) -> (JoinHandle<io::Result<()>>, JoinHandle<io::Result<()>>)
3835where
3836 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
3837{
3838 let (reader, writer) = tokio::io::split(stream);
3839 let Tunnel { tx, rx } = tunnel;
3840
3841 let read_handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, chunk_size).await });
3842
3843 let write_handle = tokio::spawn(async move { pump_rx_to_write(rx, writer).await });
3844
3845 (read_handle, write_handle)
3846}
3847
3848#[cfg(test)]
3849mod tests {
3850 use super::*;
3851
3852 // r[verify channeling.id.parity]
3853 #[test]
3854 fn channel_id_allocator_initiator_uses_odd_ids() {
3855 let alloc = ChannelIdAllocator::new(Role::Initiator);
3856 assert_eq!(alloc.next(), 1);
3857 assert_eq!(alloc.next(), 3);
3858 assert_eq!(alloc.next(), 5);
3859 assert_eq!(alloc.next(), 7);
3860 }
3861
3862 // r[verify channeling.id.parity]
3863 #[test]
3864 fn channel_id_allocator_acceptor_uses_even_ids() {
3865 let alloc = ChannelIdAllocator::new(Role::Acceptor);
3866 assert_eq!(alloc.next(), 2);
3867 assert_eq!(alloc.next(), 4);
3868 assert_eq!(alloc.next(), 6);
3869 assert_eq!(alloc.next(), 8);
3870 }
3871
3872 // r[verify channeling.holder-semantics]
3873 #[tokio::test]
3874 async fn tx_serializes_and_rx_deserializes() {
3875 // Create a channel pair using roam::channel
3876 let (tx, mut rx) = channel::<i32>();
3877
3878 // Simulate what ConnectionHandle::call would do: take the receiver
3879 let mut taken_rx = rx.receiver.take().expect("receiver should be present");
3880
3881 // Now tx can send and we can receive on the taken receiver
3882 tx.send(&100).await.unwrap();
3883 tx.send(&200).await.unwrap();
3884
3885 // Receive raw bytes and deserialize
3886 let bytes1 = taken_rx.recv().await.unwrap();
3887 let val1: i32 = facet_postcard::from_slice(&bytes1).unwrap();
3888 assert_eq!(val1, 100);
3889
3890 let bytes2 = taken_rx.recv().await.unwrap();
3891 let val2: i32 = facet_postcard::from_slice(&bytes2).unwrap();
3892 assert_eq!(val2, 200);
3893 }
3894
3895 /// Create a test registry with a dummy task channel.
3896 fn test_registry() -> ChannelRegistry {
3897 let (task_tx, _task_rx) = crate::runtime::channel(10);
3898 ChannelRegistry::new(task_tx)
3899 }
3900
3901 // r[verify channeling.data-after-close]
3902 #[tokio::test]
3903 async fn data_after_close_is_rejected() {
3904 let mut registry = test_registry();
3905 let (tx, _rx) = crate::runtime::channel(10);
3906 registry.register_incoming(42, tx);
3907
3908 // Close the stream
3909 registry.close(42);
3910
3911 // Data after close should fail
3912 let result = registry.route_data(42, b"data".to_vec()).await;
3913 assert_eq!(result, Err(ChannelError::DataAfterClose));
3914 }
3915
3916 // r[verify channeling.data]
3917 // r[verify channeling.unknown]
3918 #[tokio::test]
3919 async fn channel_registry_routes_data_to_registered_stream() {
3920 let mut registry = test_registry();
3921
3922 // Register a stream
3923 let (tx, mut rx) = crate::runtime::channel(10);
3924 registry.register_incoming(42, tx);
3925
3926 // Data to registered stream should succeed
3927 assert!(registry.route_data(42, b"hello".to_vec()).await.is_ok());
3928
3929 // Should receive the data
3930 assert_eq!(rx.recv().await, Some(b"hello".to_vec()));
3931
3932 // Data to unregistered stream should fail
3933 assert!(registry.route_data(999, b"nope".to_vec()).await.is_err());
3934 }
3935
3936 // r[verify channeling.close]
3937 #[tokio::test]
3938 async fn channel_registry_close_terminates_stream() {
3939 let mut registry = test_registry();
3940 let (tx, mut rx) = crate::runtime::channel(10);
3941 registry.register_incoming(42, tx);
3942
3943 // Send some data
3944 registry.route_data(42, b"data1".to_vec()).await.unwrap();
3945
3946 // Close the stream
3947 registry.close(42);
3948
3949 // Should still receive buffered data
3950 assert_eq!(rx.recv().await, Some(b"data1".to_vec()));
3951
3952 // Then channel closes (sender dropped)
3953 assert_eq!(rx.recv().await, None);
3954
3955 // Stream no longer registered
3956 assert!(!registry.contains(42));
3957 }
3958
3959 #[test]
3960 fn tx_rx_shape_metadata() {
3961 use facet::Facet;
3962
3963 let tx_shape = <Tx<i32> as Facet>::SHAPE;
3964 let rx_shape = <Rx<i32> as Facet>::SHAPE;
3965
3966 // Verify module_path and type_identifier are set correctly
3967 assert_eq!(tx_shape.module_path, Some("roam_session"));
3968 assert_eq!(tx_shape.type_identifier, "Tx");
3969 assert_eq!(rx_shape.module_path, Some("roam_session"));
3970 assert_eq!(rx_shape.type_identifier, "Rx");
3971
3972 // Verify type_params are populated
3973 assert_eq!(tx_shape.type_params.len(), 1);
3974 assert_eq!(rx_shape.type_params.len(), 1);
3975 }
3976
3977 // ========================================================================
3978 // Tunnel Tests
3979 // ========================================================================
3980
3981 #[tokio::test]
3982 async fn tunnel_pair_connects_bidirectionally() {
3983 let (local, remote) = tunnel_pair();
3984
3985 // Send from local to remote
3986 local.tx.send(&b"hello".to_vec()).await.unwrap();
3987
3988 // Receive on remote
3989 let mut remote_rx = remote.rx;
3990 let received = remote_rx.recv().await.unwrap().unwrap();
3991 assert_eq!(received, b"hello".to_vec());
3992
3993 // Send from remote to local
3994 remote.tx.send(&b"world".to_vec()).await.unwrap();
3995
3996 // Receive on local
3997 let mut local_rx = local.rx;
3998 let received = local_rx.recv().await.unwrap().unwrap();
3999 assert_eq!(received, b"world".to_vec());
4000 }
4001
4002 #[tokio::test]
4003 async fn pump_read_to_tx_sends_chunks() {
4004 use std::io::Cursor;
4005
4006 let data = b"hello world this is a test message";
4007 let reader = Cursor::new(data.to_vec());
4008 let (tx, mut rx) = channel::<Vec<u8>>();
4009
4010 // Pump with small chunk size to force multiple chunks
4011 let handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, 10).await });
4012
4013 // Collect all received chunks
4014 let mut received = Vec::new();
4015 while let Ok(Some(chunk)) = rx.recv().await {
4016 received.extend(chunk);
4017 }
4018
4019 // Verify we got all the data
4020 assert_eq!(received, data.to_vec());
4021
4022 // Pump should complete successfully
4023 handle.await.unwrap().unwrap();
4024 }
4025
4026 #[tokio::test]
4027 async fn pump_rx_to_write_writes_chunks() {
4028 use std::io::Cursor;
4029
4030 let (tx, rx) = channel::<Vec<u8>>();
4031 let writer = Cursor::new(Vec::new());
4032
4033 // Spawn pump task
4034 let handle = tokio::spawn(async move {
4035 let mut writer = writer;
4036 pump_rx_to_write(rx, &mut writer).await?;
4037 Ok::<_, io::Error>(writer)
4038 });
4039
4040 // Send some chunks
4041 tx.send(&b"hello ".to_vec()).await.unwrap();
4042 tx.send(&b"world".to_vec()).await.unwrap();
4043 drop(tx); // Close the channel
4044
4045 // Wait for pump to complete and get the writer
4046 let writer = handle.await.unwrap().unwrap();
4047 assert_eq!(writer.into_inner(), b"hello world".to_vec());
4048 }
4049
4050 #[tokio::test]
4051 async fn tunnel_stream_bidirectional() {
4052 // Create a duplex stream (simulates a socket)
4053 let (client, server) = tokio::io::duplex(1024);
4054
4055 // Create tunnel pair
4056 let (local, remote) = tunnel_pair();
4057
4058 // Tunnel the client side
4059 let (client_read_handle, client_write_handle) =
4060 tunnel_stream(client, local, DEFAULT_TUNNEL_CHUNK_SIZE);
4061
4062 // Use remote tunnel to send/receive
4063 tokio::spawn(async move {
4064 // Send data through the tunnel (will go to server side of duplex)
4065 remote.tx.send(&b"from tunnel".to_vec()).await.unwrap();
4066 });
4067
4068 // Read from server side of duplex
4069 let mut server = server;
4070 let mut buf = vec![0u8; 1024];
4071 let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
4072 .await
4073 .unwrap();
4074 assert!(n > 0);
4075
4076 // Write to server side
4077 tokio::io::AsyncWriteExt::write_all(&mut server, b"to tunnel")
4078 .await
4079 .unwrap();
4080 drop(server); // Close to signal EOF
4081
4082 // Wait for read task to complete
4083 client_read_handle.await.unwrap().unwrap();
4084 client_write_handle.await.unwrap().unwrap();
4085 }
4086
4087 #[tokio::test]
4088 async fn tunnel_handles_empty_data() {
4089 let (tx, mut rx) = channel::<Vec<u8>>();
4090
4091 // Sending empty vec should work
4092 tx.send(&Vec::new()).await.unwrap();
4093
4094 let received = rx.recv().await.unwrap().unwrap();
4095 assert!(received.is_empty());
4096 }
4097
4098 #[tokio::test]
4099 async fn tunnel_close_propagates() {
4100 let (local, remote) = tunnel_pair();
4101
4102 // Drop the sender
4103 drop(local.tx);
4104
4105 // Receiver should see channel closed
4106 let mut rx = remote.rx;
4107 let result = rx.recv().await;
4108 assert!(matches!(result, Ok(None)));
4109 }
4110
4111 // ========================================================================
4112 // Channel ID Collection Tests
4113 // ========================================================================
4114
4115 // r[verify call.request.channels]
4116 #[test]
4117 fn collect_channel_ids_simple_tx() {
4118 let tx: Tx<i32> = Tx::try_from(42u64).unwrap();
4119 let ids = collect_channel_ids(&tx);
4120 assert_eq!(ids, vec![42]);
4121 }
4122
4123 // r[verify call.request.channels]
4124 #[test]
4125 fn collect_channel_ids_simple_rx() {
4126 let rx: Rx<i32> = Rx::try_from(99u64).unwrap();
4127 let ids = collect_channel_ids(&rx);
4128 assert_eq!(ids, vec![99]);
4129 }
4130
4131 // r[verify call.request.channels]
4132 #[test]
4133 fn collect_channel_ids_tuple() {
4134 let rx: Rx<String> = Rx::try_from(10u64).unwrap();
4135 let tx: Tx<String> = Tx::try_from(20u64).unwrap();
4136 let args = (rx, tx);
4137 let ids = collect_channel_ids(&args);
4138 assert_eq!(ids, vec![10, 20]);
4139 }
4140
4141 // r[verify call.request.channels]
4142 #[test]
4143 fn collect_channel_ids_nested_in_struct() {
4144 #[derive(facet::Facet)]
4145 struct StreamArgs {
4146 input: Rx<i32>,
4147 output: Tx<i32>,
4148 count: u32,
4149 }
4150
4151 let args = StreamArgs {
4152 input: Rx::try_from(100u64).unwrap(),
4153 output: Tx::try_from(200u64).unwrap(),
4154 count: 5,
4155 };
4156 let ids = collect_channel_ids(&args);
4157 assert_eq!(ids, vec![100, 200]);
4158 }
4159
4160 // r[verify call.request.channels]
4161 #[test]
4162 fn collect_channel_ids_option_some() {
4163 let tx: Tx<i32> = Tx::try_from(55u64).unwrap();
4164 let args: Option<Tx<i32>> = Some(tx);
4165 let ids = collect_channel_ids(&args);
4166 assert_eq!(ids, vec![55]);
4167 }
4168
4169 // r[verify call.request.channels]
4170 #[test]
4171 fn collect_channel_ids_option_none() {
4172 let args: Option<Tx<i32>> = None;
4173 let ids = collect_channel_ids(&args);
4174 assert!(ids.is_empty());
4175 }
4176
4177 // r[verify call.request.channels]
4178 #[test]
4179 fn collect_channel_ids_vec() {
4180 let tx1: Tx<i32> = Tx::try_from(1u64).unwrap();
4181 let tx2: Tx<i32> = Tx::try_from(2u64).unwrap();
4182 let tx3: Tx<i32> = Tx::try_from(3u64).unwrap();
4183 let args: Vec<Tx<i32>> = vec![tx1, tx2, tx3];
4184 let ids = collect_channel_ids(&args);
4185 assert_eq!(ids, vec![1, 2, 3]);
4186 }
4187
4188 // r[verify call.request.channels]
4189 #[test]
4190 fn collect_channel_ids_deeply_nested() {
4191 #[derive(facet::Facet)]
4192 struct Outer {
4193 inner: Inner,
4194 }
4195
4196 #[derive(facet::Facet)]
4197 struct Inner {
4198 stream: Tx<u8>,
4199 }
4200
4201 let args = Outer {
4202 inner: Inner {
4203 stream: Tx::try_from(777u64).unwrap(),
4204 },
4205 };
4206 let ids = collect_channel_ids(&args);
4207 assert_eq!(ids, vec![777]);
4208 }
4209}