1use std::collections::HashMap;
5
6use slim_datapath::api::{CommandPayload, ProtoSessionType};
7
8use crate::{SessionError, timer_factory::TimerSettings};
9
10#[derive(Default, Clone, Debug, PartialEq)]
11pub struct MlsSettings {
12 pub header_integrity_validation_percent: u32,
13}
14
15#[derive(Default, Clone, Debug, PartialEq)]
16pub struct SessionConfig {
17 pub session_type: ProtoSessionType,
19
20 pub max_retries: Option<u32>,
22
23 pub interval: Option<std::time::Duration>,
25
26 pub mls_settings: Option<MlsSettings>,
28
29 pub initiator: bool,
31
32 pub metadata: HashMap<String, String>,
34}
35
36impl SessionConfig {
37 #[allow(deprecated)]
38 fn mls_settings_from_join(
39 join: &slim_datapath::api::JoinRequestPayload,
40 ) -> Option<MlsSettings> {
41 if join.mls_settings.is_some() {
42 let header_integrity_validation_percent = join
43 .mls_settings
44 .as_ref()
45 .map(|wire| wire.header_integrity_validation_percent.min(100))
46 .unwrap_or(100);
47 Some(MlsSettings {
48 header_integrity_validation_percent,
49 })
50 } else {
51 None
52 }
53 }
54
55 pub fn with_session_type(self, session_type: ProtoSessionType) -> Self {
56 Self {
57 session_type,
58 max_retries: self.max_retries,
59 interval: self.interval,
60 initiator: self.initiator,
61 metadata: self.metadata,
62 mls_settings: self.mls_settings,
63 }
64 }
65
66 pub fn get_timer_settings(&self) -> TimerSettings {
67 TimerSettings::constant(self.interval.unwrap_or(std::time::Duration::from_secs(1)))
68 .with_max_retries(self.max_retries.unwrap_or(10))
69 }
70
71 pub fn from_join_request(
72 session_type: ProtoSessionType,
73 payload: &CommandPayload,
74 metadata: HashMap<String, String>,
75 initiator: bool,
76 ) -> Result<Self, SessionError> {
77 let join = payload.as_join_request_payload()?;
78 let (duration, max_retries) = if let Some(ts) = &join.timer_settings {
79 (
80 Some(std::time::Duration::from_millis(ts.timeout as u64)),
81 Some(ts.max_retries),
82 )
83 } else {
84 (None, None)
85 };
86
87 Ok(SessionConfig {
88 session_type,
89 max_retries,
90 interval: duration,
91 mls_settings: Self::mls_settings_from_join(join),
92 initiator,
93 metadata,
94 })
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use slim_datapath::api::CommandPayload;
102 use slim_datapath::api::ProtoName as Name;
103 use slim_datapath::messages::utils::MessageError;
104 use std::time::Duration;
105
106 #[test]
107 fn test_default() {
108 let config = SessionConfig::default();
109 assert_eq!(config.session_type, ProtoSessionType::Unspecified);
110 assert_eq!(config.max_retries, None);
111 assert_eq!(config.interval, None);
112 assert!(config.mls_settings.is_none());
113 assert!(!config.initiator);
114 assert!(config.metadata.is_empty());
115 }
116
117 #[test]
118 fn test_with_session_type() {
119 let mut metadata = HashMap::new();
120 metadata.insert("key1".to_string(), "value1".to_string());
121
122 let config = SessionConfig {
123 session_type: ProtoSessionType::Unspecified,
124 max_retries: Some(5),
125 interval: Some(Duration::from_secs(10)),
126 initiator: true,
127 metadata: metadata.clone(),
128 mls_settings: Some(MlsSettings::default()),
129 };
130
131 let new_config = config.with_session_type(ProtoSessionType::Multicast);
132
133 assert_eq!(new_config.session_type, ProtoSessionType::Multicast);
134 assert_eq!(new_config.max_retries, Some(5));
135 assert_eq!(new_config.interval, Some(Duration::from_secs(10)));
136 assert!(new_config.mls_settings.is_some());
137 assert!(new_config.initiator);
138 assert_eq!(new_config.metadata.len(), 1);
139 assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
140 }
141
142 #[test]
143 fn test_with_session_type_point_to_point() {
144 let config = SessionConfig::default();
145 let new_config = config.with_session_type(ProtoSessionType::PointToPoint);
146 assert_eq!(new_config.session_type, ProtoSessionType::PointToPoint);
147 }
148
149 #[test]
150 fn test_from_join_request_with_timer_settings() {
151 let dest = Name::from_strings(["dest", "", ""]);
152 let payload = CommandPayload::builder().join_request(
153 Some(3),
154 Some(Duration::from_millis(500)),
155 Some(dest),
156 Some(slim_datapath::api::ProtoMlsSettings {
157 header_integrity_validation_percent: 100,
158 }),
159 );
160
161 let mut metadata = HashMap::new();
162 metadata.insert("test_key".to_string(), "test_value".to_string());
163
164 let config = SessionConfig::from_join_request(
165 ProtoSessionType::Multicast,
166 &payload,
167 metadata.clone(),
168 true,
169 )
170 .unwrap();
171
172 assert_eq!(config.session_type, ProtoSessionType::Multicast);
173 assert_eq!(config.max_retries, Some(3));
174 assert_eq!(config.interval, Some(Duration::from_millis(500)));
175 assert!(config.mls_settings.is_some());
176 assert!(config.initiator);
177 assert_eq!(config.metadata.len(), 1);
178 assert_eq!(
179 config.metadata.get("test_key"),
180 Some(&"test_value".to_string())
181 );
182 }
183
184 #[test]
185 fn test_from_join_request_without_timer_settings() {
186 let dest = Name::from_strings(["dest", "", ""]);
187 let payload = CommandPayload::builder().join_request(None, None, Some(dest), None);
188
189 let metadata = HashMap::new();
190
191 let config = SessionConfig::from_join_request(
192 ProtoSessionType::PointToPoint,
193 &payload,
194 metadata,
195 false,
196 )
197 .unwrap();
198
199 assert_eq!(config.session_type, ProtoSessionType::PointToPoint);
200 assert_eq!(config.max_retries, None);
201 assert_eq!(config.interval, None);
202 assert!(config.mls_settings.is_none());
203 assert!(!config.initiator);
204 assert!(config.metadata.is_empty());
205 }
206
207 #[test]
208 fn test_from_join_request_with_mls_enabled() {
209 let dest = Name::from_strings(["dest", "", ""]);
210 let payload = CommandPayload::builder().join_request(
211 Some(10),
212 Some(Duration::from_secs(5)),
213 Some(dest),
214 Some(slim_datapath::api::ProtoMlsSettings {
215 header_integrity_validation_percent: 100,
216 }),
217 );
218
219 let config = SessionConfig::from_join_request(
220 ProtoSessionType::Multicast,
221 &payload,
222 HashMap::new(),
223 false,
224 )
225 .unwrap();
226
227 assert!(config.mls_settings.is_some());
228 }
229
230 #[test]
231 fn test_from_join_request_invalid_payload() {
232 let payload = CommandPayload::builder().leave_request();
234
235 let result = SessionConfig::from_join_request(
236 ProtoSessionType::Multicast,
237 &payload,
238 HashMap::new(),
239 true,
240 );
241
242 assert!(result.is_err_and(|e| matches!(
243 e,
244 SessionError::MessageError(MessageError::InvalidCommandPayloadType {
245 expected: _,
246 got: _
247 })
248 )));
249 }
250
251 #[test]
252 fn test_clone() {
253 let mut metadata = HashMap::new();
254 metadata.insert("key".to_string(), "value".to_string());
255
256 let config = SessionConfig {
257 session_type: ProtoSessionType::Multicast,
258 max_retries: Some(7),
259 interval: Some(Duration::from_millis(1000)),
260 initiator: false,
261 metadata: metadata.clone(),
262 mls_settings: Some(MlsSettings::default()),
263 };
264
265 let cloned = config.clone();
266
267 assert_eq!(cloned.session_type, config.session_type);
268 assert_eq!(cloned.max_retries, config.max_retries);
269 assert_eq!(cloned.interval, config.interval);
270 assert_eq!(cloned.mls_settings, config.mls_settings);
271 assert_eq!(cloned.initiator, config.initiator);
272 assert_eq!(cloned.metadata, config.metadata);
273 }
274
275 #[test]
276 fn test_from_join_request_with_large_timeout() {
277 let dest = Name::from_strings(["dest", "", ""]);
278 let payload = CommandPayload::builder().join_request(
279 Some(100),
280 Some(Duration::from_secs(3600)), Some(dest),
282 None,
283 );
284
285 let config = SessionConfig::from_join_request(
286 ProtoSessionType::Multicast,
287 &payload,
288 HashMap::new(),
289 true,
290 )
291 .unwrap();
292
293 assert_eq!(config.max_retries, Some(100));
294 assert_eq!(config.interval, Some(Duration::from_secs(3600)));
295 }
296
297 #[test]
298 fn test_metadata_preservation() {
299 let mut metadata = HashMap::new();
300 metadata.insert("key1".to_string(), "value1".to_string());
301 metadata.insert("key2".to_string(), "value2".to_string());
302 metadata.insert("key3".to_string(), "value3".to_string());
303
304 let config = SessionConfig {
305 session_type: ProtoSessionType::Unspecified,
306 max_retries: None,
307 interval: None,
308 initiator: false,
309 metadata: metadata.clone(),
310 mls_settings: None,
311 };
312
313 let new_config = config.with_session_type(ProtoSessionType::Multicast);
314
315 assert_eq!(new_config.metadata.len(), 3);
316 assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
317 assert_eq!(new_config.metadata.get("key2"), Some(&"value2".to_string()));
318 assert_eq!(new_config.metadata.get("key3"), Some(&"value3".to_string()));
319 }
320}