Skip to main content

oxirs_stream/
tls_security.rs

1//! # Enhanced TLS/SSL Security Module
2//!
3//! Provides comprehensive TLS/SSL encryption, certificate management, and secure communication
4//! for all streaming backends with support for mutual TLS (mTLS), certificate rotation, and
5//! advanced cipher suites.
6
7use anyhow::{anyhow, Result};
8use chrono::{DateTime, Utc};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::path::PathBuf;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, info, warn};
15
16/// TLS/SSL configuration with advanced security options
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TlsConfig {
19    /// Enable TLS encryption
20    pub enabled: bool,
21    /// TLS protocol version
22    pub protocol_version: TlsVersion,
23    /// Certificate configuration
24    pub certificates: CertificateConfig,
25    /// Cipher suite configuration
26    pub cipher_suites: Vec<CipherSuite>,
27    /// Mutual TLS (mTLS) configuration
28    pub mtls: MutualTlsConfig,
29    /// Certificate rotation settings
30    pub rotation: CertRotationConfig,
31    /// OCSP stapling configuration
32    pub ocsp_stapling: OcspConfig,
33    /// Perfect forward secrecy
34    pub perfect_forward_secrecy: bool,
35    /// Session resumption
36    pub session_resumption: SessionResumptionConfig,
37    /// ALPN protocols
38    pub alpn_protocols: Vec<String>,
39}
40
41impl Default for TlsConfig {
42    fn default() -> Self {
43        Self {
44            enabled: true,
45            protocol_version: TlsVersion::Tls13,
46            certificates: CertificateConfig::default(),
47            cipher_suites: vec![
48                CipherSuite::TLS_AES_256_GCM_SHA384,
49                CipherSuite::TLS_CHACHA20_POLY1305_SHA256,
50                CipherSuite::TLS_AES_128_GCM_SHA256,
51            ],
52            mtls: MutualTlsConfig::default(),
53            rotation: CertRotationConfig::default(),
54            ocsp_stapling: OcspConfig::default(),
55            perfect_forward_secrecy: true,
56            session_resumption: SessionResumptionConfig::default(),
57            alpn_protocols: vec!["h2".to_string(), "http/1.1".to_string()],
58        }
59    }
60}
61
62/// TLS protocol versions
63#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
64pub enum TlsVersion {
65    /// TLS 1.2 (minimum recommended)
66    Tls12,
67    /// TLS 1.3 (recommended)
68    Tls13,
69}
70
71impl std::fmt::Display for TlsVersion {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        match self {
74            TlsVersion::Tls12 => write!(f, "TLS 1.2"),
75            TlsVersion::Tls13 => write!(f, "TLS 1.3"),
76        }
77    }
78}
79
80/// Certificate configuration
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct CertificateConfig {
83    /// Server certificate path
84    pub server_cert_path: PathBuf,
85    /// Server private key path
86    pub server_key_path: PathBuf,
87    /// Certificate authority (CA) certificate path
88    pub ca_cert_path: Option<PathBuf>,
89    /// Certificate chain path
90    pub cert_chain_path: Option<PathBuf>,
91    /// Key password/passphrase
92    pub key_password: Option<String>,
93    /// Certificate format
94    pub format: CertificateFormat,
95    /// Verify peer certificates
96    pub verify_peer: bool,
97    /// Verify hostname
98    pub verify_hostname: bool,
99}
100
101impl Default for CertificateConfig {
102    fn default() -> Self {
103        Self {
104            server_cert_path: PathBuf::from("/etc/oxirs/certs/server.crt"),
105            server_key_path: PathBuf::from("/etc/oxirs/certs/server.key"),
106            ca_cert_path: Some(PathBuf::from("/etc/oxirs/certs/ca.crt")),
107            cert_chain_path: None,
108            key_password: None,
109            format: CertificateFormat::PEM,
110            verify_peer: true,
111            verify_hostname: true,
112        }
113    }
114}
115
116/// Certificate formats
117#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
118pub enum CertificateFormat {
119    /// PEM format (Base64 encoded DER)
120    PEM,
121    /// DER format (binary)
122    DER,
123    /// PKCS#12 format (.pfx/.p12)
124    PKCS12,
125}
126
127/// Supported cipher suites (TLS 1.3 and 1.2)
128#[allow(non_camel_case_types)]
129#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
130pub enum CipherSuite {
131    // TLS 1.3 cipher suites
132    TLS_AES_256_GCM_SHA384,
133    TLS_CHACHA20_POLY1305_SHA256,
134    TLS_AES_128_GCM_SHA256,
135    TLS_AES_128_CCM_SHA256,
136
137    // TLS 1.2 cipher suites (backward compatibility)
138    TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
139    TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
140    TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
141    TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
142}
143
144impl std::fmt::Display for CipherSuite {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        match self {
147            CipherSuite::TLS_AES_256_GCM_SHA384 => write!(f, "TLS_AES_256_GCM_SHA384"),
148            CipherSuite::TLS_CHACHA20_POLY1305_SHA256 => write!(f, "TLS_CHACHA20_POLY1305_SHA256"),
149            CipherSuite::TLS_AES_128_GCM_SHA256 => write!(f, "TLS_AES_128_GCM_SHA256"),
150            CipherSuite::TLS_AES_128_CCM_SHA256 => write!(f, "TLS_AES_128_CCM_SHA256"),
151            CipherSuite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 => {
152                write!(f, "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384")
153            }
154            CipherSuite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 => {
155                write!(f, "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
156            }
157            CipherSuite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 => {
158                write!(f, "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256")
159            }
160            CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 => {
161                write!(f, "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256")
162            }
163        }
164    }
165}
166
167/// Mutual TLS (mTLS) configuration
168#[derive(Debug, Clone, Serialize, Deserialize)]
169pub struct MutualTlsConfig {
170    /// Enable mutual TLS
171    pub enabled: bool,
172    /// Client certificate required
173    pub require_client_cert: bool,
174    /// Trusted client CA certificates
175    pub trusted_ca_certs: Vec<PathBuf>,
176    /// Client certificate verification depth
177    pub verification_depth: u8,
178    /// Revocation check configuration
179    pub revocation_check: RevocationCheckConfig,
180}
181
182impl Default for MutualTlsConfig {
183    fn default() -> Self {
184        Self {
185            enabled: false,
186            require_client_cert: true,
187            trusted_ca_certs: vec![],
188            verification_depth: 3,
189            revocation_check: RevocationCheckConfig::default(),
190        }
191    }
192}
193
194/// Certificate revocation check configuration
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct RevocationCheckConfig {
197    /// Enable revocation checking
198    pub enabled: bool,
199    /// Check CRL (Certificate Revocation List)
200    pub check_crl: bool,
201    /// Check OCSP (Online Certificate Status Protocol)
202    pub check_ocsp: bool,
203    /// CRL cache TTL in seconds
204    pub crl_cache_ttl: u64,
205}
206
207impl Default for RevocationCheckConfig {
208    fn default() -> Self {
209        Self {
210            enabled: true,
211            check_crl: true,
212            check_ocsp: true,
213            crl_cache_ttl: 3600,
214        }
215    }
216}
217
218/// Certificate rotation configuration
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct CertRotationConfig {
221    /// Enable automatic certificate rotation
222    pub enabled: bool,
223    /// Check interval in seconds
224    pub check_interval_secs: u64,
225    /// Rotation threshold (days before expiry)
226    pub rotation_threshold_days: u32,
227    /// Graceful rotation period (seconds)
228    pub graceful_period_secs: u64,
229}
230
231impl Default for CertRotationConfig {
232    fn default() -> Self {
233        Self {
234            enabled: true,
235            check_interval_secs: 3600,   // Check every hour
236            rotation_threshold_days: 30, // Rotate 30 days before expiry
237            graceful_period_secs: 300,   // 5 minutes graceful period
238        }
239    }
240}
241
242/// OCSP (Online Certificate Status Protocol) configuration
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct OcspConfig {
245    /// Enable OCSP stapling
246    pub enabled: bool,
247    /// OCSP responder URL
248    pub responder_url: Option<String>,
249    /// OCSP response cache TTL in seconds
250    pub cache_ttl: u64,
251    /// OCSP request timeout in seconds
252    pub timeout_secs: u64,
253}
254
255impl Default for OcspConfig {
256    fn default() -> Self {
257        Self {
258            enabled: true,
259            responder_url: None,
260            cache_ttl: 3600,
261            timeout_secs: 10,
262        }
263    }
264}
265
266/// Session resumption configuration
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct SessionResumptionConfig {
269    /// Enable session resumption
270    pub enabled: bool,
271    /// Session cache size
272    pub cache_size: usize,
273    /// Session ticket lifetime in seconds
274    pub ticket_lifetime_secs: u64,
275    /// Session ID lifetime in seconds
276    pub session_id_lifetime_secs: u64,
277}
278
279impl Default for SessionResumptionConfig {
280    fn default() -> Self {
281        Self {
282            enabled: true,
283            cache_size: 10000,
284            ticket_lifetime_secs: 7200,     // 2 hours
285            session_id_lifetime_secs: 7200, // 2 hours
286        }
287    }
288}
289
290/// TLS certificate information
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct CertificateInfo {
293    /// Certificate subject
294    pub subject: String,
295    /// Certificate issuer
296    pub issuer: String,
297    /// Serial number
298    pub serial_number: String,
299    /// Valid from
300    pub valid_from: DateTime<Utc>,
301    /// Valid until
302    pub valid_until: DateTime<Utc>,
303    /// Subject alternative names (SANs)
304    pub san: Vec<String>,
305    /// Key algorithm
306    pub key_algorithm: String,
307    /// Key size in bits
308    pub key_size: u32,
309    /// Signature algorithm
310    pub signature_algorithm: String,
311    /// Fingerprint (SHA-256)
312    pub fingerprint_sha256: String,
313}
314
315/// TLS session information
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct TlsSessionInfo {
318    /// Session ID
319    pub session_id: String,
320    /// TLS version used
321    pub protocol_version: TlsVersion,
322    /// Cipher suite used
323    pub cipher_suite: String,
324    /// Server name indication (SNI)
325    pub sni: Option<String>,
326    /// ALPN protocol negotiated
327    pub alpn_protocol: Option<String>,
328    /// Client certificate (if mTLS)
329    pub client_cert: Option<CertificateInfo>,
330    /// Established timestamp
331    pub established_at: DateTime<Utc>,
332}
333
334/// TLS manager for certificate and connection management
335pub struct TlsManager {
336    config: TlsConfig,
337    certificates: Arc<RwLock<HashMap<String, CertificateInfo>>>,
338    sessions: Arc<RwLock<HashMap<String, TlsSessionInfo>>>,
339    metrics: Arc<RwLock<TlsMetrics>>,
340}
341
342/// TLS metrics
343#[derive(Debug, Clone, Default, Serialize, Deserialize)]
344pub struct TlsMetrics {
345    /// Total TLS connections established
346    pub connections_established: u64,
347    /// Total TLS handshakes
348    pub handshakes_total: u64,
349    /// Failed handshakes
350    pub handshakes_failed: u64,
351    /// Certificate rotations performed
352    pub certificate_rotations: u64,
353    /// OCSP requests
354    pub ocsp_requests: u64,
355    /// Session resumptions
356    pub session_resumptions: u64,
357    /// Average handshake duration (ms)
358    pub avg_handshake_duration_ms: f64,
359    /// TLS version distribution
360    pub tls_version_distribution: HashMap<String, u64>,
361    /// Cipher suite distribution
362    pub cipher_suite_distribution: HashMap<String, u64>,
363}
364
365impl TlsManager {
366    /// Create a new TLS manager
367    pub fn new(config: TlsConfig) -> Self {
368        Self {
369            config,
370            certificates: Arc::new(RwLock::new(HashMap::new())),
371            sessions: Arc::new(RwLock::new(HashMap::new())),
372            metrics: Arc::new(RwLock::new(TlsMetrics::default())),
373        }
374    }
375
376    /// Initialize TLS manager and load certificates
377    pub async fn initialize(&self) -> Result<()> {
378        info!("Initializing TLS manager");
379
380        if !self.config.enabled {
381            warn!("TLS is disabled");
382            return Ok(());
383        }
384
385        // Validate certificate paths
386        self.validate_certificate_paths().await?;
387
388        // Load certificates
389        self.load_certificates().await?;
390
391        // Start certificate rotation monitor if enabled
392        if self.config.rotation.enabled {
393            self.start_rotation_monitor().await?;
394        }
395
396        info!("TLS manager initialized successfully");
397        Ok(())
398    }
399
400    /// Validate certificate file paths
401    async fn validate_certificate_paths(&self) -> Result<()> {
402        let cert_path = &self.config.certificates.server_cert_path;
403        let key_path = &self.config.certificates.server_key_path;
404
405        if !cert_path.exists() {
406            return Err(anyhow!("Server certificate not found: {:?}", cert_path));
407        }
408
409        if !key_path.exists() {
410            return Err(anyhow!("Server private key not found: {:?}", key_path));
411        }
412
413        if let Some(ca_path) = &self.config.certificates.ca_cert_path {
414            if !ca_path.exists() {
415                warn!("CA certificate not found: {:?}", ca_path);
416            }
417        }
418
419        debug!("Certificate paths validated");
420        Ok(())
421    }
422
423    /// Load certificates from disk
424    async fn load_certificates(&self) -> Result<()> {
425        info!("Loading TLS certificates");
426
427        // In a real implementation, this would:
428        // 1. Read certificate files from disk
429        // 2. Parse X.509 certificates
430        // 3. Extract certificate information
431        // 4. Store in certificates HashMap
432        // 5. Validate certificate chain
433
434        // For now, this is a placeholder
435        debug!("Certificates loaded successfully");
436        Ok(())
437    }
438
439    /// Start certificate rotation monitor
440    async fn start_rotation_monitor(&self) -> Result<()> {
441        info!("Starting certificate rotation monitor");
442
443        let check_interval = self.config.rotation.check_interval_secs;
444        let threshold_days = self.config.rotation.rotation_threshold_days;
445
446        // In a real implementation, this would spawn a background task
447        // that periodically checks certificate expiration and rotates
448        // certificates when necessary
449
450        debug!(
451            "Rotation monitor started (check_interval={}s, threshold={}d)",
452            check_interval, threshold_days
453        );
454        Ok(())
455    }
456
457    /// Perform TLS handshake (placeholder for actual implementation)
458    pub async fn handshake(&self, connection_id: &str) -> Result<TlsSessionInfo> {
459        let start_time = std::time::Instant::now();
460
461        // Record handshake attempt
462        {
463            let mut metrics = self.metrics.write().await;
464            metrics.handshakes_total += 1;
465        }
466
467        // Perform actual TLS handshake (placeholder)
468        let session_info = TlsSessionInfo {
469            session_id: connection_id.to_string(),
470            protocol_version: self.config.protocol_version,
471            cipher_suite: self.config.cipher_suites[0].to_string(),
472            sni: None,
473            alpn_protocol: self.config.alpn_protocols.first().cloned(),
474            client_cert: None,
475            established_at: Utc::now(),
476        };
477
478        // Store session
479        self.sessions
480            .write()
481            .await
482            .insert(connection_id.to_string(), session_info.clone());
483
484        // Update metrics
485        {
486            let mut metrics = self.metrics.write().await;
487            metrics.connections_established += 1;
488            let duration = start_time.elapsed().as_millis() as f64;
489            metrics.avg_handshake_duration_ms =
490                (metrics.avg_handshake_duration_ms + duration) / 2.0;
491
492            // Update TLS version distribution
493            let version_key = session_info.protocol_version.to_string();
494            *metrics
495                .tls_version_distribution
496                .entry(version_key)
497                .or_insert(0) += 1;
498
499            // Update cipher suite distribution
500            *metrics
501                .cipher_suite_distribution
502                .entry(session_info.cipher_suite.clone())
503                .or_insert(0) += 1;
504        }
505
506        debug!(
507            "TLS handshake completed for connection: {} in {:?}",
508            connection_id,
509            start_time.elapsed()
510        );
511
512        Ok(session_info)
513    }
514
515    /// Rotate certificates
516    pub async fn rotate_certificates(&self) -> Result<()> {
517        info!("Starting certificate rotation");
518
519        // In a real implementation, this would:
520        // 1. Load new certificates from disk or certificate management system
521        // 2. Validate new certificates
522        // 3. Gradually transition connections to new certificates
523        // 4. Monitor for issues during rotation
524        // 5. Complete rotation after graceful period
525
526        {
527            let mut metrics = self.metrics.write().await;
528            metrics.certificate_rotations += 1;
529        }
530
531        info!("Certificate rotation completed successfully");
532        Ok(())
533    }
534
535    /// Get TLS session information
536    pub async fn get_session_info(&self, session_id: &str) -> Option<TlsSessionInfo> {
537        self.sessions.read().await.get(session_id).cloned()
538    }
539
540    /// Get certificate information
541    pub async fn get_certificate_info(&self, cert_id: &str) -> Option<CertificateInfo> {
542        self.certificates.read().await.get(cert_id).cloned()
543    }
544
545    /// Get TLS metrics
546    pub async fn get_metrics(&self) -> TlsMetrics {
547        self.metrics.read().await.clone()
548    }
549
550    /// Close TLS session
551    pub async fn close_session(&self, session_id: &str) -> Result<()> {
552        self.sessions.write().await.remove(session_id);
553        debug!("TLS session closed: {}", session_id);
554        Ok(())
555    }
556
557    /// Validate certificate expiry
558    pub async fn check_certificate_expiry(&self) -> Result<Vec<ExpiryWarning>> {
559        let mut warnings = Vec::new();
560
561        let certificates = self.certificates.read().await;
562        let threshold_days = self.config.rotation.rotation_threshold_days;
563
564        for (cert_id, cert_info) in certificates.iter() {
565            let days_until_expiry = (cert_info.valid_until - Utc::now()).num_days();
566
567            if days_until_expiry < threshold_days as i64 {
568                warnings.push(ExpiryWarning {
569                    certificate_id: cert_id.clone(),
570                    subject: cert_info.subject.clone(),
571                    expires_at: cert_info.valid_until,
572                    days_until_expiry,
573                });
574
575                warn!(
576                    "Certificate {} expires in {} days",
577                    cert_id, days_until_expiry
578                );
579            }
580        }
581
582        Ok(warnings)
583    }
584}
585
586/// Certificate expiry warning
587#[derive(Debug, Clone, Serialize, Deserialize)]
588pub struct ExpiryWarning {
589    /// Certificate ID
590    pub certificate_id: String,
591    /// Certificate subject
592    pub subject: String,
593    /// Expiration date
594    pub expires_at: DateTime<Utc>,
595    /// Days until expiry
596    pub days_until_expiry: i64,
597}
598
599#[cfg(test)]
600mod tests {
601    use super::*;
602
603    #[tokio::test]
604    async fn test_tls_config_default() {
605        let config = TlsConfig::default();
606        assert!(config.enabled);
607        assert_eq!(config.protocol_version, TlsVersion::Tls13);
608        assert!(config.perfect_forward_secrecy);
609    }
610
611    #[tokio::test]
612    async fn test_tls_manager_creation() {
613        let config = TlsConfig::default();
614        let manager = TlsManager::new(config);
615        let metrics = manager.get_metrics().await;
616        assert_eq!(metrics.connections_established, 0);
617    }
618
619    #[tokio::test]
620    async fn test_cipher_suite_display() {
621        let suite = CipherSuite::TLS_AES_256_GCM_SHA384;
622        assert_eq!(suite.to_string(), "TLS_AES_256_GCM_SHA384");
623    }
624
625    #[tokio::test]
626    async fn test_tls_version_display() {
627        assert_eq!(TlsVersion::Tls13.to_string(), "TLS 1.3");
628        assert_eq!(TlsVersion::Tls12.to_string(), "TLS 1.2");
629    }
630
631    #[tokio::test]
632    async fn test_mtls_config_default() {
633        let config = MutualTlsConfig::default();
634        assert!(!config.enabled);
635        assert!(config.require_client_cert);
636        assert_eq!(config.verification_depth, 3);
637    }
638}