Skip to main content

slim_datapath/
connection.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::api::proto::dataplane::v1::Message;
5use semver::Version;
6use slim_config::client::ClientConfig;
7use std::net::SocketAddr;
8use tokio::sync::mpsc;
9use tokio_util::sync::CancellationToken;
10use tonic::Status;
11
12use crate::header_mac::HeaderMacSession;
13
14#[derive(Debug, Clone)]
15pub enum Channel {
16    Server(mpsc::Sender<Result<Message, Status>>),
17    Client(mpsc::Sender<Message>),
18}
19
20use crate::tables::ConnType;
21
22#[derive(Clone)]
23/// Connection information.
24pub struct Connection {
25    /// Remote address and port. Not available for local connections
26    remote_addr: Option<SocketAddr>,
27
28    /// Local address and port. Not available for remote connections
29    local_addr: Option<SocketAddr>,
30
31    /// Channel to send messages
32    channel: Channel,
33
34    /// Configuration data for the connection.
35    config_data: Option<ClientConfig>,
36
37    /// Connection type
38    connection_type: ConnType,
39
40    /// cancellation token to stop the receiving loop on this connection
41    cancellation_token: Option<CancellationToken>,
42
43    /// Link identifier shared between both sides of a remote link.
44    link_id: Option<String>,
45
46    /// SLIM version of the remote peer (set during negotiation).
47    remote_slim_version: Option<Version>,
48
49    /// HMAC session derived from the ECDH key exchange (set during negotiation).
50    header_hmac: Option<HeaderMacSession>,
51
52    /// Strict header MAC policy for this connection (fixed at establishment).
53    require_header_mac: bool,
54
55    /// Remote node identifier, set during link negotiation.
56    peer_node_id: Option<String>,
57}
58
59impl std::fmt::Debug for Connection {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        f.debug_struct("Connection")
62            .field("remote_addr", &self.remote_addr)
63            .field("local_addr", &self.local_addr)
64            .field("channel", &self.channel)
65            .field("config_data", &self.config_data)
66            .field("connection_type", &self.connection_type)
67            .field("link_id", &self.link_id)
68            .field("remote_slim_version", &self.remote_slim_version)
69            .field("header_hmac", &self.header_hmac.is_some())
70            .finish_non_exhaustive()
71    }
72}
73
74/// Implementation of Connection
75impl Connection {
76    /// Create a new Connection
77    pub fn new(connection_type: ConnType, channel: Channel) -> Self {
78        Self {
79            remote_addr: None,
80            local_addr: None,
81            channel,
82            config_data: None,
83            connection_type,
84            cancellation_token: None,
85            link_id: None,
86            remote_slim_version: None,
87            header_hmac: None,
88            require_header_mac: false,
89            peer_node_id: None,
90        }
91    }
92
93    /// Set whether strict header MAC verification applies on this connection.
94    pub(crate) fn with_require_header_mac(self, require_header_mac: bool) -> Self {
95        Self {
96            require_header_mac,
97            ..self
98        }
99    }
100
101    pub(crate) fn require_header_mac(&self) -> bool {
102        self.require_header_mac
103    }
104
105    /// Set the remote address
106    pub(crate) fn with_remote_addr(self, remote_addr: Option<SocketAddr>) -> Self {
107        Self {
108            remote_addr,
109            ..self
110        }
111    }
112
113    /// Set the local address
114    pub(crate) fn with_local_addr(self, local_addr: Option<SocketAddr>) -> Self {
115        Self { local_addr, ..self }
116    }
117
118    /// Set the configuration data for the connection
119    pub(crate) fn with_config_data(self, config_data: Option<ClientConfig>) -> Self {
120        Self {
121            config_data,
122            ..self
123        }
124    }
125
126    pub(crate) fn header_hmac(&self) -> Option<&HeaderMacSession> {
127        self.header_hmac.as_ref()
128    }
129
130    pub(crate) fn install_header_hmac(&mut self, mac: HeaderMacSession) {
131        self.header_hmac = Some(mac);
132    }
133
134    /// Get the remote address
135    pub fn remote_addr(&self) -> Option<&SocketAddr> {
136        self.remote_addr.as_ref()
137    }
138
139    /// Get the local address
140    pub fn local_addr(&self) -> Option<&SocketAddr> {
141        self.local_addr.as_ref()
142    }
143
144    /// Get the channel
145    pub(crate) fn channel(&self) -> &Channel {
146        &self.channel
147    }
148
149    pub fn config_data(&self) -> Option<&ClientConfig> {
150        self.config_data.as_ref()
151    }
152
153    /// Get the connection type
154    pub fn connection_type(&self) -> ConnType {
155        self.connection_type
156    }
157
158    /// Upgrade the connection type (e.g., from Remote to Peer after negotiation).
159    pub(crate) fn set_connection_type(&mut self, conn_type: ConnType) {
160        self.connection_type = conn_type;
161    }
162
163    /// Return true if is a local connection
164    pub(crate) fn is_local_connection(&self) -> bool {
165        matches!(self.connection_type, ConnType::Local)
166    }
167
168    /// Return true if is a peer connection (same deployment replica)
169    #[allow(dead_code)]
170    pub(crate) fn is_peer_connection(&self) -> bool {
171        matches!(self.connection_type, ConnType::Peer)
172    }
173
174    /// Return true if this node initiated the connection (outbound dial).
175    ///
176    /// gRPC inbound peers use [`Channel::Server`]; outbound dials use [`Channel::Client`]
177    /// with [`config_data`](Self::config_data) set from [`ClientConfig`].
178    ///
179    /// WebSocket is asymmetric: the server accept path still uses [`Channel::Client`] for
180    /// writes, but leaves `config_data` unset, so inbound WebSocket is distinguished from
181    /// outbound WebSocket (which always carries `config_data` from the dial).
182    pub fn is_outgoing(&self) -> bool {
183        matches!(self.channel, Channel::Client(_)) && self.config_data.is_some()
184    }
185
186    /// Set cancellation token
187    pub(crate) fn with_cancellation_token(
188        self,
189        cancellation_token: Option<CancellationToken>,
190    ) -> Self {
191        Self {
192            cancellation_token,
193            ..self
194        }
195    }
196
197    /// Get cancellation token
198    pub(crate) fn cancellation_token(&self) -> Option<&CancellationToken> {
199        self.cancellation_token.as_ref()
200    }
201
202    /// Set the link identifier at construction time (client side).
203    pub(crate) fn with_link_id(mut self, link_id: String) -> Self {
204        self.link_id = Some(link_id);
205        self
206    }
207
208    /// Set the shared link identifier for this connection.
209    pub fn set_link_id(&mut self, link_id: String) {
210        self.link_id = Some(link_id);
211    }
212
213    /// Get the shared link identifier for this connection.
214    pub fn link_id(&self) -> Option<String> {
215        self.link_id.clone()
216    }
217
218    /// Get the SLIM version of the remote peer.
219    pub fn remote_slim_version(&self) -> Option<Version> {
220        self.remote_slim_version.clone()
221    }
222
223    /// Get the remote peer's node identifier (set during link negotiation).
224    pub fn peer_node_id(&self) -> Option<&str> {
225        self.peer_node_id.as_deref()
226    }
227
228    /// Set the remote peer's node identifier.
229    pub(crate) fn set_peer_node_id(&mut self, node_id: String) {
230        self.peer_node_id = Some(node_id);
231    }
232
233    /// Returns true if link negotiation has completed (remote_slim_version is set).
234    pub fn is_negotiated(&self) -> bool {
235        self.remote_slim_version.is_some()
236    }
237
238    /// Complete link negotiation on the server (incoming) path.
239    ///
240    /// Stores `link_id` and `version`. Returns `false` if `link_id` is empty or
241    /// negotiation is already complete (replay protection).
242    pub fn complete_negotiation_as_server(&mut self, link_id: &str, version: Version) -> bool {
243        if self.remote_slim_version.is_some() {
244            return false;
245        }
246        if link_id.is_empty() {
247            return false;
248        }
249        self.link_id = Some(link_id.to_string());
250        self.remote_slim_version = Some(version);
251        true
252    }
253
254    /// Complete link negotiation on the client (outgoing) path.
255    ///
256    /// Verifies the echoed `link_id` matches what was set, then stores `version`.
257    /// Returns `false` if there is a mismatch or negotiation is already complete.
258    pub fn complete_negotiation_as_client(&mut self, link_id: &str, version: Version) -> bool {
259        if self.remote_slim_version.is_some() {
260            return false;
261        }
262        if self.link_id.as_deref() != Some(link_id) {
263            return false;
264        }
265        self.remote_slim_version = Some(version);
266        true
267    }
268
269    /// Send a message directly through this connection's channel.
270    pub(crate) async fn send(&self, msg: Message) -> Result<(), crate::errors::DataPathError> {
271        match &self.channel {
272            Channel::Server(tx) => tx
273                .send(Ok(msg))
274                .await
275                .map_err(|_| crate::errors::DataPathError::ConnectionSendError),
276            Channel::Client(tx) => tx
277                .send(msg)
278                .await
279                .map_err(|_| crate::errors::DataPathError::ConnectionSendError),
280        }
281    }
282
283    /// Set negotiation state at construction time.
284    pub fn with_negotiation(mut self, link_id: &str, version: &str) -> Self {
285        self.link_id = Some(link_id.to_string());
286        self.remote_slim_version = version.parse().ok();
287        self
288    }
289
290    /// Set header HMAC at construction time (for testing).
291    #[cfg(test)]
292    pub(crate) fn with_header_hmac(mut self, mac: HeaderMacSession) -> Self {
293        self.header_hmac = Some(mac);
294        self
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use std::net::{Ipv4Addr, SocketAddrV4, ToSocketAddrs};
301
302    use super::*;
303    use tokio::sync::mpsc;
304
305    fn server_conn() -> Connection {
306        let (tx, _rx) = mpsc::channel(1);
307        Connection::new(ConnType::Remote, Channel::Server(tx))
308    }
309
310    fn client_conn() -> Connection {
311        let (tx, _rx) = mpsc::channel(1);
312        Connection::new(ConnType::Remote, Channel::Client(tx))
313            .with_config_data(Some(ClientConfig::default()))
314    }
315
316    #[test]
317    fn test_is_outgoing_client() {
318        assert!(client_conn().is_outgoing());
319    }
320
321    #[test]
322    fn test_is_outgoing_server() {
323        assert!(!server_conn().is_outgoing());
324    }
325
326    #[test]
327    fn test_is_outgoing_websocket_inbound() {
328        let (tx, _rx) = mpsc::channel(1);
329        let conn = Connection::new(ConnType::Remote, Channel::Client(tx));
330        assert!(!conn.is_outgoing());
331    }
332
333    #[test]
334    fn test_link_id_initially_none() {
335        assert!(server_conn().link_id().is_none());
336    }
337
338    #[test]
339    fn test_connection_format_print() {
340        let remote = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8080)
341            .to_socket_addrs()
342            .unwrap()
343            .next()
344            .unwrap();
345
346        let local = SocketAddrV4::new(Ipv4Addr::new(127, 0, 0, 1), 8081)
347            .to_socket_addrs()
348            .unwrap()
349            .next()
350            .unwrap();
351
352        let conn = client_conn()
353            .with_remote_addr(Some(remote))
354            .with_local_addr(Some(local));
355        let debug = format!("{conn:?}");
356
357        assert!(debug.starts_with("Connection"));
358        assert!(debug.contains("connection_type: Remote"));
359        assert!(debug.contains("remote_addr: Some"));
360        assert!(debug.contains("local_addr: Some"));
361    }
362
363    #[test]
364    fn test_set_and_get_link_id() {
365        let mut conn = server_conn();
366        conn.set_link_id("my-link".to_string());
367        assert_eq!(conn.link_id(), Some("my-link".to_string()));
368    }
369
370    #[test]
371    fn test_remote_slim_version_initially_none() {
372        assert!(server_conn().remote_slim_version().is_none());
373    }
374
375    #[test]
376    fn test_is_negotiated_initially_false() {
377        assert!(!server_conn().is_negotiated());
378        assert!(!client_conn().is_negotiated());
379    }
380
381    #[test]
382    fn test_is_negotiated_true_after_server_negotiation() {
383        let mut conn = server_conn();
384        conn.complete_negotiation_as_server("link-id", Version::parse("1.0.0").unwrap());
385        assert!(conn.is_negotiated());
386    }
387
388    #[test]
389    fn test_is_negotiated_true_after_client_negotiation() {
390        let mut conn = client_conn();
391        let id = uuid::Uuid::new_v4().to_string();
392        conn.set_link_id(id.clone());
393        conn.complete_negotiation_as_client(&id, Version::parse("1.0.0").unwrap());
394        assert!(conn.is_negotiated());
395    }
396
397    #[test]
398    fn test_complete_negotiation_as_server_stores_link_id() {
399        let mut conn = server_conn();
400        let id = "my-custom-link-id";
401        let v = Version::parse("1.2.3").unwrap();
402        assert!(conn.complete_negotiation_as_server(id, v.clone()));
403        assert_eq!(conn.link_id(), Some(id.to_string()));
404        assert_eq!(conn.remote_slim_version(), Some(v));
405    }
406
407    #[test]
408    fn test_complete_negotiation_as_server_rejects_empty_link_id() {
409        let mut conn = server_conn();
410        assert!(!conn.complete_negotiation_as_server("", Version::parse("1.0.0").unwrap()));
411        assert!(conn.link_id().is_none());
412        assert!(conn.remote_slim_version().is_none());
413    }
414
415    #[test]
416    fn test_complete_negotiation_as_server_replay_returns_false() {
417        let mut conn = server_conn();
418        let id = uuid::Uuid::new_v4().to_string();
419        let v1 = Version::parse("1.0.0").unwrap();
420        assert!(conn.complete_negotiation_as_server(&id, v1.clone()));
421        // Second call must be rejected; state must not change.
422        assert!(!conn.complete_negotiation_as_server(&id, Version::parse("2.0.0").unwrap()));
423        assert_eq!(conn.remote_slim_version(), Some(v1));
424    }
425
426    #[test]
427    fn test_complete_negotiation_as_client_accepts_matching_link_id() {
428        let mut conn = client_conn();
429        let id = uuid::Uuid::new_v4().to_string();
430        conn.set_link_id(id.clone());
431        let v = Version::parse("1.0.0").unwrap();
432        assert!(conn.complete_negotiation_as_client(&id, v.clone()));
433        assert_eq!(conn.remote_slim_version(), Some(v));
434    }
435
436    #[test]
437    fn test_complete_negotiation_as_client_rejects_mismatched_link_id() {
438        let mut conn = client_conn();
439        conn.set_link_id(uuid::Uuid::new_v4().to_string());
440        assert!(!conn.complete_negotiation_as_client("wrong-id", Version::parse("1.0.0").unwrap()));
441        assert!(conn.remote_slim_version().is_none());
442    }
443
444    #[test]
445    fn test_complete_negotiation_as_client_replay_returns_false() {
446        let mut conn = client_conn();
447        let id = uuid::Uuid::new_v4().to_string();
448        conn.set_link_id(id.clone());
449        let v1 = Version::parse("1.0.0").unwrap();
450        assert!(conn.complete_negotiation_as_client(&id, v1.clone()));
451        // Second call must be rejected; state must not change.
452        assert!(!conn.complete_negotiation_as_client(&id, Version::parse("2.0.0").unwrap()));
453        assert_eq!(conn.remote_slim_version(), Some(v1));
454    }
455}