Skip to main content

slim_session/
session_config.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5
6use slim_datapath::api::{CommandPayload, ProtoSessionType};
7
8use crate::{SessionError, timer_factory::TimerSettings};
9
10#[derive(Default, Clone, Debug, PartialEq)]
11pub struct MlsSettings {
12    pub header_integrity_validation_percent: u32,
13}
14
15#[derive(Default, Clone, Debug, PartialEq)]
16pub struct SessionConfig {
17    /// session type
18    pub session_type: ProtoSessionType,
19
20    /// number of retries for each message/rtx
21    pub max_retries: Option<u32>,
22
23    /// interval between retries
24    pub interval: Option<std::time::Duration>,
25
26    /// MLS related settings
27    pub mls_settings: Option<MlsSettings>,
28
29    /// true is the local endpoint is initiator of the session
30    pub initiator: bool,
31
32    /// metadata related to the sessions
33    pub metadata: HashMap<String, String>,
34}
35
36impl SessionConfig {
37    #[allow(deprecated)]
38    fn mls_settings_from_join(
39        join: &slim_datapath::api::JoinRequestPayload,
40    ) -> Option<MlsSettings> {
41        if join.mls_settings.is_some() {
42            let header_integrity_validation_percent = join
43                .mls_settings
44                .as_ref()
45                .map(|wire| wire.header_integrity_validation_percent.min(100))
46                .unwrap_or(100);
47            Some(MlsSettings {
48                header_integrity_validation_percent,
49            })
50        } else {
51            None
52        }
53    }
54
55    pub fn with_session_type(self, session_type: ProtoSessionType) -> Self {
56        Self {
57            session_type,
58            max_retries: self.max_retries,
59            interval: self.interval,
60            initiator: self.initiator,
61            metadata: self.metadata,
62            mls_settings: self.mls_settings,
63        }
64    }
65
66    pub fn get_timer_settings(&self) -> TimerSettings {
67        TimerSettings::constant(self.interval.unwrap_or(std::time::Duration::from_secs(1)))
68            .with_max_retries(self.max_retries.unwrap_or(10))
69    }
70
71    pub fn from_join_request(
72        session_type: ProtoSessionType,
73        payload: &CommandPayload,
74        metadata: HashMap<String, String>,
75        initiator: bool,
76    ) -> Result<Self, SessionError> {
77        let join = payload.as_join_request_payload()?;
78        let (duration, max_retries) = if let Some(ts) = &join.timer_settings {
79            (
80                Some(std::time::Duration::from_millis(ts.timeout as u64)),
81                Some(ts.max_retries),
82            )
83        } else {
84            (None, None)
85        };
86
87        Ok(SessionConfig {
88            session_type,
89            max_retries,
90            interval: duration,
91            mls_settings: Self::mls_settings_from_join(join),
92            initiator,
93            metadata,
94        })
95    }
96}
97
98#[cfg(test)]
99mod tests {
100    use super::*;
101    use slim_datapath::api::CommandPayload;
102    use slim_datapath::api::ProtoName as Name;
103    use slim_datapath::messages::utils::MessageError;
104    use std::time::Duration;
105
106    #[test]
107    fn test_default() {
108        let config = SessionConfig::default();
109        assert_eq!(config.session_type, ProtoSessionType::Unspecified);
110        assert_eq!(config.max_retries, None);
111        assert_eq!(config.interval, None);
112        assert!(config.mls_settings.is_none());
113        assert!(!config.initiator);
114        assert!(config.metadata.is_empty());
115    }
116
117    #[test]
118    fn test_with_session_type() {
119        let mut metadata = HashMap::new();
120        metadata.insert("key1".to_string(), "value1".to_string());
121
122        let config = SessionConfig {
123            session_type: ProtoSessionType::Unspecified,
124            max_retries: Some(5),
125            interval: Some(Duration::from_secs(10)),
126            initiator: true,
127            metadata: metadata.clone(),
128            mls_settings: Some(MlsSettings::default()),
129        };
130
131        let new_config = config.with_session_type(ProtoSessionType::Multicast);
132
133        assert_eq!(new_config.session_type, ProtoSessionType::Multicast);
134        assert_eq!(new_config.max_retries, Some(5));
135        assert_eq!(new_config.interval, Some(Duration::from_secs(10)));
136        assert!(new_config.mls_settings.is_some());
137        assert!(new_config.initiator);
138        assert_eq!(new_config.metadata.len(), 1);
139        assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
140    }
141
142    #[test]
143    fn test_with_session_type_point_to_point() {
144        let config = SessionConfig::default();
145        let new_config = config.with_session_type(ProtoSessionType::PointToPoint);
146        assert_eq!(new_config.session_type, ProtoSessionType::PointToPoint);
147    }
148
149    #[test]
150    fn test_from_join_request_with_timer_settings() {
151        let dest = Name::from_strings(["dest", "", ""]);
152        let payload = CommandPayload::builder().join_request(
153            Some(3),
154            Some(Duration::from_millis(500)),
155            Some(dest),
156            Some(slim_datapath::api::ProtoMlsSettings {
157                header_integrity_validation_percent: 100,
158            }),
159        );
160
161        let mut metadata = HashMap::new();
162        metadata.insert("test_key".to_string(), "test_value".to_string());
163
164        let config = SessionConfig::from_join_request(
165            ProtoSessionType::Multicast,
166            &payload,
167            metadata.clone(),
168            true,
169        )
170        .unwrap();
171
172        assert_eq!(config.session_type, ProtoSessionType::Multicast);
173        assert_eq!(config.max_retries, Some(3));
174        assert_eq!(config.interval, Some(Duration::from_millis(500)));
175        assert!(config.mls_settings.is_some());
176        assert!(config.initiator);
177        assert_eq!(config.metadata.len(), 1);
178        assert_eq!(
179            config.metadata.get("test_key"),
180            Some(&"test_value".to_string())
181        );
182    }
183
184    #[test]
185    fn test_from_join_request_without_timer_settings() {
186        let dest = Name::from_strings(["dest", "", ""]);
187        let payload = CommandPayload::builder().join_request(None, None, Some(dest), None);
188
189        let metadata = HashMap::new();
190
191        let config = SessionConfig::from_join_request(
192            ProtoSessionType::PointToPoint,
193            &payload,
194            metadata,
195            false,
196        )
197        .unwrap();
198
199        assert_eq!(config.session_type, ProtoSessionType::PointToPoint);
200        assert_eq!(config.max_retries, None);
201        assert_eq!(config.interval, None);
202        assert!(config.mls_settings.is_none());
203        assert!(!config.initiator);
204        assert!(config.metadata.is_empty());
205    }
206
207    #[test]
208    fn test_from_join_request_with_mls_enabled() {
209        let dest = Name::from_strings(["dest", "", ""]);
210        let payload = CommandPayload::builder().join_request(
211            Some(10),
212            Some(Duration::from_secs(5)),
213            Some(dest),
214            Some(slim_datapath::api::ProtoMlsSettings {
215                header_integrity_validation_percent: 100,
216            }),
217        );
218
219        let config = SessionConfig::from_join_request(
220            ProtoSessionType::Multicast,
221            &payload,
222            HashMap::new(),
223            false,
224        )
225        .unwrap();
226
227        assert!(config.mls_settings.is_some());
228    }
229
230    #[test]
231    fn test_from_join_request_invalid_payload() {
232        // Create a payload that is not a join request
233        let payload = CommandPayload::builder().leave_request();
234
235        let result = SessionConfig::from_join_request(
236            ProtoSessionType::Multicast,
237            &payload,
238            HashMap::new(),
239            true,
240        );
241
242        assert!(result.is_err_and(|e| matches!(
243            e,
244            SessionError::MessageError(MessageError::InvalidCommandPayloadType {
245                expected: _,
246                got: _
247            })
248        )));
249    }
250
251    #[test]
252    fn test_clone() {
253        let mut metadata = HashMap::new();
254        metadata.insert("key".to_string(), "value".to_string());
255
256        let config = SessionConfig {
257            session_type: ProtoSessionType::Multicast,
258            max_retries: Some(7),
259            interval: Some(Duration::from_millis(1000)),
260            initiator: false,
261            metadata: metadata.clone(),
262            mls_settings: Some(MlsSettings::default()),
263        };
264
265        let cloned = config.clone();
266
267        assert_eq!(cloned.session_type, config.session_type);
268        assert_eq!(cloned.max_retries, config.max_retries);
269        assert_eq!(cloned.interval, config.interval);
270        assert_eq!(cloned.mls_settings, config.mls_settings);
271        assert_eq!(cloned.initiator, config.initiator);
272        assert_eq!(cloned.metadata, config.metadata);
273    }
274
275    #[test]
276    fn test_from_join_request_with_large_timeout() {
277        let dest = Name::from_strings(["dest", "", ""]);
278        let payload = CommandPayload::builder().join_request(
279            Some(100),
280            Some(Duration::from_secs(3600)), // 1 hour
281            Some(dest),
282            None,
283        );
284
285        let config = SessionConfig::from_join_request(
286            ProtoSessionType::Multicast,
287            &payload,
288            HashMap::new(),
289            true,
290        )
291        .unwrap();
292
293        assert_eq!(config.max_retries, Some(100));
294        assert_eq!(config.interval, Some(Duration::from_secs(3600)));
295    }
296
297    #[test]
298    fn test_metadata_preservation() {
299        let mut metadata = HashMap::new();
300        metadata.insert("key1".to_string(), "value1".to_string());
301        metadata.insert("key2".to_string(), "value2".to_string());
302        metadata.insert("key3".to_string(), "value3".to_string());
303
304        let config = SessionConfig {
305            session_type: ProtoSessionType::Unspecified,
306            max_retries: None,
307            interval: None,
308            initiator: false,
309            metadata: metadata.clone(),
310            mls_settings: None,
311        };
312
313        let new_config = config.with_session_type(ProtoSessionType::Multicast);
314
315        assert_eq!(new_config.metadata.len(), 3);
316        assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
317        assert_eq!(new_config.metadata.get("key2"), Some(&"value2".to_string()));
318        assert_eq!(new_config.metadata.get("key3"), Some(&"value3".to_string()));
319    }
320}