1use std::collections::{HashMap, HashSet};
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use crate::api::DataPlaneServiceServer;
10use display_error_chain::ErrorChainExt;
11use parking_lot::RwLock;
12use slim_config::client::ClientConfig;
13use slim_config::client::TransportChannel;
14use slim_config::component::configuration::Configuration;
15use slim_config::server::ServerConfig;
16use slim_config::server_handler::ServerHandler;
17use slim_config::websocket::server as websocket_server;
18use slim_config::websocket::server::AcceptedWebSocketConnection;
19use tokio::sync::mpsc::{self, Sender};
20use tokio::task::JoinHandle;
21use tokio_stream::wrappers::ReceiverStream;
22use tokio_stream::{Stream, StreamExt};
23use tokio_util::sync::CancellationToken;
24
25use tonic::{Request, Response, Status};
26use tracing::{Instrument, debug, error, info, warn};
27
28#[cfg(feature = "otel_tracing")]
29use crate::otel_tracing;
30
31use crate::api::ProtoMessage;
32use crate::api::ProtoPublishType as PublishType;
33use crate::api::ProtoSubscribeType as SubscribeType;
34use crate::api::ProtoSubscriptionAckType as SubscriptionAckType;
35use crate::api::ProtoUnsubscribeType as UnsubscribeType;
36use crate::api::proto::dataplane::v1::Message;
37
38use crate::api::proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient;
39use crate::api::proto::dataplane::v1::data_plane_service_server::DataPlaneService;
40use crate::api::{
41 LinkNegotiationPayload, ProtoLink, ProtoLinkMessageType as LinkType, ProtoLinkType, ProtoName,
42};
43use crate::connection::{Channel, Connection};
44use crate::errors::{DataPathError, MessageContext};
45use crate::forwarder::Forwarder;
46use crate::link_ecdh::{self, X25519_PUBLIC_KEY_LEN};
47use crate::messages::utils::SlimHeaderFlags;
48use crate::recovery::RecoveryTable;
49use crate::tables::connection_table::ConnectionTable;
50use crate::tables::remote_subscription_table::SubscriptionInfo;
51use crate::tables::subscription_table::SubscriptionTableImpl;
52use crate::tables::{ConnType, MatchFilter};
53use crate::websocket;
54use semver;
55
56fn local_version() -> &'static str {
57 slim_version::version()
58}
59
60#[cfg(test)]
62static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
63
64#[derive(Debug)]
65struct MessageProcessorInternal {
66 forwarder: Forwarder<Connection>,
68
69 drain_signal: parking_lot::RwLock<Option<drain::Signal>>,
71
72 drain_watch: parking_lot::RwLock<Option<drain::Watch>>,
74
75 tx_control_plane: RwLock<Option<Sender<Result<Message, Status>>>>,
77
78 recovery_table: RecoveryTable,
80
81 sub_ack_manager: crate::subscription_ack::RemoteSubAckManager,
83
84 service_id: String,
86
87 server_require_header_mac: bool,
89
90 link_hmac_timeout: std::time::Duration,
92
93 link_hmac_poll_interval: std::time::Duration,
95}
96
97#[derive(Debug, Clone)]
98pub struct MessageProcessor {
99 internal: Arc<MessageProcessorInternal>,
100}
101
102impl Default for MessageProcessor {
103 fn default() -> Self {
104 Self::new_with_service_id(String::new())
105 }
106}
107
108impl MessageProcessor {
109 pub fn new_with_service_id(service_id: String) -> Self {
110 Self::new_with_options(service_id, None)
111 }
112
113 pub fn new_with_options(service_id: String, recovery_ttl: Option<std::time::Duration>) -> Self {
114 Self::new_internal(
115 service_id,
116 recovery_ttl,
117 false,
118 std::time::Duration::from_secs(5),
119 std::time::Duration::from_millis(5),
120 )
121 }
122
123 pub fn new_with_server_config(
125 service_id: String,
126 server_config: &ServerConfig,
127 recovery_ttl: Option<std::time::Duration>,
128 ) -> Self {
129 Self::new_internal(
130 service_id,
131 recovery_ttl,
132 server_config.require_header_mac,
133 std::time::Duration::from_secs(server_config.link_hmac_timeout_secs),
134 std::time::Duration::from_millis(server_config.link_hmac_poll_interval_ms),
135 )
136 }
137
138 fn new_internal(
139 service_id: String,
140 recovery_ttl: Option<std::time::Duration>,
141 server_require_header_mac: bool,
142 link_hmac_timeout: std::time::Duration,
143 link_hmac_poll_interval: std::time::Duration,
144 ) -> Self {
145 let (signal, watch) = drain::channel();
146 let recovery_table = match recovery_ttl {
147 Some(ttl) => RecoveryTable::new(ttl),
148 None => RecoveryTable::default(),
149 };
150 let internal = MessageProcessorInternal {
151 forwarder: Forwarder::new(),
152 drain_signal: RwLock::new(Some(signal)),
153 drain_watch: RwLock::new(Some(watch)),
154 tx_control_plane: RwLock::new(None),
155 recovery_table,
156 sub_ack_manager: crate::subscription_ack::RemoteSubAckManager::new(),
157 service_id,
158 server_require_header_mac,
159 link_hmac_timeout,
160 link_hmac_poll_interval,
161 };
162 Self {
163 internal: Arc::new(internal),
164 }
165 }
166
167 pub fn new() -> Self {
168 Self::default()
169 }
170
171 pub async fn run_server(
176 &self,
177 config: &ServerConfig,
178 ) -> Result<CancellationToken, DataPathError> {
179 debug!(%config, "starting dataplane server");
180
181 if config.require_header_mac != self.internal.server_require_header_mac {
182 warn!(
183 configured = config.require_header_mac,
184 processor = self.internal.server_require_header_mac,
185 "server require_header_mac differs from MessageProcessor; inbound connections use the processor value set at construction (prefer MessageProcessor::new_with_server_config)",
186 );
187 }
188
189 let watch = self.get_drain_watch()?;
190 config
191 .run_server(watch, Arc::new(self.clone()))
192 .await
193 .map_err(Into::into)
194 }
195
196 async fn handle_websocket_accepted(&self, accepted: AcceptedWebSocketConnection) {
197 let cancellation_token = CancellationToken::new();
198 let streams =
199 websocket::spawn_transport_tasks(accepted.websocket, cancellation_token.clone());
200
201 let connection = Connection::new(ConnType::Remote, Channel::Client(streams.outbound))
202 .with_remote_addr(accepted.remote_addr)
203 .with_local_addr(accepted.local_addr)
204 .with_require_header_mac(self.internal.server_require_header_mac)
205 .with_cancellation_token(Some(cancellation_token.clone()));
206
207 debug!(
208 remote = ?connection.remote_addr(),
209 local = ?connection.local_addr(),
210 "new websocket connection received from remote",
211 );
212 info!(telemetry = true, counter.num_active_connections = 1);
213
214 let conn_index = match self.forwarder().on_connection_established(connection, None) {
215 Some(index) => index,
216 None => {
217 error!("failed to add websocket connection to table");
218 cancellation_token.cancel();
219 return;
220 }
221 };
222
223 if let Err(err) = self.process_stream(
224 streams.inbound,
225 conn_index,
226 None,
227 cancellation_token,
228 ConnType::Remote,
229 false,
230 ) {
231 error!(error = %err.chain(), "error starting websocket processing stream");
232 }
233 }
234
235 pub fn signal_drain(&self) {
241 self.internal.drain_signal.write().take();
242 self.internal.drain_watch.write().take();
243 }
244
245 pub async fn shutdown(&self) -> Result<(), DataPathError> {
246 let signal = self
248 .internal
249 .drain_signal
250 .write()
251 .take()
252 .ok_or(DataPathError::AlreadyClosedError)?;
253
254 self.internal.drain_watch.write().take();
256
257 signal.drain().await;
259
260 Ok(())
261 }
262
263 fn set_tx_control_plane(&self, tx: Sender<Result<Message, Status>>) {
264 let mut tx_guard = self.internal.tx_control_plane.write();
265 *tx_guard = Some(tx);
266 }
267
268 fn get_tx_control_plane(&self) -> Option<Sender<Result<Message, Status>>> {
269 let tx_guard = self.internal.tx_control_plane.read();
270 tx_guard.clone()
271 }
272
273 fn forwarder(&self) -> &Forwarder<Connection> {
274 &self.internal.forwarder
275 }
276
277 pub(crate) fn verify_remote_header_mac(
279 &self,
280 conn_index: u64,
281 message: &Message,
282 enforce_strict_verification: bool,
283 ) -> Result<(), DataPathError> {
284 let conn = self
285 .forwarder()
286 .get_connection(conn_index)
287 .ok_or(DataPathError::ConnectionNotFound(conn_index))?;
288 if !matches!(conn.category(), ConnType::Remote) {
289 return Ok(());
290 }
291 let header = message
292 .try_get_slim_header()
293 .ok_or(DataPathError::UnknownMsgType)?;
294
295 let has_wire_mac = header.header_mac.as_ref().is_some_and(|m| !m.is_empty());
296
297 if (message.is_subscribe() || message.is_unsubscribe()) && !has_wire_mac {
302 if enforce_strict_verification {
303 return Err(DataPathError::NegotiationError(
304 "empty HMAC is not allowed in strict verification mode".to_string(),
305 ));
306 }
307 return Ok(());
308 }
309
310 let Some(mac) = conn.header_hmac() else {
311 if enforce_strict_verification {
312 return Err(DataPathError::NegotiationError(
313 "strict header MAC required but link HMAC session is not installed".to_string(),
314 ));
315 }
316 if message.is_publish() && has_wire_mac {
320 return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index));
321 }
322 return Ok(());
323 };
324 let link_id = conn
325 .link_id()
326 .filter(|id| slim_config::grpc::client::is_valid_uuid_v4(id))
327 .ok_or(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index))?;
328 mac.verify_slim_header(header, &link_id)
329 .map_err(DataPathError::HeaderIntegrity)
330 }
331
332 pub(crate) fn remove_sub_ack(&self, subscription_id: u64) {
333 self.internal.sub_ack_manager.remove(subscription_id);
334 }
335
336 fn get_drain_watch(&self) -> Result<drain::Watch, DataPathError> {
337 self.internal
338 .drain_watch
339 .read()
340 .clone()
341 .ok_or(DataPathError::AlreadyClosedError)
342 }
343
344 async fn restore_remote_subscriptions(
354 &self,
355 remote_subs: &HashSet<SubscriptionInfo>,
356 conn_index: u64,
357 restore_tracking: bool,
358 ) {
359 for r in remote_subs {
360 let sub_msg = Message::builder()
361 .source(r.source().clone())
362 .destination(r.name().clone())
363 .identity(r.source_identity())
364 .build_subscribe()
365 .unwrap();
366 if let Err(e) = self.send_msg(sub_msg, conn_index).await {
367 error!(
368 error = %e.chain(), %conn_index,
369 "error restoring subscription on remote node",
370 );
371 } else if restore_tracking {
372 self.forwarder().on_forwarded_subscription(
373 r.source().clone(),
374 r.name().clone(),
375 r.source_identity().clone(),
376 conn_index,
377 true,
378 r.subscription_id(),
379 );
380 }
381 }
382 }
383
384 async fn try_to_connect(
385 &self,
386 client_config: ClientConfig,
387 local: Option<SocketAddr>,
388 remote: Option<SocketAddr>,
389 existing_conn_index: Option<u64>,
390 ) -> Result<(JoinHandle<()>, u64), DataPathError> {
391 client_config.validate()?;
392
393 let mut watch = std::pin::pin!(self.get_drain_watch()?.signaled());
394 let channel = tokio::select! {
395 _ = &mut watch => {
396 return Err(DataPathError::ShuttingDownError);
397 }
398 res = client_config.to_channel() => {
399 res?
400 }
401 };
402
403 let cancellation_token = CancellationToken::new();
404 let link_id = client_config.link_id.clone();
405
406 match channel {
407 TransportChannel::Grpc(grpc_channel) => {
408 let mut client = DataPlaneServiceClient::new(grpc_channel);
409 let (tx, rx) = mpsc::channel(128);
410 let stream = client
411 .open_channel(Request::new(ReceiverStream::new(rx)))
412 .await?;
413
414 let (ecdh_sk, ecdh_pk) = link_ecdh::generate_x25519_ephemeral()
415 .map_err(|_| DataPathError::LinkKeyGeneration)?;
416
417 let (handle, conn_index) = self.register_remote_connection(
418 stream.into_inner(),
419 Channel::Client(tx),
420 &client_config,
421 local,
422 remote,
423 existing_conn_index,
424 cancellation_token,
425 Some(link_id.clone()),
426 Some(ecdh_sk),
427 )?;
428
429 self.send_client_link_negotiation(&link_id, conn_index, Some(ecdh_pk))
430 .await;
431 self.await_link_hmac_ready(conn_index, client_config.require_header_mac)
432 .await?;
433
434 Ok((handle, conn_index))
435 }
436 TransportChannel::Websocket(ws_channel) => {
437 let websocket = ws_channel
438 .take_websocket()
439 .expect("websocket channel already consumed");
440 let streams =
441 websocket::spawn_transport_tasks(websocket, cancellation_token.clone());
442
443 let (ecdh_sk, ecdh_pk) = link_ecdh::generate_x25519_ephemeral()
444 .map_err(|_| DataPathError::LinkKeyGeneration)?;
445
446 let (handle, conn_index) = self.register_remote_connection(
447 streams.inbound,
448 Channel::Client(streams.outbound),
449 &client_config,
450 local.or(ws_channel.local_addr()),
451 remote.or(ws_channel.remote_addr()),
452 existing_conn_index,
453 cancellation_token,
454 Some(link_id.clone()),
455 Some(ecdh_sk),
456 )?;
457
458 self.send_client_link_negotiation(&link_id, conn_index, Some(ecdh_pk))
459 .await;
460 self.await_link_hmac_ready(conn_index, client_config.require_header_mac)
461 .await?;
462
463 Ok((handle, conn_index))
464 }
465 }
466 }
467
468 async fn send_client_link_negotiation(
470 &self,
471 link_id: &str,
472 conn_index: u64,
473 ecdh_public_key: Option<Vec<u8>>,
474 ) {
475 let negotiation_msg = ProtoMessage::builder().build_link_negotiation(
476 link_id,
477 local_version(),
478 false,
479 ecdh_public_key,
480 );
481 if let Err(e) = self.send_msg(negotiation_msg, conn_index).await {
482 debug!(
483 %conn_index,
484 error = %e.chain(),
485 "failed to send link negotiation (remote may be an older SLIM instance)",
486 );
487 }
488 }
489
490 async fn await_link_hmac_ready(
492 &self,
493 conn_index: u64,
494 require_header_mac: bool,
495 ) -> Result<(), DataPathError> {
496 if !require_header_mac {
497 return Ok(());
498 }
499
500 let timeout = self.internal.link_hmac_timeout;
501 let start = tokio::time::Instant::now();
502 while start.elapsed() < timeout {
503 match self.forwarder().get_connection(conn_index) {
504 Some(conn) if conn.header_hmac().is_some() => return Ok(()),
505 Some(_) => {
506 tokio::time::sleep(self.internal.link_hmac_poll_interval).await;
507 }
508 None => return Err(DataPathError::ConnectionNotFound(conn_index)),
509 }
510 }
511
512 Err(DataPathError::NegotiationError(
513 "timed out waiting for link HMAC session after negotiation".to_string(),
514 ))
515 }
516
517 #[allow(clippy::too_many_arguments)]
523 fn register_remote_connection<S>(
524 &self,
525 inbound: S,
526 outbound: Channel,
527 client_config: &ClientConfig,
528 local: Option<SocketAddr>,
529 remote: Option<SocketAddr>,
530 existing_conn_index: Option<u64>,
531 cancellation_token: CancellationToken,
532 link_id: Option<String>,
533 outbound_ecdh_private: Option<aws_lc_rs::agreement::EphemeralPrivateKey>,
534 ) -> Result<(JoinHandle<()>, u64), DataPathError>
535 where
536 S: Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
537 {
538 let mut connection = Connection::new(ConnType::Remote, outbound)
539 .with_local_addr(local)
540 .with_remote_addr(remote)
541 .with_config_data(Some(client_config.clone()))
542 .with_require_header_mac(client_config.require_header_mac)
543 .with_cancellation_token(Some(cancellation_token.clone()));
544 if let Some(link_id) = link_id {
545 connection = connection.with_link_id(link_id);
546 }
547
548 if let Some(ecdh_sk) = outbound_ecdh_private {
549 connection.set_outbound_ecdh_private(ecdh_sk);
550 }
551
552 debug!(
553 remote = ?connection.remote_addr(),
554 local = ?connection.local_addr(),
555 "new connection initiated locally",
556 );
557
558 let conn_index = self
559 .forwarder()
560 .on_connection_established(connection, existing_conn_index)
561 .ok_or(DataPathError::ConnectionTableAddError)?;
562
563 debug!(%conn_index, is_local = false, "new connection index");
564
565 let handle = self.process_stream(
566 inbound,
567 conn_index,
568 Some(client_config.clone()),
569 cancellation_token,
570 ConnType::Remote,
571 false,
572 )?;
573
574 Ok((handle, conn_index))
575 }
576
577 pub async fn connect(
578 &self,
579 client_config: ClientConfig,
580 local: Option<SocketAddr>,
581 remote: Option<SocketAddr>,
582 ) -> Result<(JoinHandle<()>, u64), DataPathError> {
583 self.try_to_connect(client_config, local, remote, None)
584 .await
585 }
586
587 pub fn disconnect(&self, conn: u64) -> Result<ClientConfig, DataPathError> {
588 let connection = match self.forwarder().get_connection(conn) {
589 Some(c) => c,
590 None => {
591 error!(%conn, "error handling disconnect: connection unknown");
592 return Err(DataPathError::DisconnectionError(conn));
593 }
594 };
595
596 let token = match connection.cancellation_token() {
597 Some(t) => t,
598 None => {
599 error!(%conn, "error handling disconnect: missing cancellation token");
600 return Err(DataPathError::DisconnectionError(conn));
601 }
602 };
603
604 token.cancel();
606
607 connection
608 .config_data()
609 .cloned()
610 .ok_or(DataPathError::DisconnectionError(conn))
611 }
612
613 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id))]
614 pub fn register_local_connection(
615 &self,
616 from_control_plane: bool,
617 ) -> Result<
618 (
619 u64,
620 tokio::sync::mpsc::Sender<Result<Message, Status>>,
621 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
622 ),
623 DataPathError,
624 > {
625 let (tx1, rx1) = mpsc::channel(512);
627
628 debug!("establishing new local app connection");
629
630 let (tx2, rx2) = mpsc::channel(512);
632
633 if from_control_plane && self.get_tx_control_plane().is_none() {
636 self.set_tx_control_plane(tx2.clone());
637 }
638
639 let cancellation_token = CancellationToken::new();
641 let connection = Connection::new(ConnType::Local, Channel::Server(tx2))
642 .with_cancellation_token(Some(cancellation_token.clone()));
643
644 let conn_id = self
646 .forwarder()
647 .on_connection_established(connection, None)
648 .unwrap();
649
650 debug!(%conn_id, "local connection established");
651 info!(telemetry = true, counter.num_active_connections = 1);
652
653 self.process_stream(
655 ReceiverStream::new(rx1),
656 conn_id,
657 None,
658 cancellation_token,
659 ConnType::Local,
660 from_control_plane,
661 )?;
662
663 Ok((conn_id, tx1, rx2))
665 }
666
667 pub async fn send_msg(
668 &self,
669 #[cfg(feature = "otel_tracing")] mut msg: Message,
670 #[cfg(not(feature = "otel_tracing"))] msg: Message,
671 out_conn: u64,
672 ) -> Result<(), DataPathError> {
673 #[cfg(feature = "otel_tracing")]
674 otel_tracing::prepare_outbound_msg(
675 &mut msg,
676 "send_message",
677 &self.internal.service_id,
678 otel_tracing::SpanTarget::Connection(out_conn),
679 );
680 self.send_msg_raw(msg, out_conn).await
681 }
682
683 async fn send_msg_raw(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
684 let connection = self.forwarder().get_connection(out_conn);
685 match connection {
686 Some(conn) => {
687 if !msg.is_link() && !msg.is_subscription_ack() {
690 msg.clear_slim_header();
691 }
692
693 if !msg.is_link()
694 && !msg.is_subscription_ack()
695 && matches!(conn.category(), ConnType::Remote)
696 && conn.require_header_mac()
697 && conn.header_hmac().is_none()
698 {
699 return Err(DataPathError::NegotiationError(
700 "strict header MAC required but link HMAC session is not installed"
701 .to_string(),
702 ));
703 }
704
705 if !msg.is_link()
706 && !msg.is_subscription_ack()
707 && matches!(conn.category(), ConnType::Remote)
708 && let Some(mac) = conn.header_hmac()
709 {
710 let link_id = conn
711 .link_id()
712 .or_else(|| conn.config_data().map(|c| c.link_id.clone()))
713 .filter(|id| slim_config::grpc::client::is_valid_uuid_v4(id));
714 if let Some(ref id) = link_id {
715 let header = msg.get_slim_header_mut();
716
717 mac.sign_slim_header(header, id.as_str())
718 .map_err(DataPathError::HeaderIntegrity)?;
719
720 #[cfg(debug_assertions)]
723 if std::env::var("SLIM_TEST_TAMPER_DESTINATION").is_ok()
724 && let Some(dest) = header.destination.as_mut()
725 && let Some(sn) = dest.str_name.as_mut()
726 {
727 sn.str_component_2.push_str("-integrity-test-tamper");
728 }
729 } else {
730 return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(out_conn));
731 }
732 }
733
734 if !msg.is_link()
735 && !msg.is_subscription_ack()
736 && matches!(conn.channel(), Channel::Server(_))
737 && matches!(conn.category(), ConnType::Local)
738 {
739 msg.get_slim_header_mut().header_mac = None;
740 }
741
742 match conn.channel() {
743 Channel::Server(s) => {
744 s.send(Ok(msg))
745 .await
746 .map_err(|e| DataPathError::MessageProcessingError {
747 source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
748 msg: Box::new(e.0.unwrap_or_default()),
749 })
750 }
751 Channel::Client(s) => {
752 s.send(msg)
753 .await
754 .map_err(|e| DataPathError::MessageProcessingError {
755 source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
756 msg: Box::new(e.0),
757 })
758 }
759 }
760 }
761 None => Err(DataPathError::ConnectionNotFound(out_conn)),
762 }
763 }
764
765 async fn match_and_forward_msg(
766 &self,
767 #[cfg(feature = "otel_tracing")] mut msg: Message,
768 #[cfg(not(feature = "otel_tracing"))] msg: Message,
769 in_connection: u64,
770 fanout: u32,
771 filter: MatchFilter,
772 ) -> Result<(), DataPathError> {
773 let header = msg.get_slim_header();
774 debug!(name = %header.get_dst(), %fanout, "match and forward message");
775
776 if let Some(val) = msg.get_forward_to() {
779 debug!(conn = %val, "forwarding message to connection");
780 return self.send_msg(msg, val).await;
781 }
782
783 let encoded = header.get_encoded_dst();
784
785 match self
786 .forwarder()
787 .on_publish_msg_match(encoded, in_connection, fanout, filter)
788 {
789 Ok(out_vec) => {
790 let len = out_vec.len();
791 if len == 1 {
793 return self.send_msg(msg, out_vec[0]).await;
794 }
795
796 #[cfg(feature = "otel_tracing")]
797 otel_tracing::prepare_fanout_msg(
798 &mut msg,
799 "send_message",
800 &self.internal.service_id,
801 len as u32,
802 );
803
804 let mut i = 0usize;
805 while i < len - 1 {
806 self.send_msg_raw(msg.clone(), out_vec[i]).await?;
807 i += 1;
808 }
809 self.send_msg_raw(msg, out_vec[i]).await?;
810 Ok(())
811 }
812 Err(e) => {
813 debug!(name = %header.get_dst(), %fanout, error = %e, "no match for publish destination");
814 Err(DataPathError::MessageProcessingError {
815 source: Box::new(e),
816 msg: Box::new(msg),
817 })
818 }
819 }
820 }
821
822 async fn handle_link_message(
827 &self,
828 link: ProtoLink,
829 conn_index: u64,
830 category: ConnType,
831 ) -> Result<(), DataPathError> {
832 if category.is_local() {
833 debug!(%conn_index, "ignoring link message received on local connection");
834 return Ok(());
835 }
836 match link.link_type {
837 Some(ProtoLinkType::LinkNegotiation(payload)) => {
838 self.handle_link_negotiation(&payload, conn_index).await
839 }
840 None => {
841 debug!(%conn_index, "received link message with unset link_type");
842 Ok(())
843 }
844 }
845 }
846
847 async fn handle_link_negotiation(
857 &self,
858 payload: &LinkNegotiationPayload,
859 in_connection: u64,
860 ) -> Result<(), DataPathError> {
861 let link_id = &payload.link_id;
862 let remote_version = &payload.slim_version;
863
864 debug!(
865 %in_connection,
866 %link_id,
867 %remote_version,
868 is_reply = payload.is_reply,
869 "received link negotiation",
870 );
871
872 let Some(conn) = self.forwarder().get_connection(in_connection) else {
873 debug!(%in_connection, "ignoring link negotiation request received on unknown connection");
874 return Ok(());
875 };
876
877 let strict = conn.require_header_mac();
878
879 match (conn.is_outgoing(), payload.is_reply) {
881 (true, false) => {
882 debug!(%in_connection, "ignoring link negotiation request received on outgoing connection");
883 return Ok(());
884 }
885 (false, true) => {
886 debug!(%in_connection, "ignoring link negotiation reply received on incoming connection");
887 return Ok(());
888 }
889 _ => {}
890 }
891
892 let version = match semver::Version::parse(remote_version) {
894 Ok(v) => v,
895 Err(e) => {
896 debug!(%in_connection, %remote_version, error = %e, "ignoring link negotiation with unparsable remote SLIM version");
897 return Ok(());
898 }
899 };
900
901 if payload.is_reply {
902 if strict && payload.link_ecdh_public_key.len() != X25519_PUBLIC_KEY_LEN {
903 return Err(DataPathError::NegotiationError(
904 "public key length is invalid".to_string(),
905 ));
906 }
907
908 if !conn.complete_negotiation_as_client(link_id, version) {
909 debug!(%in_connection, %link_id, "ignoring link negotiation reply");
910 return Ok(());
911 }
912
913 if payload.link_ecdh_public_key.len() == X25519_PUBLIC_KEY_LEN
914 && let Some(sk) = conn.take_outbound_ecdh_private()
915 {
916 match link_ecdh::derive_header_mac_from_ecdh(
917 sk,
918 payload.link_ecdh_public_key.as_slice(),
919 link_id,
920 ) {
921 Ok(mac) => conn.install_header_hmac(mac),
922 Err(e) => {
923 error!(
924 %in_connection,
925 error = %e,
926 "link ECDH key derivation failed (client path)",
927 );
928 return Err(DataPathError::NegotiationError(
929 "failed to generate client exchange key".to_string(),
930 ));
931 }
932 }
933 }
934
935 if strict && conn.header_hmac().is_none() {
936 return Err(DataPathError::NegotiationError(
937 "strict header MAC required but link HMAC session is not installed".to_string(),
938 ));
939 }
940 } else {
941 if strict && payload.link_ecdh_public_key.len() != X25519_PUBLIC_KEY_LEN {
942 return Err(DataPathError::NegotiationError(
943 "public key length is invalid".to_string(),
944 ));
945 }
946
947 if !conn.complete_negotiation_as_server(link_id, version) {
948 debug!(%in_connection, %link_id, "ignoring link negotiation request");
949 return Ok(());
950 }
951
952 let peer_ecdh = payload.link_ecdh_public_key.as_slice();
953 let mut server_reply_ecdh: Option<Vec<u8>> = None;
954 if peer_ecdh.len() == X25519_PUBLIC_KEY_LEN {
955 match link_ecdh::generate_x25519_ephemeral() {
956 Ok((server_sk, server_pk)) => {
957 match link_ecdh::derive_header_mac_from_ecdh(server_sk, peer_ecdh, link_id)
958 {
959 Ok(mac) => {
960 conn.install_header_hmac(mac);
961 server_reply_ecdh = Some(server_pk);
962 }
963 Err(e) => {
964 error!(
965 %in_connection,
966 error = %e,
967 "link ECDH key derivation failed (server path)",
968 );
969 if strict {
970 return Err(DataPathError::NegotiationError(
971 "failed to derive header MAC from link ECDH (server path)"
972 .to_string(),
973 ));
974 }
975 }
976 }
977 }
978 Err(_) => {
979 error!(%in_connection, "failed to generate server link ECDH key");
980 return Err(DataPathError::NegotiationError(
981 "failed to generate server exchange key".to_string(),
982 ));
983 }
984 }
985 }
986
987 if strict && conn.header_hmac().is_none() {
988 return Err(DataPathError::NegotiationError(
989 "strict header MAC required but link HMAC session is not installed".to_string(),
990 ));
991 }
992
993 if let Some(entry) = self.internal.recovery_table.take(link_id) {
996 info!(%in_connection, %link_id, "recovering routes for reconnected peer");
997
998 for (name, sub_ids) in &entry.local_subs {
1002 for &subscription_id in sub_ids {
1003 if let Err(e) = self.forwarder().on_subscription_msg(
1004 name.clone(),
1005 in_connection,
1006 ConnType::Remote,
1007 true,
1008 subscription_id,
1009 ) {
1010 error!(
1011 error = %e.chain(), %in_connection,
1012 "error re-adding local subscription during recovery",
1013 );
1014 }
1015 }
1016 }
1017
1018 self.restore_remote_subscriptions(&entry.remote_subs, in_connection, true)
1022 .await;
1023 }
1024
1025 let reply = ProtoMessage::builder().build_link_negotiation(
1027 link_id,
1028 local_version(),
1029 true,
1030 server_reply_ecdh,
1031 );
1032 if let Err(e) = self.send_msg(reply, in_connection).await {
1033 debug!(
1034 %in_connection,
1035 error = %e.chain(),
1036 "failed to send link negotiation reply",
1037 );
1038 }
1039 }
1040
1041 Ok(())
1042 }
1043
1044 async fn process_publish(
1045 &self,
1046 msg: Message,
1047 in_connection: u64,
1048 filter: MatchFilter,
1049 ) -> Result<(), DataPathError> {
1050 debug!(
1051 %in_connection,
1052 ?msg,
1053 "received publication"
1054 );
1055
1056 info!(
1058 telemetry = true,
1059 monotonic_counter.num_messages_by_type = 1,
1060 method = "publish"
1061 );
1062 let fanout = msg.get_fanout();
1067
1068 self.match_and_forward_msg(msg, in_connection, fanout, filter)
1069 .await
1070 }
1071
1072 pub(crate) async fn send_subscription_ack(
1073 &self,
1074 in_connection: u64,
1075 subscription_id: u64,
1076 result: &Result<(), DataPathError>,
1077 ) {
1078 let (success, error_msg) = match result {
1079 Ok(()) => (true, String::new()),
1080 Err(e) => (false, e.to_string()),
1081 };
1082
1083 let ack_msg =
1084 Message::builder().build_subscription_ack(subscription_id, success, error_msg);
1085
1086 if let Err(e) = self.send_msg(ack_msg, in_connection).await {
1087 error!(error = %e.chain(), "failed to send subscription ack");
1088 }
1089 }
1090
1091 async fn process_subscription_update_and_forward(
1092 &self,
1093 msg: Message,
1094 conn: u64,
1095 forward: Option<u64>,
1096 add: bool,
1097 subscription_id: u64,
1098 ) -> Result<(), DataPathError> {
1099 let dst = msg.get_dst();
1100
1101 let connection = if let Some(c) = self.forwarder().get_connection(conn) {
1103 c
1104 } else {
1105 return Err(DataPathError::MessageProcessingError {
1106 source: Box::new(DataPathError::ConnectionNotFound(conn)),
1107 msg: Box::new(msg),
1108 });
1109 };
1110
1111 debug!(
1112 %conn,
1113 %dst,
1114 is_local = connection.is_local_connection(),
1115 "processing {}subscription",
1116 if add { "" } else { "un" }
1117 );
1118
1119 self.forwarder().on_subscription_msg(
1120 dst.clone(),
1121 conn,
1122 connection.category(),
1123 add,
1124 subscription_id,
1125 )?;
1126
1127 match forward {
1128 None => Ok(()),
1129 Some(out_conn) => {
1130 debug!(
1131 %out_conn,
1132 "forwarding {}subscription to connection",
1133 if add { "" } else { "un" }
1134 );
1135
1136 let source = msg.get_source();
1137 let identity = msg.get_identity();
1138
1139 self.send_msg(msg, out_conn).await.map(|_| {
1140 self.forwarder().on_forwarded_subscription(
1141 source,
1142 dst,
1143 identity,
1144 out_conn,
1145 add,
1146 subscription_id,
1147 );
1148 })
1149 }
1150 }
1151 }
1152
1153 async fn process_subscription(
1157 &self,
1158 msg: Message,
1159 in_connection: u64,
1160 add: bool,
1161 ) -> Result<(), DataPathError> {
1162 debug!(
1163 %in_connection,
1164 ?msg,
1165 "received {}subscription",
1166 if add { "" } else { "un" }
1167 );
1168
1169 info!(
1171 telemetry = true,
1172 monotonic_counter.num_messages_by_type = 1,
1173 message_type = { if add { "subscribe" } else { "unsubscribe" } }
1174 );
1175 let subscription_id = msg.get_subscription_id();
1178
1179 debug!(?subscription_id, "received subscription id");
1180
1181 let header = msg.get_slim_header();
1183
1184 let (in_conn, recv_from, forward) = header.get_connections();
1186 let in_conn = recv_from.unwrap_or(in_conn);
1187
1188 let forward = forward.filter(|&out| {
1191 self.forwarder()
1192 .get_connection(out)
1193 .map(|c| !c.is_local_connection())
1194 .unwrap_or(true)
1195 });
1196
1197 let use_remote_ack = forward.is_some();
1201
1202 let Some(connection) = self.forwarder().get_connection(in_conn) else {
1204 if let Some(id) = subscription_id {
1205 debug!(%in_conn, "connection not found, sending error ack");
1206 self.send_subscription_ack(
1207 in_connection,
1208 id,
1209 &Err(DataPathError::ConnectionNotFound(in_conn)),
1210 )
1211 .await;
1212 }
1213 return Err(DataPathError::MessageProcessingError {
1214 source: Box::new(DataPathError::ConnectionNotFound(in_conn)),
1215 msg: Box::new(msg),
1216 });
1217 };
1218
1219 if recv_from.is_some() && connection.is_local_connection() {
1221 if let Some(id) = subscription_id {
1222 debug!(%in_conn, "subscription looped back to local connection, acking ok");
1223 self.send_subscription_ack(in_connection, id, &Ok(())).await;
1224 }
1225 return Ok(());
1226 }
1227
1228 debug!(use_remote_ack, dst = %msg.get_dst(), forward_to = forward, "subscription: ack path decision");
1229
1230 let sub_id = subscription_id.unwrap_or(0);
1231
1232 let rx = self.internal.sub_ack_manager.register(sub_id);
1235
1236 let result = self
1238 .process_subscription_update_and_forward(msg.clone(), in_conn, forward, add, sub_id)
1239 .await;
1240
1241 if use_remote_ack && result.is_ok() {
1245 let out_conn = forward.unwrap();
1246
1247 tokio::spawn(crate::subscription_ack::retry_loop(
1248 self.clone(),
1249 sub_id,
1250 msg,
1251 out_conn,
1252 in_connection,
1253 subscription_id,
1254 rx,
1255 ));
1256
1257 return Ok(());
1258 }
1259
1260 if let Some(id) = subscription_id {
1262 debug!(%in_connection, ok = result.is_ok(), "sending immediate subscription ack");
1263 self.send_subscription_ack(in_connection, id, &result).await;
1264 }
1265
1266 result
1267 }
1268
1269 pub async fn process_message(
1270 &self,
1271 msg: Message,
1272 in_connection: u64,
1273 category: ConnType,
1274 ) -> Result<(), DataPathError> {
1275 match msg.message_type {
1276 Some(SubscribeType(_)) => self.process_subscription(msg, in_connection, true).await,
1277 Some(UnsubscribeType(_)) => self.process_subscription(msg, in_connection, false).await,
1278 Some(PublishType(_)) => {
1279 let filter = match category {
1280 ConnType::Peer => MatchFilter::EXCLUDE_PEER,
1281 _ => MatchFilter::ALL,
1282 };
1283 self.process_publish(msg, in_connection, filter).await
1284 }
1285 Some(LinkType(link)) => {
1286 self.handle_link_message(link, in_connection, category)
1287 .await
1288 }
1289 Some(SubscriptionAckType(ack)) => {
1290 let result = if ack.success {
1291 Ok(())
1292 } else {
1293 Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
1294 };
1295 self.internal
1296 .sub_ack_manager
1297 .resolve(ack.subscription_id, result);
1298 Ok(())
1299 }
1300 None => unreachable!(
1301 "message type not set; validate() must be called before process_message"
1302 ),
1303 }
1304 }
1305
1306 async fn handle_new_message(
1307 &self,
1308 conn_index: u64,
1309 category: ConnType,
1310 mut msg: Message,
1311 ) -> Result<(), DataPathError> {
1312 debug!(%conn_index, "received message from connection");
1313 info!(
1314 telemetry = true,
1315 monotonic_counter.num_processed_messages = 1
1316 );
1317
1318 if let Err(err) = msg.validate() {
1320 info!(
1321 telemetry = true,
1322 monotonic_counter.num_messages_by_type = 1,
1323 message_type = "none"
1324 );
1325
1326 let ret_err = DataPathError::MessageProcessingError {
1327 source: Box::new(err.into()),
1328 msg: Box::new(msg),
1329 };
1330
1331 return Err(ret_err);
1332 }
1333
1334 if !msg.is_link() && !msg.is_subscription_ack() {
1336 msg.set_incoming_conn(Some(conn_index));
1338
1339 #[cfg(feature = "otel_tracing")]
1340 otel_tracing::prepare_inbound_msg(
1341 &mut msg,
1342 "process_local",
1343 &self.internal.service_id,
1344 conn_index,
1345 category.is_local(),
1346 );
1347 }
1348
1349 match self.process_message(msg, conn_index, category).await {
1350 Ok(_) => Ok(()),
1351 Err(e) => {
1352 info!(
1354 telemetry = true,
1355 monotonic_counter.num_message_process_errors = 1
1356 );
1357 Err(e)
1361 }
1362 }
1363 }
1364
1365 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1366 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
1367 debug!(%conn_index, "sending error to local application");
1368 let connection = self.forwarder().get_connection(conn_index);
1369 match connection {
1370 Some(conn) => {
1371 debug!("try to notify the error to the local application");
1372 if let Channel::Server(tx) = conn.channel() {
1373 let session_ctx = match &err {
1375 DataPathError::MessageProcessingError { msg, .. } => {
1376 MessageContext::from_msg(msg)
1377 }
1378 _ => None,
1379 };
1380
1381 let payload = crate::errors::ErrorPayload::new(err.to_string(), session_ctx);
1383 let error_message = payload.to_json_string();
1384
1385 let status = Status::new(tonic::Code::Internal, error_message);
1387
1388 if tx.send(Err(status)).await.is_err() {
1389 debug!(error = %err.chain(), "unable to notify the error to the local app");
1390 }
1391 }
1392 }
1393 None => {
1394 error!(
1395 "error sending error to local app: connection {:?} not found",
1396 conn_index
1397 );
1398 }
1399 }
1400 }
1401
1402 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1403 async fn reconnect(
1404 &self,
1405 client_conf: ClientConfig,
1406 conn_index: u64,
1407 cancellation_token: &CancellationToken,
1408 ) -> bool {
1409 info!("connection lost with remote endpoint, attempting to reconnect");
1410
1411 let remote_subscriptions = self
1416 .forwarder()
1417 .get_subscriptions_forwarded_on_connection(conn_index);
1418
1419 tokio::select! {
1420 _ = cancellation_token.cancelled() => {
1421 debug!("cancellation token signaled, stopping reconnection process");
1422 false
1423 }
1424 res = self.try_to_connect(client_conf, None, None, Some(conn_index)) => {
1425 match res {
1426 Ok(_) => {
1427 info!("connection re-established successfully");
1428 self.restore_remote_subscriptions(
1433 &remote_subscriptions,
1434 conn_index,
1435 false,
1436 )
1437 .await;
1438 true
1439 }
1440 Err(e) => {
1441 error!(error = %e.chain(), "unable to reconnect to remote node");
1442 false
1443 }
1444 }
1445 }
1446 }
1447 }
1448
1449 async fn notify_control_plane_subscriptions_lost(
1455 tx_cp: Option<Sender<Result<Message, Status>>>,
1456 local_subs: HashMap<ProtoName, HashSet<u64>>,
1457 conn_index: u64,
1458 ) {
1459 let Some(tx) = tx_cp else { return };
1460 for local_sub in local_subs.into_keys() {
1461 debug!(
1462 %local_sub,
1463 "notify control plane about lost subscription",
1464 );
1465 let msg = Message::builder()
1466 .source(local_sub.clone())
1467 .destination(local_sub.clone())
1468 .flags(SlimHeaderFlags::default().with_recv_from(conn_index))
1469 .build_unsubscribe()
1470 .unwrap();
1471 if let Err(e) = tx.send(Ok(msg)).await {
1472 debug!(
1473 %local_sub,
1474 error = %e.chain(),
1475 "failed to send unsubscribe to control plane",
1476 );
1477 }
1478 }
1479 }
1480
1481 fn process_stream(
1482 &self,
1483 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
1484 conn_index: u64,
1485 client_config: Option<ClientConfig>,
1486 cancellation_token: CancellationToken,
1487 category: ConnType,
1488 from_control_plane: bool,
1489 ) -> Result<JoinHandle<()>, DataPathError> {
1490 let self_clone = self.clone();
1492 let token_clone = cancellation_token.clone();
1493 let client_conf_clone = client_config.clone();
1494 let tx_cp: Option<Sender<Result<Message, Status>>> = self.get_tx_control_plane();
1495 let watch = self.get_drain_watch()?;
1496 let is_local = category.is_local();
1497 let span = tracing::info_span!(
1498 "process_stream",
1499 service_id = %self.internal.service_id,
1500 %conn_index,
1501 is_local,
1502 );
1503 let require_header_mac = self
1504 .forwarder()
1505 .get_connection(conn_index)
1506 .map(|c| c.require_header_mac())
1507 .unwrap_or(false);
1508
1509 let handle = tokio::spawn(async move {
1510 let mut try_to_reconnect = true;
1511
1512 let mut watch = std::pin::pin!(watch.signaled());
1513 loop {
1514 tokio::select! {
1515 next = stream.next() => {
1516 match next {
1517 Some(result) => {
1518 match result {
1519 Ok(msg) => {
1520 if !is_local
1521 && !msg.is_link()
1522 && !msg.is_subscription_ack()
1523 && let Err(e) = self_clone
1524 .verify_remote_header_mac(conn_index, &msg, require_header_mac)
1525 {
1526 error!(
1527 %conn_index,
1528 error = %e.chain(),
1529 "SLIM header integrity verification failed",
1530 );
1531 continue;
1532 }
1533 if !is_local && !from_control_plane && let Some(txcp) = &tx_cp {
1539 match msg.get_type() {
1540 PublishType(_) | LinkType(_) | SubscriptionAckType(_) => {}
1541 _ => {
1542 let _ = txcp.send(Ok(msg.clone())).await;
1545 }
1546 }
1547 }
1548
1549 if let Err(e) = self_clone.handle_new_message(conn_index, category, msg).await {
1550 if matches!(e, DataPathError::NegotiationError(_)) {
1552 error!(%conn_index, "fatal link negotiation error, closing connection");
1553 break;
1554 }
1555 debug!(%conn_index, error = %e.chain(), "error processing incoming message");
1556 if is_local {
1558 self_clone.send_error_to_local_app(conn_index, e).await;
1560 }
1561 }
1562 }
1563 Err(e) => {
1564 if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
1565 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
1566 info!(%conn_index, "connection closed by peer");
1567 }
1568 } else {
1569 error!(error = %e.chain(), "error receiving messages");
1570 }
1571 break;
1572 }
1573 }
1574 }
1575 None => {
1576 debug!(%conn_index, "end of stream");
1577 break;
1578 }
1579 }
1580 }
1581 _ = &mut watch => {
1582 info!(%conn_index, "shutting down stream on drain");
1583 try_to_reconnect = false;
1584 break;
1585 }
1586 _ = token_clone.cancelled() => {
1587 info!(%conn_index, "shutting down stream on cancellation token");
1588 try_to_reconnect = false;
1589 break;
1590 }
1591 }
1592 }
1593
1594 drop(stream);
1598
1599 let is_client_connection = client_conf_clone.is_some();
1602 let mut connected = false;
1603
1604 if try_to_reconnect && let Some(config) = client_conf_clone {
1605 connected = self_clone.reconnect(config, conn_index, &token_clone)
1609 .instrument(tracing::Span::none())
1610 .await;
1611 } else {
1612 debug!(%conn_index, "close connection")
1613 }
1614
1615 if !connected {
1616 let link_id = if !is_local && !is_client_connection {
1619 self_clone
1620 .forwarder()
1621 .get_connection(conn_index)
1622 .and_then(|c| c.link_id())
1623 } else {
1624 None
1625 };
1626
1627 let (local_subs, remote_subs) = self_clone
1629 .forwarder()
1630 .on_connection_drop(conn_index, category);
1631
1632 let recovery_enabled =
1633 !self_clone.internal.recovery_table.ttl().is_zero();
1634
1635 if let Some(lid) = link_id.filter(|_| recovery_enabled) {
1636 info!(
1640 %conn_index, %lid,
1641 "connection lost, storing recovery state (TTL: {:?})",
1642 self_clone.internal.recovery_table.ttl(),
1643 );
1644 self_clone
1645 .internal
1646 .recovery_table
1647 .store(lid.clone(), local_subs, remote_subs);
1648
1649 if let Ok(drain) = self_clone.get_drain_watch() {
1651 let tx_cp_ttl = tx_cp;
1652 let mp = self_clone.clone();
1653 self_clone.internal.recovery_table.spawn_ttl_task(
1654 lid,
1655 drain,
1656 move |entry| async move {
1657 info!("recovery window expired, notifying control plane");
1658 let unreachable = entry
1664 .local_subs
1665 .into_iter()
1666 .filter(|(name, _)| {
1667 mp.forwarder()
1668 .on_publish_msg_match(name.name.unwrap(), u64::MAX, u32::MAX, MatchFilter::ALL)
1669 .is_err()
1670 })
1671 .collect();
1672 MessageProcessor::notify_control_plane_subscriptions_lost(
1673 tx_cp_ttl,
1674 unreachable,
1675 conn_index,
1676 )
1677 .await;
1678 },
1679 );
1680 }
1681 } else {
1682 if !is_local {
1685 MessageProcessor::notify_control_plane_subscriptions_lost(
1686 tx_cp, local_subs, conn_index,
1687 )
1688 .await;
1689 }
1690 }
1691
1692 info!(telemetry = true, counter.num_active_connections = -1);
1693 }
1694 }.instrument(span));
1695
1696 Ok(handle)
1697 }
1698
1699 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
1700 let mut err: &(dyn std::error::Error + 'static) = err_status;
1701
1702 loop {
1703 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
1704 return Some(io_err);
1705 }
1706
1707 if let Some(h2_err) = err.downcast_ref::<h2::Error>()
1710 && let Some(io_err) = h2_err.get_io()
1711 {
1712 return Some(io_err);
1713 }
1714
1715 err = err.source()?;
1716 }
1717 }
1718
1719 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
1720 &self.internal.forwarder.subscription_table
1721 }
1722
1723 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
1724 &self.internal.forwarder.connection_table
1725 }
1726}
1727
1728impl ServerHandler for MessageProcessor {
1729 fn grpc_routes(&self) -> Option<tonic::service::Routes> {
1730 let svc = DataPlaneServiceServer::from_arc(Arc::new(self.clone()));
1731 Some(tonic::service::Routes::new(svc))
1732 }
1733
1734 fn on_websocket_accepted(&self) -> Option<websocket_server::OnAcceptedWebSocket> {
1735 let processor = self.clone();
1736 Some(Arc::new(move |accepted| {
1737 let processor = processor.clone();
1738 Box::pin(async move { processor.handle_websocket_accepted(accepted).await })
1739 }))
1740 }
1741}
1742
1743#[tonic::async_trait]
1744impl DataPlaneService for MessageProcessor {
1745 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
1746
1747 async fn open_channel(
1748 &self,
1749 request: Request<tonic::Streaming<Message>>,
1750 ) -> Result<Response<Self::OpenChannelStream>, Status> {
1751 let remote_addr = request.remote_addr();
1752 let local_addr = request.local_addr();
1753
1754 let stream = request.into_inner();
1755 let (tx, rx) = mpsc::channel(128);
1756
1757 let connection = Connection::new(ConnType::Remote, Channel::Server(tx))
1758 .with_remote_addr(remote_addr)
1759 .with_local_addr(local_addr)
1760 .with_require_header_mac(self.internal.server_require_header_mac);
1761
1762 debug!(
1763 remote = ?connection.remote_addr(),
1764 local = ?connection.local_addr(),
1765 "new connection received from remote",
1766 );
1767 info!(telemetry = true, counter.num_active_connections = 1);
1768
1769 let conn_index = self
1771 .forwarder()
1772 .on_connection_established(connection, None)
1773 .unwrap();
1774
1775 self.process_stream(
1776 stream,
1777 conn_index,
1778 None,
1779 CancellationToken::new(),
1780 ConnType::Remote,
1781 false,
1782 )
1783 .map_err(|e| {
1784 error!(error = %e.chain(), "error starting new processing stream");
1785 Status::unavailable(format!("error processing stream: {:?}", e))
1786 })?;
1787
1788 let out_stream = ReceiverStream::new(rx);
1789 Ok(Response::new(
1790 Box::pin(out_stream) as Self::OpenChannelStream
1791 ))
1792 }
1793}
1794
1795#[cfg(test)]
1796mod tests {
1797 use slim_config::client::ClientConfig;
1798 use slim_config::grpc::client::is_valid_uuid_v4;
1799 use std::sync::Arc;
1800 use std::time::Duration;
1801
1802 use super::*;
1803 use crate::api::{ProtoName, ProtoSubscriptionAck};
1804 use crate::header_mac::HeaderMacSession;
1805 use crate::tables::remote_subscription_table::SubscriptionInfo;
1806 use tonic::Status;
1807
1808 async fn assert_failed_subscription_ack_is_sent(add: bool) {
1809 let processor = MessageProcessor::new();
1810 let (in_connection, _tx, mut rx) = processor
1811 .register_local_connection(false)
1812 .expect("failed to create local connection");
1813
1814 let source = ProtoName::from_strings(["org", "ns", "source"]).with_id(1);
1815 let destination = ProtoName::from_strings(["org", "ns", "destination"]).with_id(2);
1816 let ack_id: u64 = if add { 1 } else { 2 };
1817 let invalid_connection = u64::MAX - 1;
1818
1819 let builder = Message::builder()
1820 .source(source.clone())
1821 .destination(destination.clone())
1822 .incoming_conn(invalid_connection)
1823 .subscription_id(ack_id);
1824
1825 let msg = if add {
1826 builder.build_subscribe().unwrap()
1827 } else {
1828 builder.build_unsubscribe().unwrap()
1829 };
1830
1831 let result = processor
1832 .process_subscription(msg, in_connection, add)
1833 .await;
1834 assert!(matches!(
1835 result,
1836 Err(DataPathError::MessageProcessingError { .. })
1837 ));
1838
1839 let ack_msg = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1840 .await
1841 .expect("timeout waiting for ack")
1842 .expect("ack channel closed")
1843 .expect("failed to receive ack message");
1844
1845 assert!(matches!(ack_msg.get_type(), SubscriptionAckType(_)));
1846 let ack = ack_msg.get_subscription_ack();
1847 assert_eq!(ack.subscription_id, ack_id);
1848 assert!(!ack.success, "failed ack should have success=false");
1849 assert!(
1850 !ack.error.is_empty(),
1851 "failed ack should include an error message"
1852 );
1853 }
1854
1855 #[tokio::test]
1856 async fn test_process_subscription_sends_failed_ack_on_subscribe_error() {
1857 assert_failed_subscription_ack_is_sent(true).await;
1858 }
1859
1860 #[tokio::test]
1861 async fn test_process_subscription_sends_failed_ack_on_unsubscribe_error() {
1862 assert_failed_subscription_ack_is_sent(false).await;
1863 }
1864
1865 #[test]
1866 fn test_is_valid_uuid_v4_accepts_v4() {
1867 let id = uuid::Uuid::new_v4().to_string();
1868 assert!(is_valid_uuid_v4(&id));
1869 }
1870
1871 #[test]
1872 fn test_is_valid_uuid_v4_rejects_non_uuid_string() {
1873 assert!(!is_valid_uuid_v4("not-a-uuid"));
1874 assert!(!is_valid_uuid_v4(""));
1875 }
1876
1877 #[test]
1878 fn test_is_valid_uuid_v4_rejects_non_v4_uuid() {
1879 assert!(!is_valid_uuid_v4("00000000-0000-1000-8000-000000000000"));
1881 }
1882
1883 #[tokio::test]
1886 async fn test_handle_link_message_is_local_ignored() {
1887 let processor = MessageProcessor::new();
1888 let link = ProtoLink { link_type: None };
1889 assert!(
1890 processor
1891 .handle_link_message(link, 0, ConnType::Local)
1892 .await
1893 .is_ok()
1894 );
1895 }
1896
1897 #[tokio::test]
1898 async fn test_handle_link_message_none_link_type_ignored() {
1899 let processor = MessageProcessor::new();
1900 let link = ProtoLink { link_type: None };
1901 assert!(
1902 processor
1903 .handle_link_message(link, 0, ConnType::Remote)
1904 .await
1905 .is_ok()
1906 );
1907 }
1908
1909 fn make_server_conn(
1912 processor: &MessageProcessor,
1913 ) -> (u64, tokio::sync::mpsc::Receiver<Result<Message, Status>>) {
1914 let (tx, rx) = mpsc::channel(16);
1915 let conn = Connection::new(ConnType::Remote, Channel::Server(tx))
1916 .with_require_header_mac(processor.internal.server_require_header_mac);
1917 let conn_id = processor
1918 .forwarder()
1919 .on_connection_established(conn, None)
1920 .unwrap();
1921 (conn_id, rx)
1922 }
1923
1924 fn make_client_conn(
1925 processor: &MessageProcessor,
1926 ) -> (u64, tokio::sync::mpsc::Receiver<Message>) {
1927 let (tx, rx) = mpsc::channel(16);
1928 let conn = Connection::new(ConnType::Remote, Channel::Client(tx))
1929 .with_config_data(Some(ClientConfig::default()));
1930 let conn_id = processor
1931 .forwarder()
1932 .on_connection_established(conn, None)
1933 .unwrap();
1934 (conn_id, rx)
1935 }
1936
1937 #[tokio::test]
1938 async fn test_handle_link_negotiation_unknown_connection_ignored() {
1939 let processor = MessageProcessor::new();
1940 let payload = LinkNegotiationPayload {
1941 link_id: uuid::Uuid::new_v4().to_string(),
1942 slim_version: "1.0.0".into(),
1943 is_reply: false,
1944 link_ecdh_public_key: vec![],
1945 };
1946 assert!(
1947 processor
1948 .handle_link_negotiation(&payload, u64::MAX)
1949 .await
1950 .is_ok()
1951 );
1952 }
1953
1954 #[tokio::test]
1955 async fn test_handle_link_negotiation_role_outgoing_receives_request_ignored() {
1956 let processor = MessageProcessor::new();
1957 let (conn_id, _rx) = make_client_conn(&processor);
1958 let payload = LinkNegotiationPayload {
1959 link_id: uuid::Uuid::new_v4().to_string(),
1960 slim_version: "1.0.0".into(),
1961 is_reply: false, link_ecdh_public_key: vec![],
1963 };
1964 assert!(
1965 processor
1966 .handle_link_negotiation(&payload, conn_id)
1967 .await
1968 .is_ok()
1969 );
1970 assert!(
1971 processor
1972 .forwarder()
1973 .get_connection(conn_id)
1974 .unwrap()
1975 .remote_slim_version()
1976 .is_none()
1977 );
1978 }
1979
1980 #[tokio::test]
1981 async fn test_handle_link_negotiation_role_incoming_receives_reply_ignored() {
1982 let processor = MessageProcessor::new();
1983 let (conn_id, _rx) = make_server_conn(&processor);
1984 let payload = LinkNegotiationPayload {
1985 link_id: uuid::Uuid::new_v4().to_string(),
1986 slim_version: "1.0.0".into(),
1987 is_reply: true, link_ecdh_public_key: vec![],
1989 };
1990 assert!(
1991 processor
1992 .handle_link_negotiation(&payload, conn_id)
1993 .await
1994 .is_ok()
1995 );
1996 assert!(
1997 processor
1998 .forwarder()
1999 .get_connection(conn_id)
2000 .unwrap()
2001 .remote_slim_version()
2002 .is_none()
2003 );
2004 }
2005
2006 #[tokio::test]
2007 async fn test_handle_link_negotiation_unparsable_version_ignored() {
2008 let processor = MessageProcessor::new();
2009 let (conn_id, _rx) = make_server_conn(&processor);
2010 let payload = LinkNegotiationPayload {
2011 link_id: uuid::Uuid::new_v4().to_string(),
2012 slim_version: "not-semver".into(),
2013 is_reply: false,
2014 link_ecdh_public_key: vec![],
2015 };
2016 assert!(
2017 processor
2018 .handle_link_negotiation(&payload, conn_id)
2019 .await
2020 .is_ok()
2021 );
2022 assert!(
2023 processor
2024 .forwarder()
2025 .get_connection(conn_id)
2026 .unwrap()
2027 .remote_slim_version()
2028 .is_none()
2029 );
2030 }
2031
2032 #[tokio::test]
2033 async fn test_handle_link_negotiation_server_invalid_uuid_ignored() {
2034 let processor = MessageProcessor::new();
2035 let (conn_id, _rx) = make_server_conn(&processor);
2036 let payload = LinkNegotiationPayload {
2037 link_id: "not-a-uuid".into(),
2038 slim_version: "1.0.0".into(),
2039 is_reply: false,
2040 link_ecdh_public_key: vec![],
2041 };
2042 assert!(
2043 processor
2044 .handle_link_negotiation(&payload, conn_id)
2045 .await
2046 .is_ok()
2047 );
2048 assert!(
2049 processor
2050 .forwarder()
2051 .get_connection(conn_id)
2052 .unwrap()
2053 .remote_slim_version()
2054 .is_none()
2055 );
2056 }
2057
2058 #[tokio::test]
2059 async fn test_handle_link_negotiation_server_strict_rejects_missing_ecdh() {
2060 let mut server_config = ServerConfig::with_endpoint("127.0.0.1:0");
2061 server_config.require_header_mac = true;
2062 let processor =
2063 MessageProcessor::new_with_server_config("test".into(), &server_config, None);
2064 let (conn_id, _rx) = make_server_conn(&processor);
2065 let payload = LinkNegotiationPayload {
2066 link_id: uuid::Uuid::new_v4().to_string(),
2067 slim_version: "1.2.3".into(),
2068 is_reply: false,
2069 link_ecdh_public_key: vec![],
2070 };
2071 let err = processor
2072 .handle_link_negotiation(&payload, conn_id)
2073 .await
2074 .expect_err("strict mode must reject negotiation without peer ECDH");
2075 assert!(matches!(err, DataPathError::NegotiationError(_)));
2076 let conn = processor.forwarder().get_connection(conn_id).unwrap();
2077 assert!(conn.remote_slim_version().is_none());
2078 }
2079
2080 #[tokio::test]
2081 async fn test_handle_link_negotiation_server_happy_path() {
2082 let processor = MessageProcessor::new();
2083 let (conn_id, mut rx) = make_server_conn(&processor);
2084 let link_id = uuid::Uuid::new_v4().to_string();
2085 let payload = LinkNegotiationPayload {
2086 link_id: link_id.clone(),
2087 slim_version: "1.2.3".into(),
2088 is_reply: false,
2089 link_ecdh_public_key: vec![],
2090 };
2091 assert!(
2092 processor
2093 .handle_link_negotiation(&payload, conn_id)
2094 .await
2095 .is_ok()
2096 );
2097 let conn = processor.forwarder().get_connection(conn_id).unwrap();
2098 assert_eq!(conn.link_id(), Some(link_id));
2099 assert_eq!(
2100 conn.remote_slim_version(),
2101 Some(semver::Version::parse("1.2.3").unwrap())
2102 );
2103 let reply = rx.try_recv().expect("reply should be sent").unwrap();
2105 assert!(reply.is_link());
2106 }
2107
2108 #[tokio::test]
2109 async fn test_handle_link_negotiation_server_replay_protection() {
2110 let processor = MessageProcessor::new();
2111 let (conn_id, mut rx) = make_server_conn(&processor);
2112 let link_id = uuid::Uuid::new_v4().to_string();
2113 let payload = LinkNegotiationPayload {
2114 link_id: link_id.clone(),
2115 slim_version: "1.0.0".into(),
2116 is_reply: false,
2117 link_ecdh_public_key: vec![],
2118 };
2119 assert!(
2121 processor
2122 .handle_link_negotiation(&payload, conn_id)
2123 .await
2124 .is_ok()
2125 );
2126 assert!(rx.try_recv().is_ok());
2127 assert!(
2129 processor
2130 .handle_link_negotiation(&payload, conn_id)
2131 .await
2132 .is_ok()
2133 );
2134 assert!(rx.try_recv().is_err());
2135 }
2136
2137 #[tokio::test]
2138 async fn test_handle_link_negotiation_client_happy_path() {
2139 let processor = MessageProcessor::new();
2140 let (conn_id, _rx) = make_client_conn(&processor);
2141 let link_id = uuid::Uuid::new_v4().to_string();
2142 let conn = processor.forwarder().get_connection(conn_id).unwrap();
2143 conn.set_link_id(link_id.clone());
2144 let payload = LinkNegotiationPayload {
2145 link_id: link_id.clone(),
2146 slim_version: "2.0.0".into(),
2147 is_reply: true,
2148 link_ecdh_public_key: vec![],
2149 };
2150 assert!(
2151 processor
2152 .handle_link_negotiation(&payload, conn_id)
2153 .await
2154 .is_ok()
2155 );
2156 assert_eq!(
2157 conn.remote_slim_version(),
2158 Some(semver::Version::parse("2.0.0").unwrap())
2159 );
2160 }
2161
2162 #[tokio::test]
2163 async fn test_handle_link_negotiation_client_link_id_mismatch_ignored() {
2164 let processor = MessageProcessor::new();
2165 let (conn_id, _rx) = make_client_conn(&processor);
2166 let conn = processor.forwarder().get_connection(conn_id).unwrap();
2167 conn.set_link_id("correct-id".to_string());
2168 let payload = LinkNegotiationPayload {
2169 link_id: "wrong-id".into(),
2170 slim_version: "1.0.0".into(),
2171 is_reply: true,
2172 link_ecdh_public_key: vec![],
2173 };
2174 assert!(
2175 processor
2176 .handle_link_negotiation(&payload, conn_id)
2177 .await
2178 .is_ok()
2179 );
2180 assert!(conn.remote_slim_version().is_none());
2181 }
2182
2183 #[tokio::test]
2184 async fn test_handle_link_negotiation_client_replay_protection() {
2185 let processor = MessageProcessor::new();
2186 let (conn_id, _rx) = make_client_conn(&processor);
2187 let link_id = uuid::Uuid::new_v4().to_string();
2188 let conn = processor.forwarder().get_connection(conn_id).unwrap();
2189 conn.set_link_id(link_id.clone());
2190 let payload = LinkNegotiationPayload {
2191 link_id: link_id.clone(),
2192 slim_version: "1.0.0".into(),
2193 is_reply: true,
2194 link_ecdh_public_key: vec![],
2195 };
2196 assert!(
2198 processor
2199 .handle_link_negotiation(&payload, conn_id)
2200 .await
2201 .is_ok()
2202 );
2203 let stored = conn.remote_slim_version();
2204 assert!(stored.is_some());
2205 assert!(
2207 processor
2208 .handle_link_negotiation(&payload, conn_id)
2209 .await
2210 .is_ok()
2211 );
2212 assert_eq!(conn.remote_slim_version(), stored);
2213 }
2214
2215 fn negotiate_conn(processor: &MessageProcessor, conn_id: u64, version: &str) {
2220 let c = processor.forwarder().get_connection(conn_id).unwrap();
2221 c.complete_negotiation_as_server(
2222 &uuid::Uuid::new_v4().to_string(),
2223 semver::Version::parse(version).unwrap(),
2224 );
2225 c.test_install_header_mac(Arc::new(
2226 HeaderMacSession::new(b"01234567890123456789012345678901").unwrap(),
2227 ));
2228 }
2229
2230 #[tokio::test]
2231 async fn test_await_link_hmac_ready_timeout_configurable() {
2232 let server_config = ServerConfig {
2233 endpoint: "localhost:12345".to_string(),
2234 link_hmac_timeout_secs: 1, link_hmac_poll_interval_ms: 10, ..Default::default()
2237 };
2238 let processor = MessageProcessor::new_with_server_config(
2239 "test_service".to_string(),
2240 &server_config,
2241 None,
2242 );
2243
2244 assert_eq!(
2245 processor.internal.link_hmac_timeout,
2246 std::time::Duration::from_secs(1)
2247 );
2248 assert_eq!(
2249 processor.internal.link_hmac_poll_interval,
2250 std::time::Duration::from_millis(10)
2251 );
2252
2253 let (conn_id, _tx, _rx) = processor
2255 .register_local_connection(false)
2256 .expect("failed to register local connection");
2257
2258 let start = std::time::Instant::now();
2260 let result = processor.await_link_hmac_ready(conn_id, true).await;
2261 let elapsed = start.elapsed();
2262
2263 assert!(result.is_err());
2264 assert!(elapsed >= std::time::Duration::from_millis(900));
2265 assert!(elapsed < std::time::Duration::from_secs(3));
2266 }
2267
2268 #[test]
2269 fn verify_remote_header_mac_strict_rejects_publish_without_mac_session() {
2270 let processor = MessageProcessor::new();
2271 let (remote_conn, _rx) = make_server_conn(&processor);
2272 let conn = processor.forwarder().get_connection(remote_conn).unwrap();
2273 conn.complete_negotiation_as_server(
2274 &uuid::Uuid::new_v4().to_string(),
2275 semver::Version::parse("1.2.0").unwrap(),
2276 );
2277 assert!(conn.header_hmac().is_none());
2278
2279 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2280 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2281 let msg = ProtoMessage::builder()
2282 .source(source)
2283 .destination(dest)
2284 .application_payload("text/plain", b"hey".to_vec())
2285 .build_publish()
2286 .expect("publish");
2287
2288 let err = processor
2289 .verify_remote_header_mac(remote_conn, &msg, true)
2290 .expect_err("unsigned publish must fail in strict mode without MAC session");
2291 assert!(matches!(err, DataPathError::NegotiationError(_)));
2292 }
2293
2294 #[test]
2295 fn verify_remote_header_mac_accepts_signed_inter_node_publish() {
2296 let processor = MessageProcessor::new();
2297 let (remote_conn, _rx) = make_server_conn(&processor);
2298 negotiate_conn(&processor, remote_conn, "1.2.0");
2299 let link_id = processor
2300 .forwarder()
2301 .get_connection(remote_conn)
2302 .unwrap()
2303 .link_id()
2304 .expect("link id after negotiation");
2305
2306 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2307 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2308 let require_header_mac = true;
2309 let mut msg = ProtoMessage::builder()
2310 .source(source)
2311 .destination(dest)
2312 .application_payload("text/plain", b"hey".to_vec())
2313 .build_publish()
2314 .expect("publish");
2315
2316 let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2317 mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2318 .expect("sign header");
2319
2320 assert!(
2321 processor
2322 .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2323 .is_ok()
2324 );
2325 }
2326
2327 #[test]
2328 fn verify_remote_header_mac_rejects_destination_tamper_after_sign() {
2329 let processor = MessageProcessor::new();
2330 let (remote_conn, _rx) = make_server_conn(&processor);
2331 negotiate_conn(&processor, remote_conn, "1.2.0");
2332 let link_id = processor
2333 .forwarder()
2334 .get_connection(remote_conn)
2335 .unwrap()
2336 .link_id()
2337 .expect("link id after negotiation");
2338
2339 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2340 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2341 let mut msg = ProtoMessage::builder()
2342 .source(source)
2343 .destination(dest)
2344 .application_payload("text/plain", b"hey".to_vec())
2345 .build_publish()
2346 .expect("publish");
2347
2348 let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2349 let require_header_mac = true;
2350 mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2351 .expect("sign header");
2352
2353 let header = msg.get_slim_header_mut();
2354 if let Some(dest) = header.destination.as_mut()
2355 && let Some(sn) = dest.str_name.as_mut()
2356 {
2357 sn.str_component_2.push_str("-integrity-test-tamper");
2358 }
2359
2360 let err = processor
2361 .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2362 .expect_err("tampered header must fail MAC verify");
2363 assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2364 }
2365
2366 #[tokio::test]
2367 #[allow(clippy::disallowed_methods)]
2368 async fn test_send_msg_raw_tamper_destination_env_var() {
2369 let _guard = ENV_LOCK.lock().await;
2370 unsafe {
2371 std::env::set_var("SLIM_TEST_TAMPER_DESTINATION", "1");
2372 }
2373
2374 let processor = MessageProcessor::new();
2375 let (conn_id, mut rx) = make_server_conn(&processor);
2376 negotiate_conn(&processor, conn_id, "1.2.0");
2377
2378 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2379 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2380 let msg = ProtoMessage::builder()
2381 .source(source)
2382 .destination(dest)
2383 .application_payload("text/plain", b"hey".to_vec())
2384 .build_publish()
2385 .expect("publish");
2386
2387 processor
2388 .send_msg_raw(msg, conn_id)
2389 .await
2390 .expect("send_msg_raw failed");
2391
2392 let sent_msg = rx.recv().await.unwrap().unwrap();
2393 let header = sent_msg.get_slim_header();
2394 let dest_name = header.destination.as_ref().expect("destination");
2395 let str_name = dest_name.str_name.as_ref().expect("str_name");
2396 let require_header_mac = true;
2397
2398 assert!(str_name.str_component_2.ends_with("-integrity-test-tamper"));
2400
2401 let err = processor
2403 .verify_remote_header_mac(conn_id, &sent_msg, require_header_mac)
2404 .expect_err("tampered header must fail MAC verify");
2405 assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2406
2407 unsafe {
2408 std::env::remove_var("SLIM_TEST_TAMPER_DESTINATION");
2409 }
2410 }
2411
2412 #[tokio::test]
2413 async fn test_process_subscription_remote_ack_path_success() {
2414 let processor = MessageProcessor::new();
2417 let (local_conn, _tx_local, mut rx_local) = processor
2418 .register_local_connection(false)
2419 .expect("failed to create local connection");
2420
2421 let (remote_conn, mut rx_remote) = make_server_conn(&processor);
2422 negotiate_conn(&processor, remote_conn, "1.2.0");
2423
2424 let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2425 let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2426 let upstream_ack_id: u64 = 100;
2427
2428 let sub_msg = Message::builder()
2430 .source(source.clone())
2431 .destination(destination.clone())
2432 .incoming_conn(local_conn)
2433 .forward_to(remote_conn)
2434 .subscription_id(upstream_ack_id)
2435 .build_subscribe()
2436 .unwrap();
2437
2438 let result = processor
2440 .process_subscription(sub_msg, local_conn, true)
2441 .await;
2442 assert!(result.is_ok());
2443
2444 let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2447 .await
2448 .expect("timeout waiting for forwarded subscribe")
2449 .expect("forwarded subscribe channel closed")
2450 .unwrap();
2451 assert!(matches!(forwarded.get_type(), SubscribeType(_)));
2452
2453 let forwarded_sub_id = forwarded
2455 .get_subscription_id()
2456 .expect("forwarded subscribe must carry the same subscription_id");
2457 assert_eq!(
2458 forwarded_sub_id, upstream_ack_id,
2459 "subscription_id must not change when forwarding"
2460 );
2461
2462 let ack = ProtoSubscriptionAck {
2464 subscription_id: upstream_ack_id,
2465 success: true,
2466 error: String::new(),
2467 };
2468 processor.internal.sub_ack_manager.resolve(
2469 ack.subscription_id,
2470 if ack.success {
2471 Ok(())
2472 } else {
2473 Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2474 },
2475 );
2476
2477 let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2479 .await
2480 .expect("timeout waiting for upstream ack")
2481 .expect("upstream ack channel closed")
2482 .expect("upstream ack should be Ok");
2483
2484 assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2485 let ack_inner = upstream_ack.get_subscription_ack();
2486 assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2487 assert!(ack_inner.success);
2488 }
2489
2490 #[tokio::test]
2491 async fn test_process_subscription_remote_ack_error_forwarded_upstream() {
2492 let processor = MessageProcessor::new();
2494 let (local_conn, _tx_local, mut rx_local) = processor
2495 .register_local_connection(false)
2496 .expect("failed to create local connection");
2497
2498 let (remote_conn, mut rx_remote) = make_server_conn(&processor);
2499 negotiate_conn(&processor, remote_conn, "1.2.0");
2500
2501 let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2502 let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2503 let upstream_ack_id: u64 = 102;
2504
2505 let sub_msg = Message::builder()
2506 .source(source.clone())
2507 .destination(destination.clone())
2508 .incoming_conn(local_conn)
2509 .forward_to(remote_conn)
2510 .subscription_id(upstream_ack_id)
2511 .build_subscribe()
2512 .unwrap();
2513
2514 processor
2515 .process_subscription(sub_msg, local_conn, true)
2516 .await
2517 .unwrap();
2518
2519 let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2520 .await
2521 .expect("timeout")
2522 .expect("channel closed")
2523 .unwrap();
2524
2525 let forwarded_sub_id = forwarded
2526 .get_subscription_id()
2527 .expect("forwarded subscribe must carry the same subscription_id");
2528 assert_eq!(
2529 forwarded_sub_id, upstream_ack_id,
2530 "subscription_id must not change when forwarding"
2531 );
2532
2533 let ack = ProtoSubscriptionAck {
2535 subscription_id: upstream_ack_id,
2536 success: false,
2537 error: "remote error".to_string(),
2538 };
2539 processor.internal.sub_ack_manager.resolve(
2540 ack.subscription_id,
2541 if ack.success {
2542 Ok(())
2543 } else {
2544 Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2545 },
2546 );
2547
2548 let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2549 .await
2550 .expect("timeout")
2551 .expect("channel closed")
2552 .expect("must be Ok");
2553
2554 assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2555 let ack_inner = upstream_ack.get_subscription_ack();
2556 assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2557 assert!(!ack_inner.success);
2558 assert!(!ack_inner.error.is_empty());
2559 }
2560
2561 fn make_test_subscribe(sub_id: u64) -> Message {
2564 let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2565 let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2566 Message::builder()
2567 .source(source)
2568 .destination(destination)
2569 .subscription_id(sub_id)
2570 .build_subscribe()
2571 .unwrap()
2572 }
2573
2574 #[tokio::test(start_paused = true)]
2575 async fn test_retry_loop_ack_received_before_timeout() {
2576 let processor = MessageProcessor::new();
2578 let (local_conn, _tx_local, mut rx_local) = processor
2579 .register_local_connection(false)
2580 .expect("failed to create local connection");
2581 let (remote_conn, mut rx_remote) = make_server_conn(&processor);
2582
2583 let sub_id: u64 = 1000;
2584 let msg = make_test_subscribe(sub_id);
2585 let rx = processor.internal.sub_ack_manager.register(sub_id);
2586
2587 let proc_clone = processor.clone();
2588 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2589 proc_clone,
2590 sub_id,
2591 msg,
2592 remote_conn,
2593 local_conn,
2594 Some(sub_id),
2595 rx,
2596 ));
2597
2598 processor.internal.sub_ack_manager.resolve(sub_id, Ok(()));
2600
2601 handle.await.unwrap();
2602
2603 assert!(
2605 rx_remote.try_recv().is_err(),
2606 "no retry send expected when ack arrives before timeout"
2607 );
2608
2609 let ack = rx_local
2611 .try_recv()
2612 .expect("upstream ack should have been sent")
2613 .unwrap();
2614 assert!(ack.get_subscription_ack().success);
2615 }
2616
2617 #[tokio::test(start_paused = true)]
2618 async fn test_retry_loop_timeout_then_retry_send_then_ack() {
2619 let processor = MessageProcessor::new();
2621 let (local_conn, _tx_local, mut rx_local) = processor
2622 .register_local_connection(false)
2623 .expect("failed to create local connection");
2624 let (remote_conn, mut rx_remote) = make_server_conn(&processor);
2625
2626 let sub_id: u64 = 1001;
2627 let msg = make_test_subscribe(sub_id);
2628 let rx = processor.internal.sub_ack_manager.register(sub_id);
2629
2630 let proc_clone = processor.clone();
2631 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2632 proc_clone,
2633 sub_id,
2634 msg,
2635 remote_conn,
2636 local_conn,
2637 Some(sub_id),
2638 rx,
2639 ));
2640
2641 tokio::time::sleep(crate::subscription_ack::TIMEOUT + Duration::from_millis(100)).await;
2643
2644 let retried = rx_remote
2646 .try_recv()
2647 .expect("retry send expected after first timeout")
2648 .unwrap();
2649 assert!(retried.get_subscription_id().is_some());
2650
2651 processor.internal.sub_ack_manager.resolve(sub_id, Ok(()));
2653
2654 handle.await.unwrap();
2655
2656 let ack = rx_local
2658 .try_recv()
2659 .expect("upstream ack should have been sent")
2660 .unwrap();
2661 assert!(ack.get_subscription_ack().success);
2662 }
2663
2664 #[tokio::test(start_paused = true)]
2665 async fn test_retry_loop_retry_send_fails() {
2666 let processor = MessageProcessor::new();
2669 let (local_conn, _tx_local, mut rx_local) = processor
2670 .register_local_connection(false)
2671 .expect("failed to create local connection");
2672 let (remote_conn, _rx_remote) = make_server_conn(&processor);
2673
2674 let sub_id: u64 = 1002;
2675 let msg = make_test_subscribe(sub_id);
2676 let rx = processor.internal.sub_ack_manager.register(sub_id);
2677
2678 processor.connection_table().remove(remote_conn);
2680
2681 let proc_clone = processor.clone();
2682 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2683 proc_clone,
2684 sub_id,
2685 msg,
2686 remote_conn,
2687 local_conn,
2688 Some(sub_id),
2689 rx,
2690 ));
2691
2692 tokio::time::sleep(crate::subscription_ack::TIMEOUT + Duration::from_millis(100)).await;
2694
2695 handle.await.unwrap();
2696
2697 let ack = rx_local
2699 .try_recv()
2700 .expect("upstream ack should have been sent")
2701 .unwrap();
2702 assert!(!ack.get_subscription_ack().success);
2703 }
2704
2705 #[tokio::test(start_paused = true)]
2706 async fn test_retry_loop_all_retries_exhausted() {
2707 let processor = MessageProcessor::new();
2709 let (local_conn, _tx_local, mut rx_local) = processor
2710 .register_local_connection(false)
2711 .expect("failed to create local connection");
2712 let (remote_conn, mut rx_remote) = make_server_conn(&processor);
2713
2714 let sub_id: u64 = 1003;
2715 let msg = make_test_subscribe(sub_id);
2716 let rx = processor.internal.sub_ack_manager.register(sub_id);
2717
2718 let proc_clone = processor.clone();
2719 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2720 proc_clone,
2721 sub_id,
2722 msg,
2723 remote_conn,
2724 local_conn,
2725 Some(sub_id),
2726 rx,
2727 ));
2728
2729 for _ in 0..=crate::subscription_ack::MAX_RETRIES {
2731 tokio::time::sleep(crate::subscription_ack::TIMEOUT + Duration::from_millis(100)).await;
2732 }
2733
2734 handle.await.unwrap();
2735
2736 let mut retry_count = 0;
2739 while rx_remote.try_recv().is_ok() {
2740 retry_count += 1;
2741 }
2742 assert_eq!(
2743 retry_count,
2744 crate::subscription_ack::MAX_RETRIES as usize,
2745 "expected {} retry sends",
2746 crate::subscription_ack::MAX_RETRIES,
2747 );
2748
2749 let ack = rx_local
2751 .try_recv()
2752 .expect("upstream ack should have been sent")
2753 .unwrap();
2754 let ack_inner = ack.get_subscription_ack();
2755 assert!(
2756 !ack_inner.success,
2757 "ack must indicate failure after exhausting retries"
2758 );
2759 assert!(!ack_inner.error.is_empty());
2760 }
2761
2762 #[tokio::test(start_paused = true)]
2763 async fn test_retry_loop_no_upstream_subscription_id() {
2764 let processor = MessageProcessor::new();
2766 let (_local_conn, _tx_local, mut rx_local) = processor
2767 .register_local_connection(false)
2768 .expect("failed to create local connection");
2769 let (remote_conn, _rx_remote) = make_server_conn(&processor);
2770
2771 let sub_id: u64 = 1004;
2772 let msg = make_test_subscribe(sub_id);
2773 let rx = processor.internal.sub_ack_manager.register(sub_id);
2774
2775 let proc_clone = processor.clone();
2776 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2777 proc_clone,
2778 sub_id,
2779 msg,
2780 remote_conn,
2781 0, None,
2783 rx,
2784 ));
2785
2786 processor.internal.sub_ack_manager.resolve(sub_id, Ok(()));
2788
2789 handle.await.unwrap();
2790
2791 assert!(
2793 rx_local.try_recv().is_err(),
2794 "no upstream ack when upstream_subscription_id is None"
2795 );
2796 }
2797
2798 #[tokio::test(start_paused = true)]
2799 async fn test_retry_loop_sender_dropped() {
2800 let processor = MessageProcessor::new();
2803 let (local_conn, _tx_local, mut rx_local) = processor
2804 .register_local_connection(false)
2805 .expect("failed to create local connection");
2806 let (remote_conn, _rx_remote) = make_server_conn(&processor);
2807
2808 let sub_id: u64 = 1005;
2809 let msg = make_test_subscribe(sub_id);
2810 let rx = processor.internal.sub_ack_manager.register(sub_id);
2811
2812 processor.internal.sub_ack_manager.remove(sub_id);
2814
2815 let proc_clone = processor.clone();
2816 let handle = tokio::spawn(crate::subscription_ack::retry_loop(
2817 proc_clone,
2818 sub_id,
2819 msg,
2820 remote_conn,
2821 local_conn,
2822 Some(sub_id),
2823 rx,
2824 ));
2825
2826 handle.await.unwrap();
2827
2828 let ack = rx_local
2830 .try_recv()
2831 .expect("upstream ack should have been sent")
2832 .unwrap();
2833 assert!(!ack.get_subscription_ack().success);
2834 }
2835
2836 #[test]
2839 fn test_new_with_options_custom_ttl() {
2840 let processor =
2841 MessageProcessor::new_with_options("svc".into(), Some(Duration::from_secs(5)));
2842 assert_eq!(
2843 processor.internal.recovery_table.ttl(),
2844 Duration::from_secs(5)
2845 );
2846 }
2847
2848 #[test]
2849 fn test_new_with_options_none_uses_default() {
2850 let processor = MessageProcessor::new_with_options("svc".into(), None);
2851 assert_eq!(
2852 processor.internal.recovery_table.ttl(),
2853 Duration::from_secs(30)
2854 );
2855 }
2856
2857 #[test]
2858 fn test_new_with_options_zero_ttl() {
2859 let processor = MessageProcessor::new_with_options("svc".into(), Some(Duration::ZERO));
2860 assert!(processor.internal.recovery_table.ttl().is_zero());
2861 }
2862
2863 #[tokio::test]
2866 async fn test_notify_cp_subs_lost_sends_unsubscribes() {
2867 let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2868 let mut subs = HashMap::new();
2869 let name = ProtoName::from_strings(["org", "default", "svc"]);
2870 subs.insert(name.clone(), HashSet::from([1u64, 2u64]));
2871
2872 MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), subs, 42).await;
2873
2874 let msg = rx.recv().await.unwrap().unwrap();
2875 assert!(matches!(msg.get_type(), UnsubscribeType(_)));
2876 assert_eq!(msg.get_source(), name.clone());
2877 }
2878
2879 #[tokio::test]
2880 async fn test_notify_cp_subs_lost_no_tx_is_noop() {
2881 let subs = HashMap::from([(
2882 ProtoName::from_strings(["org", "default", "svc"]),
2883 HashSet::from([1u64]),
2884 )]);
2885 MessageProcessor::notify_control_plane_subscriptions_lost(None, subs, 1).await;
2887 }
2888
2889 #[tokio::test]
2890 async fn test_notify_cp_subs_lost_empty_subs() {
2891 let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2892 MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), HashMap::new(), 1)
2893 .await;
2894 assert!(rx.try_recv().is_err());
2896 }
2897
2898 #[tokio::test]
2901 async fn test_link_negotiation_server_triggers_route_recovery() {
2902 let processor = MessageProcessor::new();
2903 let (conn_id, _rx) = make_server_conn(&processor);
2904
2905 let link_id = uuid::Uuid::new_v4().to_string();
2906 let sub_name = ProtoName::from_strings(["org", "default", "recovered"]);
2907
2908 let mut local_subs = HashMap::new();
2910 local_subs.insert(sub_name.clone(), HashSet::from([99u64]));
2911 processor
2912 .internal
2913 .recovery_table
2914 .store(link_id.clone(), local_subs, HashSet::new());
2915
2916 let payload = LinkNegotiationPayload {
2918 link_id: link_id.clone(),
2919 slim_version: "1.0.0".into(),
2920 is_reply: false,
2921 link_ecdh_public_key: vec![],
2922 };
2923 processor
2924 .handle_link_negotiation(&payload, conn_id)
2925 .await
2926 .unwrap();
2927
2928 let result = processor.forwarder().on_publish_msg_match(
2930 sub_name.name.unwrap(),
2931 u64::MAX,
2932 1,
2933 MatchFilter::ALL,
2934 );
2935 assert!(result.is_ok(), "recovered subscription should be routable");
2936 assert_eq!(result.unwrap(), vec![conn_id]);
2937 }
2938
2939 #[tokio::test]
2940 async fn test_link_negotiation_server_recovery_restores_remote_subs() {
2941 let processor = MessageProcessor::new();
2942 let (conn_id, mut rx) = make_server_conn(&processor);
2943
2944 let link_id = uuid::Uuid::new_v4().to_string();
2945 let source = ProtoName::from_strings(["org", "default", "src"]);
2946 let dest = ProtoName::from_strings(["org", "default", "dst"]);
2947
2948 let remote_sub =
2949 SubscriptionInfo::new(source.clone(), dest.clone(), "identity".into(), conn_id, 42);
2950
2951 processor.internal.recovery_table.store(
2953 link_id.clone(),
2954 HashMap::new(),
2955 HashSet::from([remote_sub]),
2956 );
2957
2958 let payload = LinkNegotiationPayload {
2959 link_id: link_id.clone(),
2960 slim_version: "1.0.0".into(),
2961 is_reply: false,
2962 link_ecdh_public_key: vec![],
2963 };
2964 processor
2965 .handle_link_negotiation(&payload, conn_id)
2966 .await
2967 .unwrap();
2968
2969 let sub_msg = rx.recv().await.unwrap().unwrap();
2971 assert!(matches!(sub_msg.get_type(), SubscribeType(_)));
2972 let reply = rx.recv().await.unwrap().unwrap();
2973 assert!(reply.is_link());
2974 }
2975
2976 #[tokio::test]
2977 async fn test_link_negotiation_server_no_recovery_entry() {
2978 let processor = MessageProcessor::new();
2979 let (conn_id, mut rx) = make_server_conn(&processor);
2980
2981 let link_id = uuid::Uuid::new_v4().to_string();
2982 let payload = LinkNegotiationPayload {
2984 link_id: link_id.clone(),
2985 slim_version: "1.0.0".into(),
2986 is_reply: false,
2987 link_ecdh_public_key: vec![],
2988 };
2989 processor
2990 .handle_link_negotiation(&payload, conn_id)
2991 .await
2992 .unwrap();
2993
2994 let reply = rx.try_recv().unwrap().unwrap();
2996 assert!(reply.is_link());
2997 assert!(rx.try_recv().is_err());
2998 }
2999
3000 #[tokio::test]
3003 async fn test_restore_remote_subscriptions_with_tracking() {
3004 let processor = MessageProcessor::new();
3005 let (conn_id, mut rx) = make_server_conn(&processor);
3006
3007 let source = ProtoName::from_strings(["org", "default", "src"]);
3008 let dest = ProtoName::from_strings(["org", "default", "dst"]);
3009 let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
3010 let subs = HashSet::from([sub]);
3011
3012 processor
3013 .restore_remote_subscriptions(&subs, conn_id, true)
3014 .await;
3015
3016 let msg = rx.recv().await.unwrap().unwrap();
3018 assert!(matches!(msg.get_type(), SubscribeType(_)));
3019
3020 let tracked = processor
3022 .forwarder()
3023 .get_subscriptions_forwarded_on_connection(conn_id);
3024 assert_eq!(tracked.len(), 1);
3025 }
3026
3027 #[tokio::test]
3028 async fn test_restore_remote_subscriptions_without_tracking() {
3029 let processor = MessageProcessor::new();
3030 let (conn_id, mut rx) = make_server_conn(&processor);
3031
3032 let source = ProtoName::from_strings(["org", "default", "src"]);
3033 let dest = ProtoName::from_strings(["org", "default", "dst"]);
3034 let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
3035 let subs = HashSet::from([sub]);
3036
3037 processor
3038 .restore_remote_subscriptions(&subs, conn_id, false)
3039 .await;
3040
3041 let msg = rx.recv().await.unwrap().unwrap();
3043 assert!(matches!(msg.get_type(), SubscribeType(_)));
3044
3045 let tracked = processor
3047 .forwarder()
3048 .get_subscriptions_forwarded_on_connection(conn_id);
3049 assert!(tracked.is_empty());
3050 }
3051}