1use anyhow::{anyhow, Result};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TlsConfig {
19 pub enabled: bool,
21 pub protocol_version: TlsVersion,
23 pub certificates: CertificateConfig,
25 pub cipher_suites: Vec<CipherSuite>,
27 pub mtls: MutualTlsConfig,
29 pub rotation: CertRotationConfig,
31 pub ocsp_stapling: OcspConfig,
33 pub perfect_forward_secrecy: bool,
35 pub session_resumption: SessionResumptionConfig,
37 pub alpn_protocols: Vec<String>,
39}
40
41impl Default for TlsConfig {
42 fn default() -> Self {
43 Self {
44 enabled: true,
45 protocol_version: TlsVersion::Tls13,
46 certificates: CertificateConfig::default(),
47 cipher_suites: vec![
48 CipherSuite::TLS_AES_256_GCM_SHA384,
49 CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
50 CipherSuite::TLS_AES_128_GCM_SHA256,
51 ],
52 mtls: MutualTlsConfig::default(),
53 rotation: CertRotationConfig::default(),
54 ocsp_stapling: OcspConfig::default(),
55 perfect_forward_secrecy: true,
56 session_resumption: SessionResumptionConfig::default(),
57 alpn_protocols: vec!["h2".to_string(), "http/1.1".to_string()],
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
64pub enum TlsVersion {
65 Tls12,
67 Tls13,
69}
70
71impl std::fmt::Display for TlsVersion {
72 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73 match self {
74 TlsVersion::Tls12 => write!(f, "TLS 1.2"),
75 TlsVersion::Tls13 => write!(f, "TLS 1.3"),
76 }
77 }
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct CertificateConfig {
83 pub server_cert_path: PathBuf,
85 pub server_key_path: PathBuf,
87 pub ca_cert_path: Option<PathBuf>,
89 pub cert_chain_path: Option<PathBuf>,
91 pub key_password: Option<String>,
93 pub format: CertificateFormat,
95 pub verify_peer: bool,
97 pub verify_hostname: bool,
99}
100
101impl Default for CertificateConfig {
102 fn default() -> Self {
103 Self {
104 server_cert_path: PathBuf::from("/etc/oxirs/certs/server.crt"),
105 server_key_path: PathBuf::from("/etc/oxirs/certs/server.key"),
106 ca_cert_path: Some(PathBuf::from("/etc/oxirs/certs/ca.crt")),
107 cert_chain_path: None,
108 key_password: None,
109 format: CertificateFormat::PEM,
110 verify_peer: true,
111 verify_hostname: true,
112 }
113 }
114}
115
116#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
118pub enum CertificateFormat {
119 PEM,
121 DER,
123 PKCS12,
125}
126
127#[allow(non_camel_case_types)]
129#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
130pub enum CipherSuite {
131 TLS_AES_256_GCM_SHA384,
133 TLS_CHACHA20_POLY1305_SHA256,
134 TLS_AES_128_GCM_SHA256,
135 TLS_AES_128_CCM_SHA256,
136
137 TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
139 TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
140 TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
141 TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
142}
143
144impl std::fmt::Display for CipherSuite {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 match self {
147 CipherSuite::TLS_AES_256_GCM_SHA384 => write!(f, "TLS_AES_256_GCM_SHA384"),
148 CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => write!(f, "TLS_CHACHA20_POLY1305_SHA256"),
149 CipherSuite::TLS_AES_128_GCM_SHA256 => write!(f, "TLS_AES_128_GCM_SHA256"),
150 CipherSuite::TLS_AES_128_CCM_SHA256 => write!(f, "TLS_AES_128_CCM_SHA256"),
151 CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => {
152 write!(f, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")
153 }
154 CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => {
155 write!(f, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
156 }
157 CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => {
158 write!(f, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256")
159 }
160 CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => {
161 write!(f, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256")
162 }
163 }
164 }
165}
166
167#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct MutualTlsConfig {
170 pub enabled: bool,
172 pub require_client_cert: bool,
174 pub trusted_ca_certs: Vec<PathBuf>,
176 pub verification_depth: u8,
178 pub revocation_check: RevocationCheckConfig,
180}
181
182impl Default for MutualTlsConfig {
183 fn default() -> Self {
184 Self {
185 enabled: false,
186 require_client_cert: true,
187 trusted_ca_certs: vec![],
188 verification_depth: 3,
189 revocation_check: RevocationCheckConfig::default(),
190 }
191 }
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct RevocationCheckConfig {
197 pub enabled: bool,
199 pub check_crl: bool,
201 pub check_ocsp: bool,
203 pub crl_cache_ttl: u64,
205}
206
207impl Default for RevocationCheckConfig {
208 fn default() -> Self {
209 Self {
210 enabled: true,
211 check_crl: true,
212 check_ocsp: true,
213 crl_cache_ttl: 3600,
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct CertRotationConfig {
221 pub enabled: bool,
223 pub check_interval_secs: u64,
225 pub rotation_threshold_days: u32,
227 pub graceful_period_secs: u64,
229}
230
231impl Default for CertRotationConfig {
232 fn default() -> Self {
233 Self {
234 enabled: true,
235 check_interval_secs: 3600, rotation_threshold_days: 30, graceful_period_secs: 300, }
239 }
240}
241
242#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct OcspConfig {
245 pub enabled: bool,
247 pub responder_url: Option<String>,
249 pub cache_ttl: u64,
251 pub timeout_secs: u64,
253}
254
255impl Default for OcspConfig {
256 fn default() -> Self {
257 Self {
258 enabled: true,
259 responder_url: None,
260 cache_ttl: 3600,
261 timeout_secs: 10,
262 }
263 }
264}
265
266#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct SessionResumptionConfig {
269 pub enabled: bool,
271 pub cache_size: usize,
273 pub ticket_lifetime_secs: u64,
275 pub session_id_lifetime_secs: u64,
277}
278
279impl Default for SessionResumptionConfig {
280 fn default() -> Self {
281 Self {
282 enabled: true,
283 cache_size: 10000,
284 ticket_lifetime_secs: 7200, session_id_lifetime_secs: 7200, }
287 }
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CertificateInfo {
293 pub subject: String,
295 pub issuer: String,
297 pub serial_number: String,
299 pub valid_from: DateTime<Utc>,
301 pub valid_until: DateTime<Utc>,
303 pub san: Vec<String>,
305 pub key_algorithm: String,
307 pub key_size: u32,
309 pub signature_algorithm: String,
311 pub fingerprint_sha256: String,
313}
314
315#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct TlsSessionInfo {
318 pub session_id: String,
320 pub protocol_version: TlsVersion,
322 pub cipher_suite: String,
324 pub sni: Option<String>,
326 pub alpn_protocol: Option<String>,
328 pub client_cert: Option<CertificateInfo>,
330 pub established_at: DateTime<Utc>,
332}
333
334pub struct TlsManager {
336 config: TlsConfig,
337 certificates: Arc<RwLock<HashMap<String, CertificateInfo>>>,
338 sessions: Arc<RwLock<HashMap<String, TlsSessionInfo>>>,
339 metrics: Arc<RwLock<TlsMetrics>>,
340}
341
342#[derive(Debug, Clone, Default, Serialize, Deserialize)]
344pub struct TlsMetrics {
345 pub connections_established: u64,
347 pub handshakes_total: u64,
349 pub handshakes_failed: u64,
351 pub certificate_rotations: u64,
353 pub ocsp_requests: u64,
355 pub session_resumptions: u64,
357 pub avg_handshake_duration_ms: f64,
359 pub tls_version_distribution: HashMap<String, u64>,
361 pub cipher_suite_distribution: HashMap<String, u64>,
363}
364
365impl TlsManager {
366 pub fn new(config: TlsConfig) -> Self {
368 Self {
369 config,
370 certificates: Arc::new(RwLock::new(HashMap::new())),
371 sessions: Arc::new(RwLock::new(HashMap::new())),
372 metrics: Arc::new(RwLock::new(TlsMetrics::default())),
373 }
374 }
375
376 pub async fn initialize(&self) -> Result<()> {
378 info!("Initializing TLS manager");
379
380 if !self.config.enabled {
381 warn!("TLS is disabled");
382 return Ok(());
383 }
384
385 self.validate_certificate_paths().await?;
387
388 self.load_certificates().await?;
390
391 if self.config.rotation.enabled {
393 self.start_rotation_monitor().await?;
394 }
395
396 info!("TLS manager initialized successfully");
397 Ok(())
398 }
399
400 async fn validate_certificate_paths(&self) -> Result<()> {
402 let cert_path = &self.config.certificates.server_cert_path;
403 let key_path = &self.config.certificates.server_key_path;
404
405 if !cert_path.exists() {
406 return Err(anyhow!("Server certificate not found: {:?}", cert_path));
407 }
408
409 if !key_path.exists() {
410 return Err(anyhow!("Server private key not found: {:?}", key_path));
411 }
412
413 if let Some(ca_path) = &self.config.certificates.ca_cert_path {
414 if !ca_path.exists() {
415 warn!("CA certificate not found: {:?}", ca_path);
416 }
417 }
418
419 debug!("Certificate paths validated");
420 Ok(())
421 }
422
423 async fn load_certificates(&self) -> Result<()> {
425 info!("Loading TLS certificates");
426
427 debug!("Certificates loaded successfully");
436 Ok(())
437 }
438
439 async fn start_rotation_monitor(&self) -> Result<()> {
441 info!("Starting certificate rotation monitor");
442
443 let check_interval = self.config.rotation.check_interval_secs;
444 let threshold_days = self.config.rotation.rotation_threshold_days;
445
446 debug!(
451 "Rotation monitor started (check_interval={}s, threshold={}d)",
452 check_interval, threshold_days
453 );
454 Ok(())
455 }
456
457 pub async fn handshake(&self, connection_id: &str) -> Result<TlsSessionInfo> {
459 let start_time = std::time::Instant::now();
460
461 {
463 let mut metrics = self.metrics.write().await;
464 metrics.handshakes_total += 1;
465 }
466
467 let session_info = TlsSessionInfo {
469 session_id: connection_id.to_string(),
470 protocol_version: self.config.protocol_version,
471 cipher_suite: self.config.cipher_suites[0].to_string(),
472 sni: None,
473 alpn_protocol: self.config.alpn_protocols.first().cloned(),
474 client_cert: None,
475 established_at: Utc::now(),
476 };
477
478 self.sessions
480 .write()
481 .await
482 .insert(connection_id.to_string(), session_info.clone());
483
484 {
486 let mut metrics = self.metrics.write().await;
487 metrics.connections_established += 1;
488 let duration = start_time.elapsed().as_millis() as f64;
489 metrics.avg_handshake_duration_ms =
490 (metrics.avg_handshake_duration_ms + duration) / 2.0;
491
492 let version_key = session_info.protocol_version.to_string();
494 *metrics
495 .tls_version_distribution
496 .entry(version_key)
497 .or_insert(0) += 1;
498
499 *metrics
501 .cipher_suite_distribution
502 .entry(session_info.cipher_suite.clone())
503 .or_insert(0) += 1;
504 }
505
506 debug!(
507 "TLS handshake completed for connection: {} in {:?}",
508 connection_id,
509 start_time.elapsed()
510 );
511
512 Ok(session_info)
513 }
514
515 pub async fn rotate_certificates(&self) -> Result<()> {
517 info!("Starting certificate rotation");
518
519 {
527 let mut metrics = self.metrics.write().await;
528 metrics.certificate_rotations += 1;
529 }
530
531 info!("Certificate rotation completed successfully");
532 Ok(())
533 }
534
535 pub async fn get_session_info(&self, session_id: &str) -> Option<TlsSessionInfo> {
537 self.sessions.read().await.get(session_id).cloned()
538 }
539
540 pub async fn get_certificate_info(&self, cert_id: &str) -> Option<CertificateInfo> {
542 self.certificates.read().await.get(cert_id).cloned()
543 }
544
545 pub async fn get_metrics(&self) -> TlsMetrics {
547 self.metrics.read().await.clone()
548 }
549
550 pub async fn close_session(&self, session_id: &str) -> Result<()> {
552 self.sessions.write().await.remove(session_id);
553 debug!("TLS session closed: {}", session_id);
554 Ok(())
555 }
556
557 pub async fn check_certificate_expiry(&self) -> Result<Vec<ExpiryWarning>> {
559 let mut warnings = Vec::new();
560
561 let certificates = self.certificates.read().await;
562 let threshold_days = self.config.rotation.rotation_threshold_days;
563
564 for (cert_id, cert_info) in certificates.iter() {
565 let days_until_expiry = (cert_info.valid_until - Utc::now()).num_days();
566
567 if days_until_expiry < threshold_days as i64 {
568 warnings.push(ExpiryWarning {
569 certificate_id: cert_id.clone(),
570 subject: cert_info.subject.clone(),
571 expires_at: cert_info.valid_until,
572 days_until_expiry,
573 });
574
575 warn!(
576 "Certificate {} expires in {} days",
577 cert_id, days_until_expiry
578 );
579 }
580 }
581
582 Ok(warnings)
583 }
584}
585
586#[derive(Debug, Clone, Serialize, Deserialize)]
588pub struct ExpiryWarning {
589 pub certificate_id: String,
591 pub subject: String,
593 pub expires_at: DateTime<Utc>,
595 pub days_until_expiry: i64,
597}
598
599#[cfg(test)]
600mod tests {
601 use super::*;
602
603 #[tokio::test]
604 async fn test_tls_config_default() {
605 let config = TlsConfig::default();
606 assert!(config.enabled);
607 assert_eq!(config.protocol_version, TlsVersion::Tls13);
608 assert!(config.perfect_forward_secrecy);
609 }
610
611 #[tokio::test]
612 async fn test_tls_manager_creation() {
613 let config = TlsConfig::default();
614 let manager = TlsManager::new(config);
615 let metrics = manager.get_metrics().await;
616 assert_eq!(metrics.connections_established, 0);
617 }
618
619 #[tokio::test]
620 async fn test_cipher_suite_display() {
621 let suite = CipherSuite::TLS_AES_256_GCM_SHA384;
622 assert_eq!(suite.to_string(), "TLS_AES_256_GCM_SHA384");
623 }
624
625 #[tokio::test]
626 async fn test_tls_version_display() {
627 assert_eq!(TlsVersion::Tls13.to_string(), "TLS 1.3");
628 assert_eq!(TlsVersion::Tls12.to_string(), "TLS 1.2");
629 }
630
631 #[tokio::test]
632 async fn test_mtls_config_default() {
633 let config = MutualTlsConfig::default();
634 assert!(!config.enabled);
635 assert!(config.require_client_cert);
636 assert_eq!(config.verification_depth, 3);
637 }
638}