#[cfg(not(feature = "std"))]
use alloc::{
string::{String, ToString},
vec::Vec,
};
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub cert_path: Option<String>,
pub key_path: Option<String>,
pub ca_path: Option<String>,
pub server_name: Option<String>,
pub allow_self_signed: bool,
pub min_version: Option<String>,
pub alpn_protocols: Vec<String>,
}
impl TlsConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_cert(mut self, cert_path: &str, key_path: &str) -> Self {
self.cert_path = Some(cert_path.to_string());
self.key_path = Some(key_path.to_string());
self
}
pub fn with_ca(mut self, ca_path: &str) -> Self {
self.ca_path = Some(ca_path.to_string());
self
}
pub fn with_server_name(mut self, name: &str) -> Self {
self.server_name = Some(name.to_string());
self
}
pub fn allow_self_signed(mut self) -> Self {
self.allow_self_signed = true;
self
}
pub fn with_min_version(mut self, version: &str) -> Self {
self.min_version = Some(version.to_string());
self
}
pub fn with_alpn(mut self, protocol: &str) -> Self {
self.alpn_protocols.push(protocol.to_string());
self
}
pub fn is_valid_client(&self) -> bool {
self.server_name.is_some() || self.ca_path.is_some() || self.allow_self_signed
}
pub fn is_valid_server(&self) -> bool {
self.cert_path.is_some() && self.key_path.is_some()
}
}
#[derive(Debug, Clone, Default)]
pub struct DtlsConfig {
pub tls: TlsConfig,
pub mtu: Option<u16>,
pub replay_protection: bool,
pub retransmit_timeout_ms: Option<u32>,
}
impl DtlsConfig {
pub fn new() -> Self {
Self {
replay_protection: true,
..Default::default()
}
}
pub fn from_tls(tls: TlsConfig) -> Self {
Self {
tls,
replay_protection: true,
..Default::default()
}
}
pub fn with_mtu(mut self, mtu: u16) -> Self {
self.mtu = Some(mtu);
self
}
pub fn with_retransmit_timeout(mut self, timeout_ms: u32) -> Self {
self.retransmit_timeout_ms = Some(timeout_ms);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TlsState {
#[default]
Disconnected,
Handshaking,
Connected,
Closed,
Error,
}
#[derive(Debug, Clone)]
pub struct HandshakeResult {
pub protocol_version: String,
pub cipher_suite: String,
pub peer_fingerprint: Option<String>,
pub alpn_protocol: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_config_builder() {
let config = TlsConfig::new()
.with_cert("/path/to/cert.pem", "/path/to/key.pem")
.with_ca("/path/to/ca.pem")
.with_server_name("example.com")
.with_min_version("1.3")
.with_alpn("h2");
assert_eq!(config.cert_path, Some("/path/to/cert.pem".to_string()));
assert_eq!(config.key_path, Some("/path/to/key.pem".to_string()));
assert_eq!(config.ca_path, Some("/path/to/ca.pem".to_string()));
assert_eq!(config.server_name, Some("example.com".to_string()));
assert_eq!(config.min_version, Some("1.3".to_string()));
assert_eq!(config.alpn_protocols, vec!["h2".to_string()]);
}
#[test]
fn test_tls_config_validation() {
let client_config = TlsConfig::new().with_server_name("example.com");
assert!(client_config.is_valid_client());
assert!(!client_config.is_valid_server());
let server_config = TlsConfig::new().with_cert("/path/to/cert.pem", "/path/to/key.pem");
assert!(server_config.is_valid_server());
let self_signed_client = TlsConfig::new().allow_self_signed();
assert!(self_signed_client.is_valid_client());
}
#[test]
fn test_dtls_config() {
let dtls = DtlsConfig::new()
.with_mtu(1400)
.with_retransmit_timeout(500);
assert_eq!(dtls.mtu, Some(1400));
assert_eq!(dtls.retransmit_timeout_ms, Some(500));
assert!(dtls.replay_protection);
}
#[test]
fn test_tls_state_default() {
let state = TlsState::default();
assert_eq!(state, TlsState::Disconnected);
}
}