use std::fmt;
use serde::{Deserialize, Serialize};
use crate::security::errors::{Result, SecurityError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[non_exhaustive]
pub enum TlsVersion {
V1_0,
V1_1,
V1_2,
V1_3,
}
impl fmt::Display for TlsVersion {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::V1_0 => write!(f, "TLS 1.0"),
Self::V1_1 => write!(f, "TLS 1.1"),
Self::V1_2 => write!(f, "TLS 1.2"),
Self::V1_3 => write!(f, "TLS 1.3"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TlsConnection {
pub is_secure: bool,
pub version: TlsVersion,
pub has_client_cert: bool,
pub client_cert_valid: bool,
}
impl TlsConnection {
#[must_use]
pub const fn new_http() -> Self {
Self {
is_secure: false,
version: TlsVersion::V1_2, has_client_cert: false,
client_cert_valid: false,
}
}
#[must_use]
pub const fn new_secure(version: TlsVersion) -> Self {
Self {
is_secure: true,
version,
has_client_cert: false,
client_cert_valid: false,
}
}
#[must_use]
pub const fn new_secure_with_client_cert(version: TlsVersion) -> Self {
Self {
is_secure: true,
version,
has_client_cert: true,
client_cert_valid: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct TlsConfig {
pub tls_required: bool,
pub mtls_required: bool,
pub min_version: TlsVersion,
}
impl TlsConfig {
#[must_use]
pub const fn permissive() -> Self {
Self {
tls_required: false,
mtls_required: false,
min_version: TlsVersion::V1_2,
}
}
#[must_use]
pub const fn standard() -> Self {
Self {
tls_required: true,
mtls_required: false,
min_version: TlsVersion::V1_2,
}
}
#[must_use]
pub const fn strict() -> Self {
Self {
tls_required: true,
mtls_required: true,
min_version: TlsVersion::V1_3,
}
}
}
#[derive(Debug, Clone)]
pub struct TlsEnforcer {
config: TlsConfig,
}
impl TlsEnforcer {
#[must_use]
pub const fn from_config(config: TlsConfig) -> Self {
Self { config }
}
#[must_use]
pub const fn permissive() -> Self {
Self::from_config(TlsConfig::permissive())
}
#[must_use]
pub const fn standard() -> Self {
Self::from_config(TlsConfig::standard())
}
#[must_use]
pub const fn strict() -> Self {
Self::from_config(TlsConfig::strict())
}
pub fn validate_connection(&self, conn: &TlsConnection) -> Result<()> {
if self.config.tls_required && !conn.is_secure {
return Err(SecurityError::TlsRequired {
detail: "HTTPS required, but connection is HTTP".to_string(),
});
}
if conn.is_secure && conn.version < self.config.min_version {
return Err(SecurityError::TlsVersionTooOld {
current: conn.version,
required: self.config.min_version,
});
}
if self.config.mtls_required && !conn.has_client_cert {
return Err(SecurityError::MtlsRequired {
detail: "Client certificate required, but none provided".to_string(),
});
}
if conn.has_client_cert && !conn.client_cert_valid {
return Err(SecurityError::InvalidClientCert {
detail: "Client certificate provided but validation failed".to_string(),
});
}
Ok(())
}
#[must_use]
pub const fn config(&self) -> &TlsConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_http_allowed_when_tls_not_required() {
let enforcer = TlsEnforcer::permissive();
let conn = TlsConnection::new_http();
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected HTTP allowed when TLS not required: {e}"));
}
#[test]
fn test_http_rejected_when_tls_required() {
let enforcer = TlsEnforcer::standard();
let conn = TlsConnection::new_http();
let result = enforcer.validate_connection(&conn);
assert!(matches!(result, Err(SecurityError::TlsRequired { .. })));
}
#[test]
fn test_https_allowed_when_tls_required() {
let enforcer = TlsEnforcer::standard();
let conn = TlsConnection::new_secure(TlsVersion::V1_3);
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected HTTPS allowed when TLS required: {e}"));
}
#[test]
fn test_tls_1_0_rejected_when_min_1_3() {
let enforcer = TlsEnforcer::strict(); let conn = TlsConnection::new_secure(TlsVersion::V1_0);
let result = enforcer.validate_connection(&conn);
assert!(matches!(result, Err(SecurityError::TlsVersionTooOld { .. })));
}
#[test]
fn test_tls_1_2_rejected_when_min_1_3() {
let enforcer = TlsEnforcer::strict(); let conn = TlsConnection::new_secure(TlsVersion::V1_2);
let result = enforcer.validate_connection(&conn);
assert!(matches!(result, Err(SecurityError::TlsVersionTooOld { .. })));
}
#[test]
fn test_tls_1_3_allowed_when_min_1_2() {
let enforcer = TlsEnforcer::standard(); let conn = TlsConnection::new_secure(TlsVersion::V1_3);
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected TLS 1.3 allowed when min 1.2: {e}"));
}
#[test]
fn test_tls_1_2_allowed_when_min_1_2() {
let enforcer = TlsEnforcer::standard(); let conn = TlsConnection::new_secure(TlsVersion::V1_2);
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected TLS 1.2 allowed when min 1.2: {e}"));
}
#[test]
fn test_tls_version_check_skipped_for_http() {
let enforcer = TlsEnforcer::permissive();
let conn = TlsConnection::new_http();
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected version check skipped for HTTP: {e}"));
}
#[test]
fn test_client_cert_optional_when_mtls_not_required() {
let enforcer = TlsEnforcer::standard(); let conn = TlsConnection::new_secure(TlsVersion::V1_3);
enforcer.validate_connection(&conn).unwrap_or_else(|e| {
panic!("expected no client cert needed when mTLS not required: {e}")
});
}
#[test]
fn test_client_cert_required_when_mtls_required() {
let enforcer = TlsEnforcer::strict(); let conn = TlsConnection::new_secure(TlsVersion::V1_3);
let result = enforcer.validate_connection(&conn);
assert!(matches!(result, Err(SecurityError::MtlsRequired { .. })));
}
#[test]
fn test_client_cert_allowed_when_mtls_required() {
let enforcer = TlsEnforcer::strict(); let conn = TlsConnection::new_secure_with_client_cert(TlsVersion::V1_3);
enforcer.validate_connection(&conn).unwrap_or_else(|e| {
panic!("expected valid client cert accepted when mTLS required: {e}")
});
}
#[test]
fn test_invalid_cert_rejected() {
let enforcer = TlsEnforcer::strict();
let conn = TlsConnection {
is_secure: true,
version: TlsVersion::V1_3,
has_client_cert: true,
client_cert_valid: false, };
let result = enforcer.validate_connection(&conn);
assert!(matches!(result, Err(SecurityError::InvalidClientCert { .. })));
}
#[test]
fn test_valid_cert_accepted() {
let enforcer = TlsEnforcer::strict();
let conn = TlsConnection::new_secure_with_client_cert(TlsVersion::V1_3);
enforcer
.validate_connection(&conn)
.unwrap_or_else(|e| panic!("expected valid cert accepted: {e}"));
}
#[test]
fn test_all_3_tls_settings_enforced_together() {
let enforcer = TlsEnforcer::strict();
let valid_conn = TlsConnection::new_secure_with_client_cert(TlsVersion::V1_3);
enforcer
.validate_connection(&valid_conn)
.unwrap_or_else(|e| panic!("expected all checks to pass: {e}"));
let http_conn = TlsConnection::new_http();
assert!(matches!(
enforcer.validate_connection(&http_conn),
Err(SecurityError::TlsRequired { .. })
));
let old_tls_conn = TlsConnection::new_secure(TlsVersion::V1_2);
assert!(matches!(
enforcer.validate_connection(&old_tls_conn),
Err(SecurityError::TlsVersionTooOld { .. })
));
let no_cert_conn = TlsConnection::new_secure(TlsVersion::V1_3);
assert!(matches!(
enforcer.validate_connection(&no_cert_conn),
Err(SecurityError::MtlsRequired { .. })
));
}
#[test]
fn test_error_messages_clear_and_loggable() {
let enforcer = TlsEnforcer::strict();
let tls_required_err = enforcer.validate_connection(&TlsConnection::new_http());
if let Err(SecurityError::TlsRequired { detail }) = tls_required_err {
assert!(!detail.is_empty());
assert!(detail.contains("HTTP") || detail.contains("HTTPS"));
} else {
panic!("Expected TlsRequired error");
}
let tls_version_err =
enforcer.validate_connection(&TlsConnection::new_secure(TlsVersion::V1_0));
if let Err(SecurityError::TlsVersionTooOld { current, required }) = tls_version_err {
assert_eq!(current, TlsVersion::V1_0);
assert_eq!(required, TlsVersion::V1_3);
} else {
panic!("Expected TlsVersionTooOld error");
}
}
#[test]
fn test_permissive_config() {
let config = TlsConfig::permissive();
assert!(!config.tls_required);
assert!(!config.mtls_required);
assert_eq!(config.min_version, TlsVersion::V1_2);
}
#[test]
fn test_standard_config() {
let config = TlsConfig::standard();
assert!(config.tls_required);
assert!(!config.mtls_required);
assert_eq!(config.min_version, TlsVersion::V1_2);
}
#[test]
fn test_strict_config() {
let config = TlsConfig::strict();
assert!(config.tls_required);
assert!(config.mtls_required);
assert_eq!(config.min_version, TlsVersion::V1_3);
}
#[test]
fn test_enforcer_helpers() {
let permissive = TlsEnforcer::permissive();
assert!(!permissive.config().tls_required);
let standard = TlsEnforcer::standard();
assert!(standard.config().tls_required);
let strict = TlsEnforcer::strict();
assert!(strict.config().mtls_required);
}
#[test]
fn test_tls_version_display() {
assert_eq!(TlsVersion::V1_0.to_string(), "TLS 1.0");
assert_eq!(TlsVersion::V1_1.to_string(), "TLS 1.1");
assert_eq!(TlsVersion::V1_2.to_string(), "TLS 1.2");
assert_eq!(TlsVersion::V1_3.to_string(), "TLS 1.3");
}
#[test]
fn test_tls_version_ordering() {
assert!(TlsVersion::V1_0 < TlsVersion::V1_1);
assert!(TlsVersion::V1_1 < TlsVersion::V1_2);
assert!(TlsVersion::V1_2 < TlsVersion::V1_3);
assert!(TlsVersion::V1_3 > TlsVersion::V1_2);
}
#[test]
fn test_tls_connection_helpers() {
let http_conn = TlsConnection::new_http();
assert!(!http_conn.is_secure);
let secure_conn = TlsConnection::new_secure(TlsVersion::V1_3);
assert!(secure_conn.is_secure);
assert!(!secure_conn.has_client_cert);
let mtls_conn = TlsConnection::new_secure_with_client_cert(TlsVersion::V1_3);
assert!(mtls_conn.is_secure);
assert!(mtls_conn.has_client_cert);
assert!(mtls_conn.client_cert_valid);
}
#[test]
fn test_custom_config_from_individual_settings() {
let config = TlsConfig {
tls_required: true,
mtls_required: false,
min_version: TlsVersion::V1_2,
};
let enforcer = TlsEnforcer::from_config(config);
let http_conn = TlsConnection::new_http();
assert!(matches!(
enforcer.validate_connection(&http_conn),
Err(SecurityError::TlsRequired { .. })
));
let secure_conn = TlsConnection::new_secure(TlsVersion::V1_2);
enforcer
.validate_connection(&secure_conn)
.unwrap_or_else(|e| panic!("expected HTTPS with TLS 1.2 to pass: {e}"));
let no_cert_conn = TlsConnection::new_secure(TlsVersion::V1_3);
enforcer
.validate_connection(&no_cert_conn)
.unwrap_or_else(|e| panic!("expected HTTPS without client cert to pass: {e}"));
}
#[test]
fn test_http_with_certificate_info_still_fails_when_tls_required() {
let enforcer = TlsEnforcer::standard();
let http_with_cert_info = TlsConnection {
is_secure: false, version: TlsVersion::V1_2,
has_client_cert: true,
client_cert_valid: true,
};
assert!(matches!(
enforcer.validate_connection(&http_with_cert_info),
Err(SecurityError::TlsRequired { .. })
));
}
}