Skip to main content

slim_datapath/
message_processing.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::{HashMap, HashSet};
5use std::net::SocketAddr;
6use std::pin::Pin;
7use std::sync::Arc;
8
9use crate::api::DataPlaneServiceServer;
10use display_error_chain::ErrorChainExt;
11use parking_lot::RwLock;
12use slim_config::client::ClientConfig;
13use slim_config::client::TransportChannel;
14use slim_config::component::configuration::Configuration;
15use slim_config::server::ServerConfig;
16use slim_config::server_handler::ServerHandler;
17use slim_config::websocket::server as websocket_server;
18use slim_config::websocket::server::AcceptedWebSocketConnection;
19use tokio::sync::mpsc::{self, Sender};
20use tokio::sync::oneshot;
21use tokio::task::JoinHandle;
22use tokio_stream::wrappers::ReceiverStream;
23use tokio_stream::{Stream, StreamExt};
24use tokio_util::sync::CancellationToken;
25
26use tonic::{Request, Response, Status};
27use tracing::{Instrument, debug, error, info, warn};
28
29#[cfg(feature = "otel_tracing")]
30use crate::otel_tracing;
31
32use crate::api::ProtoPublishType as PublishType;
33use crate::api::ProtoSubscribeType as SubscribeType;
34use crate::api::ProtoSubscriptionAckType as SubscriptionAckType;
35use crate::api::ProtoUnsubscribeType as UnsubscribeType;
36use crate::api::proto::dataplane::v1::Message;
37
38use crate::api::proto::dataplane::v1::data_plane_service_client::DataPlaneServiceClient;
39use crate::api::proto::dataplane::v1::data_plane_service_server::DataPlaneService;
40use crate::api::{
41    LinkNegotiationPayload, ProtoLink, ProtoLinkMessageType as LinkType, ProtoLinkType, ProtoName,
42};
43use crate::connection::{Channel, Connection};
44use crate::errors::{DataPathError, MessageContext};
45use crate::forwarder::Forwarder;
46use crate::messages::utils::SlimHeaderFlags;
47use crate::sync::peer as sync_peer;
48use crate::sync::remote::{RemoteSync, SubscriptionInfo};
49use crate::tables::connection_table::ConnectionTable;
50use crate::tables::subscription_table::SubscriptionTableImpl;
51use crate::tables::{ConnType, MatchFilter};
52use crate::websocket;
53
54/// Result of updating subscription state (pure state change, no forwarding).
55struct SubscriptionOutcome {
56    /// Whether an aggregate transition occurred (0→1 or 1→0).
57    transition: bool,
58    /// Whether the source connection is a peer connection.
59    is_peer_conn: bool,
60    /// The forward-to connection (controller), if any.
61    forward_conn: Option<u64>,
62}
63
64// Sync tests using environment variables
65#[cfg(test)]
66static ENV_LOCK: tokio::sync::Mutex<()> = tokio::sync::Mutex::const_new(());
67
68#[derive(Debug)]
69struct MessageProcessorInternal {
70    /// The forwarder to handle processing events
71    forwarder: Forwarder<Connection>,
72
73    /// Drain signal to gracefully close all pending tasks
74    drain_signal: parking_lot::RwLock<Option<drain::Signal>>,
75
76    ///Drain watch to receive drain signal
77    drain_watch: parking_lot::RwLock<Option<drain::Watch>>,
78
79    /// Tx channel towards control plane
80    tx_control_plane: RwLock<Option<Sender<Result<Message, Status>>>>,
81
82    /// Tracks subscriptions forwarded to remote connections and handles restore on reconnect.
83    remote_sync: RemoteSync,
84
85    /// Service ID for tracing
86    service_id: String,
87
88    /// Peer group this node belongs to. Used during link negotiation to verify
89    /// that both sides of a peer connection belong to the same deployment.
90    /// Empty when no peer config is set.
91    deployment_name: String,
92
93    /// Default strict header MAC policy for server-accepted inter-node connections (see [`ServerConfig::require_header_mac`]).
94    server_require_header_mac: bool,
95
96    /// Timeout for link negotiation to complete.
97    negotiation_timeout: std::time::Duration,
98
99    /// Whether peer-originated publishes should be relayed to other peers.
100    /// True for hub-and-spoke (hub) or generic multi-hop topologies.
101    /// False for full-mesh (peers deliver directly — 1-hop rule).
102    relay_peer_publishes: bool,
103
104    /// Peer sync component for subscription forwarding and peer lifecycle.
105    /// Initialized as standalone; replaced with a peer-aware instance when peers are configured.
106    peer_sync: parking_lot::RwLock<crate::sync::PeerSync>,
107}
108
109#[derive(Debug, Clone)]
110pub struct MessageProcessor {
111    internal: Arc<MessageProcessorInternal>,
112}
113
114impl Default for MessageProcessor {
115    fn default() -> Self {
116        Self::new_with_service_id(String::new())
117    }
118}
119
120/// Describes how a connection enters [`MessageProcessor::process_stream`].
121///
122/// Local connections are pre-registered in the table; remote connections are
123/// only inserted after the mandatory link negotiation completes.
124enum StreamSetup {
125    /// Connection already in the table (local connections).
126    Registered(u64),
127    /// Remote connection not yet in the table; will be inserted after negotiation.
128    Pending {
129        connection: Box<Connection>,
130        existing_index: Option<u64>,
131    },
132}
133
134impl MessageProcessor {
135    pub fn new_with_service_id(service_id: String) -> Self {
136        Self::new_internal(
137            service_id,
138            String::new(),
139            false,
140            std::time::Duration::from_secs(5),
141            false,
142        )
143    }
144
145    /// Create a processor with the server strict header MAC policy from `server_config`.
146    pub fn new_with_server_config(
147        service_id: String,
148        deployment_name: String,
149        server_config: &ServerConfig,
150        relay_peer_publishes: bool,
151    ) -> Self {
152        Self::new_internal(
153            service_id,
154            deployment_name,
155            server_config.require_header_mac,
156            std::time::Duration::from_secs(server_config.negotiation_timeout_secs),
157            relay_peer_publishes,
158        )
159    }
160
161    fn new_internal(
162        service_id: String,
163        deployment_name: String,
164        server_require_header_mac: bool,
165        negotiation_timeout: std::time::Duration,
166        relay_peer_publishes: bool,
167    ) -> Self {
168        let (signal, watch) = drain::channel();
169        let internal = MessageProcessorInternal {
170            forwarder: Forwarder::new(),
171            drain_signal: RwLock::new(Some(signal)),
172            drain_watch: RwLock::new(Some(watch)),
173            tx_control_plane: RwLock::new(None),
174            remote_sync: RemoteSync::default(),
175            service_id,
176            deployment_name,
177            server_require_header_mac,
178            negotiation_timeout,
179            relay_peer_publishes,
180            peer_sync: parking_lot::RwLock::new(crate::sync::PeerSync::standalone()),
181        };
182        Self {
183            internal: Arc::new(internal),
184        }
185    }
186
187    pub fn new() -> Self {
188        Self::default()
189    }
190
191    /// Run a data plane server using this message processor's drain watch.
192    /// Dispatch on the configured transport happens inside slim-config via the
193    /// [`ServerHandler`] trait below. Returns a cancellation token that can be
194    /// used to stop the server task.
195    pub async fn run_server(
196        &self,
197        config: &ServerConfig,
198    ) -> Result<CancellationToken, DataPathError> {
199        debug!(%config, "starting dataplane server");
200
201        if config.require_header_mac != self.internal.server_require_header_mac {
202            warn!(
203                configured = config.require_header_mac,
204                processor = self.internal.server_require_header_mac,
205                "server require_header_mac differs from MessageProcessor; inbound connections use the processor value set at construction (prefer MessageProcessor::new_with_server_config)",
206            );
207        }
208
209        let watch = self.get_drain_watch()?;
210        config
211            .run_server(watch, Arc::new(self.clone()))
212            .await
213            .map_err(Into::into)
214    }
215
216    async fn handle_websocket_accepted(&self, accepted: AcceptedWebSocketConnection) {
217        let cancellation_token = CancellationToken::new();
218        let streams =
219            websocket::spawn_transport_tasks(accepted.websocket, cancellation_token.clone());
220
221        let connection = Connection::new(ConnType::Remote, Channel::Client(streams.outbound))
222            .with_remote_addr(accepted.remote_addr)
223            .with_local_addr(accepted.local_addr)
224            .with_require_header_mac(self.internal.server_require_header_mac)
225            .with_cancellation_token(Some(cancellation_token.clone()));
226
227        debug!(
228            remote = ?connection.remote_addr(),
229            local = ?connection.local_addr(),
230            "new websocket connection received from remote",
231        );
232        info!(telemetry = true, counter.num_active_connections = 1);
233
234        if let Err(err) = self.process_stream(
235            streams.inbound,
236            StreamSetup::Pending {
237                connection: Box::new(connection),
238                existing_index: None,
239            },
240            None,
241            cancellation_token,
242            ConnType::Remote,
243            false,
244        ) {
245            error!(error = %err.chain(), "error starting websocket processing stream");
246        }
247    }
248
249    /// Signal all spawned tasks (process_stream, etc.) to begin shutting down.
250    ///
251    /// Unlike [`shutdown`], this is synchronous: it drops the drain signal (which
252    /// notifies all drain watches) and the drain watch, but does NOT wait for the
253    /// tasks to finish.  Safe to call from a synchronous `Drop` implementation.
254    pub fn signal_drain(&self) {
255        self.internal.drain_signal.write().take();
256        self.internal.drain_watch.write().take();
257    }
258
259    pub async fn shutdown(&self) -> Result<(), DataPathError> {
260        // Take the drain signal
261        let signal = self
262            .internal
263            .drain_signal
264            .write()
265            .take()
266            .ok_or(DataPathError::AlreadyClosedError)?;
267
268        // Take drain watch
269        self.internal.drain_watch.write().take();
270
271        // Signal completion to all tasks
272        signal.drain().await;
273
274        Ok(())
275    }
276
277    fn set_tx_control_plane(&self, tx: Sender<Result<Message, Status>>) {
278        let mut tx_guard = self.internal.tx_control_plane.write();
279        *tx_guard = Some(tx);
280    }
281
282    fn get_tx_control_plane(&self) -> Option<Sender<Result<Message, Status>>> {
283        let tx_guard = self.internal.tx_control_plane.read();
284        tx_guard.clone()
285    }
286
287    pub fn forwarder(&self) -> &Forwarder<Connection> {
288        &self.internal.forwarder
289    }
290
291    pub(crate) fn remote_sync(&self) -> &RemoteSync {
292        &self.internal.remote_sync
293    }
294
295    /// Verify SLIM header MAC for inter-node traffic only (local app connections skip this).
296    pub(crate) fn verify_remote_header_mac(
297        &self,
298        conn_index: u64,
299        message: &Message,
300        enforce_strict_verification: bool,
301    ) -> Result<(), DataPathError> {
302        let conn = self
303            .forwarder()
304            .get_connection(conn_index)
305            .ok_or(DataPathError::ConnectionNotFound(conn_index))?;
306        if !matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge) {
307            return Ok(());
308        }
309        let header = message
310            .try_get_slim_header()
311            .ok_or(DataPathError::UnknownMsgType)?;
312
313        let has_wire_mac = header.header_mac.as_ref().is_some_and(|m| !m.is_empty());
314
315        // Publishes must carry a MAC once the inter-node session has derived a key.  Control
316        // messages (subscribe / unsubscribe) may still traverse the same gRPC stream without a
317        // tag on some federation paths; skipping verification only when the tag is absent keeps
318        // tamper detection for application traffic.
319        if (message.is_subscribe() || message.is_unsubscribe()) && !has_wire_mac {
320            if enforce_strict_verification {
321                return Err(DataPathError::NegotiationError(
322                    "empty HMAC is not allowed in strict verification mode".to_string(),
323                ));
324            }
325            return Ok(());
326        }
327
328        let Some(mac) = conn.header_hmac() else {
329            if enforce_strict_verification {
330                return Err(DataPathError::NegotiationError(
331                    "strict header MAC required but link HMAC session is not installed".to_string(),
332                ));
333            }
334            // Do not accept inter-node publishes that already carry an integrity tag until this
335            // side has derived the link MAC; otherwise verification is silently skipped and peers
336            // never see `HeaderIntegrity` failures (including tampered test traffic).
337            if message.is_publish() && has_wire_mac {
338                return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index));
339            }
340            return Ok(());
341        };
342        let link_id = conn
343            .link_id()
344            .filter(|id| !id.is_empty())
345            .ok_or(DataPathError::HeaderMacAwaitingLinkNegotiation(conn_index))?;
346        mac.verify_slim_header(header, &link_id)
347            .map_err(DataPathError::HeaderIntegrity)
348    }
349
350    pub(crate) fn get_drain_watch(&self) -> Result<drain::Watch, DataPathError> {
351        self.internal
352            .drain_watch
353            .read()
354            .clone()
355            .ok_or(DataPathError::AlreadyClosedError)
356    }
357
358    /// Re-send `remote_subs` as subscribe messages to `conn_index`.
359    /// Delegates to [`RemoteSync::restore`].
360    async fn restore_remote_subscriptions(
361        &self,
362        remote_subs: &HashSet<SubscriptionInfo>,
363        conn_index: u64,
364        restore_tracking: bool,
365    ) {
366        self.remote_sync()
367            .restore(self, remote_subs, conn_index, restore_tracking)
368            .await;
369    }
370
371    async fn try_to_connect(
372        &self,
373        client_config: ClientConfig,
374        local: Option<SocketAddr>,
375        remote: Option<SocketAddr>,
376        existing_conn_index: Option<u64>,
377    ) -> Result<(JoinHandle<()>, u64), DataPathError> {
378        client_config.validate()?;
379
380        let mut watch = std::pin::pin!(self.get_drain_watch()?.signaled());
381        let channel = tokio::select! {
382            _ = &mut watch => {
383                return Err(DataPathError::ShuttingDownError);
384            }
385            res = client_config.to_channel() => {
386                res?
387            }
388        };
389
390        let cancellation_token = CancellationToken::new();
391        let link_id = client_config.link_id.clone();
392
393        match channel {
394            TransportChannel::Grpc(grpc_channel) => {
395                let mut client = DataPlaneServiceClient::new(grpc_channel);
396                let (tx, rx) = mpsc::channel(128);
397                let stream = client
398                    .open_channel(Request::new(ReceiverStream::new(rx)))
399                    .await?;
400
401                let (handle, conn_index_rx) = self.register_remote_connection(
402                    stream.into_inner(),
403                    Channel::Client(tx),
404                    &client_config,
405                    local,
406                    remote,
407                    existing_conn_index,
408                    cancellation_token,
409                    Some(link_id.clone()),
410                )?;
411
412                let conn_index = conn_index_rx.await.map_err(|_| {
413                    DataPathError::NegotiationError(
414                        "negotiation task terminated unexpectedly".to_string(),
415                    )
416                })??;
417
418                // For peer connections established via client config (generic topology),
419                // auto-register in the forwarder and perform full sync.
420                if matches!(client_config.connection_type, ConnType::Peer) {
421                    let fwd = self.peer_sync();
422                    if !fwd.has_peer_state() {
423                        fwd.add_peer_conn_and_sync(self, conn_index);
424                    }
425                }
426
427                Ok((handle, conn_index))
428            }
429            TransportChannel::Websocket(ws_channel) => {
430                let websocket = ws_channel
431                    .take_websocket()
432                    .expect("websocket channel already consumed");
433                let streams =
434                    websocket::spawn_transport_tasks(websocket, cancellation_token.clone());
435
436                let (handle, conn_index_rx) = self.register_remote_connection(
437                    streams.inbound,
438                    Channel::Client(streams.outbound),
439                    &client_config,
440                    local.or(ws_channel.local_addr()),
441                    remote.or(ws_channel.remote_addr()),
442                    existing_conn_index,
443                    cancellation_token,
444                    Some(link_id.clone()),
445                )?;
446
447                let conn_index = conn_index_rx.await.map_err(|_| {
448                    DataPathError::NegotiationError(
449                        "negotiation task terminated unexpectedly".to_string(),
450                    )
451                })??;
452
453                // For peer connections established via client config (generic topology),
454                // auto-register in the forwarder and perform full sync.
455                if matches!(client_config.connection_type, ConnType::Peer) {
456                    let fwd = self.peer_sync();
457                    if !fwd.has_peer_state() {
458                        fwd.add_peer_conn_and_sync(self, conn_index);
459                    }
460                }
461
462                Ok((handle, conn_index))
463            }
464        }
465    }
466
467    /// Common post-connect plumbing shared by every transport: register the
468    /// new [`Connection`] in the forwarder and spawn the per-stream processor.
469    /// Transport-specific code only has to produce the inbound stream + outbound
470    /// channel and call this — see [`Self::try_to_connect`] for client-side
471    /// usage and [`Self::handle_websocket_accepted`] for the server side.
472    #[allow(clippy::too_many_arguments)]
473    fn register_remote_connection<S>(
474        &self,
475        inbound: S,
476        outbound: Channel,
477        client_config: &ClientConfig,
478        local: Option<SocketAddr>,
479        remote: Option<SocketAddr>,
480        existing_conn_index: Option<u64>,
481        cancellation_token: CancellationToken,
482        link_id: Option<String>,
483    ) -> Result<
484        (
485            JoinHandle<()>,
486            oneshot::Receiver<Result<u64, DataPathError>>,
487        ),
488        DataPathError,
489    >
490    where
491        S: Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
492    {
493        let mut connection = Connection::new(client_config.connection_type, outbound)
494            .with_local_addr(local)
495            .with_remote_addr(remote)
496            .with_config_data(Some(client_config.clone()))
497            .with_require_header_mac(client_config.require_header_mac)
498            .with_cancellation_token(Some(cancellation_token.clone()));
499        if let Some(link_id) = link_id {
500            connection = connection.with_link_id(link_id);
501        }
502
503        debug!(
504            remote = ?connection.remote_addr(),
505            local = ?connection.local_addr(),
506            ?client_config.connection_type,
507            "new connection initiated locally",
508        );
509
510        let (handle, conn_index_rx) = self.process_stream(
511            inbound,
512            StreamSetup::Pending {
513                connection: Box::new(connection),
514                existing_index: existing_conn_index,
515            },
516            Some(client_config.clone()),
517            cancellation_token,
518            client_config.connection_type,
519            false,
520        )?;
521
522        Ok((handle, conn_index_rx))
523    }
524
525    pub async fn connect(
526        &self,
527        client_config: ClientConfig,
528        local: Option<SocketAddr>,
529        remote: Option<SocketAddr>,
530    ) -> Result<(JoinHandle<()>, u64), DataPathError> {
531        self.try_to_connect(client_config, local, remote, None)
532            .await
533    }
534
535    pub fn disconnect(&self, conn: u64) -> Result<ClientConfig, DataPathError> {
536        let connection = match self.forwarder().get_connection(conn) {
537            Some(c) => c,
538            None => {
539                error!(%conn, "error handling disconnect: connection unknown");
540                return Err(DataPathError::DisconnectionError(conn));
541            }
542        };
543
544        let token = match connection.cancellation_token() {
545            Some(t) => t,
546            None => {
547                error!(%conn, "error handling disconnect: missing cancellation token");
548                return Err(DataPathError::DisconnectionError(conn));
549            }
550        };
551
552        // Cancel receiving loop; this triggers deletion of connection state.
553        token.cancel();
554
555        connection
556            .config_data()
557            .cloned()
558            .ok_or(DataPathError::DisconnectionError(conn))
559    }
560
561    #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id))]
562    pub fn register_local_connection(
563        &self,
564        from_control_plane: bool,
565    ) -> Result<
566        (
567            u64,
568            tokio::sync::mpsc::Sender<Result<Message, Status>>,
569            tokio::sync::mpsc::Receiver<Result<Message, Status>>,
570        ),
571        DataPathError,
572    > {
573        // create a pair tx, rx to be able to send messages with the standard processing loop
574        let (tx1, rx1) = mpsc::channel(512);
575
576        debug!("establishing new local app connection");
577
578        // create a pair tx, rx to be able to receive messages and insert it into the connection table
579        let (tx2, rx2) = mpsc::channel(512);
580
581        // if the call is coming from the control plane set the tx channel
582        // we assume to talk to a single control plane so set the channel only once
583        if from_control_plane && self.get_tx_control_plane().is_none() {
584            self.set_tx_control_plane(tx2.clone());
585        }
586
587        // create a connection
588        let cancellation_token = CancellationToken::new();
589        let connection = Connection::new(ConnType::Local, Channel::Server(tx2))
590            .with_cancellation_token(Some(cancellation_token.clone()));
591
592        // add it to the connection table
593        let conn_id = self
594            .forwarder()
595            .on_connection_established(connection, None)
596            .unwrap();
597
598        debug!(%conn_id, "local connection established");
599        info!(telemetry = true, counter.num_active_connections = 1);
600
601        // this loop will process messages from the local app
602        self.process_stream(
603            ReceiverStream::new(rx1),
604            StreamSetup::Registered(conn_id),
605            None,
606            cancellation_token,
607            ConnType::Local,
608            from_control_plane,
609        )?;
610
611        // return the conn_id and  handles to be used to send and receive messages
612        Ok((conn_id, tx1, rx2))
613    }
614
615    pub async fn send_msg(
616        &self,
617        #[cfg(feature = "otel_tracing")] mut msg: Message,
618        #[cfg(not(feature = "otel_tracing"))] msg: Message,
619        out_conn: u64,
620    ) -> Result<(), DataPathError> {
621        #[cfg(feature = "otel_tracing")]
622        otel_tracing::prepare_outbound_msg(
623            &mut msg,
624            "send_message",
625            &self.internal.service_id,
626            otel_tracing::SpanTarget::Connection(out_conn),
627        );
628        self.send_msg_raw(msg, out_conn).await
629    }
630
631    async fn send_msg_raw(&self, mut msg: Message, out_conn: u64) -> Result<(), DataPathError> {
632        let connection = self.forwarder().get_connection(out_conn);
633        match connection {
634            Some(conn) => {
635                // Link and SubscriptionAck messages have no SLIM header: skip header
636                // manipulation and telemetry span creation.
637                if !msg.is_link() && !msg.is_subscription_ack() {
638                    msg.clear_slim_header();
639                }
640
641                if !msg.is_link()
642                    && !msg.is_subscription_ack()
643                    && matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge)
644                    && conn.require_header_mac()
645                    && conn.header_hmac().is_none()
646                {
647                    return Err(DataPathError::NegotiationError(
648                        "strict header MAC required but link HMAC session is not installed"
649                            .to_string(),
650                    ));
651                }
652
653                if !msg.is_link()
654                    && !msg.is_subscription_ack()
655                    && matches!(conn.connection_type(), ConnType::Remote | ConnType::Edge)
656                    && let Some(mac) = conn.header_hmac()
657                {
658                    let link_id = conn
659                        .link_id()
660                        .or_else(|| conn.config_data().map(|c| c.link_id.clone()))
661                        .filter(|id| !id.is_empty());
662                    if let Some(ref id) = link_id {
663                        let header = msg.get_slim_header_mut();
664
665                        mac.sign_slim_header(header, id.as_str())
666                            .map_err(DataPathError::HeaderIntegrity)?;
667
668                        // Debug / integration-test builds only (`--release` omits this; env var is inert).
669                        // Must run *after* sign so the tag does not cover the mutated preimage fields.
670                        #[cfg(debug_assertions)]
671                        if std::env::var("SLIM_TEST_TAMPER_DESTINATION").is_ok()
672                            && let Some(dest) = header.destination.as_mut()
673                            && let Some(sn) = dest.str_name.as_mut()
674                        {
675                            sn.str_component_2.push_str("-integrity-test-tamper");
676                        }
677                    } else {
678                        return Err(DataPathError::HeaderMacAwaitingLinkNegotiation(out_conn));
679                    }
680                }
681
682                if !msg.is_link()
683                    && !msg.is_subscription_ack()
684                    && matches!(conn.channel(), Channel::Server(_))
685                    && matches!(conn.connection_type(), ConnType::Local)
686                {
687                    msg.get_slim_header_mut().header_mac = None;
688                }
689
690                match conn.channel() {
691                    Channel::Server(s) => {
692                        s.send(Ok(msg))
693                            .await
694                            .map_err(|e| DataPathError::MessageProcessingError {
695                                source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
696                                msg: Box::new(e.0.unwrap_or_default()),
697                            })
698                    }
699                    Channel::Client(s) => {
700                        s.send(msg)
701                            .await
702                            .map_err(|e| DataPathError::MessageProcessingError {
703                                source: Box::new(DataPathError::ConnectionNotFound(out_conn)),
704                                msg: Box::new(e.0),
705                            })
706                    }
707                }
708            }
709            None => Err(DataPathError::ConnectionNotFound(out_conn)),
710        }
711    }
712
713    /// Send a gRPC status error on a server-side connection.
714    /// This causes the client's stream to yield `Err(status)`.
715    async fn send_status(&self, conn_index: u64, status: Status) {
716        if let Some(conn) = self.forwarder().get_connection(conn_index)
717            && let Channel::Server(tx) = conn.channel()
718        {
719            let _ = tx.send(Err(status)).await;
720        }
721    }
722
723    async fn match_and_forward_msg(
724        &self,
725        #[cfg(feature = "otel_tracing")] mut msg: Message,
726        #[cfg(not(feature = "otel_tracing"))] msg: Message,
727        in_connection: u64,
728        fanout: u32,
729        filter: MatchFilter,
730    ) -> Result<(), DataPathError> {
731        let header = msg.get_slim_header();
732        debug!(name = %header.get_dst(), %fanout, "match and forward message");
733
734        // if the message already contains an output connection, use that one
735        // without performing any match in the subscription table
736        if let Some(val) = msg.get_forward_to() {
737            debug!(conn = %val, "forwarding message to connection");
738            return self.send_msg(msg, val).await;
739        }
740
741        let encoded = header.get_encoded_dst();
742
743        match self
744            .forwarder()
745            .on_publish_msg_match(encoded, in_connection, fanout, filter)
746        {
747            Ok(out_vec) => {
748                let len = out_vec.len();
749                // Single destination: preserve per-connection span attributes.
750                if len == 1 {
751                    return self.send_msg(msg, out_vec[0]).await;
752                }
753
754                #[cfg(feature = "otel_tracing")]
755                otel_tracing::prepare_fanout_msg(
756                    &mut msg,
757                    "send_message",
758                    &self.internal.service_id,
759                    len as u32,
760                );
761
762                let mut i = 0usize;
763                while i < len - 1 {
764                    self.send_msg_raw(msg.clone(), out_vec[i]).await?;
765                    i += 1;
766                }
767                self.send_msg_raw(msg, out_vec[i]).await?;
768                Ok(())
769            }
770            Err(e) => {
771                debug!(name = %header.get_dst(), %fanout, error = %e, "no match for publish destination");
772                Err(DataPathError::MessageProcessingError {
773                    source: Box::new(e),
774                    msg: Box::new(msg),
775                })
776            }
777        }
778    }
779
780    /// Dispatch an inbound Link message to the appropriate handler.
781    ///
782    /// Link messages are link-local and must never be processed for local connections
783    /// (they are only exchanged between SLIM nodes).
784    async fn handle_link_message(
785        &self,
786        link: ProtoLink,
787        conn_index: u64,
788        category: ConnType,
789    ) -> Result<(), DataPathError> {
790        if category.is_local() {
791            debug!(%conn_index, "ignoring link message received on local connection");
792            return Ok(());
793        }
794        match link.link_type {
795            Some(ProtoLinkType::LinkNegotiation(payload)) => {
796                self.handle_link_negotiation(&payload, conn_index).await
797            }
798            None => {
799                debug!(%conn_index, "received link message with unset link_type");
800                Ok(())
801            }
802        }
803    }
804
805    /// Handle an inbound link negotiation message arriving in the main loop.
806    ///
807    /// Since negotiation is mandatory and completes before the connection is
808    /// inserted into the table, any link negotiation message arriving here is
809    /// either a duplicate or a protocol error — we log and ignore it.
810    async fn handle_link_negotiation(
811        &self,
812        payload: &LinkNegotiationPayload,
813        in_connection: u64,
814    ) -> Result<(), DataPathError> {
815        debug!(
816            %in_connection,
817            link_id = %payload.link_id,
818            is_reply = payload.is_reply,
819            "ignoring link negotiation message on already-negotiated connection",
820        );
821
822        Ok(())
823    }
824
825    /// Upgrade a server-side connection to Peer after validating identity and deployment_name.
826    /// Notifies PeerSyncManager or auto-registers in the forwarder (generic topology).
827    pub(crate) async fn handle_peer_upgrade(
828        &self,
829        remote_node_id: &str,
830        remote_deployment_name: &str,
831        in_connection: u64,
832        link_id: &str,
833    ) -> Result<(), DataPathError> {
834        // Reject self-connections (can happen when all replicas share the same config).
835        if remote_node_id == self.internal.service_id {
836            warn!(
837                %in_connection, %link_id,
838                "rejecting peer connection from self (same node_id)"
839            );
840            self.send_status(
841                in_connection,
842                Status::permission_denied("self-connection rejected: same node_id"),
843            )
844            .await;
845            let _ = self.disconnect(in_connection);
846            return Ok(());
847        }
848
849        // Verify deployment_name: if we have a deployment_name configured, the remote must match.
850        if !self.internal.deployment_name.is_empty()
851            && remote_deployment_name != self.internal.deployment_name
852        {
853            warn!(
854                %in_connection, %link_id,
855                local_group = %self.internal.deployment_name,
856                remote_group = %remote_deployment_name,
857                "rejecting peer upgrade: deployment_name mismatch"
858            );
859            self.send_status(
860                in_connection,
861                Status::permission_denied("deployment_name mismatch"),
862            )
863            .await;
864            let _ = self.disconnect(in_connection);
865            return Ok(());
866        }
867
868        info!(
869            %in_connection, %link_id, %remote_node_id,
870            "upgrading server-side connection to Peer (negotiation)"
871        );
872        self.connection_table().update(in_connection, |conn| {
873            conn.set_connection_type(ConnType::Peer)
874        });
875
876        self.peer_sync()
877            .on_incoming_peer(self, remote_node_id.to_string(), in_connection);
878
879        Ok(())
880    }
881
882    async fn process_publish(
883        &self,
884        msg: Message,
885        in_connection: u64,
886        filter: MatchFilter,
887    ) -> Result<(), DataPathError> {
888        debug!(
889            %in_connection,
890            ?msg,
891            "received publication"
892        );
893
894        // telemetry /////////////////////////////////////////
895        info!(
896            telemetry = true,
897            monotonic_counter.num_messages_by_type = 1,
898            method = "publish"
899        );
900        //////////////////////////////////////////////////////
901
902        // this function may panic, but at this point we are sure we are processing
903        // a publish message
904        let fanout = msg.get_fanout();
905
906        self.match_and_forward_msg(msg, in_connection, fanout, filter)
907            .await
908    }
909
910    pub(crate) async fn send_subscription_ack(
911        &self,
912        in_connection: u64,
913        subscription_id: u64,
914        result: &Result<(), DataPathError>,
915    ) {
916        let (success, error_msg) = match result {
917            Ok(()) => (true, String::new()),
918            Err(e) => (false, e.to_string()),
919        };
920
921        let ack_msg =
922            Message::builder().build_subscription_ack(subscription_id, success, error_msg);
923
924        if let Err(e) = self.send_msg(ack_msg, in_connection).await {
925            error!(error = %e.chain(), "failed to send subscription ack");
926        }
927    }
928
929    /// Pure state update for a subscription: updates the subscription table
930    /// and returns the outcome (whether a transition occurred, connection type, forward target).
931    /// Does NOT perform any forwarding or event emission.
932    fn update_subscription_state(
933        &self,
934        msg: &Message,
935        conn: u64,
936        forward: Option<u64>,
937        add: bool,
938        subscription_id: u64,
939    ) -> Result<SubscriptionOutcome, DataPathError> {
940        let dst = msg.get_dst();
941
942        // As connection is deleted only after processing, at this point it must exist.
943        let connection = if let Some(c) = self.forwarder().get_connection(conn) {
944            c
945        } else {
946            return Err(DataPathError::ConnectionNotFound(conn));
947        };
948
949        debug!(
950            %conn,
951            %dst,
952            is_local = connection.is_local_connection(),
953            "processing {}subscription state",
954            if add { "" } else { "un" }
955        );
956
957        let is_peer_conn = connection.is_peer_connection();
958
959        let transition = self.forwarder().on_subscription_msg(
960            dst,
961            conn,
962            connection.connection_type(),
963            add,
964            subscription_id,
965        )?;
966
967        Ok(SubscriptionOutcome {
968            transition,
969            is_peer_conn,
970            forward_conn: forward,
971        })
972    }
973
974    // Use a single function to process subscription and unsubscription packets.
975    // The flag add = true is used to add a new subscription while add = false
976    // is used to remove existing state.
977    //
978    // This is the SINGLE entry point for all subscription handling.
979    // All forwarding (to peers, to controller, hub relay) goes through
980    // the PeerSync — no inline forwarding anywhere else.
981    async fn process_subscription(
982        &self,
983        msg: Message,
984        in_connection: u64,
985        add: bool,
986    ) -> Result<(), DataPathError> {
987        debug!(
988            %in_connection,
989            ?msg,
990            "received {}subscription",
991            if add { "" } else { "un" }
992        );
993
994        // telemetry /////////////////////////////////////////
995        info!(
996            telemetry = true,
997            monotonic_counter.num_messages_by_type = 1,
998            message_type = { if add { "subscribe" } else { "unsubscribe" } }
999        );
1000        //////////////////////////////////////////////////////
1001
1002        let subscription_id = msg.get_subscription_id();
1003
1004        debug!(?subscription_id, "received subscription id");
1005
1006        // get header
1007        let header = msg.get_slim_header();
1008
1009        // get in and out connections
1010        let (in_conn, recv_from, forward) = header.get_connections();
1011        let in_conn = recv_from.unwrap_or(in_conn);
1012
1013        // Never forward subscriptions to local connections (they are local apps whose
1014        // routes are already set locally).
1015        let forward = forward.filter(|&out| {
1016            self.forwarder()
1017                .get_connection(out)
1018                .map(|c| !c.is_local_connection())
1019                .unwrap_or(true)
1020        });
1021
1022        // As connection is deleted only after processing, at this point it must exist.
1023        let Some(connection) = self.forwarder().get_connection(in_conn) else {
1024            if let Some(id) = subscription_id {
1025                debug!(%in_conn, "connection not found, sending error ack");
1026                self.send_subscription_ack(
1027                    in_connection,
1028                    id,
1029                    &Err(DataPathError::ConnectionNotFound(in_conn)),
1030                )
1031                .await;
1032            }
1033            return Err(DataPathError::MessageProcessingError {
1034                source: Box::new(DataPathError::ConnectionNotFound(in_conn)),
1035                msg: Box::new(msg),
1036            });
1037        };
1038
1039        // Do not process subscriptions forwarded back to local connections.
1040        if recv_from.is_some() && connection.is_local_connection() {
1041            if let Some(id) = subscription_id {
1042                debug!(%in_conn, "subscription looped back to local connection, acking ok");
1043                self.send_subscription_ack(in_connection, id, &Ok(())).await;
1044            }
1045            return Ok(());
1046        }
1047
1048        // Loop prevention: check if this subscription_id has already been forwarded
1049        // by this node. This prevents loops in ring/mesh topologies where a subscription
1050        // could travel around and come back. Only applies to subscribes (add=true);
1051        // unsubscribes with a seen sub_id are expected (they cancel a prior forwarded sub).
1052        // Unsubscribe loops are bounded by TTL and don't cause state corruption since
1053        // remove operations are idempotent.
1054        let sub_id = subscription_id.unwrap_or(0);
1055        if add && sub_id != 0 && self.peer_sync().has_seen_sub_id(sub_id) {
1056            debug!(
1057                %in_conn,
1058                %sub_id,
1059                "dropping subscription already forwarded by this node (loop prevention)"
1060            );
1061            if let Some(id) = subscription_id {
1062                self.send_subscription_ack(in_connection, id, &Ok(())).await;
1063            }
1064            return Ok(());
1065        }
1066
1067        // Update local state (subscription table) — pure state change, no forwarding.
1068        let outcome = match self.update_subscription_state(&msg, in_conn, forward, add, sub_id) {
1069            Ok(o) => o,
1070            Err(e) => {
1071                if let Some(id) = subscription_id {
1072                    self.send_subscription_ack(in_connection, id, &Err(e)).await;
1073                    // Return Ok since we already sent the error ACK.
1074                    return Ok(());
1075                }
1076                return Err(DataPathError::MessageProcessingError {
1077                    source: Box::new(e),
1078                    msg: Box::new(msg),
1079                });
1080            }
1081        };
1082
1083        // Determine forwarding targets:
1084        // - Peers (All): non-peer subscription with aggregate transition (0→1 or 1→0)
1085        // - Peers (ExcludeConn): peer subscription with remaining TTL >= 2 (relay)
1086        // - Forward conn: controller/remote node when header.forward_to is set
1087        //
1088        // TTL controls propagation depth:
1089        // - TTL=2 on initial send → peer decrements to 1, sees 1 < 2, no relay (full mesh)
1090        // - TTL=3 on initial send → hub decrements to 2, relays; spoke decrements to 1, stops
1091        // - TTL=6 on initial send → allows up to 5 hops of relay (generic topology)
1092        let remaining_ttl = msg.get_ttl();
1093
1094        let (peer_target, peer_ttl) = if !outcome.is_peer_conn && outcome.transition {
1095            // Local/remote subscription transition → forward to ALL peers with configured TTL
1096            let ttl = self.peer_sync().subscription_ttl();
1097            (Some(crate::sync::PeerTarget::All), ttl)
1098        } else if outcome.is_peer_conn && remaining_ttl >= 2 {
1099            // Peer subscription relay: TTL allows further propagation.
1100            // Forward to all peers except the source, using remaining TTL.
1101            (
1102                Some(crate::sync::PeerTarget::ExcludeConn(in_conn)),
1103                remaining_ttl,
1104            )
1105        } else {
1106            (None, 0)
1107        };
1108
1109        let targets = crate::sync::ForwardTargets {
1110            peers: peer_target,
1111            forward_conn: outcome.forward_conn,
1112        };
1113
1114        // If there are forwarding targets, spawn the forwarder task (non-blocking).
1115        // The forwarder will wait for ACKs and then ACK the upstream client.
1116        if targets.has_any() {
1117            let fwd = self.peer_sync();
1118            let dst = msg.get_dst();
1119            debug!(
1120                %in_connection,
1121                %dst,
1122                %remaining_ttl,
1123                %peer_ttl,
1124                ?targets,
1125                "spawning subscription forwarder task"
1126            );
1127            let drain = self.get_drain_watch().ok();
1128            if let Some(drain) = drain {
1129                fwd.spawn_forward_and_ack(
1130                    self.clone(),
1131                    msg,
1132                    dst,
1133                    sub_id,
1134                    add,
1135                    targets,
1136                    in_connection,
1137                    subscription_id,
1138                    peer_ttl,
1139                    drain,
1140                );
1141                return Ok(());
1142            }
1143            // Fallback: drain not available (shutting down).
1144            // ACK immediately as best-effort.
1145        }
1146
1147        // No forwarding needed (or no forwarder) — ACK immediately.
1148        if let Some(id) = subscription_id {
1149            debug!(%in_connection, "sending immediate subscription ack (no forwarding)");
1150            self.send_subscription_ack(in_connection, id, &Ok(())).await;
1151        }
1152
1153        Ok(())
1154    }
1155
1156    pub async fn process_message(
1157        &self,
1158        msg: Message,
1159        in_connection: u64,
1160        category: ConnType,
1161    ) -> Result<(), DataPathError> {
1162        match msg.message_type {
1163            Some(SubscribeType(_)) => self.process_subscription(msg, in_connection, true).await,
1164            Some(UnsubscribeType(_)) => self.process_subscription(msg, in_connection, false).await,
1165            Some(PublishType(_)) => {
1166                let filter = match category {
1167                    ConnType::Peer => {
1168                        if self.internal.relay_peer_publishes {
1169                            MatchFilter::ALL
1170                        } else {
1171                            MatchFilter::EXCLUDE_PEER
1172                        }
1173                    }
1174                    _ => MatchFilter::ALL,
1175                };
1176                self.process_publish(msg, in_connection, filter).await
1177            }
1178            Some(LinkType(link)) => {
1179                self.handle_link_message(link, in_connection, category)
1180                    .await
1181            }
1182            Some(SubscriptionAckType(ack)) => {
1183                let result = if ack.success {
1184                    Ok(())
1185                } else {
1186                    Err(DataPathError::RemoteSubscriptionAckError(ack.error))
1187                };
1188
1189                self.peer_sync().resolve_ack(ack.subscription_id, result);
1190                Ok(())
1191            }
1192            None => unreachable!(
1193                "message type not set; validate() must be called before process_message"
1194            ),
1195        }
1196    }
1197
1198    pub(crate) async fn handle_new_message(
1199        &self,
1200        conn_index: u64,
1201        category: ConnType,
1202        mut msg: Message,
1203    ) -> Result<(), DataPathError> {
1204        debug!(%conn_index, "received message from connection");
1205        info!(
1206            telemetry = true,
1207            monotonic_counter.num_processed_messages = 1
1208        );
1209
1210        // validate message
1211        if let Err(err) = msg.validate() {
1212            info!(
1213                telemetry = true,
1214                monotonic_counter.num_messages_by_type = 1,
1215                message_type = "none"
1216            );
1217
1218            let ret_err = DataPathError::MessageProcessingError {
1219                source: Box::new(err.into()),
1220                msg: Box::new(msg),
1221            };
1222
1223            return Err(ret_err);
1224        }
1225
1226        // Link and SubscriptionAck messages have no SLIM header: skip header processing and telemetry span.
1227        if !msg.is_link() && !msg.is_subscription_ack() {
1228            // add incoming connection to the SLIM header
1229            msg.set_incoming_conn(Some(conn_index));
1230
1231            // TTL processing: decrement for remote messages (hop-by-hop)
1232            if !category.is_local() && msg.decrement_ttl() == 0 {
1233                debug!(%conn_index, "dropping message: TTL expired");
1234                return Err(DataPathError::TtlExpired);
1235            }
1236
1237            #[cfg(feature = "otel_tracing")]
1238            otel_tracing::prepare_inbound_msg(
1239                &mut msg,
1240                "process_local",
1241                &self.internal.service_id,
1242                conn_index,
1243                category.is_local(),
1244            );
1245        }
1246
1247        match self.process_message(msg, conn_index, category).await {
1248            Ok(_) => Ok(()),
1249            Err(e) => {
1250                // telemetry /////////////////////////////////////////
1251                info!(
1252                    telemetry = true,
1253                    monotonic_counter.num_message_process_errors = 1
1254                );
1255                //////////////////////////////////////////////////////
1256
1257                // drop message
1258                Err(e)
1259            }
1260        }
1261    }
1262
1263    #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1264    async fn send_error_to_local_app(&self, conn_index: u64, err: DataPathError) {
1265        debug!(%conn_index, "sending error to local application");
1266        let connection = self.forwarder().get_connection(conn_index);
1267        match connection {
1268            Some(conn) => {
1269                debug!("try to notify the error to the local application");
1270                if let Channel::Server(tx) = conn.channel() {
1271                    // If the error contains the message, try to extract some session information
1272                    let session_ctx = match &err {
1273                        DataPathError::MessageProcessingError { msg, .. } => {
1274                            MessageContext::from_msg(msg)
1275                        }
1276                        _ => None,
1277                    };
1278
1279                    // Make error message with optional session context using shared type
1280                    let payload = crate::errors::ErrorPayload::new(err.to_string(), session_ctx);
1281                    let error_message = payload.to_json_string();
1282
1283                    // create Status error
1284                    let status = Status::new(tonic::Code::Internal, error_message);
1285
1286                    if tx.send(Err(status)).await.is_err() {
1287                        debug!(error = %err.chain(), "unable to notify the error to the local app");
1288                    }
1289                }
1290            }
1291            None => {
1292                error!(
1293                    "error sending error to local app: connection {:?} not found",
1294                    conn_index
1295                );
1296            }
1297        }
1298    }
1299
1300    #[tracing::instrument(skip_all, fields(service_id = %self.internal.service_id, conn_index))]
1301    async fn reconnect(
1302        &self,
1303        client_conf: ClientConfig,
1304        conn_index: u64,
1305        cancellation_token: &CancellationToken,
1306    ) -> bool {
1307        info!("connection lost with remote endpoint, attempting to reconnect");
1308
1309        let is_peer = self
1310            .forwarder()
1311            .get_connection(conn_index)
1312            .map(|c| c.connection_type() == ConnType::Peer)
1313            .unwrap_or(false);
1314
1315        // For remote/controller connections: save the subscriptions we forwarded to this
1316        // connection so we can replay them after reconnecting.
1317        // For peer connections: we do a full sync instead (no need to save).
1318        let remote_subscriptions = if !is_peer {
1319            self.remote_sync()
1320                .get_subscriptions_for_reconnect(conn_index)
1321        } else {
1322            Default::default()
1323        };
1324
1325        tokio::select! {
1326            _ = cancellation_token.cancelled() => {
1327                debug!("cancellation token signaled, stopping reconnection process");
1328                false
1329            }
1330            res = self.try_to_connect(client_conf, None, None, Some(conn_index)) => {
1331                match res {
1332                    Ok(_) => {
1333                        info!("connection re-established successfully");
1334                        if is_peer {
1335                            // Peer connection: full sync (send local + remote subscriptions).
1336                            let ttl = self.peer_sync().subscription_ttl();
1337                            if let Err(e) = sync_peer::send_local_remote_sync(
1338                                self, conn_index, ttl,
1339                            )
1340                            .await
1341                            {
1342                                warn!(
1343                                    error = %e,
1344                                    "failed to send full sync after peer reconnect"
1345                                );
1346                            }
1347                        } else {
1348                            // Remote/controller: restore only what was previously forwarded.
1349                            self.restore_remote_subscriptions(
1350                                &remote_subscriptions,
1351                                conn_index,
1352                                false,
1353                            )
1354                            .await;
1355                        }
1356                        true
1357                    }
1358                    Err(e) => {
1359                        error!(error = %e.chain(), "unable to reconnect to remote node");
1360                        false
1361                    }
1362                }
1363            }
1364        }
1365    }
1366
1367    /// Send an UNSUBSCRIBE message to the control plane for each subscription in `local_subs`.
1368    ///
1369    /// This is the single authoritative place that constructs and delivers CP unsubscribe
1370    /// notifications on connection loss, used by both the immediate cleanup path and the deferred
1371    /// TTL-expiry path.
1372    async fn notify_control_plane_subscriptions_lost(
1373        tx_cp: Option<Sender<Result<Message, Status>>>,
1374        local_subs: HashMap<ProtoName, HashSet<u64>>,
1375        conn_index: u64,
1376    ) {
1377        let Some(tx) = tx_cp else { return };
1378        for local_sub in local_subs.into_keys() {
1379            debug!(
1380                %local_sub,
1381                "notify control plane about lost subscription",
1382            );
1383            let msg = Message::builder()
1384                .source(local_sub.clone())
1385                .destination(local_sub.clone())
1386                .flags(SlimHeaderFlags::default().with_recv_from(conn_index))
1387                .build_unsubscribe()
1388                .unwrap();
1389            if let Err(e) = tx.send(Ok(msg)).await {
1390                debug!(
1391                    %local_sub,
1392                    error = %e.chain(),
1393                    "failed to send unsubscribe to control plane",
1394                );
1395            }
1396        }
1397    }
1398
1399    /// Resolve the connection index for a new stream.
1400    ///
1401    /// For local (already-registered) connections, returns immediately.
1402    /// For remote connections, runs mandatory link negotiation, inserts the
1403    /// connection into the table, and handles peer upgrade if negotiated.
1404    ///
1405    /// Returns `Some((conn_index, category))` on success or `None` if the
1406    /// connection setup failed (the error is sent on `conn_index_tx`).
1407    async fn resolve_connection(
1408        &self,
1409        stream: &mut (impl Stream<Item = Result<Message, Status>> + Unpin + Send),
1410        setup: StreamSetup,
1411        category: ConnType,
1412        conn_index_tx: oneshot::Sender<Result<u64, DataPathError>>,
1413        watch: &drain::Watch,
1414        token: &CancellationToken,
1415    ) -> Option<(u64, ConnType)> {
1416        match setup {
1417            StreamSetup::Registered(idx) => {
1418                let _ = conn_index_tx.send(Ok(idx));
1419                Some((idx, category))
1420            }
1421            StreamSetup::Pending {
1422                connection,
1423                existing_index,
1424            } => {
1425                self.negotiate_and_register(
1426                    stream,
1427                    *connection,
1428                    existing_index,
1429                    category,
1430                    conn_index_tx,
1431                    watch,
1432                    token,
1433                )
1434                .await
1435            }
1436        }
1437    }
1438
1439    /// Perform link negotiation, register the connection, and handle peer upgrade.
1440    ///
1441    /// Returns `Some((conn_index, category))` on success or `None` on failure.
1442    #[allow(clippy::too_many_arguments)]
1443    async fn negotiate_and_register(
1444        &self,
1445        stream: &mut (impl Stream<Item = Result<Message, Status>> + Unpin + Send),
1446        mut connection: Connection,
1447        existing_index: Option<u64>,
1448        category: ConnType,
1449        conn_index_tx: oneshot::Sender<Result<u64, DataPathError>>,
1450        watch: &drain::Watch,
1451        token: &CancellationToken,
1452    ) -> Option<(u64, ConnType)> {
1453        let timeout = self.internal.negotiation_timeout;
1454        let params = crate::negotiation::NegotiationParams {
1455            node_id: &self.internal.service_id,
1456            deployment_name: &self.internal.deployment_name,
1457            connection_type: category,
1458        };
1459
1460        let negotiation_result = tokio::select! {
1461            result = tokio::time::timeout(
1462                timeout,
1463                crate::negotiation::run_negotiation(&mut connection, stream, &params),
1464            ) => match result {
1465                Ok(r) => r,
1466                Err(_) => Err(DataPathError::NegotiationError(
1467                    "timed out waiting for link negotiation".to_string(),
1468                )),
1469            },
1470            _ = watch.clone().signaled() => {
1471                info!("shutting down during link negotiation");
1472                let _ = conn_index_tx.send(Err(DataPathError::ShuttingDownError));
1473                return None;
1474            }
1475            _ = token.cancelled() => {
1476                info!("connection cancelled during link negotiation");
1477                let _ = conn_index_tx.send(Err(DataPathError::ShuttingDownError));
1478                return None;
1479            }
1480        };
1481
1482        let result = match negotiation_result {
1483            Ok(r) => r,
1484            Err(e) => {
1485                error!(error = %e.chain(), "link negotiation failed, closing connection");
1486                let _ = conn_index_tx.send(Err(e));
1487                info!(telemetry = true, counter.num_active_connections = -1);
1488                return None;
1489            }
1490        };
1491
1492        // Insert the fully-negotiated connection into the table.
1493        let idx = match self
1494            .forwarder()
1495            .on_connection_established(connection, existing_index)
1496        {
1497            Some(idx) => idx,
1498            None => {
1499                let _ = conn_index_tx.send(Err(DataPathError::ConnectionTableAddError));
1500                info!(telemetry = true, counter.num_active_connections = -1);
1501                return None;
1502            }
1503        };
1504
1505        debug!(%idx, "connection registered after link negotiation");
1506
1507        // Handle connection-type-specific post-negotiation logic.
1508        let category = match result.connection_type {
1509            ConnType::Peer => {
1510                let link_id = self
1511                    .forwarder()
1512                    .get_connection(idx)
1513                    .and_then(|c| c.link_id())
1514                    .unwrap_or_default();
1515                if let Err(e) = self
1516                    .handle_peer_upgrade(
1517                        &result.remote_node_id,
1518                        &result.remote_deployment_name,
1519                        idx,
1520                        &link_id,
1521                    )
1522                    .await
1523                {
1524                    error!(error = %e.chain(), "peer upgrade failed after negotiation");
1525                    let _ = conn_index_tx.send(Err(e));
1526                    info!(telemetry = true, counter.num_active_connections = -1);
1527                    return None;
1528                }
1529                ConnType::Peer
1530            }
1531            other => other,
1532        };
1533
1534        let _ = conn_index_tx.send(Ok(idx));
1535        Some((idx, category))
1536    }
1537
1538    fn process_stream(
1539        &self,
1540        mut stream: impl Stream<Item = Result<Message, Status>> + Unpin + Send + 'static,
1541        setup: StreamSetup,
1542        client_config: Option<ClientConfig>,
1543        cancellation_token: CancellationToken,
1544        category: ConnType,
1545        from_control_plane: bool,
1546    ) -> Result<
1547        (
1548            JoinHandle<()>,
1549            oneshot::Receiver<Result<u64, DataPathError>>,
1550        ),
1551        DataPathError,
1552    > {
1553        // Clone self to be able to move it into the spawned task
1554        let self_clone = self.clone();
1555        let token_clone = cancellation_token.clone();
1556        let client_conf_clone = client_config.clone();
1557        let tx_cp: Option<Sender<Result<Message, Status>>> = self.get_tx_control_plane();
1558        let watch = self.get_drain_watch()?;
1559        let is_local = category.is_local();
1560
1561        let (conn_index_tx, conn_index_rx) = oneshot::channel();
1562
1563        let require_header_mac = match &setup {
1564            StreamSetup::Registered(idx) => self
1565                .forwarder()
1566                .get_connection(*idx)
1567                .map(|c| c.require_header_mac())
1568                .unwrap_or(false),
1569            StreamSetup::Pending { connection, .. } => connection.require_header_mac(),
1570        };
1571
1572        let span = tracing::info_span!(
1573            "process_stream",
1574            service_id = %self.internal.service_id,
1575            conn_index = match &setup {
1576                StreamSetup::Registered(idx) => *idx,
1577                _ => 0,
1578            },
1579            is_local,
1580        );
1581
1582        let handle = tokio::spawn(async move {
1583            let mut try_to_reconnect = true;
1584
1585            // Resolve the conn_index: either already registered (local) or
1586            // perform negotiation + table insertion (remote).
1587            let Some((conn_index, category)) = self_clone
1588                .resolve_connection(
1589                    &mut stream,
1590                    setup,
1591                    category,
1592                    conn_index_tx,
1593                    &watch,
1594                    &token_clone,
1595                )
1596                .await
1597            else {
1598                return;
1599            };
1600
1601            let mut watch = std::pin::pin!(watch.signaled());
1602            loop {
1603                tokio::select! {
1604                    next = stream.next() => {
1605                        match next {
1606                            Some(result) => {
1607                                match result {
1608                                    Ok(msg) => {
1609                                        if !is_local
1610                                            && !msg.is_link()
1611                                            && !msg.is_subscription_ack()
1612                                            && let Err(e) = self_clone
1613                                                .verify_remote_header_mac(conn_index, &msg, require_header_mac)
1614                                        {
1615                                            error!(
1616                                                %conn_index,
1617                                                error = %e.chain(),
1618                                                "SLIM header integrity verification failed",
1619                                            );
1620                                            continue;
1621                                        }
1622                                        // check if we need to send the message to the control plane
1623                                        // we send the message if
1624                                        // 1. the message is coming from remote
1625                                        // 2. it is not coming from the control plane itself
1626                                        // 3. the control plane exists
1627                                        if !is_local && !from_control_plane && let Some(txcp) = &tx_cp {
1628                                            match msg.get_type() {
1629                                                PublishType(_) | LinkType(_) | SubscriptionAckType(_) => {/* do nothing */}
1630                                                _ => {
1631                                                    // send subscriptions and unsubscriptions
1632                                                    // to the control plane
1633                                                    let _ = txcp.send(Ok(msg.clone())).await;
1634                                                }
1635                                            }
1636                                        }
1637
1638                                        if let Err(e) = self_clone.handle_new_message(conn_index, category, msg).await {
1639                                            // Checking if NegotiationError occurred
1640                                            if matches!(e, DataPathError::NegotiationError(_)) {
1641                                                error!(%conn_index, "fatal link negotiation error, closing connection");
1642                                                try_to_reconnect = false;
1643                                                break;
1644                                            }
1645                                            debug!(%conn_index, error = %e.chain(), "error processing incoming message");
1646                                            // If the message is coming from a local app, notify it
1647                                            if is_local {
1648                                                // try to forward error to the local app
1649                                                self_clone.send_error_to_local_app(conn_index, e).await;
1650                                            }
1651                                        }
1652                                    }
1653                                    Err(e) => {
1654                                        if e.code() == tonic::Code::PermissionDenied {
1655                                            warn!(
1656                                                %conn_index,
1657                                                message = %e.message(),
1658                                                "connection rejected by remote, will not reconnect"
1659                                            );
1660                                            try_to_reconnect = false;
1661                                        } else if let Some(io_err) = MessageProcessor::match_for_io_error(&e) {
1662                                            if io_err.kind() == std::io::ErrorKind::BrokenPipe {
1663                                                info!(%conn_index, "connection closed by peer");
1664                                            }
1665                                        } else {
1666                                            error!(error = %e.chain(), "error receiving messages");
1667                                        }
1668                                        break;
1669                                    }
1670                                }
1671                            }
1672                            None => {
1673                                debug!(%conn_index, "end of stream");
1674                                break;
1675                            }
1676                        }
1677                    }
1678                    _ = &mut watch => {
1679                        info!(%conn_index, "shutting down stream on drain");
1680                        try_to_reconnect = false;
1681                        break;
1682                    }
1683                    _ = token_clone.cancelled() => {
1684                        info!(%conn_index, "shutting down stream on cancellation token");
1685                        try_to_reconnect = false;
1686                        break;
1687                    }
1688                }
1689            }
1690
1691            // we drop rx now as otherwise the connection will be closed only
1692            // when the task is dropped and we want to make sure that the rx
1693            // stream is closed as soon as possible
1694            drop(stream);
1695
1696            let mut connected = false;
1697
1698            if try_to_reconnect
1699                && !matches!(category, ConnType::Remote)
1700                && let Some(config) = client_conf_clone
1701            {
1702                // Break the span chain: reconnect → try_to_connect → process_stream
1703                // would otherwise nest under the current process_stream span on every
1704                // reconnection, growing the span hierarchy unboundedly.
1705                connected = self_clone.reconnect(config, conn_index, &token_clone)
1706                    .instrument(tracing::Span::none())
1707                    .await;
1708            } else {
1709                debug!(%conn_index, "close connection")
1710            }
1711
1712            if !connected {
1713                // Delete connection state from all tables.
1714                let local_subs = self_clone
1715                    .forwarder()
1716                    .on_connection_drop(conn_index, category);
1717                let _remote_subs = self_clone
1718                    .remote_sync()
1719                    .on_connection_drop(conn_index);
1720
1721                // Remove peer connection from forwarder's peer list if applicable.
1722                if matches!(category, ConnType::Peer) {
1723                    self_clone.peer_sync().remove_peer_conn(conn_index);
1724                }
1725
1726                // Notify peer sync about names that are no longer reachable.
1727                // For generic topologies (TTL-based relay), we also need to notify
1728                // when a peer drops — the seen_sub_ids tracking ensures we only
1729                // send unsubscribes for subscriptions we actually forwarded.
1730                {
1731                    let fwd = self_clone.peer_sync();
1732                    for name in local_subs.keys() {
1733                        let still_reachable = name.name.is_some_and(|enc| {
1734                            self_clone
1735                                .forwarder()
1736                                .on_publish_msg_match(enc, u64::MAX, u32::MAX, MatchFilter::ALL)
1737                                .is_ok()
1738                        });
1739                        if !still_reachable {
1740                            debug!(
1741                                %name,
1742                                %conn_index,
1743                                ?category,
1744                                "notifying peers of unsubscription (connection drop)"
1745                            );
1746                            fwd.notify_peers_unsubscribe(&self_clone, name).await;
1747                        } else {
1748                            debug!(
1749                                %name,
1750                                %conn_index,
1751                                ?category,
1752                                "name still reachable, not emitting removal"
1753                            );
1754                        }
1755                    }
1756                }
1757
1758
1759                // Notify the control plane about lost subscriptions.
1760                if !is_local {
1761                    MessageProcessor::notify_control_plane_subscriptions_lost(
1762                        tx_cp, local_subs, conn_index,
1763                    )
1764                    .await;
1765                }
1766
1767                info!(telemetry = true, counter.num_active_connections = -1);
1768            }
1769        }.instrument(span));
1770
1771        Ok((handle, conn_index_rx))
1772    }
1773
1774    fn match_for_io_error(err_status: &Status) -> Option<&std::io::Error> {
1775        let mut err: &(dyn std::error::Error + 'static) = err_status;
1776
1777        loop {
1778            if let Some(io_err) = err.downcast_ref::<std::io::Error>() {
1779                return Some(io_err);
1780            }
1781
1782            // h2::Error do not expose std::io::Error with `source()`
1783            // https://github.com/hyperium/h2/pull/462
1784            if let Some(h2_err) = err.downcast_ref::<h2::Error>()
1785                && let Some(io_err) = h2_err.get_io()
1786            {
1787                return Some(io_err);
1788            }
1789
1790            err = err.source()?;
1791        }
1792    }
1793
1794    pub fn subscription_table(&self) -> &SubscriptionTableImpl {
1795        &self.internal.forwarder.subscription_table
1796    }
1797
1798    pub fn connection_table(&self) -> &ConnectionTable<Connection> {
1799        &self.internal.forwarder.connection_table
1800    }
1801
1802    /// The node identity used for cross-node communication.
1803    pub fn service_id(&self) -> &str {
1804        &self.internal.service_id
1805    }
1806
1807    /// Set the peer sync component.
1808    pub fn set_peer_sync(&self, peer_sync: crate::sync::PeerSync) {
1809        *self.internal.peer_sync.write() = peer_sync;
1810    }
1811
1812    /// Get a clone of the peer sync component.
1813    pub(crate) fn peer_sync(&self) -> crate::sync::PeerSync {
1814        self.internal.peer_sync.read().clone()
1815    }
1816}
1817
1818impl ServerHandler for MessageProcessor {
1819    fn grpc_routes(&self) -> Option<tonic::service::Routes> {
1820        let svc = DataPlaneServiceServer::from_arc(Arc::new(self.clone()));
1821        Some(tonic::service::Routes::new(svc))
1822    }
1823
1824    fn on_websocket_accepted(&self) -> Option<websocket_server::OnAcceptedWebSocket> {
1825        let processor = self.clone();
1826        Some(Arc::new(move |accepted| {
1827            let processor = processor.clone();
1828            Box::pin(async move { processor.handle_websocket_accepted(accepted).await })
1829        }))
1830    }
1831}
1832
1833#[tonic::async_trait]
1834impl DataPlaneService for MessageProcessor {
1835    type OpenChannelStream = Pin<Box<dyn Stream<Item = Result<Message, Status>> + Send + 'static>>;
1836
1837    async fn open_channel(
1838        &self,
1839        request: Request<tonic::Streaming<Message>>,
1840    ) -> Result<Response<Self::OpenChannelStream>, Status> {
1841        let remote_addr = request.remote_addr();
1842        let local_addr = request.local_addr();
1843
1844        let stream = request.into_inner();
1845        let (tx, rx) = mpsc::channel(128);
1846
1847        let connection = Connection::new(ConnType::Remote, Channel::Server(tx))
1848            .with_remote_addr(remote_addr)
1849            .with_local_addr(local_addr)
1850            .with_require_header_mac(self.internal.server_require_header_mac);
1851
1852        debug!(
1853            remote = ?connection.remote_addr(),
1854            local = ?connection.local_addr(),
1855            "new connection received from remote",
1856        );
1857        info!(telemetry = true, counter.num_active_connections = 1);
1858
1859        self.process_stream(
1860            stream,
1861            StreamSetup::Pending {
1862                connection: Box::new(connection),
1863                existing_index: None,
1864            },
1865            None,
1866            CancellationToken::new(),
1867            ConnType::Remote,
1868            false,
1869        )
1870        .map_err(|e| {
1871            error!(error = %e.chain(), "error starting new processing stream");
1872            Status::unavailable(format!("error processing stream: {:?}", e))
1873        })?;
1874
1875        let out_stream = ReceiverStream::new(rx);
1876        Ok(Response::new(
1877            Box::pin(out_stream) as Self::OpenChannelStream
1878        ))
1879    }
1880}
1881
1882#[cfg(test)]
1883mod tests {
1884    use std::time::Duration;
1885
1886    use super::*;
1887    use crate::api::{ProtoMessage, ProtoName, ProtoSubscriptionAck};
1888    use crate::header_mac::HeaderMacSession;
1889    use crate::sync::remote::SubscriptionInfo;
1890    use tonic::Status;
1891
1892    async fn assert_failed_subscription_ack_is_sent(add: bool) {
1893        let processor = MessageProcessor::new();
1894        let (in_connection, _tx, mut rx) = processor
1895            .register_local_connection(false)
1896            .expect("failed to create local connection");
1897
1898        let source = ProtoName::from_strings(["org", "ns", "source"]).with_id(1);
1899        let destination = ProtoName::from_strings(["org", "ns", "destination"]).with_id(2);
1900        let ack_id: u64 = if add { 1 } else { 2 };
1901        let invalid_connection = u64::MAX - 1;
1902
1903        let builder = Message::builder()
1904            .source(source.clone())
1905            .destination(destination.clone())
1906            .incoming_conn(invalid_connection)
1907            .subscription_id(ack_id);
1908
1909        let msg = if add {
1910            builder.build_subscribe().unwrap()
1911        } else {
1912            builder.build_unsubscribe().unwrap()
1913        };
1914
1915        let result = processor
1916            .process_subscription(msg, in_connection, add)
1917            .await;
1918        assert!(matches!(
1919            result,
1920            Err(DataPathError::MessageProcessingError { .. })
1921        ));
1922
1923        let ack_msg = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1924            .await
1925            .expect("timeout waiting for ack")
1926            .expect("ack channel closed")
1927            .expect("failed to receive ack message");
1928
1929        assert!(matches!(ack_msg.get_type(), SubscriptionAckType(_)));
1930        let ack = ack_msg.get_subscription_ack();
1931        assert_eq!(ack.subscription_id, ack_id);
1932        assert!(!ack.success, "failed ack should have success=false");
1933        assert!(
1934            !ack.error.is_empty(),
1935            "failed ack should include an error message"
1936        );
1937    }
1938
1939    #[tokio::test]
1940    async fn test_process_subscription_sends_failed_ack_on_subscribe_error() {
1941        assert_failed_subscription_ack_is_sent(true).await;
1942    }
1943
1944    #[tokio::test]
1945    async fn test_process_subscription_sends_failed_ack_on_unsubscribe_error() {
1946        assert_failed_subscription_ack_is_sent(false).await;
1947    }
1948
1949    // ── handle_link_message ───────────────────────────────────────────────────
1950
1951    #[tokio::test]
1952    async fn test_handle_link_message_is_local_ignored() {
1953        let processor = MessageProcessor::new();
1954        let link = ProtoLink { link_type: None };
1955        assert!(
1956            processor
1957                .handle_link_message(link, 0, ConnType::Local)
1958                .await
1959                .is_ok()
1960        );
1961    }
1962
1963    #[tokio::test]
1964    async fn test_handle_link_message_none_link_type_ignored() {
1965        let processor = MessageProcessor::new();
1966        let link = ProtoLink { link_type: None };
1967        assert!(
1968            processor
1969                .handle_link_message(link, 0, ConnType::Remote)
1970                .await
1971                .is_ok()
1972        );
1973    }
1974
1975    // ── handle_link_negotiation ───────────────────────────────────────────────
1976
1977    /// After negotiation completes and the connection is inserted into the table,
1978    /// any further link negotiation messages arriving in the main loop are simply
1979    /// logged and ignored (the handler is a no-op). This test verifies that.
1980    #[tokio::test]
1981    async fn test_handle_link_negotiation_post_negotiation_is_noop() {
1982        let processor = MessageProcessor::new();
1983        let payload = LinkNegotiationPayload {
1984            link_id: uuid::Uuid::new_v4().to_string(),
1985            slim_version: "1.0.0".into(),
1986            is_reply: false,
1987            link_ecdh_public_key: vec![],
1988            connection_type: 0,
1989            node_id: String::new(),
1990            deployment_name: String::new(),
1991        };
1992        // Unknown connection: handler returns Ok without panic.
1993        assert!(
1994            processor
1995                .handle_link_negotiation(&payload, u64::MAX)
1996                .await
1997                .is_ok()
1998        );
1999        // Known connection: handler still returns Ok (noop).
2000        let (conn_id, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2001        assert!(
2002            processor
2003                .handle_link_negotiation(&payload, conn_id)
2004                .await
2005                .is_ok()
2006        );
2007    }
2008
2009    // ── process_subscription: remote ack path ─────────────────────────────────
2010
2011    /// Helper: create a server connection that is already negotiated with given version and
2012    /// a test HMAC session, suitable for testing routing and MAC verification.
2013    fn make_negotiated_server_conn(
2014        processor: &MessageProcessor,
2015        version: &str,
2016    ) -> (u64, tokio::sync::mpsc::Receiver<Result<Message, Status>>) {
2017        let (tx, rx) = mpsc::channel(16);
2018        let conn = Connection::new(ConnType::Remote, Channel::Server(tx))
2019            .with_require_header_mac(processor.internal.server_require_header_mac)
2020            .with_negotiation(&uuid::Uuid::new_v4().to_string(), version)
2021            .with_header_hmac(HeaderMacSession::new(b"01234567890123456789012345678901").unwrap());
2022        let conn_id = processor
2023            .forwarder()
2024            .on_connection_established(conn, None)
2025            .unwrap();
2026        (conn_id, rx)
2027    }
2028
2029    #[tokio::test]
2030    async fn test_negotiation_timeout_configurable() {
2031        let server_config = ServerConfig {
2032            endpoint: "localhost:12345".to_string(),
2033            negotiation_timeout_secs: 1, // 1 second timeout
2034            ..Default::default()
2035        };
2036        let processor = MessageProcessor::new_with_server_config(
2037            "test_service".to_string(),
2038            String::new(),
2039            &server_config,
2040            false,
2041        );
2042
2043        assert_eq!(
2044            processor.internal.negotiation_timeout,
2045            std::time::Duration::from_secs(1)
2046        );
2047    }
2048
2049    #[test]
2050    fn verify_remote_header_mac_strict_rejects_publish_without_mac_session() {
2051        let processor = MessageProcessor::new();
2052        // Create a negotiated connection WITHOUT header HMAC installed.
2053        let (tx, _rx) = mpsc::channel(16);
2054        let conn = Connection::new(ConnType::Remote, Channel::Server(tx))
2055            .with_require_header_mac(true)
2056            .with_negotiation(&uuid::Uuid::new_v4().to_string(), "1.2.0");
2057        let remote_conn = processor
2058            .forwarder()
2059            .on_connection_established(conn, None)
2060            .unwrap();
2061        let c = processor.forwarder().get_connection(remote_conn).unwrap();
2062        assert!(c.header_hmac().is_none());
2063
2064        let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2065        let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2066        let msg = ProtoMessage::builder()
2067            .source(source)
2068            .destination(dest)
2069            .application_payload("text/plain", b"hey".to_vec())
2070            .build_publish()
2071            .expect("publish");
2072
2073        let err = processor
2074            .verify_remote_header_mac(remote_conn, &msg, true)
2075            .expect_err("unsigned publish must fail in strict mode without MAC session");
2076        assert!(matches!(err, DataPathError::NegotiationError(_)));
2077    }
2078
2079    #[test]
2080    fn verify_remote_header_mac_accepts_signed_inter_node_publish() {
2081        let processor = MessageProcessor::new();
2082        let (remote_conn, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2083        let link_id = processor
2084            .forwarder()
2085            .get_connection(remote_conn)
2086            .unwrap()
2087            .link_id()
2088            .expect("link id after negotiation");
2089
2090        let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2091        let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2092        let require_header_mac = true;
2093        let mut msg = ProtoMessage::builder()
2094            .source(source)
2095            .destination(dest)
2096            .application_payload("text/plain", b"hey".to_vec())
2097            .build_publish()
2098            .expect("publish");
2099
2100        let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2101        mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2102            .expect("sign header");
2103
2104        assert!(
2105            processor
2106                .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2107                .is_ok()
2108        );
2109    }
2110
2111    #[test]
2112    fn verify_remote_header_mac_rejects_destination_tamper_after_sign() {
2113        let processor = MessageProcessor::new();
2114        let (remote_conn, _rx) = make_negotiated_server_conn(&processor, "1.2.0");
2115        let link_id = processor
2116            .forwarder()
2117            .get_connection(remote_conn)
2118            .unwrap()
2119            .link_id()
2120            .expect("link id after negotiation");
2121
2122        let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2123        let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2124        let mut msg = ProtoMessage::builder()
2125            .source(source)
2126            .destination(dest)
2127            .application_payload("text/plain", b"hey".to_vec())
2128            .build_publish()
2129            .expect("publish");
2130
2131        let mac = HeaderMacSession::new(b"01234567890123456789012345678901").unwrap();
2132        let require_header_mac = true;
2133        mac.sign_slim_header(msg.get_slim_header_mut(), &link_id)
2134            .expect("sign header");
2135
2136        let header = msg.get_slim_header_mut();
2137        if let Some(dest) = header.destination.as_mut()
2138            && let Some(sn) = dest.str_name.as_mut()
2139        {
2140            sn.str_component_2.push_str("-integrity-test-tamper");
2141        }
2142
2143        let err = processor
2144            .verify_remote_header_mac(remote_conn, &msg, require_header_mac)
2145            .expect_err("tampered header must fail MAC verify");
2146        assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2147    }
2148
2149    #[tokio::test]
2150    #[allow(clippy::disallowed_methods)]
2151    async fn test_send_msg_raw_tamper_destination_env_var() {
2152        let _guard = ENV_LOCK.lock().await;
2153        unsafe {
2154            std::env::set_var("SLIM_TEST_TAMPER_DESTINATION", "1");
2155        }
2156
2157        let processor = MessageProcessor::new();
2158        let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2159
2160        let source = ProtoName::from_strings(["org", "default", "a"]).with_id(1);
2161        let dest = ProtoName::from_strings(["org", "default", "b"]).with_id(2);
2162        let msg = ProtoMessage::builder()
2163            .source(source)
2164            .destination(dest)
2165            .application_payload("text/plain", b"hey".to_vec())
2166            .build_publish()
2167            .expect("publish");
2168
2169        processor
2170            .send_msg_raw(msg, conn_id)
2171            .await
2172            .expect("send_msg_raw failed");
2173
2174        let sent_msg = rx.recv().await.unwrap().unwrap();
2175        let header = sent_msg.get_slim_header();
2176        let dest_name = header.destination.as_ref().expect("destination");
2177        let str_name = dest_name.str_name.as_ref().expect("str_name");
2178        let require_header_mac = true;
2179
2180        // The tampering happens in send_msg_raw if the env var is set.
2181        assert!(str_name.str_component_2.ends_with("-integrity-test-tamper"));
2182
2183        // Also verify that verify_remote_header_mac rejects it.
2184        let err = processor
2185            .verify_remote_header_mac(conn_id, &sent_msg, require_header_mac)
2186            .expect_err("tampered header must fail MAC verify");
2187        assert!(matches!(err, DataPathError::HeaderIntegrity(_)));
2188
2189        unsafe {
2190            std::env::remove_var("SLIM_TEST_TAMPER_DESTINATION");
2191        }
2192    }
2193
2194    #[tokio::test]
2195    async fn test_process_subscription_remote_ack_path_success() {
2196        // Arrange: relay processor, local app connection, and a "remote" server
2197        // connection whose version is ≥ 1.2.0.
2198        let processor = MessageProcessor::new();
2199        let (local_conn, _tx_local, mut rx_local) = processor
2200            .register_local_connection(false)
2201            .expect("failed to create local connection");
2202
2203        let (remote_conn, mut rx_remote) = make_negotiated_server_conn(&processor, "1.2.0");
2204
2205        let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2206        let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2207        let upstream_ack_id: u64 = 100;
2208
2209        // Build subscribe: forward_to = remote_conn, with upstream ack ID.
2210        let sub_msg = Message::builder()
2211            .source(source.clone())
2212            .destination(destination.clone())
2213            .incoming_conn(local_conn)
2214            .forward_to(remote_conn)
2215            .subscription_id(upstream_ack_id)
2216            .build_subscribe()
2217            .unwrap();
2218
2219        // Act: process_subscription should spawn the retry task and return Ok(()).
2220        let result = processor
2221            .process_subscription(sub_msg, local_conn, true)
2222            .await;
2223        assert!(result.is_ok());
2224
2225        // The relay must have forwarded the subscribe to the remote connection.
2226        // Give the spawned task a moment to send the message.
2227        let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2228            .await
2229            .expect("timeout waiting for forwarded subscribe")
2230            .expect("forwarded subscribe channel closed")
2231            .unwrap();
2232        assert!(matches!(forwarded.get_type(), SubscribeType(_)));
2233
2234        // The forwarded message must carry the same subscription_id as the original.
2235        let forwarded_sub_id = forwarded
2236            .get_subscription_id()
2237            .expect("forwarded subscribe must carry the same subscription_id");
2238        assert_eq!(
2239            forwarded_sub_id, upstream_ack_id,
2240            "subscription_id must not change when forwarding"
2241        );
2242
2243        // Simulate the remote node sending back a success SubscriptionAck.
2244        let ack = ProtoSubscriptionAck {
2245            subscription_id: upstream_ack_id,
2246            success: true,
2247            error: String::new(),
2248        };
2249        processor.peer_sync().resolve_ack(
2250            ack.subscription_id,
2251            if ack.success {
2252                Ok(())
2253            } else {
2254                Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2255            },
2256        );
2257
2258        // The relay must now forward the upstream ACK to the local connection.
2259        let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2260            .await
2261            .expect("timeout waiting for upstream ack")
2262            .expect("upstream ack channel closed")
2263            .expect("upstream ack should be Ok");
2264
2265        assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2266        let ack_inner = upstream_ack.get_subscription_ack();
2267        assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2268        assert!(ack_inner.success);
2269    }
2270
2271    #[tokio::test]
2272    async fn test_process_subscription_remote_ack_error_forwarded_upstream() {
2273        // Remote node (v1.2.0) sends back a failure ACK; relay must forward it upstream.
2274        let processor = MessageProcessor::new();
2275        let (local_conn, _tx_local, mut rx_local) = processor
2276            .register_local_connection(false)
2277            .expect("failed to create local connection");
2278
2279        let (remote_conn, mut rx_remote) = make_negotiated_server_conn(&processor, "1.2.0");
2280
2281        let source = ProtoName::from_strings(["org", "ns", "src"]).with_id(1);
2282        let destination = ProtoName::from_strings(["org", "ns", "dst"]).with_id(2);
2283        let upstream_ack_id: u64 = 102;
2284
2285        let sub_msg = Message::builder()
2286            .source(source.clone())
2287            .destination(destination.clone())
2288            .incoming_conn(local_conn)
2289            .forward_to(remote_conn)
2290            .subscription_id(upstream_ack_id)
2291            .build_subscribe()
2292            .unwrap();
2293
2294        processor
2295            .process_subscription(sub_msg, local_conn, true)
2296            .await
2297            .unwrap();
2298
2299        let forwarded = tokio::time::timeout(Duration::from_secs(1), rx_remote.recv())
2300            .await
2301            .expect("timeout")
2302            .expect("channel closed")
2303            .unwrap();
2304
2305        let forwarded_sub_id = forwarded
2306            .get_subscription_id()
2307            .expect("forwarded subscribe must carry the same subscription_id");
2308        assert_eq!(
2309            forwarded_sub_id, upstream_ack_id,
2310            "subscription_id must not change when forwarding"
2311        );
2312
2313        // Simulate remote failure via SubscriptionAck.
2314        let ack = ProtoSubscriptionAck {
2315            subscription_id: upstream_ack_id,
2316            success: false,
2317            error: "remote error".to_string(),
2318        };
2319        processor.peer_sync().resolve_ack(
2320            ack.subscription_id,
2321            if ack.success {
2322                Ok(())
2323            } else {
2324                Err(DataPathError::RemoteSubscriptionAckError(ack.error.clone()))
2325            },
2326        );
2327
2328        let upstream_ack = tokio::time::timeout(Duration::from_secs(2), rx_local.recv())
2329            .await
2330            .expect("timeout")
2331            .expect("channel closed")
2332            .expect("must be Ok");
2333
2334        assert!(matches!(upstream_ack.get_type(), SubscriptionAckType(_)));
2335        let ack_inner = upstream_ack.get_subscription_ack();
2336        assert_eq!(ack_inner.subscription_id, upstream_ack_id);
2337        assert!(!ack_inner.success);
2338        assert!(!ack_inner.error.is_empty());
2339    }
2340
2341    // ── notify_control_plane_subscriptions_lost ───────────────────────────────
2342
2343    #[tokio::test]
2344    async fn test_notify_cp_subs_lost_sends_unsubscribes() {
2345        let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2346        let mut subs = HashMap::new();
2347        let name = ProtoName::from_strings(["org", "default", "svc"]);
2348        subs.insert(name.clone(), HashSet::from([1u64, 2u64]));
2349
2350        MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), subs, 42).await;
2351
2352        let msg = rx.recv().await.unwrap().unwrap();
2353        assert!(matches!(msg.get_type(), UnsubscribeType(_)));
2354        assert_eq!(msg.get_source(), name.clone());
2355    }
2356
2357    #[tokio::test]
2358    async fn test_notify_cp_subs_lost_no_tx_is_noop() {
2359        let subs = HashMap::from([(
2360            ProtoName::from_strings(["org", "default", "svc"]),
2361            HashSet::from([1u64]),
2362        )]);
2363        // Should not panic or hang.
2364        MessageProcessor::notify_control_plane_subscriptions_lost(None, subs, 1).await;
2365    }
2366
2367    #[tokio::test]
2368    async fn test_notify_cp_subs_lost_empty_subs() {
2369        let (tx, mut rx) = mpsc::channel::<Result<Message, Status>>(16);
2370        MessageProcessor::notify_control_plane_subscriptions_lost(Some(tx), HashMap::new(), 1)
2371            .await;
2372        // No messages should be sent.
2373        assert!(rx.try_recv().is_err());
2374    }
2375
2376    // ── restore_remote_subscriptions ──────────────────────────────────────────
2377
2378    #[tokio::test]
2379    async fn test_restore_remote_subscriptions_with_tracking() {
2380        let processor = MessageProcessor::new();
2381        let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2382
2383        let source = ProtoName::from_strings(["org", "default", "src"]);
2384        let dest = ProtoName::from_strings(["org", "default", "dst"]);
2385        let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
2386        let subs = HashSet::from([sub]);
2387
2388        processor
2389            .restore_remote_subscriptions(&subs, conn_id, true)
2390            .await;
2391
2392        // The subscribe message should have been sent.
2393        let msg = rx.recv().await.unwrap().unwrap();
2394        assert!(matches!(msg.get_type(), SubscribeType(_)));
2395
2396        // With restore_tracking=true, the forwarded subscription should be tracked.
2397        let tracked = processor
2398            .remote_sync()
2399            .get_subscriptions_for_reconnect(conn_id);
2400        assert_eq!(tracked.len(), 1);
2401    }
2402
2403    #[tokio::test]
2404    async fn test_restore_remote_subscriptions_without_tracking() {
2405        let processor = MessageProcessor::new();
2406        let (conn_id, mut rx) = make_negotiated_server_conn(&processor, "1.2.0");
2407
2408        let source = ProtoName::from_strings(["org", "default", "src"]);
2409        let dest = ProtoName::from_strings(["org", "default", "dst"]);
2410        let sub = SubscriptionInfo::new(source.clone(), dest.clone(), "id1".into(), conn_id, 7);
2411        let subs = HashSet::from([sub]);
2412
2413        processor
2414            .restore_remote_subscriptions(&subs, conn_id, false)
2415            .await;
2416
2417        // Message sent.
2418        let msg = rx.recv().await.unwrap().unwrap();
2419        assert!(matches!(msg.get_type(), SubscribeType(_)));
2420
2421        // With restore_tracking=false, forwarded subscription table should NOT be updated.
2422        let tracked = processor
2423            .remote_sync()
2424            .get_subscriptions_for_reconnect(conn_id);
2425        assert!(tracked.is_empty());
2426    }
2427}