Skip to main content

roam_core/
driver.rs

1use std::{
2    collections::BTreeMap,
3    pin::Pin,
4    sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::Semaphore;
9
10use moire::task::FutureExt as _;
11use roam_types::{
12    Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
13    ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
14    ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, MaybeSend, Payload,
15    ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError,
16    SelfRef, TxError,
17};
18
19use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
20use moire::sync::mpsc;
21
22type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
23
24/// State shared between the driver loop and any DriverCaller/DriverChannelSink handles.
25struct DriverShared {
26    pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
27    request_ids: SyncMutex<IdAllocator<RequestId>>,
28    channel_ids: SyncMutex<IdAllocator<ChannelId>>,
29    /// Registry mapping inbound channel IDs to the sender that feeds the Rx handle.
30    channel_senders:
31        SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
32    /// Buffer for channel messages that arrive before the channel is registered.
33    ///
34    /// This handles the race between the caller sending items immediately after
35    /// `bind_channels_caller_args` creates the sink, and the callee's handler task
36    /// calling `register_rx` via `bind_channels_callee_args`. Items arriving in
37    /// that window are buffered here and drained when the channel is registered.
38    channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
39    /// Credit semaphores for outbound channels (Tx on our side).
40    /// The driver's GrantCredit handler adds permits to these.
41    channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
42}
43
44struct CallerDropGuard {
45    control_tx: mpsc::UnboundedSender<DropControlRequest>,
46    request: DropControlRequest,
47}
48
49impl Drop for CallerDropGuard {
50    fn drop(&mut self) {
51        let _ = self.control_tx.send(self.request);
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::{DriverChannelCreditReplenisher, DriverLocalControl};
58    use roam_types::{ChannelCreditReplenisher, ChannelId};
59    use tokio::sync::mpsc::error::TryRecvError;
60
61    #[test]
62    fn replenisher_batches_at_half_the_initial_window() {
63        let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
64        let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
65
66        for _ in 0..7 {
67            replenisher.on_item_consumed();
68        }
69        assert!(
70            matches!(rx.try_recv(), Err(TryRecvError::Empty)),
71            "should not emit credit before reaching the batch threshold"
72        );
73
74        replenisher.on_item_consumed();
75        let Ok(DriverLocalControl::GrantCredit {
76            channel_id,
77            additional,
78        }) = rx.try_recv()
79        else {
80            panic!("expected batched credit grant");
81        };
82        assert_eq!(channel_id, ChannelId(7));
83        assert_eq!(additional, 8);
84    }
85
86    #[test]
87    fn replenisher_grants_one_by_one_for_single_credit_windows() {
88        let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
89        let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
90
91        replenisher.on_item_consumed();
92        let Ok(DriverLocalControl::GrantCredit {
93            channel_id,
94            additional,
95        }) = rx.try_recv()
96        else {
97            panic!("expected immediate credit grant");
98        };
99        assert_eq!(channel_id, ChannelId(9));
100        assert_eq!(additional, 1);
101    }
102}
103
104/// Concrete `ReplySink` implementation for the driver.
105///
106/// If dropped without `send_reply` being called, automatically sends
107/// `RoamError::Cancelled` to the caller. This guarantees that every
108/// request receives exactly one response (`rpc.response.one-per-request`),
109/// even if the handler panics or forgets to reply.
110pub struct DriverReplySink {
111    sender: Option<ConnectionSender>,
112    request_id: RequestId,
113    binder: DriverChannelBinder,
114}
115
116impl ReplySink for DriverReplySink {
117    async fn send_reply(mut self, response: RequestResponse<'_>) {
118        let sender = self
119            .sender
120            .take()
121            .expect("unreachable: send_reply takes self by value");
122        if let Err(_e) = sender.send_response(self.request_id, response).await {
123            sender.mark_failure(self.request_id, "send_response failed");
124        }
125    }
126
127    fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
128        Some(&self.binder)
129    }
130}
131
132// r[impl rpc.response.one-per-request]
133impl Drop for DriverReplySink {
134    fn drop(&mut self) {
135        if let Some(sender) = self.sender.take() {
136            sender.mark_failure(self.request_id, "no reply sent")
137        }
138    }
139}
140
141// r[impl rpc.channel.item]
142// r[impl rpc.channel.close]
143/// Concrete [`ChannelSink`] backed by a `ConnectionSender`.
144///
145/// Created by the driver when setting up outbound channels (Tx handles).
146/// Sends `ChannelItem` and `ChannelClose` messages through the connection.
147/// Wrapped with [`CreditSink`] to enforce credit-based flow control.
148pub struct DriverChannelSink {
149    sender: ConnectionSender,
150    channel_id: ChannelId,
151    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
152}
153
154impl ChannelSink for DriverChannelSink {
155    fn send_payload<'payload>(
156        &self,
157        payload: Payload<'payload>,
158    ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
159        let sender = self.sender.clone();
160        let channel_id = self.channel_id;
161        Box::pin(async move {
162            sender
163                .send(ConnectionMessage::Channel(ChannelMessage {
164                    id: channel_id,
165                    body: ChannelBody::Item(ChannelItem { item: payload }),
166                }))
167                .await
168                .map_err(|()| TxError::Transport("connection closed".into()))
169        })
170    }
171
172    fn close_channel(
173        &self,
174        _metadata: roam_types::Metadata,
175    ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
176        // [FIXME] ChannelSink::close_channel takes borrowed Metadata but returns 'static future.
177        // We drop the borrowed metadata and send an empty one. This matches the [FIXME] in the
178        // trait definition — the signature needs to be fixed to take owned metadata.
179        let sender = self.sender.clone();
180        let channel_id = self.channel_id;
181        Box::pin(async move {
182            sender
183                .send(ConnectionMessage::Channel(ChannelMessage {
184                    id: channel_id,
185                    body: ChannelBody::Close(ChannelClose {
186                        metadata: Default::default(),
187                    }),
188                }))
189                .await
190                .map_err(|()| TxError::Transport("connection closed".into()))
191        })
192    }
193
194    fn close_channel_on_drop(&self) {
195        let _ = self
196            .local_control_tx
197            .send(DriverLocalControl::CloseChannel {
198                channel_id: self.channel_id,
199            });
200    }
201}
202
203/// Liveness-only handle for a connection root.
204///
205/// Keeps the root connection alive but intentionally exposes no outbound RPC API.
206#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
207#[derive(Clone)]
208pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
209
210impl From<DriverCaller> for NoopCaller {
211    fn from(caller: DriverCaller) -> Self {
212        Self(caller)
213    }
214}
215
216#[derive(Clone)]
217struct DriverChannelBinder {
218    sender: ConnectionSender,
219    shared: Arc<DriverShared>,
220    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
221    drop_guard: Option<Arc<CallerDropGuard>>,
222}
223
224impl DriverChannelBinder {
225    fn create_tx_channel(
226        &self,
227        initial_credit: u32,
228    ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
229        let channel_id = self.shared.channel_ids.lock().alloc();
230        let inner = DriverChannelSink {
231            sender: self.sender.clone(),
232            channel_id,
233            local_control_tx: self.local_control_tx.clone(),
234        };
235        let sink = Arc::new(CreditSink::new(inner, initial_credit));
236        self.shared
237            .channel_credits
238            .lock()
239            .insert(channel_id, Arc::clone(sink.credit()));
240        (channel_id, sink)
241    }
242
243    fn register_rx_channel(
244        &self,
245        channel_id: ChannelId,
246        initial_credit: u32,
247    ) -> roam_types::BoundChannelReceiver {
248        let (tx, rx) = tokio::sync::mpsc::channel(64);
249        let mut terminal_buffered = false;
250        if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
251            for msg in buffered {
252                let is_terminal = matches!(
253                    msg,
254                    IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
255                );
256                let _ = tx.try_send(msg);
257                if is_terminal {
258                    terminal_buffered = true;
259                    break;
260                }
261            }
262        }
263        if terminal_buffered {
264            self.shared.channel_credits.lock().remove(&channel_id);
265            return roam_types::BoundChannelReceiver {
266                receiver: rx,
267                liveness: self.channel_liveness(),
268                replenisher: None,
269            };
270        }
271
272        self.shared.channel_senders.lock().insert(channel_id, tx);
273        roam_types::BoundChannelReceiver {
274            receiver: rx,
275            liveness: self.channel_liveness(),
276            replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
277                channel_id,
278                initial_credit,
279                self.local_control_tx.clone(),
280            )) as ChannelCreditReplenisherHandle),
281        }
282    }
283}
284
285impl ChannelBinder for DriverChannelBinder {
286    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
287        let (id, sink) = self.create_tx_channel(initial_credit);
288        (id, sink as Arc<dyn ChannelSink>)
289    }
290
291    fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
292        let channel_id = self.shared.channel_ids.lock().alloc();
293        let rx = self.register_rx_channel(channel_id, initial_credit);
294        (channel_id, rx)
295    }
296
297    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
298        let inner = DriverChannelSink {
299            sender: self.sender.clone(),
300            channel_id,
301            local_control_tx: self.local_control_tx.clone(),
302        };
303        let sink = Arc::new(CreditSink::new(inner, initial_credit));
304        self.shared
305            .channel_credits
306            .lock()
307            .insert(channel_id, Arc::clone(sink.credit()));
308        sink
309    }
310
311    fn register_rx(
312        &self,
313        channel_id: ChannelId,
314        initial_credit: u32,
315    ) -> roam_types::BoundChannelReceiver {
316        self.register_rx_channel(channel_id, initial_credit)
317    }
318
319    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
320        self.drop_guard
321            .as_ref()
322            .map(|guard| guard.clone() as ChannelLivenessHandle)
323    }
324}
325
326/// Implements [`Caller`]: allocates a request ID, registers a response slot,
327/// sends the call through the connection, and awaits the response.
328#[derive(Clone)]
329pub struct DriverCaller {
330    sender: ConnectionSender,
331    shared: Arc<DriverShared>,
332    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
333    _drop_guard: Option<Arc<CallerDropGuard>>,
334}
335
336impl DriverCaller {
337    /// Allocate a channel ID and create a credit-controlled sink for outbound items.
338    ///
339    /// `initial_credit` is the const generic `N` from `Tx<T, N>`.
340    /// The returned sink enforces credit; the semaphore is registered so
341    /// `GrantCredit` messages can add permits.
342    pub fn create_tx_channel(
343        &self,
344        initial_credit: u32,
345    ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
346        let channel_id = self.shared.channel_ids.lock().alloc();
347        let inner = DriverChannelSink {
348            sender: self.sender.clone(),
349            channel_id,
350            local_control_tx: self.local_control_tx.clone(),
351        };
352        let sink = Arc::new(CreditSink::new(inner, initial_credit));
353        self.shared
354            .channel_credits
355            .lock()
356            .insert(channel_id, Arc::clone(sink.credit()));
357        (channel_id, sink)
358    }
359
360    /// Returns the underlying connection sender.
361    ///
362    /// Used by in-crate tests that need to inject raw messages for cancellation
363    /// and channel protocol testing.
364    #[cfg(test)]
365    pub(crate) fn connection_sender(&self) -> &ConnectionSender {
366        &self.sender
367    }
368
369    /// Register an inbound channel (Rx on our side) and return the receiver.
370    ///
371    /// The channel ID comes from the peer (e.g. from `RequestCall.channels`).
372    /// The returned receiver should be bound to an `Rx` handle via `Rx::bind()`.
373    pub fn register_rx_channel(
374        &self,
375        channel_id: ChannelId,
376        initial_credit: u32,
377    ) -> roam_types::BoundChannelReceiver {
378        let (tx, rx) = tokio::sync::mpsc::channel(64);
379        let mut terminal_buffered = false;
380        // Drain any buffered messages that arrived before registration.
381        if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
382            for msg in buffered {
383                let is_terminal = matches!(
384                    msg,
385                    IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
386                );
387                let _ = tx.try_send(msg);
388                if is_terminal {
389                    terminal_buffered = true;
390                    break;
391                }
392            }
393        }
394        if terminal_buffered {
395            self.shared.channel_credits.lock().remove(&channel_id);
396            return roam_types::BoundChannelReceiver {
397                receiver: rx,
398                liveness: self.channel_liveness(),
399                replenisher: None,
400            };
401        }
402
403        self.shared.channel_senders.lock().insert(channel_id, tx);
404        roam_types::BoundChannelReceiver {
405            receiver: rx,
406            liveness: self.channel_liveness(),
407            replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
408                channel_id,
409                initial_credit,
410                self.local_control_tx.clone(),
411            )) as ChannelCreditReplenisherHandle),
412        }
413    }
414}
415
416impl ChannelBinder for DriverCaller {
417    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
418        let (id, sink) = self.create_tx_channel(initial_credit);
419        (id, sink as Arc<dyn ChannelSink>)
420    }
421
422    fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
423        let channel_id = self.shared.channel_ids.lock().alloc();
424        let rx = self.register_rx_channel(channel_id, initial_credit);
425        (channel_id, rx)
426    }
427
428    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
429        let inner = DriverChannelSink {
430            sender: self.sender.clone(),
431            channel_id,
432            local_control_tx: self.local_control_tx.clone(),
433        };
434        let sink = Arc::new(CreditSink::new(inner, initial_credit));
435        self.shared
436            .channel_credits
437            .lock()
438            .insert(channel_id, Arc::clone(sink.credit()));
439        sink
440    }
441
442    fn register_rx(
443        &self,
444        channel_id: ChannelId,
445        initial_credit: u32,
446    ) -> roam_types::BoundChannelReceiver {
447        self.register_rx_channel(channel_id, initial_credit)
448    }
449
450    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
451        self._drop_guard
452            .as_ref()
453            .map(|guard| guard.clone() as ChannelLivenessHandle)
454    }
455}
456
457impl Caller for DriverCaller {
458    fn call<'a>(
459        &'a self,
460        call: RequestCall<'a>,
461    ) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
462    + MaybeSend
463    + 'a {
464        async {
465            // Allocate a request ID.
466            let req_id = self.shared.request_ids.lock().alloc();
467
468            // Register the response slot before sending, so the driver can
469            // route the response even if it arrives before we start awaiting.
470            let (tx, rx) = moire::sync::oneshot::channel("driver.response");
471            self.shared.pending_responses.lock().insert(req_id, tx);
472
473            // Send the call. This awaits the conduit permit and serializes
474            // the borrowed payload all the way to the link's write buffer.
475            let send_result = self
476                .sender
477                .send(ConnectionMessage::Request(RequestMessage {
478                    id: req_id,
479                    body: RequestBody::Call(call),
480                }))
481                .await;
482
483            if send_result.is_err() {
484                // Clean up the pending slot.
485                self.shared.pending_responses.lock().remove(&req_id);
486                return Err(RoamError::Cancelled);
487            }
488
489            // Await the response from the driver loop.
490            let response_msg: SelfRef<RequestMessage<'static>> = rx
491                .named("awaiting_response")
492                .await
493                .map_err(|_| RoamError::Cancelled)?;
494
495            // Extract the Response variant from the RequestMessage.
496            let response = response_msg.map(|m| match m.body {
497                RequestBody::Response(r) => r,
498                _ => unreachable!("pending_responses only gets Response variants"),
499            });
500
501            Ok(response)
502        }
503        .named("Caller::call")
504    }
505
506    fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
507        Some(self)
508    }
509}
510
511// r[impl rpc.handler]
512// r[impl rpc.request]
513// r[impl rpc.response]
514// r[impl rpc.pipelining]
515/// Per-connection driver. Handles in-flight request tracking, dispatches
516/// incoming calls to a Handler, and manages channel state/flow control.
517pub struct Driver<H: Handler<DriverReplySink>> {
518    sender: ConnectionSender,
519    rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
520    failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
521    local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
522    handler: Arc<H>,
523    shared: Arc<DriverShared>,
524    /// In-flight server-side handler tasks, keyed by request ID.
525    /// Used to abort handlers on cancel.
526    in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
527    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
528    drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
529    drop_control_request: DropControlRequest,
530    drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
531}
532
533enum DriverLocalControl {
534    CloseChannel {
535        channel_id: ChannelId,
536    },
537    GrantCredit {
538        channel_id: ChannelId,
539        additional: u32,
540    },
541}
542
543struct DriverChannelCreditReplenisher {
544    channel_id: ChannelId,
545    threshold: u32,
546    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
547    pending: std::sync::Mutex<u32>,
548}
549
550impl DriverChannelCreditReplenisher {
551    fn new(
552        channel_id: ChannelId,
553        initial_credit: u32,
554        local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
555    ) -> Self {
556        Self {
557            channel_id,
558            threshold: (initial_credit / 2).max(1),
559            local_control_tx,
560            pending: std::sync::Mutex::new(0),
561        }
562    }
563}
564
565impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
566    fn on_item_consumed(&self) {
567        let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
568        *pending += 1;
569        if *pending < self.threshold {
570            return;
571        }
572
573        let additional = *pending;
574        *pending = 0;
575        let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
576            channel_id: self.channel_id,
577            additional,
578        });
579    }
580}
581
582impl<H: Handler<DriverReplySink>> Driver<H> {
583    pub fn new(handle: ConnectionHandle, handler: H) -> Self {
584        let conn_id = handle.connection_id();
585        let ConnectionHandle {
586            sender,
587            rx,
588            failures_rx,
589            control_tx,
590            parity,
591        } = handle;
592        let drop_control_request = DropControlRequest::Close(conn_id);
593        let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
594        Self {
595            sender,
596            rx,
597            failures_rx,
598            local_control_rx,
599            handler: Arc::new(handler),
600            shared: Arc::new(DriverShared {
601                pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
602                request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
603                channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
604                channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
605                channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
606                channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
607            }),
608            in_flight_handlers: BTreeMap::new(),
609            local_control_tx,
610            drop_control_seed: control_tx,
611            drop_control_request,
612            drop_guard: SyncMutex::new("driver.drop_guard", None),
613        }
614    }
615
616    /// Get a cloneable caller handle for making outgoing calls.
617    // r[impl rpc.caller.liveness.refcounted]
618    // r[impl rpc.caller.liveness.last-drop-closes-connection]
619    // r[impl rpc.caller.liveness.root-internal-close]
620    // r[impl rpc.caller.liveness.root-teardown-condition]
621    fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
622        self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
623    }
624
625    fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
626        let drop_guard = if let Some(existing) = self.existing_drop_guard() {
627            Some(existing)
628        } else if let Some(seed) = &self.drop_control_seed {
629            let mut guard = self.drop_guard.lock();
630            if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
631                Some(existing)
632            } else {
633                let arc = Arc::new(CallerDropGuard {
634                    control_tx: seed.clone(),
635                    request: self.drop_control_request,
636                });
637                *guard = Some(Arc::downgrade(&arc));
638                Some(arc)
639            }
640        } else {
641            None
642        };
643        drop_guard
644    }
645
646    pub fn caller(&self) -> DriverCaller {
647        let drop_guard = self.connection_drop_guard();
648        DriverCaller {
649            sender: self.sender.clone(),
650            shared: Arc::clone(&self.shared),
651            local_control_tx: self.local_control_tx.clone(),
652            _drop_guard: drop_guard,
653        }
654    }
655
656    fn internal_binder(&self) -> DriverChannelBinder {
657        DriverChannelBinder {
658            sender: self.sender.clone(),
659            shared: Arc::clone(&self.shared),
660            local_control_tx: self.local_control_tx.clone(),
661            drop_guard: self.existing_drop_guard(),
662        }
663    }
664
665    // r[impl rpc.pipelining]
666    /// Main loop: receive messages from the session and dispatch them.
667    /// Handler calls run as spawned tasks — we don't block the driver
668    /// loop waiting for a handler to finish.
669    pub async fn run(&mut self) {
670        loop {
671            tokio::select! {
672                msg = self.rx.recv() => {
673                    match msg {
674                        Some(msg) => self.handle_msg(msg),
675                        None => break,
676                    }
677                }
678                Some((req_id, _reason)) = self.failures_rx.recv() => {
679                    // Clean up the handler tracking entry.
680                    self.in_flight_handlers.remove(&req_id);
681                    if self.shared.pending_responses.lock().remove(&req_id).is_none() {
682                        // Incoming call — handler failed to reply.
683                        // Wire format is always Result<T, RoamError<E>>, so encode
684                        // Cancelled as Err(...) in that envelope.
685                        let error: Result<(), RoamError<core::convert::Infallible>> =
686                            Err(RoamError::Cancelled);
687                        let _ = self.sender.send_response(req_id, RequestResponse {
688                            ret: Payload::outgoing(&error),
689                            channels: vec![],
690                            metadata: Default::default(),
691                        }).await;
692                    }
693                }
694                Some(ctrl) = self.local_control_rx.recv() => {
695                    self.handle_local_control(ctrl).await;
696                }
697            }
698        }
699
700        for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
701            handle.abort();
702        }
703        self.shared.pending_responses.lock().clear();
704
705        // Connection is gone: drop channel runtime state so any registered Rx
706        // receivers observe closure instead of hanging on recv().
707        self.shared.channel_senders.lock().clear();
708        self.shared.channel_buffers.lock().clear();
709        self.shared.channel_credits.lock().clear();
710    }
711
712    async fn handle_local_control(&mut self, control: DriverLocalControl) {
713        match control {
714            DriverLocalControl::CloseChannel { channel_id } => {
715                let _ = self
716                    .sender
717                    .send(ConnectionMessage::Channel(ChannelMessage {
718                        id: channel_id,
719                        body: ChannelBody::Close(ChannelClose {
720                            metadata: Default::default(),
721                        }),
722                    }))
723                    .await;
724            }
725            DriverLocalControl::GrantCredit {
726                channel_id,
727                additional,
728            } => {
729                let _ = self
730                    .sender
731                    .send(ConnectionMessage::Channel(ChannelMessage {
732                        id: channel_id,
733                        body: ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
734                            additional,
735                        }),
736                    }))
737                    .await;
738            }
739        }
740    }
741
742    fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
743        let is_request = matches!(&*msg, ConnectionMessage::Request(_));
744        if is_request {
745            let msg = msg.map(|m| match m {
746                ConnectionMessage::Request(r) => r,
747                _ => unreachable!(),
748            });
749            self.handle_request(msg);
750        } else {
751            let msg = msg.map(|m| match m {
752                ConnectionMessage::Channel(c) => c,
753                _ => unreachable!(),
754            });
755            self.handle_channel(msg);
756        }
757    }
758
759    fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
760        let req_id = msg.id;
761        let is_call = matches!(&msg.body, RequestBody::Call(_));
762        let is_response = matches!(&msg.body, RequestBody::Response(_));
763        let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
764
765        if is_call {
766            // r[impl rpc.request]
767            // r[impl rpc.error.scope]
768            let reply = DriverReplySink {
769                sender: Some(self.sender.clone()),
770                request_id: req_id,
771                binder: self.internal_binder(),
772            };
773            let call = msg.map(|m| match m.body {
774                RequestBody::Call(c) => c,
775                _ => unreachable!(),
776            });
777            let handler = Arc::clone(&self.handler);
778            let join_handle = moire::task::spawn(
779                async move {
780                    handler.handle(call, reply).await;
781                }
782                .named("handler"),
783            );
784            self.in_flight_handlers.insert(req_id, join_handle);
785        } else if is_response {
786            // r[impl rpc.response.one-per-request]
787            if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
788                let _: Result<(), _> = tx.send(msg);
789            }
790        } else if is_cancel {
791            // r[impl rpc.cancel]
792            // r[impl rpc.cancel.channels]
793            // Abort the in-flight handler task. Channels are intentionally left
794            // intact — they have independent lifecycles per spec.
795            if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
796                handle.abort();
797            }
798            // The response is sent automatically: aborting drops DriverReplySink →
799            // mark_failure fires → failures_rx arm sends RoamError::Cancelled.
800        }
801    }
802
803    fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
804        let chan_id = msg.id;
805
806        // Look up the channel sender from the shared registry (handles registered
807        // by both the driver and any DriverCaller that set up channels).
808        let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
809
810        match &msg.body {
811            // r[impl rpc.channel.item]
812            ChannelBody::Item(_item) => {
813                if let Some(tx) = &sender {
814                    let item = msg.map(|m| match m.body {
815                        ChannelBody::Item(item) => item,
816                        _ => unreachable!(),
817                    });
818                    // try_send: if the Rx has been dropped or the buffer is full, drop the item.
819                    let _ = tx.try_send(IncomingChannelMessage::Item(item));
820                } else {
821                    // Channel not yet registered — buffer until register_rx_channel is called.
822                    let item = msg.map(|m| match m.body {
823                        ChannelBody::Item(item) => item,
824                        _ => unreachable!(),
825                    });
826                    self.shared
827                        .channel_buffers
828                        .lock()
829                        .entry(chan_id)
830                        .or_default()
831                        .push(IncomingChannelMessage::Item(item));
832                }
833            }
834            // r[impl rpc.channel.close]
835            ChannelBody::Close(_close) => {
836                if let Some(tx) = &sender {
837                    let close = msg.map(|m| match m.body {
838                        ChannelBody::Close(close) => close,
839                        _ => unreachable!(),
840                    });
841                    let _ = tx.try_send(IncomingChannelMessage::Close(close));
842                } else {
843                    // Channel not yet registered — buffer the close.
844                    let close = msg.map(|m| match m.body {
845                        ChannelBody::Close(close) => close,
846                        _ => unreachable!(),
847                    });
848                    self.shared
849                        .channel_buffers
850                        .lock()
851                        .entry(chan_id)
852                        .or_default()
853                        .push(IncomingChannelMessage::Close(close));
854                }
855                self.shared.channel_senders.lock().remove(&chan_id);
856                self.shared.channel_credits.lock().remove(&chan_id);
857            }
858            // r[impl rpc.channel.reset]
859            ChannelBody::Reset(_reset) => {
860                if let Some(tx) = &sender {
861                    let reset = msg.map(|m| match m.body {
862                        ChannelBody::Reset(reset) => reset,
863                        _ => unreachable!(),
864                    });
865                    let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
866                } else {
867                    // Channel not yet registered — buffer the reset.
868                    let reset = msg.map(|m| match m.body {
869                        ChannelBody::Reset(reset) => reset,
870                        _ => unreachable!(),
871                    });
872                    self.shared
873                        .channel_buffers
874                        .lock()
875                        .entry(chan_id)
876                        .or_default()
877                        .push(IncomingChannelMessage::Reset(reset));
878                }
879                self.shared.channel_senders.lock().remove(&chan_id);
880                self.shared.channel_credits.lock().remove(&chan_id);
881            }
882            // r[impl rpc.flow-control.credit.grant]
883            // r[impl rpc.flow-control.credit.grant.additive]
884            ChannelBody::GrantCredit(grant) => {
885                if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
886                    semaphore.add_permits(grant.additional as usize);
887                }
888            }
889        }
890    }
891}