1use std::collections::{HashMap, HashSet};
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use crate::api::DataPlaneServiceServer;
10use display_error_chain::ErrorChainExt;
11use parking_lot::RwLock;
12use slim_config::client::ClientConfig;
13use slim_config::client::TransportChannel;
14use slim_config::component::configuration::Configuration;
15use slim_config::server::ServerConfig;
16use slim_config::server_handler::ServerHandler;
17use slim_config::websocket::server as websocket_server;
18use slim_config::websocket::server::AcceptedWebSocketConnection;
19use tokio::sync::mpsc::{self, Sender};
20use tokio::sync::oneshot;
21use tokio::task::JoinHandle;
22use tokio_stream::wrappers::ReceiverStream;
23use tokio_stream::{Stream, StreamExt};
24use tokio_util::sync::CancellationToken;
25
26use tonic::{Request, Response, Status};
27use tracing::{Instrument, debug, error, info, warn};
28
29#[cfg(feature = "otel_tracing")]
30use crate::otel_tracing;
31
32use crate::api::ProtoPublishType as PublishType;
33use crate::api::ProtoSubscribeType as SubscribeType;
34use crate::api::ProtoSubscriptionAckType as SubscriptionAckType;
35use crate::api::ProtoUnsubscribeType as UnsubscribeType;
36use crate::api::proto::dataplane::v1::Message;
37
38use crate::api::proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient;
39use crate::api::proto::dataplane::v1::data_plane_service_server::DataPlaneService;
40use crate::api::{
41 LinkNegotiationPayload, ProtoLink, ProtoLinkMessageType as LinkType, ProtoLinkType, ProtoName,
42};
43use crate::connection::{Channel, Connection};
44use crate::errors::{DataPathError, MessageContext};
45use crate::forwarder::Forwarder;
46use crate::messages::utils::SlimHeaderFlags;
47use crate::sync::peer as sync_peer;
48use crate::sync::remote::{RemoteSync, SubscriptionInfo};
49use crate::tables::connection_table::ConnectionTable;
50use crate::tables::subscription_table::SubscriptionTableImpl;
51use crate::tables::{ConnType, MatchFilter};
52use crate::websocket;
53
54struct SubscriptionOutcome {
56 transition: bool,
58 is_peer_conn: bool,
60 forward_conn: Option<u64>,
62}
63
64#[cfg(test)]
66static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
67
68#[derive(Debug)]
69struct MessageProcessorInternal {
70 forwarder: Forwarder<Connection>,
72
73 drain_signal: parking_lot::RwLock<Option<drain::Signal>>,
75
76 drain_watch: parking_lot::RwLock<Option<drain::Watch>>,
78
79 tx_control_plane: RwLock<Option<Sender<Result<Message, Status>>>>,
81
82 remote_sync: RemoteSync,
84
85 service_id: String,
87
88 deployment_name: String,
92
93 server_require_header_mac: bool,
95
96 negotiation_timeout: std::time::Duration,
98
99 relay_peer_publishes: bool,
103
104 peer_sync: parking_lot::RwLock<crate::sync::PeerSync>,
107}
108
109#[derive(Debug, Clone)]
110pub struct MessageProcessor {
111 internal: Arc<MessageProcessorInternal>,
112}
113
114impl Default for MessageProcessor {
115 fn default() -> Self {
116 Self::new_with_service_id(String::new())
117 }
118}
119
120enum StreamSetup {
125 Registered(u64),
127 Pending {
129 connection: Box<Connection>,
130 existing_index: Option<u64>,
131 },
132}
133
134impl MessageProcessor {
135 pub fn new_with_service_id(service_id: String) -> Self {
136 Self::new_internal(
137 service_id,
138 String::new(),
139 false,
140 std::time::Duration::from_secs(5),
141 false,
142 )
143 }
144
145 pub fn new_with_server_config(
147 service_id: String,
148 deployment_name: String,
149 server_config: &ServerConfig,
150 relay_peer_publishes: bool,
151 ) -> Self {
152 Self::new_internal(
153 service_id,
154 deployment_name,
155 server_config.require_header_mac,
156 std::time::Duration::from_secs(server_config.negotiation_timeout_secs),
157 relay_peer_publishes,
158 )
159 }
160
161 fn new_internal(
162 service_id: String,
163 deployment_name: String,
164 server_require_header_mac: bool,
165 negotiation_timeout: std::time::Duration,
166 relay_peer_publishes: bool,
167 ) -> Self {
168 let (signal, watch) = drain::channel();
169 let internal = MessageProcessorInternal {
170 forwarder: Forwarder::new(),
171 drain_signal: RwLock::new(Some(signal)),
172 drain_watch: RwLock::new(Some(watch)),
173 tx_control_plane: RwLock::new(None),
174 remote_sync: RemoteSync::default(),
175 service_id,
176 deployment_name,
177 server_require_header_mac,
178 negotiation_timeout,
179 relay_peer_publishes,
180 peer_sync: parking_lot::RwLock::new(crate::sync::PeerSync::standalone()),
181 };
182 Self {
183 internal: Arc::new(internal),
184 }
185 }
186
187 pub fn new() -> Self {
188 Self::default()
189 }
190
191 pub async fn run_server(
196 &self,
197 config: &ServerConfig,
198 ) -> Result<CancellationToken, DataPathError> {
199 debug!(%config, "starting dataplane server");
200
201 if config.require_header_mac != self.internal.server_require_header_mac {
202 warn!(
203 configured = config.require_header_mac,
204 processor = self.internal.server_require_header_mac,
205 "server require_header_mac differs from MessageProcessor; inbound connections use the processor value set at construction (prefer MessageProcessor::new_with_server_config)",
206 );
207 }
208
209 let watch = self.get_drain_watch()?;
210 config
211 .run_server(watch, Arc::new(self.clone()))
212 .await
213 .map_err(Into::into)
214 }
215
216 async fn handle_websocket_accepted(&self, accepted: AcceptedWebSocketConnection) {
217 let cancellation_token = CancellationToken::new();
218 let streams =
219 websocket::spawn_transport_tasks(accepted.websocket, cancellation_token.clone());
220
221 let connection = Connection::new(ConnType::Remote, Channel::Client(streams.outbound))
222 .with_remote_addr(accepted.remote_addr)
223 .with_local_addr(accepted.local_addr)
224 .with_require_header_mac(self.internal.server_require_header_mac)
225 .with_cancellation_token(Some(cancellation_token.clone()));
226
227 debug!(
228 remote = ?connection.remote_addr(),
229 local = ?connection.local_addr(),
230 "new websocket connection received from remote",
231 );
232 info!(telemetry = true, counter.num_active_connections = 1);
233
234 if let Err(err) = self.process_stream(
235 streams.inbound,
236 StreamSetup::Pending {
237 connection: Box::new(connection),
238 existing_index: None,
239 },
240 None,
241 cancellation_token,
242 ConnType::Remote,
243 false,
244 ) {
245 error!(error = %err.chain(), "error starting websocket processing stream");
246 }
247 }
248
249 pub fn signal_drain(&self) {
255 self.internal.drain_signal.write().take();
256 self.internal.drain_watch.write().take();
257 }
258
259 pub async fn shutdown(&self) -> Result<(), DataPathError> {
260 let signal = self
262 .internal
263 .drain_signal
264 .write()
265 .take()
266 .ok_or(DataPathError::AlreadyClosedError)?;
267
268 self.internal.drain_watch.write().take();
270
271 signal.drain().await;
273
274 Ok(())
275 }
276
277 fn set_tx_control_plane(&self, tx: Sender<Result<Message, Status>>) {
278 let mut tx_guard = self.internal.tx_control_plane.write();
279 *tx_guard = Some(tx);
280 }
281
282 fn get_tx_control_plane(&self) -> Option<Sender<Result<Message, Status>>> {
283 let tx_guard = self.internal.tx_control_plane.read();
284 tx_guard.clone()
285 }
286
287 pub fn forwarder(&self) -> &Forwarder<Connection> {
288 &self.internal.forwarder
289 }
290
291 pub(crate) fn remote_sync(&self) -> &RemoteSync {
292 &self.internal.remote_sync
293 }
294
295 pub(crate) fn verify_remote_header_mac(
297 &self,
298 conn_index: u64,
299 message: &Message,
300 enforce_strict_verification: bool,
301 ) -> Result<(), DataPathError> {
302 let conn = self
303 .forwarder()
304 .get_connection(conn_index)
305 .ok_or(DataPathError::ConnectionNotFound(conn_index))?;
306 if !matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge) {
307 return Ok(());
308 }
309 let header = message
310 .try_get_slim_header()
311 .ok_or(DataPathError::UnknownMsgType)?;
312
313 let has_wire_mac = header.header_mac.as_ref().is_some_and(|m| !m.is_empty());
314
315 if (message.is_subscribe() || message.is_unsubscribe()) && !has_wire_mac {
320 if enforce_strict_verification {
321 return Err(DataPathError::NegotiationError(
322 "empty HMAC is not allowed in strict verification mode".to_string(),
323 ));
324 }
325 return Ok(());
326 }
327
328 let Some(mac) = conn.header_hmac() else {
329 if enforce_strict_verification {
330 return Err(DataPathError::NegotiationError(
331 "strict header MAC required but link HMAC session is not installed".to_string(),
332 ));
333 }
334 if message.is_publish() && has_wire_mac {
338 return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index));
339 }
340 return Ok(());
341 };
342 let link_id = conn
343 .link_id()
344 .filter(|id| !id.is_empty())
345 .ok_or(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index))?;
346 mac.verify_slim_header(header, &link_id)
347 .map_err(DataPathError::HeaderIntegrity)
348 }
349
350 pub(crate) fn get_drain_watch(&self) -> Result<drain::Watch, DataPathError> {
351 self.internal
352 .drain_watch
353 .read()
354 .clone()
355 .ok_or(DataPathError::AlreadyClosedError)
356 }
357
358 async fn restore_remote_subscriptions(
361 &self,
362 remote_subs: &HashSet<SubscriptionInfo>,
363 conn_index: u64,
364 restore_tracking: bool,
365 ) {
366 self.remote_sync()
367 .restore(self, remote_subs, conn_index, restore_tracking)
368 .await;
369 }
370
371 async fn try_to_connect(
372 &self,
373 client_config: ClientConfig,
374 local: Option<SocketAddr>,
375 remote: Option<SocketAddr>,
376 existing_conn_index: Option<u64>,
377 ) -> Result<(JoinHandle<()>, u64), DataPathError> {
378 client_config.validate()?;
379
380 let mut watch = std::pin::pin!(self.get_drain_watch()?.signaled());
381 let channel = tokio::select! {
382 _ = &mut watch => {
383 return Err(DataPathError::ShuttingDownError);
384 }
385 res = client_config.to_channel() => {
386 res?
387 }
388 };
389
390 let cancellation_token = CancellationToken::new();
391 let link_id = client_config.link_id.clone();
392
393 match channel {
394 TransportChannel::Grpc(grpc_channel) => {
395 let mut client = DataPlaneServiceClient::new(grpc_channel);
396 let (tx, rx) = mpsc::channel(128);
397 let stream = client
398 .open_channel(Request::new(ReceiverStream::new(rx)))
399 .await?;
400
401 let (handle, conn_index_rx) = self.register_remote_connection(
402 stream.into_inner(),
403 Channel::Client(tx),
404 &client_config,
405 local,
406 remote,
407 existing_conn_index,
408 cancellation_token,
409 Some(link_id.clone()),
410 )?;
411
412 let conn_index = conn_index_rx.await.map_err(|_| {
413 DataPathError::NegotiationError(
414 "negotiation task terminated unexpectedly".to_string(),
415 )
416 })??;
417
418 if matches!(client_config.connection_type, ConnType::Peer) {
421 let fwd = self.peer_sync();
422 if !fwd.has_peer_state() {
423 fwd.add_peer_conn_and_sync(self, conn_index);
424 }
425 }
426
427 Ok((handle, conn_index))
428 }
429 TransportChannel::Websocket(ws_channel) => {
430 let websocket = ws_channel
431 .take_websocket()
432 .expect("websocket channel already consumed");
433 let streams =
434 websocket::spawn_transport_tasks(websocket, cancellation_token.clone());
435
436 let (handle, conn_index_rx) = self.register_remote_connection(
437 streams.inbound,
438 Channel::Client(streams.outbound),
439 &client_config,
440 local.or(ws_channel.local_addr()),
441 remote.or(ws_channel.remote_addr()),
442 existing_conn_index,
443 cancellation_token,
444 Some(link_id.clone()),
445 )?;
446
447 let conn_index = conn_index_rx.await.map_err(|_| {
448 DataPathError::NegotiationError(
449 "negotiation task terminated unexpectedly".to_string(),
450 )
451 })??;
452
453 if matches!(client_config.connection_type, ConnType::Peer) {
456 let fwd = self.peer_sync();
457 if !fwd.has_peer_state() {
458 fwd.add_peer_conn_and_sync(self, conn_index);
459 }
460 }
461
462 Ok((handle, conn_index))
463 }
464 }
465 }
466
467 #[allow(clippy::too_many_arguments)]
473 fn register_remote_connection<S>(
474 &self,
475 inbound: S,
476 outbound: Channel,
477 client_config: &ClientConfig,
478 local: Option<SocketAddr>,
479 remote: Option<SocketAddr>,
480 existing_conn_index: Option<u64>,
481 cancellation_token: CancellationToken,
482 link_id: Option<String>,
483 ) -> Result<
484 (
485 JoinHandle<()>,
486 oneshot::Receiver<Result<u64, DataPathError>>,
487 ),
488 DataPathError,
489 >
490 where
491 S: Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
492 {
493 let mut connection = Connection::new(client_config.connection_type, outbound)
494 .with_local_addr(local)
495 .with_remote_addr(remote)
496 .with_config_data(Some(client_config.clone()))
497 .with_require_header_mac(client_config.require_header_mac)
498 .with_cancellation_token(Some(cancellation_token.clone()));
499 if let Some(link_id) = link_id {
500 connection = connection.with_link_id(link_id);
501 }
502
503 debug!(
504 remote = ?connection.remote_addr(),
505 local = ?connection.local_addr(),
506 ?client_config.connection_type,
507 "new connection initiated locally",
508 );
509
510 let (handle, conn_index_rx) = self.process_stream(
511 inbound,
512 StreamSetup::Pending {
513 connection: Box::new(connection),
514 existing_index: existing_conn_index,
515 },
516 Some(client_config.clone()),
517 cancellation_token,
518 client_config.connection_type,
519 false,
520 )?;
521
522 Ok((handle, conn_index_rx))
523 }
524
525 pub async fn connect(
526 &self,
527 client_config: ClientConfig,
528 local: Option<SocketAddr>,
529 remote: Option<SocketAddr>,
530 ) -> Result<(JoinHandle<()>, u64), DataPathError> {
531 self.try_to_connect(client_config, local, remote, None)
532 .await
533 }
534
535 pub fn disconnect(&self, conn: u64) -> Result<ClientConfig, DataPathError> {
536 let connection = match self.forwarder().get_connection(conn) {
537 Some(c) => c,
538 None => {
539 error!(%conn, "error handling disconnect: connection unknown");
540 return Err(DataPathError::DisconnectionError(conn));
541 }
542 };
543
544 let token = match connection.cancellation_token() {
545 Some(t) => t,
546 None => {
547 error!(%conn, "error handling disconnect: missing cancellation token");
548 return Err(DataPathError::DisconnectionError(conn));
549 }
550 };
551
552 token.cancel();
554
555 connection
556 .config_data()
557 .cloned()
558 .ok_or(DataPathError::DisconnectionError(conn))
559 }
560
561 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id))]
562 pub fn register_local_connection(
563 &self,
564 from_control_plane: bool,
565 ) -> Result<
566 (
567 u64,
568 tokio::sync::mpsc::Sender<Result<Message, Status>>,
569 tokio::sync::mpsc::Receiver<Result<Message, Status>>,
570 ),
571 DataPathError,
572 > {
573 let (tx1, rx1) = mpsc::channel(512);
575
576 debug!("establishing new local app connection");
577
578 let (tx2, rx2) = mpsc::channel(512);
580
581 if from_control_plane && self.get_tx_control_plane().is_none() {
584 self.set_tx_control_plane(tx2.clone());
585 }
586
587 let cancellation_token = CancellationToken::new();
589 let connection = Connection::new(ConnType::Local, Channel::Server(tx2))
590 .with_cancellation_token(Some(cancellation_token.clone()));
591
592 let conn_id = self
594 .forwarder()
595 .on_connection_established(connection, None)
596 .unwrap();
597
598 debug!(%conn_id, "local connection established");
599 info!(telemetry = true, counter.num_active_connections = 1);
600
601 self.process_stream(
603 ReceiverStream::new(rx1),
604 StreamSetup::Registered(conn_id),
605 None,
606 cancellation_token,
607 ConnType::Local,
608 from_control_plane,
609 )?;
610
611 Ok((conn_id, tx1, rx2))
613 }
614
615 pub async fn send_msg(
616 &self,
617 #[cfg(feature = "otel_tracing")] mut msg: Message,
618 #[cfg(not(feature = "otel_tracing"))] msg: Message,
619 out_conn: u64,
620 ) -> Result<(), DataPathError> {
621 #[cfg(feature = "otel_tracing")]
622 otel_tracing::prepare_outbound_msg(
623 &mut msg,
624 "send_message",
625 &self.internal.service_id,
626 otel_tracing::SpanTarget::Connection(out_conn),
627 );
628 self.send_msg_raw(msg, out_conn).await
629 }
630
631 async fn send_msg_raw(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
632 let connection = self.forwarder().get_connection(out_conn);
633 match connection {
634 Some(conn) => {
635 if !msg.is_link() && !msg.is_subscription_ack() {
638 msg.clear_slim_header();
639 }
640
641 if !msg.is_link()
642 && !msg.is_subscription_ack()
643 && matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge)
644 && conn.require_header_mac()
645 && conn.header_hmac().is_none()
646 {
647 return Err(DataPathError::NegotiationError(
648 "strict header MAC required but link HMAC session is not installed"
649 .to_string(),
650 ));
651 }
652
653 if !msg.is_link()
654 && !msg.is_subscription_ack()
655 && matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge)
656 && let Some(mac) = conn.header_hmac()
657 {
658 let link_id = conn
659 .link_id()
660 .or_else(|| conn.config_data().map(|c| c.link_id.clone()))
661 .filter(|id| !id.is_empty());
662 if let Some(ref id) = link_id {
663 let header = msg.get_slim_header_mut();
664
665 mac.sign_slim_header(header, id.as_str())
666 .map_err(DataPathError::HeaderIntegrity)?;
667
668 #[cfg(debug_assertions)]
671 if std::env::var("SLIM_TEST_TAMPER_DESTINATION").is_ok()
672 && let Some(dest) = header.destination.as_mut()
673 && let Some(sn) = dest.str_name.as_mut()
674 {
675 sn.str_component_2.push_str("-integrity-test-tamper");
676 }
677 } else {
678 return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(out_conn));
679 }
680 }
681
682 if !msg.is_link()
683 && !msg.is_subscription_ack()
684 && matches!(conn.channel(), Channel::Server(_))
685 && matches!(conn.connection_type(), ConnType::Local)
686 {
687 msg.get_slim_header_mut().header_mac = None;
688 }
689
690 match conn.channel() {
691 Channel::Server(s) => {
692 s.send(Ok(msg))
693 .await
694 .map_err(|e| DataPathError::MessageProcessingError {
695 source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
696 msg: Box::new(e.0.unwrap_or_default()),
697 })
698 }
699 Channel::Client(s) => {
700 s.send(msg)
701 .await
702 .map_err(|e| DataPathError::MessageProcessingError {
703 source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
704 msg: Box::new(e.0),
705 })
706 }
707 }
708 }
709 None => Err(DataPathError::ConnectionNotFound(out_conn)),
710 }
711 }
712
713 async fn send_status(&self, conn_index: u64, status: Status) {
716 if let Some(conn) = self.forwarder().get_connection(conn_index)
717 && let Channel::Server(tx) = conn.channel()
718 {
719 let _ = tx.send(Err(status)).await;
720 }
721 }
722
723 async fn match_and_forward_msg(
724 &self,
725 #[cfg(feature = "otel_tracing")] mut msg: Message,
726 #[cfg(not(feature = "otel_tracing"))] msg: Message,
727 in_connection: u64,
728 fanout: u32,
729 filter: MatchFilter,
730 ) -> Result<(), DataPathError> {
731 let header = msg.get_slim_header();
732 debug!(name = %header.get_dst(), %fanout, "match and forward message");
733
734 if let Some(val) = msg.get_forward_to() {
737 debug!(conn = %val, "forwarding message to connection");
738 return self.send_msg(msg, val).await;
739 }
740
741 let encoded = header.get_encoded_dst();
742
743 match self
744 .forwarder()
745 .on_publish_msg_match(encoded, in_connection, fanout, filter)
746 {
747 Ok(out_vec) => {
748 let len = out_vec.len();
749 if len == 1 {
751 return self.send_msg(msg, out_vec[0]).await;
752 }
753
754 #[cfg(feature = "otel_tracing")]
755 otel_tracing::prepare_fanout_msg(
756 &mut msg,
757 "send_message",
758 &self.internal.service_id,
759 len as u32,
760 );
761
762 let mut i = 0usize;
763 while i < len - 1 {
764 self.send_msg_raw(msg.clone(), out_vec[i]).await?;
765 i += 1;
766 }
767 self.send_msg_raw(msg, out_vec[i]).await?;
768 Ok(())
769 }
770 Err(e) => {
771 debug!(name = %header.get_dst(), %fanout, error = %e, "no match for publish destination");
772 Err(DataPathError::MessageProcessingError {
773 source: Box::new(e),
774 msg: Box::new(msg),
775 })
776 }
777 }
778 }
779
780 async fn handle_link_message(
785 &self,
786 link: ProtoLink,
787 conn_index: u64,
788 category: ConnType,
789 ) -> Result<(), DataPathError> {
790 if category.is_local() {
791 debug!(%conn_index, "ignoring link message received on local connection");
792 return Ok(());
793 }
794 match link.link_type {
795 Some(ProtoLinkType::LinkNegotiation(payload)) => {
796 self.handle_link_negotiation(&payload, conn_index).await
797 }
798 None => {
799 debug!(%conn_index, "received link message with unset link_type");
800 Ok(())
801 }
802 }
803 }
804
805 async fn handle_link_negotiation(
811 &self,
812 payload: &LinkNegotiationPayload,
813 in_connection: u64,
814 ) -> Result<(), DataPathError> {
815 debug!(
816 %in_connection,
817 link_id = %payload.link_id,
818 is_reply = payload.is_reply,
819 "ignoring link negotiation message on already-negotiated connection",
820 );
821
822 Ok(())
823 }
824
825 pub(crate) async fn handle_peer_upgrade(
828 &self,
829 remote_node_id: &str,
830 remote_deployment_name: &str,
831 in_connection: u64,
832 link_id: &str,
833 ) -> Result<(), DataPathError> {
834 if remote_node_id == self.internal.service_id {
836 warn!(
837 %in_connection, %link_id,
838 "rejecting peer connection from self (same node_id)"
839 );
840 self.send_status(
841 in_connection,
842 Status::permission_denied("self-connection rejected: same node_id"),
843 )
844 .await;
845 let _ = self.disconnect(in_connection);
846 return Ok(());
847 }
848
849 if !self.internal.deployment_name.is_empty()
851 && remote_deployment_name != self.internal.deployment_name
852 {
853 warn!(
854 %in_connection, %link_id,
855 local_group = %self.internal.deployment_name,
856 remote_group = %remote_deployment_name,
857 "rejecting peer upgrade: deployment_name mismatch"
858 );
859 self.send_status(
860 in_connection,
861 Status::permission_denied("deployment_name mismatch"),
862 )
863 .await;
864 let _ = self.disconnect(in_connection);
865 return Ok(());
866 }
867
868 info!(
869 %in_connection, %link_id, %remote_node_id,
870 "upgrading server-side connection to Peer (negotiation)"
871 );
872 self.connection_table().update(in_connection, |conn| {
873 conn.set_connection_type(ConnType::Peer)
874 });
875
876 self.peer_sync()
877 .on_incoming_peer(self, remote_node_id.to_string(), in_connection);
878
879 Ok(())
880 }
881
882 async fn process_publish(
883 &self,
884 msg: Message,
885 in_connection: u64,
886 filter: MatchFilter,
887 ) -> Result<(), DataPathError> {
888 debug!(
889 %in_connection,
890 ?msg,
891 "received publication"
892 );
893
894 info!(
896 telemetry = true,
897 monotonic_counter.num_messages_by_type = 1,
898 method = "publish"
899 );
900 let fanout = msg.get_fanout();
905
906 self.match_and_forward_msg(msg, in_connection, fanout, filter)
907 .await
908 }
909
910 pub(crate) async fn send_subscription_ack(
911 &self,
912 in_connection: u64,
913 subscription_id: u64,
914 result: &Result<(), DataPathError>,
915 ) {
916 let (success, error_msg) = match result {
917 Ok(()) => (true, String::new()),
918 Err(e) => (false, e.to_string()),
919 };
920
921 let ack_msg =
922 Message::builder().build_subscription_ack(subscription_id, success, error_msg);
923
924 if let Err(e) = self.send_msg(ack_msg, in_connection).await {
925 error!(error = %e.chain(), "failed to send subscription ack");
926 }
927 }
928
929 fn update_subscription_state(
933 &self,
934 msg: &Message,
935 conn: u64,
936 forward: Option<u64>,
937 add: bool,
938 subscription_id: u64,
939 ) -> Result<SubscriptionOutcome, DataPathError> {
940 let dst = msg.get_dst();
941
942 let connection = if let Some(c) = self.forwarder().get_connection(conn) {
944 c
945 } else {
946 return Err(DataPathError::ConnectionNotFound(conn));
947 };
948
949 debug!(
950 %conn,
951 %dst,
952 is_local = connection.is_local_connection(),
953 "processing {}subscription state",
954 if add { "" } else { "un" }
955 );
956
957 let is_peer_conn = connection.is_peer_connection();
958
959 let transition = self.forwarder().on_subscription_msg(
960 dst,
961 conn,
962 connection.connection_type(),
963 add,
964 subscription_id,
965 )?;
966
967 Ok(SubscriptionOutcome {
968 transition,
969 is_peer_conn,
970 forward_conn: forward,
971 })
972 }
973
974 async fn process_subscription(
982 &self,
983 msg: Message,
984 in_connection: u64,
985 add: bool,
986 ) -> Result<(), DataPathError> {
987 debug!(
988 %in_connection,
989 ?msg,
990 "received {}subscription",
991 if add { "" } else { "un" }
992 );
993
994 info!(
996 telemetry = true,
997 monotonic_counter.num_messages_by_type = 1,
998 message_type = { if add { "subscribe" } else { "unsubscribe" } }
999 );
1000 let subscription_id = msg.get_subscription_id();
1003
1004 debug!(?subscription_id, "received subscription id");
1005
1006 let header = msg.get_slim_header();
1008
1009 let (in_conn, recv_from, forward) = header.get_connections();
1011 let in_conn = recv_from.unwrap_or(in_conn);
1012
1013 let forward = forward.filter(|&out| {
1016 self.forwarder()
1017 .get_connection(out)
1018 .map(|c| !c.is_local_connection())
1019 .unwrap_or(true)
1020 });
1021
1022 let Some(connection) = self.forwarder().get_connection(in_conn) else {
1024 if let Some(id) = subscription_id {
1025 debug!(%in_conn, "connection not found, sending error ack");
1026 self.send_subscription_ack(
1027 in_connection,
1028 id,
1029 &Err(DataPathError::ConnectionNotFound(in_conn)),
1030 )
1031 .await;
1032 }
1033 return Err(DataPathError::MessageProcessingError {
1034 source: Box::new(DataPathError::ConnectionNotFound(in_conn)),
1035 msg: Box::new(msg),
1036 });
1037 };
1038
1039 if recv_from.is_some() && connection.is_local_connection() {
1041 if let Some(id) = subscription_id {
1042 debug!(%in_conn, "subscription looped back to local connection, acking ok");
1043 self.send_subscription_ack(in_connection, id, &Ok(())).await;
1044 }
1045 return Ok(());
1046 }
1047
1048 let sub_id = subscription_id.unwrap_or(0);
1055 if add && sub_id != 0 && self.peer_sync().has_seen_sub_id(sub_id) {
1056 debug!(
1057 %in_conn,
1058 %sub_id,
1059 "dropping subscription already forwarded by this node (loop prevention)"
1060 );
1061 if let Some(id) = subscription_id {
1062 self.send_subscription_ack(in_connection, id, &Ok(())).await;
1063 }
1064 return Ok(());
1065 }
1066
1067 let outcome = match self.update_subscription_state(&msg, in_conn, forward, add, sub_id) {
1069 Ok(o) => o,
1070 Err(e) => {
1071 if let Some(id) = subscription_id {
1072 self.send_subscription_ack(in_connection, id, &Err(e)).await;
1073 return Ok(());
1075 }
1076 return Err(DataPathError::MessageProcessingError {
1077 source: Box::new(e),
1078 msg: Box::new(msg),
1079 });
1080 }
1081 };
1082
1083 let remaining_ttl = msg.get_ttl();
1093
1094 let (peer_target, peer_ttl) = if !outcome.is_peer_conn && outcome.transition {
1095 let ttl = self.peer_sync().subscription_ttl();
1097 (Some(crate::sync::PeerTarget::All), ttl)
1098 } else if outcome.is_peer_conn && remaining_ttl >= 2 {
1099 (
1102 Some(crate::sync::PeerTarget::ExcludeConn(in_conn)),
1103 remaining_ttl,
1104 )
1105 } else {
1106 (None, 0)
1107 };
1108
1109 let targets = crate::sync::ForwardTargets {
1110 peers: peer_target,
1111 forward_conn: outcome.forward_conn,
1112 };
1113
1114 if targets.has_any() {
1117 let fwd = self.peer_sync();
1118 let dst = msg.get_dst();
1119 debug!(
1120 %in_connection,
1121 %dst,
1122 %remaining_ttl,
1123 %peer_ttl,
1124 ?targets,
1125 "spawning subscription forwarder task"
1126 );
1127 let drain = self.get_drain_watch().ok();
1128 if let Some(drain) = drain {
1129 fwd.spawn_forward_and_ack(
1130 self.clone(),
1131 msg,
1132 dst,
1133 sub_id,
1134 add,
1135 targets,
1136 in_connection,
1137 subscription_id,
1138 peer_ttl,
1139 drain,
1140 );
1141 return Ok(());
1142 }
1143 }
1146
1147 if let Some(id) = subscription_id {
1149 debug!(%in_connection, "sending immediate subscription ack (no forwarding)");
1150 self.send_subscription_ack(in_connection, id, &Ok(())).await;
1151 }
1152
1153 Ok(())
1154 }
1155
1156 pub async fn process_message(
1157 &self,
1158 msg: Message,
1159 in_connection: u64,
1160 category: ConnType,
1161 ) -> Result<(), DataPathError> {
1162 match msg.message_type {
1163 Some(SubscribeType(_)) => self.process_subscription(msg, in_connection, true).await,
1164 Some(UnsubscribeType(_)) => self.process_subscription(msg, in_connection, false).await,
1165 Some(PublishType(_)) => {
1166 let filter = match category {
1167 ConnType::Peer => {
1168 if self.internal.relay_peer_publishes {
1169 MatchFilter::ALL
1170 } else {
1171 MatchFilter::EXCLUDE_PEER
1172 }
1173 }
1174 _ => MatchFilter::ALL,
1175 };
1176 self.process_publish(msg, in_connection, filter).await
1177 }
1178 Some(LinkType(link)) => {
1179 self.handle_link_message(link, in_connection, category)
1180 .await
1181 }
1182 Some(SubscriptionAckType(ack)) => {
1183 let result = if ack.success {
1184 Ok(())
1185 } else {
1186 Err(DataPathError::RemoteSubscriptionAckError(ack.error))
1187 };
1188
1189 self.peer_sync().resolve_ack(ack.subscription_id, result);
1190 Ok(())
1191 }
1192 None => unreachable!(
1193 "message type not set; validate() must be called before process_message"
1194 ),
1195 }
1196 }
1197
1198 pub(crate) async fn handle_new_message(
1199 &self,
1200 conn_index: u64,
1201 category: ConnType,
1202 mut msg: Message,
1203 ) -> Result<(), DataPathError> {
1204 debug!(%conn_index, "received message from connection");
1205 info!(
1206 telemetry = true,
1207 monotonic_counter.num_processed_messages = 1
1208 );
1209
1210 if let Err(err) = msg.validate() {
1212 info!(
1213 telemetry = true,
1214 monotonic_counter.num_messages_by_type = 1,
1215 message_type = "none"
1216 );
1217
1218 let ret_err = DataPathError::MessageProcessingError {
1219 source: Box::new(err.into()),
1220 msg: Box::new(msg),
1221 };
1222
1223 return Err(ret_err);
1224 }
1225
1226 if !msg.is_link() && !msg.is_subscription_ack() {
1228 msg.set_incoming_conn(Some(conn_index));
1230
1231 if !category.is_local() && msg.decrement_ttl() == 0 {
1233 debug!(%conn_index, "dropping message: TTL expired");
1234 return Err(DataPathError::TtlExpired);
1235 }
1236
1237 #[cfg(feature = "otel_tracing")]
1238 otel_tracing::prepare_inbound_msg(
1239 &mut msg,
1240 "process_local",
1241 &self.internal.service_id,
1242 conn_index,
1243 category.is_local(),
1244 );
1245 }
1246
1247 match self.process_message(msg, conn_index, category).await {
1248 Ok(_) => Ok(()),
1249 Err(e) => {
1250 info!(
1252 telemetry = true,
1253 monotonic_counter.num_message_process_errors = 1
1254 );
1255 Err(e)
1259 }
1260 }
1261 }
1262
1263 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1264 async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
1265 debug!(%conn_index, "sending error to local application");
1266 let connection = self.forwarder().get_connection(conn_index);
1267 match connection {
1268 Some(conn) => {
1269 debug!("try to notify the error to the local application");
1270 if let Channel::Server(tx) = conn.channel() {
1271 let session_ctx = match &err {
1273 DataPathError::MessageProcessingError { msg, .. } => {
1274 MessageContext::from_msg(msg)
1275 }
1276 _ => None,
1277 };
1278
1279 let payload = crate::errors::ErrorPayload::new(err.to_string(), session_ctx);
1281 let error_message = payload.to_json_string();
1282
1283 let status = Status::new(tonic::Code::Internal, error_message);
1285
1286 if tx.send(Err(status)).await.is_err() {
1287 debug!(error = %err.chain(), "unable to notify the error to the local app");
1288 }
1289 }
1290 }
1291 None => {
1292 error!(
1293 "error sending error to local app: connection {:?} not found",
1294 conn_index
1295 );
1296 }
1297 }
1298 }
1299
1300 #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1301 async fn reconnect(
1302 &self,
1303 client_conf: ClientConfig,
1304 conn_index: u64,
1305 cancellation_token: &CancellationToken,
1306 ) -> bool {
1307 info!("connection lost with remote endpoint, attempting to reconnect");
1308
1309 let is_peer = self
1310 .forwarder()
1311 .get_connection(conn_index)
1312 .map(|c| c.connection_type() == ConnType::Peer)
1313 .unwrap_or(false);
1314
1315 let remote_subscriptions = if !is_peer {
1319 self.remote_sync()
1320 .get_subscriptions_for_reconnect(conn_index)
1321 } else {
1322 Default::default()
1323 };
1324
1325 tokio::select! {
1326 _ = cancellation_token.cancelled() => {
1327 debug!("cancellation token signaled, stopping reconnection process");
1328 false
1329 }
1330 res = self.try_to_connect(client_conf, None, None, Some(conn_index)) => {
1331 match res {
1332 Ok(_) => {
1333 info!("connection re-established successfully");
1334 if is_peer {
1335 let ttl = self.peer_sync().subscription_ttl();
1337 if let Err(e) = sync_peer::send_local_remote_sync(
1338 self, conn_index, ttl,
1339 )
1340 .await
1341 {
1342 warn!(
1343 error = %e,
1344 "failed to send full sync after peer reconnect"
1345 );
1346 }
1347 } else {
1348 self.restore_remote_subscriptions(
1350 &remote_subscriptions,
1351 conn_index,
1352 false,
1353 )
1354 .await;
1355 }
1356 true
1357 }
1358 Err(e) => {
1359 error!(error = %e.chain(), "unable to reconnect to remote node");
1360 false
1361 }
1362 }
1363 }
1364 }
1365 }
1366
1367 async fn notify_control_plane_subscriptions_lost(
1373 tx_cp: Option<Sender<Result<Message, Status>>>,
1374 local_subs: HashMap<ProtoName, HashSet<u64>>,
1375 conn_index: u64,
1376 ) {
1377 let Some(tx) = tx_cp else { return };
1378 for local_sub in local_subs.into_keys() {
1379 debug!(
1380 %local_sub,
1381 "notify control plane about lost subscription",
1382 );
1383 let msg = Message::builder()
1384 .source(local_sub.clone())
1385 .destination(local_sub.clone())
1386 .flags(SlimHeaderFlags::default().with_recv_from(conn_index))
1387 .build_unsubscribe()
1388 .unwrap();
1389 if let Err(e) = tx.send(Ok(msg)).await {
1390 debug!(
1391 %local_sub,
1392 error = %e.chain(),
1393 "failed to send unsubscribe to control plane",
1394 );
1395 }
1396 }
1397 }
1398
1399 async fn resolve_connection(
1408 &self,
1409 stream: &mut (impl Stream<Item = Result<Message, Status>> + Unpin + Send),
1410 setup: StreamSetup,
1411 category: ConnType,
1412 conn_index_tx: oneshot::Sender<Result<u64, DataPathError>>,
1413 watch: &drain::Watch,
1414 token: &CancellationToken,
1415 ) -> Option<(u64, ConnType)> {
1416 match setup {
1417 StreamSetup::Registered(idx) => {
1418 let _ = conn_index_tx.send(Ok(idx));
1419 Some((idx, category))
1420 }
1421 StreamSetup::Pending {
1422 connection,
1423 existing_index,
1424 } => {
1425 self.negotiate_and_register(
1426 stream,
1427 *connection,
1428 existing_index,
1429 category,
1430 conn_index_tx,
1431 watch,
1432 token,
1433 )
1434 .await
1435 }
1436 }
1437 }
1438
1439 #[allow(clippy::too_many_arguments)]
1443 async fn negotiate_and_register(
1444 &self,
1445 stream: &mut (impl Stream<Item = Result<Message, Status>> + Unpin + Send),
1446 mut connection: Connection,
1447 existing_index: Option<u64>,
1448 category: ConnType,
1449 conn_index_tx: oneshot::Sender<Result<u64, DataPathError>>,
1450 watch: &drain::Watch,
1451 token: &CancellationToken,
1452 ) -> Option<(u64, ConnType)> {
1453 let timeout = self.internal.negotiation_timeout;
1454 let params = crate::negotiation::NegotiationParams {
1455 node_id: &self.internal.service_id,
1456 deployment_name: &self.internal.deployment_name,
1457 connection_type: category,
1458 };
1459
1460 let negotiation_result = tokio::select! {
1461 result = tokio::time::timeout(
1462 timeout,
1463 crate::negotiation::run_negotiation(&mut connection, stream, ¶ms),
1464 ) => match result {
1465 Ok(r) => r,
1466 Err(_) => Err(DataPathError::NegotiationError(
1467 "timed out waiting for link negotiation".to_string(),
1468 )),
1469 },
1470 _ = watch.clone().signaled() => {
1471 info!("shutting down during link negotiation");
1472 let _ = conn_index_tx.send(Err(DataPathError::ShuttingDownError));
1473 return None;
1474 }
1475 _ = token.cancelled() => {
1476 info!("connection cancelled during link negotiation");
1477 let _ = conn_index_tx.send(Err(DataPathError::ShuttingDownError));
1478 return None;
1479 }
1480 };
1481
1482 let result = match negotiation_result {
1483 Ok(r) => r,
1484 Err(e) => {
1485 error!(error = %e.chain(), "link negotiation failed, closing connection");
1486 let _ = conn_index_tx.send(Err(e));
1487 info!(telemetry = true, counter.num_active_connections = -1);
1488 return None;
1489 }
1490 };
1491
1492 let idx = match self
1494 .forwarder()
1495 .on_connection_established(connection, existing_index)
1496 {
1497 Some(idx) => idx,
1498 None => {
1499 let _ = conn_index_tx.send(Err(DataPathError::ConnectionTableAddError));
1500 info!(telemetry = true, counter.num_active_connections = -1);
1501 return None;
1502 }
1503 };
1504
1505 debug!(%idx, "connection registered after link negotiation");
1506
1507 let category = match result.connection_type {
1509 ConnType::Peer => {
1510 let link_id = self
1511 .forwarder()
1512 .get_connection(idx)
1513 .and_then(|c| c.link_id())
1514 .unwrap_or_default();
1515 if let Err(e) = self
1516 .handle_peer_upgrade(
1517 &result.remote_node_id,
1518 &result.remote_deployment_name,
1519 idx,
1520 &link_id,
1521 )
1522 .await
1523 {
1524 error!(error = %e.chain(), "peer upgrade failed after negotiation");
1525 let _ = conn_index_tx.send(Err(e));
1526 info!(telemetry = true, counter.num_active_connections = -1);
1527 return None;
1528 }
1529 ConnType::Peer
1530 }
1531 other => other,
1532 };
1533
1534 let _ = conn_index_tx.send(Ok(idx));
1535 Some((idx, category))
1536 }
1537
1538 fn process_stream(
1539 &self,
1540 mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
1541 setup: StreamSetup,
1542 client_config: Option<ClientConfig>,
1543 cancellation_token: CancellationToken,
1544 category: ConnType,
1545 from_control_plane: bool,
1546 ) -> Result<
1547 (
1548 JoinHandle<()>,
1549 oneshot::Receiver<Result<u64, DataPathError>>,
1550 ),
1551 DataPathError,
1552 > {
1553 let self_clone = self.clone();
1555 let token_clone = cancellation_token.clone();
1556 let client_conf_clone = client_config.clone();
1557 let tx_cp: Option<Sender<Result<Message, Status>>> = self.get_tx_control_plane();
1558 let watch = self.get_drain_watch()?;
1559 let is_local = category.is_local();
1560
1561 let (conn_index_tx, conn_index_rx) = oneshot::channel();
1562
1563 let require_header_mac = match &setup {
1564 StreamSetup::Registered(idx) => self
1565 .forwarder()
1566 .get_connection(*idx)
1567 .map(|c| c.require_header_mac())
1568 .unwrap_or(false),
1569 StreamSetup::Pending { connection, .. } => connection.require_header_mac(),
1570 };
1571
1572 let span = tracing::info_span!(
1573 "process_stream",
1574 service_id = %self.internal.service_id,
1575 conn_index = match &setup {
1576 StreamSetup::Registered(idx) => *idx,
1577 _ => 0,
1578 },
1579 is_local,
1580 );
1581
1582 let handle = tokio::spawn(async move {
1583 let mut try_to_reconnect = true;
1584
1585 let Some((conn_index, category)) = self_clone
1588 .resolve_connection(
1589 &mut stream,
1590 setup,
1591 category,
1592 conn_index_tx,
1593 &watch,
1594 &token_clone,
1595 )
1596 .await
1597 else {
1598 return;
1599 };
1600
1601 let mut watch = std::pin::pin!(watch.signaled());
1602 loop {
1603 tokio::select! {
1604 next = stream.next() => {
1605 match next {
1606 Some(result) => {
1607 match result {
1608 Ok(msg) => {
1609 if !is_local
1610 && !msg.is_link()
1611 && !msg.is_subscription_ack()
1612 && let Err(e) = self_clone
1613 .verify_remote_header_mac(conn_index, &msg, require_header_mac)
1614 {
1615 error!(
1616 %conn_index,
1617 error = %e.chain(),
1618 "SLIM header integrity verification failed",
1619 );
1620 continue;
1621 }
1622 if !is_local && !from_control_plane && let Some(txcp) = &tx_cp {
1628 match msg.get_type() {
1629 PublishType(_) | LinkType(_) | SubscriptionAckType(_) => {}
1630 _ => {
1631 let _ = txcp.send(Ok(msg.clone())).await;
1634 }
1635 }
1636 }
1637
1638 if let Err(e) = self_clone.handle_new_message(conn_index, category, msg).await {
1639 if matches!(e, DataPathError::NegotiationError(_)) {
1641 error!(%conn_index, "fatal link negotiation error, closing connection");
1642 try_to_reconnect = false;
1643 break;
1644 }
1645 debug!(%conn_index, error = %e.chain(), "error processing incoming message");
1646 if is_local {
1648 self_clone.send_error_to_local_app(conn_index, e).await;
1650 }
1651 }
1652 }
1653 Err(e) => {
1654 if e.code() == tonic::Code::PermissionDenied {
1655 warn!(
1656 %conn_index,
1657 message = %e.message(),
1658 "connection rejected by remote, will not reconnect"
1659 );
1660 try_to_reconnect = false;
1661 } else if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
1662 if io_err.kind() == std::io::ErrorKind::BrokenPipe {
1663 info!(%conn_index, "connection closed by peer");
1664 }
1665 } else {
1666 error!(error = %e.chain(), "error receiving messages");
1667 }
1668 break;
1669 }
1670 }
1671 }
1672 None => {
1673 debug!(%conn_index, "end of stream");
1674 break;
1675 }
1676 }
1677 }
1678 _ = &mut watch => {
1679 info!(%conn_index, "shutting down stream on drain");
1680 try_to_reconnect = false;
1681 break;
1682 }
1683 _ = token_clone.cancelled() => {
1684 info!(%conn_index, "shutting down stream on cancellation token");
1685 try_to_reconnect = false;
1686 break;
1687 }
1688 }
1689 }
1690
1691 drop(stream);
1695
1696 let mut connected = false;
1697
1698 if try_to_reconnect
1699 && !matches!(category, ConnType::Remote)
1700 && let Some(config) = client_conf_clone
1701 {
1702 connected = self_clone.reconnect(config, conn_index, &token_clone)
1706 .instrument(tracing::Span::none())
1707 .await;
1708 } else {
1709 debug!(%conn_index, "close connection")
1710 }
1711
1712 if !connected {
1713 let local_subs = self_clone
1715 .forwarder()
1716 .on_connection_drop(conn_index, category);
1717 let _remote_subs = self_clone
1718 .remote_sync()
1719 .on_connection_drop(conn_index);
1720
1721 if matches!(category, ConnType::Peer) {
1723 self_clone.peer_sync().remove_peer_conn(conn_index);
1724 }
1725
1726 {
1731 let fwd = self_clone.peer_sync();
1732 for name in local_subs.keys() {
1733 let still_reachable = name.name.is_some_and(|enc| {
1734 self_clone
1735 .forwarder()
1736 .on_publish_msg_match(enc, u64::MAX, u32::MAX, MatchFilter::ALL)
1737 .is_ok()
1738 });
1739 if !still_reachable {
1740 debug!(
1741 %name,
1742 %conn_index,
1743 ?category,
1744 "notifying peers of unsubscription (connection drop)"
1745 );
1746 fwd.notify_peers_unsubscribe(&self_clone, name).await;
1747 } else {
1748 debug!(
1749 %name,
1750 %conn_index,
1751 ?category,
1752 "name still reachable, not emitting removal"
1753 );
1754 }
1755 }
1756 }
1757
1758
1759 if !is_local {
1761 MessageProcessor::notify_control_plane_subscriptions_lost(
1762 tx_cp, local_subs, conn_index,
1763 )
1764 .await;
1765 }
1766
1767 info!(telemetry = true, counter.num_active_connections = -1);
1768 }
1769 }.instrument(span));
1770
1771 Ok((handle, conn_index_rx))
1772 }
1773
1774 fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
1775 let mut err: &(dyn std::error::Error + 'static) = err_status;
1776
1777 loop {
1778 if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
1779 return Some(io_err);
1780 }
1781
1782 if let Some(h2_err) = err.downcast_ref::<h2::Error>()
1785 && let Some(io_err) = h2_err.get_io()
1786 {
1787 return Some(io_err);
1788 }
1789
1790 err = err.source()?;
1791 }
1792 }
1793
1794 pub fn subscription_table(&self) -> &SubscriptionTableImpl {
1795 &self.internal.forwarder.subscription_table
1796 }
1797
1798 pub fn connection_table(&self) -> &ConnectionTable<Connection> {
1799 &self.internal.forwarder.connection_table
1800 }
1801
1802 pub fn service_id(&self) -> &str {
1804 &self.internal.service_id
1805 }
1806
1807 pub fn set_peer_sync(&self, peer_sync: crate::sync::PeerSync) {
1809 *self.internal.peer_sync.write() = peer_sync;
1810 }
1811
1812 pub(crate) fn peer_sync(&self) -> crate::sync::PeerSync {
1814 self.internal.peer_sync.read().clone()
1815 }
1816}
1817
1818impl ServerHandler for MessageProcessor {
1819 fn grpc_routes(&self) -> Option<tonic::service::Routes> {
1820 let svc = DataPlaneServiceServer::from_arc(Arc::new(self.clone()));
1821 Some(tonic::service::Routes::new(svc))
1822 }
1823
1824 fn on_websocket_accepted(&self) -> Option<websocket_server::OnAcceptedWebSocket> {
1825 let processor = self.clone();
1826 Some(Arc::new(move |accepted| {
1827 let processor = processor.clone();
1828 Box::pin(async move { processor.handle_websocket_accepted(accepted).await })
1829 }))
1830 }
1831}
1832
1833#[tonic::async_trait]
1834impl DataPlaneService for MessageProcessor {
1835 type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
1836
1837 async fn open_channel(
1838 &self,
1839 request: Request<tonic::Streaming<Message>>,
1840 ) -> Result<Response<Self::OpenChannelStream>, Status> {
1841 let remote_addr = request.remote_addr();
1842 let local_addr = request.local_addr();
1843
1844 let stream = request.into_inner();
1845 let (tx, rx) = mpsc::channel(128);
1846
1847 let connection = Connection::new(ConnType::Remote, Channel::Server(tx))
1848 .with_remote_addr(remote_addr)
1849 .with_local_addr(local_addr)
1850 .with_require_header_mac(self.internal.server_require_header_mac);
1851
1852 debug!(
1853 remote = ?connection.remote_addr(),
1854 local = ?connection.local_addr(),
1855 "new connection received from remote",
1856 );
1857 info!(telemetry = true, counter.num_active_connections = 1);
1858
1859 self.process_stream(
1860 stream,
1861 StreamSetup::Pending {
1862 connection: Box::new(connection),
1863 existing_index: None,
1864 },
1865 None,
1866 CancellationToken::new(),
1867 ConnType::Remote,
1868 false,
1869 )
1870 .map_err(|e| {
1871 error!(error = %e.chain(), "error starting new processing stream");
1872 Status::unavailable(format!("error processing stream: {:?}", e))
1873 })?;
1874
1875 let out_stream = ReceiverStream::new(rx);
1876 Ok(Response::new(
1877 Box::pin(out_stream) as Self::OpenChannelStream
1878 ))
1879 }
1880}
1881
1882#[cfg(test)]
1883mod tests {
1884 use std::time::Duration;
1885
1886 use super::*;
1887 use crate::api::{ProtoMessage, ProtoName, ProtoSubscriptionAck};
1888 use crate::header_mac::HeaderMacSession;
1889 use crate::sync::remote::SubscriptionInfo;
1890 use tonic::Status;
1891
1892 async fn assert_failed_subscription_ack_is_sent(add: bool) {
1893 let processor = MessageProcessor::new();
1894 let (in_connection, _tx, mut rx) = processor
1895 .register_local_connection(false)
1896 .expect("failed to create local connection");
1897
1898 let source = ProtoName::from_strings(["org", "ns", "source"]).with_id(1);
1899 let destination = ProtoName::from_strings(["org", "ns", "destination"]).with_id(2);
1900 let ack_id: u64 = if add { 1 } else { 2 };
1901 let invalid_connection = u64::MAX - 1;
1902
1903 let builder = Message::builder()
1904 .source(source.clone())
1905 .destination(destination.clone())
1906 .incoming_conn(invalid_connection)
1907 .subscription_id(ack_id);
1908
1909 let msg = if add {
1910 builder.build_subscribe().unwrap()
1911 } else {
1912 builder.build_unsubscribe().unwrap()
1913 };
1914
1915 let result = processor
1916 .process_subscription(msg, in_connection, add)
1917 .await;
1918 assert!(matches!(
1919 result,
1920 Err(DataPathError::MessageProcessingError { .. })
1921 ));
1922
1923 let ack_msg = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1924 .await
1925 .expect("timeout waiting for ack")
1926 .expect("ack channel closed")
1927 .expect("failed to receive ack message");
1928
1929 assert!(matches!(ack_msg.get_type(), SubscriptionAckType(_)));
1930 let ack = ack_msg.get_subscription_ack();
1931 assert_eq!(ack.subscription_id, ack_id);
1932 assert!(!ack.success, "failed ack should have success=false");
1933 assert!(
1934 !ack.error.is_empty(),
1935 "failed ack should include an error message"
1936 );
1937 }
1938
1939 #[tokio::test]
1940 async fn test_process_subscription_sends_failed_ack_on_subscribe_error() {
1941 assert_failed_subscription_ack_is_sent(true).await;
1942 }
1943
1944 #[tokio::test]
1945 async fn test_process_subscription_sends_failed_ack_on_unsubscribe_error() {
1946 assert_failed_subscription_ack_is_sent(false).await;
1947 }
1948
1949 #[tokio::test]
1952 async fn test_handle_link_message_is_local_ignored() {
1953 let processor = MessageProcessor::new();
1954 let link = ProtoLink { link_type: None };
1955 assert!(
1956 processor
1957 .handle_link_message(link, 0, ConnType::Local)
1958 .await
1959 .is_ok()
1960 );
1961 }
1962
1963 #[tokio::test]
1964 async fn test_handle_link_message_none_link_type_ignored() {
1965 let processor = MessageProcessor::new();
1966 let link = ProtoLink { link_type: None };
1967 assert!(
1968 processor
1969 .handle_link_message(link, 0, ConnType::Remote)
1970 .await
1971 .is_ok()
1972 );
1973 }
1974
1975 #[tokio::test]
1981 async fn test_handle_link_negotiation_post_negotiation_is_noop() {
1982 let processor = MessageProcessor::new();
1983 let payload = LinkNegotiationPayload {
1984 link_id: uuid::Uuid::new_v4().to_string(),
1985 slim_version: "1.0.0".into(),
1986 is_reply: false,
1987 link_ecdh_public_key: vec![],
1988 connection_type: 0,
1989 node_id: String::new(),
1990 deployment_name: String::new(),
1991 };
1992 assert!(
1994 processor
1995 .handle_link_negotiation(&payload, u64::MAX)
1996 .await
1997 .is_ok()
1998 );
1999 let (conn_id, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2001 assert!(
2002 processor
2003 .handle_link_negotiation(&payload, conn_id)
2004 .await
2005 .is_ok()
2006 );
2007 }
2008
2009 fn make_negotiated_server_conn(
2014 processor: &MessageProcessor,
2015 version: &str,
2016 ) -> (u64, tokio::sync::mpsc::Receiver<Result<Message, Status>>) {
2017 let (tx, rx) = mpsc::channel(16);
2018 let conn = Connection::new(ConnType::Remote, Channel::Server(tx))
2019 .with_require_header_mac(processor.internal.server_require_header_mac)
2020 .with_negotiation(&uuid::Uuid::new_v4().to_string(), version)
2021 .with_header_hmac(HeaderMacSession::new(b"01234567890123456789012345678901").unwrap());
2022 let conn_id = processor
2023 .forwarder()
2024 .on_connection_established(conn, None)
2025 .unwrap();
2026 (conn_id, rx)
2027 }
2028
2029 #[tokio::test]
2030 async fn test_negotiation_timeout_configurable() {
2031 let server_config = ServerConfig {
2032 endpoint: "localhost:12345".to_string(),
2033 negotiation_timeout_secs: 1, ..Default::default()
2035 };
2036 let processor = MessageProcessor::new_with_server_config(
2037 "test_service".to_string(),
2038 String::new(),
2039 &server_config,
2040 false,
2041 );
2042
2043 assert_eq!(
2044 processor.internal.negotiation_timeout,
2045 std::time::Duration::from_secs(1)
2046 );
2047 }
2048
2049 #[test]
2050 fn verify_remote_header_mac_strict_rejects_publish_without_mac_session() {
2051 let processor = MessageProcessor::new();
2052 let (tx, _rx) = mpsc::channel(16);
2054 let conn = Connection::new(ConnType::Remote, Channel::Server(tx))
2055 .with_require_header_mac(true)
2056 .with_negotiation(&uuid::Uuid::new_v4().to_string(), "1.2.0");
2057 let remote_conn = processor
2058 .forwarder()
2059 .on_connection_established(conn, None)
2060 .unwrap();
2061 let c = processor.forwarder().get_connection(remote_conn).unwrap();
2062 assert!(c.header_hmac().is_none());
2063
2064 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2065 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2066 let msg = ProtoMessage::builder()
2067 .source(source)
2068 .destination(dest)
2069 .application_payload("text/plain", b"hey".to_vec())
2070 .build_publish()
2071 .expect("publish");
2072
2073 let err = processor
2074 .verify_remote_header_mac(remote_conn, &msg, true)
2075 .expect_err("unsigned publish must fail in strict mode without MAC session");
2076 assert!(matches!(err, DataPathError::NegotiationError(_)));
2077 }
2078
2079 #[test]
2080 fn verify_remote_header_mac_accepts_signed_inter_node_publish() {
2081 let processor = MessageProcessor::new();
2082 let (remote_conn, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2083 let link_id = processor
2084 .forwarder()
2085 .get_connection(remote_conn)
2086 .unwrap()
2087 .link_id()
2088 .expect("link id after negotiation");
2089
2090 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2091 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2092 let require_header_mac = true;
2093 let mut msg = ProtoMessage::builder()
2094 .source(source)
2095 .destination(dest)
2096 .application_payload("text/plain", b"hey".to_vec())
2097 .build_publish()
2098 .expect("publish");
2099
2100 let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2101 mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2102 .expect("sign header");
2103
2104 assert!(
2105 processor
2106 .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2107 .is_ok()
2108 );
2109 }
2110
2111 #[test]
2112 fn verify_remote_header_mac_rejects_destination_tamper_after_sign() {
2113 let processor = MessageProcessor::new();
2114 let (remote_conn, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2115 let link_id = processor
2116 .forwarder()
2117 .get_connection(remote_conn)
2118 .unwrap()
2119 .link_id()
2120 .expect("link id after negotiation");
2121
2122 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2123 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2124 let mut msg = ProtoMessage::builder()
2125 .source(source)
2126 .destination(dest)
2127 .application_payload("text/plain", b"hey".to_vec())
2128 .build_publish()
2129 .expect("publish");
2130
2131 let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2132 let require_header_mac = true;
2133 mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2134 .expect("sign header");
2135
2136 let header = msg.get_slim_header_mut();
2137 if let Some(dest) = header.destination.as_mut()
2138 && let Some(sn) = dest.str_name.as_mut()
2139 {
2140 sn.str_component_2.push_str("-integrity-test-tamper");
2141 }
2142
2143 let err = processor
2144 .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2145 .expect_err("tampered header must fail MAC verify");
2146 assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2147 }
2148
2149 #[tokio::test]
2150 #[allow(clippy::disallowed_methods)]
2151 async fn test_send_msg_raw_tamper_destination_env_var() {
2152 let _guard = ENV_LOCK.lock().await;
2153 unsafe {
2154 std::env::set_var("SLIM_TEST_TAMPER_DESTINATION", "1");
2155 }
2156
2157 let processor = MessageProcessor::new();
2158 let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2159
2160 let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2161 let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2162 let msg = ProtoMessage::builder()
2163 .source(source)
2164 .destination(dest)
2165 .application_payload("text/plain", b"hey".to_vec())
2166 .build_publish()
2167 .expect("publish");
2168
2169 processor
2170 .send_msg_raw(msg, conn_id)
2171 .await
2172 .expect("send_msg_raw failed");
2173
2174 let sent_msg = rx.recv().await.unwrap().unwrap();
2175 let header = sent_msg.get_slim_header();
2176 let dest_name = header.destination.as_ref().expect("destination");
2177 let str_name = dest_name.str_name.as_ref().expect("str_name");
2178 let require_header_mac = true;
2179
2180 assert!(str_name.str_component_2.ends_with("-integrity-test-tamper"));
2182
2183 let err = processor
2185 .verify_remote_header_mac(conn_id, &sent_msg, require_header_mac)
2186 .expect_err("tampered header must fail MAC verify");
2187 assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2188
2189 unsafe {
2190 std::env::remove_var("SLIM_TEST_TAMPER_DESTINATION");
2191 }
2192 }
2193
2194 #[tokio::test]
2195 async fn test_process_subscription_remote_ack_path_success() {
2196 let processor = MessageProcessor::new();
2199 let (local_conn, _tx_local, mut rx_local) = processor
2200 .register_local_connection(false)
2201 .expect("failed to create local connection");
2202
2203 let (remote_conn, mut rx_remote) = make_negotiated_server_conn(&processor, "1.2.0");
2204
2205 let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2206 let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2207 let upstream_ack_id: u64 = 100;
2208
2209 let sub_msg = Message::builder()
2211 .source(source.clone())
2212 .destination(destination.clone())
2213 .incoming_conn(local_conn)
2214 .forward_to(remote_conn)
2215 .subscription_id(upstream_ack_id)
2216 .build_subscribe()
2217 .unwrap();
2218
2219 let result = processor
2221 .process_subscription(sub_msg, local_conn, true)
2222 .await;
2223 assert!(result.is_ok());
2224
2225 let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2228 .await
2229 .expect("timeout waiting for forwarded subscribe")
2230 .expect("forwarded subscribe channel closed")
2231 .unwrap();
2232 assert!(matches!(forwarded.get_type(), SubscribeType(_)));
2233
2234 let forwarded_sub_id = forwarded
2236 .get_subscription_id()
2237 .expect("forwarded subscribe must carry the same subscription_id");
2238 assert_eq!(
2239 forwarded_sub_id, upstream_ack_id,
2240 "subscription_id must not change when forwarding"
2241 );
2242
2243 let ack = ProtoSubscriptionAck {
2245 subscription_id: upstream_ack_id,
2246 success: true,
2247 error: String::new(),
2248 };
2249 processor.peer_sync().resolve_ack(
2250 ack.subscription_id,
2251 if ack.success {
2252 Ok(())
2253 } else {
2254 Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2255 },
2256 );
2257
2258 let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2260 .await
2261 .expect("timeout waiting for upstream ack")
2262 .expect("upstream ack channel closed")
2263 .expect("upstream ack should be Ok");
2264
2265 assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2266 let ack_inner = upstream_ack.get_subscription_ack();
2267 assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2268 assert!(ack_inner.success);
2269 }
2270
2271 #[tokio::test]
2272 async fn test_process_subscription_remote_ack_error_forwarded_upstream() {
2273 let processor = MessageProcessor::new();
2275 let (local_conn, _tx_local, mut rx_local) = processor
2276 .register_local_connection(false)
2277 .expect("failed to create local connection");
2278
2279 let (remote_conn, mut rx_remote) = make_negotiated_server_conn(&processor, "1.2.0");
2280
2281 let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2282 let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2283 let upstream_ack_id: u64 = 102;
2284
2285 let sub_msg = Message::builder()
2286 .source(source.clone())
2287 .destination(destination.clone())
2288 .incoming_conn(local_conn)
2289 .forward_to(remote_conn)
2290 .subscription_id(upstream_ack_id)
2291 .build_subscribe()
2292 .unwrap();
2293
2294 processor
2295 .process_subscription(sub_msg, local_conn, true)
2296 .await
2297 .unwrap();
2298
2299 let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2300 .await
2301 .expect("timeout")
2302 .expect("channel closed")
2303 .unwrap();
2304
2305 let forwarded_sub_id = forwarded
2306 .get_subscription_id()
2307 .expect("forwarded subscribe must carry the same subscription_id");
2308 assert_eq!(
2309 forwarded_sub_id, upstream_ack_id,
2310 "subscription_id must not change when forwarding"
2311 );
2312
2313 let ack = ProtoSubscriptionAck {
2315 subscription_id: upstream_ack_id,
2316 success: false,
2317 error: "remote error".to_string(),
2318 };
2319 processor.peer_sync().resolve_ack(
2320 ack.subscription_id,
2321 if ack.success {
2322 Ok(())
2323 } else {
2324 Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2325 },
2326 );
2327
2328 let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2329 .await
2330 .expect("timeout")
2331 .expect("channel closed")
2332 .expect("must be Ok");
2333
2334 assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2335 let ack_inner = upstream_ack.get_subscription_ack();
2336 assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2337 assert!(!ack_inner.success);
2338 assert!(!ack_inner.error.is_empty());
2339 }
2340
2341 #[tokio::test]
2344 async fn test_notify_cp_subs_lost_sends_unsubscribes() {
2345 let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2346 let mut subs = HashMap::new();
2347 let name = ProtoName::from_strings(["org", "default", "svc"]);
2348 subs.insert(name.clone(), HashSet::from([1u64, 2u64]));
2349
2350 MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), subs, 42).await;
2351
2352 let msg = rx.recv().await.unwrap().unwrap();
2353 assert!(matches!(msg.get_type(), UnsubscribeType(_)));
2354 assert_eq!(msg.get_source(), name.clone());
2355 }
2356
2357 #[tokio::test]
2358 async fn test_notify_cp_subs_lost_no_tx_is_noop() {
2359 let subs = HashMap::from([(
2360 ProtoName::from_strings(["org", "default", "svc"]),
2361 HashSet::from([1u64]),
2362 )]);
2363 MessageProcessor::notify_control_plane_subscriptions_lost(None, subs, 1).await;
2365 }
2366
2367 #[tokio::test]
2368 async fn test_notify_cp_subs_lost_empty_subs() {
2369 let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2370 MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), HashMap::new(), 1)
2371 .await;
2372 assert!(rx.try_recv().is_err());
2374 }
2375
2376 #[tokio::test]
2379 async fn test_restore_remote_subscriptions_with_tracking() {
2380 let processor = MessageProcessor::new();
2381 let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2382
2383 let source = ProtoName::from_strings(["org", "default", "src"]);
2384 let dest = ProtoName::from_strings(["org", "default", "dst"]);
2385 let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
2386 let subs = HashSet::from([sub]);
2387
2388 processor
2389 .restore_remote_subscriptions(&subs, conn_id, true)
2390 .await;
2391
2392 let msg = rx.recv().await.unwrap().unwrap();
2394 assert!(matches!(msg.get_type(), SubscribeType(_)));
2395
2396 let tracked = processor
2398 .remote_sync()
2399 .get_subscriptions_for_reconnect(conn_id);
2400 assert_eq!(tracked.len(), 1);
2401 }
2402
2403 #[tokio::test]
2404 async fn test_restore_remote_subscriptions_without_tracking() {
2405 let processor = MessageProcessor::new();
2406 let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2407
2408 let source = ProtoName::from_strings(["org", "default", "src"]);
2409 let dest = ProtoName::from_strings(["org", "default", "dst"]);
2410 let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
2411 let subs = HashSet::from([sub]);
2412
2413 processor
2414 .restore_remote_subscriptions(&subs, conn_id, false)
2415 .await;
2416
2417 let msg = rx.recv().await.unwrap().unwrap();
2419 assert!(matches!(msg.get_type(), SubscribeType(_)));
2420
2421 let tracked = processor
2423 .remote_sync()
2424 .get_subscriptions_for_reconnect(conn_id);
2425 assert!(tracked.is_empty());
2426 }
2427}