1use std::{collections::BTreeMap, pin::Pin, sync::Arc, time::Duration};
2
3use moire::sync::mpsc;
4use roam_types::{
5 ChannelMessage, Conduit, ConduitRx, ConduitTx, ConduitTxPermit, ConnectionAccept,
6 ConnectionClose, ConnectionId, ConnectionOpen, ConnectionReject, ConnectionSettings,
7 IdAllocator, MaybeSend, MaybeSync, Message, MessageFamily, MessagePayload, Metadata, Parity,
8 RequestBody, RequestId, RequestMessage, RequestResponse, SelfRef, SessionRole,
9};
10use tracing::{debug, warn};
11
12mod builders;
13pub use builders::*;
14
15pub const PROTOCOL_VERSION: u32 = 7;
18
19#[derive(Debug, Clone, Copy)]
21pub struct SessionKeepaliveConfig {
22 pub ping_interval: Duration,
23 pub pong_timeout: Duration,
24}
25
26pub trait ConnectionAcceptor: Send + 'static {
39 fn accept(
40 &self,
41 conn_id: ConnectionId,
42 peer_settings: &ConnectionSettings,
43 metadata: &[roam_types::MetadataEntry],
44 ) -> Result<AcceptedConnection, Metadata<'static>>;
45}
46
47pub struct AcceptedConnection {
49 pub settings: ConnectionSettings,
51 pub metadata: Metadata<'static>,
53 pub setup: Box<dyn FnOnce(ConnectionHandle) + Send>,
55}
56
57struct OpenRequest {
62 settings: ConnectionSettings,
63 metadata: Metadata<'static>,
64 result_tx: moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>,
65}
66
67struct CloseRequest {
68 conn_id: ConnectionId,
69 metadata: Metadata<'static>,
70 result_tx: moire::sync::oneshot::Sender<Result<(), SessionError>>,
71}
72
73#[derive(Debug, Clone, Copy)]
74pub(crate) enum DropControlRequest {
75 Shutdown,
76 Close(ConnectionId),
77}
78
79#[cfg(not(target_arch = "wasm32"))]
80fn send_drop_control(
81 tx: &mpsc::UnboundedSender<DropControlRequest>,
82 req: DropControlRequest,
83) -> Result<(), ()> {
84 tx.send(req).map_err(|_| ())
85}
86
87#[cfg(target_arch = "wasm32")]
88fn send_drop_control(
89 tx: &mpsc::UnboundedSender<DropControlRequest>,
90 req: DropControlRequest,
91) -> Result<(), ()> {
92 tx.try_send(req).map_err(|_| ())
93}
94
95#[derive(Clone)]
106pub struct SessionHandle {
107 open_tx: mpsc::Sender<OpenRequest>,
108 close_tx: mpsc::Sender<CloseRequest>,
109 control_tx: mpsc::UnboundedSender<DropControlRequest>,
110}
111
112impl SessionHandle {
113 pub async fn open_connection(
120 &self,
121 settings: ConnectionSettings,
122 metadata: Metadata<'static>,
123 ) -> Result<ConnectionHandle, SessionError> {
124 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.open_result");
125 self.open_tx
126 .send(OpenRequest {
127 settings,
128 metadata,
129 result_tx,
130 })
131 .await
132 .map_err(|_| SessionError::Protocol("session closed".into()))?;
133 result_rx
134 .await
135 .map_err(|_| SessionError::Protocol("session closed".into()))?
136 }
137
138 pub async fn close_connection(
145 &self,
146 conn_id: ConnectionId,
147 metadata: Metadata<'static>,
148 ) -> Result<(), SessionError> {
149 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.close_result");
150 self.close_tx
151 .send(CloseRequest {
152 conn_id,
153 metadata,
154 result_tx,
155 })
156 .await
157 .map_err(|_| SessionError::Protocol("session closed".into()))?;
158 result_rx
159 .await
160 .map_err(|_| SessionError::Protocol("session closed".into()))?
161 }
162
163 pub fn shutdown(&self) -> Result<(), SessionError> {
165 send_drop_control(&self.control_tx, DropControlRequest::Shutdown)
166 .map_err(|_| SessionError::Protocol("session closed".into()))
167 }
168}
169
170pub struct Session<C: Conduit> {
178 rx: C::Rx,
180
181 role: SessionRole,
183
184 parity: Parity,
187
188 sess_core: Arc<SessionCore>,
190
191 conns: BTreeMap<ConnectionId, ConnectionSlot>,
193 root_closed_internal: bool,
195
196 conn_ids: IdAllocator<ConnectionId>,
198
199 on_connection: Option<Box<dyn ConnectionAcceptor>>,
201
202 open_rx: mpsc::Receiver<OpenRequest>,
204
205 close_rx: mpsc::Receiver<CloseRequest>,
207
208 control_tx: mpsc::UnboundedSender<DropControlRequest>,
210 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
211
212 keepalive: Option<SessionKeepaliveConfig>,
214}
215
216#[derive(Debug)]
217struct KeepaliveRuntime {
218 ping_interval: Duration,
219 pong_timeout: Duration,
220 next_ping_at: tokio::time::Instant,
221 waiting_pong_nonce: Option<u64>,
222 pong_deadline: tokio::time::Instant,
223 next_ping_nonce: u64,
224}
225
226#[derive(Debug)]
229pub struct ConnectionState {
230 pub id: ConnectionId,
232
233 pub local_settings: ConnectionSettings,
235
236 pub peer_settings: ConnectionSettings,
238
239 conn_tx: mpsc::Sender<SelfRef<ConnectionMessage<'static>>>,
241}
242
243#[derive(Debug)]
244enum ConnectionSlot {
245 Active(ConnectionState),
246 PendingOutbound(PendingOutboundData),
247}
248
249struct PendingOutboundData {
251 local_settings: ConnectionSettings,
252 result_tx: Option<moire::sync::oneshot::Sender<Result<ConnectionHandle, SessionError>>>,
253}
254
255impl std::fmt::Debug for PendingOutboundData {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 f.debug_struct("PendingOutbound")
258 .field("local_settings", &self.local_settings)
259 .finish()
260 }
261}
262
263#[derive(Clone)]
264pub(crate) struct ConnectionSender {
265 connection_id: ConnectionId,
266 sess_core: Arc<SessionCore>,
267 failures: Arc<mpsc::UnboundedSender<(RequestId, &'static str)>>,
268}
269
270fn forwarded_payload<'a>(payload: &'a roam_types::Payload<'static>) -> roam_types::Payload<'a> {
271 let roam_types::Payload::Incoming(bytes) = payload else {
272 unreachable!("proxy forwarding expects decoded incoming payload bytes")
273 };
274 roam_types::Payload::Incoming(bytes)
275}
276
277fn forwarded_request_body<'a>(body: &'a RequestBody<'static>) -> RequestBody<'a> {
278 match body {
279 RequestBody::Call(call) => RequestBody::Call(roam_types::RequestCall {
280 method_id: call.method_id,
281 channels: call.channels.clone(),
282 metadata: call.metadata.clone(),
283 args: forwarded_payload(&call.args),
284 }),
285 RequestBody::Response(response) => RequestBody::Response(RequestResponse {
286 channels: response.channels.clone(),
287 metadata: response.metadata.clone(),
288 ret: forwarded_payload(&response.ret),
289 }),
290 RequestBody::Cancel(cancel) => RequestBody::Cancel(roam_types::RequestCancel {
291 metadata: cancel.metadata.clone(),
292 }),
293 }
294}
295
296fn forwarded_channel_body<'a>(
297 body: &'a roam_types::ChannelBody<'static>,
298) -> roam_types::ChannelBody<'a> {
299 match body {
300 roam_types::ChannelBody::Item(item) => {
301 roam_types::ChannelBody::Item(roam_types::ChannelItem {
302 item: forwarded_payload(&item.item),
303 })
304 }
305 roam_types::ChannelBody::Close(close) => {
306 roam_types::ChannelBody::Close(roam_types::ChannelClose {
307 metadata: close.metadata.clone(),
308 })
309 }
310 roam_types::ChannelBody::Reset(reset) => {
311 roam_types::ChannelBody::Reset(roam_types::ChannelReset {
312 metadata: reset.metadata.clone(),
313 })
314 }
315 roam_types::ChannelBody::GrantCredit(credit) => {
316 roam_types::ChannelBody::GrantCredit(roam_types::ChannelGrantCredit {
317 additional: credit.additional,
318 })
319 }
320 }
321}
322
323impl ConnectionSender {
324 pub async fn send<'a>(&self, msg: ConnectionMessage<'a>) -> Result<(), ()> {
326 let payload = match msg {
327 ConnectionMessage::Request(r) => MessagePayload::RequestMessage(r),
328 ConnectionMessage::Channel(c) => MessagePayload::ChannelMessage(c),
329 };
330 let message = Message {
331 connection_id: self.connection_id,
332 payload,
333 };
334 self.sess_core.send(message).await.map_err(|_| ())
335 }
336
337 pub(crate) async fn send_owned(
339 &self,
340 msg: SelfRef<ConnectionMessage<'static>>,
341 ) -> Result<(), ()> {
342 let payload = match &*msg {
343 ConnectionMessage::Request(request) => MessagePayload::RequestMessage(RequestMessage {
344 id: request.id,
345 body: forwarded_request_body(&request.body),
346 }),
347 ConnectionMessage::Channel(channel) => MessagePayload::ChannelMessage(ChannelMessage {
348 id: channel.id,
349 body: forwarded_channel_body(&channel.body),
350 }),
351 };
352
353 self.sess_core
354 .send(Message {
355 connection_id: self.connection_id,
356 payload,
357 })
358 .await
359 .map_err(|_| ())
360 }
361
362 pub async fn send_response<'a>(
364 &self,
365 request_id: RequestId,
366 response: RequestResponse<'a>,
367 ) -> Result<(), ()> {
368 self.send(ConnectionMessage::Request(RequestMessage {
369 id: request_id,
370 body: RequestBody::Response(response),
371 }))
372 .await
373 }
374
375 pub fn mark_failure(&self, request_id: RequestId, reason: &'static str) {
378 let _ = self.failures.send((request_id, reason));
379 }
380}
381
382pub struct ConnectionHandle {
383 pub(crate) sender: ConnectionSender,
384 pub(crate) rx: mpsc::Receiver<SelfRef<ConnectionMessage<'static>>>,
385 pub(crate) failures_rx: mpsc::UnboundedReceiver<(RequestId, &'static str)>,
386 pub(crate) control_tx: Option<mpsc::UnboundedSender<DropControlRequest>>,
387 pub parity: Parity,
389}
390
391impl std::fmt::Debug for ConnectionHandle {
392 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
393 f.debug_struct("ConnectionHandle")
394 .field("connection_id", &self.sender.connection_id)
395 .finish()
396 }
397}
398
399pub(crate) enum ConnectionMessage<'payload> {
400 Request(RequestMessage<'payload>),
401 Channel(ChannelMessage<'payload>),
402}
403
404impl ConnectionHandle {
405 pub fn connection_id(&self) -> ConnectionId {
407 self.sender.connection_id
408 }
409}
410
411pub async fn proxy_connections(left: ConnectionHandle, right: ConnectionHandle) {
417 let left_conn_id = left.connection_id();
418 let right_conn_id = right.connection_id();
419 let ConnectionHandle {
420 sender: left_sender,
421 rx: mut left_rx,
422 failures_rx: _left_failures_rx,
423 control_tx: left_control_tx,
424 parity: _left_parity,
425 } = left;
426 let ConnectionHandle {
427 sender: right_sender,
428 rx: mut right_rx,
429 failures_rx: _right_failures_rx,
430 control_tx: right_control_tx,
431 parity: _right_parity,
432 } = right;
433
434 loop {
435 tokio::select! {
436 msg = left_rx.recv() => {
437 let Some(msg) = msg else {
438 break;
439 };
440 if right_sender.send_owned(msg).await.is_err() {
441 break;
442 }
443 }
444 msg = right_rx.recv() => {
445 let Some(msg) = msg else {
446 break;
447 };
448 if left_sender.send_owned(msg).await.is_err() {
449 break;
450 }
451 }
452 }
453 }
454
455 if let Some(tx) = left_control_tx.as_ref() {
456 let _ = send_drop_control(tx, DropControlRequest::Close(left_conn_id));
457 }
458 if let Some(tx) = right_control_tx.as_ref() {
459 let _ = send_drop_control(tx, DropControlRequest::Close(right_conn_id));
460 }
461}
462
463#[derive(Debug)]
465pub enum SessionError {
466 Io(std::io::Error),
467 Protocol(String),
468 Rejected(Metadata<'static>),
469}
470
471impl std::fmt::Display for SessionError {
472 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
473 match self {
474 Self::Io(e) => write!(f, "io error: {e}"),
475 Self::Protocol(msg) => write!(f, "protocol error: {msg}"),
476 Self::Rejected(_) => write!(f, "connection rejected"),
477 }
478 }
479}
480
481impl std::error::Error for SessionError {}
482
483impl<C> Session<C>
484where
485 C: Conduit<Msg = MessageFamily>,
486 C::Tx: MaybeSend + MaybeSync + 'static,
487 for<'p> <C::Tx as ConduitTx>::Permit<'p>: MaybeSend,
488 C::Rx: MaybeSend,
489{
490 #[allow(clippy::too_many_arguments)]
491 fn pre_handshake(
492 tx: C::Tx,
493 rx: C::Rx,
494 on_connection: Option<Box<dyn ConnectionAcceptor>>,
495 open_rx: mpsc::Receiver<OpenRequest>,
496 close_rx: mpsc::Receiver<CloseRequest>,
497 control_tx: mpsc::UnboundedSender<DropControlRequest>,
498 control_rx: mpsc::UnboundedReceiver<DropControlRequest>,
499 keepalive: Option<SessionKeepaliveConfig>,
500 ) -> Self {
501 let sess_core = Arc::new(SessionCore { tx: Box::new(tx) });
502 Session {
503 rx,
504 role: SessionRole::Initiator, parity: Parity::Odd, sess_core,
507 conns: BTreeMap::new(),
508 root_closed_internal: false,
509 conn_ids: IdAllocator::new(Parity::Odd), on_connection,
511 open_rx,
512 close_rx,
513 control_tx,
514 control_rx,
515 keepalive,
516 }
517 }
518
519 async fn establish_as_initiator(
521 &mut self,
522 settings: ConnectionSettings,
523 metadata: Metadata<'_>,
524 ) -> Result<ConnectionHandle, SessionError> {
525 use roam_types::{Hello, MessagePayload};
526
527 self.role = SessionRole::Initiator;
528 self.parity = settings.parity;
529 self.conn_ids = IdAllocator::new(settings.parity);
530
531 self.sess_core
533 .send(Message {
534 connection_id: ConnectionId::ROOT,
535 payload: MessagePayload::Hello(Hello {
536 version: PROTOCOL_VERSION,
537 connection_settings: settings.clone(),
538 metadata,
539 }),
540 })
541 .await
542 .map_err(|_| SessionError::Protocol("failed to send Hello".into()))?;
543
544 let peer_settings = match self.rx.recv().await {
546 Ok(Some(msg)) => {
547 let payload = msg.map(|m| m.payload);
548 match &*payload {
549 MessagePayload::HelloYourself(hy) => hy.connection_settings.clone(),
550 MessagePayload::ProtocolError(e) => {
551 return Err(SessionError::Protocol(e.description.to_owned()));
552 }
553 _ => {
554 return Err(SessionError::Protocol("expected HelloYourself".into()));
555 }
556 }
557 }
558 Ok(None) => {
559 return Err(SessionError::Protocol(
560 "peer closed during handshake".into(),
561 ));
562 }
563 Err(e) => return Err(SessionError::Protocol(e.to_string())),
564 };
565
566 Ok(self.make_root_handle(settings, peer_settings))
567 }
568
569 #[moire::instrument]
571 async fn establish_as_acceptor(
572 &mut self,
573 settings: ConnectionSettings,
574 metadata: Metadata<'_>,
575 ) -> Result<ConnectionHandle, SessionError> {
576 use roam_types::{HelloYourself, MessagePayload};
577
578 self.role = SessionRole::Acceptor;
579
580 let peer_settings = match self.rx.recv().await {
582 Ok(Some(msg)) => {
583 let payload = msg.map(|m| m.payload);
584 match &*payload {
585 MessagePayload::Hello(h) => {
586 if h.version != PROTOCOL_VERSION {
587 return Err(SessionError::Protocol(format!(
588 "version mismatch: got {}, expected {PROTOCOL_VERSION}",
589 h.version
590 )));
591 }
592 h.connection_settings.clone()
593 }
594 MessagePayload::ProtocolError(e) => {
595 return Err(SessionError::Protocol(e.description.to_owned()));
596 }
597 _ => {
598 return Err(SessionError::Protocol("expected Hello".into()));
599 }
600 }
601 }
602 Ok(None) => {
603 return Err(SessionError::Protocol(
604 "peer closed during handshake".into(),
605 ));
606 }
607 Err(e) => return Err(SessionError::Protocol(e.to_string())),
608 };
609
610 let our_settings = ConnectionSettings {
612 parity: peer_settings.parity.other(),
613 ..settings
614 };
615 self.parity = our_settings.parity;
616 self.conn_ids = IdAllocator::new(our_settings.parity);
617
618 self.sess_core
620 .send(Message {
621 connection_id: ConnectionId::ROOT,
622 payload: MessagePayload::HelloYourself(HelloYourself {
623 connection_settings: our_settings.clone(),
624 metadata,
625 }),
626 })
627 .await
628 .map_err(|_| SessionError::Protocol("failed to send HelloYourself".into()))?;
629
630 Ok(self.make_root_handle(our_settings, peer_settings))
631 }
632
633 fn make_root_handle(
634 &mut self,
635 local_settings: ConnectionSettings,
636 peer_settings: ConnectionSettings,
637 ) -> ConnectionHandle {
638 self.make_connection_handle(ConnectionId::ROOT, local_settings, peer_settings)
639 }
640
641 fn make_connection_handle(
642 &mut self,
643 conn_id: ConnectionId,
644 local_settings: ConnectionSettings,
645 peer_settings: ConnectionSettings,
646 ) -> ConnectionHandle {
647 let label = format!("session.conn{}", conn_id.0);
648 let (conn_tx, conn_rx) = mpsc::channel::<SelfRef<ConnectionMessage<'static>>>(&label, 64);
649 let (failures_tx, failures_rx) = mpsc::unbounded_channel(format!("{label}.failures"));
650
651 let sender = ConnectionSender {
652 connection_id: conn_id,
653 sess_core: Arc::clone(&self.sess_core),
654 failures: Arc::new(failures_tx),
655 };
656
657 let parity = local_settings.parity;
658 self.conns.insert(
659 conn_id,
660 ConnectionSlot::Active(ConnectionState {
661 id: conn_id,
662 local_settings,
663 peer_settings,
664 conn_tx,
665 }),
666 );
667
668 ConnectionHandle {
669 sender,
670 rx: conn_rx,
671 failures_rx,
672 control_tx: Some(self.control_tx.clone()),
673 parity,
674 }
675 }
676
677 pub async fn run(&mut self) {
682 let mut keepalive_runtime = self.make_keepalive_runtime();
683 let mut keepalive_tick = keepalive_runtime.as_ref().map(|_| {
684 let mut interval = tokio::time::interval(Duration::from_millis(10));
685 interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
686 interval
687 });
688
689 loop {
690 tokio::select! {
691 msg = self.rx.recv() => {
692 match msg {
693 Ok(Some(msg)) => self.handle_message(msg, &mut keepalive_runtime).await,
694 Ok(None) => {
695 warn!("session recv loop ended: conduit returned EOF");
696 break;
697 }
698 Err(error) => {
699 warn!(error = %error, "session recv loop ended: conduit recv error");
700 break;
701 }
702 }
703 }
704 Some(req) = self.open_rx.recv() => {
705 self.handle_open_request(req).await;
706 }
707 Some(req) = self.close_rx.recv() => {
708 self.handle_close_request(req).await;
709 }
710 Some(req) = self.control_rx.recv() => {
711 if !self.handle_drop_control_request(req).await {
712 break;
713 }
714 }
715 _ = async {
716 if let Some(interval) = keepalive_tick.as_mut() {
717 interval.tick().await;
718 }
719 }, if keepalive_tick.is_some() => {
720 if !self.handle_keepalive_tick(&mut keepalive_runtime).await {
721 break;
722 }
723 }
724 }
725 }
726
727 self.conns.clear();
729 debug!("session recv loop exited");
730 }
731
732 async fn handle_message(
733 &mut self,
734 msg: SelfRef<Message<'static>>,
735 keepalive_runtime: &mut Option<KeepaliveRuntime>,
736 ) {
737 let conn_id = msg.connection_id;
738 roam_types::selfref_match!(msg, payload {
739 MessagePayload::ConnectionClose(_) => {
741 if conn_id.is_root() {
742 warn!("received ConnectionClose for root connection");
743 } else {
744 debug!(conn_id = conn_id.0, "received ConnectionClose for virtual connection");
745 }
746 self.conns.remove(&conn_id);
750 self.maybe_request_shutdown_after_root_closed();
751 }
752 MessagePayload::ConnectionOpen(open) => {
753 self.handle_inbound_open(conn_id, open).await;
754 }
755 MessagePayload::ConnectionAccept(accept) => {
756 self.handle_inbound_accept(conn_id, accept);
757 }
758 MessagePayload::ConnectionReject(reject) => {
759 self.handle_inbound_reject(conn_id, reject);
760 }
761 MessagePayload::RequestMessage(r) => {
762 let conn_tx = match self.conns.get(&conn_id) {
763 Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
764 _ => return,
765 };
766 if conn_tx.send(r.map(ConnectionMessage::Request)).await.is_err() {
767 self.conns.remove(&conn_id);
768 self.maybe_request_shutdown_after_root_closed();
769 }
770 }
771 MessagePayload::ChannelMessage(c) => {
772 let conn_tx = match self.conns.get(&conn_id) {
773 Some(ConnectionSlot::Active(state)) => state.conn_tx.clone(),
774 _ => return,
775 };
776 if conn_tx.send(c.map(ConnectionMessage::Channel)).await.is_err() {
777 self.conns.remove(&conn_id);
778 self.maybe_request_shutdown_after_root_closed();
779 }
780 }
781 MessagePayload::Ping(ping) => {
782 let _ = self
783 .sess_core
784 .send(Message {
785 connection_id: conn_id,
786 payload: MessagePayload::Pong(roam_types::Pong { nonce: ping.nonce }),
787 })
788 .await;
789 }
790 MessagePayload::Pong(pong) => {
791 if conn_id.is_root() {
792 self.handle_keepalive_pong(pong.nonce, keepalive_runtime);
793 }
794 }
795 })
797 }
798
799 fn make_keepalive_runtime(&self) -> Option<KeepaliveRuntime> {
800 let config = self.keepalive?;
801 if config.ping_interval.is_zero() || config.pong_timeout.is_zero() {
802 warn!("keepalive disabled due to non-positive interval/timeout");
803 return None;
804 }
805 let now = tokio::time::Instant::now();
806 Some(KeepaliveRuntime {
807 ping_interval: config.ping_interval,
808 pong_timeout: config.pong_timeout,
809 next_ping_at: now + config.ping_interval,
810 waiting_pong_nonce: None,
811 pong_deadline: now,
812 next_ping_nonce: 1,
813 })
814 }
815
816 fn handle_keepalive_pong(&self, nonce: u64, keepalive_runtime: &mut Option<KeepaliveRuntime>) {
817 let Some(runtime) = keepalive_runtime.as_mut() else {
818 return;
819 };
820 if runtime.waiting_pong_nonce != Some(nonce) {
821 return;
822 }
823 runtime.waiting_pong_nonce = None;
824 runtime.next_ping_at = tokio::time::Instant::now() + runtime.ping_interval;
825 }
826
827 async fn handle_keepalive_tick(
828 &mut self,
829 keepalive_runtime: &mut Option<KeepaliveRuntime>,
830 ) -> bool {
831 let Some(runtime) = keepalive_runtime.as_mut() else {
832 return true;
833 };
834 let now = tokio::time::Instant::now();
835
836 if let Some(waiting_nonce) = runtime.waiting_pong_nonce {
837 if now >= runtime.pong_deadline {
838 warn!(
839 nonce = waiting_nonce,
840 timeout_ms = runtime.pong_timeout.as_millis(),
841 "keepalive timeout waiting for pong"
842 );
843 return false;
844 }
845 return true;
846 }
847
848 if now < runtime.next_ping_at {
849 return true;
850 }
851
852 let nonce = runtime.next_ping_nonce;
853 if self
854 .sess_core
855 .send(Message {
856 connection_id: ConnectionId::ROOT,
857 payload: MessagePayload::Ping(roam_types::Ping { nonce }),
858 })
859 .await
860 .is_err()
861 {
862 warn!("failed to send keepalive ping");
863 return false;
864 }
865
866 runtime.waiting_pong_nonce = Some(nonce);
867 runtime.pong_deadline = now + runtime.pong_timeout;
868 runtime.next_ping_at = now + runtime.ping_interval;
869 runtime.next_ping_nonce = runtime.next_ping_nonce.wrapping_add(1);
870 true
871 }
872
873 async fn handle_inbound_open(
874 &mut self,
875 conn_id: ConnectionId,
876 open: SelfRef<ConnectionOpen<'static>>,
877 ) {
878 let peer_parity = self.parity.other();
880 if !conn_id.has_parity(peer_parity) {
881 let _ = self
883 .sess_core
884 .send(Message {
885 connection_id: conn_id,
886 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
887 metadata: vec![],
888 }),
889 })
890 .await;
891 return;
892 }
893
894 if self.conns.contains_key(&conn_id) {
896 let _ = self
898 .sess_core
899 .send(Message {
900 connection_id: conn_id,
901 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
902 metadata: vec![],
903 }),
904 })
905 .await;
906 return;
907 }
908
909 let acceptor = match &self.on_connection {
912 Some(a) => a,
913 None => {
914 let _ = self
915 .sess_core
916 .send(Message {
917 connection_id: conn_id,
918 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
919 metadata: vec![],
920 }),
921 })
922 .await;
923 return;
924 }
925 };
926
927 match acceptor.accept(conn_id, &open.connection_settings, &open.metadata) {
928 Ok(accepted) => {
929 let handle = self.make_connection_handle(
931 conn_id,
932 accepted.settings.clone(),
933 open.connection_settings.clone(),
934 );
935
936 let _ = self
938 .sess_core
939 .send(Message {
940 connection_id: conn_id,
941 payload: MessagePayload::ConnectionAccept(roam_types::ConnectionAccept {
942 connection_settings: accepted.settings,
943 metadata: accepted.metadata,
944 }),
945 })
946 .await;
947
948 (accepted.setup)(handle);
950 }
951 Err(reject_metadata) => {
952 let _ = self
953 .sess_core
954 .send(Message {
955 connection_id: conn_id,
956 payload: MessagePayload::ConnectionReject(roam_types::ConnectionReject {
957 metadata: reject_metadata,
958 }),
959 })
960 .await;
961 }
962 }
963 }
964
965 fn handle_inbound_accept(
966 &mut self,
967 conn_id: ConnectionId,
968 accept: SelfRef<ConnectionAccept<'static>>,
969 ) {
970 let slot = self.conns.remove(&conn_id);
971 match slot {
972 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
973 let handle = self.make_connection_handle(
974 conn_id,
975 pending.local_settings.clone(),
976 accept.connection_settings.clone(),
977 );
978
979 if let Some(tx) = pending.result_tx.take() {
980 let _ = tx.send(Ok(handle));
981 }
982 }
983 Some(other) => {
984 self.conns.insert(conn_id, other);
986 }
987 None => {
988 }
990 }
991 }
992
993 fn handle_inbound_reject(
994 &mut self,
995 conn_id: ConnectionId,
996 reject: SelfRef<ConnectionReject<'static>>,
997 ) {
998 let slot = self.conns.remove(&conn_id);
999 match slot {
1000 Some(ConnectionSlot::PendingOutbound(mut pending)) => {
1001 if let Some(tx) = pending.result_tx.take() {
1002 let _ = tx.send(Err(SessionError::Rejected(reject.metadata.to_vec())));
1003 }
1004 }
1005 Some(other) => {
1006 self.conns.insert(conn_id, other);
1007 }
1008 None => {}
1009 }
1010 }
1011
1012 async fn handle_open_request(&mut self, req: OpenRequest) {
1014 let conn_id = self.conn_ids.alloc();
1015
1016 let send_result = self
1018 .sess_core
1019 .send(Message {
1020 connection_id: conn_id,
1021 payload: MessagePayload::ConnectionOpen(ConnectionOpen {
1022 connection_settings: req.settings.clone(),
1023 metadata: req.metadata,
1024 }),
1025 })
1026 .await;
1027
1028 if send_result.is_err() {
1029 let _ = req.result_tx.send(Err(SessionError::Protocol(
1030 "failed to send ConnectionOpen".into(),
1031 )));
1032 return;
1033 }
1034
1035 self.conns.insert(
1038 conn_id,
1039 ConnectionSlot::PendingOutbound(PendingOutboundData {
1040 local_settings: req.settings,
1041 result_tx: Some(req.result_tx),
1042 }),
1043 );
1044 }
1045
1046 async fn handle_close_request(&mut self, req: CloseRequest) {
1048 if req.conn_id.is_root() {
1049 let _ = req.result_tx.send(Err(SessionError::Protocol(
1050 "cannot close root connection".into(),
1051 )));
1052 return;
1053 }
1054
1055 if self.conns.remove(&req.conn_id).is_none() {
1058 let _ = req
1059 .result_tx
1060 .send(Err(SessionError::Protocol("connection not found".into())));
1061 return;
1062 }
1063
1064 let send_result = self
1066 .sess_core
1067 .send(Message {
1068 connection_id: req.conn_id,
1069 payload: MessagePayload::ConnectionClose(ConnectionClose {
1070 metadata: req.metadata,
1071 }),
1072 })
1073 .await;
1074
1075 if send_result.is_err() {
1076 let _ = req.result_tx.send(Err(SessionError::Protocol(
1077 "failed to send ConnectionClose".into(),
1078 )));
1079 return;
1080 }
1081
1082 let _ = req.result_tx.send(Ok(()));
1083 self.maybe_request_shutdown_after_root_closed();
1084 }
1085
1086 async fn handle_drop_control_request(&mut self, req: DropControlRequest) -> bool {
1087 match req {
1088 DropControlRequest::Shutdown => {
1089 debug!("session shutdown requested");
1090 false
1091 }
1092 DropControlRequest::Close(conn_id) => {
1093 if conn_id.is_root() {
1095 debug!("root callers dropped; internally closing root connection");
1097 self.root_closed_internal = true;
1098 return self.has_virtual_connections();
1100 }
1101
1102 if self.conns.remove(&conn_id).is_some() {
1103 let _ = self
1104 .sess_core
1105 .send(Message {
1106 connection_id: conn_id,
1107 payload: MessagePayload::ConnectionClose(ConnectionClose {
1108 metadata: vec![],
1109 }),
1110 })
1111 .await;
1112 }
1113
1114 !self.root_closed_internal || self.has_virtual_connections()
1115 }
1116 }
1117 }
1118
1119 fn has_virtual_connections(&self) -> bool {
1120 self.conns.keys().any(|id| !id.is_root())
1121 }
1122
1123 fn maybe_request_shutdown_after_root_closed(&self) {
1124 if self.root_closed_internal && !self.has_virtual_connections() {
1125 let _ = send_drop_control(&self.control_tx, DropControlRequest::Shutdown);
1126 }
1127 }
1128}
1129
1130pub(crate) struct SessionCore {
1131 tx: Box<dyn DynConduitTx>,
1132}
1133
1134impl SessionCore {
1135 pub(crate) async fn send<'a>(&self, msg: Message<'a>) -> Result<(), ()> {
1136 self.tx.send_msg(msg).await.map_err(|_| ())
1137 }
1138}
1139
1140#[cfg(not(target_arch = "wasm32"))]
1141type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
1142#[cfg(target_arch = "wasm32")]
1143type BoxFuture<'a, T> = Pin<Box<dyn std::future::Future<Output = T> + 'a>>;
1144
1145#[cfg(not(target_arch = "wasm32"))]
1146pub trait DynConduitTx: Send + Sync {
1147 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1148}
1149#[cfg(target_arch = "wasm32")]
1150pub trait DynConduitTx {
1151 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>>;
1152}
1153
1154impl<T> DynConduitTx for T
1157where
1158 T: ConduitTx<Msg = MessageFamily> + MaybeSend + MaybeSync,
1159 for<'p> <T as ConduitTx>::Permit<'p>: MaybeSend,
1160{
1161 fn send_msg<'a>(&'a self, msg: Message<'a>) -> BoxFuture<'a, std::io::Result<()>> {
1162 Box::pin(async move {
1163 let permit = self.reserve().await?;
1164 permit
1165 .send(msg)
1166 .map_err(|e| std::io::Error::other(e.to_string()))
1167 })
1168 }
1169}
1170
1171#[cfg(test)]
1172mod tests {
1173 use moire::sync::mpsc;
1174 use roam_types::{
1175 Backing, Conduit, ConnectionAccept, ConnectionReject, MessageFamily, SelfRef,
1176 };
1177
1178 use super::*;
1179
1180 type TestConduit = crate::BareConduit<MessageFamily, crate::MemoryLink>;
1181
1182 fn make_session() -> Session<TestConduit> {
1183 let (a, b) = crate::memory_link_pair(32);
1184 std::mem::forget(b);
1186 let conduit = crate::BareConduit::new(a);
1187 let (tx, rx) = conduit.split();
1188 let (_open_tx, open_rx) = mpsc::channel::<OpenRequest>("session.open.test", 4);
1189 let (_close_tx, close_rx) = mpsc::channel::<CloseRequest>("session.close.test", 4);
1190 let (control_tx, control_rx) = mpsc::unbounded_channel("session.control.test");
1191 Session::pre_handshake(
1192 tx, rx, None, open_rx, close_rx, control_tx, control_rx, None,
1193 )
1194 }
1195
1196 fn accept_ref() -> SelfRef<ConnectionAccept<'static>> {
1197 SelfRef::owning(
1198 Backing::Boxed(Box::<[u8]>::default()),
1199 ConnectionAccept {
1200 connection_settings: ConnectionSettings {
1201 parity: Parity::Even,
1202 max_concurrent_requests: 64,
1203 },
1204 metadata: vec![],
1205 },
1206 )
1207 }
1208
1209 fn reject_ref() -> SelfRef<ConnectionReject<'static>> {
1210 SelfRef::owning(
1211 Backing::Boxed(Box::<[u8]>::default()),
1212 ConnectionReject { metadata: vec![] },
1213 )
1214 }
1215
1216 #[tokio::test]
1217 async fn duplicate_connection_accept_is_ignored_after_first() {
1218 let mut session = make_session();
1219 let conn_id = ConnectionId(1);
1220 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1221
1222 session.conns.insert(
1223 conn_id,
1224 ConnectionSlot::PendingOutbound(PendingOutboundData {
1225 local_settings: ConnectionSettings {
1226 parity: Parity::Odd,
1227 max_concurrent_requests: 64,
1228 },
1229 result_tx: Some(result_tx),
1230 }),
1231 );
1232
1233 session.handle_inbound_accept(conn_id, accept_ref());
1234 let handle = result_rx
1235 .await
1236 .expect("pending outbound result should resolve")
1237 .expect("accept should resolve as Ok");
1238 assert_eq!(handle.connection_id(), conn_id);
1239
1240 session.handle_inbound_accept(conn_id, accept_ref());
1241 assert!(
1242 matches!(
1243 session.conns.get(&conn_id),
1244 Some(ConnectionSlot::Active(ConnectionState { id, .. })) if *id == conn_id
1245 ),
1246 "duplicate accept should keep existing active connection state"
1247 );
1248 }
1249
1250 #[tokio::test]
1251 async fn duplicate_connection_reject_is_ignored_after_first() {
1252 let mut session = make_session();
1253 let conn_id = ConnectionId(1);
1254 let (result_tx, result_rx) = moire::sync::oneshot::channel("session.test.open_result");
1255
1256 session.conns.insert(
1257 conn_id,
1258 ConnectionSlot::PendingOutbound(PendingOutboundData {
1259 local_settings: ConnectionSettings {
1260 parity: Parity::Odd,
1261 max_concurrent_requests: 64,
1262 },
1263 result_tx: Some(result_tx),
1264 }),
1265 );
1266
1267 session.handle_inbound_reject(conn_id, reject_ref());
1268 let result = result_rx
1269 .await
1270 .expect("pending outbound result should resolve");
1271 assert!(
1272 matches!(result, Err(SessionError::Rejected(_))),
1273 "expected rejection, got: {result:?}"
1274 );
1275
1276 session.handle_inbound_reject(conn_id, reject_ref());
1277 assert!(
1278 !session.conns.contains_key(&conn_id),
1279 "duplicate reject should not recreate connection state"
1280 );
1281 }
1282
1283 #[test]
1284 fn out_of_order_accept_or_reject_without_pending_is_ignored() {
1285 let mut session = make_session();
1286 let conn_id = ConnectionId(99);
1287
1288 session.handle_inbound_accept(conn_id, accept_ref());
1289 session.handle_inbound_reject(conn_id, reject_ref());
1290
1291 assert!(
1292 session.conns.is_empty(),
1293 "out-of-order accept/reject should not mutate empty connection table"
1294 );
1295 }
1296
1297 #[tokio::test]
1298 async fn close_request_clears_pending_outbound_open() {
1299 let mut session = make_session();
1300 let (open_result_tx, open_result_rx) = moire::sync::oneshot::channel("session.open.result");
1301 let (close_result_tx, close_result_rx) =
1302 moire::sync::oneshot::channel("session.close.result");
1303
1304 session.conns.insert(
1305 ConnectionId(1),
1306 ConnectionSlot::PendingOutbound(PendingOutboundData {
1307 local_settings: ConnectionSettings {
1308 parity: Parity::Odd,
1309 max_concurrent_requests: 64,
1310 },
1311 result_tx: Some(open_result_tx),
1312 }),
1313 );
1314
1315 session
1316 .handle_close_request(CloseRequest {
1317 conn_id: ConnectionId(1),
1318 metadata: vec![],
1319 result_tx: close_result_tx,
1320 })
1321 .await;
1322
1323 let close_result = close_result_rx
1324 .await
1325 .expect("close result should be delivered");
1326 assert!(
1327 close_result.is_ok(),
1328 "close should succeed for pending outbound connection"
1329 );
1330
1331 assert!(
1332 open_result_rx.await.is_err(),
1333 "pending open result channel should be closed once the pending slot is removed"
1334 );
1335 }
1336}