1use std::collections::HashMap;
29use std::future::Future;
30use std::io;
31use std::pin::Pin;
32use std::sync::Arc;
33use std::time::Duration;
34
35use facet::Facet;
36
37use crate::runtime::{Mutex, Receiver, channel, sleep, spawn, spawn_with_abort};
38use crate::{
39 ChannelError, ChannelRegistry, ConnectionHandle, Context, DriverMessage, MessageTransport,
40 ResponseData, RoamError, Role, ServiceDispatcher, TransportError,
41};
42use roam_wire::{ConnectionId, Hello, Message};
43
44#[derive(Debug, Clone)]
46pub struct Negotiated {
47 pub max_payload_size: u32,
49 pub initial_credit: u32,
51}
52
53#[derive(Debug)]
55pub enum ConnectionError {
56 Io(std::io::Error),
58 ProtocolViolation {
60 rule_id: &'static str,
62 context: String,
64 },
65 Dispatch(String),
67 Closed,
69}
70
71impl std::fmt::Display for ConnectionError {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 ConnectionError::Io(e) => write!(f, "IO error: {e}"),
75 ConnectionError::ProtocolViolation { rule_id, context } => {
76 write!(f, "protocol violation: {rule_id}: {context}")
77 }
78 ConnectionError::Dispatch(msg) => write!(f, "dispatch error: {msg}"),
79 ConnectionError::Closed => write!(f, "connection closed"),
80 }
81 }
82}
83
84impl std::error::Error for ConnectionError {
85 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
86 match self {
87 ConnectionError::Io(e) => Some(e),
88 _ => None,
89 }
90 }
91}
92
93impl From<std::io::Error> for ConnectionError {
94 fn from(e: std::io::Error) -> Self {
95 ConnectionError::Io(e)
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct HandshakeConfig {
102 pub max_payload_size: u32,
104 pub initial_channel_credit: u32,
106}
107
108impl Default for HandshakeConfig {
109 fn default() -> Self {
110 Self {
111 max_payload_size: 1024 * 1024, initial_channel_credit: 64 * 1024, }
114 }
115}
116
117impl HandshakeConfig {
118 pub fn to_hello(&self) -> Hello {
120 Hello::V2 {
121 max_payload_size: self.max_payload_size,
122 initial_channel_credit: self.initial_channel_credit,
123 }
124 }
125}
126
127pub trait MessageConnector: Send + Sync + 'static {
132 type Transport: MessageTransport;
134
135 fn connect(&self) -> impl Future<Output = io::Result<Self::Transport>> + Send;
137}
138
139#[derive(Debug, Clone)]
141pub struct RetryPolicy {
142 pub max_attempts: u32,
144 pub initial_backoff: Duration,
146 pub max_backoff: Duration,
148 pub backoff_multiplier: f64,
150}
151
152impl Default for RetryPolicy {
153 fn default() -> Self {
154 Self {
155 max_attempts: 3,
156 initial_backoff: Duration::from_millis(100),
157 max_backoff: Duration::from_secs(5),
158 backoff_multiplier: 2.0,
159 }
160 }
161}
162
163impl RetryPolicy {
164 pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
166 let multiplier = self
167 .backoff_multiplier
168 .powi(attempt.saturating_sub(1) as i32);
169 let backoff = self.initial_backoff.mul_f64(multiplier);
170 backoff.min(self.max_backoff)
171 }
172}
173
174#[derive(Debug)]
176pub enum ConnectError {
177 RetriesExhausted {
179 original: io::Error,
181 attempts: u32,
183 },
184 ConnectFailed(io::Error),
186 Rpc(TransportError),
188 Rejected(String),
190}
191
192impl std::fmt::Display for ConnectError {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 ConnectError::RetriesExhausted { original, attempts } => {
196 write!(
197 f,
198 "reconnection failed after {attempts} attempts: {original}"
199 )
200 }
201 ConnectError::ConnectFailed(e) => write!(f, "connection failed: {e}"),
202 ConnectError::Rpc(e) => write!(f, "RPC error: {e}"),
203 ConnectError::Rejected(reason) => write!(f, "connection rejected: {reason}"),
204 }
205 }
206}
207
208impl std::error::Error for ConnectError {
209 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
210 match self {
211 ConnectError::RetriesExhausted { original, .. } => Some(original),
212 ConnectError::ConnectFailed(e) => Some(e),
213 ConnectError::Rpc(e) => Some(e),
214 ConnectError::Rejected(_) => None,
215 }
216 }
217}
218
219impl From<TransportError> for ConnectError {
220 fn from(e: TransportError) -> Self {
221 ConnectError::Rpc(e)
222 }
223}
224
225pub async fn accept_framed<T, D>(
261 transport: T,
262 config: HandshakeConfig,
263 dispatcher: D,
264) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
265where
266 T: MessageTransport,
267 D: ServiceDispatcher,
268{
269 establish(transport, config.to_hello(), dispatcher, Role::Acceptor).await
270}
271
272pub fn connect_framed<C, D>(
281 connector: C,
282 config: HandshakeConfig,
283 dispatcher: D,
284) -> FramedClient<C, D>
285where
286 C: MessageConnector,
287 D: ServiceDispatcher + Clone,
288{
289 FramedClient {
290 connector: Arc::new(connector),
291 config,
292 dispatcher,
293 retry_policy: RetryPolicy::default(),
294 state: Arc::new(Mutex::new(None)),
295 }
296}
297
298pub fn connect_framed_with_policy<C, D>(
300 connector: C,
301 config: HandshakeConfig,
302 dispatcher: D,
303 retry_policy: RetryPolicy,
304) -> FramedClient<C, D>
305where
306 C: MessageConnector,
307 D: ServiceDispatcher + Clone,
308{
309 FramedClient {
310 connector: Arc::new(connector),
311 config,
312 dispatcher,
313 retry_policy,
314 state: Arc::new(Mutex::new(None)),
315 }
316}
317
318struct FramedClientState {
320 handle: ConnectionHandle,
321}
322
323pub struct FramedClient<C, D> {
328 connector: Arc<C>,
329 config: HandshakeConfig,
330 dispatcher: D,
331 retry_policy: RetryPolicy,
332 state: Arc<Mutex<Option<FramedClientState>>>,
333}
334
335impl<C, D> Clone for FramedClient<C, D>
336where
337 D: Clone,
338{
339 fn clone(&self) -> Self {
340 Self {
341 connector: self.connector.clone(),
342 config: self.config.clone(),
343 dispatcher: self.dispatcher.clone(),
344 retry_policy: self.retry_policy.clone(),
345 state: self.state.clone(),
346 }
347 }
348}
349
350impl<C, D> FramedClient<C, D>
351where
352 C: MessageConnector,
353 D: ServiceDispatcher + Clone + 'static,
354{
355 pub async fn handle(&self) -> Result<ConnectionHandle, ConnectError> {
357 self.ensure_connected().await
358 }
359
360 async fn ensure_connected(&self) -> Result<ConnectionHandle, ConnectError> {
361 let mut state = self.state.lock().await;
362
363 if let Some(ref conn) = *state {
364 return Ok(conn.handle.clone());
367 }
368
369 let conn = self.connect_internal().await?;
370 let handle = conn.handle.clone();
371 *state = Some(conn);
372 Ok(handle)
373 }
374
375 async fn connect_internal(&self) -> Result<FramedClientState, ConnectError> {
376 let transport = self
377 .connector
378 .connect()
379 .await
380 .map_err(ConnectError::ConnectFailed)?;
381
382 let (handle, _incoming, driver) = establish(
383 transport,
384 self.config.to_hello(),
385 self.dispatcher.clone(),
386 Role::Initiator,
387 )
388 .await
389 .map_err(|e| ConnectError::ConnectFailed(connection_error_to_io(e)))?;
390
391 spawn(async move {
396 let _ = driver.run().await;
397 });
398
399 Ok(FramedClientState { handle })
400 }
401
402 pub async fn call_raw(
404 &self,
405 method_id: u64,
406 payload: Vec<u8>,
407 ) -> Result<Vec<u8>, ConnectError> {
408 let mut last_error: Option<io::Error> = None;
409 let mut attempt = 0u32;
410
411 loop {
412 let handle = match self.ensure_connected().await {
413 Ok(h) => h,
414 Err(ConnectError::ConnectFailed(e)) => {
415 attempt += 1;
416 if attempt >= self.retry_policy.max_attempts {
417 return Err(ConnectError::RetriesExhausted {
418 original: last_error.unwrap_or(e),
419 attempts: attempt,
420 });
421 }
422 last_error = Some(e);
423 let backoff = self.retry_policy.backoff_for_attempt(attempt);
424 sleep(backoff).await;
425 continue;
426 }
427 Err(e) => return Err(e),
428 };
429
430 match handle.call_raw(method_id, payload.clone()).await {
431 Ok(response) => return Ok(response),
432 Err(TransportError::Encode(e)) => {
433 return Err(ConnectError::Rpc(TransportError::Encode(e)));
434 }
435 Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
436 {
437 let mut state = self.state.lock().await;
438 *state = None;
439 }
440
441 attempt += 1;
442 if attempt >= self.retry_policy.max_attempts {
443 let error = last_error.unwrap_or_else(|| {
444 io::Error::new(io::ErrorKind::ConnectionReset, "connection closed")
445 });
446 return Err(ConnectError::RetriesExhausted {
447 original: error,
448 attempts: attempt,
449 });
450 }
451
452 last_error = Some(io::Error::new(
453 io::ErrorKind::ConnectionReset,
454 "connection closed",
455 ));
456 let backoff = self.retry_policy.backoff_for_attempt(attempt);
457 sleep(backoff).await;
458 }
459 }
460 }
461 }
462}
463
464impl<C, D> crate::Caller for FramedClient<C, D>
465where
466 C: MessageConnector,
467 D: ServiceDispatcher + Clone + 'static,
468{
469 async fn call_with_metadata<T: Facet<'static> + Send>(
470 &self,
471 method_id: u64,
472 args: &mut T,
473 metadata: roam_wire::Metadata,
474 ) -> Result<ResponseData, TransportError> {
475 let mut attempt = 0u32;
476
477 loop {
478 let handle = match self.ensure_connected().await {
479 Ok(h) => h,
480 Err(ConnectError::ConnectFailed(_)) => {
481 attempt += 1;
482 if attempt >= self.retry_policy.max_attempts {
483 return Err(TransportError::ConnectionClosed);
484 }
485 let backoff = self.retry_policy.backoff_for_attempt(attempt);
486 sleep(backoff).await;
487 continue;
488 }
489 Err(ConnectError::RetriesExhausted { .. }) => {
490 return Err(TransportError::ConnectionClosed);
491 }
492 Err(ConnectError::Rpc(e)) => return Err(e),
493 Err(ConnectError::Rejected(_)) => {
494 return Err(TransportError::ConnectionClosed);
496 }
497 };
498
499 match handle
500 .call_with_metadata(method_id, args, metadata.clone())
501 .await
502 {
503 Ok(response) => return Ok(response),
504 Err(TransportError::Encode(e)) => {
505 return Err(TransportError::Encode(e));
506 }
507 Err(TransportError::ConnectionClosed) | Err(TransportError::DriverGone) => {
508 {
509 let mut state = self.state.lock().await;
510 *state = None;
511 }
512
513 attempt += 1;
514 if attempt >= self.retry_policy.max_attempts {
515 return Err(TransportError::ConnectionClosed);
516 }
517
518 let backoff = self.retry_policy.backoff_for_attempt(attempt);
519 sleep(backoff).await;
520 }
521 }
522 }
523 }
524
525 fn bind_response_streams<R: Facet<'static>>(&self, response: &mut R, channels: &[u64]) {
526 let _ = (response, channels);
533 }
534}
535
536fn connection_error_to_io(e: ConnectionError) -> io::Error {
537 match e {
538 ConnectionError::Io(io_err) => io_err,
539 ConnectionError::ProtocolViolation { rule_id, context } => io::Error::new(
540 io::ErrorKind::InvalidData,
541 format!("protocol violation: {rule_id}: {context}"),
542 ),
543 ConnectionError::Dispatch(msg) => io::Error::other(format!("dispatch error: {msg}")),
544 ConnectionError::Closed => {
545 io::Error::new(io::ErrorKind::ConnectionReset, "connection closed")
546 }
547 }
548}
549
550struct ConnectionState {
562 #[allow(dead_code)]
564 conn_id: ConnectionId,
565 handle: ConnectionHandle,
567 server_channel_registry: ChannelRegistry,
569 dispatcher: Option<Box<dyn ServiceDispatcher>>,
572 pending_responses:
574 HashMap<u64, crate::runtime::OneshotSender<Result<ResponseData, TransportError>>>,
575 in_flight_server_requests: HashMap<u64, crate::runtime::AbortHandle>,
578}
579
580impl ConnectionState {
581 fn new(
583 conn_id: ConnectionId,
584 driver_tx: crate::runtime::Sender<DriverMessage>,
585 role: Role,
586 initial_credit: u32,
587 diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
588 dispatcher: Option<Box<dyn ServiceDispatcher>>,
589 ) -> Self {
590 let handle = ConnectionHandle::new_with_diagnostics(
591 conn_id,
592 driver_tx.clone(),
593 role,
594 initial_credit,
595 diagnostic_state,
596 );
597 let server_channel_registry =
598 ChannelRegistry::new_with_credit_and_role(conn_id, initial_credit, driver_tx, role);
599 Self {
600 conn_id,
601 handle,
602 server_channel_registry,
603 dispatcher,
604 pending_responses: HashMap::new(),
605 in_flight_server_requests: HashMap::new(),
606 }
607 }
608
609 fn fail_pending_responses(&mut self) {
611 for (_, tx) in self.pending_responses.drain() {
612 let _ = tx.send(Err(TransportError::ConnectionClosed));
613 }
614 }
615
616 fn abort_in_flight_requests(&mut self) {
618 for (_, abort_handle) in self.in_flight_server_requests.drain() {
619 abort_handle.abort();
620 }
621 }
622}
623
624pub struct IncomingConnection {
630 request_id: u64,
632 pub metadata: roam_wire::Metadata,
634 response_tx: crate::runtime::OneshotSender<IncomingConnectionResponse>,
636}
637
638impl IncomingConnection {
639 pub async fn accept(
650 self,
651 metadata: roam_wire::Metadata,
652 dispatcher: Option<Box<dyn ServiceDispatcher>>,
653 ) -> Result<ConnectionHandle, TransportError> {
654 let (handle_tx, handle_rx) = crate::runtime::oneshot();
655 let _ = self.response_tx.send(IncomingConnectionResponse::Accept {
656 request_id: self.request_id,
657 metadata,
658 dispatcher,
659 handle_tx,
660 });
661 let result: Result<ConnectionHandle, _> =
662 handle_rx.await.map_err(|_| TransportError::DriverGone)?;
663 result
664 }
665
666 pub fn reject(self, reason: String, metadata: roam_wire::Metadata) {
668 let _ = self.response_tx.send(IncomingConnectionResponse::Reject {
669 request_id: self.request_id,
670 reason,
671 metadata,
672 });
673 }
674}
675
676enum IncomingConnectionResponse {
678 Accept {
679 request_id: u64,
680 metadata: roam_wire::Metadata,
681 dispatcher: Option<Box<dyn ServiceDispatcher>>,
682 handle_tx: crate::runtime::OneshotSender<Result<ConnectionHandle, TransportError>>,
683 },
684 Reject {
685 request_id: u64,
686 reason: String,
687 metadata: roam_wire::Metadata,
688 },
689}
690
691struct PendingConnect {
693 response_tx: crate::runtime::OneshotSender<Result<ConnectionHandle, ConnectError>>,
695 dispatcher: Option<Box<dyn ServiceDispatcher>>,
697}
698
699pub struct Driver<T, D> {
711 io: T,
712 dispatcher: D,
713 #[allow(dead_code)]
714 role: Role,
715 negotiated: Negotiated,
716 driver_rx: Receiver<DriverMessage>,
718 driver_tx: crate::runtime::Sender<DriverMessage>,
720 connections: HashMap<ConnectionId, ConnectionState>,
722 next_conn_id: u64,
725 pending_connects: HashMap<u64, PendingConnect>,
727 incoming_connections_tx: Option<crate::runtime::Sender<IncomingConnection>>,
730 incoming_response_rx: Option<Receiver<IncomingConnectionResponse>>,
732 incoming_response_tx: crate::runtime::Sender<IncomingConnectionResponse>,
733 diagnostic_state: Option<Arc<crate::diagnostic::DiagnosticState>>,
735}
736
737impl<T, D> Driver<T, D>
738where
739 T: MessageTransport,
740 D: ServiceDispatcher,
741{
742 pub fn root_handle(&self) -> ConnectionHandle {
747 self.connections
748 .get(&ConnectionId::ROOT)
749 .expect("root connection always exists")
750 .handle
751 .clone()
752 }
753
754 pub async fn run(mut self) -> Result<(), ConnectionError> {
756 use futures_util::FutureExt;
757
758 loop {
759 futures_util::select! {
760 msg = self.driver_rx.recv().fuse() => {
761 if let Some(msg) = msg {
762 self.handle_driver_message(msg).await?;
763 }
764 }
765
766 response = async {
768 if let Some(rx) = &mut self.incoming_response_rx {
769 rx.recv().await
770 } else {
771 std::future::pending().await
772 }
773 }.fuse() => {
774 if let Some(response) = response {
775 self.handle_incoming_response(response).await?;
776 }
777 }
778
779 result = self.io.recv().fuse() => {
780 match self.handle_recv(result).await {
781 Ok(true) => continue,
782 Ok(false) => return Ok(()),
783 Err(e) => return Err(e),
784 }
785 }
786 }
787 }
788 }
789
790 async fn handle_incoming_response(
792 &mut self,
793 response: IncomingConnectionResponse,
794 ) -> Result<(), ConnectionError> {
795 match response {
796 IncomingConnectionResponse::Accept {
797 request_id,
798 metadata,
799 dispatcher,
800 handle_tx,
801 } => {
802 let conn_id = ConnectionId::new(self.next_conn_id);
805 self.next_conn_id += 1;
806
807 let conn_state = ConnectionState::new(
809 conn_id,
810 self.driver_tx.clone(),
811 self.role,
812 self.negotiated.initial_credit,
813 self.diagnostic_state.clone(),
814 dispatcher,
815 );
816 let handle = conn_state.handle.clone();
817 self.connections.insert(conn_id, conn_state);
818
819 let msg = Message::Accept {
821 request_id,
822 conn_id,
823 metadata,
824 };
825 self.io.send(&msg).await?;
826
827 let _ = handle_tx.send(Ok(handle));
829 }
830 IncomingConnectionResponse::Reject {
831 request_id,
832 reason,
833 metadata,
834 } => {
835 let msg = Message::Reject {
836 request_id,
837 reason,
838 metadata,
839 };
840 self.io.send(&msg).await?;
841 }
842 }
843 Ok(())
844 }
845
846 async fn handle_driver_message(&mut self, msg: DriverMessage) -> Result<(), ConnectionError> {
847 match msg {
848 DriverMessage::Call {
849 conn_id,
850 request_id,
851 method_id,
852 metadata,
853 channels,
854 payload,
855 response_tx,
856 } => {
857 if let Some(conn) = self.connections.get_mut(&conn_id) {
859 conn.pending_responses.insert(request_id, response_tx);
860 } else {
861 let _ = response_tx.send(Err(TransportError::ConnectionClosed));
863 return Ok(());
864 }
865 let req = Message::Request {
866 conn_id,
867 request_id,
868 method_id,
869 metadata,
870 channels,
871 payload,
872 };
873 self.io.send(&req).await?;
874 }
875 DriverMessage::Data {
876 conn_id,
877 channel_id,
878 payload,
879 } => {
880 let wire_msg = Message::Data {
881 conn_id,
882 channel_id,
883 payload,
884 };
885 self.io.send(&wire_msg).await?;
886 }
887 DriverMessage::Close {
888 conn_id,
889 channel_id,
890 } => {
891 let wire_msg = Message::Close {
892 conn_id,
893 channel_id,
894 };
895 self.io.send(&wire_msg).await?;
896 }
897 DriverMessage::Response {
898 conn_id,
899 request_id,
900 channels,
901 payload,
902 } => {
903 let should_send = if let Some(conn) = self.connections.get_mut(&conn_id) {
907 conn.in_flight_server_requests.remove(&request_id).is_some()
908 } else {
909 false
910 };
911 if !should_send {
912 return Ok(());
913 }
914 let wire_msg = Message::Response {
915 conn_id,
916 request_id,
917 metadata: vec![],
918 channels,
919 payload,
920 };
921 self.io.send(&wire_msg).await?;
922 }
923 DriverMessage::Connect {
924 request_id,
925 metadata,
926 response_tx,
927 dispatcher,
928 } => {
929 self.pending_connects.insert(
934 request_id,
935 PendingConnect {
936 response_tx,
937 dispatcher,
938 },
939 );
940 let wire_msg = Message::Connect {
942 request_id,
943 metadata,
944 };
945 self.io.send(&wire_msg).await?;
946 }
947 }
948 Ok(())
949 }
950
951 async fn handle_recv(
952 &mut self,
953 result: std::io::Result<Option<Message>>,
954 ) -> Result<bool, ConnectionError> {
955 let msg = match result {
956 Ok(Some(m)) => m,
957 Ok(None) => return Ok(false),
958 Err(e) => {
959 let raw = self.io.last_decoded();
960 if raw.len() >= 2 && raw[0] == 0x00 && raw[1] != 0x00 {
961 return Err(self.goodbye("message.hello.unknown-version").await);
962 }
963 if !raw.is_empty() && raw[0] >= 12 {
964 return Err(self.goodbye("message.unknown-variant").await);
965 }
966 if e.kind() == std::io::ErrorKind::InvalidData {
967 return Err(self.goodbye("message.decode-error").await);
968 }
969 return Err(ConnectionError::Io(e));
970 }
971 };
972
973 match self.handle_message(msg).await {
974 Ok(()) => Ok(true),
975 Err(ConnectionError::Closed) => Ok(false),
976 Err(e) => Err(e),
977 }
978 }
979
980 async fn handle_message(&mut self, msg: Message) -> Result<(), ConnectionError> {
981 match msg {
982 Message::Hello(_) => {
983 }
985 Message::Connect {
986 request_id,
987 metadata,
988 } => {
989 if let Some(tx) = &self.incoming_connections_tx {
992 let (response_tx, response_rx) = crate::runtime::oneshot();
994 let incoming = IncomingConnection {
995 request_id,
996 metadata,
997 response_tx,
998 };
999 if tx.try_send(incoming).is_ok() {
1000 let incoming_response_tx = self.incoming_response_tx.clone();
1002 spawn(async move {
1003 if let Ok(response) = response_rx.await {
1004 let _ = incoming_response_tx.send(response).await;
1005 }
1006 });
1007 } else {
1008 let msg = Message::Reject {
1010 request_id,
1011 reason: "not listening".into(),
1012 metadata: vec![],
1013 };
1014 self.io.send(&msg).await?;
1015 }
1016 } else {
1017 let msg = Message::Reject {
1020 request_id,
1021 reason: "not listening".into(),
1022 metadata: vec![],
1023 };
1024 self.io.send(&msg).await?;
1025 }
1026 }
1027 Message::Accept {
1028 request_id,
1029 conn_id,
1030 metadata: _,
1031 } => {
1032 if let Some(pending) = self.pending_connects.remove(&request_id) {
1037 let conn_state = ConnectionState::new(
1041 conn_id,
1042 self.driver_tx.clone(),
1043 self.role,
1044 self.negotiated.initial_credit,
1045 self.diagnostic_state.clone(),
1046 pending.dispatcher,
1047 );
1048 let handle = conn_state.handle.clone();
1049 self.connections.insert(conn_id, conn_state);
1050 let _ = pending.response_tx.send(Ok(handle));
1051 }
1052 }
1054 Message::Reject {
1055 request_id,
1056 reason,
1057 metadata: _,
1058 } => {
1059 if let Some(pending) = self.pending_connects.remove(&request_id) {
1063 let _ = pending
1064 .response_tx
1065 .send(Err(ConnectError::Rejected(reason)));
1066 }
1067 }
1069 Message::Goodbye { conn_id, reason: _ } => {
1070 if conn_id.is_root() {
1072 for (_, mut conn) in self.connections.drain() {
1074 conn.fail_pending_responses();
1075 conn.abort_in_flight_requests();
1076 }
1077 return Err(ConnectionError::Closed);
1078 } else {
1079 if let Some(mut conn) = self.connections.remove(&conn_id) {
1082 conn.fail_pending_responses();
1083 conn.abort_in_flight_requests();
1084 }
1085 }
1086 }
1087 Message::Request {
1088 conn_id,
1089 request_id,
1090 method_id,
1091 metadata,
1092 channels,
1093 payload,
1094 } => {
1095 self.handle_incoming_request(
1096 conn_id, request_id, method_id, metadata, channels, payload,
1097 )
1098 .await?;
1099 }
1100 Message::Response {
1101 conn_id,
1102 request_id,
1103 channels,
1104 payload,
1105 ..
1106 } => {
1107 if let Some(conn) = self.connections.get_mut(&conn_id)
1109 && let Some(tx) = conn.pending_responses.remove(&request_id)
1110 {
1111 let _ = tx.send(Ok(ResponseData { payload, channels }));
1112 }
1113 }
1115 Message::Cancel {
1116 conn_id,
1117 request_id,
1118 } => {
1119 self.handle_cancel(conn_id, request_id).await?;
1122 }
1123 Message::Data {
1124 conn_id,
1125 channel_id,
1126 payload,
1127 } => {
1128 self.handle_data(conn_id, channel_id, payload).await?;
1129 }
1130 Message::Close {
1131 conn_id,
1132 channel_id,
1133 } => {
1134 self.handle_close(conn_id, channel_id).await?;
1135 }
1136 Message::Reset {
1137 conn_id,
1138 channel_id,
1139 } => {
1140 self.handle_reset(conn_id, channel_id)?;
1141 }
1142 Message::Credit {
1143 conn_id,
1144 channel_id,
1145 bytes,
1146 } => {
1147 self.handle_credit(conn_id, channel_id, bytes)?;
1148 }
1149 }
1150 Ok(())
1151 }
1152
1153 async fn handle_incoming_request(
1154 &mut self,
1155 conn_id: ConnectionId,
1156 request_id: u64,
1157 method_id: u64,
1158 metadata: Vec<(String, roam_wire::MetadataValue)>,
1159 channels: Vec<u64>,
1160 payload: Vec<u8>,
1161 ) -> Result<(), ConnectionError> {
1162 let conn = match self.connections.get_mut(&conn_id) {
1164 Some(c) => c,
1165 None => {
1166 return Err(self.goodbye("message.conn-id").await);
1168 }
1169 };
1170
1171 if conn.in_flight_server_requests.contains_key(&request_id) {
1173 return Err(self.goodbye("call.request-id.duplicate-detection").await);
1174 }
1175
1176 if let Err(rule_id) = roam_wire::validate_metadata(&metadata) {
1177 return Err(self.goodbye(rule_id).await);
1178 }
1179
1180 if payload.len() as u32 > self.negotiated.max_payload_size {
1181 return Err(self.goodbye("flow.call.payload-limit").await);
1182 }
1183
1184 let cx = Context::new(
1185 conn_id,
1186 roam_wire::RequestId::new(request_id),
1187 roam_wire::MethodId::new(method_id),
1188 metadata,
1189 channels,
1190 );
1191
1192 let dispatcher: &dyn ServiceDispatcher = if let Some(ref conn_dispatcher) = conn.dispatcher
1194 {
1195 conn_dispatcher.as_ref()
1196 } else {
1197 &self.dispatcher
1198 };
1199
1200 debug!(
1201 conn_id = conn_id.raw(),
1202 request_id, method_id, "dispatching incoming request"
1203 );
1204
1205 let handler_fut = dispatcher.dispatch(&cx, payload, &mut conn.server_channel_registry);
1206
1207 let abort_handle = spawn_with_abort(async move {
1209 handler_fut.await;
1210 });
1211 conn.in_flight_server_requests
1212 .insert(request_id, abort_handle);
1213 Ok(())
1214 }
1215
1216 async fn handle_cancel(
1222 &mut self,
1223 conn_id: ConnectionId,
1224 request_id: u64,
1225 ) -> Result<(), ConnectionError> {
1226 let conn = match self.connections.get_mut(&conn_id) {
1228 Some(c) => c,
1229 None => {
1230 return Ok(());
1232 }
1233 };
1234
1235 if let Some(abort_handle) = conn.in_flight_server_requests.remove(&request_id) {
1237 abort_handle.abort();
1239
1240 let wire_msg = Message::Response {
1243 conn_id,
1244 request_id,
1245 metadata: vec![],
1246 channels: vec![],
1247 payload: vec![1, 3],
1249 };
1250 self.io.send(&wire_msg).await?;
1251 }
1252 Ok(())
1255 }
1256
1257 async fn handle_data(
1258 &mut self,
1259 conn_id: ConnectionId,
1260 channel_id: u64,
1261 payload: Vec<u8>,
1262 ) -> Result<(), ConnectionError> {
1263 if channel_id == 0 {
1264 return Err(self.goodbye("channeling.id.zero-reserved").await);
1265 }
1266
1267 if payload.len() as u32 > self.negotiated.max_payload_size {
1268 return Err(self.goodbye("flow.call.payload-limit").await);
1269 }
1270
1271 let conn = match self.connections.get_mut(&conn_id) {
1273 Some(c) => c,
1274 None => return Err(self.goodbye("message.conn-id").await),
1275 };
1276
1277 let result = if conn.server_channel_registry.contains_incoming(channel_id) {
1278 conn.server_channel_registry
1279 .route_data(channel_id, payload)
1280 .await
1281 } else if conn.handle.contains_channel(channel_id) {
1282 conn.handle.route_data(channel_id, payload).await
1283 } else {
1284 Err(ChannelError::Unknown)
1285 };
1286
1287 match result {
1288 Ok(()) => Ok(()),
1289 Err(ChannelError::Unknown) => Err(self.goodbye("channeling.unknown").await),
1290 Err(ChannelError::DataAfterClose) => {
1291 Err(self.goodbye("channeling.data-after-close").await)
1292 }
1293 Err(ChannelError::CreditOverrun) => {
1294 Err(self.goodbye("flow.channel.credit-overrun").await)
1295 }
1296 }
1297 }
1298
1299 async fn handle_close(
1300 &mut self,
1301 conn_id: ConnectionId,
1302 channel_id: u64,
1303 ) -> Result<(), ConnectionError> {
1304 if channel_id == 0 {
1305 return Err(self.goodbye("channeling.id.zero-reserved").await);
1306 }
1307
1308 let conn = match self.connections.get_mut(&conn_id) {
1309 Some(c) => c,
1310 None => return Err(self.goodbye("message.conn-id").await),
1311 };
1312
1313 if conn.server_channel_registry.contains(channel_id) {
1314 conn.server_channel_registry.close(channel_id);
1315 } else if conn.handle.contains_channel(channel_id) {
1316 conn.handle.close_channel(channel_id);
1317 } else {
1318 return Err(self.goodbye("channeling.unknown").await);
1319 }
1320 Ok(())
1321 }
1322
1323 fn handle_reset(
1324 &mut self,
1325 conn_id: ConnectionId,
1326 channel_id: u64,
1327 ) -> Result<(), ConnectionError> {
1328 if let Some(conn) = self.connections.get_mut(&conn_id) {
1329 if conn.server_channel_registry.contains(channel_id) {
1330 conn.server_channel_registry.reset(channel_id);
1331 } else if conn.handle.contains_channel(channel_id) {
1332 conn.handle.reset_channel(channel_id);
1333 }
1334 }
1335 Ok(())
1336 }
1337
1338 fn handle_credit(
1339 &mut self,
1340 conn_id: ConnectionId,
1341 channel_id: u64,
1342 bytes: u32,
1343 ) -> Result<(), ConnectionError> {
1344 if let Some(conn) = self.connections.get_mut(&conn_id) {
1345 if conn.server_channel_registry.contains(channel_id) {
1346 conn.server_channel_registry
1347 .receive_credit(channel_id, bytes);
1348 } else if conn.handle.contains_channel(channel_id) {
1349 conn.handle.receive_credit(channel_id, bytes);
1350 }
1351 }
1352 Ok(())
1353 }
1354
1355 async fn goodbye(&mut self, rule_id: &'static str) -> ConnectionError {
1356 for (_, conn) in self.connections.iter_mut() {
1358 conn.fail_pending_responses();
1359 conn.abort_in_flight_requests();
1360 }
1361
1362 let _ = self
1363 .io
1364 .send(&Message::Goodbye {
1365 conn_id: ConnectionId::ROOT,
1366 reason: rule_id.into(),
1367 })
1368 .await;
1369
1370 ConnectionError::ProtocolViolation {
1371 rule_id,
1372 context: String::new(),
1373 }
1374 }
1375}
1376
1377pub async fn initiate_framed<T, D>(
1393 transport: T,
1394 config: HandshakeConfig,
1395 dispatcher: D,
1396) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
1397where
1398 T: MessageTransport,
1399 D: ServiceDispatcher,
1400{
1401 establish(transport, config.to_hello(), dispatcher, Role::Initiator).await
1402}
1403
1404pub type IncomingConnections = Receiver<IncomingConnection>;
1416
1417async fn establish<T, D>(
1418 mut io: T,
1419 our_hello: Hello,
1420 dispatcher: D,
1421 role: Role,
1422) -> Result<(ConnectionHandle, IncomingConnections, Driver<T, D>), ConnectionError>
1423where
1424 T: MessageTransport,
1425 D: ServiceDispatcher,
1426{
1427 io.send(&Message::Hello(our_hello.clone())).await?;
1429
1430 let peer_hello = match io.recv_timeout(Duration::from_secs(5)).await {
1432 Ok(Some(Message::Hello(Hello::V2 {
1433 max_payload_size,
1434 initial_channel_credit,
1435 }))) => Hello::V2 {
1436 max_payload_size,
1437 initial_channel_credit,
1438 },
1439 Ok(Some(Message::Hello(Hello::V1 { .. }))) => {
1440 let _ = io
1442 .send(&Message::Goodbye {
1443 conn_id: ConnectionId::ROOT,
1444 reason: "message.hello.unknown-version".into(),
1445 })
1446 .await;
1447 return Err(ConnectionError::ProtocolViolation {
1448 rule_id: "message.hello.unknown-version",
1449 context: "received Hello::V1, but V1 is no longer supported".into(),
1450 });
1451 }
1452 Ok(Some(_)) => {
1453 let _ = io
1454 .send(&Message::Goodbye {
1455 conn_id: ConnectionId::ROOT,
1456 reason: "message.hello.ordering".into(),
1457 })
1458 .await;
1459 return Err(ConnectionError::ProtocolViolation {
1460 rule_id: "message.hello.ordering",
1461 context: "received non-Hello before Hello exchange".into(),
1462 });
1463 }
1464 Ok(None) => return Err(ConnectionError::Closed),
1465 Err(e) => {
1466 let raw = io.last_decoded();
1467 let is_unknown_hello = raw.len() >= 2 && raw[0] == 0x00 && raw[1] > 0x01;
1468 let version = if is_unknown_hello { raw[1] } else { 0 };
1469
1470 if is_unknown_hello {
1471 let _ = io
1472 .send(&Message::Goodbye {
1473 conn_id: ConnectionId::ROOT,
1474 reason: "message.hello.unknown-version".into(),
1475 })
1476 .await;
1477 return Err(ConnectionError::ProtocolViolation {
1478 rule_id: "message.hello.unknown-version",
1479 context: format!("unknown Hello version: {version}"),
1480 });
1481 }
1482 return Err(ConnectionError::Io(e));
1483 }
1484 };
1485
1486 let (our_max, our_credit) = match &our_hello {
1488 Hello::V2 {
1489 max_payload_size,
1490 initial_channel_credit,
1491 } => (*max_payload_size, *initial_channel_credit),
1492 Hello::V1 { .. } => unreachable!("we always send V2"),
1493 };
1494 let (peer_max, peer_credit) = match &peer_hello {
1495 Hello::V2 {
1496 max_payload_size,
1497 initial_channel_credit,
1498 } => (*max_payload_size, *initial_channel_credit),
1499 Hello::V1 { .. } => unreachable!("V1 is rejected above"),
1500 };
1501
1502 let negotiated = Negotiated {
1503 max_payload_size: our_max.min(peer_max),
1504 initial_credit: our_credit.min(peer_credit),
1505 };
1506
1507 let (driver_tx, driver_rx) = channel(256);
1509
1510 let root_conn = ConnectionState::new(
1514 ConnectionId::ROOT,
1515 driver_tx.clone(),
1516 role,
1517 negotiated.initial_credit,
1518 None,
1519 None,
1520 );
1521 let handle = root_conn.handle.clone();
1522
1523 let mut connections = HashMap::new();
1524 connections.insert(ConnectionId::ROOT, root_conn);
1525
1526 let (incoming_connections_tx, incoming_connections_rx) = channel(64);
1529
1530 let (incoming_response_tx, incoming_response_rx) = channel(64);
1532
1533 let driver = Driver {
1534 io,
1535 dispatcher,
1536 role,
1537 negotiated: negotiated.clone(),
1538 driver_rx,
1539 driver_tx,
1540 connections,
1541 next_conn_id: 1, pending_connects: HashMap::new(),
1543 incoming_connections_tx: Some(incoming_connections_tx), incoming_response_rx: Some(incoming_response_rx),
1545 incoming_response_tx,
1546 diagnostic_state: None,
1547 };
1548
1549 Ok((handle, incoming_connections_rx, driver))
1550}
1551
1552pub struct NoDispatcher;
1560
1561impl ServiceDispatcher for NoDispatcher {
1562 fn method_ids(&self) -> Vec<u64> {
1563 vec![]
1564 }
1565
1566 fn dispatch(
1567 &self,
1568 cx: &Context,
1569 _payload: Vec<u8>,
1570 registry: &mut ChannelRegistry,
1571 ) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
1572 let conn_id = cx.conn_id;
1573 let request_id = cx.request_id.raw();
1574 let driver_tx = registry.driver_tx();
1575 Box::pin(async move {
1576 let response: Result<(), RoamError<()>> = Err(RoamError::UnknownMethod);
1577 let payload = facet_postcard::to_vec(&response).unwrap_or_default();
1578 let _ = driver_tx
1579 .send(DriverMessage::Response {
1580 conn_id,
1581 request_id,
1582 channels: Vec::new(),
1583 payload,
1584 })
1585 .await;
1586 })
1587 }
1588}
1589
1590impl Clone for NoDispatcher {
1591 fn clone(&self) -> Self {
1592 NoDispatcher
1593 }
1594}
1595
1596#[cfg(test)]
1597mod tests {
1598 use super::*;
1599
1600 #[test]
1601 fn test_backoff_calculation() {
1602 let policy = RetryPolicy::default();
1603 assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(100));
1604 assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(200));
1605 assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(400));
1606 assert_eq!(policy.backoff_for_attempt(10), Duration::from_secs(5));
1607 }
1608}