1use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
2
3use moire::sync::mpsc;
4use roam_types::{
5 ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
6 ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
7 IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity,
8 RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef, SessionRole,
9};
10use tokio::sync::watch;
11use tracing::{debug, warn};
12
13mod builders;
14pub use builders::*;
15
16pub const PROTOCOL_VERSION: u32 = 7;
19
20#[derive(Debug, Clone, Copy)]
22pub struct SessionKeepaliveConfig {
23 pub ping_interval: Duration,
24 pub pong_timeout: Duration,
25}
26
27pub trait ConnectionAcceptor: Send + 'static {
40 fn accept(
41 &self,
42 conn_id: ConnectionId,
43 peer_settings: &ConnectionSettings,
44 metadata: &[roam_types::MetadataEntry],
45 ) -> Result<AcceptedConnection, Metadata<'static>>;
46}
47
48pub struct AcceptedConnection {
50 pub settings: ConnectionSettings,
52 pub metadata: Metadata<'static>,
54 pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
56}
57
58struct OpenRequest {
63 settings: ConnectionSettings,
64 metadata: Metadata<'static>,
65 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
66}
67
68struct CloseRequest {
69 conn_id: ConnectionId,
70 metadata: Metadata<'static>,
71 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
72}
73
74#[derive(Debug, Clone, Copy)]
75pub(crate) enum DropControlRequest {
76 Shutdown,
77 Close(ConnectionId),
78}
79
80#[cfg(not(target_arch = "wasm32"))]
81fn send_drop_control(
82 tx: &mpsc::UnboundedSender<DropControlRequest>,
83 req: DropControlRequest,
84) -> Result<(), ()> {
85 tx.send(req).map_err(|_| ())
86}
87
88#[cfg(target_arch = "wasm32")]
89fn send_drop_control(
90 tx: &mpsc::UnboundedSender<DropControlRequest>,
91 req: DropControlRequest,
92) -> Result<(), ()> {
93 tx.try_send(req).map_err(|_| ())
94}
95
96#[derive(Clone)]
107pub struct SessionHandle {
108 open_tx: mpsc::Sender<OpenRequest>,
109 close_tx: mpsc::Sender<CloseRequest>,
110 control_tx: mpsc::UnboundedSender<DropControlRequest>,
111}
112
113impl SessionHandle {
114 pub async fn open_connection(
121 &self,
122 settings: ConnectionSettings,
123 metadata: Metadata<'static>,
124 ) -> Result<ConnectionHandle, SessionError> {
125 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
126 self.open_tx
127 .send(OpenRequest {
128 settings,
129 metadata,
130 result_tx,
131 })
132 .await
133 .map_err(|_| SessionError::Protocol("session closed".into()))?;
134 result_rx
135 .await
136 .map_err(|_| SessionError::Protocol("session closed".into()))?
137 }
138
139 pub async fn close_connection(
146 &self,
147 conn_id: ConnectionId,
148 metadata: Metadata<'static>,
149 ) -> Result<(), SessionError> {
150 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
151 self.close_tx
152 .send(CloseRequest {
153 conn_id,
154 metadata,
155 result_tx,
156 })
157 .await
158 .map_err(|_| SessionError::Protocol("session closed".into()))?;
159 result_rx
160 .await
161 .map_err(|_| SessionError::Protocol("session closed".into()))?
162 }
163
164 pub fn shutdown(&self) -> Result<(), SessionError> {
166 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
167 .map_err(|_| SessionError::Protocol("session closed".into()))
168 }
169}
170
171pub struct Session<C: Conduit> {
179 rx: C::Rx,
181
182 role: SessionRole,
184
185 parity: Parity,
188
189 sess_core: Arc<SessionCore>,
191
192 conns: BTreeMap<ConnectionId, ConnectionSlot>,
194 root_closed_internal: bool,
196
197 conn_ids: IdAllocator<ConnectionId>,
199
200 on_connection: Option<Box<dyn ConnectionAcceptor>>,
202
203 open_rx: mpsc::Receiver<OpenRequest>,
205
206 close_rx: mpsc::Receiver<CloseRequest>,
208
209 control_tx: mpsc::UnboundedSender<DropControlRequest>,
211 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
212
213 keepalive: Option<SessionKeepaliveConfig>,
215}
216
217#[derive(Debug)]
218struct KeepaliveRuntime {
219 ping_interval: Duration,
220 pong_timeout: Duration,
221 next_ping_at: tokio::time::Instant,
222 waiting_pong_nonce: Option<u64>,
223 pong_deadline: tokio::time::Instant,
224 next_ping_nonce: u64,
225}
226
227#[derive(Debug)]
230pub struct ConnectionState {
231 pub id: ConnectionId,
233
234 pub local_settings: ConnectionSettings,
236
237 pub peer_settings: ConnectionSettings,
239
240 conn_tx: mpsc::Sender<SelfRef<ConnectionMessage<'static>>>,
242 closed_tx: watch::Sender<bool>,
243}
244
245#[derive(Debug)]
246enum ConnectionSlot {
247 Active(ConnectionState),
248 PendingOutbound(PendingOutboundData),
249}
250
251struct PendingOutboundData {
253 local_settings: ConnectionSettings,
254 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
255}
256
257impl std::fmt::Debug for PendingOutboundData {
258 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
259 f.debug_struct("PendingOutbound")
260 .field("local_settings", &self.local_settings)
261 .finish()
262 }
263}
264
265#[derive(Clone)]
266pub(crate) struct ConnectionSender {
267 connection_id: ConnectionId,
268 sess_core: Arc<SessionCore>,
269 failures: Arc<mpsc::UnboundedSender<(RequestId, &'static str)>>,
270}
271
272fn forwarded_payload<'a>(payload: &'a roam_types::Payload<'static>) -> roam_types::Payload<'a> {
273 let roam_types::Payload::Incoming(bytes) = payload else {
274 unreachable!("proxy forwarding expects decoded incoming payload bytes")
275 };
276 roam_types::Payload::Incoming(bytes)
277}
278
279fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
280 match body {
281 RequestBody::Call(call) => RequestBody::Call(roam_types::RequestCall {
282 method_id: call.method_id,
283 channels: call.channels.clone(),
284 metadata: call.metadata.clone(),
285 args: forwarded_payload(&call.args),
286 }),
287 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
288 channels: response.channels.clone(),
289 metadata: response.metadata.clone(),
290 ret: forwarded_payload(&response.ret),
291 }),
292 RequestBody::Cancel(cancel) => RequestBody::Cancel(roam_types::RequestCancel {
293 metadata: cancel.metadata.clone(),
294 }),
295 }
296}
297
298fn forwarded_channel_body<'a>(
299 body: &'a roam_types::ChannelBody<'static>,
300) -> roam_types::ChannelBody<'a> {
301 match body {
302 roam_types::ChannelBody::Item(item) => {
303 roam_types::ChannelBody::Item(roam_types::ChannelItem {
304 item: forwarded_payload(&item.item),
305 })
306 }
307 roam_types::ChannelBody::Close(close) => {
308 roam_types::ChannelBody::Close(roam_types::ChannelClose {
309 metadata: close.metadata.clone(),
310 })
311 }
312 roam_types::ChannelBody::Reset(reset) => {
313 roam_types::ChannelBody::Reset(roam_types::ChannelReset {
314 metadata: reset.metadata.clone(),
315 })
316 }
317 roam_types::ChannelBody::GrantCredit(credit) => {
318 roam_types::ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
319 additional: credit.additional,
320 })
321 }
322 }
323}
324
325impl ConnectionSender {
326 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
328 let payload = match msg {
329 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
330 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
331 };
332 let message = Message {
333 connection_id: self.connection_id,
334 payload,
335 };
336 self.sess_core.send(message).await.map_err(|_| ())
337 }
338
339 pub(crate) async fn send_owned(
341 &self,
342 msg: SelfRef<ConnectionMessage<'static>>,
343 ) -> Result<(), ()> {
344 let payload = match &*msg {
345 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
346 id: request.id,
347 body: forwarded_request_body(&request.body),
348 }),
349 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
350 id: channel.id,
351 body: forwarded_channel_body(&channel.body),
352 }),
353 };
354
355 self.sess_core
356 .send(Message {
357 connection_id: self.connection_id,
358 payload,
359 })
360 .await
361 .map_err(|_| ())
362 }
363
364 pub async fn send_response<'a>(
366 &self,
367 request_id: RequestId,
368 response: RequestResponse<'a>,
369 ) -> Result<(), ()> {
370 self.send(ConnectionMessage::Request(RequestMessage {
371 id: request_id,
372 body: RequestBody::Response(response),
373 }))
374 .await
375 }
376
377 pub fn mark_failure(&self, request_id: RequestId, reason: &'static str) {
380 let _ = self.failures.send((request_id, reason));
381 }
382}
383
384pub struct ConnectionHandle {
385 pub(crate) sender: ConnectionSender,
386 pub(crate) rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
387 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
388 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
389 pub(crate) closed_rx: watch::Receiver<bool>,
390 pub parity: Parity,
392}
393
394impl std::fmt::Debug for ConnectionHandle {
395 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
396 f.debug_struct("ConnectionHandle")
397 .field("connection_id", &self.sender.connection_id)
398 .finish()
399 }
400}
401
402pub(crate) enum ConnectionMessage<'payload> {
403 Request(RequestMessage<'payload>),
404 Channel(ChannelMessage<'payload>),
405}
406
407impl ConnectionHandle {
408 pub fn connection_id(&self) -> ConnectionId {
410 self.sender.connection_id
411 }
412
413 pub async fn closed(&self) {
415 if *self.closed_rx.borrow() {
416 return;
417 }
418 let mut rx = self.closed_rx.clone();
419 while rx.changed().await.is_ok() {
420 if *rx.borrow() {
421 return;
422 }
423 }
424 }
425
426 pub fn is_connected(&self) -> bool {
428 !*self.closed_rx.borrow()
429 }
430}
431
432pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
438 let left_conn_id = left.connection_id();
439 let right_conn_id = right.connection_id();
440 let ConnectionHandle {
441 sender: left_sender,
442 rx: mut left_rx,
443 failures_rx: _left_failures_rx,
444 control_tx: left_control_tx,
445 closed_rx: _left_closed_rx,
446 parity: _left_parity,
447 } = left;
448 let ConnectionHandle {
449 sender: right_sender,
450 rx: mut right_rx,
451 failures_rx: _right_failures_rx,
452 control_tx: right_control_tx,
453 closed_rx: _right_closed_rx,
454 parity: _right_parity,
455 } = right;
456
457 loop {
458 tokio::select! {
459 msg = left_rx.recv() => {
460 let Some(msg) = msg else {
461 break;
462 };
463 if right_sender.send_owned(msg).await.is_err() {
464 break;
465 }
466 }
467 msg = right_rx.recv() => {
468 let Some(msg) = msg else {
469 break;
470 };
471 if left_sender.send_owned(msg).await.is_err() {
472 break;
473 }
474 }
475 }
476 }
477
478 if let Some(tx) = left_control_tx.as_ref() {
479 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
480 }
481 if let Some(tx) = right_control_tx.as_ref() {
482 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
483 }
484}
485
486#[derive(Debug)]
488pub enum SessionError {
489 Io(std::io::Error),
490 Protocol(String),
491 Rejected(Metadata<'static>),
492}
493
494impl std::fmt::Display for SessionError {
495 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
496 match self {
497 Self::Io(e) => write!(f, "io error: {e}"),
498 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
499 Self::Rejected(_) => write!(f, "connection rejected"),
500 }
501 }
502}
503
504impl std::error::Error for SessionError {}
505
506impl<C> Session<C>
507where
508 C: Conduit<Msg = MessageFamily>,
509 C::Tx: MaybeSend + MaybeSync + 'static,
510 for<'p> <C::Tx as ConduitTx>::Permit<'p>: MaybeSend,
511 C::Rx: MaybeSend,
512{
513 #[allow(clippy::too_many_arguments)]
514 fn pre_handshake(
515 tx: C::Tx,
516 rx: C::Rx,
517 on_connection: Option<Box<dyn ConnectionAcceptor>>,
518 open_rx: mpsc::Receiver<OpenRequest>,
519 close_rx: mpsc::Receiver<CloseRequest>,
520 control_tx: mpsc::UnboundedSender<DropControlRequest>,
521 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
522 keepalive: Option<SessionKeepaliveConfig>,
523 ) -> Self {
524 let sess_core = Arc::new(SessionCore { tx: Box::new(tx) });
525 Session {
526 rx,
527 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
530 conns: BTreeMap::new(),
531 root_closed_internal: false,
532 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
534 open_rx,
535 close_rx,
536 control_tx,
537 control_rx,
538 keepalive,
539 }
540 }
541
542 async fn establish_as_initiator(
544 &mut self,
545 settings: ConnectionSettings,
546 metadata: Metadata<'_>,
547 ) -> Result<ConnectionHandle, SessionError> {
548 use roam_types::{Hello, MessagePayload};
549
550 self.role = SessionRole::Initiator;
551 self.parity = settings.parity;
552 self.conn_ids = IdAllocator::new(settings.parity);
553
554 self.sess_core
556 .send(Message {
557 connection_id: ConnectionId::ROOT,
558 payload: MessagePayload::Hello(Hello {
559 version: PROTOCOL_VERSION,
560 connection_settings: settings.clone(),
561 metadata,
562 }),
563 })
564 .await
565 .map_err(|_| SessionError::Protocol("failed to send Hello".into()))?;
566
567 let peer_settings = match self.rx.recv().await {
569 Ok(Some(msg)) => {
570 let payload = msg.map(|m| m.payload);
571 match &*payload {
572 MessagePayload::HelloYourself(hy) => hy.connection_settings.clone(),
573 MessagePayload::ProtocolError(e) => {
574 return Err(SessionError::Protocol(e.description.to_owned()));
575 }
576 _ => {
577 return Err(SessionError::Protocol("expected HelloYourself".into()));
578 }
579 }
580 }
581 Ok(None) => {
582 return Err(SessionError::Protocol(
583 "peer closed during handshake".into(),
584 ));
585 }
586 Err(e) => return Err(SessionError::Protocol(e.to_string())),
587 };
588
589 Ok(self.make_root_handle(settings, peer_settings))
590 }
591
592 #[moire::instrument]
594 async fn establish_as_acceptor(
595 &mut self,
596 settings: ConnectionSettings,
597 metadata: Metadata<'_>,
598 ) -> Result<ConnectionHandle, SessionError> {
599 use roam_types::{HelloYourself, MessagePayload};
600
601 self.role = SessionRole::Acceptor;
602
603 let peer_settings = match self.rx.recv().await {
605 Ok(Some(msg)) => {
606 let payload = msg.map(|m| m.payload);
607 match &*payload {
608 MessagePayload::Hello(h) => {
609 if h.version != PROTOCOL_VERSION {
610 return Err(SessionError::Protocol(format!(
611 "version mismatch: got {}, expected {PROTOCOL_VERSION}",
612 h.version
613 )));
614 }
615 h.connection_settings.clone()
616 }
617 MessagePayload::ProtocolError(e) => {
618 return Err(SessionError::Protocol(e.description.to_owned()));
619 }
620 _ => {
621 return Err(SessionError::Protocol("expected Hello".into()));
622 }
623 }
624 }
625 Ok(None) => {
626 return Err(SessionError::Protocol(
627 "peer closed during handshake".into(),
628 ));
629 }
630 Err(e) => return Err(SessionError::Protocol(e.to_string())),
631 };
632
633 let our_settings = ConnectionSettings {
635 parity: peer_settings.parity.other(),
636 ..settings
637 };
638 self.parity = our_settings.parity;
639 self.conn_ids = IdAllocator::new(our_settings.parity);
640
641 self.sess_core
643 .send(Message {
644 connection_id: ConnectionId::ROOT,
645 payload: MessagePayload::HelloYourself(HelloYourself {
646 connection_settings: our_settings.clone(),
647 metadata,
648 }),
649 })
650 .await
651 .map_err(|_| SessionError::Protocol("failed to send HelloYourself".into()))?;
652
653 Ok(self.make_root_handle(our_settings, peer_settings))
654 }
655
656 fn make_root_handle(
657 &mut self,
658 local_settings: ConnectionSettings,
659 peer_settings: ConnectionSettings,
660 ) -> ConnectionHandle {
661 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
662 }
663
664 fn make_connection_handle(
665 &mut self,
666 conn_id: ConnectionId,
667 local_settings: ConnectionSettings,
668 peer_settings: ConnectionSettings,
669 ) -> ConnectionHandle {
670 let label = format!("session.conn{}", conn_id.0);
671 let (conn_tx, conn_rx) = mpsc::channel::<SelfRef<ConnectionMessage<'static>>>(&label, 64);
672 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
673 let (closed_tx, closed_rx) = watch::channel(false);
674
675 let sender = ConnectionSender {
676 connection_id: conn_id,
677 sess_core: Arc::clone(&self.sess_core),
678 failures: Arc::new(failures_tx),
679 };
680
681 let parity = local_settings.parity;
682 self.conns.insert(
683 conn_id,
684 ConnectionSlot::Active(ConnectionState {
685 id: conn_id,
686 local_settings,
687 peer_settings,
688 conn_tx,
689 closed_tx,
690 }),
691 );
692
693 ConnectionHandle {
694 sender,
695 rx: conn_rx,
696 failures_rx,
697 control_tx: Some(self.control_tx.clone()),
698 closed_rx,
699 parity,
700 }
701 }
702
703 pub async fn run(&mut self) {
708 let mut keepalive_runtime = self.make_keepalive_runtime();
709 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
710 let mut interval = tokio::time::interval(Duration::from_millis(10));
711 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
712 interval
713 });
714
715 loop {
716 tokio::select! {
717 msg = self.rx.recv() => {
718 match msg {
719 Ok(Some(msg)) => self.handle_message(msg, &mut keepalive_runtime).await,
720 Ok(None) => {
721 warn!("session recv loop ended: conduit returned EOF");
722 break;
723 }
724 Err(error) => {
725 warn!(error = %error, "session recv loop ended: conduit recv error");
726 break;
727 }
728 }
729 }
730 Some(req) = self.open_rx.recv() => {
731 self.handle_open_request(req).await;
732 }
733 Some(req) = self.close_rx.recv() => {
734 self.handle_close_request(req).await;
735 }
736 Some(req) = self.control_rx.recv() => {
737 if !self.handle_drop_control_request(req).await {
738 break;
739 }
740 }
741 _ = async {
742 if let Some(interval) = keepalive_tick.as_mut() {
743 interval.tick().await;
744 }
745 }, if keepalive_tick.is_some() => {
746 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
747 break;
748 }
749 }
750 }
751 }
752
753 self.close_all_connections();
755 debug!("session recv loop exited");
756 }
757
758 async fn handle_message(
759 &mut self,
760 msg: SelfRef<Message<'static>>,
761 keepalive_runtime: &mut Option<KeepaliveRuntime>,
762 ) {
763 let conn_id = msg.connection_id;
764 roam_types::selfref_match!(msg, payload {
765 MessagePayload::ConnectionClose(_) => {
767 if conn_id.is_root() {
768 warn!("received ConnectionClose for root connection");
769 } else {
770 debug!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
771 }
772 self.remove_connection(&conn_id);
776 self.maybe_request_shutdown_after_root_closed();
777 }
778 MessagePayload::ConnectionOpen(open) => {
779 self.handle_inbound_open(conn_id, open).await;
780 }
781 MessagePayload::ConnectionAccept(accept) => {
782 self.handle_inbound_accept(conn_id, accept);
783 }
784 MessagePayload::ConnectionReject(reject) => {
785 self.handle_inbound_reject(conn_id, reject);
786 }
787 MessagePayload::RequestMessage(r) => {
788 let conn_tx = match self.conns.get(&conn_id) {
789 Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
790 _ => return,
791 };
792 if conn_tx.send(r.map(ConnectionMessage::Request)).await.is_err() {
793 self.remove_connection(&conn_id);
794 self.maybe_request_shutdown_after_root_closed();
795 }
796 }
797 MessagePayload::ChannelMessage(c) => {
798 let conn_tx = match self.conns.get(&conn_id) {
799 Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
800 _ => return,
801 };
802 if conn_tx.send(c.map(ConnectionMessage::Channel)).await.is_err() {
803 self.remove_connection(&conn_id);
804 self.maybe_request_shutdown_after_root_closed();
805 }
806 }
807 MessagePayload::Ping(ping) => {
808 let _ = self
809 .sess_core
810 .send(Message {
811 connection_id: conn_id,
812 payload: MessagePayload::Pong(roam_types::Pong { nonce: ping.nonce }),
813 })
814 .await;
815 }
816 MessagePayload::Pong(pong) => {
817 if conn_id.is_root() {
818 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
819 }
820 }
821 })
823 }
824
825 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
826 let config = self.keepalive?;
827 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
828 warn!("keepalive disabled due to non-positive interval/timeout");
829 return None;
830 }
831 let now = tokio::time::Instant::now();
832 Some(KeepaliveRuntime {
833 ping_interval: config.ping_interval,
834 pong_timeout: config.pong_timeout,
835 next_ping_at: now + config.ping_interval,
836 waiting_pong_nonce: None,
837 pong_deadline: now,
838 next_ping_nonce: 1,
839 })
840 }
841
842 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
843 let Some(runtime) = keepalive_runtime.as_mut() else {
844 return;
845 };
846 if runtime.waiting_pong_nonce != Some(nonce) {
847 return;
848 }
849 runtime.waiting_pong_nonce = None;
850 runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
851 }
852
853 async fn handle_keepalive_tick(
854 &mut self,
855 keepalive_runtime: &mut Option<KeepaliveRuntime>,
856 ) -> bool {
857 let Some(runtime) = keepalive_runtime.as_mut() else {
858 return true;
859 };
860 let now = tokio::time::Instant::now();
861
862 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
863 if now >= runtime.pong_deadline {
864 warn!(
865 nonce = waiting_nonce,
866 timeout_ms = runtime.pong_timeout.as_millis(),
867 "keepalive timeout waiting for pong"
868 );
869 return false;
870 }
871 return true;
872 }
873
874 if now < runtime.next_ping_at {
875 return true;
876 }
877
878 let nonce = runtime.next_ping_nonce;
879 if self
880 .sess_core
881 .send(Message {
882 connection_id: ConnectionId::ROOT,
883 payload: MessagePayload::Ping(roam_types::Ping { nonce }),
884 })
885 .await
886 .is_err()
887 {
888 warn!("failed to send keepalive ping");
889 return false;
890 }
891
892 runtime.waiting_pong_nonce = Some(nonce);
893 runtime.pong_deadline = now + runtime.pong_timeout;
894 runtime.next_ping_at = now + runtime.ping_interval;
895 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
896 true
897 }
898
899 async fn handle_inbound_open(
900 &mut self,
901 conn_id: ConnectionId,
902 open: SelfRef<ConnectionOpen<'static>>,
903 ) {
904 let peer_parity = self.parity.other();
906 if !conn_id.has_parity(peer_parity) {
907 let _ = self
909 .sess_core
910 .send(Message {
911 connection_id: conn_id,
912 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
913 metadata: vec![],
914 }),
915 })
916 .await;
917 return;
918 }
919
920 if self.conns.contains_key(&conn_id) {
922 let _ = self
924 .sess_core
925 .send(Message {
926 connection_id: conn_id,
927 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
928 metadata: vec![],
929 }),
930 })
931 .await;
932 return;
933 }
934
935 let acceptor = match &self.on_connection {
938 Some(a) => a,
939 None => {
940 let _ = self
941 .sess_core
942 .send(Message {
943 connection_id: conn_id,
944 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
945 metadata: vec![],
946 }),
947 })
948 .await;
949 return;
950 }
951 };
952
953 match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
954 Ok(accepted) => {
955 let handle = self.make_connection_handle(
957 conn_id,
958 accepted.settings.clone(),
959 open.connection_settings.clone(),
960 );
961
962 let _ = self
964 .sess_core
965 .send(Message {
966 connection_id: conn_id,
967 payload: MessagePayload::ConnectionAccept(roam_types::ConnectionAccept {
968 connection_settings: accepted.settings,
969 metadata: accepted.metadata,
970 }),
971 })
972 .await;
973
974 (accepted.setup)(handle);
976 }
977 Err(reject_metadata) => {
978 let _ = self
979 .sess_core
980 .send(Message {
981 connection_id: conn_id,
982 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
983 metadata: reject_metadata,
984 }),
985 })
986 .await;
987 }
988 }
989 }
990
991 fn handle_inbound_accept(
992 &mut self,
993 conn_id: ConnectionId,
994 accept: SelfRef<ConnectionAccept<'static>>,
995 ) {
996 let slot = self.remove_connection(&conn_id);
997 match slot {
998 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
999 let handle = self.make_connection_handle(
1000 conn_id,
1001 pending.local_settings.clone(),
1002 accept.connection_settings.clone(),
1003 );
1004
1005 if let Some(tx) = pending.result_tx.take() {
1006 let _ = tx.send(Ok(handle));
1007 }
1008 }
1009 Some(other) => {
1010 self.conns.insert(conn_id, other);
1012 }
1013 None => {
1014 }
1016 }
1017 }
1018
1019 fn handle_inbound_reject(
1020 &mut self,
1021 conn_id: ConnectionId,
1022 reject: SelfRef<ConnectionReject<'static>>,
1023 ) {
1024 let slot = self.remove_connection(&conn_id);
1025 match slot {
1026 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1027 if let Some(tx) = pending.result_tx.take() {
1028 let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1029 }
1030 }
1031 Some(other) => {
1032 self.conns.insert(conn_id, other);
1033 }
1034 None => {}
1035 }
1036 }
1037
1038 async fn handle_open_request(&mut self, req: OpenRequest) {
1040 let conn_id = self.conn_ids.alloc();
1041
1042 let send_result = self
1044 .sess_core
1045 .send(Message {
1046 connection_id: conn_id,
1047 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1048 connection_settings: req.settings.clone(),
1049 metadata: req.metadata,
1050 }),
1051 })
1052 .await;
1053
1054 if send_result.is_err() {
1055 let _ = req.result_tx.send(Err(SessionError::Protocol(
1056 "failed to send ConnectionOpen".into(),
1057 )));
1058 return;
1059 }
1060
1061 self.conns.insert(
1064 conn_id,
1065 ConnectionSlot::PendingOutbound(PendingOutboundData {
1066 local_settings: req.settings,
1067 result_tx: Some(req.result_tx),
1068 }),
1069 );
1070 }
1071
1072 async fn handle_close_request(&mut self, req: CloseRequest) {
1074 if req.conn_id.is_root() {
1075 let _ = req.result_tx.send(Err(SessionError::Protocol(
1076 "cannot close root connection".into(),
1077 )));
1078 return;
1079 }
1080
1081 if self.remove_connection(&req.conn_id).is_none() {
1084 let _ = req
1085 .result_tx
1086 .send(Err(SessionError::Protocol("connection not found".into())));
1087 return;
1088 }
1089
1090 let send_result = self
1092 .sess_core
1093 .send(Message {
1094 connection_id: req.conn_id,
1095 payload: MessagePayload::ConnectionClose(ConnectionClose {
1096 metadata: req.metadata,
1097 }),
1098 })
1099 .await;
1100
1101 if send_result.is_err() {
1102 let _ = req.result_tx.send(Err(SessionError::Protocol(
1103 "failed to send ConnectionClose".into(),
1104 )));
1105 return;
1106 }
1107
1108 let _ = req.result_tx.send(Ok(()));
1109 self.maybe_request_shutdown_after_root_closed();
1110 }
1111
1112 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1113 match req {
1114 DropControlRequest::Shutdown => {
1115 debug!("session shutdown requested");
1116 false
1117 }
1118 DropControlRequest::Close(conn_id) => {
1119 if conn_id.is_root() {
1121 debug!("root callers dropped; internally closing root connection");
1123 self.root_closed_internal = true;
1124 return self.has_virtual_connections();
1126 }
1127
1128 if self.remove_connection(&conn_id).is_some() {
1129 let _ = self
1130 .sess_core
1131 .send(Message {
1132 connection_id: conn_id,
1133 payload: MessagePayload::ConnectionClose(ConnectionClose {
1134 metadata: vec![],
1135 }),
1136 })
1137 .await;
1138 }
1139
1140 !self.root_closed_internal || self.has_virtual_connections()
1141 }
1142 }
1143 }
1144
1145 fn has_virtual_connections(&self) -> bool {
1146 self.conns.keys().any(|id| !id.is_root())
1147 }
1148
1149 fn remove_connection(&mut self, conn_id: &ConnectionId) -> Option<ConnectionSlot> {
1150 let slot = self.conns.remove(conn_id);
1151 if let Some(ConnectionSlot::Active(state)) = &slot {
1152 let _ = state.closed_tx.send(true);
1153 }
1154 slot
1155 }
1156
1157 fn close_all_connections(&mut self) {
1158 for slot in self.conns.values() {
1159 if let ConnectionSlot::Active(state) = slot {
1160 let _ = state.closed_tx.send(true);
1161 }
1162 }
1163 self.conns.clear();
1164 }
1165
1166 fn maybe_request_shutdown_after_root_closed(&self) {
1167 if self.root_closed_internal && !self.has_virtual_connections() {
1168 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1169 }
1170 }
1171}
1172
1173pub(crate) struct SessionCore {
1174 tx: Box<dyn DynConduitTx>,
1175}
1176
1177impl SessionCore {
1178 pub(crate) async fn send<'a>(&self, msg: Message<'a>) -> Result<(), ()> {
1179 self.tx.send_msg(msg).await.map_err(|_| ())
1180 }
1181}
1182
1183#[cfg(not(target_arch = "wasm32"))]
1184type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
1185#[cfg(target_arch = "wasm32")]
1186type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
1187
1188#[cfg(not(target_arch = "wasm32"))]
1189pub trait DynConduitTx: Send + Sync {
1190 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1191}
1192#[cfg(target_arch = "wasm32")]
1193pub trait DynConduitTx {
1194 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1195}
1196
1197impl<T> DynConduitTx for T
1200where
1201 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1202 for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1203{
1204 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>> {
1205 Box::pin(async move {
1206 let permit = self.reserve().await?;
1207 permit
1208 .send(msg)
1209 .map_err(|e| std::io::Error::other(e.to_string()))
1210 })
1211 }
1212}
1213
1214#[cfg(test)]
1215mod tests {
1216 use moire::sync::mpsc;
1217 use roam_types::{
1218 Backing, Conduit, ConnectionAccept, ConnectionReject, MessageFamily, SelfRef,
1219 };
1220
1221 use super::*;
1222
1223 type TestConduit = crate::BareConduit<MessageFamily, crate::MemoryLink>;
1224
1225 fn make_session() -> Session<TestConduit> {
1226 let (a, b) = crate::memory_link_pair(32);
1227 std::mem::forget(b);
1229 let conduit = crate::BareConduit::new(a);
1230 let (tx, rx) = conduit.split();
1231 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1232 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1233 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1234 Session::pre_handshake(
1235 tx, rx, None, open_rx, close_rx, control_tx, control_rx, None,
1236 )
1237 }
1238
1239 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1240 SelfRef::owning(
1241 Backing::Boxed(Box::<[u8]>::default()),
1242 ConnectionAccept {
1243 connection_settings: ConnectionSettings {
1244 parity: Parity::Even,
1245 max_concurrent_requests: 64,
1246 },
1247 metadata: vec![],
1248 },
1249 )
1250 }
1251
1252 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1253 SelfRef::owning(
1254 Backing::Boxed(Box::<[u8]>::default()),
1255 ConnectionReject { metadata: vec![] },
1256 )
1257 }
1258
1259 #[tokio::test]
1260 async fn duplicate_connection_accept_is_ignored_after_first() {
1261 let mut session = make_session();
1262 let conn_id = ConnectionId(1);
1263 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1264
1265 session.conns.insert(
1266 conn_id,
1267 ConnectionSlot::PendingOutbound(PendingOutboundData {
1268 local_settings: ConnectionSettings {
1269 parity: Parity::Odd,
1270 max_concurrent_requests: 64,
1271 },
1272 result_tx: Some(result_tx),
1273 }),
1274 );
1275
1276 session.handle_inbound_accept(conn_id, accept_ref());
1277 let handle = result_rx
1278 .await
1279 .expect("pending outbound result should resolve")
1280 .expect("accept should resolve as Ok");
1281 assert_eq!(handle.connection_id(), conn_id);
1282
1283 session.handle_inbound_accept(conn_id, accept_ref());
1284 assert!(
1285 matches!(
1286 session.conns.get(&conn_id),
1287 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
1288 ),
1289 "duplicate accept should keep existing active connection state"
1290 );
1291 }
1292
1293 #[tokio::test]
1294 async fn duplicate_connection_reject_is_ignored_after_first() {
1295 let mut session = make_session();
1296 let conn_id = ConnectionId(1);
1297 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1298
1299 session.conns.insert(
1300 conn_id,
1301 ConnectionSlot::PendingOutbound(PendingOutboundData {
1302 local_settings: ConnectionSettings {
1303 parity: Parity::Odd,
1304 max_concurrent_requests: 64,
1305 },
1306 result_tx: Some(result_tx),
1307 }),
1308 );
1309
1310 session.handle_inbound_reject(conn_id, reject_ref());
1311 let result = result_rx
1312 .await
1313 .expect("pending outbound result should resolve");
1314 assert!(
1315 matches!(result, Err(SessionError::Rejected(_))),
1316 "expected rejection, got: {result:?}"
1317 );
1318
1319 session.handle_inbound_reject(conn_id, reject_ref());
1320 assert!(
1321 !session.conns.contains_key(&conn_id),
1322 "duplicate reject should not recreate connection state"
1323 );
1324 }
1325
1326 #[test]
1327 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
1328 let mut session = make_session();
1329 let conn_id = ConnectionId(99);
1330
1331 session.handle_inbound_accept(conn_id, accept_ref());
1332 session.handle_inbound_reject(conn_id, reject_ref());
1333
1334 assert!(
1335 session.conns.is_empty(),
1336 "out-of-order accept/reject should not mutate empty connection table"
1337 );
1338 }
1339
1340 #[tokio::test]
1341 async fn close_request_clears_pending_outbound_open() {
1342 let mut session = make_session();
1343 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
1344 let (close_result_tx, close_result_rx) =
1345 moire::sync::oneshot::channel("session.close.result");
1346
1347 session.conns.insert(
1348 ConnectionId(1),
1349 ConnectionSlot::PendingOutbound(PendingOutboundData {
1350 local_settings: ConnectionSettings {
1351 parity: Parity::Odd,
1352 max_concurrent_requests: 64,
1353 },
1354 result_tx: Some(open_result_tx),
1355 }),
1356 );
1357
1358 session
1359 .handle_close_request(CloseRequest {
1360 conn_id: ConnectionId(1),
1361 metadata: vec![],
1362 result_tx: close_result_tx,
1363 })
1364 .await;
1365
1366 let close_result = close_result_rx
1367 .await
1368 .expect("close result should be delivered");
1369 assert!(
1370 close_result.is_ok(),
1371 "close should succeed for pending outbound connection"
1372 );
1373
1374 assert!(
1375 open_result_rx.await.is_err(),
1376 "pending open result channel should be closed once the pending slot is removed"
1377 );
1378 }
1379}