Skip to main content

roam_core/session/
mod.rs

1use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
2
3use moire::sync::mpsc;
4use roam_types::{
5    ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
6    ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
7    IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity,
8    RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef, SessionRole,
9};
10use tokio::sync::watch;
11use tracing::{debug, warn};
12
13mod builders;
14pub use builders::*;
15
16// r[impl session.handshake]
17/// Current roam session protocol version.
18pub const PROTOCOL_VERSION: u32 = 7;
19
20/// Session-level protocol keepalive configuration.
21#[derive(Debug, Clone, Copy)]
22pub struct SessionKeepaliveConfig {
23    pub ping_interval: Duration,
24    pub pong_timeout: Duration,
25}
26
27// ---------------------------------------------------------------------------
28// Connection acceptor trait
29// ---------------------------------------------------------------------------
30
31/// Callback for accepting or rejecting inbound virtual connections.
32///
33/// Registered on the session via the builder's `.on_connection()` method.
34/// Called synchronously from the session run loop when a peer sends
35/// `ConnectionOpen`. The acceptor returns either an `AcceptedConnection`
36/// (with settings, metadata, and a setup callback that spawns the driver)
37/// or rejection metadata.
38// r[impl rpc.virtual-connection.accept]
39pub trait ConnectionAcceptor: Send + 'static {
40    fn accept(
41        &self,
42        conn_id: ConnectionId,
43        peer_settings: &ConnectionSettings,
44        metadata: &[roam_types::MetadataEntry],
45    ) -> Result<AcceptedConnection, Metadata<'static>>;
46}
47
48/// Result of accepting a virtual connection.
49pub struct AcceptedConnection {
50    /// Our settings for this connection.
51    pub settings: ConnectionSettings,
52    /// Metadata to send back in ConnectionAccept.
53    pub metadata: Metadata<'static>,
54    /// Callback that receives the ConnectionHandle and spawns a Driver.
55    pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
56}
57
58// ---------------------------------------------------------------------------
59// Open/close request types (from SessionHandle → run loop)
60// ---------------------------------------------------------------------------
61
62struct OpenRequest {
63    settings: ConnectionSettings,
64    metadata: Metadata<'static>,
65    result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
66}
67
68struct CloseRequest {
69    conn_id: ConnectionId,
70    metadata: Metadata<'static>,
71    result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
72}
73
74#[derive(Debug, Clone, Copy)]
75pub(crate) enum DropControlRequest {
76    Shutdown,
77    Close(ConnectionId),
78}
79
80#[cfg(not(target_arch = "wasm32"))]
81fn send_drop_control(
82    tx: &mpsc::UnboundedSender<DropControlRequest>,
83    req: DropControlRequest,
84) -> Result<(), ()> {
85    tx.send(req).map_err(|_| ())
86}
87
88#[cfg(target_arch = "wasm32")]
89fn send_drop_control(
90    tx: &mpsc::UnboundedSender<DropControlRequest>,
91    req: DropControlRequest,
92) -> Result<(), ()> {
93    tx.try_send(req).map_err(|_| ())
94}
95
96// ---------------------------------------------------------------------------
97// SessionHandle — cloneable handle for opening/closing virtual connections
98// ---------------------------------------------------------------------------
99
100/// Cloneable handle for opening and closing virtual connections.
101///
102/// Returned by the session builder alongside the `Session` and root
103/// `ConnectionHandle`. The session's `run()` loop must be running
104/// concurrently for requests to be processed.
105// r[impl rpc.virtual-connection.open]
106#[derive(Clone)]
107pub struct SessionHandle {
108    open_tx: mpsc::Sender<OpenRequest>,
109    close_tx: mpsc::Sender<CloseRequest>,
110    control_tx: mpsc::UnboundedSender<DropControlRequest>,
111}
112
113impl SessionHandle {
114    /// Open a new virtual connection on the session.
115    ///
116    /// Allocates a connection ID, sends `ConnectionOpen` to the peer, and
117    /// waits for `ConnectionAccept` or `ConnectionReject`. The session's
118    /// `run()` loop processes the response and completes the returned future.
119    // r[impl connection.open]
120    pub async fn open_connection(
121        &self,
122        settings: ConnectionSettings,
123        metadata: Metadata<'static>,
124    ) -> Result<ConnectionHandle, SessionError> {
125        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
126        self.open_tx
127            .send(OpenRequest {
128                settings,
129                metadata,
130                result_tx,
131            })
132            .await
133            .map_err(|_| SessionError::Protocol("session closed".into()))?;
134        result_rx
135            .await
136            .map_err(|_| SessionError::Protocol("session closed".into()))?
137    }
138
139    /// Close a virtual connection.
140    ///
141    /// Sends `ConnectionClose` to the peer and removes the connection slot.
142    /// After this returns, no further messages will be routed to the
143    /// connection's driver.
144    // r[impl connection.close]
145    pub async fn close_connection(
146        &self,
147        conn_id: ConnectionId,
148        metadata: Metadata<'static>,
149    ) -> Result<(), SessionError> {
150        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
151        self.close_tx
152            .send(CloseRequest {
153                conn_id,
154                metadata,
155                result_tx,
156            })
157            .await
158            .map_err(|_| SessionError::Protocol("session closed".into()))?;
159        result_rx
160            .await
161            .map_err(|_| SessionError::Protocol("session closed".into()))?
162    }
163
164    /// Request shutdown of the entire session (root + all virtual connections).
165    pub fn shutdown(&self) -> Result<(), SessionError> {
166        send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
167            .map_err(|_| SessionError::Protocol("session closed".into()))
168    }
169}
170
171// ---------------------------------------------------------------------------
172// Session
173// ---------------------------------------------------------------------------
174
175/// Session state machine.
176// r[impl session]
177// r[impl rpc.one-service-per-connection]
178pub struct Session<C: Conduit> {
179    /// Conduit receiver
180    rx: C::Rx,
181
182    // r[impl session.role]
183    role: SessionRole,
184
185    /// Our local parity — determines which connection IDs we allocate.
186    // r[impl session.parity]
187    parity: Parity,
188
189    /// Shared core (for sending) — also held by all ConnectionSenders.
190    sess_core: Arc<SessionCore>,
191
192    /// Connection state (active, pending inbound, pending outbound).
193    conns: BTreeMap<ConnectionId, ConnectionSlot>,
194    /// Whether the root connection was internally closed because all root callers dropped.
195    root_closed_internal: bool,
196
197    /// Allocator for outbound virtual connection IDs (uses session parity).
198    conn_ids: IdAllocator<ConnectionId>,
199
200    /// Callback for accepting inbound virtual connections.
201    on_connection: Option<Box<dyn ConnectionAcceptor>>,
202
203    /// Receiver for open requests from SessionHandle.
204    open_rx: mpsc::Receiver<OpenRequest>,
205
206    /// Receiver for close requests from SessionHandle.
207    close_rx: mpsc::Receiver<CloseRequest>,
208
209    /// Sender/receiver for drop-driven session/connection control requests.
210    control_tx: mpsc::UnboundedSender<DropControlRequest>,
211    control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
212
213    /// Optional proactive keepalive runtime config for connection ID 0.
214    keepalive: Option<SessionKeepaliveConfig>,
215}
216
217#[derive(Debug)]
218struct KeepaliveRuntime {
219    ping_interval: Duration,
220    pong_timeout: Duration,
221    next_ping_at: tokio::time::Instant,
222    waiting_pong_nonce: Option<u64>,
223    pong_deadline: tokio::time::Instant,
224    next_ping_nonce: u64,
225}
226
227// r[impl connection]
228/// Static data for one active connection.
229#[derive(Debug)]
230pub struct ConnectionState {
231    /// Unique connection identifier
232    pub id: ConnectionId,
233
234    /// Our settings
235    pub local_settings: ConnectionSettings,
236
237    /// The peer's settings
238    pub peer_settings: ConnectionSettings,
239
240    /// Sender for routing incoming messages to the per-connection driver task.
241    conn_tx: mpsc::Sender<SelfRef<ConnectionMessage<'static>>>,
242    closed_tx: watch::Sender<bool>,
243}
244
245#[derive(Debug)]
246enum ConnectionSlot {
247    Active(ConnectionState),
248    PendingOutbound(PendingOutboundData),
249}
250
251/// Debug-printable wrapper that omits the oneshot sender.
252struct PendingOutboundData {
253    local_settings: ConnectionSettings,
254    result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
255}
256
257impl std::fmt::Debug for PendingOutboundData {
258    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259        f.debug_struct("PendingOutbound")
260            .field("local_settings", &self.local_settings)
261            .finish()
262    }
263}
264
265#[derive(Clone)]
266pub(crate) struct ConnectionSender {
267    connection_id: ConnectionId,
268    sess_core: Arc<SessionCore>,
269    failures: Arc<mpsc::UnboundedSender<(RequestId, &'static str)>>,
270}
271
272fn forwarded_payload<'a>(payload: &'a roam_types::Payload<'static>) -> roam_types::Payload<'a> {
273    let roam_types::Payload::Incoming(bytes) = payload else {
274        unreachable!("proxy forwarding expects decoded incoming payload bytes")
275    };
276    roam_types::Payload::Incoming(bytes)
277}
278
279fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
280    match body {
281        RequestBody::Call(call) => RequestBody::Call(roam_types::RequestCall {
282            method_id: call.method_id,
283            channels: call.channels.clone(),
284            metadata: call.metadata.clone(),
285            args: forwarded_payload(&call.args),
286        }),
287        RequestBody::Response(response) => RequestBody::Response(RequestResponse {
288            channels: response.channels.clone(),
289            metadata: response.metadata.clone(),
290            ret: forwarded_payload(&response.ret),
291        }),
292        RequestBody::Cancel(cancel) => RequestBody::Cancel(roam_types::RequestCancel {
293            metadata: cancel.metadata.clone(),
294        }),
295    }
296}
297
298fn forwarded_channel_body<'a>(
299    body: &'a roam_types::ChannelBody<'static>,
300) -> roam_types::ChannelBody<'a> {
301    match body {
302        roam_types::ChannelBody::Item(item) => {
303            roam_types::ChannelBody::Item(roam_types::ChannelItem {
304                item: forwarded_payload(&item.item),
305            })
306        }
307        roam_types::ChannelBody::Close(close) => {
308            roam_types::ChannelBody::Close(roam_types::ChannelClose {
309                metadata: close.metadata.clone(),
310            })
311        }
312        roam_types::ChannelBody::Reset(reset) => {
313            roam_types::ChannelBody::Reset(roam_types::ChannelReset {
314                metadata: reset.metadata.clone(),
315            })
316        }
317        roam_types::ChannelBody::GrantCredit(credit) => {
318            roam_types::ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
319                additional: credit.additional,
320            })
321        }
322    }
323}
324
325impl ConnectionSender {
326    /// Send an arbitrary connection message
327    pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
328        let payload = match msg {
329            ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
330            ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
331        };
332        let message = Message {
333            connection_id: self.connection_id,
334            payload,
335        };
336        self.sess_core.send(message).await.map_err(|_| ())
337    }
338
339    /// Send a received connection message without re-materializing payload values.
340    pub(crate) async fn send_owned(
341        &self,
342        msg: SelfRef<ConnectionMessage<'static>>,
343    ) -> Result<(), ()> {
344        let payload = match &*msg {
345            ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
346                id: request.id,
347                body: forwarded_request_body(&request.body),
348            }),
349            ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
350                id: channel.id,
351                body: forwarded_channel_body(&channel.body),
352            }),
353        };
354
355        self.sess_core
356            .send(Message {
357                connection_id: self.connection_id,
358                payload,
359            })
360            .await
361            .map_err(|_| ())
362    }
363
364    /// Send a response specifically
365    pub async fn send_response<'a>(
366        &self,
367        request_id: RequestId,
368        response: RequestResponse<'a>,
369    ) -> Result<(), ()> {
370        self.send(ConnectionMessage::Request(RequestMessage {
371            id: request_id,
372            body: RequestBody::Response(response),
373        }))
374        .await
375    }
376
377    /// Mark a request as failed by removing any pending response slot.
378    /// Called when a send error occurs or no reply was sent.
379    pub fn mark_failure(&self, request_id: RequestId, reason: &'static str) {
380        let _ = self.failures.send((request_id, reason));
381    }
382}
383
384pub struct ConnectionHandle {
385    pub(crate) sender: ConnectionSender,
386    pub(crate) rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
387    pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
388    pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
389    pub(crate) closed_rx: watch::Receiver<bool>,
390    /// The parity this side should use for allocating request/channel IDs.
391    pub parity: Parity,
392}
393
394impl std::fmt::Debug for ConnectionHandle {
395    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396        f.debug_struct("ConnectionHandle")
397            .field("connection_id", &self.sender.connection_id)
398            .finish()
399    }
400}
401
402pub(crate) enum ConnectionMessage<'payload> {
403    Request(RequestMessage<'payload>),
404    Channel(ChannelMessage<'payload>),
405}
406
407impl ConnectionHandle {
408    /// Returns the connection ID for this handle.
409    pub fn connection_id(&self) -> ConnectionId {
410        self.sender.connection_id
411    }
412
413    /// Resolve when this connection closes.
414    pub async fn closed(&self) {
415        if *self.closed_rx.borrow() {
416            return;
417        }
418        let mut rx = self.closed_rx.clone();
419        while rx.changed().await.is_ok() {
420            if *rx.borrow() {
421                return;
422            }
423        }
424    }
425
426    /// Return whether this connection is still considered connected.
427    pub fn is_connected(&self) -> bool {
428        !*self.closed_rx.borrow()
429    }
430}
431
432/// Forward all request/channel traffic between two connections.
433///
434/// This is a protocol-level bridge: it does not inspect service schemas or method IDs.
435/// It exits when either side closes or a forward send fails, then requests closure of
436/// both underlying connections.
437pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
438    let left_conn_id = left.connection_id();
439    let right_conn_id = right.connection_id();
440    let ConnectionHandle {
441        sender: left_sender,
442        rx: mut left_rx,
443        failures_rx: _left_failures_rx,
444        control_tx: left_control_tx,
445        closed_rx: _left_closed_rx,
446        parity: _left_parity,
447    } = left;
448    let ConnectionHandle {
449        sender: right_sender,
450        rx: mut right_rx,
451        failures_rx: _right_failures_rx,
452        control_tx: right_control_tx,
453        closed_rx: _right_closed_rx,
454        parity: _right_parity,
455    } = right;
456
457    loop {
458        tokio::select! {
459            msg = left_rx.recv() => {
460                let Some(msg) = msg else {
461                    break;
462                };
463                if right_sender.send_owned(msg).await.is_err() {
464                    break;
465                }
466            }
467            msg = right_rx.recv() => {
468                let Some(msg) = msg else {
469                    break;
470                };
471                if left_sender.send_owned(msg).await.is_err() {
472                    break;
473                }
474            }
475        }
476    }
477
478    if let Some(tx) = left_control_tx.as_ref() {
479        let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
480    }
481    if let Some(tx) = right_control_tx.as_ref() {
482        let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
483    }
484}
485
486/// Errors that can occur during session establishment or operation.
487#[derive(Debug)]
488pub enum SessionError {
489    Io(std::io::Error),
490    Protocol(String),
491    Rejected(Metadata<'static>),
492}
493
494impl std::fmt::Display for SessionError {
495    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496        match self {
497            Self::Io(e) => write!(f, "io error: {e}"),
498            Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
499            Self::Rejected(_) => write!(f, "connection rejected"),
500        }
501    }
502}
503
504impl std::error::Error for SessionError {}
505
506impl<C> Session<C>
507where
508    C: Conduit<Msg = MessageFamily>,
509    C::Tx: MaybeSend + MaybeSync + 'static,
510    for<'p> <C::Tx as ConduitTx>::Permit<'p>: MaybeSend,
511    C::Rx: MaybeSend,
512{
513    #[allow(clippy::too_many_arguments)]
514    fn pre_handshake(
515        tx: C::Tx,
516        rx: C::Rx,
517        on_connection: Option<Box<dyn ConnectionAcceptor>>,
518        open_rx: mpsc::Receiver<OpenRequest>,
519        close_rx: mpsc::Receiver<CloseRequest>,
520        control_tx: mpsc::UnboundedSender<DropControlRequest>,
521        control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
522        keepalive: Option<SessionKeepaliveConfig>,
523    ) -> Self {
524        let sess_core = Arc::new(SessionCore { tx: Box::new(tx) });
525        Session {
526            rx,
527            role: SessionRole::Initiator, // overwritten in establish_as_*
528            parity: Parity::Odd,          // overwritten in establish_as_*
529            sess_core,
530            conns: BTreeMap::new(),
531            root_closed_internal: false,
532            conn_ids: IdAllocator::new(Parity::Odd), // overwritten in establish_as_*
533            on_connection,
534            open_rx,
535            close_rx,
536            control_tx,
537            control_rx,
538            keepalive,
539        }
540    }
541
542    // r[impl session.handshake]
543    async fn establish_as_initiator(
544        &mut self,
545        settings: ConnectionSettings,
546        metadata: Metadata<'_>,
547    ) -> Result<ConnectionHandle, SessionError> {
548        use roam_types::{Hello, MessagePayload};
549
550        self.role = SessionRole::Initiator;
551        self.parity = settings.parity;
552        self.conn_ids = IdAllocator::new(settings.parity);
553
554        // Send Hello
555        self.sess_core
556            .send(Message {
557                connection_id: ConnectionId::ROOT,
558                payload: MessagePayload::Hello(Hello {
559                    version: PROTOCOL_VERSION,
560                    connection_settings: settings.clone(),
561                    metadata,
562                }),
563            })
564            .await
565            .map_err(|_| SessionError::Protocol("failed to send Hello".into()))?;
566
567        // Receive HelloYourself
568        let peer_settings = match self.rx.recv().await {
569            Ok(Some(msg)) => {
570                let payload = msg.map(|m| m.payload);
571                match &*payload {
572                    MessagePayload::HelloYourself(hy) => hy.connection_settings.clone(),
573                    MessagePayload::ProtocolError(e) => {
574                        return Err(SessionError::Protocol(e.description.to_owned()));
575                    }
576                    _ => {
577                        return Err(SessionError::Protocol("expected HelloYourself".into()));
578                    }
579                }
580            }
581            Ok(None) => {
582                return Err(SessionError::Protocol(
583                    "peer closed during handshake".into(),
584                ));
585            }
586            Err(e) => return Err(SessionError::Protocol(e.to_string())),
587        };
588
589        Ok(self.make_root_handle(settings, peer_settings))
590    }
591
592    // r[impl session.handshake]
593    #[moire::instrument]
594    async fn establish_as_acceptor(
595        &mut self,
596        settings: ConnectionSettings,
597        metadata: Metadata<'_>,
598    ) -> Result<ConnectionHandle, SessionError> {
599        use roam_types::{HelloYourself, MessagePayload};
600
601        self.role = SessionRole::Acceptor;
602
603        // Receive Hello
604        let peer_settings = match self.rx.recv().await {
605            Ok(Some(msg)) => {
606                let payload = msg.map(|m| m.payload);
607                match &*payload {
608                    MessagePayload::Hello(h) => {
609                        if h.version != PROTOCOL_VERSION {
610                            return Err(SessionError::Protocol(format!(
611                                "version mismatch: got {}, expected {PROTOCOL_VERSION}",
612                                h.version
613                            )));
614                        }
615                        h.connection_settings.clone()
616                    }
617                    MessagePayload::ProtocolError(e) => {
618                        return Err(SessionError::Protocol(e.description.to_owned()));
619                    }
620                    _ => {
621                        return Err(SessionError::Protocol("expected Hello".into()));
622                    }
623                }
624            }
625            Ok(None) => {
626                return Err(SessionError::Protocol(
627                    "peer closed during handshake".into(),
628                ));
629            }
630            Err(e) => return Err(SessionError::Protocol(e.to_string())),
631        };
632
633        // Acceptor parity is opposite of initiator
634        let our_settings = ConnectionSettings {
635            parity: peer_settings.parity.other(),
636            ..settings
637        };
638        self.parity = our_settings.parity;
639        self.conn_ids = IdAllocator::new(our_settings.parity);
640
641        // Send HelloYourself
642        self.sess_core
643            .send(Message {
644                connection_id: ConnectionId::ROOT,
645                payload: MessagePayload::HelloYourself(HelloYourself {
646                    connection_settings: our_settings.clone(),
647                    metadata,
648                }),
649            })
650            .await
651            .map_err(|_| SessionError::Protocol("failed to send HelloYourself".into()))?;
652
653        Ok(self.make_root_handle(our_settings, peer_settings))
654    }
655
656    fn make_root_handle(
657        &mut self,
658        local_settings: ConnectionSettings,
659        peer_settings: ConnectionSettings,
660    ) -> ConnectionHandle {
661        self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
662    }
663
664    fn make_connection_handle(
665        &mut self,
666        conn_id: ConnectionId,
667        local_settings: ConnectionSettings,
668        peer_settings: ConnectionSettings,
669    ) -> ConnectionHandle {
670        let label = format!("session.conn{}", conn_id.0);
671        let (conn_tx, conn_rx) = mpsc::channel::<SelfRef<ConnectionMessage<'static>>>(&label, 64);
672        let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
673        let (closed_tx, closed_rx) = watch::channel(false);
674
675        let sender = ConnectionSender {
676            connection_id: conn_id,
677            sess_core: Arc::clone(&self.sess_core),
678            failures: Arc::new(failures_tx),
679        };
680
681        let parity = local_settings.parity;
682        self.conns.insert(
683            conn_id,
684            ConnectionSlot::Active(ConnectionState {
685                id: conn_id,
686                local_settings,
687                peer_settings,
688                conn_tx,
689                closed_tx,
690            }),
691        );
692
693        ConnectionHandle {
694            sender,
695            rx: conn_rx,
696            failures_rx,
697            control_tx: Some(self.control_tx.clone()),
698            closed_rx,
699            parity,
700        }
701    }
702
703    /// Run the session recv loop: read from the conduit, demux by connection
704    /// ID, and route to the appropriate connection's driver. Also processes
705    /// open/close requests from the SessionHandle.
706    // r[impl zerocopy.framing.pipeline.incoming]
707    pub async fn run(&mut self) {
708        let mut keepalive_runtime = self.make_keepalive_runtime();
709        let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
710            let mut interval = tokio::time::interval(Duration::from_millis(10));
711            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
712            interval
713        });
714
715        loop {
716            tokio::select! {
717                msg = self.rx.recv() => {
718                    match msg {
719                        Ok(Some(msg)) => self.handle_message(msg, &mut keepalive_runtime).await,
720                        Ok(None) => {
721                            warn!("session recv loop ended: conduit returned EOF");
722                            break;
723                        }
724                        Err(error) => {
725                            warn!(error = %error, "session recv loop ended: conduit recv error");
726                            break;
727                        }
728                    }
729                }
730                Some(req) = self.open_rx.recv() => {
731                    self.handle_open_request(req).await;
732                }
733                Some(req) = self.close_rx.recv() => {
734                    self.handle_close_request(req).await;
735                }
736                Some(req) = self.control_rx.recv() => {
737                    if !self.handle_drop_control_request(req).await {
738                        break;
739                    }
740                }
741                _ = async {
742                    if let Some(interval) = keepalive_tick.as_mut() {
743                        interval.tick().await;
744                    }
745                }, if keepalive_tick.is_some() => {
746                    if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
747                        break;
748                    }
749                }
750            }
751        }
752
753        // Drop all connection slots so per-connection drivers exit immediately.
754        self.close_all_connections();
755        debug!("session recv loop exited");
756    }
757
758    async fn handle_message(
759        &mut self,
760        msg: SelfRef<Message<'static>>,
761        keepalive_runtime: &mut Option<KeepaliveRuntime>,
762    ) {
763        let conn_id = msg.connection_id;
764        roam_types::selfref_match!(msg, payload {
765            // r[impl connection.close.semantics]
766            MessagePayload::ConnectionClose(_) => {
767                if conn_id.is_root() {
768                    warn!("received ConnectionClose for root connection");
769                } else {
770                    debug!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
771                }
772                // Remove the connection — dropping conn_tx causes the Driver's rx
773                // to return None, which exits its run loop. All in-flight handlers
774                // are dropped, triggering DriverReplySink::drop → Cancelled responses.
775                self.remove_connection(&conn_id);
776                self.maybe_request_shutdown_after_root_closed();
777            }
778            MessagePayload::ConnectionOpen(open) => {
779                self.handle_inbound_open(conn_id, open).await;
780            }
781            MessagePayload::ConnectionAccept(accept) => {
782                self.handle_inbound_accept(conn_id, accept);
783            }
784            MessagePayload::ConnectionReject(reject) => {
785                self.handle_inbound_reject(conn_id, reject);
786            }
787            MessagePayload::RequestMessage(r) => {
788                let conn_tx = match self.conns.get(&conn_id) {
789                    Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
790                    _ => return,
791                };
792                if conn_tx.send(r.map(ConnectionMessage::Request)).await.is_err() {
793                    self.remove_connection(&conn_id);
794                    self.maybe_request_shutdown_after_root_closed();
795                }
796            }
797            MessagePayload::ChannelMessage(c) => {
798                let conn_tx = match self.conns.get(&conn_id) {
799                    Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
800                    _ => return,
801                };
802                if conn_tx.send(c.map(ConnectionMessage::Channel)).await.is_err() {
803                    self.remove_connection(&conn_id);
804                    self.maybe_request_shutdown_after_root_closed();
805                }
806            }
807            MessagePayload::Ping(ping) => {
808                let _ = self
809                    .sess_core
810                    .send(Message {
811                        connection_id: conn_id,
812                        payload: MessagePayload::Pong(roam_types::Pong { nonce: ping.nonce }),
813                    })
814                    .await;
815            }
816            MessagePayload::Pong(pong) => {
817                if conn_id.is_root() {
818                    self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
819                }
820            }
821            // Hello, HelloYourself, ProtocolError: not valid post-handshake, drop.
822        })
823    }
824
825    fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
826        let config = self.keepalive?;
827        if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
828            warn!("keepalive disabled due to non-positive interval/timeout");
829            return None;
830        }
831        let now = tokio::time::Instant::now();
832        Some(KeepaliveRuntime {
833            ping_interval: config.ping_interval,
834            pong_timeout: config.pong_timeout,
835            next_ping_at: now + config.ping_interval,
836            waiting_pong_nonce: None,
837            pong_deadline: now,
838            next_ping_nonce: 1,
839        })
840    }
841
842    fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
843        let Some(runtime) = keepalive_runtime.as_mut() else {
844            return;
845        };
846        if runtime.waiting_pong_nonce != Some(nonce) {
847            return;
848        }
849        runtime.waiting_pong_nonce = None;
850        runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
851    }
852
853    async fn handle_keepalive_tick(
854        &mut self,
855        keepalive_runtime: &mut Option<KeepaliveRuntime>,
856    ) -> bool {
857        let Some(runtime) = keepalive_runtime.as_mut() else {
858            return true;
859        };
860        let now = tokio::time::Instant::now();
861
862        if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
863            if now >= runtime.pong_deadline {
864                warn!(
865                    nonce = waiting_nonce,
866                    timeout_ms = runtime.pong_timeout.as_millis(),
867                    "keepalive timeout waiting for pong"
868                );
869                return false;
870            }
871            return true;
872        }
873
874        if now < runtime.next_ping_at {
875            return true;
876        }
877
878        let nonce = runtime.next_ping_nonce;
879        if self
880            .sess_core
881            .send(Message {
882                connection_id: ConnectionId::ROOT,
883                payload: MessagePayload::Ping(roam_types::Ping { nonce }),
884            })
885            .await
886            .is_err()
887        {
888            warn!("failed to send keepalive ping");
889            return false;
890        }
891
892        runtime.waiting_pong_nonce = Some(nonce);
893        runtime.pong_deadline = now + runtime.pong_timeout;
894        runtime.next_ping_at = now + runtime.ping_interval;
895        runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
896        true
897    }
898
899    async fn handle_inbound_open(
900        &mut self,
901        conn_id: ConnectionId,
902        open: SelfRef<ConnectionOpen<'static>>,
903    ) {
904        // Validate: connection ID must match peer's parity (opposite of ours).
905        let peer_parity = self.parity.other();
906        if !conn_id.has_parity(peer_parity) {
907            // Protocol error: wrong parity. For now, just reject.
908            let _ = self
909                .sess_core
910                .send(Message {
911                    connection_id: conn_id,
912                    payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
913                        metadata: vec![],
914                    }),
915                })
916                .await;
917            return;
918        }
919
920        // Validate: connection ID must not already be in use.
921        if self.conns.contains_key(&conn_id) {
922            // Protocol error: duplicate connection ID.
923            let _ = self
924                .sess_core
925                .send(Message {
926                    connection_id: conn_id,
927                    payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
928                        metadata: vec![],
929                    }),
930                })
931                .await;
932            return;
933        }
934
935        // r[impl connection.open.rejection]
936        // Call the acceptor callback. If none is registered, reject.
937        let acceptor = match &self.on_connection {
938            Some(a) => a,
939            None => {
940                let _ = self
941                    .sess_core
942                    .send(Message {
943                        connection_id: conn_id,
944                        payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
945                            metadata: vec![],
946                        }),
947                    })
948                    .await;
949                return;
950            }
951        };
952
953        match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
954            Ok(accepted) => {
955                // Create the connection handle and activate it.
956                let handle = self.make_connection_handle(
957                    conn_id,
958                    accepted.settings.clone(),
959                    open.connection_settings.clone(),
960                );
961
962                // Send ConnectionAccept to the peer.
963                let _ = self
964                    .sess_core
965                    .send(Message {
966                        connection_id: conn_id,
967                        payload: MessagePayload::ConnectionAccept(roam_types::ConnectionAccept {
968                            connection_settings: accepted.settings,
969                            metadata: accepted.metadata,
970                        }),
971                    })
972                    .await;
973
974                // Let the acceptor set up its driver.
975                (accepted.setup)(handle);
976            }
977            Err(reject_metadata) => {
978                let _ = self
979                    .sess_core
980                    .send(Message {
981                        connection_id: conn_id,
982                        payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
983                            metadata: reject_metadata,
984                        }),
985                    })
986                    .await;
987            }
988        }
989    }
990
991    fn handle_inbound_accept(
992        &mut self,
993        conn_id: ConnectionId,
994        accept: SelfRef<ConnectionAccept<'static>>,
995    ) {
996        let slot = self.remove_connection(&conn_id);
997        match slot {
998            Some(ConnectionSlot::PendingOutbound(mut pending)) => {
999                let handle = self.make_connection_handle(
1000                    conn_id,
1001                    pending.local_settings.clone(),
1002                    accept.connection_settings.clone(),
1003                );
1004
1005                if let Some(tx) = pending.result_tx.take() {
1006                    let _ = tx.send(Ok(handle));
1007                }
1008            }
1009            Some(other) => {
1010                // Not pending outbound — put it back and ignore.
1011                self.conns.insert(conn_id, other);
1012            }
1013            None => {
1014                // No pending open for this ID — ignore.
1015            }
1016        }
1017    }
1018
1019    fn handle_inbound_reject(
1020        &mut self,
1021        conn_id: ConnectionId,
1022        reject: SelfRef<ConnectionReject<'static>>,
1023    ) {
1024        let slot = self.remove_connection(&conn_id);
1025        match slot {
1026            Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1027                if let Some(tx) = pending.result_tx.take() {
1028                    let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1029                }
1030            }
1031            Some(other) => {
1032                self.conns.insert(conn_id, other);
1033            }
1034            None => {}
1035        }
1036    }
1037
1038    // r[impl connection.open]
1039    async fn handle_open_request(&mut self, req: OpenRequest) {
1040        let conn_id = self.conn_ids.alloc();
1041
1042        // Send ConnectionOpen to the peer.
1043        let send_result = self
1044            .sess_core
1045            .send(Message {
1046                connection_id: conn_id,
1047                payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1048                    connection_settings: req.settings.clone(),
1049                    metadata: req.metadata,
1050                }),
1051            })
1052            .await;
1053
1054        if send_result.is_err() {
1055            let _ = req.result_tx.send(Err(SessionError::Protocol(
1056                "failed to send ConnectionOpen".into(),
1057            )));
1058            return;
1059        }
1060
1061        // Store the pending state. The run loop will complete the oneshot
1062        // when ConnectionAccept or ConnectionReject arrives.
1063        self.conns.insert(
1064            conn_id,
1065            ConnectionSlot::PendingOutbound(PendingOutboundData {
1066                local_settings: req.settings,
1067                result_tx: Some(req.result_tx),
1068            }),
1069        );
1070    }
1071
1072    // r[impl connection.close]
1073    async fn handle_close_request(&mut self, req: CloseRequest) {
1074        if req.conn_id.is_root() {
1075            let _ = req.result_tx.send(Err(SessionError::Protocol(
1076                "cannot close root connection".into(),
1077            )));
1078            return;
1079        }
1080
1081        // Remove the connection slot — this drops conn_tx and causes the
1082        // Driver to exit cleanly.
1083        if self.remove_connection(&req.conn_id).is_none() {
1084            let _ = req
1085                .result_tx
1086                .send(Err(SessionError::Protocol("connection not found".into())));
1087            return;
1088        }
1089
1090        // Send ConnectionClose to the peer.
1091        let send_result = self
1092            .sess_core
1093            .send(Message {
1094                connection_id: req.conn_id,
1095                payload: MessagePayload::ConnectionClose(ConnectionClose {
1096                    metadata: req.metadata,
1097                }),
1098            })
1099            .await;
1100
1101        if send_result.is_err() {
1102            let _ = req.result_tx.send(Err(SessionError::Protocol(
1103                "failed to send ConnectionClose".into(),
1104            )));
1105            return;
1106        }
1107
1108        let _ = req.result_tx.send(Ok(()));
1109        self.maybe_request_shutdown_after_root_closed();
1110    }
1111
1112    async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1113        match req {
1114            DropControlRequest::Shutdown => {
1115                debug!("session shutdown requested");
1116                false
1117            }
1118            DropControlRequest::Close(conn_id) => {
1119                // r[impl rpc.caller.liveness.last-drop-closes-connection]
1120                if conn_id.is_root() {
1121                    // r[impl rpc.caller.liveness.root-internal-close]
1122                    debug!("root callers dropped; internally closing root connection");
1123                    self.root_closed_internal = true;
1124                    // r[impl rpc.caller.liveness.root-teardown-condition]
1125                    return self.has_virtual_connections();
1126                }
1127
1128                if self.remove_connection(&conn_id).is_some() {
1129                    let _ = self
1130                        .sess_core
1131                        .send(Message {
1132                            connection_id: conn_id,
1133                            payload: MessagePayload::ConnectionClose(ConnectionClose {
1134                                metadata: vec![],
1135                            }),
1136                        })
1137                        .await;
1138                }
1139
1140                !self.root_closed_internal || self.has_virtual_connections()
1141            }
1142        }
1143    }
1144
1145    fn has_virtual_connections(&self) -> bool {
1146        self.conns.keys().any(|id| !id.is_root())
1147    }
1148
1149    fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
1150        let slot = self.conns.remove(conn_id);
1151        if let Some(ConnectionSlot::Active(state)) = &slot {
1152            let _ = state.closed_tx.send(true);
1153        }
1154        slot
1155    }
1156
1157    fn close_all_connections(&mut self) {
1158        for slot in self.conns.values() {
1159            if let ConnectionSlot::Active(state) = slot {
1160                let _ = state.closed_tx.send(true);
1161            }
1162        }
1163        self.conns.clear();
1164    }
1165
1166    fn maybe_request_shutdown_after_root_closed(&self) {
1167        if self.root_closed_internal && !self.has_virtual_connections() {
1168            let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1169        }
1170    }
1171}
1172
1173pub(crate) struct SessionCore {
1174    tx: Box<dyn DynConduitTx>,
1175}
1176
1177impl SessionCore {
1178    pub(crate) async fn send<'a>(&self, msg: Message<'a>) -> Result<(), ()> {
1179        self.tx.send_msg(msg).await.map_err(|_| ())
1180    }
1181}
1182
1183#[cfg(not(target_arch = "wasm32"))]
1184type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
1185#[cfg(target_arch = "wasm32")]
1186type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
1187
1188#[cfg(not(target_arch = "wasm32"))]
1189pub trait DynConduitTx: Send + Sync {
1190    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1191}
1192#[cfg(target_arch = "wasm32")]
1193pub trait DynConduitTx {
1194    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1195}
1196
1197// r[impl zerocopy.send]
1198// r[impl zerocopy.framing.pipeline.outgoing]
1199impl<T> DynConduitTx for T
1200where
1201    T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1202    for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1203{
1204    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>> {
1205        Box::pin(async move {
1206            let permit = self.reserve().await?;
1207            permit
1208                .send(msg)
1209                .map_err(|e| std::io::Error::other(e.to_string()))
1210        })
1211    }
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216    use moire::sync::mpsc;
1217    use roam_types::{
1218        Backing, Conduit, ConnectionAccept, ConnectionReject, MessageFamily, SelfRef,
1219    };
1220
1221    use super::*;
1222
1223    type TestConduit = crate::BareConduit<MessageFamily, crate::MemoryLink>;
1224
1225    fn make_session() -> Session<TestConduit> {
1226        let (a, b) = crate::memory_link_pair(32);
1227        // Keep the peer link alive so sess_core sends don't fail with broken pipe.
1228        std::mem::forget(b);
1229        let conduit = crate::BareConduit::new(a);
1230        let (tx, rx) = conduit.split();
1231        let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1232        let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1233        let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1234        Session::pre_handshake(
1235            tx, rx, None, open_rx, close_rx, control_tx, control_rx, None,
1236        )
1237    }
1238
1239    fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1240        SelfRef::owning(
1241            Backing::Boxed(Box::<[u8]>::default()),
1242            ConnectionAccept {
1243                connection_settings: ConnectionSettings {
1244                    parity: Parity::Even,
1245                    max_concurrent_requests: 64,
1246                },
1247                metadata: vec![],
1248            },
1249        )
1250    }
1251
1252    fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1253        SelfRef::owning(
1254            Backing::Boxed(Box::<[u8]>::default()),
1255            ConnectionReject { metadata: vec![] },
1256        )
1257    }
1258
1259    #[tokio::test]
1260    async fn duplicate_connection_accept_is_ignored_after_first() {
1261        let mut session = make_session();
1262        let conn_id = ConnectionId(1);
1263        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1264
1265        session.conns.insert(
1266            conn_id,
1267            ConnectionSlot::PendingOutbound(PendingOutboundData {
1268                local_settings: ConnectionSettings {
1269                    parity: Parity::Odd,
1270                    max_concurrent_requests: 64,
1271                },
1272                result_tx: Some(result_tx),
1273            }),
1274        );
1275
1276        session.handle_inbound_accept(conn_id, accept_ref());
1277        let handle = result_rx
1278            .await
1279            .expect("pending outbound result should resolve")
1280            .expect("accept should resolve as Ok");
1281        assert_eq!(handle.connection_id(), conn_id);
1282
1283        session.handle_inbound_accept(conn_id, accept_ref());
1284        assert!(
1285            matches!(
1286                session.conns.get(&conn_id),
1287                Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
1288            ),
1289            "duplicate accept should keep existing active connection state"
1290        );
1291    }
1292
1293    #[tokio::test]
1294    async fn duplicate_connection_reject_is_ignored_after_first() {
1295        let mut session = make_session();
1296        let conn_id = ConnectionId(1);
1297        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1298
1299        session.conns.insert(
1300            conn_id,
1301            ConnectionSlot::PendingOutbound(PendingOutboundData {
1302                local_settings: ConnectionSettings {
1303                    parity: Parity::Odd,
1304                    max_concurrent_requests: 64,
1305                },
1306                result_tx: Some(result_tx),
1307            }),
1308        );
1309
1310        session.handle_inbound_reject(conn_id, reject_ref());
1311        let result = result_rx
1312            .await
1313            .expect("pending outbound result should resolve");
1314        assert!(
1315            matches!(result, Err(SessionError::Rejected(_))),
1316            "expected rejection, got: {result:?}"
1317        );
1318
1319        session.handle_inbound_reject(conn_id, reject_ref());
1320        assert!(
1321            !session.conns.contains_key(&conn_id),
1322            "duplicate reject should not recreate connection state"
1323        );
1324    }
1325
1326    #[test]
1327    fn out_of_order_accept_or_reject_without_pending_is_ignored() {
1328        let mut session = make_session();
1329        let conn_id = ConnectionId(99);
1330
1331        session.handle_inbound_accept(conn_id, accept_ref());
1332        session.handle_inbound_reject(conn_id, reject_ref());
1333
1334        assert!(
1335            session.conns.is_empty(),
1336            "out-of-order accept/reject should not mutate empty connection table"
1337        );
1338    }
1339
1340    #[tokio::test]
1341    async fn close_request_clears_pending_outbound_open() {
1342        let mut session = make_session();
1343        let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
1344        let (close_result_tx, close_result_rx) =
1345            moire::sync::oneshot::channel("session.close.result");
1346
1347        session.conns.insert(
1348            ConnectionId(1),
1349            ConnectionSlot::PendingOutbound(PendingOutboundData {
1350                local_settings: ConnectionSettings {
1351                    parity: Parity::Odd,
1352                    max_concurrent_requests: 64,
1353                },
1354                result_tx: Some(open_result_tx),
1355            }),
1356        );
1357
1358        session
1359            .handle_close_request(CloseRequest {
1360                conn_id: ConnectionId(1),
1361                metadata: vec![],
1362                result_tx: close_result_tx,
1363            })
1364            .await;
1365
1366        let close_result = close_result_rx
1367            .await
1368            .expect("close result should be delivered");
1369        assert!(
1370            close_result.is_ok(),
1371            "close should succeed for pending outbound connection"
1372        );
1373
1374        assert!(
1375            open_result_rx.await.is_err(),
1376            "pending open result channel should be closed once the pending slot is removed"
1377        );
1378    }
1379}