use std::collections::HashMap;
use slim_datapath::api::{CommandPayload, ProtoSessionType};
use crate::{SessionError, timer_factory::TimerSettings};
#[derive(Default, Clone, Debug, PartialEq)]
pub struct MlsSettings {
pub header_integrity_validation_percent: u32,
}
#[derive(Default, Clone, Debug, PartialEq)]
pub struct SessionConfig {
pub session_type: ProtoSessionType,
pub max_retries: Option<u32>,
pub interval: Option<std::time::Duration>,
pub mls_settings: Option<MlsSettings>,
pub initiator: bool,
pub metadata: HashMap<String, String>,
}
impl SessionConfig {
#[allow(deprecated)]
fn mls_settings_from_join(
join: &slim_datapath::api::JoinRequestPayload,
) -> Option<MlsSettings> {
if join.mls_settings.is_some() {
let header_integrity_validation_percent = join
.mls_settings
.as_ref()
.map(|wire| wire.header_integrity_validation_percent.min(100))
.unwrap_or(100);
Some(MlsSettings {
header_integrity_validation_percent,
})
} else {
None
}
}
pub fn with_session_type(self, session_type: ProtoSessionType) -> Self {
Self {
session_type,
max_retries: self.max_retries,
interval: self.interval,
initiator: self.initiator,
metadata: self.metadata,
mls_settings: self.mls_settings,
}
}
pub fn get_timer_settings(&self) -> TimerSettings {
TimerSettings::constant(self.interval.unwrap_or(std::time::Duration::from_secs(1)))
.with_max_retries(self.max_retries.unwrap_or(10))
}
pub fn from_join_request(
session_type: ProtoSessionType,
payload: &CommandPayload,
metadata: HashMap<String, String>,
initiator: bool,
) -> Result<Self, SessionError> {
let join = payload.as_join_request_payload()?;
let (duration, max_retries) = if let Some(ts) = &join.timer_settings {
(
Some(std::time::Duration::from_millis(ts.timeout as u64)),
Some(ts.max_retries),
)
} else {
(None, None)
};
Ok(SessionConfig {
session_type,
max_retries,
interval: duration,
mls_settings: Self::mls_settings_from_join(join),
initiator,
metadata,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use slim_datapath::api::CommandPayload;
use slim_datapath::api::ProtoName as Name;
use slim_datapath::messages::utils::MessageError;
use std::time::Duration;
#[test]
fn test_default() {
let config = SessionConfig::default();
assert_eq!(config.session_type, ProtoSessionType::Unspecified);
assert_eq!(config.max_retries, None);
assert_eq!(config.interval, None);
assert!(config.mls_settings.is_none());
assert!(!config.initiator);
assert!(config.metadata.is_empty());
}
#[test]
fn test_with_session_type() {
let mut metadata = HashMap::new();
metadata.insert("key1".to_string(), "value1".to_string());
let config = SessionConfig {
session_type: ProtoSessionType::Unspecified,
max_retries: Some(5),
interval: Some(Duration::from_secs(10)),
initiator: true,
metadata: metadata.clone(),
mls_settings: Some(MlsSettings::default()),
};
let new_config = config.with_session_type(ProtoSessionType::Multicast);
assert_eq!(new_config.session_type, ProtoSessionType::Multicast);
assert_eq!(new_config.max_retries, Some(5));
assert_eq!(new_config.interval, Some(Duration::from_secs(10)));
assert!(new_config.mls_settings.is_some());
assert!(new_config.initiator);
assert_eq!(new_config.metadata.len(), 1);
assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
}
#[test]
fn test_with_session_type_point_to_point() {
let config = SessionConfig::default();
let new_config = config.with_session_type(ProtoSessionType::PointToPoint);
assert_eq!(new_config.session_type, ProtoSessionType::PointToPoint);
}
#[test]
fn test_from_join_request_with_timer_settings() {
let dest = Name::from_strings(["dest", "", ""]);
let payload = CommandPayload::builder().join_request(
Some(3),
Some(Duration::from_millis(500)),
Some(dest),
Some(slim_datapath::api::ProtoMlsSettings {
header_integrity_validation_percent: 100,
}),
);
let mut metadata = HashMap::new();
metadata.insert("test_key".to_string(), "test_value".to_string());
let config = SessionConfig::from_join_request(
ProtoSessionType::Multicast,
&payload,
metadata.clone(),
true,
)
.unwrap();
assert_eq!(config.session_type, ProtoSessionType::Multicast);
assert_eq!(config.max_retries, Some(3));
assert_eq!(config.interval, Some(Duration::from_millis(500)));
assert!(config.mls_settings.is_some());
assert!(config.initiator);
assert_eq!(config.metadata.len(), 1);
assert_eq!(
config.metadata.get("test_key"),
Some(&"test_value".to_string())
);
}
#[test]
fn test_from_join_request_without_timer_settings() {
let dest = Name::from_strings(["dest", "", ""]);
let payload = CommandPayload::builder().join_request(None, None, Some(dest), None);
let metadata = HashMap::new();
let config = SessionConfig::from_join_request(
ProtoSessionType::PointToPoint,
&payload,
metadata,
false,
)
.unwrap();
assert_eq!(config.session_type, ProtoSessionType::PointToPoint);
assert_eq!(config.max_retries, None);
assert_eq!(config.interval, None);
assert!(config.mls_settings.is_none());
assert!(!config.initiator);
assert!(config.metadata.is_empty());
}
#[test]
fn test_from_join_request_with_mls_enabled() {
let dest = Name::from_strings(["dest", "", ""]);
let payload = CommandPayload::builder().join_request(
Some(10),
Some(Duration::from_secs(5)),
Some(dest),
Some(slim_datapath::api::ProtoMlsSettings {
header_integrity_validation_percent: 100,
}),
);
let config = SessionConfig::from_join_request(
ProtoSessionType::Multicast,
&payload,
HashMap::new(),
false,
)
.unwrap();
assert!(config.mls_settings.is_some());
}
#[test]
fn test_from_join_request_invalid_payload() {
let payload = CommandPayload::builder().leave_request();
let result = SessionConfig::from_join_request(
ProtoSessionType::Multicast,
&payload,
HashMap::new(),
true,
);
assert!(result.is_err_and(|e| matches!(
e,
SessionError::MessageError(MessageError::InvalidCommandPayloadType {
expected: _,
got: _
})
)));
}
#[test]
fn test_clone() {
let mut metadata = HashMap::new();
metadata.insert("key".to_string(), "value".to_string());
let config = SessionConfig {
session_type: ProtoSessionType::Multicast,
max_retries: Some(7),
interval: Some(Duration::from_millis(1000)),
initiator: false,
metadata: metadata.clone(),
mls_settings: Some(MlsSettings::default()),
};
let cloned = config.clone();
assert_eq!(cloned.session_type, config.session_type);
assert_eq!(cloned.max_retries, config.max_retries);
assert_eq!(cloned.interval, config.interval);
assert_eq!(cloned.mls_settings, config.mls_settings);
assert_eq!(cloned.initiator, config.initiator);
assert_eq!(cloned.metadata, config.metadata);
}
#[test]
fn test_from_join_request_with_large_timeout() {
let dest = Name::from_strings(["dest", "", ""]);
let payload = CommandPayload::builder().join_request(
Some(100),
Some(Duration::from_secs(3600)), Some(dest),
None,
);
let config = SessionConfig::from_join_request(
ProtoSessionType::Multicast,
&payload,
HashMap::new(),
true,
)
.unwrap();
assert_eq!(config.max_retries, Some(100));
assert_eq!(config.interval, Some(Duration::from_secs(3600)));
}
#[test]
fn test_metadata_preservation() {
let mut metadata = HashMap::new();
metadata.insert("key1".to_string(), "value1".to_string());
metadata.insert("key2".to_string(), "value2".to_string());
metadata.insert("key3".to_string(), "value3".to_string());
let config = SessionConfig {
session_type: ProtoSessionType::Unspecified,
max_retries: None,
interval: None,
initiator: false,
metadata: metadata.clone(),
mls_settings: None,
};
let new_config = config.with_session_type(ProtoSessionType::Multicast);
assert_eq!(new_config.metadata.len(), 3);
assert_eq!(new_config.metadata.get("key1"), Some(&"value1".to_string()));
assert_eq!(new_config.metadata.get("key2"), Some(&"value2".to_string()));
assert_eq!(new_config.metadata.get("key3"), Some(&"value3".to_string()));
}
}