roam_session/
lib.rs

1#![deny(unsafe_code)]
2
3//! Session/state machine and RPC-level utilities.
4//!
5//! Canonical definitions live in `docs/content/spec/_index.md`,
6//! `docs/content/rust-spec/_index.md`, and `docs/content/shm-spec/_index.md`.
7
8#[macro_use]
9mod macros;
10
11pub mod diagnostic;
12pub mod driver;
13pub mod runtime;
14pub mod transport;
15
16pub use driver::{
17    ConnectError, ConnectionError, Driver, FramedClient, HandshakeConfig, IncomingConnection,
18    IncomingConnections, MessageConnector, Negotiated, NoDispatcher, RetryPolicy, accept_framed,
19    connect_framed, connect_framed_with_policy, initiate_framed,
20};
21pub use transport::MessageTransport;
22
23use std::marker::PhantomData;
24use std::sync::Arc;
25use std::sync::atomic::{AtomicU64, Ordering};
26
27use crate::runtime::{OneshotSender, Receiver, Sender, oneshot};
28use facet::Facet;
29use std::convert::Infallible;
30
31pub use roam_frame::{Frame, MsgDesc, OwnedMessage, Payload};
32
33const CHANNEL_SIZE: usize = 1024;
34const RX_STREAM_BUFFER_SIZE: usize = 1024;
35
36// ============================================================================
37// Streaming types
38// ============================================================================
39
40/// Stream ID type.
41pub type ChannelId = u64;
42
43/// Connection role - determines stream ID parity.
44///
45/// The initiator is whoever opened the connection (e.g. connected to a TCP socket,
46/// or opened an SHM channel). The acceptor is whoever accepted/received the connection.
47///
48/// r[impl channeling.id.parity]
49#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum Role {
51    /// Initiator uses odd stream IDs (1, 3, 5, ...).
52    Initiator,
53    /// Acceptor uses even stream IDs (2, 4, 6, ...).
54    Acceptor,
55}
56
57/// Allocates unique stream IDs with correct parity.
58///
59/// r[impl channeling.id.uniqueness] - IDs are unique within a connection.
60/// r[impl channeling.id.parity] - Initiator uses odd, Acceptor uses even.
61pub struct ChannelIdAllocator {
62    next: AtomicU64,
63}
64
65impl ChannelIdAllocator {
66    /// Create a new allocator for the given role.
67    pub fn new(role: Role) -> Self {
68        let start = match role {
69            Role::Initiator => 1, // odd: 1, 3, 5, ...
70            Role::Acceptor => 2,  // even: 2, 4, 6, ...
71        };
72        Self {
73            next: AtomicU64::new(start),
74        }
75    }
76
77    /// Allocate the next stream ID.
78    pub fn next(&self) -> ChannelId {
79        self.next.fetch_add(2, Ordering::Relaxed)
80    }
81}
82
83// ============================================================================
84// SenderSlot - Wrapper for Option<Sender> that implements Facet
85// ============================================================================
86
87/// A wrapper around `Option<Sender<Vec<u8>>>` that implements Facet.
88///
89/// This allows `Poke::get_mut::<SenderSlot>()` to work, enabling `.take()`
90/// via reflection. Used by `ConnectionHandle::call` to extract senders from
91/// `Tx<T>` arguments and register them with the stream registry.
92#[derive(Facet)]
93#[facet(opaque)]
94pub struct SenderSlot {
95    /// The optional sender. Public within crate for `Tx::send()` access.
96    pub(crate) inner: Option<Sender<Vec<u8>>>,
97}
98
99impl SenderSlot {
100    /// Create a slot containing a sender.
101    pub fn new(tx: Sender<Vec<u8>>) -> Self {
102        Self { inner: Some(tx) }
103    }
104
105    /// Create an empty slot.
106    pub fn empty() -> Self {
107        Self { inner: None }
108    }
109
110    /// Take the sender out of the slot, leaving it empty.
111    pub fn take(&mut self) -> Option<Sender<Vec<u8>>> {
112        self.inner.take()
113    }
114
115    /// Check if the slot contains a sender.
116    pub fn is_some(&self) -> bool {
117        self.inner.is_some()
118    }
119
120    /// Check if the slot is empty.
121    pub fn is_none(&self) -> bool {
122        self.inner.is_none()
123    }
124
125    /// Set the sender in this slot.
126    ///
127    /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Tx<T>`
128    /// with an actual channel sender.
129    pub fn set(&mut self, tx: Sender<Vec<u8>>) {
130        self.inner = Some(tx);
131    }
132}
133
134// ============================================================================
135// DriverTxSlot - Wrapper for Option<Sender<DriverMessage>> that implements Facet
136// ============================================================================
137
138/// A wrapper around `Option<Sender<DriverMessage>>` that implements Facet.
139///
140/// This allows `Poke::get_mut::<DriverTxSlot>()` to work, enabling reflection-based
141/// hydration of `Tx<T>` handles on the server side. Sends Data/Close messages
142/// directly to the connection driver.
143#[derive(Facet)]
144#[facet(opaque)]
145pub struct DriverTxSlot {
146    /// The optional sender. Public within crate for `Tx::send()` access.
147    pub(crate) inner: Option<Sender<DriverMessage>>,
148}
149
150impl DriverTxSlot {
151    /// Create a slot containing a task sender.
152    pub fn new(tx: Sender<DriverMessage>) -> Self {
153        Self { inner: Some(tx) }
154    }
155
156    /// Create an empty slot.
157    pub fn empty() -> Self {
158        Self { inner: None }
159    }
160
161    /// Take the sender out of the slot, leaving it empty.
162    pub fn take(&mut self) -> Option<Sender<DriverMessage>> {
163        self.inner.take()
164    }
165
166    /// Check if the slot contains a sender.
167    pub fn is_some(&self) -> bool {
168        self.inner.is_some()
169    }
170
171    /// Check if the slot is empty.
172    pub fn is_none(&self) -> bool {
173        self.inner.is_none()
174    }
175
176    /// Set the task sender in this slot.
177    ///
178    /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Tx<T>`
179    /// with the connection's task message channel.
180    pub fn set(&mut self, tx: Sender<DriverMessage>) {
181        self.inner = Some(tx);
182    }
183
184    /// Clone the sender if present.
185    pub fn clone_inner(&self) -> Option<Sender<DriverMessage>> {
186        self.inner.clone()
187    }
188}
189
190/// Tx stream handle - caller sends data to callee.
191///
192/// r[impl channeling.caller-pov] - From caller's perspective, Tx means "I send".
193/// r[impl channeling.type] - Serializes as u64 stream ID on wire.
194/// r[impl channeling.holder-semantics] - The holder sends on this stream.
195/// r[impl channeling.channels-outlive-response] - Tx streams may outlive Response.
196/// r[impl channeling.lifecycle.immediate-data] - Can send Data before Response.
197/// r[impl channeling.lifecycle.speculative] - Early Data may be wasted on error.
198///
199/// # Facet Implementation
200///
201/// Uses `#[facet(proxy = u64)]` so that:
202/// - `channel_id` is pokeable (Connection can walk args and set stream IDs)
203/// - Serializes as just a `u64` on the wire
204/// - `T` is exposed as a type parameter for codegen introspection
205///
206/// # Two modes of operation
207///
208/// - **Client side**: `sender` holds a channel to an intermediate drain task.
209///   `ConnectionHandle::call` takes the receiver and drains it to wire.
210/// - **Server side**: `task_tx` holds a direct channel to the connection driver.
211///   `ChannelRegistry::bind_streams` sets this, and `send()` writes `DriverMessage::Data`.
212#[derive(Facet)]
213#[facet(proxy = u64)]
214pub struct Tx<T: 'static> {
215    /// The connection ID this stream belongs to.
216    pub conn_id: roam_wire::ConnectionId,
217    /// The unique stream ID for this stream.
218    /// Public so Connection can poke it when binding streams.
219    pub channel_id: ChannelId,
220    /// Channel sender for outgoing data (client-side mode).
221    /// Used when Tx is created via `roam::channel()`.
222    pub sender: SenderSlot,
223    /// Direct driver message sender (server-side mode).
224    /// Used when Tx is hydrated by `ChannelRegistry::bind_streams`.
225    pub driver_tx: DriverTxSlot,
226    /// Phantom data for the element type.
227    #[facet(opaque)]
228    _marker: PhantomData<T>,
229}
230
231/// Serialization: `&Tx<T>` -> u64 (extracts channel_id)
232///
233/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
234#[allow(clippy::infallible_try_from)]
235impl<T: 'static> TryFrom<&Tx<T>> for u64 {
236    type Error = Infallible;
237    fn try_from(tx: &Tx<T>) -> Result<Self, Self::Error> {
238        Ok(tx.channel_id)
239    }
240}
241
242/// Deserialization: u64 -> `Tx<T>` (creates a "hollow" Tx)
243///
244/// Both sender slots are empty - the real sender gets set up by Connection
245/// after deserialization when it binds the stream.
246///
247/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
248#[allow(clippy::infallible_try_from)]
249impl<T: 'static> TryFrom<u64> for Tx<T> {
250    type Error = Infallible;
251    fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
252        // Create a hollow Tx - no actual sender, Connection will bind later
253        // conn_id will be set when binding
254        Ok(Tx {
255            conn_id: roam_wire::ConnectionId::ROOT,
256            channel_id,
257            sender: SenderSlot::empty(),
258            driver_tx: DriverTxSlot::empty(),
259            _marker: PhantomData,
260        })
261    }
262}
263
264impl<T: 'static> Tx<T> {
265    /// Create a new Tx stream with the given ID and sender channel (client-side mode).
266    pub fn new(channel_id: ChannelId, tx: Sender<Vec<u8>>) -> Self {
267        Self {
268            conn_id: roam_wire::ConnectionId::ROOT,
269            channel_id,
270            sender: SenderSlot::new(tx),
271            driver_tx: DriverTxSlot::empty(),
272            _marker: PhantomData,
273        }
274    }
275
276    /// Create an unbound Tx with a sender but channel_id 0.
277    ///
278    /// Used by `roam::channel()` to create a pair before binding.
279    /// Connection will poke the channel_id and conn_id when binding.
280    pub fn unbound(tx: Sender<Vec<u8>>) -> Self {
281        Self {
282            conn_id: roam_wire::ConnectionId::ROOT,
283            channel_id: 0,
284            sender: SenderSlot::new(tx),
285            driver_tx: DriverTxSlot::empty(),
286            _marker: PhantomData,
287        }
288    }
289
290    /// Create a bound Tx with conn_id, channel_id and driver_tx already set.
291    ///
292    /// Used by `roam::channel()` when called during dispatch to create
293    /// response channels that can send Data directly over the wire.
294    pub fn bound(
295        conn_id: roam_wire::ConnectionId,
296        channel_id: ChannelId,
297        tx: Sender<Vec<u8>>,
298        driver_tx: Sender<DriverMessage>,
299    ) -> Self {
300        Self {
301            conn_id,
302            channel_id,
303            sender: SenderSlot::new(tx),
304            driver_tx: DriverTxSlot::new(driver_tx),
305            _marker: PhantomData,
306        }
307    }
308
309    /// Get the stream ID.
310    pub fn channel_id(&self) -> ChannelId {
311        self.channel_id
312    }
313
314    /// Send a value on this stream.
315    ///
316    /// r[impl channeling.data] - Data messages carry serialized values.
317    ///
318    /// Works in two modes:
319    /// - Client-side (or passthrough): sends raw bytes to intermediate channel (drained by connection)
320    /// - Server-side: sends `DriverMessage::Data` directly to connection driver
321    ///
322    /// IMPORTANT: We prefer sender over driver_tx because when a channel created during
323    /// dispatch is passed to a callback, the rx gets a NEW channel_id allocated by the
324    /// caller's bind_streams. The drain task uses that new channel_id, while self.channel_id
325    /// still has the old dispatch-context channel_id. By using sender, data flows through
326    /// the drain task which uses the correct channel_id.
327    pub async fn send(&self, value: &T) -> Result<(), TxError>
328    where
329        T: Facet<'static>,
330    {
331        let bytes = facet_postcard::to_vec(value).map_err(TxError::Serialize)?;
332
333        // Prefer sender - data flows through drain task which has correct channel_id
334        if let Some(tx) = self.sender.inner.as_ref() {
335            tx.send(bytes).await.map_err(|_| TxError::Closed)
336        }
337        // Fallback to direct driver_tx (sender was taken or never set)
338        else if let Some(task_tx) = self.driver_tx.inner.as_ref() {
339            task_tx
340                .send(DriverMessage::Data {
341                    conn_id: self.conn_id,
342                    channel_id: self.channel_id,
343                    payload: bytes,
344                })
345                .await
346                .map_err(|_| TxError::Closed)
347        } else {
348            Err(TxError::Taken)
349        }
350    }
351}
352
353/// When a Tx is dropped, send a Close message.
354///
355/// r[impl channeling.close] - Close terminates the stream.
356///
357/// The Close path depends on how data was sent:
358/// - If sender is present: data went through drain task, drain task sends Close when channel closes
359/// - If only driver_tx is present: data went directly to driver, we send Close via driver_tx
360impl<T: 'static> Drop for Tx<T> {
361    fn drop(&mut self) {
362        // If sender is still present, the drain task will handle Close when
363        // the internal channel closes. Don't send Close via driver_tx because
364        // it would use the wrong channel_id (dispatch-context id vs caller-allocated id).
365        if self.sender.inner.is_some() {
366            // Just drop the sender - drain task handles Close
367            return;
368        }
369
370        // Sender was taken or never set - send Close via driver_tx if available
371        if let Some(task_tx) = self.driver_tx.inner.take() {
372            let conn_id = self.conn_id;
373            let channel_id = self.channel_id;
374            // Use try_send for synchronous Close delivery.
375            // This ensures Close is queued before Response in dispatch_call.
376            //
377            // WARNING: If try_send fails (channel full), we spawn as fallback.
378            // This creates a potential ordering issue where Close could arrive
379            // after Response. To mitigate: task_tx channels should be sized
380            // generously (256+) to make this unlikely. A proper fix would use
381            // unbounded channels for task messages.
382            if task_tx
383                .try_send(DriverMessage::Close {
384                    conn_id,
385                    channel_id,
386                })
387                .is_err()
388            {
389                // Channel full or closed - spawn as fallback (see warning above)
390                crate::runtime::spawn(async move {
391                    let _ = task_tx
392                        .send(DriverMessage::Close {
393                            conn_id,
394                            channel_id,
395                        })
396                        .await;
397                });
398            }
399        }
400    }
401}
402
403/// Error when sending on a Tx stream.
404#[derive(Debug)]
405pub enum TxError {
406    /// Failed to serialize the value.
407    Serialize(facet_postcard::SerializeError),
408    /// The stream channel is closed.
409    Closed,
410    /// The sender was already taken (e.g., by ConnectionHandle::call).
411    Taken,
412}
413
414impl std::fmt::Display for TxError {
415    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
416        match self {
417            TxError::Serialize(e) => write!(f, "serialize error: {e}"),
418            TxError::Closed => write!(f, "stream closed"),
419            TxError::Taken => write!(f, "sender was taken"),
420        }
421    }
422}
423
424impl std::error::Error for TxError {}
425
426// ============================================================================
427// ReceiverSlot - Wrapper for Option<Receiver> that implements Facet
428// ============================================================================
429
430/// A wrapper around `Option<Receiver<Vec<u8>>>` that implements Facet.
431///
432/// This allows `Poke::get_mut::<ReceiverSlot>()` to work, enabling `.take()`
433/// via reflection. Used by `ConnectionHandle::call` to extract receivers from
434/// `Rx<T>` arguments and register them with the stream registry.
435#[derive(Facet)]
436#[facet(opaque)]
437pub struct ReceiverSlot {
438    /// The optional receiver. Public within crate for `Rx::recv()` access.
439    pub(crate) inner: Option<Receiver<Vec<u8>>>,
440}
441
442impl ReceiverSlot {
443    /// Create a slot containing a receiver.
444    pub fn new(rx: Receiver<Vec<u8>>) -> Self {
445        Self { inner: Some(rx) }
446    }
447
448    /// Create an empty slot.
449    pub fn empty() -> Self {
450        Self { inner: None }
451    }
452
453    /// Take the receiver out of the slot, leaving it empty.
454    pub fn take(&mut self) -> Option<Receiver<Vec<u8>>> {
455        self.inner.take()
456    }
457
458    /// Check if the slot contains a receiver.
459    pub fn is_some(&self) -> bool {
460        self.inner.is_some()
461    }
462
463    /// Check if the slot is empty.
464    pub fn is_none(&self) -> bool {
465        self.inner.is_none()
466    }
467
468    /// Set the receiver in this slot.
469    ///
470    /// Used by `ChannelRegistry::bind_streams` to hydrate a deserialized `Rx<T>`
471    /// with an actual channel receiver.
472    pub fn set(&mut self, rx: Receiver<Vec<u8>>) {
473        self.inner = Some(rx);
474    }
475}
476
477/// Rx stream handle - caller receives data from callee.
478///
479/// r[impl channeling.caller-pov] - From caller's perspective, Rx means "I receive".
480/// r[impl channeling.type] - Serializes as u64 stream ID on wire.
481/// r[impl channeling.holder-semantics] - The holder receives from this stream.
482///
483/// # Facet Implementation
484///
485/// Uses `#[facet(proxy = u64)]` so that:
486/// - `channel_id` is pokeable (Connection can walk args and set stream IDs)
487/// - Serializes as just a `u64` on the wire
488/// - `T` is exposed as a type parameter for codegen introspection
489///
490/// The `receiver` field uses `ReceiverSlot` wrapper so that `ConnectionHandle::call`
491/// can use `Poke::get_mut::<ReceiverSlot>()` to `.take()` the receiver and register
492/// it with the stream registry.
493#[derive(Facet)]
494#[facet(proxy = u64)]
495pub struct Rx<T: 'static> {
496    /// The unique stream ID for this stream.
497    /// Public so Connection can poke it when binding streams.
498    pub channel_id: ChannelId,
499    /// Channel receiver for incoming data.
500    /// Uses ReceiverSlot so it's pokeable (can .take() via Poke).
501    pub receiver: ReceiverSlot,
502    /// Phantom data for the element type.
503    #[facet(opaque)]
504    _marker: PhantomData<T>,
505}
506
507/// Serialization: `&Rx<T>` -> u64 (extracts channel_id)
508///
509/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
510#[allow(clippy::infallible_try_from)]
511impl<T: 'static> TryFrom<&Rx<T>> for u64 {
512    type Error = Infallible;
513    fn try_from(rx: &Rx<T>) -> Result<Self, Self::Error> {
514        Ok(rx.channel_id)
515    }
516}
517
518/// Deserialization: u64 -> `Rx<T>` (creates a "hollow" Rx)
519///
520/// The receiver is a placeholder - the real receiver gets set up by Connection
521/// after deserialization when it binds the stream.
522///
523/// Uses TryFrom rather than From because facet's proxy mechanism requires TryFrom.
524#[allow(clippy::infallible_try_from)]
525impl<T: 'static> TryFrom<u64> for Rx<T> {
526    type Error = Infallible;
527    fn try_from(channel_id: u64) -> Result<Self, Self::Error> {
528        // Create a hollow Rx - no actual receiver, Connection will bind later
529        Ok(Rx {
530            channel_id,
531            receiver: ReceiverSlot::empty(),
532            _marker: PhantomData,
533        })
534    }
535}
536
537impl<T: 'static> Rx<T> {
538    /// Create a new Rx stream with the given ID and receiver channel.
539    pub fn new(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
540        Self {
541            channel_id,
542            receiver: ReceiverSlot::new(rx),
543            _marker: PhantomData,
544        }
545    }
546
547    /// Create an unbound Rx with a receiver but channel_id 0.
548    ///
549    /// Used by `roam::channel()` to create a pair before binding.
550    /// Connection will poke the channel_id when binding.
551    pub fn unbound(rx: Receiver<Vec<u8>>) -> Self {
552        Self {
553            channel_id: 0,
554            receiver: ReceiverSlot::new(rx),
555            _marker: PhantomData,
556        }
557    }
558
559    /// Create a bound Rx with channel_id already set.
560    ///
561    /// Used by `roam::channel()` when called during dispatch to create
562    /// response channels. The channel_id will be serialized and sent to
563    /// the client, who will bind a receiver for incoming Data.
564    pub fn bound(channel_id: ChannelId, rx: Receiver<Vec<u8>>) -> Self {
565        Self {
566            channel_id,
567            receiver: ReceiverSlot::new(rx),
568            _marker: PhantomData,
569        }
570    }
571
572    /// Get the stream ID.
573    pub fn channel_id(&self) -> ChannelId {
574        self.channel_id
575    }
576
577    /// Receive the next value from this stream.
578    ///
579    /// Returns `Ok(Some(value))` for each received value,
580    /// `Ok(None)` when the stream is closed,
581    /// or `Err` if deserialization fails.
582    ///
583    /// r[impl channeling.data] - Deserialize Data message payloads.
584    /// r[impl channeling.data.invalid] - Caller must send Goodbye on deserialize error.
585    pub async fn recv(&mut self) -> Result<Option<T>, RxError>
586    where
587        T: Facet<'static>,
588    {
589        let rx = self.receiver.inner.as_mut().ok_or(RxError::Taken)?;
590        match rx.recv().await {
591            Some(bytes) => {
592                let value = facet_postcard::from_slice(&bytes).map_err(RxError::Deserialize)?;
593                Ok(Some(value))
594            }
595            None => Ok(None),
596        }
597    }
598}
599
600/// Error when receiving from a Rx stream.
601#[derive(Debug)]
602pub enum RxError {
603    /// Failed to deserialize the value.
604    Deserialize(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
605    /// The receiver was already taken (e.g., by ConnectionHandle::call).
606    Taken,
607}
608
609impl std::fmt::Display for RxError {
610    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
611        match self {
612            RxError::Deserialize(e) => write!(f, "deserialize error: {e}"),
613            RxError::Taken => write!(f, "receiver was taken"),
614        }
615    }
616}
617
618impl std::error::Error for RxError {}
619
620// ============================================================================
621// Channel creation
622// ============================================================================
623
624/// Create an unbound channel pair for streaming RPC.
625///
626/// Returns `(Tx<T>, Rx<T>)` with `channel_id: 0`. The `ConnectionHandle::call`
627/// method will walk the args, find `Rx<T>` or `Tx<T>` fields, assign stream IDs,
628/// and take the internal channel handles to register with the stream registry.
629///
630/// # Channel semantics (like regular mpsc)
631///
632/// - If caller wants to **send** data: pass `rx`, keep `tx`
633/// - If caller wants to **receive** data: pass `tx`, keep `rx`
634///
635/// # Example
636///
637/// ```ignore
638/// // sum(numbers: Rx<i32>) -> i64
639/// let (tx, rx) = roam::channel::<i32>();
640/// let fut = client.sum(rx);  // pass rx, keep tx
641/// tx.send(1).await;
642/// tx.send(2).await;
643/// drop(tx);
644/// let sum = fut.await?;
645/// ```
646pub fn channel<T: 'static>() -> (Tx<T>, Rx<T>) {
647    let (sender, receiver) = crate::runtime::channel(CHANNEL_SIZE);
648
649    // Check if we're in a dispatch context - if so, create bound channels
650    if let Some(ctx) = get_dispatch_context() {
651        let channel_id = ctx.channel_ids.next();
652        debug!(channel_id, "roam::channel() creating bound channel pair");
653        (
654            Tx::bound(ctx.conn_id, channel_id, sender, ctx.driver_tx.clone()),
655            Rx::bound(channel_id, receiver),
656        )
657    } else {
658        trace!("roam::channel() creating unbound channel pair (no dispatch context)");
659        (Tx::unbound(sender), Rx::unbound(receiver))
660    }
661}
662
663// ============================================================================
664// Dispatch Context (task-local for response channel binding)
665// ============================================================================
666
667/// Context for binding response channels during dispatch.
668///
669/// When a service handler creates a channel with `roam::channel()` and returns
670/// the Rx, the Tx needs to be bound to send Data over the wire. This context
671/// provides the channel ID allocator and driver_tx needed for binding.
672#[derive(Clone)]
673struct DispatchContext {
674    conn_id: roam_wire::ConnectionId,
675    channel_ids: Arc<ChannelIdAllocator>,
676    driver_tx: Sender<DriverMessage>,
677}
678
679roam_task_local::task_local! {
680    /// Task-local dispatch context. Using task_local instead of thread_local
681    /// is critical: thread_local can leak across different async tasks that
682    /// happen to run on the same worker thread, causing channel binding bugs.
683    static DISPATCH_CONTEXT: DispatchContext;
684}
685
686/// Get the current dispatch context, if any.
687fn get_dispatch_context() -> Option<DispatchContext> {
688    DISPATCH_CONTEXT.try_with(|ctx| ctx.clone()).ok()
689}
690
691// ============================================================================
692// Stream Registry
693// ============================================================================
694
695use std::collections::{HashMap, HashSet};
696
697/// Response data returned from a call, including any response stream channels.
698#[derive(Debug)]
699pub struct ResponseData {
700    /// The response payload bytes.
701    pub payload: Vec<u8>,
702    /// Channel IDs for streams in the response (`Rx<T>` returned by the method).
703    /// Client must register receivers for these channels.
704    pub channels: Vec<u64>,
705}
706
707/// All messages to the connection driver go through a single channel.
708///
709/// This unified channel ensures FIFO ordering: a Call followed by Data
710/// will always be processed in that order, preventing race conditions
711/// where Data could arrive before the Request is sent.
712pub enum DriverMessage {
713    /// Send a Request and expect a Response (client-side call).
714    Call {
715        conn_id: roam_wire::ConnectionId,
716        request_id: u64,
717        method_id: u64,
718        metadata: Vec<(String, roam_wire::MetadataValue)>,
719        /// Channel IDs used by this call (Tx/Rx), in declaration order.
720        channels: Vec<u64>,
721        payload: Vec<u8>,
722        response_tx: OneshotSender<Result<ResponseData, TransportError>>,
723    },
724    /// Send a Data message on a stream.
725    Data {
726        conn_id: roam_wire::ConnectionId,
727        channel_id: ChannelId,
728        payload: Vec<u8>,
729    },
730    /// Send a Close message to end a stream.
731    Close {
732        conn_id: roam_wire::ConnectionId,
733        channel_id: ChannelId,
734    },
735    /// Send a Response message (server-side call completed).
736    Response {
737        conn_id: roam_wire::ConnectionId,
738        request_id: u64,
739        /// Channel IDs for streams in the response (Tx/Rx returned by the method).
740        channels: Vec<u64>,
741        payload: Vec<u8>,
742    },
743    /// Request to open a new virtual connection.
744    Connect {
745        request_id: u64,
746        metadata: roam_wire::Metadata,
747        response_tx: OneshotSender<Result<ConnectionHandle, crate::ConnectError>>,
748        /// Dispatcher for handling incoming requests on the virtual connection.
749        /// If None, the connection can only make calls, not receive them.
750        dispatcher: Option<Box<dyn ServiceDispatcher>>,
751    },
752}
753
754/// Registry of active streams for a connection.
755///
756/// Handles incoming streams (Data from wire → `Rx<T>` / `Tx<T>` handles).
757/// For outgoing streams (server `Tx<T>` args), spawned tasks drain receivers
758/// and send Data/Close messages via `driver_tx`.
759///
760/// r[impl channeling.unknown] - Unknown stream IDs cause Goodbye.
761pub struct ChannelRegistry {
762    /// Connection ID this registry belongs to.
763    conn_id: roam_wire::ConnectionId,
764
765    /// Streams where we receive Data messages (backing `Rx<T>` or `Tx<T>` handles on our side).
766    /// Key: channel_id, Value: sender to route Data payloads to the handle.
767    incoming: HashMap<ChannelId, Sender<Vec<u8>>>,
768
769    /// Stream IDs that have been closed.
770    /// Used to detect data-after-close violations.
771    ///
772    /// r[impl channeling.data-after-close] - Track closed streams.
773    closed: HashSet<ChannelId>,
774
775    // ========================================================================
776    // Flow Control
777    // ========================================================================
778    /// r[impl flow.channel.credit-based] - Credit tracking for incoming streams.
779    /// r[impl flow.channel.all-transports] - Flow control applies to all transports.
780    /// This is the credit we've granted to the peer - bytes they can still send us.
781    /// Decremented when we receive Data, incremented when we send Credit.
782    incoming_credit: HashMap<ChannelId, u32>,
783
784    /// r[impl flow.channel.credit-based] - Credit tracking for outgoing streams.
785    /// r[impl flow.channel.all-transports] - Flow control applies to all transports.
786    /// This is the credit peer granted us - bytes we can still send them.
787    /// Decremented when we send Data, incremented when we receive Credit.
788    outgoing_credit: HashMap<ChannelId, u32>,
789
790    /// Initial credit to grant new streams.
791    /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
792    initial_credit: u32,
793
794    /// Unified channel for all messages to the driver.
795    /// The driver owns the receiving end and sends these on the wire.
796    /// Using a single channel ensures FIFO ordering.
797    driver_tx: Sender<DriverMessage>,
798
799    /// Channel ID allocator for response channels created during dispatch.
800    /// These are channels returned by service methods (e.g., `subscribe() -> Rx<Event>`).
801    response_channel_ids: Arc<ChannelIdAllocator>,
802}
803
804impl ChannelRegistry {
805    /// Create a new registry with the given conn_id, initial credit, driver channel, and role.
806    ///
807    /// The `driver_tx` is used to send all messages (Call/Data/Close/Response)
808    /// to the driver for transmission on the wire.
809    ///
810    /// The `role` determines channel ID parity for response channels:
811    /// - Acceptor (server) uses even IDs
812    /// - Initiator (client) uses odd IDs
813    ///
814    /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
815    pub fn new_with_credit_and_role(
816        conn_id: roam_wire::ConnectionId,
817        initial_credit: u32,
818        driver_tx: Sender<DriverMessage>,
819        role: Role,
820    ) -> Self {
821        Self {
822            conn_id,
823            incoming: HashMap::new(),
824            closed: HashSet::new(),
825            incoming_credit: HashMap::new(),
826            outgoing_credit: HashMap::new(),
827            initial_credit,
828            driver_tx,
829            response_channel_ids: Arc::new(ChannelIdAllocator::new(role)),
830        }
831    }
832
833    /// Create a new registry with the given initial credit and driver channel.
834    /// Uses ROOT conn_id and Acceptor role for backward compatibility (server-side usage).
835    ///
836    /// r[impl flow.channel.initial-credit] - Each stream starts with this credit.
837    pub fn new_with_credit(initial_credit: u32, driver_tx: Sender<DriverMessage>) -> Self {
838        Self::new_with_credit_and_role(
839            roam_wire::ConnectionId::ROOT,
840            initial_credit,
841            driver_tx,
842            Role::Acceptor,
843        )
844    }
845
846    /// Create a new registry with default infinite credit.
847    ///
848    /// r[impl flow.channel.infinite-credit] - Implementations MAY use very large credit.
849    /// r[impl flow.channel.zero-credit] - With infinite credit, zero-credit never occurs.
850    /// This disables backpressure but simplifies implementation.
851    pub fn new(driver_tx: Sender<DriverMessage>) -> Self {
852        Self::new_with_credit(u32::MAX, driver_tx)
853    }
854
855    /// Get the connection ID for this registry.
856    pub fn conn_id(&self) -> roam_wire::ConnectionId {
857        self.conn_id
858    }
859
860    /// Get the dispatch context for response channel binding.
861    ///
862    /// Used by `dispatch_call` and `dispatch_call_infallible` to set up
863    /// thread-local context so `roam::channel()` can create bound channels.
864    pub(crate) fn dispatch_context(&self) -> DispatchContext {
865        DispatchContext {
866            conn_id: self.conn_id,
867            channel_ids: self.response_channel_ids.clone(),
868            driver_tx: self.driver_tx.clone(),
869        }
870    }
871
872    /// Get a clone of the driver message sender.
873    ///
874    /// Used by codegen to spawn tasks that send Data/Close/Response messages.
875    pub fn driver_tx(&self) -> Sender<DriverMessage> {
876        self.driver_tx.clone()
877    }
878
879    /// Get the response channel ID allocator.
880    /// Used by ForwardingDispatcher to allocate downstream channel IDs for response channels.
881    pub fn response_channel_ids(&self) -> Arc<ChannelIdAllocator> {
882        self.response_channel_ids.clone()
883    }
884
885    /// Register an incoming stream.
886    ///
887    /// The connection layer will route Data messages for this channel_id to the sender.
888    /// Used for both `Rx<T>` (caller receives from callee) and `Tx<T>` (callee sends to caller).
889    ///
890    /// r[impl flow.channel.initial-credit] - Stream starts with initial credit.
891    pub fn register_incoming(&mut self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
892        self.incoming.insert(channel_id, tx);
893        // Grant initial credit - peer can send us this many bytes
894        self.incoming_credit.insert(channel_id, self.initial_credit);
895    }
896
897    /// Register credit tracking for an outgoing stream.
898    ///
899    /// The actual receiver is NOT stored here - the driver owns it directly.
900    /// This only sets up credit tracking for the stream.
901    ///
902    /// r[impl flow.channel.initial-credit] - Stream starts with initial credit.
903    pub fn register_outgoing_credit(&mut self, channel_id: ChannelId) {
904        // Assume peer grants us initial credit - we can send them this many bytes
905        self.outgoing_credit.insert(channel_id, self.initial_credit);
906    }
907
908    /// Route a Data message payload to the appropriate incoming stream.
909    ///
910    /// Returns Ok(()) if routed successfully, Err(ChannelError) otherwise.
911    ///
912    /// r[impl channeling.data] - Data messages routed by channel_id.
913    /// r[impl channeling.data-after-close] - Reject data on closed streams.
914    /// r[impl flow.channel.credit-overrun] - Reject if data exceeds remaining credit.
915    /// r[impl flow.channel.credit-consume] - Deduct bytes from remaining credit.
916    /// r[impl flow.channel.byte-accounting] - Credit measured in payload bytes.
917    ///
918    /// Returns a sender and payload if routing is allowed, or an error.
919    /// The actual send must be done by the caller to avoid holding locks across await.
920    pub fn prepare_route_data(
921        &mut self,
922        channel_id: ChannelId,
923        payload: Vec<u8>,
924    ) -> Result<(Sender<Vec<u8>>, Vec<u8>), ChannelError> {
925        // Check for data-after-close
926        if self.closed.contains(&channel_id) {
927            return Err(ChannelError::DataAfterClose);
928        }
929
930        // Check credit before routing
931        // r[impl flow.channel.credit-overrun] - Reject if exceeds credit
932        let payload_len = payload.len() as u32;
933        if let Some(credit) = self.incoming_credit.get_mut(&channel_id) {
934            if payload_len > *credit {
935                return Err(ChannelError::CreditOverrun);
936            }
937            // r[impl flow.channel.credit-consume] - Deduct from credit
938            *credit -= payload_len;
939        }
940        // Note: if no credit entry exists, the stream may not be registered yet
941        // (e.g., Rx stream created by callee). In that case, skip credit check.
942
943        if let Some(tx) = self.incoming.get(&channel_id) {
944            Ok((tx.clone(), payload))
945        } else {
946            Err(ChannelError::Unknown)
947        }
948    }
949
950    /// Route a Data message payload to the appropriate incoming stream.
951    ///
952    /// Returns Ok(()) if routed successfully, Err(ChannelError) otherwise.
953    ///
954    /// r[impl channeling.data] - Data messages routed by channel_id.
955    /// r[impl channeling.data-after-close] - Reject data on closed streams.
956    /// r[impl flow.channel.credit-overrun] - Reject if data exceeds remaining credit.
957    /// r[impl flow.channel.credit-consume] - Deduct bytes from remaining credit.
958    /// r[impl flow.channel.byte-accounting] - Credit measured in payload bytes.
959    pub async fn route_data(
960        &mut self,
961        channel_id: ChannelId,
962        payload: Vec<u8>,
963    ) -> Result<(), ChannelError> {
964        let (tx, payload) = self.prepare_route_data(channel_id, payload)?;
965        // If send fails, the Rx<T> was dropped - that's okay, just drop the data
966        let _ = tx.send(payload).await;
967        Ok(())
968    }
969
970    /// Close an incoming stream (remove from registry).
971    ///
972    /// Dropping the sender will cause the `Rx<T>`'s recv() to return None.
973    ///
974    /// r[impl channeling.close] - Close terminates the stream.
975    /// r[impl flow.channel.close-exempt] - Close doesn't consume credit.
976    pub fn close(&mut self, channel_id: ChannelId) {
977        self.incoming.remove(&channel_id);
978        self.incoming_credit.remove(&channel_id);
979        self.outgoing_credit.remove(&channel_id);
980        self.closed.insert(channel_id);
981    }
982
983    /// Reset a stream (remove from registry, discard credit).
984    ///
985    /// r[impl channeling.reset] - Reset terminates the stream abruptly.
986    /// r[impl channeling.reset.credit] - Outstanding credit is lost on reset.
987    pub fn reset(&mut self, channel_id: ChannelId) {
988        self.incoming.remove(&channel_id);
989        self.incoming_credit.remove(&channel_id);
990        self.outgoing_credit.remove(&channel_id);
991        self.closed.insert(channel_id);
992    }
993
994    /// Receive a Credit message - add credit for an outgoing stream.
995    ///
996    /// r[impl flow.channel.credit-grant] - Credit message adds to available credit.
997    /// r[impl flow.channel.credit-additive] - Credit accumulates additively.
998    pub fn receive_credit(&mut self, channel_id: ChannelId, bytes: u32) {
999        if let Some(credit) = self.outgoing_credit.get_mut(&channel_id) {
1000            // r[impl flow.channel.credit-additive] - Add to existing credit
1001            *credit = credit.saturating_add(bytes);
1002        }
1003        // If no entry, stream may be closed or unknown - ignore
1004    }
1005
1006    /// Check if a stream ID is registered (either incoming or outgoing credit).
1007    pub fn contains(&self, channel_id: ChannelId) -> bool {
1008        self.incoming.contains_key(&channel_id) || self.outgoing_credit.contains_key(&channel_id)
1009    }
1010
1011    /// Check if a stream ID is registered as incoming.
1012    pub fn contains_incoming(&self, channel_id: ChannelId) -> bool {
1013        self.incoming.contains_key(&channel_id)
1014    }
1015
1016    /// Check if a stream ID has outgoing credit registered.
1017    pub fn contains_outgoing(&self, channel_id: ChannelId) -> bool {
1018        self.outgoing_credit.contains_key(&channel_id)
1019    }
1020
1021    /// Check if a stream has been closed.
1022    pub fn is_closed(&self, channel_id: ChannelId) -> bool {
1023        self.closed.contains(&channel_id)
1024    }
1025
1026    /// Get the number of active outgoing streams (by credit tracking).
1027    pub fn outgoing_count(&self) -> usize {
1028        self.outgoing_credit.len()
1029    }
1030
1031    /// Get remaining credit for an outgoing stream.
1032    ///
1033    /// Returns None if stream is not registered.
1034    pub fn outgoing_credit(&self, channel_id: ChannelId) -> Option<u32> {
1035        self.outgoing_credit.get(&channel_id).copied()
1036    }
1037
1038    /// Get remaining credit we've granted for an incoming stream.
1039    ///
1040    /// Returns None if stream is not registered.
1041    pub fn incoming_credit(&self, channel_id: ChannelId) -> Option<u32> {
1042        self.incoming_credit.get(&channel_id).copied()
1043    }
1044
1045    /// Bind streams in deserialized args for server-side dispatch.
1046    ///
1047    /// Walks the args using Poke reflection to find any `Rx<T>` or `Tx<T>` fields.
1048    /// For each stream found:
1049    /// - For `Rx<T>`: creates a channel, sets the receiver slot, registers for incoming data
1050    /// - For `Tx<T>`: sets the task_tx so send() writes directly to the wire
1051    ///
1052    /// # Example
1053    ///
1054    /// ```ignore
1055    /// let mut args = facet_postcard::from_slice::<(Rx<i32>, Tx<String>)>(&payload)?;
1056    /// registry.bind_streams(&mut args);
1057    /// let (input, output) = args;
1058    /// // ... call handler with input, output ...
1059    /// // When handler returns and Tx is dropped, Close is sent automatically
1060    /// ```
1061    pub fn bind_streams<T: Facet<'static>>(&mut self, args: &mut T) {
1062        let poke = facet::Poke::new(args);
1063        self.bind_streams_recursive(poke);
1064    }
1065
1066    /// Recursively walk a Poke value looking for Rx/Tx streams to bind.
1067    #[allow(unsafe_code)]
1068    fn bind_streams_recursive(&mut self, mut poke: facet::Poke<'_, '_>) {
1069        use facet::Def;
1070
1071        let shape = poke.shape();
1072
1073        trace!(
1074            module_path = ?shape.module_path,
1075            type_identifier = shape.type_identifier,
1076            "bind_streams_recursive: visiting type"
1077        );
1078
1079        // Check if this is an Rx or Tx type
1080        if shape.module_path == Some("roam_session") {
1081            if shape.type_identifier == "Rx" {
1082                debug!("bind_streams_recursive: found Rx, binding");
1083                self.bind_rx_stream(poke);
1084                return;
1085            } else if shape.type_identifier == "Tx" {
1086                debug!("bind_streams_recursive: found Tx, binding");
1087                self.bind_tx_stream(poke);
1088                return;
1089            }
1090        }
1091
1092        // Dispatch based on the shape's definition
1093        match shape.def {
1094            Def::Scalar => {}
1095
1096            // Recurse into struct/tuple fields
1097            _ if poke.is_struct() => {
1098                let mut ps = poke.into_struct().expect("is_struct was true");
1099                let field_count = ps.field_count();
1100                trace!(field_count, "bind_streams_recursive: recursing into struct");
1101                for i in 0..field_count {
1102                    if let Ok(field_poke) = ps.field(i) {
1103                        self.bind_streams_recursive(field_poke);
1104                    }
1105                }
1106            }
1107
1108            // Recurse into Option<T>
1109            Def::Option(_) => {
1110                // Option is represented as an enum, use into_enum to access its value
1111                if let Ok(mut pe) = poke.into_enum()
1112                    && let Ok(Some(inner_poke)) = pe.field(0)
1113                {
1114                    self.bind_streams_recursive(inner_poke);
1115                }
1116            }
1117
1118            // Recurse into list elements (e.g., Vec<Tx<T>>)
1119            Def::List(list_def) => {
1120                let len = {
1121                    let peek = poke.as_peek();
1122                    peek.into_list().map(|pl| pl.len()).unwrap_or(0)
1123                };
1124                // Get mutable access to elements via VTable (no PokeList exists)
1125                if let Some(get_mut_fn) = list_def.vtable.get_mut {
1126                    let element_shape = list_def.t;
1127                    let data_ptr = poke.data_mut();
1128                    for i in 0..len {
1129                        // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
1130                        let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
1131                        if let Some(ptr) = element_ptr {
1132                            // SAFETY: ptr points to a valid element with the correct shape
1133                            let element_poke =
1134                                unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
1135                            self.bind_streams_recursive(element_poke);
1136                        }
1137                    }
1138                }
1139            }
1140
1141            // Other enum variants
1142            _ if poke.is_enum() => {
1143                if let Ok(mut pe) = poke.into_enum()
1144                    && let Ok(Some(variant_poke)) = pe.field(0)
1145                {
1146                    self.bind_streams_recursive(variant_poke);
1147                }
1148            }
1149
1150            _ => {}
1151        }
1152    }
1153
1154    /// Bind an Rx<T> stream for server-side dispatch.
1155    ///
1156    /// Server receives data from client on this stream.
1157    /// Creates a channel, sets the receiver slot, registers the sender for routing.
1158    fn bind_rx_stream(&mut self, poke: facet::Poke<'_, '_>) {
1159        if let Ok(mut ps) = poke.into_struct() {
1160            // Get the channel_id that was deserialized from the wire
1161            let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
1162                && let Ok(id_ref) = channel_id_field.get::<ChannelId>()
1163            {
1164                *id_ref
1165            } else {
1166                warn!("bind_rx_stream: could not get channel_id field");
1167                return;
1168            };
1169
1170            debug!(channel_id, "bind_rx_stream: registering incoming channel");
1171
1172            // Create channel and set receiver slot
1173            let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
1174
1175            if let Ok(mut receiver_field) = ps.field_by_name("receiver")
1176                && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
1177            {
1178                slot.set(rx);
1179            }
1180
1181            // Register for incoming data routing
1182            self.register_incoming(channel_id, tx);
1183            debug!(channel_id, "bind_rx_stream: channel registered");
1184        } else {
1185            warn!("bind_rx_stream: could not convert poke to struct");
1186        }
1187    }
1188
1189    /// Bind a Tx<T> stream for server-side dispatch.
1190    ///
1191    /// Server sends data to client on this stream.
1192    /// Sets the conn_id and driver_tx so Tx::send() writes DriverMessage::Data to the wire.
1193    /// When the Tx is dropped, it sends DriverMessage::Close automatically.
1194    fn bind_tx_stream(&mut self, poke: facet::Poke<'_, '_>) {
1195        if let Ok(mut ps) = poke.into_struct() {
1196            // Set conn_id so Data/Close messages go to the correct virtual connection
1197            // r[impl core.conn.independence]
1198            if let Ok(mut conn_id_field) = ps.field_by_name("conn_id")
1199                && let Ok(id_ref) = conn_id_field.get_mut::<roam_wire::ConnectionId>()
1200            {
1201                *id_ref = self.conn_id;
1202            }
1203
1204            // Set driver_tx so Tx::send() can write directly to the wire
1205            if let Ok(mut driver_tx_field) = ps.field_by_name("driver_tx")
1206                && let Ok(slot) = driver_tx_field.get_mut::<DriverTxSlot>()
1207            {
1208                slot.set(self.driver_tx.clone());
1209            }
1210        }
1211    }
1212}
1213
1214/// Error when routing stream data.
1215#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1216pub enum ChannelError {
1217    /// Stream ID not found in registry.
1218    Unknown,
1219    /// Data received after stream was closed.
1220    DataAfterClose,
1221    /// r[impl flow.channel.credit-overrun] - Data exceeded remaining credit.
1222    CreditOverrun,
1223}
1224
1225// ============================================================================
1226// Flow Control
1227// ============================================================================
1228
1229/// Abstraction for stream flow control mechanism.
1230///
1231/// Different transports implement credit-based flow control differently:
1232/// - **Stream transports** (TCP, WebSocket): explicit `Message::Credit` on the wire
1233/// - **SHM**: shared atomic counters in the channel table (`ChannelEntry::granted_total`)
1234///
1235/// This trait abstracts the mechanism while `ChannelRegistry` remains the source
1236/// of truth for stream lifecycle (routing, ordering, existence).
1237///
1238/// r[impl flow.channel.credit-based]
1239/// r[impl flow.channel.all-transports]
1240pub trait FlowControl: Send {
1241    /// Called when we receive data on a channel (receiver side).
1242    ///
1243    /// The implementation may grant credit back to the sender:
1244    /// - Stream: queue a `Message::Credit` to send
1245    /// - SHM: increment `ChannelEntry::granted_total` atomically
1246    ///
1247    /// r[impl flow.channel.credit-grant]
1248    fn on_data_received(&mut self, channel_id: ChannelId, bytes: u32);
1249
1250    /// Wait until we have enough credit to send `bytes` on a channel (sender side).
1251    ///
1252    /// - Stream: check `ChannelRegistry::outgoing_credit`, wait on notify if insufficient
1253    /// - SHM: poll/futex wait on `granted_total - sent_total >= bytes`
1254    ///
1255    /// Returns `Ok(())` when credit is available, `Err` if the channel is closed/invalid.
1256    ///
1257    /// r[impl flow.channel.zero-credit]
1258    fn wait_for_send_credit(
1259        &mut self,
1260        channel_id: ChannelId,
1261        bytes: u32,
1262    ) -> impl std::future::Future<Output = std::io::Result<()>> + Send;
1263
1264    /// Consume credit after sending data (sender side).
1265    ///
1266    /// Called after successfully sending `bytes` on a channel.
1267    /// - Stream: decrement `ChannelRegistry::outgoing_credit`
1268    /// - SHM: increment local `sent_total`
1269    ///
1270    /// r[impl flow.channel.credit-consume]
1271    fn consume_send_credit(&mut self, channel_id: ChannelId, bytes: u32);
1272}
1273
1274/// No-op flow control for infinite credit mode.
1275///
1276/// r[impl flow.channel.infinite-credit]
1277///
1278/// Used when flow control is disabled or not yet implemented.
1279/// All operations succeed immediately without tracking.
1280#[derive(Debug, Clone, Copy, Default)]
1281pub struct InfiniteCredit;
1282
1283impl FlowControl for InfiniteCredit {
1284    fn on_data_received(&mut self, _channel_id: ChannelId, _bytes: u32) {
1285        // No credit tracking needed
1286    }
1287
1288    async fn wait_for_send_credit(
1289        &mut self,
1290        _channel_id: ChannelId,
1291        _bytes: u32,
1292    ) -> std::io::Result<()> {
1293        // Infinite credit - always available
1294        Ok(())
1295    }
1296
1297    fn consume_send_credit(&mut self, _channel_id: ChannelId, _bytes: u32) {
1298        // No credit tracking needed
1299    }
1300}
1301
1302// ============================================================================
1303// Request ID generation
1304// ============================================================================
1305
1306/// Generates unique request IDs for a connection.
1307///
1308/// r[impl call.request-id.uniqueness] - monotonically increasing counter starting at 1
1309pub struct RequestIdGenerator {
1310    next: AtomicU64,
1311}
1312
1313impl RequestIdGenerator {
1314    /// Create a new generator starting at 1.
1315    pub fn new() -> Self {
1316        Self {
1317            next: AtomicU64::new(1),
1318        }
1319    }
1320
1321    /// Generate the next unique request ID.
1322    pub fn next(&self) -> u64 {
1323        self.next.fetch_add(1, Ordering::Relaxed)
1324    }
1325}
1326
1327impl Default for RequestIdGenerator {
1328    fn default() -> Self {
1329        Self::new()
1330    }
1331}
1332
1333// ============================================================================
1334// Dispatch Helper
1335// ============================================================================
1336
1337/// Helper for dispatching RPC methods with minimal generated code.
1338///
1339/// This function handles the common dispatch pattern:
1340/// 1. Deserialize args from payload
1341/// 2. Bind any Tx/Rx streams via registry
1342/// 3. Call the handler closure
1343/// 4. Encode the result and send Response
1344///
1345/// The generated code just needs to provide a closure that calls the handler method.
1346///
1347/// # Type Parameters
1348///
1349/// - `A`: Args tuple type (must implement Facet for deserialization)
1350/// - `R`: Result ok type (must implement Facet for serialization)
1351/// - `E`: User error type (must implement Facet for serialization)
1352/// - `F`: Handler closure type
1353/// - `Fut`: Future returned by handler
1354///
1355/// # Example
1356///
1357/// ```ignore
1358/// fn dispatch_echo(&self, payload: Vec<u8>, request_id: u64, registry: &mut ChannelRegistry)
1359///     -> Pin<Box<dyn Future<Output = ()> + Send + 'static>>
1360/// {
1361///     let handler = self.handler.clone();
1362///     dispatch_call(payload, request_id, registry, move |args: (String,)| async move {
1363///         handler.echo(args.0).await
1364///     })
1365/// }
1366/// ```
1367///
1368/// The handler returns `Result<R, E>` - user errors are automatically wrapped
1369/// in `RoamError::User(e)` for wire serialization.
1370///
1371/// The `channels` parameter contains channel IDs from the Request message framing.
1372/// These are patched into the deserialized args before binding streams.
1373pub fn dispatch_call<A, R, E, F, Fut>(
1374    cx: &Context,
1375    payload: Vec<u8>,
1376    registry: &mut ChannelRegistry,
1377    handler: F,
1378) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
1379where
1380    A: Facet<'static> + Send,
1381    R: Facet<'static> + Send,
1382    E: Facet<'static> + Send,
1383    F: FnOnce(A) -> Fut + Send + 'static,
1384    Fut: std::future::Future<Output = Result<R, E>> + Send + 'static,
1385{
1386    let conn_id = cx.conn_id;
1387    let request_id = cx.request_id.raw();
1388    let channels = &cx.channels;
1389
1390    // Deserialize args
1391    let mut args: A = match facet_postcard::from_slice(&payload) {
1392        Ok(args) => args,
1393        Err(_) => {
1394            let task_tx = registry.driver_tx();
1395            return Box::pin(async move {
1396                // InvalidPayload error: Result::Err(1) + RoamError::InvalidPayload(2)
1397                let _ = task_tx
1398                    .send(DriverMessage::Response {
1399                        conn_id,
1400                        request_id,
1401                        channels: Vec::new(),
1402                        payload: vec![1, 2],
1403                    })
1404                    .await;
1405            });
1406        }
1407    };
1408
1409    // Patch channel IDs from Request framing into deserialized args
1410    debug!(channels = ?channels, "dispatch_call: patching channel IDs");
1411    patch_channel_ids(&mut args, channels);
1412
1413    // Bind streams via reflection - THIS MUST HAPPEN SYNCHRONOUSLY
1414    debug!("dispatch_call: binding streams SYNC");
1415    registry.bind_streams(&mut args);
1416    debug!("dispatch_call: streams bound SYNC - channels should now be registered");
1417
1418    let task_tx = registry.driver_tx();
1419    let dispatch_ctx = registry.dispatch_context();
1420
1421    // Use task_local scope so roam::channel() creates bound channels.
1422    // This is critical: unlike thread_local, task_local won't leak to other
1423    // tasks that happen to run on the same worker thread.
1424    Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
1425        debug!("dispatch_call: handler ASYNC starting");
1426        let result = handler(args).await;
1427        debug!("dispatch_call: handler ASYNC finished");
1428        let (payload, response_channels) = match result {
1429            Ok(ref ok_result) => {
1430                // Collect channel IDs from the result (e.g., Rx<T> in return type)
1431                let channels = collect_channel_ids(ok_result);
1432                // Result::Ok(0) + serialized value
1433                let mut out = vec![0u8];
1434                match facet_postcard::to_vec(ok_result) {
1435                    Ok(bytes) => out.extend(bytes),
1436                    Err(_) => return,
1437                }
1438                (out, channels)
1439            }
1440            Err(user_error) => {
1441                // Result::Err(1) + RoamError::User(0) + serialized user error
1442                let mut out = vec![1u8, 0u8];
1443                match facet_postcard::to_vec(&user_error) {
1444                    Ok(bytes) => out.extend(bytes),
1445                    Err(_) => return,
1446                }
1447                (out, Vec::new())
1448            }
1449        };
1450
1451        // Send Response with channel IDs for any Rx<T> in the result.
1452        // ForwardingDispatcher uses these to set up Data forwarding.
1453        let _ = task_tx
1454            .send(DriverMessage::Response {
1455                conn_id,
1456                request_id,
1457                channels: response_channels,
1458                payload,
1459            })
1460            .await;
1461    }))
1462}
1463
1464/// Dispatch helper for infallible methods (those that return `T` instead of `Result<T, E>`).
1465///
1466/// Same as `dispatch_call` but for handlers that cannot fail at the application level.
1467pub fn dispatch_call_infallible<A, R, F, Fut>(
1468    cx: &Context,
1469    payload: Vec<u8>,
1470    registry: &mut ChannelRegistry,
1471    handler: F,
1472) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>
1473where
1474    A: Facet<'static> + Send,
1475    R: Facet<'static> + Send,
1476    F: FnOnce(A) -> Fut + Send + 'static,
1477    Fut: std::future::Future<Output = R> + Send + 'static,
1478{
1479    let conn_id = cx.conn_id;
1480    let request_id = cx.request_id.raw();
1481    let channels = &cx.channels;
1482
1483    // Deserialize args
1484    let mut args: A = match facet_postcard::from_slice(&payload) {
1485        Ok(args) => args,
1486        Err(_) => {
1487            let task_tx = registry.driver_tx();
1488            return Box::pin(async move {
1489                // InvalidPayload error: Result::Err(1) + RoamError::InvalidPayload(2)
1490                let _ = task_tx
1491                    .send(DriverMessage::Response {
1492                        conn_id,
1493                        request_id,
1494                        channels: Vec::new(),
1495                        payload: vec![1, 2],
1496                    })
1497                    .await;
1498            });
1499        }
1500    };
1501
1502    // Patch channel IDs from Request framing into deserialized args
1503    patch_channel_ids(&mut args, channels);
1504
1505    // Bind streams via reflection
1506    registry.bind_streams(&mut args);
1507
1508    let task_tx = registry.driver_tx();
1509    let dispatch_ctx = registry.dispatch_context();
1510
1511    // Use task_local scope so roam::channel() creates bound channels.
1512    Box::pin(DISPATCH_CONTEXT.scope(dispatch_ctx, async move {
1513        let result = handler(args).await;
1514
1515        // Collect channel IDs from the result (e.g., Rx<T> in return type)
1516        let response_channels = collect_channel_ids(&result);
1517        if !response_channels.is_empty() {
1518            debug!(
1519                channels = ?response_channels,
1520                "dispatch_call_infallible: collected response channels"
1521            );
1522        }
1523
1524        // Result::Ok(0) + serialized value
1525        let mut payload = vec![0u8];
1526        match facet_postcard::to_vec(&result) {
1527            Ok(bytes) => payload.extend(bytes),
1528            Err(_) => return,
1529        }
1530
1531        // Send Response with channel IDs for any Rx<T> in the result.
1532        // ForwardingDispatcher uses these to set up Data forwarding.
1533        let _ = task_tx
1534            .send(DriverMessage::Response {
1535                conn_id,
1536                request_id,
1537                channels: response_channels,
1538                payload,
1539            })
1540            .await;
1541    }))
1542}
1543
1544/// Send an "unknown method" error response.
1545///
1546/// Used by dispatchers when the method_id doesn't match any known method.
1547pub fn dispatch_unknown_method(
1548    cx: &Context,
1549    registry: &mut ChannelRegistry,
1550) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1551    let conn_id = cx.conn_id;
1552    let request_id = cx.request_id.raw();
1553    let task_tx = registry.driver_tx();
1554    Box::pin(async move {
1555        // UnknownMethod error
1556        let _ = task_tx
1557            .send(DriverMessage::Response {
1558                conn_id,
1559                request_id,
1560                channels: Vec::new(),
1561                payload: vec![1, 1],
1562            })
1563            .await;
1564    })
1565}
1566
1567/// Collect channel IDs from args by walking with Peek.
1568///
1569/// Returns channel IDs in declaration order (depth-first traversal).
1570/// Used by the client to populate the `channels` vec in Request messages.
1571///
1572/// r[impl call.request.channels] - Collects channel IDs in declaration order for the Request.
1573pub fn collect_channel_ids<T: Facet<'static>>(args: &T) -> Vec<u64> {
1574    let mut ids = Vec::new();
1575    let poke = facet::Peek::new(args);
1576    collect_channel_ids_recursive(poke, &mut ids);
1577    ids
1578}
1579
1580fn collect_channel_ids_recursive(peek: facet::Peek<'_, '_>, ids: &mut Vec<u64>) {
1581    let shape = peek.shape();
1582
1583    // Check if this is an Rx or Tx type
1584    if shape.module_path == Some("roam_session")
1585        && (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
1586    {
1587        // Read the channel_id field
1588        if let Ok(ps) = peek.into_struct()
1589            && let Ok(channel_id_field) = ps.field_by_name("channel_id")
1590            && let Ok(&channel_id) = channel_id_field.get::<ChannelId>()
1591        {
1592            ids.push(channel_id);
1593        }
1594        return;
1595    }
1596
1597    // Recurse into struct/tuple fields
1598    if let Ok(ps) = peek.into_struct() {
1599        let field_count = ps.field_count();
1600        for i in 0..field_count {
1601            if let Ok(field_peek) = ps.field(i) {
1602                collect_channel_ids_recursive(field_peek, ids);
1603            }
1604        }
1605        return;
1606    }
1607
1608    // Recurse into Option<T> (specialized handling)
1609    if let Ok(po) = peek.into_option() {
1610        if let Some(inner) = po.value() {
1611            collect_channel_ids_recursive(inner, ids);
1612        }
1613        return;
1614    }
1615
1616    // Recurse into enum variants (for other enums with data)
1617    if let Ok(pe) = peek.into_enum() {
1618        // Try to get the first field of the active variant (e.g., Some(T) has one field)
1619        if let Ok(Some(variant_peek)) = pe.field(0) {
1620            collect_channel_ids_recursive(variant_peek, ids);
1621        }
1622        return;
1623    }
1624
1625    // Recurse into sequences (e.g., Vec<Tx<T>>)
1626    if let Ok(pl) = peek.into_list() {
1627        for element in pl.iter() {
1628            collect_channel_ids_recursive(element, ids);
1629        }
1630    }
1631}
1632
1633/// Patch channel IDs into deserialized args by walking with Poke.
1634///
1635/// Overwrites channel_id fields in Rx/Tx in declaration order.
1636/// Used by the server to apply the authoritative `channels` vec from Request.
1637pub fn patch_channel_ids<T: Facet<'static>>(args: &mut T, channels: &[u64]) {
1638    debug!(channels = ?channels, "patch_channel_ids: patching channels from wire");
1639    let mut idx = 0;
1640    let poke = facet::Poke::new(args);
1641    patch_channel_ids_recursive(poke, channels, &mut idx);
1642}
1643
1644#[allow(unsafe_code)]
1645fn patch_channel_ids_recursive(mut poke: facet::Poke<'_, '_>, channels: &[u64], idx: &mut usize) {
1646    use facet::Def;
1647
1648    let shape = poke.shape();
1649
1650    // Check if this is an Rx or Tx type
1651    if shape.module_path == Some("roam_session")
1652        && (shape.type_identifier == "Rx" || shape.type_identifier == "Tx")
1653    {
1654        // Overwrite the channel_id field
1655        if let Ok(mut ps) = poke.into_struct()
1656            && let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
1657            && let Ok(channel_id_ref) = channel_id_field.get_mut::<ChannelId>()
1658            && *idx < channels.len()
1659        {
1660            *channel_id_ref = channels[*idx];
1661            *idx += 1;
1662        }
1663        return;
1664    }
1665
1666    // Dispatch based on the shape's definition
1667    match shape.def {
1668        Def::Scalar => {}
1669
1670        // Recurse into struct/tuple fields
1671        _ if poke.is_struct() => {
1672            let mut ps = poke.into_struct().expect("is_struct was true");
1673            let field_count = ps.field_count();
1674            for i in 0..field_count {
1675                if let Ok(field_poke) = ps.field(i) {
1676                    patch_channel_ids_recursive(field_poke, channels, idx);
1677                }
1678            }
1679        }
1680
1681        // Recurse into Option<T>
1682        Def::Option(_) => {
1683            // Option is represented as an enum, use into_enum to access its value
1684            if let Ok(mut pe) = poke.into_enum()
1685                && let Ok(Some(inner_poke)) = pe.field(0)
1686            {
1687                patch_channel_ids_recursive(inner_poke, channels, idx);
1688            }
1689        }
1690
1691        // Recurse into list elements (e.g., Vec<Tx<T>>)
1692        Def::List(list_def) => {
1693            let len = {
1694                let peek = poke.as_peek();
1695                peek.into_list().map(|pl| pl.len()).unwrap_or(0)
1696            };
1697            // Get mutable access to elements via VTable (no PokeList exists)
1698            if let Some(get_mut_fn) = list_def.vtable.get_mut {
1699                let element_shape = list_def.t;
1700                let data_ptr = poke.data_mut();
1701                for i in 0..len {
1702                    // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
1703                    let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
1704                    if let Some(ptr) = element_ptr {
1705                        // SAFETY: ptr points to a valid element with the correct shape
1706                        let element_poke =
1707                            unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
1708                        patch_channel_ids_recursive(element_poke, channels, idx);
1709                    }
1710                }
1711            }
1712        }
1713
1714        // Other enum variants
1715        _ if poke.is_enum() => {
1716            if let Ok(mut pe) = poke.into_enum()
1717                && let Ok(Some(variant_poke)) = pe.field(0)
1718            {
1719                patch_channel_ids_recursive(variant_poke, channels, idx);
1720            }
1721        }
1722
1723        _ => {}
1724    }
1725}
1726
1727// ============================================================================
1728// Service Dispatcher
1729// ============================================================================
1730
1731/// Context passed to service method implementations.
1732///
1733/// Contains information about the request that may be useful to the handler:
1734/// - `conn_id`: Which virtual connection the request came from
1735/// - `metadata`: Key-value pairs sent with the request
1736///
1737/// This enables services to identify callers and access per-request metadata.
1738#[derive(Debug, Clone)]
1739pub struct Context {
1740    /// The connection ID this request arrived on.
1741    ///
1742    /// For virtual connections, this identifies which specific connection
1743    /// the request came from, enabling bidirectional communication.
1744    pub conn_id: roam_wire::ConnectionId,
1745
1746    /// The request ID for this call.
1747    ///
1748    /// Unique within the connection; used for response routing and cancellation.
1749    pub request_id: roam_wire::RequestId,
1750
1751    /// The method ID being called.
1752    pub method_id: roam_wire::MethodId,
1753
1754    /// Metadata sent with the request.
1755    ///
1756    /// This is the `metadata` field from the wire `Request` message.
1757    pub metadata: roam_wire::Metadata,
1758
1759    /// Channel IDs from the request, in argument declaration order.
1760    ///
1761    /// Used for stream binding. Proxies can use this to remap channel IDs.
1762    pub channels: Vec<u64>,
1763}
1764
1765impl Context {
1766    /// Create a new context.
1767    pub fn new(
1768        conn_id: roam_wire::ConnectionId,
1769        request_id: roam_wire::RequestId,
1770        method_id: roam_wire::MethodId,
1771        metadata: roam_wire::Metadata,
1772        channels: Vec<u64>,
1773    ) -> Self {
1774        Self {
1775            conn_id,
1776            request_id,
1777            method_id,
1778            metadata,
1779            channels,
1780        }
1781    }
1782
1783    /// Get the connection ID.
1784    pub fn conn_id(&self) -> roam_wire::ConnectionId {
1785        self.conn_id
1786    }
1787
1788    /// Get the request ID.
1789    pub fn request_id(&self) -> roam_wire::RequestId {
1790        self.request_id
1791    }
1792
1793    /// Get the method ID.
1794    pub fn method_id(&self) -> roam_wire::MethodId {
1795        self.method_id
1796    }
1797
1798    /// Get the request metadata.
1799    pub fn metadata(&self) -> &roam_wire::Metadata {
1800        &self.metadata
1801    }
1802
1803    /// Get the channel IDs.
1804    pub fn channels(&self) -> &[u64] {
1805        &self.channels
1806    }
1807}
1808
1809/// Trait for dispatching requests to a service.
1810///
1811/// The dispatcher handles both simple and channeling methods uniformly.
1812/// Stream binding is done via reflection (Poke) on the deserialized args.
1813pub trait ServiceDispatcher: Send + Sync {
1814    /// Returns the method IDs this dispatcher handles.
1815    ///
1816    /// Used by [`RoutedDispatcher`] to determine which methods to route
1817    /// to which dispatcher.
1818    fn method_ids(&self) -> Vec<u64>;
1819
1820    /// Dispatch a request and send the response via the task channel.
1821    ///
1822    /// The dispatcher is responsible for:
1823    /// - Looking up the method by `cx.method_id()`
1824    /// - Deserializing arguments from payload
1825    /// - Patching channel IDs from `cx.channels()` into deserialized args via `patch_channel_ids()`
1826    /// - Binding any Tx/Rx streams via the registry
1827    /// - Calling the service method
1828    /// - Sending Data/Close messages for any Tx streams
1829    /// - Sending the Response message via DriverMessage::Response
1830    ///
1831    /// By using a single channel for Data/Close/Response, correct ordering is guaranteed:
1832    /// all stream Data and Close messages are sent before the Response.
1833    ///
1834    /// The `cx.channels()` contains channel IDs from the Request message framing,
1835    /// in declaration order. For a ForwardingDispatcher, this enables transparent proxying
1836    /// without parsing the payload.
1837    ///
1838    /// Returns a boxed future with `'static` lifetime so it can be spawned.
1839    /// Implementations should clone their service into the future to achieve this.
1840    ///
1841    /// r[impl channeling.allocation.caller] - Stream IDs are from Request.channels (caller allocated).
1842    fn dispatch(
1843        &self,
1844        cx: &Context,
1845        payload: Vec<u8>,
1846        registry: &mut ChannelRegistry,
1847    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>>;
1848}
1849
1850/// A dispatcher that routes to one of two dispatchers based on method ID.
1851///
1852/// Methods handled by `primary` (via [`ServiceDispatcher::method_ids`]) are
1853/// routed to it; all other methods are routed to `fallback`.
1854pub struct RoutedDispatcher<A, B> {
1855    primary: A,
1856    fallback: B,
1857    primary_methods: Vec<u64>,
1858}
1859
1860impl<A, B> RoutedDispatcher<A, B>
1861where
1862    A: ServiceDispatcher,
1863{
1864    /// Create a new routed dispatcher.
1865    ///
1866    /// Methods declared by `primary.method_ids()` are routed to `primary`,
1867    /// all others to `fallback`.
1868    pub fn new(primary: A, fallback: B) -> Self {
1869        let primary_methods = primary.method_ids();
1870        Self {
1871            primary,
1872            fallback,
1873            primary_methods,
1874        }
1875    }
1876}
1877
1878impl<A, B> ServiceDispatcher for RoutedDispatcher<A, B>
1879where
1880    A: ServiceDispatcher,
1881    B: ServiceDispatcher,
1882{
1883    fn method_ids(&self) -> Vec<u64> {
1884        let mut ids = self.primary_methods.clone();
1885        ids.extend(self.fallback.method_ids());
1886        ids
1887    }
1888
1889    fn dispatch(
1890        &self,
1891        cx: &Context,
1892        payload: Vec<u8>,
1893        registry: &mut ChannelRegistry,
1894    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1895        if self.primary_methods.contains(&cx.method_id().raw()) {
1896            self.primary.dispatch(cx, payload, registry)
1897        } else {
1898            self.fallback.dispatch(cx, payload, registry)
1899        }
1900    }
1901}
1902
1903// ============================================================================
1904// ForwardingDispatcher - Transparent RPC Proxy
1905// ============================================================================
1906
1907/// A dispatcher that forwards all requests to an upstream connection.
1908///
1909/// This enables transparent proxying without knowing the service schema.
1910/// Channel IDs are remapped automatically: the proxy allocates new channel IDs
1911/// for the upstream connection and maintains bidirectional forwarding.
1912///
1913/// # Example
1914///
1915/// ```ignore
1916/// use roam_session::{ForwardingDispatcher, ConnectionHandle};
1917///
1918/// // Upstream connection to the actual service
1919/// let upstream: ConnectionHandle = /* ... */;
1920///
1921/// // Create a forwarding dispatcher
1922/// let proxy = ForwardingDispatcher::new(upstream);
1923///
1924/// // Use with accept() - all calls will be forwarded to upstream
1925/// let (handle, driver) = accept(stream, config, proxy).await?;
1926/// ```
1927pub struct ForwardingDispatcher {
1928    upstream: ConnectionHandle,
1929}
1930
1931impl ForwardingDispatcher {
1932    /// Create a new forwarding dispatcher that proxies to the upstream connection.
1933    pub fn new(upstream: ConnectionHandle) -> Self {
1934        Self { upstream }
1935    }
1936}
1937
1938impl Clone for ForwardingDispatcher {
1939    fn clone(&self) -> Self {
1940        Self {
1941            upstream: self.upstream.clone(),
1942        }
1943    }
1944}
1945
1946impl ServiceDispatcher for ForwardingDispatcher {
1947    /// Returns empty - this dispatcher accepts all method IDs.
1948    fn method_ids(&self) -> Vec<u64> {
1949        vec![]
1950    }
1951
1952    fn dispatch(
1953        &self,
1954        cx: &Context,
1955        payload: Vec<u8>,
1956        registry: &mut ChannelRegistry,
1957    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
1958        let task_tx = registry.driver_tx();
1959        let upstream = self.upstream.clone();
1960        let conn_id = cx.conn_id;
1961        let method_id = cx.method_id.raw();
1962        let request_id = cx.request_id.raw();
1963        let channels = cx.channels.clone();
1964
1965        if channels.is_empty() {
1966            // Unary call - but response may contain Rx<T> channels
1967            // We need to set up forwarding for any response channels.
1968            //
1969            // IMPORTANT: Upstream and downstream use different channel ID spaces.
1970            // The upstream channel IDs must be remapped to downstream channel IDs.
1971            let downstream_channel_ids = registry.response_channel_ids();
1972
1973            Box::pin(async move {
1974                let response = upstream
1975                    .call_raw_with_channels(method_id, vec![], payload, None)
1976                    .await;
1977
1978                let (response_payload, upstream_response_channels) = match response {
1979                    Ok(data) => (data.payload, data.channels),
1980                    Err(TransportError::Encode(_)) => {
1981                        // Should not happen for raw call
1982                        (vec![1, 2], Vec::new()) // Err(InvalidPayload)
1983                    }
1984                    Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
1985                        // Connection to upstream failed - return Cancelled
1986                        (vec![1, 3], Vec::new()) // Err(Cancelled)
1987                    }
1988                };
1989
1990                // If response has channels (e.g., method returns Rx<T>),
1991                // set up forwarding for Data from upstream to downstream.
1992                // We allocate new downstream channel IDs and remap when forwarding.
1993                let mut downstream_channels = Vec::new();
1994                if !upstream_response_channels.is_empty() {
1995                    debug!(
1996                        upstream_channels = ?upstream_response_channels,
1997                        "ForwardingDispatcher: setting up response channel forwarding"
1998                    );
1999                    for &upstream_id in &upstream_response_channels {
2000                        // Allocate a downstream channel ID
2001                        let downstream_id = downstream_channel_ids.next();
2002                        downstream_channels.push(downstream_id);
2003
2004                        debug!(
2005                            upstream_id,
2006                            downstream_id, "ForwardingDispatcher: mapping channel IDs"
2007                        );
2008
2009                        // Set up forwarding: upstream → downstream
2010                        let (tx, mut rx) = crate::runtime::channel::<Vec<u8>>(64);
2011                        upstream.register_incoming(upstream_id, tx);
2012
2013                        let task_tx_clone = task_tx.clone();
2014                        crate::runtime::spawn(async move {
2015                            debug!(
2016                                upstream_id,
2017                                downstream_id, "ForwardingDispatcher: forwarding task started"
2018                            );
2019                            while let Some(data) = rx.recv().await {
2020                                debug!(
2021                                    upstream_id,
2022                                    downstream_id,
2023                                    data_len = data.len(),
2024                                    "ForwardingDispatcher: forwarding data"
2025                                );
2026                                let _ = task_tx_clone
2027                                    .send(DriverMessage::Data {
2028                                        conn_id,
2029                                        channel_id: downstream_id,
2030                                        payload: data,
2031                                    })
2032                                    .await;
2033                            }
2034                            debug!(
2035                                upstream_id,
2036                                downstream_id,
2037                                "ForwardingDispatcher: forwarding task ended, sending Close"
2038                            );
2039                            // Channel closed
2040                            let _ = task_tx_clone
2041                                .send(DriverMessage::Close {
2042                                    conn_id,
2043                                    channel_id: downstream_id,
2044                                })
2045                                .await;
2046                        });
2047                    }
2048                }
2049
2050                let _ = task_tx
2051                    .send(DriverMessage::Response {
2052                        conn_id,
2053                        request_id,
2054                        channels: downstream_channels,
2055                        payload: response_payload,
2056                    })
2057                    .await;
2058            })
2059        } else {
2060            // Streaming call - set up bidirectional channel forwarding
2061            //
2062            // IMPORTANT: We must send the upstream Request BEFORE any Data is
2063            // forwarded, otherwise the backend will reject Data for unknown channels.
2064            //
2065            // Strategy:
2066            // 1. Register incoming handlers synchronously (buffers Data in mpsc channels)
2067            // 2. In the async block: send Request first, then spawn forwarding tasks
2068            //    (spawning AFTER Request is sent is safe - ordering is established)
2069
2070            // Allocate upstream channel IDs and set up buffering channels
2071            let mut upstream_channels = Vec::with_capacity(channels.len());
2072            let mut ds_to_us_rxs = Vec::with_capacity(channels.len());
2073            let mut us_to_ds_rxs = Vec::with_capacity(channels.len());
2074            let mut channel_map = Vec::with_capacity(channels.len());
2075
2076            let upstream_task_tx = upstream.driver_tx();
2077
2078            for &downstream_id in &channels {
2079                let upstream_id = upstream.alloc_channel_id();
2080                upstream_channels.push(upstream_id);
2081                channel_map.push((downstream_id, upstream_id));
2082
2083                // Buffer for downstream → upstream (client sends Data)
2084                let (ds_to_us_tx, ds_to_us_rx) = crate::runtime::channel(64);
2085                registry.register_incoming(downstream_id, ds_to_us_tx);
2086                ds_to_us_rxs.push(ds_to_us_rx);
2087
2088                // Buffer for upstream → downstream (server sends Data)
2089                let (us_to_ds_tx, us_to_ds_rx) = crate::runtime::channel(64);
2090                upstream.register_incoming(upstream_id, us_to_ds_tx);
2091                us_to_ds_rxs.push(us_to_ds_rx);
2092            }
2093
2094            // Everything below runs in the async block
2095            Box::pin(async move {
2096                // Send the upstream Request - this queues the Request command
2097                // which will be sent before any Data we forward
2098                let response_future =
2099                    upstream.call_raw_with_channels(method_id, upstream_channels, payload, None);
2100
2101                // Now spawn forwarding tasks - safe because Request is queued first
2102                // and command_tx/task_tx are processed in order by the driver
2103                let upstream_conn_id = upstream.conn_id();
2104                for (i, mut rx) in ds_to_us_rxs.into_iter().enumerate() {
2105                    let upstream_id = channel_map[i].1;
2106                    let upstream_task_tx = upstream_task_tx.clone();
2107                    crate::runtime::spawn(async move {
2108                        while let Some(data) = rx.recv().await {
2109                            let _ = upstream_task_tx
2110                                .send(DriverMessage::Data {
2111                                    conn_id: upstream_conn_id,
2112                                    channel_id: upstream_id,
2113                                    payload: data,
2114                                })
2115                                .await;
2116                        }
2117                        // Channel closed
2118                        let _ = upstream_task_tx
2119                            .send(DriverMessage::Close {
2120                                conn_id: upstream_conn_id,
2121                                channel_id: upstream_id,
2122                            })
2123                            .await;
2124                    });
2125                }
2126
2127                for (i, mut rx) in us_to_ds_rxs.into_iter().enumerate() {
2128                    let downstream_id = channel_map[i].0;
2129                    let task_tx = task_tx.clone();
2130                    crate::runtime::spawn(async move {
2131                        while let Some(data) = rx.recv().await {
2132                            let _ = task_tx
2133                                .send(DriverMessage::Data {
2134                                    conn_id,
2135                                    channel_id: downstream_id,
2136                                    payload: data,
2137                                })
2138                                .await;
2139                        }
2140                        // Channel closed
2141                        let _ = task_tx
2142                            .send(DriverMessage::Close {
2143                                conn_id,
2144                                channel_id: downstream_id,
2145                            })
2146                            .await;
2147                    });
2148                }
2149
2150                // Wait for upstream response
2151                let response = response_future.await;
2152
2153                let (response_payload, upstream_response_channels) = match response {
2154                    Ok(data) => (data.payload, data.channels),
2155                    Err(TransportError::Encode(_)) => {
2156                        (vec![1, 2], Vec::new()) // Err(InvalidPayload)
2157                    }
2158                    Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
2159                        (vec![1, 3], Vec::new()) // Err(Cancelled)
2160                    }
2161                };
2162
2163                // Map upstream response channels back to downstream channel IDs.
2164                // The downstream client allocated the original IDs and expects them
2165                // in the Response, not the upstream IDs we allocated for forwarding.
2166                let downstream_response_channels: Vec<u64> = upstream_response_channels
2167                    .iter()
2168                    .filter_map(|&upstream_id| {
2169                        channel_map
2170                            .iter()
2171                            .find(|(_, us)| *us == upstream_id)
2172                            .map(|(ds, _)| *ds)
2173                    })
2174                    .collect();
2175
2176                let _ = task_tx
2177                    .send(DriverMessage::Response {
2178                        conn_id,
2179                        request_id,
2180                        channels: downstream_response_channels,
2181                        payload: response_payload,
2182                    })
2183                    .await;
2184            })
2185        }
2186    }
2187}
2188
2189// ============================================================================
2190// LateBoundForwarder - Forwarding with Deferred Handle Binding
2191// ============================================================================
2192
2193/// A handle that can be set once after creation.
2194///
2195/// This solves the chicken-and-egg problem in bidirectional proxying where:
2196/// 1. You need to pass a dispatcher to `connect()` for reverse-direction calls
2197/// 2. But the dispatcher needs a handle that's only available after `accept_framed()`
2198///
2199/// # Example
2200///
2201/// ```ignore
2202/// // Create the late-bound handle (empty initially)
2203/// let late_bound = LateBoundHandle::new();
2204///
2205/// // Pass a forwarder using this handle to connect()
2206/// let virtual_conn = handle.connect(
2207///     metadata,
2208///     Some(Box::new(LateBoundForwarder::new(late_bound.clone()))),
2209/// ).await?;
2210///
2211/// // Accept the other connection to get its handle
2212/// let (browser_handle, driver) = accept_framed(transport, config, dispatcher).await?;
2213///
2214/// // NOW bind the handle - any incoming calls will be forwarded
2215/// late_bound.set(browser_handle);
2216/// ```
2217#[derive(Clone)]
2218pub struct LateBoundHandle {
2219    inner: Arc<std::sync::OnceLock<ConnectionHandle>>,
2220}
2221
2222impl LateBoundHandle {
2223    /// Create a new unbound handle.
2224    pub fn new() -> Self {
2225        Self {
2226            inner: Arc::new(std::sync::OnceLock::new()),
2227        }
2228    }
2229
2230    /// Bind the handle to a connection. Can only be called once.
2231    ///
2232    /// # Panics
2233    ///
2234    /// Panics if called more than once.
2235    pub fn set(&self, handle: ConnectionHandle) {
2236        if self.inner.set(handle).is_err() {
2237            panic!("LateBoundHandle::set called more than once");
2238        }
2239    }
2240
2241    /// Try to get the bound handle, if set.
2242    pub fn get(&self) -> Option<&ConnectionHandle> {
2243        self.inner.get()
2244    }
2245}
2246
2247impl Default for LateBoundHandle {
2248    fn default() -> Self {
2249        Self::new()
2250    }
2251}
2252
2253/// A dispatcher that forwards all requests to a late-bound upstream connection.
2254///
2255/// Like [`ForwardingDispatcher`], but the upstream handle is provided after creation
2256/// via [`LateBoundHandle::set`]. This enables bidirectional proxying scenarios.
2257///
2258/// If a request arrives before the handle is bound, it returns `Cancelled`.
2259///
2260/// # Example
2261///
2262/// ```ignore
2263/// // Create late-bound handle and forwarder
2264/// let late_bound = LateBoundHandle::new();
2265/// let forwarder = LateBoundForwarder::new(late_bound.clone());
2266///
2267/// // Use forwarder with connect() for reverse-direction calls
2268/// let virtual_conn = handle.connect(metadata, Some(Box::new(forwarder))).await?;
2269///
2270/// // Later, bind the actual handle
2271/// let (browser_handle, driver) = accept_framed(...).await?;
2272/// late_bound.set(browser_handle);
2273/// ```
2274pub struct LateBoundForwarder {
2275    upstream: LateBoundHandle,
2276}
2277
2278impl LateBoundForwarder {
2279    /// Create a new late-bound forwarding dispatcher.
2280    pub fn new(upstream: LateBoundHandle) -> Self {
2281        Self { upstream }
2282    }
2283}
2284
2285impl Clone for LateBoundForwarder {
2286    fn clone(&self) -> Self {
2287        Self {
2288            upstream: self.upstream.clone(),
2289        }
2290    }
2291}
2292
2293impl ServiceDispatcher for LateBoundForwarder {
2294    fn method_ids(&self) -> Vec<u64> {
2295        vec![]
2296    }
2297
2298    fn dispatch(
2299        &self,
2300        cx: &Context,
2301        payload: Vec<u8>,
2302        registry: &mut ChannelRegistry,
2303    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send + 'static>> {
2304        let task_tx = registry.driver_tx();
2305        let conn_id = cx.conn_id;
2306        let request_id = cx.request_id.raw();
2307
2308        // Try to get the upstream handle
2309        let Some(upstream) = self.upstream.get().cloned() else {
2310            // Handle not bound yet - return Cancelled
2311            debug!(
2312                method_id = cx.method_id.raw(),
2313                "LateBoundForwarder: upstream not bound, returning Cancelled"
2314            );
2315            return Box::pin(async move {
2316                let _ = task_tx
2317                    .send(DriverMessage::Response {
2318                        conn_id,
2319                        request_id,
2320                        channels: vec![],
2321                        payload: vec![1, 3], // Err(Cancelled)
2322                    })
2323                    .await;
2324            });
2325        };
2326
2327        // Delegate to ForwardingDispatcher now that we have the handle
2328        ForwardingDispatcher::new(upstream).dispatch(cx, payload, registry)
2329    }
2330}
2331
2332// TODO: Remove this shim once facet implements `Facet` for `core::convert::Infallible`
2333// and for the never type `!` (facet-rs/facet#1668), then use `Infallible`.
2334#[derive(Debug, Clone, PartialEq, Eq, Facet)]
2335pub struct Never;
2336
2337/// Call error type encoded in RPC responses.
2338///
2339/// r[impl core.error.roam-error] - Wraps call results to distinguish app vs protocol errors
2340/// r[impl call.response.encoding] - Response is `Result<T, RoamError<E>>`
2341/// r[impl call.error.roam-error] - Protocol errors use RoamError variants
2342/// r[impl call.error.protocol] - Discriminants 1-3 are protocol-level errors
2343///
2344/// Spec: `docs/content/spec/_index.md` "RoamError".
2345#[repr(u8)]
2346#[derive(Debug, Clone, PartialEq, Eq, Facet)]
2347pub enum RoamError<E> {
2348    /// r[impl core.error.call-vs-connection] - User errors affect only this call
2349    /// r[impl call.error.user] - User(E) carries the application's error type
2350    User(E) = 0,
2351    /// r[impl call.error.unknown-method] - Method ID not recognized
2352    UnknownMethod = 1,
2353    /// r[impl call.error.invalid-payload] - Request payload deserialization failed
2354    InvalidPayload = 2,
2355    Cancelled = 3,
2356}
2357
2358impl<E> RoamError<E> {
2359    /// Map the user error type to a different type.
2360    pub fn map_user<F, E2>(self, f: F) -> RoamError<E2>
2361    where
2362        F: FnOnce(E) -> E2,
2363    {
2364        match self {
2365            RoamError::User(e) => RoamError::User(f(e)),
2366            RoamError::UnknownMethod => RoamError::UnknownMethod,
2367            RoamError::InvalidPayload => RoamError::InvalidPayload,
2368            RoamError::Cancelled => RoamError::Cancelled,
2369        }
2370    }
2371}
2372
2373pub type CallResult<T, E> = ::core::result::Result<T, RoamError<E>>;
2374pub type BorrowedCallResult<T, E> = OwnedMessage<CallResult<T, E>>;
2375
2376// ============================================================================
2377// Connection Handle (Client-side API)
2378// ============================================================================
2379
2380/// Error from making an outgoing call.
2381///
2382/// This flattens the nested `Result<Result<T, RoamError<E>>, CallError>` pattern
2383/// into a single `Result<T, CallError<E>>` for better ergonomics.
2384///
2385/// The type parameter `E` represents the user's error type from fallible methods.
2386/// For infallible methods, use `CallError<Never>`.
2387#[derive(Debug)]
2388pub enum CallError<E = Never> {
2389    /// The remote returned a roam-level error (user error or protocol error).
2390    Roam(RoamError<E>),
2391    /// Failed to encode request payload.
2392    Encode(facet_postcard::SerializeError),
2393    /// Failed to decode response payload.
2394    Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
2395    /// Protocol-level decode error (malformed response structure).
2396    Protocol(DecodeError),
2397    /// Connection was closed before response.
2398    ConnectionClosed,
2399    /// Driver task is gone.
2400    DriverGone,
2401}
2402
2403impl<E> CallError<E> {
2404    /// Map the user error type to a different type.
2405    pub fn map_user<F, E2>(self, f: F) -> CallError<E2>
2406    where
2407        F: FnOnce(E) -> E2,
2408    {
2409        match self {
2410            CallError::Roam(roam_err) => CallError::Roam(roam_err.map_user(f)),
2411            CallError::Encode(e) => CallError::Encode(e),
2412            CallError::Decode(e) => CallError::Decode(e),
2413            CallError::Protocol(e) => CallError::Protocol(e),
2414            CallError::ConnectionClosed => CallError::ConnectionClosed,
2415            CallError::DriverGone => CallError::DriverGone,
2416        }
2417    }
2418}
2419
2420impl<E: std::fmt::Debug> std::fmt::Display for CallError<E> {
2421    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2422        match self {
2423            CallError::Roam(e) => write!(f, "roam error: {e:?}"),
2424            CallError::Encode(e) => write!(f, "encode error: {e}"),
2425            CallError::Decode(e) => write!(f, "decode error: {e}"),
2426            CallError::Protocol(e) => write!(f, "protocol error: {e}"),
2427            CallError::ConnectionClosed => write!(f, "connection closed"),
2428            CallError::DriverGone => write!(f, "driver task stopped"),
2429        }
2430    }
2431}
2432
2433impl<E: std::fmt::Debug> std::error::Error for CallError<E> {}
2434
2435/// Transport-level call error (no user error type).
2436///
2437/// Used by the `Caller` trait which operates at the transport level
2438/// before response decoding.
2439#[derive(Debug)]
2440pub enum TransportError {
2441    /// Failed to encode request payload.
2442    Encode(facet_postcard::SerializeError),
2443    /// Connection was closed before response.
2444    ConnectionClosed,
2445    /// Driver task is gone.
2446    DriverGone,
2447}
2448
2449impl<E> From<TransportError> for CallError<E> {
2450    fn from(e: TransportError) -> Self {
2451        match e {
2452            TransportError::Encode(e) => CallError::Encode(e),
2453            TransportError::ConnectionClosed => CallError::ConnectionClosed,
2454            TransportError::DriverGone => CallError::DriverGone,
2455        }
2456    }
2457}
2458
2459impl std::fmt::Display for TransportError {
2460    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2461        match self {
2462            TransportError::Encode(e) => write!(f, "encode error: {e}"),
2463            TransportError::ConnectionClosed => write!(f, "connection closed"),
2464            TransportError::DriverGone => write!(f, "driver task stopped"),
2465        }
2466    }
2467}
2468
2469impl std::error::Error for TransportError {}
2470
2471/// Error decoding a response payload.
2472#[derive(Debug)]
2473pub enum DecodeError {
2474    /// Empty response payload.
2475    EmptyPayload,
2476    /// Truncated error response.
2477    TruncatedError,
2478    /// Unknown RoamError discriminant.
2479    UnknownRoamErrorDiscriminant(u8),
2480    /// Invalid Result discriminant.
2481    InvalidResultDiscriminant(u8),
2482    /// Postcard deserialization error.
2483    Postcard(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
2484}
2485
2486impl std::fmt::Display for DecodeError {
2487    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2488        match self {
2489            DecodeError::EmptyPayload => write!(f, "empty response payload"),
2490            DecodeError::TruncatedError => write!(f, "truncated error response"),
2491            DecodeError::UnknownRoamErrorDiscriminant(d) => {
2492                write!(f, "unknown RoamError discriminant: {d}")
2493            }
2494            DecodeError::InvalidResultDiscriminant(d) => {
2495                write!(f, "invalid Result discriminant: {d}")
2496            }
2497            DecodeError::Postcard(e) => write!(f, "postcard: {e}"),
2498        }
2499    }
2500}
2501
2502impl std::error::Error for DecodeError {}
2503
2504impl<E> From<DecodeError> for CallError<E> {
2505    fn from(e: DecodeError) -> Self {
2506        match e {
2507            DecodeError::Postcard(pe) => CallError::Decode(pe),
2508            other => CallError::Protocol(other),
2509        }
2510    }
2511}
2512
2513/// Decode a response payload into the expected type.
2514///
2515/// This is the core response decoding logic used by generated clients.
2516/// It handles the wire format: `[0] + value_bytes` for Ok, `[1, discriminant] + error_bytes` for Err.
2517///
2518/// Returns `Result<T, CallError<E>>` with the decoded value or error.
2519pub fn decode_response<T: Facet<'static>, E: Facet<'static>>(
2520    payload: &[u8],
2521) -> Result<T, CallError<E>> {
2522    if payload.is_empty() {
2523        return Err(DecodeError::EmptyPayload.into());
2524    }
2525
2526    match payload[0] {
2527        0 => {
2528            // Ok variant: deserialize the value
2529            facet_postcard::from_slice(&payload[1..]).map_err(CallError::Decode)
2530        }
2531        1 => {
2532            // Err variant: deserialize RoamError<E>
2533            if payload.len() < 2 {
2534                return Err(DecodeError::TruncatedError.into());
2535            }
2536            let roam_error = match payload[1] {
2537                0 => {
2538                    // User error
2539                    let user_error: E =
2540                        facet_postcard::from_slice(&payload[2..]).map_err(CallError::Decode)?;
2541                    RoamError::User(user_error)
2542                }
2543                1 => RoamError::UnknownMethod,
2544                2 => RoamError::InvalidPayload,
2545                3 => RoamError::Cancelled,
2546                d => return Err(DecodeError::UnknownRoamErrorDiscriminant(d).into()),
2547            };
2548            Err(CallError::Roam(roam_error))
2549        }
2550        d => Err(DecodeError::InvalidResultDiscriminant(d).into()),
2551    }
2552}
2553
2554/// Trait for making RPC calls.
2555///
2556/// This abstracts over different connection types (e.g., `ConnectionHandle`,
2557/// `ReconnectingClient`) so generated clients can work with any of them.
2558///
2559/// All callers return `TransportError` for transport-level failures.
2560/// Generated clients convert this to `CallError<E>` which also includes
2561/// response-level errors like `RoamError::User(E)`.
2562#[allow(async_fn_in_trait)]
2563pub trait Caller: Clone + Send + Sync + 'static {
2564    /// Make an RPC call with the given method ID and arguments.
2565    ///
2566    /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2567    /// assigned channel IDs before serialization.
2568    ///
2569    /// Returns ResponseData containing the payload and any response channel IDs.
2570    #[cfg(not(target_arch = "wasm32"))]
2571    fn call<T: Facet<'static> + Send>(
2572        &self,
2573        method_id: u64,
2574        args: &mut T,
2575    ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send {
2576        self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2577    }
2578
2579    /// Make an RPC call with the given method ID and arguments.
2580    ///
2581    /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2582    /// assigned channel IDs before serialization.
2583    ///
2584    /// Returns ResponseData containing the payload and any response channel IDs.
2585    #[cfg(target_arch = "wasm32")]
2586    fn call<T: Facet<'static> + Send>(
2587        &self,
2588        method_id: u64,
2589        args: &mut T,
2590    ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> {
2591        self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2592    }
2593
2594    /// Make an RPC call with the given method ID, arguments, and metadata.
2595    ///
2596    /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2597    /// assigned channel IDs before serialization.
2598    ///
2599    /// Returns ResponseData containing the payload and any response channel IDs.
2600    #[cfg(not(target_arch = "wasm32"))]
2601    fn call_with_metadata<T: Facet<'static> + Send>(
2602        &self,
2603        method_id: u64,
2604        args: &mut T,
2605        metadata: roam_wire::Metadata,
2606    ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>> + Send;
2607
2608    /// Make an RPC call with the given method ID, arguments, and metadata.
2609    ///
2610    /// The arguments are mutable because stream bindings (Tx/Rx) need to be
2611    /// assigned channel IDs before serialization.
2612    ///
2613    /// Returns ResponseData containing the payload and any response channel IDs.
2614    #[cfg(target_arch = "wasm32")]
2615    fn call_with_metadata<T: Facet<'static> + Send>(
2616        &self,
2617        method_id: u64,
2618        args: &mut T,
2619        metadata: roam_wire::Metadata,
2620    ) -> impl std::future::Future<Output = Result<ResponseData, TransportError>>;
2621
2622    /// Bind receivers for `Rx<T>` streams in the response.
2623    ///
2624    /// After deserializing a response, any `Rx<T>` values in it are "hollow" -
2625    /// they have channel IDs but no actual receiver. This method walks the
2626    /// response and binds receivers for each Rx using the channel IDs from
2627    /// the Response message.
2628    fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]);
2629}
2630
2631impl Caller for ConnectionHandle {
2632    async fn call_with_metadata<T: Facet<'static> + Send>(
2633        &self,
2634        method_id: u64,
2635        args: &mut T,
2636        metadata: roam_wire::Metadata,
2637    ) -> Result<ResponseData, TransportError> {
2638        ConnectionHandle::call_with_metadata(self, method_id, args, metadata).await
2639    }
2640
2641    fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
2642        ConnectionHandle::bind_response_streams(self, response, channels)
2643    }
2644}
2645
2646// ============================================================================
2647// CallFuture - Builder pattern for RPC calls with optional metadata
2648// ============================================================================
2649
2650/// A future representing an RPC call that can be configured with metadata.
2651///
2652/// This provides a builder pattern for RPC calls:
2653/// - `client.method(args).await` - Simple call with default (empty) metadata
2654/// - `client.method(args).with_metadata(meta).await` - Call with custom metadata
2655///
2656/// The future is lazy - the RPC call is not made until `.await` is called.
2657///
2658/// # Example
2659///
2660/// ```ignore
2661/// // Simple call
2662/// let result = client.subscribe(route).await?;
2663///
2664/// // With metadata
2665/// let result = client.subscribe(route)
2666///     .with_metadata(vec![("trace-id".into(), MetadataValue::String("abc".into()))])
2667///     .await?;
2668/// ```
2669pub struct CallFuture<C, Args, Ok, Err>
2670where
2671    C: Caller,
2672    Args: Facet<'static>,
2673{
2674    caller: C,
2675    method_id: u64,
2676    args: Args,
2677    metadata: roam_wire::Metadata,
2678    _phantom: PhantomData<fn() -> (Ok, Err)>,
2679}
2680
2681impl<C, Args, Ok, Err> CallFuture<C, Args, Ok, Err>
2682where
2683    C: Caller,
2684    Args: Facet<'static>,
2685{
2686    /// Create a new CallFuture.
2687    pub fn new(caller: C, method_id: u64, args: Args) -> Self {
2688        Self {
2689            caller,
2690            method_id,
2691            args,
2692            metadata: roam_wire::Metadata::default(),
2693            _phantom: PhantomData,
2694        }
2695    }
2696
2697    /// Set metadata for this call.
2698    ///
2699    /// Metadata is a list of key-value pairs that will be sent with the request.
2700    /// The server can access this via `Context::metadata()`.
2701    pub fn with_metadata(mut self, metadata: roam_wire::Metadata) -> Self {
2702        self.metadata = metadata;
2703        self
2704    }
2705}
2706
2707// On native, the future must be Send so it can be spawned on tokio.
2708// On WASM, futures don't need Send since everything is single-threaded.
2709#[cfg(not(target_arch = "wasm32"))]
2710impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
2711where
2712    C: Caller,
2713    Args: Facet<'static> + Send + 'static,
2714    Ok: Facet<'static> + Send + 'static,
2715    Err: Facet<'static> + Send + 'static,
2716{
2717    type Output = Result<Ok, CallError<Err>>;
2718    type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
2719
2720    fn into_future(self) -> Self::IntoFuture {
2721        let CallFuture {
2722            caller,
2723            method_id,
2724            mut args,
2725            metadata,
2726            _phantom,
2727        } = self;
2728
2729        Box::pin(async move {
2730            let response = caller
2731                .call_with_metadata(method_id, &mut args, metadata)
2732                .await
2733                .map_err(CallError::from)?;
2734            let mut result = decode_response::<Ok, Err>(&response.payload)?;
2735            caller.bind_response_streams(&mut result, &response.channels);
2736            Ok(result)
2737        })
2738    }
2739}
2740
2741#[cfg(target_arch = "wasm32")]
2742impl<C, Args, Ok, Err> std::future::IntoFuture for CallFuture<C, Args, Ok, Err>
2743where
2744    C: Caller,
2745    Args: Facet<'static> + Send + 'static,
2746    Ok: Facet<'static> + Send + 'static,
2747    Err: Facet<'static> + Send + 'static,
2748{
2749    type Output = Result<Ok, CallError<Err>>;
2750    type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output>>>;
2751
2752    fn into_future(self) -> Self::IntoFuture {
2753        let CallFuture {
2754            caller,
2755            method_id,
2756            mut args,
2757            metadata,
2758            _phantom,
2759        } = self;
2760
2761        Box::pin(async move {
2762            let response = caller
2763                .call_with_metadata(method_id, &mut args, metadata)
2764                .await
2765                .map_err(CallError::from)?;
2766            let mut result = decode_response::<Ok, Err>(&response.payload)?;
2767            caller.bind_response_streams(&mut result, &response.channels);
2768            Ok(result)
2769        })
2770    }
2771}
2772
2773/// Shared state between ConnectionHandle and Driver.
2774struct HandleShared {
2775    /// Connection ID for this handle (0 = root connection).
2776    conn_id: roam_wire::ConnectionId,
2777    /// Unified channel to send all messages to the driver.
2778    driver_tx: Sender<DriverMessage>,
2779    /// Request ID generator.
2780    request_ids: RequestIdGenerator,
2781    /// Stream ID allocator.
2782    channel_ids: ChannelIdAllocator,
2783    /// Stream registry for routing incoming data.
2784    /// Protected by a mutex since handles may create streams concurrently.
2785    channel_registry: std::sync::Mutex<ChannelRegistry>,
2786    /// Optional diagnostic state for SIGUSR1 dumps.
2787    diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
2788}
2789
2790/// Handle for making outgoing RPC calls.
2791///
2792/// This is the client-side API. It can be cloned and used from multiple tasks.
2793/// The actual I/O is driven by the `Driver` future which must be spawned.
2794///
2795/// # Example
2796///
2797/// ```ignore
2798/// let (handle, driver) = establish_connection(transport, dispatcher).await?;
2799/// tokio::spawn(driver);
2800///
2801/// // Use handle to make calls
2802/// let response = handle.call_raw(method_id, payload).await?;
2803/// ```
2804#[derive(Clone)]
2805pub struct ConnectionHandle {
2806    shared: Arc<HandleShared>,
2807}
2808
2809impl ConnectionHandle {
2810    /// Create a new handle for the root connection (conn_id = 0).
2811    ///
2812    /// All messages (Call/Data/Close/Response) go through a single unified channel
2813    /// to ensure FIFO ordering.
2814    pub fn new(driver_tx: Sender<DriverMessage>, role: Role, initial_credit: u32) -> Self {
2815        Self::new_with_diagnostics(
2816            roam_wire::ConnectionId::ROOT,
2817            driver_tx,
2818            role,
2819            initial_credit,
2820            None,
2821        )
2822    }
2823
2824    /// Create a new handle with a specific connection ID and optional diagnostic state.
2825    ///
2826    /// If `diagnostic_state` is provided, all RPC calls and channels will be tracked
2827    /// for debugging purposes.
2828    pub fn new_with_diagnostics(
2829        conn_id: roam_wire::ConnectionId,
2830        driver_tx: Sender<DriverMessage>,
2831        role: Role,
2832        initial_credit: u32,
2833        diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
2834    ) -> Self {
2835        let channel_registry = ChannelRegistry::new_with_credit(initial_credit, driver_tx.clone());
2836        Self {
2837            shared: Arc::new(HandleShared {
2838                conn_id,
2839                driver_tx,
2840                request_ids: RequestIdGenerator::new(),
2841                channel_ids: ChannelIdAllocator::new(role),
2842                channel_registry: std::sync::Mutex::new(channel_registry),
2843                diagnostic_state,
2844            }),
2845        }
2846    }
2847
2848    /// Get the connection ID for this handle.
2849    pub fn conn_id(&self) -> roam_wire::ConnectionId {
2850        self.shared.conn_id
2851    }
2852
2853    /// Get the diagnostic state, if any.
2854    pub fn diagnostic_state(&self) -> Option<&Arc<crate::diagnostic::DiagnosticState>> {
2855        self.shared.diagnostic_state.as_ref()
2856    }
2857
2858    /// Make a typed RPC call with automatic serialization and stream binding.
2859    ///
2860    /// Walks the args using Poke reflection to find any `Rx<T>` or `Tx<T>` fields,
2861    /// binds stream IDs, and sets up the stream infrastructure before serialization.
2862    ///
2863    /// # Arguments
2864    ///
2865    /// * `method_id` - The method ID to call
2866    /// * `args` - Arguments to serialize (typically a tuple of all method args).
2867    ///   Must be mutable so stream IDs can be assigned.
2868    ///
2869    /// # Stream Binding
2870    ///
2871    /// For `Rx<T>` in args (caller passes receiver, keeps sender to push data):
2872    /// - Allocates a stream ID
2873    /// - Takes the receiver and spawns a task to drain it, sending Data messages
2874    /// - The caller keeps the `Tx<T>` from `roam::channel()` to send values
2875    ///
2876    /// For `Tx<T>` in args (caller passes sender, keeps receiver to pull data):
2877    /// - Allocates a stream ID
2878    /// - Takes the sender and registers for incoming Data routing
2879    /// - The caller keeps the `Rx<T>` from `roam::channel()` to receive values
2880    ///
2881    /// # Example
2882    ///
2883    /// ```ignore
2884    /// // For a streaming method sum(numbers: Rx<i32>) -> i64
2885    /// let (tx, rx) = roam::channel::<i32>();
2886    /// let response = handle.call(method_id::SUM, &mut (rx,)).await?;
2887    /// // tx.send(&42).await to push values
2888    /// ```
2889    /// Make an RPC call with default (empty) metadata.
2890    pub async fn call<T: Facet<'static>>(
2891        &self,
2892        method_id: u64,
2893        args: &mut T,
2894    ) -> Result<ResponseData, TransportError> {
2895        self.call_with_metadata(method_id, args, roam_wire::Metadata::default())
2896            .await
2897    }
2898
2899    /// Make an RPC call with custom metadata.
2900    pub async fn call_with_metadata<T: Facet<'static>>(
2901        &self,
2902        method_id: u64,
2903        args: &mut T,
2904        metadata: roam_wire::Metadata,
2905    ) -> Result<ResponseData, TransportError> {
2906        // Walk args and bind any streams (allocates channel IDs)
2907        // This collects receivers that need to be drained but does NOT spawn
2908        let mut drains = Vec::new();
2909        debug!("ConnectionHandle::call: binding streams");
2910        self.bind_streams(args, &mut drains);
2911
2912        // Collect channel IDs for the Request message
2913        let channels = collect_channel_ids(args);
2914        debug!(
2915            channels = ?channels,
2916            drain_count = drains.len(),
2917            "ConnectionHandle::call: collected channels after bind_streams"
2918        );
2919
2920        let payload = facet_postcard::to_vec(args).map_err(TransportError::Encode)?;
2921
2922        // Generate args debug info for diagnostics when enabled
2923        let args_debug = if diagnostic::debug_enabled() {
2924            Some(
2925                facet_pretty::PrettyPrinter::new()
2926                    .with_colors(facet_pretty::ColorMode::Never)
2927                    .with_max_content_len(64)
2928                    .format(args),
2929            )
2930        } else {
2931            None
2932        };
2933
2934        if drains.is_empty() {
2935            // No Rx streams - simple call
2936            self.call_raw_with_channels_and_metadata(
2937                method_id, channels, payload, args_debug, metadata,
2938            )
2939            .await
2940        } else {
2941            // Has Rx streams - spawn tasks to drain them
2942            // IMPORTANT: We must send Request BEFORE spawning drain tasks to ensure ordering.
2943            // We need to actually send the DriverMessage::Call to the driver's queue
2944            // before spawning drains, not just create the future.
2945            let request_id = self.shared.request_ids.next();
2946            let (response_tx, response_rx) = oneshot();
2947
2948            // Track outgoing request for diagnostics
2949            if let Some(diag) = &self.shared.diagnostic_state {
2950                let args = args_debug.map(|s| {
2951                    let mut map = std::collections::HashMap::new();
2952                    map.insert("args".to_string(), s);
2953                    map
2954                });
2955                diag.record_outgoing_request(request_id, method_id, args);
2956                // Associate channels with this request
2957                diag.associate_channels_with_request(&channels, request_id);
2958            }
2959
2960            let msg = DriverMessage::Call {
2961                conn_id: self.shared.conn_id,
2962                request_id,
2963                method_id,
2964                metadata,
2965                channels,
2966                payload,
2967                response_tx,
2968            };
2969
2970            // Send the Call message NOW, before spawning drain tasks
2971            if self.shared.driver_tx.send(msg).await.is_err() {
2972                return Err(TransportError::DriverGone);
2973            }
2974
2975            let task_tx = self.shared.channel_registry.lock().unwrap().driver_tx();
2976            let conn_id = self.shared.conn_id;
2977
2978            // Spawn a task for each drain to forward data to driver
2979            for (channel_id, mut rx) in drains {
2980                let task_tx = task_tx.clone();
2981                crate::runtime::spawn(async move {
2982                    loop {
2983                        match rx.recv().await {
2984                            Some(payload) => {
2985                                debug!(
2986                                    "drain task: received {} bytes on channel {}",
2987                                    payload.len(),
2988                                    channel_id
2989                                );
2990                                // Send data to driver
2991                                let _ = task_tx
2992                                    .send(DriverMessage::Data {
2993                                        conn_id,
2994                                        channel_id,
2995                                        payload,
2996                                    })
2997                                    .await;
2998                                debug!(
2999                                    "drain task: sent DriverMessage::Data for channel {}",
3000                                    channel_id
3001                                );
3002                            }
3003                            None => {
3004                                debug!("drain task: channel {} closed", channel_id);
3005                                // Channel closed, send Close and exit
3006                                let _ = task_tx
3007                                    .send(DriverMessage::Close {
3008                                        conn_id,
3009                                        channel_id,
3010                                    })
3011                                    .await;
3012                                debug!(
3013                                    "drain task: sent DriverMessage::Close for channel {}",
3014                                    channel_id
3015                                );
3016                                break;
3017                            }
3018                        }
3019                    }
3020                });
3021            }
3022
3023            // Just await the response - drain tasks run independently
3024            let result = response_rx
3025                .await
3026                .map_err(|_| TransportError::DriverGone)?
3027                .map_err(|_| TransportError::ConnectionClosed);
3028
3029            // Mark request as complete
3030            if let Some(diag) = &self.shared.diagnostic_state {
3031                diag.complete_request(request_id);
3032            }
3033
3034            result
3035        }
3036    }
3037
3038    /// Walk args and bind any Rx<T> or Tx<T> streams.
3039    /// Collects (channel_id, receiver) pairs for Rx streams that need draining.
3040    fn bind_streams<T: Facet<'static>>(
3041        &self,
3042        args: &mut T,
3043        drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3044    ) {
3045        let poke = facet::Poke::new(args);
3046        self.bind_streams_recursive(poke, drains);
3047    }
3048
3049    /// Recursively walk a Poke value looking for Rx/Tx streams to bind.
3050    #[allow(unsafe_code)]
3051    fn bind_streams_recursive(
3052        &self,
3053        mut poke: facet::Poke<'_, '_>,
3054        drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3055    ) {
3056        use facet::Def;
3057
3058        let shape = poke.shape();
3059
3060        // Check if this is an Rx or Tx type
3061        if shape.module_path == Some("roam_session") {
3062            if shape.type_identifier == "Rx" {
3063                self.bind_rx_stream(poke, drains);
3064                return;
3065            } else if shape.type_identifier == "Tx" {
3066                self.bind_tx_stream(poke);
3067                return;
3068            }
3069        }
3070
3071        // Dispatch based on the shape's definition
3072        match shape.def {
3073            Def::Scalar => {}
3074
3075            // Recurse into struct/tuple fields
3076            _ if poke.is_struct() => {
3077                let mut ps = poke.into_struct().expect("is_struct was true");
3078                let field_count = ps.field_count();
3079                for i in 0..field_count {
3080                    if let Ok(field_poke) = ps.field(i) {
3081                        self.bind_streams_recursive(field_poke, drains);
3082                    }
3083                }
3084            }
3085
3086            // Recurse into Option<T>
3087            Def::Option(_) => {
3088                // Option is represented as an enum, use into_enum to access its value
3089                if let Ok(mut pe) = poke.into_enum()
3090                    && let Ok(Some(inner_poke)) = pe.field(0)
3091                {
3092                    self.bind_streams_recursive(inner_poke, drains);
3093                }
3094            }
3095
3096            // Recurse into list elements (e.g., Vec<Tx<T>>)
3097            Def::List(list_def) => {
3098                let len = {
3099                    let peek = poke.as_peek();
3100                    peek.into_list().map(|pl| pl.len()).unwrap_or(0)
3101                };
3102                // Get mutable access to elements via VTable (no PokeList exists)
3103                if let Some(get_mut_fn) = list_def.vtable.get_mut {
3104                    let element_shape = list_def.t;
3105                    let data_ptr = poke.data_mut();
3106                    for i in 0..len {
3107                        // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
3108                        let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
3109                        if let Some(ptr) = element_ptr {
3110                            // SAFETY: ptr points to a valid element with the correct shape
3111                            let element_poke =
3112                                unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
3113                            self.bind_streams_recursive(element_poke, drains);
3114                        }
3115                    }
3116                }
3117            }
3118
3119            // Other enum variants
3120            _ if poke.is_enum() => {
3121                if let Ok(mut pe) = poke.into_enum()
3122                    && let Ok(Some(variant_poke)) = pe.field(0)
3123                {
3124                    self.bind_streams_recursive(variant_poke, drains);
3125                }
3126            }
3127
3128            _ => {}
3129        }
3130    }
3131
3132    /// Bind an Rx<T> stream - caller passes receiver, keeps sender.
3133    /// Collects the receiver for draining (no spawning).
3134    fn bind_rx_stream(
3135        &self,
3136        poke: facet::Poke<'_, '_>,
3137        drains: &mut Vec<(ChannelId, Receiver<Vec<u8>>)>,
3138    ) {
3139        let channel_id = self.alloc_channel_id();
3140        debug!(
3141            channel_id,
3142            "OutgoingBinder::bind_rx_stream: allocated channel_id for Rx"
3143        );
3144
3145        if let Ok(mut ps) = poke.into_struct() {
3146            // Set channel_id field by getting mutable access to the u64
3147            if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
3148                && let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
3149            {
3150                debug!(
3151                    old_id = *id_ref,
3152                    new_id = channel_id,
3153                    "OutgoingBinder::bind_rx_stream: overwriting channel_id"
3154                );
3155                *id_ref = channel_id;
3156            }
3157
3158            // Take the receiver from ReceiverSlot - collect for draining later
3159            if let Ok(mut receiver_field) = ps.field_by_name("receiver")
3160                && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
3161                && let Some(rx) = slot.take()
3162            {
3163                debug!(
3164                    channel_id,
3165                    "OutgoingBinder::bind_rx_stream: took receiver, adding to drains"
3166                );
3167                drains.push((channel_id, rx));
3168            }
3169        }
3170    }
3171
3172    /// Bind a Tx<T> stream - caller passes sender, keeps receiver.
3173    /// We take the sender and register for incoming Data routing.
3174    fn bind_tx_stream(&self, poke: facet::Poke<'_, '_>) {
3175        let channel_id = self.alloc_channel_id();
3176        debug!(
3177            channel_id,
3178            "OutgoingBinder::bind_tx_stream: allocated channel_id for Tx"
3179        );
3180
3181        if let Ok(mut ps) = poke.into_struct() {
3182            // Set channel_id field by getting mutable access to the u64
3183            if let Ok(mut channel_id_field) = ps.field_by_name("channel_id")
3184                && let Ok(id_ref) = channel_id_field.get_mut::<ChannelId>()
3185            {
3186                debug!(
3187                    old_id = *id_ref,
3188                    new_id = channel_id,
3189                    "OutgoingBinder::bind_tx_stream: overwriting channel_id"
3190                );
3191                *id_ref = channel_id;
3192            }
3193
3194            // Take the sender from SenderSlot
3195            if let Ok(mut sender_field) = ps.field_by_name("sender")
3196                && let Ok(slot) = sender_field.get_mut::<SenderSlot>()
3197                && let Some(tx) = slot.take()
3198            {
3199                debug!(
3200                    channel_id,
3201                    "OutgoingBinder::bind_tx_stream: took sender, registering for incoming"
3202                );
3203                // Register for incoming Data routing
3204                self.register_incoming(channel_id, tx);
3205            }
3206        }
3207    }
3208
3209    /// Make a raw RPC call with pre-serialized payload.
3210    ///
3211    /// Returns the raw response payload bytes.
3212    /// Note: For streaming calls, use `call()` which handles channel binding.
3213    pub async fn call_raw(
3214        &self,
3215        method_id: u64,
3216        payload: Vec<u8>,
3217    ) -> Result<Vec<u8>, TransportError> {
3218        self.call_raw_full(method_id, Vec::new(), Vec::new(), payload, None)
3219            .await
3220            .map(|r| r.payload)
3221    }
3222
3223    /// Make a raw RPC call with pre-serialized payload and channel IDs.
3224    ///
3225    /// Used internally by `call()` after binding streams.
3226    /// Returns ResponseData so caller can handle response channels.
3227    async fn call_raw_with_channels(
3228        &self,
3229        method_id: u64,
3230        channels: Vec<u64>,
3231        payload: Vec<u8>,
3232        args_debug: Option<String>,
3233    ) -> Result<ResponseData, TransportError> {
3234        self.call_raw_full(method_id, Vec::new(), channels, payload, args_debug)
3235            .await
3236    }
3237
3238    async fn call_raw_with_channels_and_metadata(
3239        &self,
3240        method_id: u64,
3241        channels: Vec<u64>,
3242        payload: Vec<u8>,
3243        args_debug: Option<String>,
3244        metadata: roam_wire::Metadata,
3245    ) -> Result<ResponseData, TransportError> {
3246        self.call_raw_full(method_id, metadata, channels, payload, args_debug)
3247            .await
3248    }
3249
3250    /// Make a raw RPC call with pre-serialized payload and metadata.
3251    ///
3252    /// Returns the raw response payload bytes.
3253    pub async fn call_raw_with_metadata(
3254        &self,
3255        method_id: u64,
3256        payload: Vec<u8>,
3257        metadata: Vec<(String, roam_wire::MetadataValue)>,
3258    ) -> Result<Vec<u8>, TransportError> {
3259        self.call_raw_full(method_id, metadata, Vec::new(), payload, None)
3260            .await
3261            .map(|r| r.payload)
3262    }
3263
3264    /// Make a raw RPC call with all options.
3265    ///
3266    /// Returns ResponseData containing the payload and any response channel IDs.
3267    async fn call_raw_full(
3268        &self,
3269        method_id: u64,
3270        metadata: Vec<(String, roam_wire::MetadataValue)>,
3271        channels: Vec<u64>,
3272        payload: Vec<u8>,
3273        args_debug: Option<String>,
3274    ) -> Result<ResponseData, TransportError> {
3275        let request_id = self.shared.request_ids.next();
3276        let (response_tx, response_rx) = oneshot();
3277
3278        // Track outgoing request for diagnostics
3279        if let Some(diag) = &self.shared.diagnostic_state {
3280            let args = args_debug.map(|s| {
3281                let mut map = std::collections::HashMap::new();
3282                map.insert("args".to_string(), s);
3283                map
3284            });
3285            diag.record_outgoing_request(request_id, method_id, args);
3286            // Associate channels with this request
3287            diag.associate_channels_with_request(&channels, request_id);
3288        }
3289
3290        let msg = DriverMessage::Call {
3291            conn_id: self.shared.conn_id,
3292            request_id,
3293            method_id,
3294            metadata,
3295            channels,
3296            payload,
3297            response_tx,
3298        };
3299
3300        self.shared
3301            .driver_tx
3302            .send(msg)
3303            .await
3304            .map_err(|_| TransportError::DriverGone)?;
3305
3306        let result = response_rx
3307            .await
3308            .map_err(|_| TransportError::DriverGone)?
3309            .map_err(|_| TransportError::ConnectionClosed);
3310
3311        // Mark request as complete
3312        if let Some(diag) = &self.shared.diagnostic_state {
3313            diag.complete_request(request_id);
3314        }
3315
3316        result
3317    }
3318
3319    /// Open a new virtual connection on the link.
3320    ///
3321    /// Sends a `Connect` message to the remote peer and waits for an
3322    /// `Accept` or `Reject` response. Returns a new `ConnectionHandle`
3323    /// for the virtual connection if accepted.
3324    ///
3325    /// r[impl core.conn.open]
3326    ///
3327    /// # Arguments
3328    ///
3329    /// * `metadata` - Optional metadata to send with the Connect request
3330    ///   (e.g., authentication tokens, routing hints).
3331    /// * `dispatcher` - Optional dispatcher for handling incoming requests on the
3332    ///   virtual connection. If None, the connection can only make calls, not receive them.
3333    ///
3334    /// # Example
3335    ///
3336    /// ```ignore
3337    /// // Open a new virtual connection that can receive calls
3338    /// let dispatcher = Box::new(MyDispatcher::new());
3339    /// let virtual_conn = handle.connect(vec![], Some(dispatcher)).await?;
3340    ///
3341    /// // Use the new connection for calls
3342    /// let response = virtual_conn.call_raw(method_id, payload).await?;
3343    /// ```
3344    pub async fn connect(
3345        &self,
3346        metadata: roam_wire::Metadata,
3347        dispatcher: Option<Box<dyn ServiceDispatcher>>,
3348    ) -> Result<ConnectionHandle, crate::ConnectError> {
3349        let request_id = self.shared.request_ids.next();
3350        let (response_tx, response_rx) = oneshot();
3351
3352        let msg = DriverMessage::Connect {
3353            request_id,
3354            metadata,
3355            response_tx,
3356            dispatcher,
3357        };
3358
3359        self.shared.driver_tx.send(msg).await.map_err(|_| {
3360            crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone"))
3361        })?;
3362
3363        response_rx
3364            .await
3365            .map_err(|_| crate::ConnectError::ConnectFailed(std::io::Error::other("driver gone")))?
3366    }
3367
3368    /// Allocate a stream ID for an outgoing stream.
3369    ///
3370    /// Used internally when binding streams during call().
3371    pub fn alloc_channel_id(&self) -> ChannelId {
3372        self.shared.channel_ids.next()
3373    }
3374
3375    /// Allocate a unique request ID for an outgoing call.
3376    ///
3377    /// Used when manually constructing DriverMessage::Call.
3378    pub fn alloc_request_id(&self) -> u64 {
3379        self.shared.request_ids.next()
3380    }
3381
3382    /// Register an incoming stream (we receive data from peer).
3383    ///
3384    /// Used when schema has `Tx<T>` (callee sends to caller) - we receive that data.
3385    pub fn register_incoming(&self, channel_id: ChannelId, tx: Sender<Vec<u8>>) {
3386        // Track channel for diagnostics (request_id not available here)
3387        if let Some(diag) = &self.shared.diagnostic_state {
3388            diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Rx, None);
3389        }
3390        self.shared
3391            .channel_registry
3392            .lock()
3393            .unwrap()
3394            .register_incoming(channel_id, tx);
3395    }
3396
3397    /// Register credit tracking for an outgoing stream.
3398    ///
3399    /// The actual receiver is owned by the driver, not the registry.
3400    pub fn register_outgoing_credit(&self, channel_id: ChannelId) {
3401        // Track channel for diagnostics (request_id not available here)
3402        if let Some(diag) = &self.shared.diagnostic_state {
3403            diag.record_channel_open(channel_id, crate::diagnostic::ChannelDirection::Tx, None);
3404        }
3405        self.shared
3406            .channel_registry
3407            .lock()
3408            .unwrap()
3409            .register_outgoing_credit(channel_id);
3410    }
3411
3412    /// Route incoming stream data to the appropriate Rx.
3413    pub async fn route_data(
3414        &self,
3415        channel_id: ChannelId,
3416        payload: Vec<u8>,
3417    ) -> Result<(), ChannelError> {
3418        // Get the sender while holding the lock, then release before await
3419        let (tx, payload) = self
3420            .shared
3421            .channel_registry
3422            .lock()
3423            .unwrap()
3424            .prepare_route_data(channel_id, payload)?;
3425        // Send without holding the lock
3426        let _ = tx.send(payload).await;
3427        Ok(())
3428    }
3429
3430    /// Close an incoming stream.
3431    pub fn close_channel(&self, channel_id: ChannelId) {
3432        // Track channel close for diagnostics
3433        if let Some(diag) = &self.shared.diagnostic_state {
3434            diag.record_channel_close(channel_id);
3435        }
3436        self.shared
3437            .channel_registry
3438            .lock()
3439            .unwrap()
3440            .close(channel_id);
3441    }
3442
3443    /// Reset a stream.
3444    pub fn reset_channel(&self, channel_id: ChannelId) {
3445        // Track channel close for diagnostics
3446        if let Some(diag) = &self.shared.diagnostic_state {
3447            diag.record_channel_close(channel_id);
3448        }
3449        self.shared
3450            .channel_registry
3451            .lock()
3452            .unwrap()
3453            .reset(channel_id);
3454    }
3455
3456    /// Check if a stream exists.
3457    pub fn contains_channel(&self, channel_id: ChannelId) -> bool {
3458        self.shared
3459            .channel_registry
3460            .lock()
3461            .unwrap()
3462            .contains(channel_id)
3463    }
3464
3465    /// Receive credit for an outgoing stream.
3466    pub fn receive_credit(&self, channel_id: ChannelId, bytes: u32) {
3467        self.shared
3468            .channel_registry
3469            .lock()
3470            .unwrap()
3471            .receive_credit(channel_id, bytes);
3472    }
3473
3474    /// Get a clone of the driver message sender.
3475    ///
3476    /// Used for forwarding/proxy scenarios where messages need to be sent
3477    /// on this connection's wire.
3478    pub fn driver_tx(&self) -> Sender<DriverMessage> {
3479        self.shared.channel_registry.lock().unwrap().driver_tx()
3480    }
3481
3482    /// Bind receivers for `Rx<T>` streams in a deserialized response.
3483    ///
3484    /// After deserializing a response, any `Rx<T>` values are "hollow" - they have
3485    /// channel IDs but no actual receiver. This method walks the response using
3486    /// reflection and binds receivers for each `Rx<T>` so data can be received.
3487    ///
3488    /// # How it works
3489    ///
3490    /// For each `Rx<T>` found in the response:
3491    /// 1. Read the channel_id that was set during deserialization
3492    /// 2. Create a new channel (tx, rx)
3493    /// 3. Set the receiver slot on the Rx
3494    /// 4. Register the sender with the channel registry for incoming data routing
3495    ///
3496    /// This mirrors server-side `ChannelRegistry::bind_streams` but for responses.
3497    ///
3498    /// IMPORTANT: The `channels` parameter contains the authoritative channel IDs
3499    /// from the Response framing. For forwarded connections (via ForwardingDispatcher),
3500    /// these IDs may differ from the IDs serialized in the payload. We patch them first.
3501    pub fn bind_response_streams<T: Facet<'static>>(&self, response: &mut T, channels: &[u64]) {
3502        // Patch channel IDs from Response.channels into the deserialized response.
3503        // This is critical for ForwardingDispatcher where the payload contains upstream
3504        // channel IDs but channels[] contains the remapped downstream IDs.
3505        patch_channel_ids(response, channels);
3506
3507        let poke = facet::Poke::new(response);
3508        self.bind_response_streams_recursive(poke);
3509    }
3510
3511    /// Recursively walk a Poke value looking for Rx streams to bind in responses.
3512    #[allow(unsafe_code)]
3513    fn bind_response_streams_recursive(&self, mut poke: facet::Poke<'_, '_>) {
3514        use facet::Def;
3515
3516        let shape = poke.shape();
3517
3518        // Check if this is an Rx type - only Rx needs binding in responses
3519        // (Tx in responses would be outgoing, but that's uncommon for return types)
3520        if shape.module_path == Some("roam_session") && shape.type_identifier == "Rx" {
3521            self.bind_rx_response_stream(poke);
3522            return;
3523        }
3524
3525        // Dispatch based on the shape's definition
3526        match shape.def {
3527            Def::Scalar => {}
3528
3529            // Recurse into struct/tuple fields
3530            _ if poke.is_struct() => {
3531                let mut ps = poke.into_struct().expect("is_struct was true");
3532                let field_count = ps.field_count();
3533                for i in 0..field_count {
3534                    if let Ok(field_poke) = ps.field(i) {
3535                        self.bind_response_streams_recursive(field_poke);
3536                    }
3537                }
3538            }
3539
3540            // Recurse into Option<T>
3541            Def::Option(_) => {
3542                // Option is represented as an enum, use into_enum to access its value
3543                if let Ok(mut pe) = poke.into_enum()
3544                    && let Ok(Some(inner_poke)) = pe.field(0)
3545                {
3546                    self.bind_response_streams_recursive(inner_poke);
3547                }
3548            }
3549
3550            // Recurse into list elements (e.g., Vec<Rx<T>>)
3551            Def::List(list_def) => {
3552                let len = {
3553                    let peek = poke.as_peek();
3554                    peek.into_list().map(|pl| pl.len()).unwrap_or(0)
3555                };
3556                // Get mutable access to elements via VTable (no PokeList exists)
3557                if let Some(get_mut_fn) = list_def.vtable.get_mut {
3558                    let element_shape = list_def.t;
3559                    let data_ptr = poke.data_mut();
3560                    for i in 0..len {
3561                        // SAFETY: We have exclusive mutable access via poke, index < len, shape is correct
3562                        let element_ptr = unsafe { (get_mut_fn)(data_ptr, i, element_shape) };
3563                        if let Some(ptr) = element_ptr {
3564                            // SAFETY: ptr points to a valid element with the correct shape
3565                            let element_poke =
3566                                unsafe { facet::Poke::from_raw_parts(ptr, element_shape) };
3567                            self.bind_response_streams_recursive(element_poke);
3568                        }
3569                    }
3570                }
3571            }
3572
3573            // Other enum variants
3574            _ if poke.is_enum() => {
3575                if let Ok(mut pe) = poke.into_enum()
3576                    && let Ok(Some(variant_poke)) = pe.field(0)
3577                {
3578                    self.bind_response_streams_recursive(variant_poke);
3579                }
3580            }
3581
3582            _ => {}
3583        }
3584    }
3585
3586    /// Bind a single Rx<T> stream from a response.
3587    ///
3588    /// Creates a channel, sets the receiver slot, and registers for incoming data.
3589    fn bind_rx_response_stream(&self, poke: facet::Poke<'_, '_>) {
3590        if let Ok(mut ps) = poke.into_struct() {
3591            // Get the channel_id that was deserialized from the wire
3592            let channel_id = if let Ok(channel_id_field) = ps.field_by_name("channel_id")
3593                && let Ok(id_ref) = channel_id_field.get::<ChannelId>()
3594            {
3595                *id_ref
3596            } else {
3597                return;
3598            };
3599
3600            // Create channel and set receiver slot
3601            let (tx, rx) = crate::runtime::channel(RX_STREAM_BUFFER_SIZE);
3602
3603            if let Ok(mut receiver_field) = ps.field_by_name("receiver")
3604                && let Ok(slot) = receiver_field.get_mut::<ReceiverSlot>()
3605            {
3606                slot.set(rx);
3607            }
3608
3609            // Register for incoming data routing
3610            self.register_incoming(channel_id, tx);
3611        }
3612    }
3613}
3614
3615#[derive(Debug)]
3616pub enum ClientError<TransportError> {
3617    Transport(TransportError),
3618    Encode(facet_postcard::SerializeError),
3619    Decode(facet_postcard::DeserializeError<facet_postcard::PostcardError>),
3620}
3621
3622impl<TransportError> From<TransportError> for ClientError<TransportError> {
3623    fn from(value: TransportError) -> Self {
3624        Self::Transport(value)
3625    }
3626}
3627
3628#[derive(Debug)]
3629pub enum DispatchError {
3630    Encode(facet_postcard::SerializeError),
3631}
3632
3633// ============================================================================
3634// Tunnel Adapters for AsyncRead/AsyncWrite Streams (native only)
3635// ============================================================================
3636
3637#[cfg(not(target_arch = "wasm32"))]
3638use std::io;
3639#[cfg(not(target_arch = "wasm32"))]
3640use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
3641#[cfg(not(target_arch = "wasm32"))]
3642use tokio::task::JoinHandle;
3643
3644/// Default chunk size for tunnel pumps (32KB).
3645///
3646/// Balances throughput with memory usage and slot consumption.
3647/// Larger values improve throughput but use more memory per read.
3648/// Smaller values improve latency but increase syscall overhead.
3649#[cfg(not(target_arch = "wasm32"))]
3650pub const DEFAULT_TUNNEL_CHUNK_SIZE: usize = 32 * 1024;
3651
3652/// A bidirectional byte tunnel over roam channels.
3653///
3654/// From the perspective of whoever holds the tunnel:
3655/// - `tx`: Send bytes TO the remote end
3656/// - `rx`: Receive bytes FROM the remote end
3657///
3658/// Tunnels are typically used to bridge async byte streams (TCP, Unix sockets, etc.)
3659/// with roam's streaming channels. One side creates a tunnel pair with [`tunnel_pair()`],
3660/// passes one half to the remote via an RPC call, and uses the other half locally.
3661///
3662/// # Example
3663///
3664/// ```ignore
3665/// // Host side: create tunnel and pump to/from a socket
3666/// let (local, remote) = roam_session::tunnel_pair();
3667/// let (read_handle, write_handle) = roam_session::tunnel_stream(socket, local, 32 * 1024);
3668///
3669/// // Pass `remote` to cell via RPC
3670/// cell.handle_connection(remote).await?;
3671/// ```
3672#[derive(Facet)]
3673pub struct Tunnel {
3674    /// Channel for sending bytes to the remote end.
3675    pub tx: Tx<Vec<u8>>,
3676    /// Channel for receiving bytes from the remote end.
3677    pub rx: Rx<Vec<u8>>,
3678}
3679
3680/// Create a pair of connected tunnels.
3681///
3682/// Returns `(local, remote)` where:
3683/// - Data sent on `local.tx` arrives at `remote.rx`
3684/// - Data sent on `remote.tx` arrives at `local.rx`
3685///
3686/// This is useful for creating a bidirectional channel that can be split
3687/// across an RPC boundary. One side keeps `local` and passes `remote` to
3688/// the other side via an RPC call.
3689///
3690/// # Example
3691///
3692/// ```ignore
3693/// let (local, remote) = tunnel_pair();
3694///
3695/// // Spawn tasks to pump data from local stream
3696/// tunnel_stream(tcp_stream, local, DEFAULT_TUNNEL_CHUNK_SIZE);
3697///
3698/// // Send remote to the other side via RPC
3699/// service.handle_tunnel(remote).await?;
3700/// ```
3701pub fn tunnel_pair() -> (Tunnel, Tunnel) {
3702    let (tx1, rx1) = channel::<Vec<u8>>();
3703    let (tx2, rx2) = channel::<Vec<u8>>();
3704    (Tunnel { tx: tx1, rx: rx2 }, Tunnel { tx: tx2, rx: rx1 })
3705}
3706
3707/// Pump bytes from an `AsyncRead` into a `Tx<Vec<u8>>`.
3708///
3709/// Reads chunks up to `chunk_size` bytes and sends them on the channel.
3710/// Returns when the reader reaches EOF or the channel closes.
3711///
3712/// # Arguments
3713///
3714/// * `reader` - Any type implementing `AsyncRead + Unpin`
3715/// * `tx` - The transmit channel to send bytes to
3716/// * `chunk_size` - Maximum bytes to read per chunk
3717///
3718/// # Returns
3719///
3720/// * `Ok(())` - Reader reached EOF, channel closed gracefully
3721/// * `Err(io::Error)` - Read error occurred
3722///
3723/// # Example
3724///
3725/// ```ignore
3726/// let (tx, rx) = roam::channel::<Vec<u8>>();
3727/// let result = pump_read_to_tx(reader, tx, 32 * 1024).await;
3728/// ```
3729#[cfg(not(target_arch = "wasm32"))]
3730pub async fn pump_read_to_tx<R: AsyncRead + Unpin>(
3731    mut reader: R,
3732    tx: Tx<Vec<u8>>,
3733    chunk_size: usize,
3734) -> io::Result<()> {
3735    let mut buf = vec![0u8; chunk_size];
3736    loop {
3737        let n = reader.read(&mut buf).await?;
3738        if n == 0 {
3739            // EOF - drop tx to close the channel
3740            break;
3741        }
3742        // Send the bytes we read
3743        if tx.send(&buf[..n].to_vec()).await.is_err() {
3744            // Channel closed by receiver - treat as graceful shutdown
3745            break;
3746        }
3747    }
3748    Ok(())
3749}
3750
3751/// Pump bytes from an `Rx<Vec<u8>>` into an `AsyncWrite`.
3752///
3753/// Receives chunks and writes them to the writer.
3754/// Returns when the channel closes or a write error occurs.
3755///
3756/// # Arguments
3757///
3758/// * `rx` - The receive channel to get bytes from
3759/// * `writer` - Any type implementing `AsyncWrite + Unpin`
3760///
3761/// # Returns
3762///
3763/// * `Ok(())` - Channel closed gracefully
3764/// * `Err(io::Error)` - Write error or deserialization error occurred
3765///
3766/// # Example
3767///
3768/// ```ignore
3769/// let (tx, rx) = roam::channel::<Vec<u8>>();
3770/// let result = pump_rx_to_write(rx, writer).await;
3771/// ```
3772#[cfg(not(target_arch = "wasm32"))]
3773pub async fn pump_rx_to_write<W: AsyncWrite + Unpin>(
3774    mut rx: Rx<Vec<u8>>,
3775    mut writer: W,
3776) -> io::Result<()> {
3777    loop {
3778        match rx.recv().await {
3779            Ok(Some(data)) => {
3780                writer.write_all(&data).await?;
3781            }
3782            Ok(None) => {
3783                // Channel closed - flush and exit
3784                writer.flush().await?;
3785                break;
3786            }
3787            Err(e) => {
3788                return Err(io::Error::new(
3789                    io::ErrorKind::InvalidData,
3790                    format!("tunnel receive error: {e}"),
3791                ));
3792            }
3793        }
3794    }
3795    Ok(())
3796}
3797
3798/// Tunnel a bidirectional stream through a roam Tunnel.
3799///
3800/// Spawns two tasks to pump data in both directions:
3801/// - One task reads from `stream` and sends to `tunnel.tx`
3802/// - One task receives from `tunnel.rx` and writes to `stream`
3803///
3804/// Returns handles to join on completion. Both tasks run until their
3805/// respective direction completes (EOF/close) or an error occurs.
3806///
3807/// # Arguments
3808///
3809/// * `stream` - Any type implementing `AsyncRead + AsyncWrite + Unpin + Send + 'static`
3810/// * `tunnel` - The tunnel to pump data through
3811/// * `chunk_size` - Maximum bytes to read per chunk (see [`DEFAULT_TUNNEL_CHUNK_SIZE`])
3812///
3813/// # Returns
3814///
3815/// A tuple of `(read_handle, write_handle)`:
3816/// - `read_handle` - Completes when the stream reaches EOF or tx closes
3817/// - `write_handle` - Completes when rx closes or stream write fails
3818///
3819/// # Example
3820///
3821/// ```ignore
3822/// let (local, remote) = tunnel_pair();
3823/// let (read_handle, write_handle) = tunnel_stream(tcp_stream, local, DEFAULT_TUNNEL_CHUNK_SIZE);
3824///
3825/// // Wait for both directions to complete
3826/// let _ = read_handle.await;
3827/// let _ = write_handle.await;
3828/// ```
3829#[cfg(not(target_arch = "wasm32"))]
3830pub fn tunnel_stream<S>(
3831    stream: S,
3832    tunnel: Tunnel,
3833    chunk_size: usize,
3834) -> (JoinHandle<io::Result<()>>, JoinHandle<io::Result<()>>)
3835where
3836    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
3837{
3838    let (reader, writer) = tokio::io::split(stream);
3839    let Tunnel { tx, rx } = tunnel;
3840
3841    let read_handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, chunk_size).await });
3842
3843    let write_handle = tokio::spawn(async move { pump_rx_to_write(rx, writer).await });
3844
3845    (read_handle, write_handle)
3846}
3847
3848#[cfg(test)]
3849mod tests {
3850    use super::*;
3851
3852    // r[verify channeling.id.parity]
3853    #[test]
3854    fn channel_id_allocator_initiator_uses_odd_ids() {
3855        let alloc = ChannelIdAllocator::new(Role::Initiator);
3856        assert_eq!(alloc.next(), 1);
3857        assert_eq!(alloc.next(), 3);
3858        assert_eq!(alloc.next(), 5);
3859        assert_eq!(alloc.next(), 7);
3860    }
3861
3862    // r[verify channeling.id.parity]
3863    #[test]
3864    fn channel_id_allocator_acceptor_uses_even_ids() {
3865        let alloc = ChannelIdAllocator::new(Role::Acceptor);
3866        assert_eq!(alloc.next(), 2);
3867        assert_eq!(alloc.next(), 4);
3868        assert_eq!(alloc.next(), 6);
3869        assert_eq!(alloc.next(), 8);
3870    }
3871
3872    // r[verify channeling.holder-semantics]
3873    #[tokio::test]
3874    async fn tx_serializes_and_rx_deserializes() {
3875        // Create a channel pair using roam::channel
3876        let (tx, mut rx) = channel::<i32>();
3877
3878        // Simulate what ConnectionHandle::call would do: take the receiver
3879        let mut taken_rx = rx.receiver.take().expect("receiver should be present");
3880
3881        // Now tx can send and we can receive on the taken receiver
3882        tx.send(&100).await.unwrap();
3883        tx.send(&200).await.unwrap();
3884
3885        // Receive raw bytes and deserialize
3886        let bytes1 = taken_rx.recv().await.unwrap();
3887        let val1: i32 = facet_postcard::from_slice(&bytes1).unwrap();
3888        assert_eq!(val1, 100);
3889
3890        let bytes2 = taken_rx.recv().await.unwrap();
3891        let val2: i32 = facet_postcard::from_slice(&bytes2).unwrap();
3892        assert_eq!(val2, 200);
3893    }
3894
3895    /// Create a test registry with a dummy task channel.
3896    fn test_registry() -> ChannelRegistry {
3897        let (task_tx, _task_rx) = crate::runtime::channel(10);
3898        ChannelRegistry::new(task_tx)
3899    }
3900
3901    // r[verify channeling.data-after-close]
3902    #[tokio::test]
3903    async fn data_after_close_is_rejected() {
3904        let mut registry = test_registry();
3905        let (tx, _rx) = crate::runtime::channel(10);
3906        registry.register_incoming(42, tx);
3907
3908        // Close the stream
3909        registry.close(42);
3910
3911        // Data after close should fail
3912        let result = registry.route_data(42, b"data".to_vec()).await;
3913        assert_eq!(result, Err(ChannelError::DataAfterClose));
3914    }
3915
3916    // r[verify channeling.data]
3917    // r[verify channeling.unknown]
3918    #[tokio::test]
3919    async fn channel_registry_routes_data_to_registered_stream() {
3920        let mut registry = test_registry();
3921
3922        // Register a stream
3923        let (tx, mut rx) = crate::runtime::channel(10);
3924        registry.register_incoming(42, tx);
3925
3926        // Data to registered stream should succeed
3927        assert!(registry.route_data(42, b"hello".to_vec()).await.is_ok());
3928
3929        // Should receive the data
3930        assert_eq!(rx.recv().await, Some(b"hello".to_vec()));
3931
3932        // Data to unregistered stream should fail
3933        assert!(registry.route_data(999, b"nope".to_vec()).await.is_err());
3934    }
3935
3936    // r[verify channeling.close]
3937    #[tokio::test]
3938    async fn channel_registry_close_terminates_stream() {
3939        let mut registry = test_registry();
3940        let (tx, mut rx) = crate::runtime::channel(10);
3941        registry.register_incoming(42, tx);
3942
3943        // Send some data
3944        registry.route_data(42, b"data1".to_vec()).await.unwrap();
3945
3946        // Close the stream
3947        registry.close(42);
3948
3949        // Should still receive buffered data
3950        assert_eq!(rx.recv().await, Some(b"data1".to_vec()));
3951
3952        // Then channel closes (sender dropped)
3953        assert_eq!(rx.recv().await, None);
3954
3955        // Stream no longer registered
3956        assert!(!registry.contains(42));
3957    }
3958
3959    #[test]
3960    fn tx_rx_shape_metadata() {
3961        use facet::Facet;
3962
3963        let tx_shape = <Tx<i32> as Facet>::SHAPE;
3964        let rx_shape = <Rx<i32> as Facet>::SHAPE;
3965
3966        // Verify module_path and type_identifier are set correctly
3967        assert_eq!(tx_shape.module_path, Some("roam_session"));
3968        assert_eq!(tx_shape.type_identifier, "Tx");
3969        assert_eq!(rx_shape.module_path, Some("roam_session"));
3970        assert_eq!(rx_shape.type_identifier, "Rx");
3971
3972        // Verify type_params are populated
3973        assert_eq!(tx_shape.type_params.len(), 1);
3974        assert_eq!(rx_shape.type_params.len(), 1);
3975    }
3976
3977    // ========================================================================
3978    // Tunnel Tests
3979    // ========================================================================
3980
3981    #[tokio::test]
3982    async fn tunnel_pair_connects_bidirectionally() {
3983        let (local, remote) = tunnel_pair();
3984
3985        // Send from local to remote
3986        local.tx.send(&b"hello".to_vec()).await.unwrap();
3987
3988        // Receive on remote
3989        let mut remote_rx = remote.rx;
3990        let received = remote_rx.recv().await.unwrap().unwrap();
3991        assert_eq!(received, b"hello".to_vec());
3992
3993        // Send from remote to local
3994        remote.tx.send(&b"world".to_vec()).await.unwrap();
3995
3996        // Receive on local
3997        let mut local_rx = local.rx;
3998        let received = local_rx.recv().await.unwrap().unwrap();
3999        assert_eq!(received, b"world".to_vec());
4000    }
4001
4002    #[tokio::test]
4003    async fn pump_read_to_tx_sends_chunks() {
4004        use std::io::Cursor;
4005
4006        let data = b"hello world this is a test message";
4007        let reader = Cursor::new(data.to_vec());
4008        let (tx, mut rx) = channel::<Vec<u8>>();
4009
4010        // Pump with small chunk size to force multiple chunks
4011        let handle = tokio::spawn(async move { pump_read_to_tx(reader, tx, 10).await });
4012
4013        // Collect all received chunks
4014        let mut received = Vec::new();
4015        while let Ok(Some(chunk)) = rx.recv().await {
4016            received.extend(chunk);
4017        }
4018
4019        // Verify we got all the data
4020        assert_eq!(received, data.to_vec());
4021
4022        // Pump should complete successfully
4023        handle.await.unwrap().unwrap();
4024    }
4025
4026    #[tokio::test]
4027    async fn pump_rx_to_write_writes_chunks() {
4028        use std::io::Cursor;
4029
4030        let (tx, rx) = channel::<Vec<u8>>();
4031        let writer = Cursor::new(Vec::new());
4032
4033        // Spawn pump task
4034        let handle = tokio::spawn(async move {
4035            let mut writer = writer;
4036            pump_rx_to_write(rx, &mut writer).await?;
4037            Ok::<_, io::Error>(writer)
4038        });
4039
4040        // Send some chunks
4041        tx.send(&b"hello ".to_vec()).await.unwrap();
4042        tx.send(&b"world".to_vec()).await.unwrap();
4043        drop(tx); // Close the channel
4044
4045        // Wait for pump to complete and get the writer
4046        let writer = handle.await.unwrap().unwrap();
4047        assert_eq!(writer.into_inner(), b"hello world".to_vec());
4048    }
4049
4050    #[tokio::test]
4051    async fn tunnel_stream_bidirectional() {
4052        // Create a duplex stream (simulates a socket)
4053        let (client, server) = tokio::io::duplex(1024);
4054
4055        // Create tunnel pair
4056        let (local, remote) = tunnel_pair();
4057
4058        // Tunnel the client side
4059        let (client_read_handle, client_write_handle) =
4060            tunnel_stream(client, local, DEFAULT_TUNNEL_CHUNK_SIZE);
4061
4062        // Use remote tunnel to send/receive
4063        tokio::spawn(async move {
4064            // Send data through the tunnel (will go to server side of duplex)
4065            remote.tx.send(&b"from tunnel".to_vec()).await.unwrap();
4066        });
4067
4068        // Read from server side of duplex
4069        let mut server = server;
4070        let mut buf = vec![0u8; 1024];
4071        let n = tokio::io::AsyncReadExt::read(&mut server, &mut buf)
4072            .await
4073            .unwrap();
4074        assert!(n > 0);
4075
4076        // Write to server side
4077        tokio::io::AsyncWriteExt::write_all(&mut server, b"to tunnel")
4078            .await
4079            .unwrap();
4080        drop(server); // Close to signal EOF
4081
4082        // Wait for read task to complete
4083        client_read_handle.await.unwrap().unwrap();
4084        client_write_handle.await.unwrap().unwrap();
4085    }
4086
4087    #[tokio::test]
4088    async fn tunnel_handles_empty_data() {
4089        let (tx, mut rx) = channel::<Vec<u8>>();
4090
4091        // Sending empty vec should work
4092        tx.send(&Vec::new()).await.unwrap();
4093
4094        let received = rx.recv().await.unwrap().unwrap();
4095        assert!(received.is_empty());
4096    }
4097
4098    #[tokio::test]
4099    async fn tunnel_close_propagates() {
4100        let (local, remote) = tunnel_pair();
4101
4102        // Drop the sender
4103        drop(local.tx);
4104
4105        // Receiver should see channel closed
4106        let mut rx = remote.rx;
4107        let result = rx.recv().await;
4108        assert!(matches!(result, Ok(None)));
4109    }
4110
4111    // ========================================================================
4112    // Channel ID Collection Tests
4113    // ========================================================================
4114
4115    // r[verify call.request.channels]
4116    #[test]
4117    fn collect_channel_ids_simple_tx() {
4118        let tx: Tx<i32> = Tx::try_from(42u64).unwrap();
4119        let ids = collect_channel_ids(&tx);
4120        assert_eq!(ids, vec![42]);
4121    }
4122
4123    // r[verify call.request.channels]
4124    #[test]
4125    fn collect_channel_ids_simple_rx() {
4126        let rx: Rx<i32> = Rx::try_from(99u64).unwrap();
4127        let ids = collect_channel_ids(&rx);
4128        assert_eq!(ids, vec![99]);
4129    }
4130
4131    // r[verify call.request.channels]
4132    #[test]
4133    fn collect_channel_ids_tuple() {
4134        let rx: Rx<String> = Rx::try_from(10u64).unwrap();
4135        let tx: Tx<String> = Tx::try_from(20u64).unwrap();
4136        let args = (rx, tx);
4137        let ids = collect_channel_ids(&args);
4138        assert_eq!(ids, vec![10, 20]);
4139    }
4140
4141    // r[verify call.request.channels]
4142    #[test]
4143    fn collect_channel_ids_nested_in_struct() {
4144        #[derive(facet::Facet)]
4145        struct StreamArgs {
4146            input: Rx<i32>,
4147            output: Tx<i32>,
4148            count: u32,
4149        }
4150
4151        let args = StreamArgs {
4152            input: Rx::try_from(100u64).unwrap(),
4153            output: Tx::try_from(200u64).unwrap(),
4154            count: 5,
4155        };
4156        let ids = collect_channel_ids(&args);
4157        assert_eq!(ids, vec![100, 200]);
4158    }
4159
4160    // r[verify call.request.channels]
4161    #[test]
4162    fn collect_channel_ids_option_some() {
4163        let tx: Tx<i32> = Tx::try_from(55u64).unwrap();
4164        let args: Option<Tx<i32>> = Some(tx);
4165        let ids = collect_channel_ids(&args);
4166        assert_eq!(ids, vec![55]);
4167    }
4168
4169    // r[verify call.request.channels]
4170    #[test]
4171    fn collect_channel_ids_option_none() {
4172        let args: Option<Tx<i32>> = None;
4173        let ids = collect_channel_ids(&args);
4174        assert!(ids.is_empty());
4175    }
4176
4177    // r[verify call.request.channels]
4178    #[test]
4179    fn collect_channel_ids_vec() {
4180        let tx1: Tx<i32> = Tx::try_from(1u64).unwrap();
4181        let tx2: Tx<i32> = Tx::try_from(2u64).unwrap();
4182        let tx3: Tx<i32> = Tx::try_from(3u64).unwrap();
4183        let args: Vec<Tx<i32>> = vec![tx1, tx2, tx3];
4184        let ids = collect_channel_ids(&args);
4185        assert_eq!(ids, vec![1, 2, 3]);
4186    }
4187
4188    // r[verify call.request.channels]
4189    #[test]
4190    fn collect_channel_ids_deeply_nested() {
4191        #[derive(facet::Facet)]
4192        struct Outer {
4193            inner: Inner,
4194        }
4195
4196        #[derive(facet::Facet)]
4197        struct Inner {
4198            stream: Tx<u8>,
4199        }
4200
4201        let args = Outer {
4202            inner: Inner {
4203                stream: Tx::try_from(777u64).unwrap(),
4204            },
4205        };
4206        let ids = collect_channel_ids(&args);
4207        assert_eq!(ids, vec![777]);
4208    }
4209}