Skip to main content

roam_core/
driver.rs

1use std::{
2    collections::BTreeMap,
3    pin::Pin,
4    sync::{Arc, Weak},
5};
6
7use moire::sync::SyncMutex;
8use tokio::sync::{Semaphore, watch};
9
10use moire::task::FutureExt as _;
11use roam_types::{
12    Caller, ChannelBinder, ChannelBody, ChannelClose, ChannelCreditReplenisher,
13    ChannelCreditReplenisherHandle, ChannelId, ChannelItem, ChannelLivenessHandle, ChannelMessage,
14    ChannelSink, CreditSink, Handler, IdAllocator, IncomingChannelMessage, MaybeSend, Payload,
15    ReplySink, RequestBody, RequestCall, RequestId, RequestMessage, RequestResponse, RoamError,
16    SelfRef, TxError,
17};
18
19use crate::session::{ConnectionHandle, ConnectionMessage, ConnectionSender, DropControlRequest};
20use moire::sync::mpsc;
21
22type ResponseSlot = moire::sync::oneshot::Sender<SelfRef<RequestMessage<'static>>>;
23
24/// State shared between the driver loop and any DriverCaller/DriverChannelSink handles.
25struct DriverShared {
26    pending_responses: SyncMutex<BTreeMap<RequestId, ResponseSlot>>,
27    request_ids: SyncMutex<IdAllocator<RequestId>>,
28    channel_ids: SyncMutex<IdAllocator<ChannelId>>,
29    /// Registry mapping inbound channel IDs to the sender that feeds the Rx handle.
30    channel_senders:
31        SyncMutex<BTreeMap<ChannelId, tokio::sync::mpsc::Sender<IncomingChannelMessage>>>,
32    /// Buffer for channel messages that arrive before the channel is registered.
33    ///
34    /// This handles the race between the caller sending items immediately after
35    /// `bind_channels_caller_args` creates the sink, and the callee's handler task
36    /// calling `register_rx` via `bind_channels_callee_args`. Items arriving in
37    /// that window are buffered here and drained when the channel is registered.
38    channel_buffers: SyncMutex<BTreeMap<ChannelId, Vec<IncomingChannelMessage>>>,
39    /// Credit semaphores for outbound channels (Tx on our side).
40    /// The driver's GrantCredit handler adds permits to these.
41    channel_credits: SyncMutex<BTreeMap<ChannelId, Arc<Semaphore>>>,
42}
43
44struct CallerDropGuard {
45    control_tx: mpsc::UnboundedSender<DropControlRequest>,
46    request: DropControlRequest,
47}
48
49impl Drop for CallerDropGuard {
50    fn drop(&mut self) {
51        let _ = self.control_tx.send(self.request);
52    }
53}
54
55#[cfg(test)]
56mod tests {
57    use super::{DriverChannelCreditReplenisher, DriverLocalControl};
58    use roam_types::{ChannelCreditReplenisher, ChannelId};
59    use tokio::sync::mpsc::error::TryRecvError;
60
61    #[test]
62    fn replenisher_batches_at_half_the_initial_window() {
63        let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher");
64        let replenisher = DriverChannelCreditReplenisher::new(ChannelId(7), 16, tx);
65
66        for _ in 0..7 {
67            replenisher.on_item_consumed();
68        }
69        assert!(
70            matches!(rx.try_recv(), Err(TryRecvError::Empty)),
71            "should not emit credit before reaching the batch threshold"
72        );
73
74        replenisher.on_item_consumed();
75        let Ok(DriverLocalControl::GrantCredit {
76            channel_id,
77            additional,
78        }) = rx.try_recv()
79        else {
80            panic!("expected batched credit grant");
81        };
82        assert_eq!(channel_id, ChannelId(7));
83        assert_eq!(additional, 8);
84    }
85
86    #[test]
87    fn replenisher_grants_one_by_one_for_single_credit_windows() {
88        let (tx, mut rx) = moire::sync::mpsc::unbounded_channel("test.replenisher.single");
89        let replenisher = DriverChannelCreditReplenisher::new(ChannelId(9), 1, tx);
90
91        replenisher.on_item_consumed();
92        let Ok(DriverLocalControl::GrantCredit {
93            channel_id,
94            additional,
95        }) = rx.try_recv()
96        else {
97            panic!("expected immediate credit grant");
98        };
99        assert_eq!(channel_id, ChannelId(9));
100        assert_eq!(additional, 1);
101    }
102}
103
104/// Concrete `ReplySink` implementation for the driver.
105///
106/// If dropped without `send_reply` being called, automatically sends
107/// `RoamError::Cancelled` to the caller. This guarantees that every
108/// request receives exactly one response (`rpc.response.one-per-request`),
109/// even if the handler panics or forgets to reply.
110pub struct DriverReplySink {
111    sender: Option<ConnectionSender>,
112    request_id: RequestId,
113    binder: DriverChannelBinder,
114}
115
116impl ReplySink for DriverReplySink {
117    async fn send_reply(mut self, response: RequestResponse<'_>) {
118        let sender = self
119            .sender
120            .take()
121            .expect("unreachable: send_reply takes self by value");
122        if let Err(_e) = sender.send_response(self.request_id, response).await {
123            sender.mark_failure(self.request_id, "send_response failed");
124        }
125    }
126
127    fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
128        Some(&self.binder)
129    }
130}
131
132// r[impl rpc.response.one-per-request]
133impl Drop for DriverReplySink {
134    fn drop(&mut self) {
135        if let Some(sender) = self.sender.take() {
136            sender.mark_failure(self.request_id, "no reply sent")
137        }
138    }
139}
140
141// r[impl rpc.channel.item]
142// r[impl rpc.channel.close]
143/// Concrete [`ChannelSink`] backed by a `ConnectionSender`.
144///
145/// Created by the driver when setting up outbound channels (Tx handles).
146/// Sends `ChannelItem` and `ChannelClose` messages through the connection.
147/// Wrapped with [`CreditSink`] to enforce credit-based flow control.
148pub struct DriverChannelSink {
149    sender: ConnectionSender,
150    channel_id: ChannelId,
151    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
152}
153
154impl ChannelSink for DriverChannelSink {
155    fn send_payload<'payload>(
156        &self,
157        payload: Payload<'payload>,
158    ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'payload>> {
159        let sender = self.sender.clone();
160        let channel_id = self.channel_id;
161        Box::pin(async move {
162            sender
163                .send(ConnectionMessage::Channel(ChannelMessage {
164                    id: channel_id,
165                    body: ChannelBody::Item(ChannelItem { item: payload }),
166                }))
167                .await
168                .map_err(|()| TxError::Transport("connection closed".into()))
169        })
170    }
171
172    fn close_channel(
173        &self,
174        _metadata: roam_types::Metadata,
175    ) -> Pin<Box<dyn std::future::Future<Output = Result<(), TxError>> + Send + 'static>> {
176        // [FIXME] ChannelSink::close_channel takes borrowed Metadata but returns 'static future.
177        // We drop the borrowed metadata and send an empty one. This matches the [FIXME] in the
178        // trait definition — the signature needs to be fixed to take owned metadata.
179        let sender = self.sender.clone();
180        let channel_id = self.channel_id;
181        Box::pin(async move {
182            sender
183                .send(ConnectionMessage::Channel(ChannelMessage {
184                    id: channel_id,
185                    body: ChannelBody::Close(ChannelClose {
186                        metadata: Default::default(),
187                    }),
188                }))
189                .await
190                .map_err(|()| TxError::Transport("connection closed".into()))
191        })
192    }
193
194    fn close_channel_on_drop(&self) {
195        let _ = self
196            .local_control_tx
197            .send(DriverLocalControl::CloseChannel {
198                channel_id: self.channel_id,
199            });
200    }
201}
202
203/// Liveness-only handle for a connection root.
204///
205/// Keeps the root connection alive but intentionally exposes no outbound RPC API.
206#[must_use = "Dropping NoopCaller may close the connection if it is the last caller."]
207#[derive(Clone)]
208pub struct NoopCaller(#[allow(dead_code)] DriverCaller);
209
210impl From<DriverCaller> for NoopCaller {
211    fn from(caller: DriverCaller) -> Self {
212        Self(caller)
213    }
214}
215
216#[derive(Clone)]
217struct DriverChannelBinder {
218    sender: ConnectionSender,
219    shared: Arc<DriverShared>,
220    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
221    drop_guard: Option<Arc<CallerDropGuard>>,
222}
223
224impl DriverChannelBinder {
225    fn create_tx_channel(
226        &self,
227        initial_credit: u32,
228    ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
229        let channel_id = self.shared.channel_ids.lock().alloc();
230        let inner = DriverChannelSink {
231            sender: self.sender.clone(),
232            channel_id,
233            local_control_tx: self.local_control_tx.clone(),
234        };
235        let sink = Arc::new(CreditSink::new(inner, initial_credit));
236        self.shared
237            .channel_credits
238            .lock()
239            .insert(channel_id, Arc::clone(sink.credit()));
240        (channel_id, sink)
241    }
242
243    fn register_rx_channel(
244        &self,
245        channel_id: ChannelId,
246        initial_credit: u32,
247    ) -> roam_types::BoundChannelReceiver {
248        let (tx, rx) = tokio::sync::mpsc::channel(64);
249        let mut terminal_buffered = false;
250        if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
251            for msg in buffered {
252                let is_terminal = matches!(
253                    msg,
254                    IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
255                );
256                let _ = tx.try_send(msg);
257                if is_terminal {
258                    terminal_buffered = true;
259                    break;
260                }
261            }
262        }
263        if terminal_buffered {
264            self.shared.channel_credits.lock().remove(&channel_id);
265            return roam_types::BoundChannelReceiver {
266                receiver: rx,
267                liveness: self.channel_liveness(),
268                replenisher: None,
269            };
270        }
271
272        self.shared.channel_senders.lock().insert(channel_id, tx);
273        roam_types::BoundChannelReceiver {
274            receiver: rx,
275            liveness: self.channel_liveness(),
276            replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
277                channel_id,
278                initial_credit,
279                self.local_control_tx.clone(),
280            )) as ChannelCreditReplenisherHandle),
281        }
282    }
283}
284
285impl ChannelBinder for DriverChannelBinder {
286    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
287        let (id, sink) = self.create_tx_channel(initial_credit);
288        (id, sink as Arc<dyn ChannelSink>)
289    }
290
291    fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
292        let channel_id = self.shared.channel_ids.lock().alloc();
293        let rx = self.register_rx_channel(channel_id, initial_credit);
294        (channel_id, rx)
295    }
296
297    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
298        let inner = DriverChannelSink {
299            sender: self.sender.clone(),
300            channel_id,
301            local_control_tx: self.local_control_tx.clone(),
302        };
303        let sink = Arc::new(CreditSink::new(inner, initial_credit));
304        self.shared
305            .channel_credits
306            .lock()
307            .insert(channel_id, Arc::clone(sink.credit()));
308        sink
309    }
310
311    fn register_rx(
312        &self,
313        channel_id: ChannelId,
314        initial_credit: u32,
315    ) -> roam_types::BoundChannelReceiver {
316        self.register_rx_channel(channel_id, initial_credit)
317    }
318
319    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
320        self.drop_guard
321            .as_ref()
322            .map(|guard| guard.clone() as ChannelLivenessHandle)
323    }
324}
325
326/// Implements [`Caller`]: allocates a request ID, registers a response slot,
327/// sends the call through the connection, and awaits the response.
328#[derive(Clone)]
329pub struct DriverCaller {
330    sender: ConnectionSender,
331    shared: Arc<DriverShared>,
332    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
333    closed_rx: watch::Receiver<bool>,
334    _drop_guard: Option<Arc<CallerDropGuard>>,
335}
336
337impl DriverCaller {
338    /// Allocate a channel ID and create a credit-controlled sink for outbound items.
339    ///
340    /// `initial_credit` is the const generic `N` from `Tx<T, N>`.
341    /// The returned sink enforces credit; the semaphore is registered so
342    /// `GrantCredit` messages can add permits.
343    pub fn create_tx_channel(
344        &self,
345        initial_credit: u32,
346    ) -> (ChannelId, Arc<CreditSink<DriverChannelSink>>) {
347        let channel_id = self.shared.channel_ids.lock().alloc();
348        let inner = DriverChannelSink {
349            sender: self.sender.clone(),
350            channel_id,
351            local_control_tx: self.local_control_tx.clone(),
352        };
353        let sink = Arc::new(CreditSink::new(inner, initial_credit));
354        self.shared
355            .channel_credits
356            .lock()
357            .insert(channel_id, Arc::clone(sink.credit()));
358        (channel_id, sink)
359    }
360
361    /// Returns the underlying connection sender.
362    ///
363    /// Used by in-crate tests that need to inject raw messages for cancellation
364    /// and channel protocol testing.
365    #[cfg(test)]
366    pub(crate) fn connection_sender(&self) -> &ConnectionSender {
367        &self.sender
368    }
369
370    /// Register an inbound channel (Rx on our side) and return the receiver.
371    ///
372    /// The channel ID comes from the peer (e.g. from `RequestCall.channels`).
373    /// The returned receiver should be bound to an `Rx` handle via `Rx::bind()`.
374    pub fn register_rx_channel(
375        &self,
376        channel_id: ChannelId,
377        initial_credit: u32,
378    ) -> roam_types::BoundChannelReceiver {
379        let (tx, rx) = tokio::sync::mpsc::channel(64);
380        let mut terminal_buffered = false;
381        // Drain any buffered messages that arrived before registration.
382        if let Some(buffered) = self.shared.channel_buffers.lock().remove(&channel_id) {
383            for msg in buffered {
384                let is_terminal = matches!(
385                    msg,
386                    IncomingChannelMessage::Close(_) | IncomingChannelMessage::Reset(_)
387                );
388                let _ = tx.try_send(msg);
389                if is_terminal {
390                    terminal_buffered = true;
391                    break;
392                }
393            }
394        }
395        if terminal_buffered {
396            self.shared.channel_credits.lock().remove(&channel_id);
397            return roam_types::BoundChannelReceiver {
398                receiver: rx,
399                liveness: self.channel_liveness(),
400                replenisher: None,
401            };
402        }
403
404        self.shared.channel_senders.lock().insert(channel_id, tx);
405        roam_types::BoundChannelReceiver {
406            receiver: rx,
407            liveness: self.channel_liveness(),
408            replenisher: Some(Arc::new(DriverChannelCreditReplenisher::new(
409                channel_id,
410                initial_credit,
411                self.local_control_tx.clone(),
412            )) as ChannelCreditReplenisherHandle),
413        }
414    }
415}
416
417impl ChannelBinder for DriverCaller {
418    fn create_tx(&self, initial_credit: u32) -> (ChannelId, Arc<dyn ChannelSink>) {
419        let (id, sink) = self.create_tx_channel(initial_credit);
420        (id, sink as Arc<dyn ChannelSink>)
421    }
422
423    fn create_rx(&self, initial_credit: u32) -> (ChannelId, roam_types::BoundChannelReceiver) {
424        let channel_id = self.shared.channel_ids.lock().alloc();
425        let rx = self.register_rx_channel(channel_id, initial_credit);
426        (channel_id, rx)
427    }
428
429    fn bind_tx(&self, channel_id: ChannelId, initial_credit: u32) -> Arc<dyn ChannelSink> {
430        let inner = DriverChannelSink {
431            sender: self.sender.clone(),
432            channel_id,
433            local_control_tx: self.local_control_tx.clone(),
434        };
435        let sink = Arc::new(CreditSink::new(inner, initial_credit));
436        self.shared
437            .channel_credits
438            .lock()
439            .insert(channel_id, Arc::clone(sink.credit()));
440        sink
441    }
442
443    fn register_rx(
444        &self,
445        channel_id: ChannelId,
446        initial_credit: u32,
447    ) -> roam_types::BoundChannelReceiver {
448        self.register_rx_channel(channel_id, initial_credit)
449    }
450
451    fn channel_liveness(&self) -> Option<ChannelLivenessHandle> {
452        self._drop_guard
453            .as_ref()
454            .map(|guard| guard.clone() as ChannelLivenessHandle)
455    }
456}
457
458impl Caller for DriverCaller {
459    fn call<'a>(
460        &'a self,
461        call: RequestCall<'a>,
462    ) -> impl std::future::Future<Output = Result<SelfRef<RequestResponse<'static>>, RoamError>>
463    + MaybeSend
464    + 'a {
465        async {
466            // Allocate a request ID.
467            let req_id = self.shared.request_ids.lock().alloc();
468
469            // Register the response slot before sending, so the driver can
470            // route the response even if it arrives before we start awaiting.
471            let (tx, rx) = moire::sync::oneshot::channel("driver.response");
472            self.shared.pending_responses.lock().insert(req_id, tx);
473
474            // Send the call. This awaits the conduit permit and serializes
475            // the borrowed payload all the way to the link's write buffer.
476            let send_result = self
477                .sender
478                .send(ConnectionMessage::Request(RequestMessage {
479                    id: req_id,
480                    body: RequestBody::Call(call),
481                }))
482                .await;
483
484            if send_result.is_err() {
485                // Clean up the pending slot.
486                self.shared.pending_responses.lock().remove(&req_id);
487                return Err(RoamError::Cancelled);
488            }
489
490            // Await the response from the driver loop.
491            let response_msg: SelfRef<RequestMessage<'static>> = rx
492                .named("awaiting_response")
493                .await
494                .map_err(|_| RoamError::Cancelled)?;
495
496            // Extract the Response variant from the RequestMessage.
497            let response = response_msg.map(|m| match m.body {
498                RequestBody::Response(r) => r,
499                _ => unreachable!("pending_responses only gets Response variants"),
500            });
501
502            Ok(response)
503        }
504        .named("Caller::call")
505    }
506
507    fn closed(&self) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> {
508        Box::pin(async move {
509            if *self.closed_rx.borrow() {
510                return;
511            }
512            let mut rx = self.closed_rx.clone();
513            while rx.changed().await.is_ok() {
514                if *rx.borrow() {
515                    return;
516                }
517            }
518        })
519    }
520
521    fn is_connected(&self) -> bool {
522        !*self.closed_rx.borrow()
523    }
524
525    fn channel_binder(&self) -> Option<&dyn ChannelBinder> {
526        Some(self)
527    }
528}
529
530// r[impl rpc.handler]
531// r[impl rpc.request]
532// r[impl rpc.response]
533// r[impl rpc.pipelining]
534/// Per-connection driver. Handles in-flight request tracking, dispatches
535/// incoming calls to a Handler, and manages channel state/flow control.
536pub struct Driver<H: Handler<DriverReplySink>> {
537    sender: ConnectionSender,
538    rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
539    failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
540    closed_rx: watch::Receiver<bool>,
541    local_control_rx: mpsc::UnboundedReceiver<DriverLocalControl>,
542    handler: Arc<H>,
543    shared: Arc<DriverShared>,
544    /// In-flight server-side handler tasks, keyed by request ID.
545    /// Used to abort handlers on cancel.
546    in_flight_handlers: BTreeMap<RequestId, moire::task::JoinHandle<()>>,
547    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
548    drop_control_seed: Option<mpsc::UnboundedSender<DropControlRequest>>,
549    drop_control_request: DropControlRequest,
550    drop_guard: SyncMutex<Option<Weak<CallerDropGuard>>>,
551}
552
553enum DriverLocalControl {
554    CloseChannel {
555        channel_id: ChannelId,
556    },
557    GrantCredit {
558        channel_id: ChannelId,
559        additional: u32,
560    },
561}
562
563struct DriverChannelCreditReplenisher {
564    channel_id: ChannelId,
565    threshold: u32,
566    local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
567    pending: std::sync::Mutex<u32>,
568}
569
570impl DriverChannelCreditReplenisher {
571    fn new(
572        channel_id: ChannelId,
573        initial_credit: u32,
574        local_control_tx: mpsc::UnboundedSender<DriverLocalControl>,
575    ) -> Self {
576        Self {
577            channel_id,
578            threshold: (initial_credit / 2).max(1),
579            local_control_tx,
580            pending: std::sync::Mutex::new(0),
581        }
582    }
583}
584
585impl ChannelCreditReplenisher for DriverChannelCreditReplenisher {
586    fn on_item_consumed(&self) {
587        let mut pending = self.pending.lock().expect("pending credit mutex poisoned");
588        *pending += 1;
589        if *pending < self.threshold {
590            return;
591        }
592
593        let additional = *pending;
594        *pending = 0;
595        let _ = self.local_control_tx.send(DriverLocalControl::GrantCredit {
596            channel_id: self.channel_id,
597            additional,
598        });
599    }
600}
601
602impl<H: Handler<DriverReplySink>> Driver<H> {
603    pub fn new(handle: ConnectionHandle, handler: H) -> Self {
604        let conn_id = handle.connection_id();
605        let ConnectionHandle {
606            sender,
607            rx,
608            failures_rx,
609            control_tx,
610            closed_rx,
611            parity,
612        } = handle;
613        let drop_control_request = DropControlRequest::Close(conn_id);
614        let (local_control_tx, local_control_rx) = mpsc::unbounded_channel("driver.local_control");
615        Self {
616            sender,
617            rx,
618            failures_rx,
619            closed_rx,
620            local_control_rx,
621            handler: Arc::new(handler),
622            shared: Arc::new(DriverShared {
623                pending_responses: SyncMutex::new("driver.pending_responses", BTreeMap::new()),
624                request_ids: SyncMutex::new("driver.request_ids", IdAllocator::new(parity)),
625                channel_ids: SyncMutex::new("driver.channel_ids", IdAllocator::new(parity)),
626                channel_senders: SyncMutex::new("driver.channel_senders", BTreeMap::new()),
627                channel_buffers: SyncMutex::new("driver.channel_buffers", BTreeMap::new()),
628                channel_credits: SyncMutex::new("driver.channel_credits", BTreeMap::new()),
629            }),
630            in_flight_handlers: BTreeMap::new(),
631            local_control_tx,
632            drop_control_seed: control_tx,
633            drop_control_request,
634            drop_guard: SyncMutex::new("driver.drop_guard", None),
635        }
636    }
637
638    /// Get a cloneable caller handle for making outgoing calls.
639    // r[impl rpc.caller.liveness.refcounted]
640    // r[impl rpc.caller.liveness.last-drop-closes-connection]
641    // r[impl rpc.caller.liveness.root-internal-close]
642    // r[impl rpc.caller.liveness.root-teardown-condition]
643    fn existing_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
644        self.drop_guard.lock().as_ref().and_then(Weak::upgrade)
645    }
646
647    fn connection_drop_guard(&self) -> Option<Arc<CallerDropGuard>> {
648        let drop_guard = if let Some(existing) = self.existing_drop_guard() {
649            Some(existing)
650        } else if let Some(seed) = &self.drop_control_seed {
651            let mut guard = self.drop_guard.lock();
652            if let Some(existing) = guard.as_ref().and_then(Weak::upgrade) {
653                Some(existing)
654            } else {
655                let arc = Arc::new(CallerDropGuard {
656                    control_tx: seed.clone(),
657                    request: self.drop_control_request,
658                });
659                *guard = Some(Arc::downgrade(&arc));
660                Some(arc)
661            }
662        } else {
663            None
664        };
665        drop_guard
666    }
667
668    pub fn caller(&self) -> DriverCaller {
669        let drop_guard = self.connection_drop_guard();
670        DriverCaller {
671            sender: self.sender.clone(),
672            shared: Arc::clone(&self.shared),
673            local_control_tx: self.local_control_tx.clone(),
674            closed_rx: self.closed_rx.clone(),
675            _drop_guard: drop_guard,
676        }
677    }
678
679    fn internal_binder(&self) -> DriverChannelBinder {
680        DriverChannelBinder {
681            sender: self.sender.clone(),
682            shared: Arc::clone(&self.shared),
683            local_control_tx: self.local_control_tx.clone(),
684            drop_guard: self.existing_drop_guard(),
685        }
686    }
687
688    // r[impl rpc.pipelining]
689    /// Main loop: receive messages from the session and dispatch them.
690    /// Handler calls run as spawned tasks — we don't block the driver
691    /// loop waiting for a handler to finish.
692    pub async fn run(&mut self) {
693        loop {
694            tokio::select! {
695                msg = self.rx.recv() => {
696                    match msg {
697                        Some(msg) => self.handle_msg(msg),
698                        None => break,
699                    }
700                }
701                Some((req_id, _reason)) = self.failures_rx.recv() => {
702                    // Clean up the handler tracking entry.
703                    self.in_flight_handlers.remove(&req_id);
704                    if self.shared.pending_responses.lock().remove(&req_id).is_none() {
705                        // Incoming call — handler failed to reply.
706                        // Wire format is always Result<T, RoamError<E>>, so encode
707                        // Cancelled as Err(...) in that envelope.
708                        let error: Result<(), RoamError<core::convert::Infallible>> =
709                            Err(RoamError::Cancelled);
710                        let _ = self.sender.send_response(req_id, RequestResponse {
711                            ret: Payload::outgoing(&error),
712                            channels: vec![],
713                            metadata: Default::default(),
714                        }).await;
715                    }
716                }
717                Some(ctrl) = self.local_control_rx.recv() => {
718                    self.handle_local_control(ctrl).await;
719                }
720            }
721        }
722
723        for (_, handle) in std::mem::take(&mut self.in_flight_handlers) {
724            handle.abort();
725        }
726        self.shared.pending_responses.lock().clear();
727
728        // Connection is gone: drop channel runtime state so any registered Rx
729        // receivers observe closure instead of hanging on recv().
730        self.shared.channel_senders.lock().clear();
731        self.shared.channel_buffers.lock().clear();
732        self.shared.channel_credits.lock().clear();
733    }
734
735    async fn handle_local_control(&mut self, control: DriverLocalControl) {
736        match control {
737            DriverLocalControl::CloseChannel { channel_id } => {
738                let _ = self
739                    .sender
740                    .send(ConnectionMessage::Channel(ChannelMessage {
741                        id: channel_id,
742                        body: ChannelBody::Close(ChannelClose {
743                            metadata: Default::default(),
744                        }),
745                    }))
746                    .await;
747            }
748            DriverLocalControl::GrantCredit {
749                channel_id,
750                additional,
751            } => {
752                let _ = self
753                    .sender
754                    .send(ConnectionMessage::Channel(ChannelMessage {
755                        id: channel_id,
756                        body: ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
757                            additional,
758                        }),
759                    }))
760                    .await;
761            }
762        }
763    }
764
765    fn handle_msg(&mut self, msg: SelfRef<ConnectionMessage<'static>>) {
766        let is_request = matches!(&*msg, ConnectionMessage::Request(_));
767        if is_request {
768            let msg = msg.map(|m| match m {
769                ConnectionMessage::Request(r) => r,
770                _ => unreachable!(),
771            });
772            self.handle_request(msg);
773        } else {
774            let msg = msg.map(|m| match m {
775                ConnectionMessage::Channel(c) => c,
776                _ => unreachable!(),
777            });
778            self.handle_channel(msg);
779        }
780    }
781
782    fn handle_request(&mut self, msg: SelfRef<RequestMessage<'static>>) {
783        let req_id = msg.id;
784        let is_call = matches!(&msg.body, RequestBody::Call(_));
785        let is_response = matches!(&msg.body, RequestBody::Response(_));
786        let is_cancel = matches!(&msg.body, RequestBody::Cancel(_));
787
788        if is_call {
789            // r[impl rpc.request]
790            // r[impl rpc.error.scope]
791            let reply = DriverReplySink {
792                sender: Some(self.sender.clone()),
793                request_id: req_id,
794                binder: self.internal_binder(),
795            };
796            let call = msg.map(|m| match m.body {
797                RequestBody::Call(c) => c,
798                _ => unreachable!(),
799            });
800            let handler = Arc::clone(&self.handler);
801            let join_handle = moire::task::spawn(
802                async move {
803                    handler.handle(call, reply).await;
804                }
805                .named("handler"),
806            );
807            self.in_flight_handlers.insert(req_id, join_handle);
808        } else if is_response {
809            // r[impl rpc.response.one-per-request]
810            if let Some(tx) = self.shared.pending_responses.lock().remove(&req_id) {
811                let _: Result<(), _> = tx.send(msg);
812            }
813        } else if is_cancel {
814            // r[impl rpc.cancel]
815            // r[impl rpc.cancel.channels]
816            // Abort the in-flight handler task. Channels are intentionally left
817            // intact — they have independent lifecycles per spec.
818            if let Some(handle) = self.in_flight_handlers.remove(&req_id) {
819                handle.abort();
820            }
821            // The response is sent automatically: aborting drops DriverReplySink →
822            // mark_failure fires → failures_rx arm sends RoamError::Cancelled.
823        }
824    }
825
826    fn handle_channel(&mut self, msg: SelfRef<ChannelMessage<'static>>) {
827        let chan_id = msg.id;
828
829        // Look up the channel sender from the shared registry (handles registered
830        // by both the driver and any DriverCaller that set up channels).
831        let sender = self.shared.channel_senders.lock().get(&chan_id).cloned();
832
833        match &msg.body {
834            // r[impl rpc.channel.item]
835            ChannelBody::Item(_item) => {
836                if let Some(tx) = &sender {
837                    let item = msg.map(|m| match m.body {
838                        ChannelBody::Item(item) => item,
839                        _ => unreachable!(),
840                    });
841                    // try_send: if the Rx has been dropped or the buffer is full, drop the item.
842                    let _ = tx.try_send(IncomingChannelMessage::Item(item));
843                } else {
844                    // Channel not yet registered — buffer until register_rx_channel is called.
845                    let item = msg.map(|m| match m.body {
846                        ChannelBody::Item(item) => item,
847                        _ => unreachable!(),
848                    });
849                    self.shared
850                        .channel_buffers
851                        .lock()
852                        .entry(chan_id)
853                        .or_default()
854                        .push(IncomingChannelMessage::Item(item));
855                }
856            }
857            // r[impl rpc.channel.close]
858            ChannelBody::Close(_close) => {
859                if let Some(tx) = &sender {
860                    let close = msg.map(|m| match m.body {
861                        ChannelBody::Close(close) => close,
862                        _ => unreachable!(),
863                    });
864                    let _ = tx.try_send(IncomingChannelMessage::Close(close));
865                } else {
866                    // Channel not yet registered — buffer the close.
867                    let close = msg.map(|m| match m.body {
868                        ChannelBody::Close(close) => close,
869                        _ => unreachable!(),
870                    });
871                    self.shared
872                        .channel_buffers
873                        .lock()
874                        .entry(chan_id)
875                        .or_default()
876                        .push(IncomingChannelMessage::Close(close));
877                }
878                self.shared.channel_senders.lock().remove(&chan_id);
879                self.shared.channel_credits.lock().remove(&chan_id);
880            }
881            // r[impl rpc.channel.reset]
882            ChannelBody::Reset(_reset) => {
883                if let Some(tx) = &sender {
884                    let reset = msg.map(|m| match m.body {
885                        ChannelBody::Reset(reset) => reset,
886                        _ => unreachable!(),
887                    });
888                    let _ = tx.try_send(IncomingChannelMessage::Reset(reset));
889                } else {
890                    // Channel not yet registered — buffer the reset.
891                    let reset = msg.map(|m| match m.body {
892                        ChannelBody::Reset(reset) => reset,
893                        _ => unreachable!(),
894                    });
895                    self.shared
896                        .channel_buffers
897                        .lock()
898                        .entry(chan_id)
899                        .or_default()
900                        .push(IncomingChannelMessage::Reset(reset));
901                }
902                self.shared.channel_senders.lock().remove(&chan_id);
903                self.shared.channel_credits.lock().remove(&chan_id);
904            }
905            // r[impl rpc.flow-control.credit.grant]
906            // r[impl rpc.flow-control.credit.grant.additive]
907            ChannelBody::GrantCredit(grant) => {
908                if let Some(semaphore) = self.shared.channel_credits.lock().get(&chan_id) {
909                    semaphore.add_permits(grant.additional as usize);
910                }
911            }
912        }
913    }
914}