use std::any::Any;
use std::sync::Arc;
use crate::error::ExtensionError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum TlsVersion {
Tls10,
Tls11,
Tls12,
Tls13,
}
impl TlsVersion {
#[must_use]
#[inline]
pub const fn is_deprecated(self) -> bool {
matches!(self, Self::Tls10 | Self::Tls11)
}
}
impl std::fmt::Display for TlsVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tls10 => f.write_str("TLS 1.0"),
Self::Tls11 => f.write_str("TLS 1.1"),
Self::Tls12 => f.write_str("TLS 1.2"),
Self::Tls13 => f.write_str("TLS 1.3"),
}
}
}
pub trait TlsConfigProvider: Send + Sync {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError>;
fn provider_name(&self) -> &str;
fn config_type_name(&self) -> &str;
fn min_tls_version(&self) -> TlsVersion;
fn supports_mtls(&self) -> bool;
fn accepts_invalid_certs(&self) -> bool;
}
#[must_use]
pub fn audit_tls_provider(
provider: &dyn TlsConfigProvider,
) -> Vec<crate::warning::ExtensionWarning> {
let mut warnings = Vec::new();
if provider.accepts_invalid_certs() {
warnings.push(crate::warning::ExtensionWarning {
code: "TLS_NO_VERIFY",
severity: crate::warning::WarningSeverity::High,
message: format!(
"TLS provider {:?} has certificate verification disabled",
provider.provider_name()
),
cwe: Some(295),
});
}
let min_version = provider.min_tls_version();
if min_version.is_deprecated() {
warnings.push(crate::warning::ExtensionWarning {
code: "TLS_DEPRECATED_VERSION",
severity: crate::warning::WarningSeverity::Medium,
message: format!(
"TLS provider {:?} allows deprecated {} (RFC 8996)",
provider.provider_name(),
min_version,
),
cwe: Some(327), });
}
warnings
}
#[cfg(test)]
mod tests {
use super::*;
struct SecureProvider {
config: Arc<String>,
}
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for SecureProvider {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Ok(self.config.clone())
}
fn provider_name(&self) -> &str {
"test-secure"
}
fn config_type_name(&self) -> &str {
"String"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls12
}
fn supports_mtls(&self) -> bool {
false
}
fn accepts_invalid_certs(&self) -> bool {
false
}
}
struct InsecureProvider;
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for InsecureProvider {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Ok(Arc::new(42u32))
}
fn provider_name(&self) -> &str {
"test-insecure"
}
fn config_type_name(&self) -> &str {
"u32"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls10
}
fn supports_mtls(&self) -> bool {
false
}
fn accepts_invalid_certs(&self) -> bool {
true
}
}
struct FailingProvider;
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for FailingProvider {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Err(ExtensionError::new("certificate file not found"))
}
fn provider_name(&self) -> &str {
"test-failing"
}
fn config_type_name(&self) -> &str {
"never"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls13
}
fn supports_mtls(&self) -> bool {
true
}
fn accepts_invalid_certs(&self) -> bool {
false
}
}
#[test]
fn secure_provider_returns_config() {
let provider = SecureProvider {
config: Arc::new("test-config".to_string()),
};
let config = provider.client_config().unwrap();
let s = config.downcast_ref::<String>().unwrap();
assert_eq!(s, "test-config");
}
#[test]
fn secure_provider_metadata() {
let provider = SecureProvider {
config: Arc::new(String::new()),
};
assert_eq!(provider.provider_name(), "test-secure");
assert_eq!(provider.config_type_name(), "String");
assert_eq!(provider.min_tls_version(), TlsVersion::Tls12);
assert!(!provider.supports_mtls());
assert!(!provider.accepts_invalid_certs());
}
#[test]
fn failing_provider_returns_error() {
let provider = FailingProvider;
let err = provider.client_config().unwrap_err();
assert_eq!(err.as_str(), "certificate file not found");
assert!(provider.supports_mtls());
}
#[test]
fn downcast_wrong_type_returns_none() {
let provider = SecureProvider {
config: Arc::new("hello".to_string()),
};
let config = provider.client_config().unwrap();
assert!(config.downcast_ref::<u32>().is_none());
}
#[test]
fn audit_secure_provider_no_warnings() {
let provider = SecureProvider {
config: Arc::new(String::new()),
};
let warnings = audit_tls_provider(&provider);
assert!(warnings.is_empty());
}
#[test]
fn audit_insecure_provider_flags_issues() {
let provider = InsecureProvider;
let warnings = audit_tls_provider(&provider);
assert_eq!(warnings.len(), 2);
let cert_warning = warnings.iter().find(|w| w.code == "TLS_NO_VERIFY");
assert!(cert_warning.is_some());
assert_eq!(cert_warning.unwrap().cwe, Some(295));
assert_eq!(
cert_warning.unwrap().severity,
crate::warning::WarningSeverity::High
);
let version_warning = warnings.iter().find(|w| w.code == "TLS_DEPRECATED_VERSION");
assert!(version_warning.is_some());
assert_eq!(version_warning.unwrap().cwe, Some(327));
assert!(version_warning.unwrap().message.contains("TLS 1.0"));
}
#[test]
fn tls_version_ordering() {
assert!(TlsVersion::Tls10 < TlsVersion::Tls11);
assert!(TlsVersion::Tls11 < TlsVersion::Tls12);
assert!(TlsVersion::Tls12 < TlsVersion::Tls13);
}
#[test]
fn tls_version_deprecated() {
assert!(TlsVersion::Tls10.is_deprecated());
assert!(TlsVersion::Tls11.is_deprecated());
assert!(!TlsVersion::Tls12.is_deprecated());
assert!(!TlsVersion::Tls13.is_deprecated());
}
#[test]
fn tls_version_display() {
assert_eq!(TlsVersion::Tls10.to_string(), "TLS 1.0");
assert_eq!(TlsVersion::Tls11.to_string(), "TLS 1.1");
assert_eq!(TlsVersion::Tls12.to_string(), "TLS 1.2");
assert_eq!(TlsVersion::Tls13.to_string(), "TLS 1.3");
}
#[test]
fn provider_is_send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SecureProvider>();
}
#[test]
fn trait_object_is_send_sync() {
fn assert_send_sync<T: Send + Sync + ?Sized>() {}
assert_send_sync::<dyn TlsConfigProvider>();
}
#[test]
fn tls11_is_deprecated() {
struct Tls11Provider;
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for Tls11Provider {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Ok(Arc::new(()))
}
fn provider_name(&self) -> &str {
"tls11-test"
}
fn config_type_name(&self) -> &str {
"()"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls11
}
fn supports_mtls(&self) -> bool {
false
}
fn accepts_invalid_certs(&self) -> bool {
false
}
}
let warnings = audit_tls_provider(&Tls11Provider);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].code, "TLS_DEPRECATED_VERSION");
assert!(warnings[0].message.contains("TLS 1.1"));
}
#[test]
fn tls13_no_warnings() {
struct Tls13Provider;
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for Tls13Provider {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Ok(Arc::new(()))
}
fn provider_name(&self) -> &str {
"tls13-test"
}
fn config_type_name(&self) -> &str {
"()"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls13
}
fn supports_mtls(&self) -> bool {
true
}
fn accepts_invalid_certs(&self) -> bool {
false
}
}
let warnings = audit_tls_provider(&Tls13Provider);
assert!(warnings.is_empty());
}
#[test]
fn audit_only_invalid_certs_not_deprecated_version() {
struct CertOnlyInsecure;
#[allow(clippy::unnecessary_literal_bound)]
impl TlsConfigProvider for CertOnlyInsecure {
fn client_config(&self) -> Result<Arc<dyn Any + Send + Sync>, ExtensionError> {
Ok(Arc::new(()))
}
fn provider_name(&self) -> &str {
"cert-insecure"
}
fn config_type_name(&self) -> &str {
"()"
}
fn min_tls_version(&self) -> TlsVersion {
TlsVersion::Tls13
}
fn supports_mtls(&self) -> bool {
false
}
fn accepts_invalid_certs(&self) -> bool {
true
}
}
let warnings = audit_tls_provider(&CertOnlyInsecure);
assert_eq!(warnings.len(), 1);
assert_eq!(warnings[0].code, "TLS_NO_VERIFY");
assert_eq!(warnings[0].cwe, Some(295));
}
#[test]
fn tls_version_is_not_deprecated() {
assert!(!TlsVersion::Tls12.is_deprecated());
assert!(!TlsVersion::Tls13.is_deprecated());
}
#[test]
fn failing_provider_metadata() {
let provider = FailingProvider;
assert_eq!(provider.provider_name(), "test-failing");
assert_eq!(provider.config_type_name(), "never");
assert_eq!(provider.min_tls_version(), TlsVersion::Tls13);
assert!(!provider.accepts_invalid_certs());
}
}