#![deny(unsafe_code)]
#![deny(missing_docs)]
use crate::tls::{TlsConfig, TlsMode};
pub struct TlsInvariants {
require_pq: bool,
min_tls_version: u16,
require_forward_secrecy: bool,
}
#[derive(Debug, Clone)]
pub struct InvariantResult {
pub passed: bool,
pub violations: Vec<String>,
}
impl TlsInvariants {
#[must_use]
pub fn new() -> Self {
Self {
require_pq: false,
min_tls_version: 0x0304, require_forward_secrecy: true,
}
}
#[must_use]
pub fn strict() -> Self {
Self { require_pq: true, min_tls_version: 0x0304, require_forward_secrecy: true }
}
#[must_use]
pub fn with_pq_required(mut self, required: bool) -> Self {
self.require_pq = required;
self
}
#[must_use]
pub fn with_min_tls_version(mut self, version: u16) -> Self {
self.min_tls_version = version;
self
}
#[must_use]
pub fn check(&self, config: &TlsConfig) -> InvariantResult {
let mut violations = Vec::new();
if self.require_pq && config.mode == TlsMode::Classic {
violations
.push("INV-1: PQ key exchange required but Classic mode selected".to_string());
}
if let Some(min_ver) = config.min_protocol_version {
let ver_num: u16 = min_ver.into();
if ver_num < self.min_tls_version {
violations.push(format!(
"INV-2: Minimum TLS version 0x{:04x} is below required 0x{:04x}",
ver_num, self.min_tls_version
));
}
}
if self.require_forward_secrecy
&& let Some(min_ver) = config.min_protocol_version
{
let ver_num: u16 = min_ver.into();
if ver_num < 0x0303 {
violations.push("INV-3: Forward secrecy requires TLS 1.2+ minimum".to_string());
}
}
if self.require_pq && config.enable_key_logging {
violations
.push("INV-4: Key logging enabled with strict security requirements".to_string());
}
InvariantResult { passed: violations.is_empty(), violations }
}
}
impl Default for TlsInvariants {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
#[allow(clippy::indexing_slicing)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_default_invariants_pass_default_config_succeeds() {
let inv = TlsInvariants::new();
let config = TlsConfig::new();
let result = inv.check(&config);
assert!(result.passed);
assert!(result.violations.is_empty());
}
#[test]
fn test_strict_invariants_fail_classic_mode_fails() {
let inv = TlsInvariants::strict();
let config = TlsConfig { mode: TlsMode::Classic, ..TlsConfig::default() };
let result = inv.check(&config);
assert!(!result.passed);
assert!(result.violations.iter().any(|v| v.contains("INV-1")));
}
#[test]
fn test_strict_invariants_pass_hybrid_mode_succeeds() {
let inv = TlsInvariants::strict();
let config = TlsConfig::new(); let result = inv.check(&config);
assert!(result.passed);
}
#[test]
fn test_strict_invariants_fail_key_logging_fails() {
let inv = TlsInvariants::strict();
let config = TlsConfig::new().with_key_logging();
let result = inv.check(&config);
assert!(!result.passed);
assert!(result.violations.iter().any(|v| v.contains("INV-4")));
}
#[test]
fn test_custom_min_version_rejects_tls12_fails() {
let inv = TlsInvariants::new().with_min_tls_version(0x0304);
let mut config = TlsConfig::new();
config.min_protocol_version = Some(rustls::ProtocolVersion::TLSv1_2);
let result = inv.check(&config);
assert!(!result.passed);
assert!(result.violations.iter().any(|v| v.contains("INV-2")));
}
#[test]
fn test_pq_required_builder_succeeds() {
let inv = TlsInvariants::new().with_pq_required(true);
let config = TlsConfig { mode: TlsMode::Classic, ..TlsConfig::default() };
let result = inv.check(&config);
assert!(!result.passed);
}
#[test]
fn test_invariant_result_debug_succeeds() {
let result = InvariantResult { passed: true, violations: vec![] };
let debug = format!("{:?}", result);
assert!(debug.contains("passed: true"));
}
}