Skip to main content

roam_core/session/
mod.rs

1use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
2
3use moire::sync::mpsc;
4use roam_types::{
5    ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
6    ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
7    IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity,
8    RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef, SessionRole,
9};
10use tracing::{debug, warn};
11
12mod builders;
13pub use builders::*;
14
15// r[impl session.handshake]
16/// Current roam session protocol version.
17pub const PROTOCOL_VERSION: u32 = 7;
18
19/// Session-level protocol keepalive configuration.
20#[derive(Debug, Clone, Copy)]
21pub struct SessionKeepaliveConfig {
22    pub ping_interval: Duration,
23    pub pong_timeout: Duration,
24}
25
26// ---------------------------------------------------------------------------
27// Connection acceptor trait
28// ---------------------------------------------------------------------------
29
30/// Callback for accepting or rejecting inbound virtual connections.
31///
32/// Registered on the session via the builder's `.on_connection()` method.
33/// Called synchronously from the session run loop when a peer sends
34/// `ConnectionOpen`. The acceptor returns either an `AcceptedConnection`
35/// (with settings, metadata, and a setup callback that spawns the driver)
36/// or rejection metadata.
37// r[impl rpc.virtual-connection.accept]
38pub trait ConnectionAcceptor: Send + 'static {
39    fn accept(
40        &self,
41        conn_id: ConnectionId,
42        peer_settings: &ConnectionSettings,
43        metadata: &[roam_types::MetadataEntry],
44    ) -> Result<AcceptedConnection, Metadata<'static>>;
45}
46
47/// Result of accepting a virtual connection.
48pub struct AcceptedConnection {
49    /// Our settings for this connection.
50    pub settings: ConnectionSettings,
51    /// Metadata to send back in ConnectionAccept.
52    pub metadata: Metadata<'static>,
53    /// Callback that receives the ConnectionHandle and spawns a Driver.
54    pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
55}
56
57// ---------------------------------------------------------------------------
58// Open/close request types (from SessionHandle → run loop)
59// ---------------------------------------------------------------------------
60
61struct OpenRequest {
62    settings: ConnectionSettings,
63    metadata: Metadata<'static>,
64    result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
65}
66
67struct CloseRequest {
68    conn_id: ConnectionId,
69    metadata: Metadata<'static>,
70    result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
71}
72
73#[derive(Debug, Clone, Copy)]
74pub(crate) enum DropControlRequest {
75    Shutdown,
76    Close(ConnectionId),
77}
78
79#[cfg(not(target_arch = "wasm32"))]
80fn send_drop_control(
81    tx: &mpsc::UnboundedSender<DropControlRequest>,
82    req: DropControlRequest,
83) -> Result<(), ()> {
84    tx.send(req).map_err(|_| ())
85}
86
87#[cfg(target_arch = "wasm32")]
88fn send_drop_control(
89    tx: &mpsc::UnboundedSender<DropControlRequest>,
90    req: DropControlRequest,
91) -> Result<(), ()> {
92    tx.try_send(req).map_err(|_| ())
93}
94
95// ---------------------------------------------------------------------------
96// SessionHandle — cloneable handle for opening/closing virtual connections
97// ---------------------------------------------------------------------------
98
99/// Cloneable handle for opening and closing virtual connections.
100///
101/// Returned by the session builder alongside the `Session` and root
102/// `ConnectionHandle`. The session's `run()` loop must be running
103/// concurrently for requests to be processed.
104// r[impl rpc.virtual-connection.open]
105#[derive(Clone)]
106pub struct SessionHandle {
107    open_tx: mpsc::Sender<OpenRequest>,
108    close_tx: mpsc::Sender<CloseRequest>,
109    control_tx: mpsc::UnboundedSender<DropControlRequest>,
110}
111
112impl SessionHandle {
113    /// Open a new virtual connection on the session.
114    ///
115    /// Allocates a connection ID, sends `ConnectionOpen` to the peer, and
116    /// waits for `ConnectionAccept` or `ConnectionReject`. The session's
117    /// `run()` loop processes the response and completes the returned future.
118    // r[impl connection.open]
119    pub async fn open_connection(
120        &self,
121        settings: ConnectionSettings,
122        metadata: Metadata<'static>,
123    ) -> Result<ConnectionHandle, SessionError> {
124        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
125        self.open_tx
126            .send(OpenRequest {
127                settings,
128                metadata,
129                result_tx,
130            })
131            .await
132            .map_err(|_| SessionError::Protocol("session closed".into()))?;
133        result_rx
134            .await
135            .map_err(|_| SessionError::Protocol("session closed".into()))?
136    }
137
138    /// Close a virtual connection.
139    ///
140    /// Sends `ConnectionClose` to the peer and removes the connection slot.
141    /// After this returns, no further messages will be routed to the
142    /// connection's driver.
143    // r[impl connection.close]
144    pub async fn close_connection(
145        &self,
146        conn_id: ConnectionId,
147        metadata: Metadata<'static>,
148    ) -> Result<(), SessionError> {
149        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
150        self.close_tx
151            .send(CloseRequest {
152                conn_id,
153                metadata,
154                result_tx,
155            })
156            .await
157            .map_err(|_| SessionError::Protocol("session closed".into()))?;
158        result_rx
159            .await
160            .map_err(|_| SessionError::Protocol("session closed".into()))?
161    }
162
163    /// Request shutdown of the entire session (root + all virtual connections).
164    pub fn shutdown(&self) -> Result<(), SessionError> {
165        send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
166            .map_err(|_| SessionError::Protocol("session closed".into()))
167    }
168}
169
170// ---------------------------------------------------------------------------
171// Session
172// ---------------------------------------------------------------------------
173
174/// Session state machine.
175// r[impl session]
176// r[impl rpc.one-service-per-connection]
177pub struct Session<C: Conduit> {
178    /// Conduit receiver
179    rx: C::Rx,
180
181    // r[impl session.role]
182    role: SessionRole,
183
184    /// Our local parity — determines which connection IDs we allocate.
185    // r[impl session.parity]
186    parity: Parity,
187
188    /// Shared core (for sending) — also held by all ConnectionSenders.
189    sess_core: Arc<SessionCore>,
190
191    /// Connection state (active, pending inbound, pending outbound).
192    conns: BTreeMap<ConnectionId, ConnectionSlot>,
193    /// Whether the root connection was internally closed because all root callers dropped.
194    root_closed_internal: bool,
195
196    /// Allocator for outbound virtual connection IDs (uses session parity).
197    conn_ids: IdAllocator<ConnectionId>,
198
199    /// Callback for accepting inbound virtual connections.
200    on_connection: Option<Box<dyn ConnectionAcceptor>>,
201
202    /// Receiver for open requests from SessionHandle.
203    open_rx: mpsc::Receiver<OpenRequest>,
204
205    /// Receiver for close requests from SessionHandle.
206    close_rx: mpsc::Receiver<CloseRequest>,
207
208    /// Sender/receiver for drop-driven session/connection control requests.
209    control_tx: mpsc::UnboundedSender<DropControlRequest>,
210    control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
211
212    /// Optional proactive keepalive runtime config for connection ID 0.
213    keepalive: Option<SessionKeepaliveConfig>,
214}
215
216#[derive(Debug)]
217struct KeepaliveRuntime {
218    ping_interval: Duration,
219    pong_timeout: Duration,
220    next_ping_at: tokio::time::Instant,
221    waiting_pong_nonce: Option<u64>,
222    pong_deadline: tokio::time::Instant,
223    next_ping_nonce: u64,
224}
225
226// r[impl connection]
227/// Static data for one active connection.
228#[derive(Debug)]
229pub struct ConnectionState {
230    /// Unique connection identifier
231    pub id: ConnectionId,
232
233    /// Our settings
234    pub local_settings: ConnectionSettings,
235
236    /// The peer's settings
237    pub peer_settings: ConnectionSettings,
238
239    /// Sender for routing incoming messages to the per-connection driver task.
240    conn_tx: mpsc::Sender<SelfRef<ConnectionMessage<'static>>>,
241}
242
243#[derive(Debug)]
244enum ConnectionSlot {
245    Active(ConnectionState),
246    PendingOutbound(PendingOutboundData),
247}
248
249/// Debug-printable wrapper that omits the oneshot sender.
250struct PendingOutboundData {
251    local_settings: ConnectionSettings,
252    result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
253}
254
255impl std::fmt::Debug for PendingOutboundData {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        f.debug_struct("PendingOutbound")
258            .field("local_settings", &self.local_settings)
259            .finish()
260    }
261}
262
263#[derive(Clone)]
264pub(crate) struct ConnectionSender {
265    connection_id: ConnectionId,
266    sess_core: Arc<SessionCore>,
267    failures: Arc<mpsc::UnboundedSender<(RequestId, &'static str)>>,
268}
269
270fn forwarded_payload<'a>(payload: &'a roam_types::Payload<'static>) -> roam_types::Payload<'a> {
271    let roam_types::Payload::Incoming(bytes) = payload else {
272        unreachable!("proxy forwarding expects decoded incoming payload bytes")
273    };
274    roam_types::Payload::Incoming(bytes)
275}
276
277fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
278    match body {
279        RequestBody::Call(call) => RequestBody::Call(roam_types::RequestCall {
280            method_id: call.method_id,
281            channels: call.channels.clone(),
282            metadata: call.metadata.clone(),
283            args: forwarded_payload(&call.args),
284        }),
285        RequestBody::Response(response) => RequestBody::Response(RequestResponse {
286            channels: response.channels.clone(),
287            metadata: response.metadata.clone(),
288            ret: forwarded_payload(&response.ret),
289        }),
290        RequestBody::Cancel(cancel) => RequestBody::Cancel(roam_types::RequestCancel {
291            metadata: cancel.metadata.clone(),
292        }),
293    }
294}
295
296fn forwarded_channel_body<'a>(
297    body: &'a roam_types::ChannelBody<'static>,
298) -> roam_types::ChannelBody<'a> {
299    match body {
300        roam_types::ChannelBody::Item(item) => {
301            roam_types::ChannelBody::Item(roam_types::ChannelItem {
302                item: forwarded_payload(&item.item),
303            })
304        }
305        roam_types::ChannelBody::Close(close) => {
306            roam_types::ChannelBody::Close(roam_types::ChannelClose {
307                metadata: close.metadata.clone(),
308            })
309        }
310        roam_types::ChannelBody::Reset(reset) => {
311            roam_types::ChannelBody::Reset(roam_types::ChannelReset {
312                metadata: reset.metadata.clone(),
313            })
314        }
315        roam_types::ChannelBody::GrantCredit(credit) => {
316            roam_types::ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
317                additional: credit.additional,
318            })
319        }
320    }
321}
322
323impl ConnectionSender {
324    /// Send an arbitrary connection message
325    pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
326        let payload = match msg {
327            ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
328            ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
329        };
330        let message = Message {
331            connection_id: self.connection_id,
332            payload,
333        };
334        self.sess_core.send(message).await.map_err(|_| ())
335    }
336
337    /// Send a received connection message without re-materializing payload values.
338    pub(crate) async fn send_owned(
339        &self,
340        msg: SelfRef<ConnectionMessage<'static>>,
341    ) -> Result<(), ()> {
342        let payload = match &*msg {
343            ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
344                id: request.id,
345                body: forwarded_request_body(&request.body),
346            }),
347            ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
348                id: channel.id,
349                body: forwarded_channel_body(&channel.body),
350            }),
351        };
352
353        self.sess_core
354            .send(Message {
355                connection_id: self.connection_id,
356                payload,
357            })
358            .await
359            .map_err(|_| ())
360    }
361
362    /// Send a response specifically
363    pub async fn send_response<'a>(
364        &self,
365        request_id: RequestId,
366        response: RequestResponse<'a>,
367    ) -> Result<(), ()> {
368        self.send(ConnectionMessage::Request(RequestMessage {
369            id: request_id,
370            body: RequestBody::Response(response),
371        }))
372        .await
373    }
374
375    /// Mark a request as failed by removing any pending response slot.
376    /// Called when a send error occurs or no reply was sent.
377    pub fn mark_failure(&self, request_id: RequestId, reason: &'static str) {
378        let _ = self.failures.send((request_id, reason));
379    }
380}
381
382pub struct ConnectionHandle {
383    pub(crate) sender: ConnectionSender,
384    pub(crate) rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
385    pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
386    pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
387    /// The parity this side should use for allocating request/channel IDs.
388    pub parity: Parity,
389}
390
391impl std::fmt::Debug for ConnectionHandle {
392    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393        f.debug_struct("ConnectionHandle")
394            .field("connection_id", &self.sender.connection_id)
395            .finish()
396    }
397}
398
399pub(crate) enum ConnectionMessage<'payload> {
400    Request(RequestMessage<'payload>),
401    Channel(ChannelMessage<'payload>),
402}
403
404impl ConnectionHandle {
405    /// Returns the connection ID for this handle.
406    pub fn connection_id(&self) -> ConnectionId {
407        self.sender.connection_id
408    }
409}
410
411/// Forward all request/channel traffic between two connections.
412///
413/// This is a protocol-level bridge: it does not inspect service schemas or method IDs.
414/// It exits when either side closes or a forward send fails, then requests closure of
415/// both underlying connections.
416pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
417    let left_conn_id = left.connection_id();
418    let right_conn_id = right.connection_id();
419    let ConnectionHandle {
420        sender: left_sender,
421        rx: mut left_rx,
422        failures_rx: _left_failures_rx,
423        control_tx: left_control_tx,
424        parity: _left_parity,
425    } = left;
426    let ConnectionHandle {
427        sender: right_sender,
428        rx: mut right_rx,
429        failures_rx: _right_failures_rx,
430        control_tx: right_control_tx,
431        parity: _right_parity,
432    } = right;
433
434    loop {
435        tokio::select! {
436            msg = left_rx.recv() => {
437                let Some(msg) = msg else {
438                    break;
439                };
440                if right_sender.send_owned(msg).await.is_err() {
441                    break;
442                }
443            }
444            msg = right_rx.recv() => {
445                let Some(msg) = msg else {
446                    break;
447                };
448                if left_sender.send_owned(msg).await.is_err() {
449                    break;
450                }
451            }
452        }
453    }
454
455    if let Some(tx) = left_control_tx.as_ref() {
456        let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
457    }
458    if let Some(tx) = right_control_tx.as_ref() {
459        let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
460    }
461}
462
463/// Errors that can occur during session establishment or operation.
464#[derive(Debug)]
465pub enum SessionError {
466    Io(std::io::Error),
467    Protocol(String),
468    Rejected(Metadata<'static>),
469}
470
471impl std::fmt::Display for SessionError {
472    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473        match self {
474            Self::Io(e) => write!(f, "io error: {e}"),
475            Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
476            Self::Rejected(_) => write!(f, "connection rejected"),
477        }
478    }
479}
480
481impl std::error::Error for SessionError {}
482
483impl<C> Session<C>
484where
485    C: Conduit<Msg = MessageFamily>,
486    C::Tx: MaybeSend + MaybeSync + 'static,
487    for<'p> <C::Tx as ConduitTx>::Permit<'p>: MaybeSend,
488    C::Rx: MaybeSend,
489{
490    #[allow(clippy::too_many_arguments)]
491    fn pre_handshake(
492        tx: C::Tx,
493        rx: C::Rx,
494        on_connection: Option<Box<dyn ConnectionAcceptor>>,
495        open_rx: mpsc::Receiver<OpenRequest>,
496        close_rx: mpsc::Receiver<CloseRequest>,
497        control_tx: mpsc::UnboundedSender<DropControlRequest>,
498        control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
499        keepalive: Option<SessionKeepaliveConfig>,
500    ) -> Self {
501        let sess_core = Arc::new(SessionCore { tx: Box::new(tx) });
502        Session {
503            rx,
504            role: SessionRole::Initiator, // overwritten in establish_as_*
505            parity: Parity::Odd,          // overwritten in establish_as_*
506            sess_core,
507            conns: BTreeMap::new(),
508            root_closed_internal: false,
509            conn_ids: IdAllocator::new(Parity::Odd), // overwritten in establish_as_*
510            on_connection,
511            open_rx,
512            close_rx,
513            control_tx,
514            control_rx,
515            keepalive,
516        }
517    }
518
519    // r[impl session.handshake]
520    async fn establish_as_initiator(
521        &mut self,
522        settings: ConnectionSettings,
523        metadata: Metadata<'_>,
524    ) -> Result<ConnectionHandle, SessionError> {
525        use roam_types::{Hello, MessagePayload};
526
527        self.role = SessionRole::Initiator;
528        self.parity = settings.parity;
529        self.conn_ids = IdAllocator::new(settings.parity);
530
531        // Send Hello
532        self.sess_core
533            .send(Message {
534                connection_id: ConnectionId::ROOT,
535                payload: MessagePayload::Hello(Hello {
536                    version: PROTOCOL_VERSION,
537                    connection_settings: settings.clone(),
538                    metadata,
539                }),
540            })
541            .await
542            .map_err(|_| SessionError::Protocol("failed to send Hello".into()))?;
543
544        // Receive HelloYourself
545        let peer_settings = match self.rx.recv().await {
546            Ok(Some(msg)) => {
547                let payload = msg.map(|m| m.payload);
548                match &*payload {
549                    MessagePayload::HelloYourself(hy) => hy.connection_settings.clone(),
550                    MessagePayload::ProtocolError(e) => {
551                        return Err(SessionError::Protocol(e.description.to_owned()));
552                    }
553                    _ => {
554                        return Err(SessionError::Protocol("expected HelloYourself".into()));
555                    }
556                }
557            }
558            Ok(None) => {
559                return Err(SessionError::Protocol(
560                    "peer closed during handshake".into(),
561                ));
562            }
563            Err(e) => return Err(SessionError::Protocol(e.to_string())),
564        };
565
566        Ok(self.make_root_handle(settings, peer_settings))
567    }
568
569    // r[impl session.handshake]
570    #[moire::instrument]
571    async fn establish_as_acceptor(
572        &mut self,
573        settings: ConnectionSettings,
574        metadata: Metadata<'_>,
575    ) -> Result<ConnectionHandle, SessionError> {
576        use roam_types::{HelloYourself, MessagePayload};
577
578        self.role = SessionRole::Acceptor;
579
580        // Receive Hello
581        let peer_settings = match self.rx.recv().await {
582            Ok(Some(msg)) => {
583                let payload = msg.map(|m| m.payload);
584                match &*payload {
585                    MessagePayload::Hello(h) => {
586                        if h.version != PROTOCOL_VERSION {
587                            return Err(SessionError::Protocol(format!(
588                                "version mismatch: got {}, expected {PROTOCOL_VERSION}",
589                                h.version
590                            )));
591                        }
592                        h.connection_settings.clone()
593                    }
594                    MessagePayload::ProtocolError(e) => {
595                        return Err(SessionError::Protocol(e.description.to_owned()));
596                    }
597                    _ => {
598                        return Err(SessionError::Protocol("expected Hello".into()));
599                    }
600                }
601            }
602            Ok(None) => {
603                return Err(SessionError::Protocol(
604                    "peer closed during handshake".into(),
605                ));
606            }
607            Err(e) => return Err(SessionError::Protocol(e.to_string())),
608        };
609
610        // Acceptor parity is opposite of initiator
611        let our_settings = ConnectionSettings {
612            parity: peer_settings.parity.other(),
613            ..settings
614        };
615        self.parity = our_settings.parity;
616        self.conn_ids = IdAllocator::new(our_settings.parity);
617
618        // Send HelloYourself
619        self.sess_core
620            .send(Message {
621                connection_id: ConnectionId::ROOT,
622                payload: MessagePayload::HelloYourself(HelloYourself {
623                    connection_settings: our_settings.clone(),
624                    metadata,
625                }),
626            })
627            .await
628            .map_err(|_| SessionError::Protocol("failed to send HelloYourself".into()))?;
629
630        Ok(self.make_root_handle(our_settings, peer_settings))
631    }
632
633    fn make_root_handle(
634        &mut self,
635        local_settings: ConnectionSettings,
636        peer_settings: ConnectionSettings,
637    ) -> ConnectionHandle {
638        self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
639    }
640
641    fn make_connection_handle(
642        &mut self,
643        conn_id: ConnectionId,
644        local_settings: ConnectionSettings,
645        peer_settings: ConnectionSettings,
646    ) -> ConnectionHandle {
647        let label = format!("session.conn{}", conn_id.0);
648        let (conn_tx, conn_rx) = mpsc::channel::<SelfRef<ConnectionMessage<'static>>>(&label, 64);
649        let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
650
651        let sender = ConnectionSender {
652            connection_id: conn_id,
653            sess_core: Arc::clone(&self.sess_core),
654            failures: Arc::new(failures_tx),
655        };
656
657        let parity = local_settings.parity;
658        self.conns.insert(
659            conn_id,
660            ConnectionSlot::Active(ConnectionState {
661                id: conn_id,
662                local_settings,
663                peer_settings,
664                conn_tx,
665            }),
666        );
667
668        ConnectionHandle {
669            sender,
670            rx: conn_rx,
671            failures_rx,
672            control_tx: Some(self.control_tx.clone()),
673            parity,
674        }
675    }
676
677    /// Run the session recv loop: read from the conduit, demux by connection
678    /// ID, and route to the appropriate connection's driver. Also processes
679    /// open/close requests from the SessionHandle.
680    // r[impl zerocopy.framing.pipeline.incoming]
681    pub async fn run(&mut self) {
682        let mut keepalive_runtime = self.make_keepalive_runtime();
683        let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
684            let mut interval = tokio::time::interval(Duration::from_millis(10));
685            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
686            interval
687        });
688
689        loop {
690            tokio::select! {
691                msg = self.rx.recv() => {
692                    match msg {
693                        Ok(Some(msg)) => self.handle_message(msg, &mut keepalive_runtime).await,
694                        Ok(None) => {
695                            warn!("session recv loop ended: conduit returned EOF");
696                            break;
697                        }
698                        Err(error) => {
699                            warn!(error = %error, "session recv loop ended: conduit recv error");
700                            break;
701                        }
702                    }
703                }
704                Some(req) = self.open_rx.recv() => {
705                    self.handle_open_request(req).await;
706                }
707                Some(req) = self.close_rx.recv() => {
708                    self.handle_close_request(req).await;
709                }
710                Some(req) = self.control_rx.recv() => {
711                    if !self.handle_drop_control_request(req).await {
712                        break;
713                    }
714                }
715                _ = async {
716                    if let Some(interval) = keepalive_tick.as_mut() {
717                        interval.tick().await;
718                    }
719                }, if keepalive_tick.is_some() => {
720                    if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
721                        break;
722                    }
723                }
724            }
725        }
726
727        // Drop all connection slots so per-connection drivers exit immediately.
728        self.conns.clear();
729        debug!("session recv loop exited");
730    }
731
732    async fn handle_message(
733        &mut self,
734        msg: SelfRef<Message<'static>>,
735        keepalive_runtime: &mut Option<KeepaliveRuntime>,
736    ) {
737        let conn_id = msg.connection_id;
738        roam_types::selfref_match!(msg, payload {
739            // r[impl connection.close.semantics]
740            MessagePayload::ConnectionClose(_) => {
741                if conn_id.is_root() {
742                    warn!("received ConnectionClose for root connection");
743                } else {
744                    debug!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
745                }
746                // Remove the connection — dropping conn_tx causes the Driver's rx
747                // to return None, which exits its run loop. All in-flight handlers
748                // are dropped, triggering DriverReplySink::drop → Cancelled responses.
749                self.conns.remove(&conn_id);
750                self.maybe_request_shutdown_after_root_closed();
751            }
752            MessagePayload::ConnectionOpen(open) => {
753                self.handle_inbound_open(conn_id, open).await;
754            }
755            MessagePayload::ConnectionAccept(accept) => {
756                self.handle_inbound_accept(conn_id, accept);
757            }
758            MessagePayload::ConnectionReject(reject) => {
759                self.handle_inbound_reject(conn_id, reject);
760            }
761            MessagePayload::RequestMessage(r) => {
762                let conn_tx = match self.conns.get(&conn_id) {
763                    Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
764                    _ => return,
765                };
766                if conn_tx.send(r.map(ConnectionMessage::Request)).await.is_err() {
767                    self.conns.remove(&conn_id);
768                    self.maybe_request_shutdown_after_root_closed();
769                }
770            }
771            MessagePayload::ChannelMessage(c) => {
772                let conn_tx = match self.conns.get(&conn_id) {
773                    Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
774                    _ => return,
775                };
776                if conn_tx.send(c.map(ConnectionMessage::Channel)).await.is_err() {
777                    self.conns.remove(&conn_id);
778                    self.maybe_request_shutdown_after_root_closed();
779                }
780            }
781            MessagePayload::Ping(ping) => {
782                let _ = self
783                    .sess_core
784                    .send(Message {
785                        connection_id: conn_id,
786                        payload: MessagePayload::Pong(roam_types::Pong { nonce: ping.nonce }),
787                    })
788                    .await;
789            }
790            MessagePayload::Pong(pong) => {
791                if conn_id.is_root() {
792                    self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
793                }
794            }
795            // Hello, HelloYourself, ProtocolError: not valid post-handshake, drop.
796        })
797    }
798
799    fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
800        let config = self.keepalive?;
801        if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
802            warn!("keepalive disabled due to non-positive interval/timeout");
803            return None;
804        }
805        let now = tokio::time::Instant::now();
806        Some(KeepaliveRuntime {
807            ping_interval: config.ping_interval,
808            pong_timeout: config.pong_timeout,
809            next_ping_at: now + config.ping_interval,
810            waiting_pong_nonce: None,
811            pong_deadline: now,
812            next_ping_nonce: 1,
813        })
814    }
815
816    fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
817        let Some(runtime) = keepalive_runtime.as_mut() else {
818            return;
819        };
820        if runtime.waiting_pong_nonce != Some(nonce) {
821            return;
822        }
823        runtime.waiting_pong_nonce = None;
824        runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
825    }
826
827    async fn handle_keepalive_tick(
828        &mut self,
829        keepalive_runtime: &mut Option<KeepaliveRuntime>,
830    ) -> bool {
831        let Some(runtime) = keepalive_runtime.as_mut() else {
832            return true;
833        };
834        let now = tokio::time::Instant::now();
835
836        if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
837            if now >= runtime.pong_deadline {
838                warn!(
839                    nonce = waiting_nonce,
840                    timeout_ms = runtime.pong_timeout.as_millis(),
841                    "keepalive timeout waiting for pong"
842                );
843                return false;
844            }
845            return true;
846        }
847
848        if now < runtime.next_ping_at {
849            return true;
850        }
851
852        let nonce = runtime.next_ping_nonce;
853        if self
854            .sess_core
855            .send(Message {
856                connection_id: ConnectionId::ROOT,
857                payload: MessagePayload::Ping(roam_types::Ping { nonce }),
858            })
859            .await
860            .is_err()
861        {
862            warn!("failed to send keepalive ping");
863            return false;
864        }
865
866        runtime.waiting_pong_nonce = Some(nonce);
867        runtime.pong_deadline = now + runtime.pong_timeout;
868        runtime.next_ping_at = now + runtime.ping_interval;
869        runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
870        true
871    }
872
873    async fn handle_inbound_open(
874        &mut self,
875        conn_id: ConnectionId,
876        open: SelfRef<ConnectionOpen<'static>>,
877    ) {
878        // Validate: connection ID must match peer's parity (opposite of ours).
879        let peer_parity = self.parity.other();
880        if !conn_id.has_parity(peer_parity) {
881            // Protocol error: wrong parity. For now, just reject.
882            let _ = self
883                .sess_core
884                .send(Message {
885                    connection_id: conn_id,
886                    payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
887                        metadata: vec![],
888                    }),
889                })
890                .await;
891            return;
892        }
893
894        // Validate: connection ID must not already be in use.
895        if self.conns.contains_key(&conn_id) {
896            // Protocol error: duplicate connection ID.
897            let _ = self
898                .sess_core
899                .send(Message {
900                    connection_id: conn_id,
901                    payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
902                        metadata: vec![],
903                    }),
904                })
905                .await;
906            return;
907        }
908
909        // r[impl connection.open.rejection]
910        // Call the acceptor callback. If none is registered, reject.
911        let acceptor = match &self.on_connection {
912            Some(a) => a,
913            None => {
914                let _ = self
915                    .sess_core
916                    .send(Message {
917                        connection_id: conn_id,
918                        payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
919                            metadata: vec![],
920                        }),
921                    })
922                    .await;
923                return;
924            }
925        };
926
927        match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
928            Ok(accepted) => {
929                // Create the connection handle and activate it.
930                let handle = self.make_connection_handle(
931                    conn_id,
932                    accepted.settings.clone(),
933                    open.connection_settings.clone(),
934                );
935
936                // Send ConnectionAccept to the peer.
937                let _ = self
938                    .sess_core
939                    .send(Message {
940                        connection_id: conn_id,
941                        payload: MessagePayload::ConnectionAccept(roam_types::ConnectionAccept {
942                            connection_settings: accepted.settings,
943                            metadata: accepted.metadata,
944                        }),
945                    })
946                    .await;
947
948                // Let the acceptor set up its driver.
949                (accepted.setup)(handle);
950            }
951            Err(reject_metadata) => {
952                let _ = self
953                    .sess_core
954                    .send(Message {
955                        connection_id: conn_id,
956                        payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
957                            metadata: reject_metadata,
958                        }),
959                    })
960                    .await;
961            }
962        }
963    }
964
965    fn handle_inbound_accept(
966        &mut self,
967        conn_id: ConnectionId,
968        accept: SelfRef<ConnectionAccept<'static>>,
969    ) {
970        let slot = self.conns.remove(&conn_id);
971        match slot {
972            Some(ConnectionSlot::PendingOutbound(mut pending)) => {
973                let handle = self.make_connection_handle(
974                    conn_id,
975                    pending.local_settings.clone(),
976                    accept.connection_settings.clone(),
977                );
978
979                if let Some(tx) = pending.result_tx.take() {
980                    let _ = tx.send(Ok(handle));
981                }
982            }
983            Some(other) => {
984                // Not pending outbound — put it back and ignore.
985                self.conns.insert(conn_id, other);
986            }
987            None => {
988                // No pending open for this ID — ignore.
989            }
990        }
991    }
992
993    fn handle_inbound_reject(
994        &mut self,
995        conn_id: ConnectionId,
996        reject: SelfRef<ConnectionReject<'static>>,
997    ) {
998        let slot = self.conns.remove(&conn_id);
999        match slot {
1000            Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1001                if let Some(tx) = pending.result_tx.take() {
1002                    let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1003                }
1004            }
1005            Some(other) => {
1006                self.conns.insert(conn_id, other);
1007            }
1008            None => {}
1009        }
1010    }
1011
1012    // r[impl connection.open]
1013    async fn handle_open_request(&mut self, req: OpenRequest) {
1014        let conn_id = self.conn_ids.alloc();
1015
1016        // Send ConnectionOpen to the peer.
1017        let send_result = self
1018            .sess_core
1019            .send(Message {
1020                connection_id: conn_id,
1021                payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1022                    connection_settings: req.settings.clone(),
1023                    metadata: req.metadata,
1024                }),
1025            })
1026            .await;
1027
1028        if send_result.is_err() {
1029            let _ = req.result_tx.send(Err(SessionError::Protocol(
1030                "failed to send ConnectionOpen".into(),
1031            )));
1032            return;
1033        }
1034
1035        // Store the pending state. The run loop will complete the oneshot
1036        // when ConnectionAccept or ConnectionReject arrives.
1037        self.conns.insert(
1038            conn_id,
1039            ConnectionSlot::PendingOutbound(PendingOutboundData {
1040                local_settings: req.settings,
1041                result_tx: Some(req.result_tx),
1042            }),
1043        );
1044    }
1045
1046    // r[impl connection.close]
1047    async fn handle_close_request(&mut self, req: CloseRequest) {
1048        if req.conn_id.is_root() {
1049            let _ = req.result_tx.send(Err(SessionError::Protocol(
1050                "cannot close root connection".into(),
1051            )));
1052            return;
1053        }
1054
1055        // Remove the connection slot — this drops conn_tx and causes the
1056        // Driver to exit cleanly.
1057        if self.conns.remove(&req.conn_id).is_none() {
1058            let _ = req
1059                .result_tx
1060                .send(Err(SessionError::Protocol("connection not found".into())));
1061            return;
1062        }
1063
1064        // Send ConnectionClose to the peer.
1065        let send_result = self
1066            .sess_core
1067            .send(Message {
1068                connection_id: req.conn_id,
1069                payload: MessagePayload::ConnectionClose(ConnectionClose {
1070                    metadata: req.metadata,
1071                }),
1072            })
1073            .await;
1074
1075        if send_result.is_err() {
1076            let _ = req.result_tx.send(Err(SessionError::Protocol(
1077                "failed to send ConnectionClose".into(),
1078            )));
1079            return;
1080        }
1081
1082        let _ = req.result_tx.send(Ok(()));
1083        self.maybe_request_shutdown_after_root_closed();
1084    }
1085
1086    async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1087        match req {
1088            DropControlRequest::Shutdown => {
1089                debug!("session shutdown requested");
1090                false
1091            }
1092            DropControlRequest::Close(conn_id) => {
1093                // r[impl rpc.caller.liveness.last-drop-closes-connection]
1094                if conn_id.is_root() {
1095                    // r[impl rpc.caller.liveness.root-internal-close]
1096                    debug!("root callers dropped; internally closing root connection");
1097                    self.root_closed_internal = true;
1098                    // r[impl rpc.caller.liveness.root-teardown-condition]
1099                    return self.has_virtual_connections();
1100                }
1101
1102                if self.conns.remove(&conn_id).is_some() {
1103                    let _ = self
1104                        .sess_core
1105                        .send(Message {
1106                            connection_id: conn_id,
1107                            payload: MessagePayload::ConnectionClose(ConnectionClose {
1108                                metadata: vec![],
1109                            }),
1110                        })
1111                        .await;
1112                }
1113
1114                !self.root_closed_internal || self.has_virtual_connections()
1115            }
1116        }
1117    }
1118
1119    fn has_virtual_connections(&self) -> bool {
1120        self.conns.keys().any(|id| !id.is_root())
1121    }
1122
1123    fn maybe_request_shutdown_after_root_closed(&self) {
1124        if self.root_closed_internal && !self.has_virtual_connections() {
1125            let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1126        }
1127    }
1128}
1129
1130pub(crate) struct SessionCore {
1131    tx: Box<dyn DynConduitTx>,
1132}
1133
1134impl SessionCore {
1135    pub(crate) async fn send<'a>(&self, msg: Message<'a>) -> Result<(), ()> {
1136        self.tx.send_msg(msg).await.map_err(|_| ())
1137    }
1138}
1139
1140#[cfg(not(target_arch = "wasm32"))]
1141type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
1142#[cfg(target_arch = "wasm32")]
1143type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
1144
1145#[cfg(not(target_arch = "wasm32"))]
1146pub trait DynConduitTx: Send + Sync {
1147    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1148}
1149#[cfg(target_arch = "wasm32")]
1150pub trait DynConduitTx {
1151    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1152}
1153
1154// r[impl zerocopy.send]
1155// r[impl zerocopy.framing.pipeline.outgoing]
1156impl<T> DynConduitTx for T
1157where
1158    T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1159    for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1160{
1161    fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>> {
1162        Box::pin(async move {
1163            let permit = self.reserve().await?;
1164            permit
1165                .send(msg)
1166                .map_err(|e| std::io::Error::other(e.to_string()))
1167        })
1168    }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173    use moire::sync::mpsc;
1174    use roam_types::{
1175        Backing, Conduit, ConnectionAccept, ConnectionReject, MessageFamily, SelfRef,
1176    };
1177
1178    use super::*;
1179
1180    type TestConduit = crate::BareConduit<MessageFamily, crate::MemoryLink>;
1181
1182    fn make_session() -> Session<TestConduit> {
1183        let (a, b) = crate::memory_link_pair(32);
1184        // Keep the peer link alive so sess_core sends don't fail with broken pipe.
1185        std::mem::forget(b);
1186        let conduit = crate::BareConduit::new(a);
1187        let (tx, rx) = conduit.split();
1188        let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1189        let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1190        let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1191        Session::pre_handshake(
1192            tx, rx, None, open_rx, close_rx, control_tx, control_rx, None,
1193        )
1194    }
1195
1196    fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1197        SelfRef::owning(
1198            Backing::Boxed(Box::<[u8]>::default()),
1199            ConnectionAccept {
1200                connection_settings: ConnectionSettings {
1201                    parity: Parity::Even,
1202                    max_concurrent_requests: 64,
1203                },
1204                metadata: vec![],
1205            },
1206        )
1207    }
1208
1209    fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1210        SelfRef::owning(
1211            Backing::Boxed(Box::<[u8]>::default()),
1212            ConnectionReject { metadata: vec![] },
1213        )
1214    }
1215
1216    #[tokio::test]
1217    async fn duplicate_connection_accept_is_ignored_after_first() {
1218        let mut session = make_session();
1219        let conn_id = ConnectionId(1);
1220        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1221
1222        session.conns.insert(
1223            conn_id,
1224            ConnectionSlot::PendingOutbound(PendingOutboundData {
1225                local_settings: ConnectionSettings {
1226                    parity: Parity::Odd,
1227                    max_concurrent_requests: 64,
1228                },
1229                result_tx: Some(result_tx),
1230            }),
1231        );
1232
1233        session.handle_inbound_accept(conn_id, accept_ref());
1234        let handle = result_rx
1235            .await
1236            .expect("pending outbound result should resolve")
1237            .expect("accept should resolve as Ok");
1238        assert_eq!(handle.connection_id(), conn_id);
1239
1240        session.handle_inbound_accept(conn_id, accept_ref());
1241        assert!(
1242            matches!(
1243                session.conns.get(&conn_id),
1244                Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
1245            ),
1246            "duplicate accept should keep existing active connection state"
1247        );
1248    }
1249
1250    #[tokio::test]
1251    async fn duplicate_connection_reject_is_ignored_after_first() {
1252        let mut session = make_session();
1253        let conn_id = ConnectionId(1);
1254        let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1255
1256        session.conns.insert(
1257            conn_id,
1258            ConnectionSlot::PendingOutbound(PendingOutboundData {
1259                local_settings: ConnectionSettings {
1260                    parity: Parity::Odd,
1261                    max_concurrent_requests: 64,
1262                },
1263                result_tx: Some(result_tx),
1264            }),
1265        );
1266
1267        session.handle_inbound_reject(conn_id, reject_ref());
1268        let result = result_rx
1269            .await
1270            .expect("pending outbound result should resolve");
1271        assert!(
1272            matches!(result, Err(SessionError::Rejected(_))),
1273            "expected rejection, got: {result:?}"
1274        );
1275
1276        session.handle_inbound_reject(conn_id, reject_ref());
1277        assert!(
1278            !session.conns.contains_key(&conn_id),
1279            "duplicate reject should not recreate connection state"
1280        );
1281    }
1282
1283    #[test]
1284    fn out_of_order_accept_or_reject_without_pending_is_ignored() {
1285        let mut session = make_session();
1286        let conn_id = ConnectionId(99);
1287
1288        session.handle_inbound_accept(conn_id, accept_ref());
1289        session.handle_inbound_reject(conn_id, reject_ref());
1290
1291        assert!(
1292            session.conns.is_empty(),
1293            "out-of-order accept/reject should not mutate empty connection table"
1294        );
1295    }
1296
1297    #[tokio::test]
1298    async fn close_request_clears_pending_outbound_open() {
1299        let mut session = make_session();
1300        let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
1301        let (close_result_tx, close_result_rx) =
1302            moire::sync::oneshot::channel("session.close.result");
1303
1304        session.conns.insert(
1305            ConnectionId(1),
1306            ConnectionSlot::PendingOutbound(PendingOutboundData {
1307                local_settings: ConnectionSettings {
1308                    parity: Parity::Odd,
1309                    max_concurrent_requests: 64,
1310                },
1311                result_tx: Some(open_result_tx),
1312            }),
1313        );
1314
1315        session
1316            .handle_close_request(CloseRequest {
1317                conn_id: ConnectionId(1),
1318                metadata: vec![],
1319                result_tx: close_result_tx,
1320            })
1321            .await;
1322
1323        let close_result = close_result_rx
1324            .await
1325            .expect("close result should be delivered");
1326        assert!(
1327            close_result.is_ok(),
1328            "close should succeed for pending outbound connection"
1329        );
1330
1331        assert!(
1332            open_result_rx.await.is_err(),
1333            "pending open result channel should be closed once the pending slot is removed"
1334        );
1335    }
1336}