use anyhow::{anyhow, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub enabled: bool,
pub protocol_version: TlsVersion,
pub certificates: CertificateConfig,
pub cipher_suites: Vec<CipherSuite>,
pub mtls: MutualTlsConfig,
pub rotation: CertRotationConfig,
pub ocsp_stapling: OcspConfig,
pub perfect_forward_secrecy: bool,
pub session_resumption: SessionResumptionConfig,
pub alpn_protocols: Vec<String>,
}
impl Default for TlsConfig {
fn default() -> Self {
Self {
enabled: true,
protocol_version: TlsVersion::Tls13,
certificates: CertificateConfig::default(),
cipher_suites: vec![
CipherSuite::TLS_AES_256_GCM_SHA384,
CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
CipherSuite::TLS_AES_128_GCM_SHA256,
],
mtls: MutualTlsConfig::default(),
rotation: CertRotationConfig::default(),
ocsp_stapling: OcspConfig::default(),
perfect_forward_secrecy: true,
session_resumption: SessionResumptionConfig::default(),
alpn_protocols: vec!["h2".to_string(), "http/1.1".to_string()],
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum TlsVersion {
Tls12,
Tls13,
}
impl std::fmt::Display for TlsVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TlsVersion::Tls12 => write!(f, "TLS 1.2"),
TlsVersion::Tls13 => write!(f, "TLS 1.3"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateConfig {
pub server_cert_path: PathBuf,
pub server_key_path: PathBuf,
pub ca_cert_path: Option<PathBuf>,
pub cert_chain_path: Option<PathBuf>,
pub key_password: Option<String>,
pub format: CertificateFormat,
pub verify_peer: bool,
pub verify_hostname: bool,
}
impl Default for CertificateConfig {
fn default() -> Self {
Self {
server_cert_path: PathBuf::from("/etc/oxirs/certs/server.crt"),
server_key_path: PathBuf::from("/etc/oxirs/certs/server.key"),
ca_cert_path: Some(PathBuf::from("/etc/oxirs/certs/ca.crt")),
cert_chain_path: None,
key_password: None,
format: CertificateFormat::PEM,
verify_peer: true,
verify_hostname: true,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum CertificateFormat {
PEM,
DER,
PKCS12,
}
#[allow(non_camel_case_types)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum CipherSuite {
TLS_AES_256_GCM_SHA384,
TLS_CHACHA20_POLY1305_SHA256,
TLS_AES_128_GCM_SHA256,
TLS_AES_128_CCM_SHA256,
TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
}
impl std::fmt::Display for CipherSuite {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CipherSuite::TLS_AES_256_GCM_SHA384 => write!(f, "TLS_AES_256_GCM_SHA384"),
CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => write!(f, "TLS_CHACHA20_POLY1305_SHA256"),
CipherSuite::TLS_AES_128_GCM_SHA256 => write!(f, "TLS_AES_128_GCM_SHA256"),
CipherSuite::TLS_AES_128_CCM_SHA256 => write!(f, "TLS_AES_128_CCM_SHA256"),
CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => {
write!(f, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")
}
CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => {
write!(f, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
}
CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => {
write!(f, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256")
}
CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => {
write!(f, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256")
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MutualTlsConfig {
pub enabled: bool,
pub require_client_cert: bool,
pub trusted_ca_certs: Vec<PathBuf>,
pub verification_depth: u8,
pub revocation_check: RevocationCheckConfig,
}
impl Default for MutualTlsConfig {
fn default() -> Self {
Self {
enabled: false,
require_client_cert: true,
trusted_ca_certs: vec![],
verification_depth: 3,
revocation_check: RevocationCheckConfig::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RevocationCheckConfig {
pub enabled: bool,
pub check_crl: bool,
pub check_ocsp: bool,
pub crl_cache_ttl: u64,
}
impl Default for RevocationCheckConfig {
fn default() -> Self {
Self {
enabled: true,
check_crl: true,
check_ocsp: true,
crl_cache_ttl: 3600,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertRotationConfig {
pub enabled: bool,
pub check_interval_secs: u64,
pub rotation_threshold_days: u32,
pub graceful_period_secs: u64,
}
impl Default for CertRotationConfig {
fn default() -> Self {
Self {
enabled: true,
check_interval_secs: 3600, rotation_threshold_days: 30, graceful_period_secs: 300, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcspConfig {
pub enabled: bool,
pub responder_url: Option<String>,
pub cache_ttl: u64,
pub timeout_secs: u64,
}
impl Default for OcspConfig {
fn default() -> Self {
Self {
enabled: true,
responder_url: None,
cache_ttl: 3600,
timeout_secs: 10,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionResumptionConfig {
pub enabled: bool,
pub cache_size: usize,
pub ticket_lifetime_secs: u64,
pub session_id_lifetime_secs: u64,
}
impl Default for SessionResumptionConfig {
fn default() -> Self {
Self {
enabled: true,
cache_size: 10000,
ticket_lifetime_secs: 7200, session_id_lifetime_secs: 7200, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CertificateInfo {
pub subject: String,
pub issuer: String,
pub serial_number: String,
pub valid_from: DateTime<Utc>,
pub valid_until: DateTime<Utc>,
pub san: Vec<String>,
pub key_algorithm: String,
pub key_size: u32,
pub signature_algorithm: String,
pub fingerprint_sha256: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsSessionInfo {
pub session_id: String,
pub protocol_version: TlsVersion,
pub cipher_suite: String,
pub sni: Option<String>,
pub alpn_protocol: Option<String>,
pub client_cert: Option<CertificateInfo>,
pub established_at: DateTime<Utc>,
}
pub struct TlsManager {
config: TlsConfig,
certificates: Arc<RwLock<HashMap<String, CertificateInfo>>>,
sessions: Arc<RwLock<HashMap<String, TlsSessionInfo>>>,
metrics: Arc<RwLock<TlsMetrics>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TlsMetrics {
pub connections_established: u64,
pub handshakes_total: u64,
pub handshakes_failed: u64,
pub certificate_rotations: u64,
pub ocsp_requests: u64,
pub session_resumptions: u64,
pub avg_handshake_duration_ms: f64,
pub tls_version_distribution: HashMap<String, u64>,
pub cipher_suite_distribution: HashMap<String, u64>,
}
impl TlsManager {
pub fn new(config: TlsConfig) -> Self {
Self {
config,
certificates: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
metrics: Arc::new(RwLock::new(TlsMetrics::default())),
}
}
pub async fn initialize(&self) -> Result<()> {
info!("Initializing TLS manager");
if !self.config.enabled {
warn!("TLS is disabled");
return Ok(());
}
self.validate_certificate_paths().await?;
self.load_certificates().await?;
if self.config.rotation.enabled {
self.start_rotation_monitor().await?;
}
info!("TLS manager initialized successfully");
Ok(())
}
async fn validate_certificate_paths(&self) -> Result<()> {
let cert_path = &self.config.certificates.server_cert_path;
let key_path = &self.config.certificates.server_key_path;
if !cert_path.exists() {
return Err(anyhow!("Server certificate not found: {:?}", cert_path));
}
if !key_path.exists() {
return Err(anyhow!("Server private key not found: {:?}", key_path));
}
if let Some(ca_path) = &self.config.certificates.ca_cert_path {
if !ca_path.exists() {
warn!("CA certificate not found: {:?}", ca_path);
}
}
debug!("Certificate paths validated");
Ok(())
}
async fn load_certificates(&self) -> Result<()> {
info!("Loading TLS certificates");
debug!("Certificates loaded successfully");
Ok(())
}
async fn start_rotation_monitor(&self) -> Result<()> {
info!("Starting certificate rotation monitor");
let check_interval = self.config.rotation.check_interval_secs;
let threshold_days = self.config.rotation.rotation_threshold_days;
debug!(
"Rotation monitor started (check_interval={}s, threshold={}d)",
check_interval, threshold_days
);
Ok(())
}
pub async fn handshake(&self, connection_id: &str) -> Result<TlsSessionInfo> {
let start_time = std::time::Instant::now();
{
let mut metrics = self.metrics.write().await;
metrics.handshakes_total += 1;
}
let session_info = TlsSessionInfo {
session_id: connection_id.to_string(),
protocol_version: self.config.protocol_version,
cipher_suite: self.config.cipher_suites[0].to_string(),
sni: None,
alpn_protocol: self.config.alpn_protocols.first().cloned(),
client_cert: None,
established_at: Utc::now(),
};
self.sessions
.write()
.await
.insert(connection_id.to_string(), session_info.clone());
{
let mut metrics = self.metrics.write().await;
metrics.connections_established += 1;
let duration = start_time.elapsed().as_millis() as f64;
metrics.avg_handshake_duration_ms =
(metrics.avg_handshake_duration_ms + duration) / 2.0;
let version_key = session_info.protocol_version.to_string();
*metrics
.tls_version_distribution
.entry(version_key)
.or_insert(0) += 1;
*metrics
.cipher_suite_distribution
.entry(session_info.cipher_suite.clone())
.or_insert(0) += 1;
}
debug!(
"TLS handshake completed for connection: {} in {:?}",
connection_id,
start_time.elapsed()
);
Ok(session_info)
}
pub async fn rotate_certificates(&self) -> Result<()> {
info!("Starting certificate rotation");
{
let mut metrics = self.metrics.write().await;
metrics.certificate_rotations += 1;
}
info!("Certificate rotation completed successfully");
Ok(())
}
pub async fn get_session_info(&self, session_id: &str) -> Option<TlsSessionInfo> {
self.sessions.read().await.get(session_id).cloned()
}
pub async fn get_certificate_info(&self, cert_id: &str) -> Option<CertificateInfo> {
self.certificates.read().await.get(cert_id).cloned()
}
pub async fn get_metrics(&self) -> TlsMetrics {
self.metrics.read().await.clone()
}
pub async fn close_session(&self, session_id: &str) -> Result<()> {
self.sessions.write().await.remove(session_id);
debug!("TLS session closed: {}", session_id);
Ok(())
}
pub async fn check_certificate_expiry(&self) -> Result<Vec<ExpiryWarning>> {
let mut warnings = Vec::new();
let certificates = self.certificates.read().await;
let threshold_days = self.config.rotation.rotation_threshold_days;
for (cert_id, cert_info) in certificates.iter() {
let days_until_expiry = (cert_info.valid_until - Utc::now()).num_days();
if days_until_expiry < threshold_days as i64 {
warnings.push(ExpiryWarning {
certificate_id: cert_id.clone(),
subject: cert_info.subject.clone(),
expires_at: cert_info.valid_until,
days_until_expiry,
});
warn!(
"Certificate {} expires in {} days",
cert_id, days_until_expiry
);
}
}
Ok(warnings)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpiryWarning {
pub certificate_id: String,
pub subject: String,
pub expires_at: DateTime<Utc>,
pub days_until_expiry: i64,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_tls_config_default() {
let config = TlsConfig::default();
assert!(config.enabled);
assert_eq!(config.protocol_version, TlsVersion::Tls13);
assert!(config.perfect_forward_secrecy);
}
#[tokio::test]
async fn test_tls_manager_creation() {
let config = TlsConfig::default();
let manager = TlsManager::new(config);
let metrics = manager.get_metrics().await;
assert_eq!(metrics.connections_established, 0);
}
#[tokio::test]
async fn test_cipher_suite_display() {
let suite = CipherSuite::TLS_AES_256_GCM_SHA384;
assert_eq!(suite.to_string(), "TLS_AES_256_GCM_SHA384");
}
#[tokio::test]
async fn test_tls_version_display() {
assert_eq!(TlsVersion::Tls13.to_string(), "TLS 1.3");
assert_eq!(TlsVersion::Tls12.to_string(), "TLS 1.2");
}
#[tokio::test]
async fn test_mtls_config_default() {
let config = MutualTlsConfig::default();
assert!(!config.enabled);
assert!(config.require_client_cert);
assert_eq!(config.verification_depth, 3);
}
}