Skip to main content

rivven_core/
tls.rs

1//! Production-grade TLS/mTLS infrastructure for Rivven.
2//!
3//! This module provides comprehensive encryption for all communication paths:
4//! - Client → Broker (optional TLS, mTLS for high-security)
5//! - Connect → Broker (mTLS required for service-to-service)
6//! - Broker ↔ Broker (mTLS for cluster communication)
7//! - Admin → Broker (TLS/mTLS for management APIs)
8//!
9//! # Security Model
10//!
11//! Rivven follows a zero-trust security model for inter-service communication:
12//! - All internal communication uses mTLS by default
13//! - Certificate-based identity for services
14//! - Strong cipher suites only (TLS 1.3 preferred)
15//! - Certificate rotation support
16//! - Optional certificate pinning for high-security deployments
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use rivven_core::tls::{TlsConfigBuilder, TlsAcceptor, TlsConnector};
22//!
23//! // Server-side mTLS
24//! let server_config = TlsConfigBuilder::new()
25//!     .with_cert_file("server.crt")?
26//!     .with_key_file("server.key")?
27//!     .with_client_ca_file("ca.crt")?  // Enable mTLS
28//!     .require_client_cert(true)
29//!     .build()?;
30//!
31//! let acceptor = TlsAcceptor::new(server_config)?;
32//!
33//! // Client-side mTLS
34//! let client_config = TlsConfigBuilder::new()
35//!     .with_cert_file("client.crt")?
36//!     .with_key_file("client.key")?
37//!     .with_root_ca_file("ca.crt")?
38//!     .build()?;
39//!
40//! let connector = TlsConnector::new(client_config)?;
41//! ```
42
43use std::collections::HashMap;
44use std::fmt;
45use std::fs;
46use std::io::{self, BufReader, Cursor};
47use std::net::SocketAddr;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::{Duration, SystemTime};
51
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54use thiserror::Error;
55use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
56use tokio::net::TcpStream;
57
58// Re-export rustls types that users might need
59pub use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
60pub use rustls::{ClientConfig, ServerConfig};
61
62/// TLS-related errors
63#[derive(Debug, Error)]
64pub enum TlsError {
65    /// Certificate file not found or unreadable
66    #[error("Failed to read certificate file '{path}': {source}")]
67    CertificateReadError {
68        path: PathBuf,
69        #[source]
70        source: io::Error,
71    },
72
73    /// Private key file not found or unreadable
74    #[error("Failed to read private key file '{path}': {source}")]
75    KeyReadError {
76        path: PathBuf,
77        #[source]
78        source: io::Error,
79    },
80
81    /// Invalid certificate format
82    #[error("Invalid certificate format: {0}")]
83    InvalidCertificate(String),
84
85    /// Invalid private key format
86    #[error("Invalid private key format: {0}")]
87    InvalidPrivateKey(String),
88
89    /// Certificate chain validation failed
90    #[error("Certificate chain validation failed: {0}")]
91    CertificateChainError(String),
92
93    /// TLS handshake failed
94    #[error("TLS handshake failed: {0}")]
95    HandshakeError(String),
96
97    /// Connection error
98    #[error("Connection error: {0}")]
99    ConnectionError(String),
100
101    /// Configuration error
102    #[error("TLS configuration error: {0}")]
103    ConfigError(String),
104
105    /// Certificate expired
106    #[error("Certificate expired: {0}")]
107    CertificateExpired(String),
108
109    /// Certificate not yet valid
110    #[error("Certificate not yet valid: {0}")]
111    CertificateNotYetValid(String),
112
113    /// Certificate revoked
114    #[error("Certificate revoked: {0}")]
115    CertificateRevoked(String),
116
117    /// Hostname verification failed
118    #[error("Hostname verification failed: expected '{expected}', got '{actual}'")]
119    HostnameVerificationFailed { expected: String, actual: String },
120
121    /// mTLS required but client certificate not provided
122    #[error("Client certificate required for mTLS but not provided")]
123    ClientCertificateRequired,
124
125    /// Self-signed certificate generation failed
126    #[error("Failed to generate self-signed certificate: {0}")]
127    SelfSignedGenerationError(String),
128
129    /// ALPN negotiation failed
130    #[error("ALPN negotiation failed: no common protocol")]
131    AlpnNegotiationFailed,
132
133    /// Internal rustls error
134    #[error("TLS internal error: {0}")]
135    RustlsError(String),
136}
137
138impl From<rustls::Error> for TlsError {
139    fn from(err: rustls::Error) -> Self {
140        TlsError::RustlsError(err.to_string())
141    }
142}
143
144/// Result type for TLS operations
145pub type TlsResult<T> = std::result::Result<T, TlsError>;
146
147// ============================================================================
148// TLS Configuration
149// ============================================================================
150
151/// TLS protocol version
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
153#[serde(rename_all = "snake_case")]
154pub enum TlsVersion {
155    /// TLS 1.2 (minimum for compatibility)
156    Tls12,
157    /// TLS 1.3 (preferred, default)
158    #[default]
159    Tls13,
160}
161
162/// mTLS mode configuration
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
164#[serde(rename_all = "snake_case")]
165pub enum MtlsMode {
166    /// TLS without client certificate verification
167    #[default]
168    Disabled,
169    /// Request client certificate but don't require it
170    Optional,
171    /// Require valid client certificate (recommended for service-to-service)
172    Required,
173}
174
175/// Certificate source for flexible certificate loading
176#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum CertificateSource {
179    /// Load from PEM file
180    File { path: PathBuf },
181    /// Load from PEM string
182    Pem { content: String },
183    /// Load from DER bytes (base64 encoded in config)
184    Der { content: String },
185    /// Generate self-signed (development only)
186    SelfSigned { common_name: String },
187}
188
189/// Private key source for flexible key loading
190#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum PrivateKeySource {
193    /// Load from PEM file
194    File { path: PathBuf },
195    /// Load from PEM string
196    Pem { content: String },
197    /// Load from DER bytes (base64 encoded in config)
198    Der { content: String },
199}
200
201/// Complete TLS configuration for a component
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct TlsConfig {
204    /// Whether TLS is enabled
205    #[serde(default = "default_true")]
206    pub enabled: bool,
207
208    /// Server certificate and chain
209    pub certificate: Option<CertificateSource>,
210
211    /// Server private key
212    pub private_key: Option<PrivateKeySource>,
213
214    /// Root CA certificates for verification
215    pub root_ca: Option<CertificateSource>,
216
217    /// Client CA certificates for mTLS verification
218    pub client_ca: Option<CertificateSource>,
219
220    /// mTLS mode
221    #[serde(default)]
222    pub mtls_mode: MtlsMode,
223
224    /// Minimum TLS version
225    #[serde(default)]
226    pub min_version: TlsVersion,
227
228    /// ALPN protocols (e.g., ["h2", "http/1.1"])
229    #[serde(default)]
230    pub alpn_protocols: Vec<String>,
231
232    /// Enable OCSP stapling
233    #[serde(default)]
234    pub ocsp_stapling: bool,
235
236    /// Certificate pinning (SHA-256 fingerprints)
237    #[serde(default)]
238    pub pinned_certificates: Vec<String>,
239
240    /// Skip certificate verification (DANGEROUS - testing only)
241    #[serde(default)]
242    pub insecure_skip_verify: bool,
243
244    /// Server name for SNI (client-side)
245    pub server_name: Option<String>,
246
247    /// Session cache size (0 to disable)
248    #[serde(default = "default_session_cache_size")]
249    pub session_cache_size: usize,
250
251    /// Session ticket lifetime
252    #[serde(default = "default_session_ticket_lifetime")]
253    #[serde(with = "humantime_serde")]
254    pub session_ticket_lifetime: Duration,
255
256    /// Certificate reload interval (0 to disable)
257    #[serde(default)]
258    #[serde(with = "humantime_serde")]
259    pub cert_reload_interval: Duration,
260}
261
262fn default_true() -> bool {
263    true
264}
265
266fn default_session_cache_size() -> usize {
267    256
268}
269
270fn default_session_ticket_lifetime() -> Duration {
271    Duration::from_secs(86400) // 24 hours
272}
273
274impl Default for TlsConfig {
275    fn default() -> Self {
276        Self {
277            enabled: false, // Opt-in to TLS
278            certificate: None,
279            private_key: None,
280            root_ca: None,
281            client_ca: None,
282            mtls_mode: MtlsMode::Disabled,
283            min_version: TlsVersion::Tls13,
284            alpn_protocols: vec![],
285            ocsp_stapling: false,
286            pinned_certificates: vec![],
287            insecure_skip_verify: false,
288            server_name: None,
289            session_cache_size: default_session_cache_size(),
290            session_ticket_lifetime: default_session_ticket_lifetime(),
291            cert_reload_interval: Duration::ZERO,
292        }
293    }
294}
295
296impl TlsConfig {
297    /// Create a new disabled TLS configuration
298    pub fn disabled() -> Self {
299        Self::default()
300    }
301
302    /// Create TLS configuration for development with self-signed certificates
303    pub fn self_signed(common_name: &str) -> Self {
304        Self {
305            enabled: true,
306            certificate: Some(CertificateSource::SelfSigned {
307                common_name: common_name.to_string(),
308            }),
309            private_key: None,          // Generated with self-signed cert
310            insecure_skip_verify: true, // Required for self-signed
311            ..Default::default()
312        }
313    }
314
315    /// Create TLS configuration from PEM files
316    pub fn from_pem_files<P: Into<PathBuf>>(cert_path: P, key_path: P) -> Self {
317        Self {
318            enabled: true,
319            certificate: Some(CertificateSource::File {
320                path: cert_path.into(),
321            }),
322            private_key: Some(PrivateKeySource::File {
323                path: key_path.into(),
324            }),
325            ..Default::default()
326        }
327    }
328
329    /// Create mTLS configuration from PEM files
330    pub fn mtls_from_pem_files<P1, P2, P3>(cert_path: P1, key_path: P2, ca_path: P3) -> Self
331    where
332        P1: Into<PathBuf>,
333        P2: Into<PathBuf>,
334        P3: Into<PathBuf> + Clone,
335    {
336        let ca: PathBuf = ca_path.clone().into();
337        Self {
338            enabled: true,
339            certificate: Some(CertificateSource::File {
340                path: cert_path.into(),
341            }),
342            private_key: Some(PrivateKeySource::File {
343                path: key_path.into(),
344            }),
345            client_ca: Some(CertificateSource::File { path: ca.clone() }),
346            root_ca: Some(CertificateSource::File { path: ca }),
347            mtls_mode: MtlsMode::Required,
348            ..Default::default()
349        }
350    }
351}
352
353// ============================================================================
354// Builder Pattern for Complex Configurations
355// ============================================================================
356
357/// Builder for TLS configuration
358pub struct TlsConfigBuilder {
359    config: TlsConfig,
360}
361
362impl TlsConfigBuilder {
363    /// Create a new TLS configuration builder
364    pub fn new() -> Self {
365        Self {
366            config: TlsConfig {
367                enabled: true,
368                ..Default::default()
369            },
370        }
371    }
372
373    /// Set the server certificate from a file
374    pub fn with_cert_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
375        self.config.certificate = Some(CertificateSource::File { path: path.into() });
376        self
377    }
378
379    /// Set the server certificate from PEM content
380    pub fn with_cert_pem(mut self, pem: String) -> Self {
381        self.config.certificate = Some(CertificateSource::Pem { content: pem });
382        self
383    }
384
385    /// Set the private key from a file
386    pub fn with_key_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
387        self.config.private_key = Some(PrivateKeySource::File { path: path.into() });
388        self
389    }
390
391    /// Set the private key from PEM content
392    pub fn with_key_pem(mut self, pem: String) -> Self {
393        self.config.private_key = Some(PrivateKeySource::Pem { content: pem });
394        self
395    }
396
397    /// Set the root CA for server verification (client-side)
398    pub fn with_root_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
399        self.config.root_ca = Some(CertificateSource::File { path: path.into() });
400        self
401    }
402
403    /// Set the client CA for mTLS (server-side)
404    pub fn with_client_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
405        self.config.client_ca = Some(CertificateSource::File { path: path.into() });
406        self
407    }
408
409    /// Require client certificate (mTLS)
410    pub fn require_client_cert(mut self, required: bool) -> Self {
411        self.config.mtls_mode = if required {
412            MtlsMode::Required
413        } else {
414            MtlsMode::Disabled
415        };
416        self
417    }
418
419    /// Set mTLS mode
420    pub fn with_mtls_mode(mut self, mode: MtlsMode) -> Self {
421        self.config.mtls_mode = mode;
422        self
423    }
424
425    /// Set minimum TLS version
426    pub fn with_min_version(mut self, version: TlsVersion) -> Self {
427        self.config.min_version = version;
428        self
429    }
430
431    /// Set ALPN protocols
432    pub fn with_alpn_protocols(mut self, protocols: Vec<String>) -> Self {
433        self.config.alpn_protocols = protocols;
434        self
435    }
436
437    /// Set server name for SNI
438    pub fn with_server_name(mut self, name: String) -> Self {
439        self.config.server_name = Some(name);
440        self
441    }
442
443    /// Skip certificate verification (DANGEROUS - testing only)
444    pub fn insecure_skip_verify(mut self) -> Self {
445        self.config.insecure_skip_verify = true;
446        self
447    }
448
449    /// Add pinned certificate fingerprint
450    pub fn with_pinned_certificate(mut self, fingerprint: String) -> Self {
451        self.config.pinned_certificates.push(fingerprint);
452        self
453    }
454
455    /// Use self-signed certificate
456    pub fn with_self_signed(mut self, common_name: &str) -> Self {
457        self.config.certificate = Some(CertificateSource::SelfSigned {
458            common_name: common_name.to_string(),
459        });
460        self.config.insecure_skip_verify = true;
461        self
462    }
463
464    /// Enable certificate reloading
465    pub fn with_cert_reload_interval(mut self, interval: Duration) -> Self {
466        self.config.cert_reload_interval = interval;
467        self
468    }
469
470    /// Build the TLS configuration
471    pub fn build(self) -> TlsConfig {
472        self.config
473    }
474}
475
476impl Default for TlsConfigBuilder {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482// ============================================================================
483// Certificate Loading Utilities
484// ============================================================================
485
486/// Load certificates from a source
487pub fn load_certificates(source: &CertificateSource) -> TlsResult<Vec<CertificateDer<'static>>> {
488    match source {
489        CertificateSource::File { path } => {
490            let data = fs::read(path).map_err(|e| TlsError::CertificateReadError {
491                path: path.clone(),
492                source: e,
493            })?;
494            parse_pem_certificates(&data)
495        }
496        CertificateSource::Pem { content } => parse_pem_certificates(content.as_bytes()),
497        CertificateSource::Der { content } => {
498            let der =
499                base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
500                    .map_err(|e| TlsError::InvalidCertificate(format!("Invalid base64: {}", e)))?;
501            Ok(vec![CertificateDer::from(der)])
502        }
503        CertificateSource::SelfSigned { common_name } => {
504            let (cert, _key) = generate_self_signed(common_name)?;
505            Ok(vec![cert])
506        }
507    }
508}
509
510/// Parse PEM-encoded certificates
511fn parse_pem_certificates(data: &[u8]) -> TlsResult<Vec<CertificateDer<'static>>> {
512    let mut reader = BufReader::new(Cursor::new(data));
513    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
514        .collect::<Result<Vec<_>, _>>()
515        .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse PEM: {}", e)))?;
516
517    if certs.is_empty() {
518        return Err(TlsError::InvalidCertificate(
519            "No certificates found in PEM data".to_string(),
520        ));
521    }
522
523    Ok(certs)
524}
525
526/// Load private key from a source
527pub fn load_private_key(source: &PrivateKeySource) -> TlsResult<PrivateKeyDer<'static>> {
528    match source {
529        PrivateKeySource::File { path } => {
530            let data = fs::read(path).map_err(|e| TlsError::KeyReadError {
531                path: path.clone(),
532                source: e,
533            })?;
534            parse_pem_private_key(&data)
535        }
536        PrivateKeySource::Pem { content } => parse_pem_private_key(content.as_bytes()),
537        PrivateKeySource::Der { content } => {
538            let der = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
539                .map_err(|e| TlsError::InvalidPrivateKey(format!("Invalid base64: {}", e)))?;
540            Ok(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der)))
541        }
542    }
543}
544
545/// Parse PEM-encoded private key
546fn parse_pem_private_key(data: &[u8]) -> TlsResult<PrivateKeyDer<'static>> {
547    let mut reader = BufReader::new(Cursor::new(data));
548
549    rustls_pemfile::private_key(&mut reader)
550        .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse PEM: {}", e)))?
551        .ok_or_else(|| TlsError::InvalidPrivateKey("No private key found in PEM data".to_string()))
552}
553
554/// Generate self-signed certificate for development/testing
555pub fn generate_self_signed(
556    common_name: &str,
557) -> TlsResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
558    let subject_alt_names = vec![
559        common_name.to_string(),
560        "localhost".to_string(),
561        "127.0.0.1".to_string(),
562    ];
563
564    let mut cert_params = rcgen::CertificateParams::new(subject_alt_names)
565        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
566
567    // Set the distinguished name with proper common name
568    cert_params.distinguished_name = rcgen::DistinguishedName::new();
569    cert_params.distinguished_name.push(
570        rcgen::DnType::CommonName,
571        rcgen::DnValue::Utf8String(common_name.to_string()),
572    );
573    cert_params.distinguished_name.push(
574        rcgen::DnType::OrganizationName,
575        rcgen::DnValue::Utf8String("Rivven".to_string()),
576    );
577
578    let key_pair = rcgen::KeyPair::generate()
579        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
580
581    let cert = cert_params
582        .self_signed(&key_pair)
583        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
584
585    let cert_der = CertificateDer::from(cert.der().to_vec());
586    let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
587
588    Ok((cert_der, key_der))
589}
590
591// ============================================================================
592// Server-Side TLS (Acceptor)
593// ============================================================================
594
595/// TLS acceptor for server-side connections
596pub struct TlsAcceptor {
597    config: Arc<ServerConfig>,
598    inner: tokio_rustls::TlsAcceptor,
599    /// Configuration for hot-reloading
600    tls_config: TlsConfig,
601    /// Reloadable config handle
602    reloadable_config: Option<Arc<RwLock<Arc<ServerConfig>>>>,
603}
604
605impl TlsAcceptor {
606    /// Create a new TLS acceptor from configuration
607    pub fn new(config: &TlsConfig) -> TlsResult<Self> {
608        let server_config = build_server_config(config)?;
609        let server_config = Arc::new(server_config);
610
611        Ok(Self {
612            inner: tokio_rustls::TlsAcceptor::from(server_config.clone()),
613            config: server_config.clone(),
614            tls_config: config.clone(),
615            reloadable_config: if config.cert_reload_interval > Duration::ZERO {
616                Some(Arc::new(RwLock::new(server_config)))
617            } else {
618                None
619            },
620        })
621    }
622
623    /// Accept a TLS connection
624    ///
625    /// When hot-reloading is enabled, this uses the latest reloaded ServerConfig
626    /// so that new connections pick up rotated certificates immediately.
627    pub async fn accept<IO>(&self, stream: IO) -> TlsResult<TlsServerStream<IO>>
628    where
629        IO: AsyncRead + AsyncWrite + Unpin,
630    {
631        let acceptor = if let Some(ref reloadable) = self.reloadable_config {
632            let config = reloadable.read().clone();
633            tokio_rustls::TlsAcceptor::from(config)
634        } else {
635            self.inner.clone()
636        };
637
638        let tls_stream = acceptor
639            .accept(stream)
640            .await
641            .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
642
643        Ok(TlsServerStream { inner: tls_stream })
644    }
645
646    /// Accept a TCP connection with TLS
647    pub async fn accept_tcp(&self, stream: TcpStream) -> TlsResult<TlsServerStream<TcpStream>> {
648        self.accept(stream).await
649    }
650
651    /// Reload certificates (for hot-reloading)
652    pub fn reload(&mut self) -> TlsResult<()> {
653        let new_config = build_server_config(&self.tls_config)?;
654        let new_config = Arc::new(new_config);
655
656        // Update the active acceptor and config
657        self.inner = tokio_rustls::TlsAcceptor::from(new_config.clone());
658        self.config = new_config.clone();
659
660        if let Some(ref reloadable) = self.reloadable_config {
661            *reloadable.write() = new_config;
662        }
663        Ok(())
664    }
665
666    /// Get the underlying rustls ServerConfig
667    pub fn config(&self) -> &Arc<ServerConfig> {
668        &self.config
669    }
670}
671
672impl fmt::Debug for TlsAcceptor {
673    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
674        f.debug_struct("TlsAcceptor")
675            .field("mtls_mode", &self.tls_config.mtls_mode)
676            .field("min_version", &self.tls_config.min_version)
677            .finish()
678    }
679}
680
681/// Build rustls ServerConfig from TlsConfig
682fn build_server_config(config: &TlsConfig) -> TlsResult<ServerConfig> {
683    // Handle self-signed certificates specially to ensure cert and key match
684    let (certs, key) =
685        if let Some(CertificateSource::SelfSigned { common_name }) = &config.certificate {
686            // Generate both cert and key together to ensure they match
687            let (cert, key) = generate_self_signed(common_name)?;
688            (vec![cert], key)
689        } else {
690            // Load certificates from explicit sources
691            let certs = if let Some(ref cert_source) = config.certificate {
692                load_certificates(cert_source)?
693            } else {
694                return Err(TlsError::ConfigError(
695                    "Server certificate required".to_string(),
696                ));
697            };
698
699            // Load private key
700            let key = if let Some(ref key_source) = config.private_key {
701                load_private_key(key_source)?
702            } else {
703                return Err(TlsError::ConfigError("Private key required".to_string()));
704            };
705
706            (certs, key)
707        };
708
709    // Build TLS versions
710    let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
711        TlsVersion::Tls13 => vec![&rustls::version::TLS13],
712        TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
713    };
714
715    // Configure client certificate verification
716    let client_cert_verifier = match config.mtls_mode {
717        MtlsMode::Disabled => None,
718        MtlsMode::Optional | MtlsMode::Required => {
719            if let Some(ref ca_source) = config.client_ca {
720                let ca_certs = load_certificates(ca_source)?;
721                let mut root_store = rustls::RootCertStore::empty();
722                for cert in ca_certs {
723                    root_store.add(cert).map_err(|e| {
724                        TlsError::CertificateChainError(format!("Failed to add CA cert: {}", e))
725                    })?;
726                }
727
728                let verifier = if config.mtls_mode == MtlsMode::Required {
729                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
730                        .build()
731                        .map_err(|e| {
732                            TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
733                        })?
734                } else {
735                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
736                        .allow_unauthenticated()
737                        .build()
738                        .map_err(|e| {
739                            TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
740                        })?
741                };
742
743                Some(verifier)
744            } else if config.mtls_mode == MtlsMode::Required {
745                return Err(TlsError::ConfigError(
746                    "mTLS required but no client CA configured".to_string(),
747                ));
748            } else {
749                None
750            }
751        }
752    };
753
754    // Build server config
755    let mut server_config = if let Some(verifier) = client_cert_verifier {
756        ServerConfig::builder_with_protocol_versions(&versions)
757            .with_client_cert_verifier(verifier)
758            .with_single_cert(certs, key)
759            .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
760    } else {
761        ServerConfig::builder_with_protocol_versions(&versions)
762            .with_no_client_auth()
763            .with_single_cert(certs, key)
764            .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
765    };
766
767    // Configure ALPN
768    if !config.alpn_protocols.is_empty() {
769        server_config.alpn_protocols = config
770            .alpn_protocols
771            .iter()
772            .map(|p| p.as_bytes().to_vec())
773            .collect();
774    }
775
776    // Configure session cache
777    if config.session_cache_size > 0 {
778        server_config.session_storage =
779            rustls::server::ServerSessionMemoryCache::new(config.session_cache_size);
780    }
781
782    Ok(server_config)
783}
784
785// ============================================================================
786// Client-Side TLS (Connector)
787// ============================================================================
788
789/// TLS connector for client-side connections
790pub struct TlsConnector {
791    config: Arc<ClientConfig>,
792    inner: tokio_rustls::TlsConnector,
793    /// Default server name for SNI
794    server_name: Option<String>,
795}
796
797impl TlsConnector {
798    /// Create a new TLS connector from configuration
799    pub fn new(config: &TlsConfig) -> TlsResult<Self> {
800        let client_config = build_client_config(config)?;
801        let client_config = Arc::new(client_config);
802
803        Ok(Self {
804            inner: tokio_rustls::TlsConnector::from(client_config.clone()),
805            config: client_config,
806            server_name: config.server_name.clone(),
807        })
808    }
809
810    /// Connect to a server with TLS
811    pub async fn connect<IO>(&self, stream: IO, server_name: &str) -> TlsResult<TlsClientStream<IO>>
812    where
813        IO: AsyncRead + AsyncWrite + Unpin,
814    {
815        let name: rustls::pki_types::ServerName<'static> = server_name
816            .to_string()
817            .try_into()
818            .map_err(|_| TlsError::ConfigError(format!("Invalid server name: {}", server_name)))?;
819
820        let tls_stream = self
821            .inner
822            .connect(name, stream)
823            .await
824            .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
825
826        Ok(TlsClientStream { inner: tls_stream })
827    }
828
829    /// Connect to a TCP address with TLS
830    pub async fn connect_tcp(
831        &self,
832        addr: SocketAddr,
833        server_name: &str,
834    ) -> TlsResult<TlsClientStream<TcpStream>> {
835        let stream = TcpStream::connect(addr)
836            .await
837            .map_err(|e| TlsError::ConnectionError(e.to_string()))?;
838
839        self.connect(stream, server_name).await
840    }
841
842    /// Connect using the configured server name
843    pub async fn connect_with_default_name<IO>(&self, stream: IO) -> TlsResult<TlsClientStream<IO>>
844    where
845        IO: AsyncRead + AsyncWrite + Unpin,
846    {
847        let name = self.server_name.as_ref().ok_or_else(|| {
848            TlsError::ConfigError("No server name configured for SNI".to_string())
849        })?;
850        self.connect(stream, name).await
851    }
852
853    /// Get the underlying rustls ClientConfig
854    pub fn config(&self) -> &Arc<ClientConfig> {
855        &self.config
856    }
857}
858
859impl fmt::Debug for TlsConnector {
860    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
861        f.debug_struct("TlsConnector")
862            .field("server_name", &self.server_name)
863            .finish()
864    }
865}
866
867/// Build rustls ClientConfig from TlsConfig
868fn build_client_config(config: &TlsConfig) -> TlsResult<ClientConfig> {
869    // Build TLS versions
870    let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
871        TlsVersion::Tls13 => vec![&rustls::version::TLS13],
872        TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
873    };
874
875    // Build root certificate store
876    let root_store = if config.insecure_skip_verify {
877        // DANGEROUS: Trust all certificates (development only)
878        rustls::RootCertStore::empty()
879    } else if let Some(ref ca_source) = config.root_ca {
880        let ca_certs = load_certificates(ca_source)?;
881        let mut store = rustls::RootCertStore::empty();
882        for cert in ca_certs {
883            store.add(cert).map_err(|e| {
884                TlsError::CertificateChainError(format!("Failed to add root CA: {}", e))
885            })?;
886        }
887        store
888    } else {
889        // Use system root certificates
890        let mut store = rustls::RootCertStore::empty();
891        let native_certs = rustls_native_certs::load_native_certs();
892        for cert in native_certs.certs {
893            let _ = store.add(cert);
894        }
895        store
896    };
897
898    // Build client config with or without client certificate
899    let mut client_config = if let (Some(ref cert_source), Some(ref key_source)) =
900        (&config.certificate, &config.private_key)
901    {
902        // mTLS: provide client certificate
903        let certs = load_certificates(cert_source)?;
904        let key = load_private_key(key_source)?;
905
906        ClientConfig::builder_with_protocol_versions(&versions)
907            .with_root_certificates(root_store)
908            .with_client_auth_cert(certs, key)
909            .map_err(|e| TlsError::ConfigError(format!("Invalid client cert/key: {}", e)))?
910    } else if config.insecure_skip_verify {
911        // DANGEROUS: Skip server verification
912        ClientConfig::builder_with_protocol_versions(&versions)
913            .dangerous()
914            .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
915            .with_no_client_auth()
916    } else {
917        // Standard TLS without client certificate
918        ClientConfig::builder_with_protocol_versions(&versions)
919            .with_root_certificates(root_store)
920            .with_no_client_auth()
921    };
922
923    // Configure ALPN
924    if !config.alpn_protocols.is_empty() {
925        client_config.alpn_protocols = config
926            .alpn_protocols
927            .iter()
928            .map(|p| p.as_bytes().to_vec())
929            .collect();
930    }
931
932    Ok(client_config)
933}
934
935/// DANGEROUS: Certificate verifier that accepts any certificate
936/// Only for development/testing
937#[derive(Debug)]
938struct NoCertificateVerification;
939
940impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
941    fn verify_server_cert(
942        &self,
943        _end_entity: &CertificateDer<'_>,
944        _intermediates: &[CertificateDer<'_>],
945        _server_name: &rustls::pki_types::ServerName<'_>,
946        _ocsp_response: &[u8],
947        _now: rustls::pki_types::UnixTime,
948    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
949        Ok(rustls::client::danger::ServerCertVerified::assertion())
950    }
951
952    fn verify_tls12_signature(
953        &self,
954        _message: &[u8],
955        _cert: &CertificateDer<'_>,
956        _dss: &rustls::DigitallySignedStruct,
957    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
958        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
959    }
960
961    fn verify_tls13_signature(
962        &self,
963        _message: &[u8],
964        _cert: &CertificateDer<'_>,
965        _dss: &rustls::DigitallySignedStruct,
966    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
967        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
968    }
969
970    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
971        vec![
972            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
973            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
974            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
975            rustls::SignatureScheme::RSA_PSS_SHA256,
976            rustls::SignatureScheme::RSA_PSS_SHA384,
977            rustls::SignatureScheme::RSA_PSS_SHA512,
978            rustls::SignatureScheme::RSA_PKCS1_SHA256,
979            rustls::SignatureScheme::RSA_PKCS1_SHA384,
980            rustls::SignatureScheme::RSA_PKCS1_SHA512,
981            rustls::SignatureScheme::ED25519,
982        ]
983    }
984}
985
986// ============================================================================
987// TLS Streams
988// ============================================================================
989
990/// Server-side TLS stream
991pub struct TlsServerStream<IO> {
992    inner: tokio_rustls::server::TlsStream<IO>,
993}
994
995impl<IO> TlsServerStream<IO>
996where
997    IO: AsyncRead + AsyncWrite + Unpin,
998{
999    /// Get client certificate if presented
1000    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1001        self.inner.get_ref().1.peer_certificates()
1002    }
1003
1004    /// Get ALPN protocol if negotiated
1005    pub fn alpn_protocol(&self) -> Option<&[u8]> {
1006        self.inner.get_ref().1.alpn_protocol()
1007    }
1008
1009    /// Get negotiated protocol version
1010    pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1011        self.inner.get_ref().1.protocol_version()
1012    }
1013
1014    /// Get negotiated cipher suite
1015    pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
1016        self.inner.get_ref().1.negotiated_cipher_suite()
1017    }
1018
1019    /// Get cipher suite name as string
1020    pub fn cipher_suite_name(&self) -> Option<String> {
1021        self.negotiated_cipher_suite()
1022            .map(|cs| format!("{:?}", cs.suite()))
1023    }
1024
1025    /// Check if the connection uses TLS 1.3
1026    pub fn is_tls_13(&self) -> bool {
1027        self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1028    }
1029
1030    /// Extract the client certificate common name (CN)
1031    pub fn peer_common_name(&self) -> Option<String> {
1032        self.peer_certificates().and_then(|certs| {
1033            if certs.is_empty() {
1034                return None;
1035            }
1036            extract_common_name(&certs[0])
1037        })
1038    }
1039
1040    /// Extract the client certificate subject
1041    pub fn peer_subject(&self) -> Option<String> {
1042        self.peer_certificates().and_then(|certs| {
1043            if certs.is_empty() {
1044                return None;
1045            }
1046            extract_subject(&certs[0])
1047        })
1048    }
1049
1050    /// Get reference to the inner stream
1051    pub fn get_ref(&self) -> &IO {
1052        self.inner.get_ref().0
1053    }
1054
1055    /// Get mutable reference to the inner stream
1056    pub fn get_mut(&mut self) -> &mut IO {
1057        self.inner.get_mut().0
1058    }
1059
1060    /// Unwrap and get the inner stream
1061    pub fn into_inner(self) -> IO {
1062        self.inner.into_inner().0
1063    }
1064}
1065
1066impl<IO> tokio::io::AsyncRead for TlsServerStream<IO>
1067where
1068    IO: AsyncRead + AsyncWrite + Unpin,
1069{
1070    fn poll_read(
1071        mut self: std::pin::Pin<&mut Self>,
1072        cx: &mut std::task::Context<'_>,
1073        buf: &mut ReadBuf<'_>,
1074    ) -> std::task::Poll<io::Result<()>> {
1075        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1076    }
1077}
1078
1079impl<IO> tokio::io::AsyncWrite for TlsServerStream<IO>
1080where
1081    IO: AsyncRead + AsyncWrite + Unpin,
1082{
1083    fn poll_write(
1084        mut self: std::pin::Pin<&mut Self>,
1085        cx: &mut std::task::Context<'_>,
1086        buf: &[u8],
1087    ) -> std::task::Poll<io::Result<usize>> {
1088        std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1089    }
1090
1091    fn poll_flush(
1092        mut self: std::pin::Pin<&mut Self>,
1093        cx: &mut std::task::Context<'_>,
1094    ) -> std::task::Poll<io::Result<()>> {
1095        std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1096    }
1097
1098    fn poll_shutdown(
1099        mut self: std::pin::Pin<&mut Self>,
1100        cx: &mut std::task::Context<'_>,
1101    ) -> std::task::Poll<io::Result<()>> {
1102        std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1103    }
1104}
1105
1106/// Client-side TLS stream
1107pub struct TlsClientStream<IO> {
1108    inner: tokio_rustls::client::TlsStream<IO>,
1109}
1110
1111impl<IO> TlsClientStream<IO>
1112where
1113    IO: AsyncRead + AsyncWrite + Unpin,
1114{
1115    /// Get server certificate if provided
1116    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1117        self.inner.get_ref().1.peer_certificates()
1118    }
1119
1120    /// Get ALPN protocol if negotiated
1121    pub fn alpn_protocol(&self) -> Option<&[u8]> {
1122        self.inner.get_ref().1.alpn_protocol()
1123    }
1124
1125    /// Get negotiated protocol version
1126    pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1127        self.inner.get_ref().1.protocol_version()
1128    }
1129
1130    /// Check if the connection uses TLS 1.3
1131    pub fn is_tls_13(&self) -> bool {
1132        self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1133    }
1134
1135    /// Get reference to the inner stream
1136    pub fn get_ref(&self) -> &IO {
1137        self.inner.get_ref().0
1138    }
1139
1140    /// Get mutable reference to the inner stream
1141    pub fn get_mut(&mut self) -> &mut IO {
1142        self.inner.get_mut().0
1143    }
1144
1145    /// Unwrap and get the inner stream
1146    pub fn into_inner(self) -> IO {
1147        self.inner.into_inner().0
1148    }
1149}
1150
1151impl<IO> tokio::io::AsyncRead for TlsClientStream<IO>
1152where
1153    IO: AsyncRead + AsyncWrite + Unpin,
1154{
1155    fn poll_read(
1156        mut self: std::pin::Pin<&mut Self>,
1157        cx: &mut std::task::Context<'_>,
1158        buf: &mut ReadBuf<'_>,
1159    ) -> std::task::Poll<io::Result<()>> {
1160        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1161    }
1162}
1163
1164impl<IO> tokio::io::AsyncWrite for TlsClientStream<IO>
1165where
1166    IO: AsyncRead + AsyncWrite + Unpin,
1167{
1168    fn poll_write(
1169        mut self: std::pin::Pin<&mut Self>,
1170        cx: &mut std::task::Context<'_>,
1171        buf: &[u8],
1172    ) -> std::task::Poll<io::Result<usize>> {
1173        std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1174    }
1175
1176    fn poll_flush(
1177        mut self: std::pin::Pin<&mut Self>,
1178        cx: &mut std::task::Context<'_>,
1179    ) -> std::task::Poll<io::Result<()>> {
1180        std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1181    }
1182
1183    fn poll_shutdown(
1184        mut self: std::pin::Pin<&mut Self>,
1185        cx: &mut std::task::Context<'_>,
1186    ) -> std::task::Poll<io::Result<()>> {
1187        std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1188    }
1189}
1190
1191// ============================================================================
1192// Certificate Utilities
1193// ============================================================================
1194
1195/// Extract common name (CN) from certificate
1196fn extract_common_name(cert: &CertificateDer<'_>) -> Option<String> {
1197    // Parse the certificate using x509-parser
1198    let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1199
1200    for rdn in cert.subject().iter_rdn() {
1201        for attr in rdn.iter() {
1202            if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
1203                return attr.as_str().ok().map(|s| s.to_string());
1204            }
1205        }
1206    }
1207
1208    None
1209}
1210
1211/// Extract full subject from certificate
1212fn extract_subject(cert: &CertificateDer<'_>) -> Option<String> {
1213    let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1214    Some(cert.subject().to_string())
1215}
1216
1217/// Calculate SHA-256 fingerprint of a certificate
1218pub fn certificate_fingerprint(cert: &CertificateDer<'_>) -> String {
1219    use sha2::{Digest, Sha256};
1220    let hash = Sha256::digest(cert.as_ref());
1221    hex::encode(hash)
1222}
1223
1224/// Verify certificate chain
1225///
1226/// Note: This is a basic sanity check. The actual TLS handshake performs
1227/// full chain validation using WebPKI through rustls.
1228pub fn verify_certificate_chain(
1229    chain: &[CertificateDer<'_>],
1230    root_store: &rustls::RootCertStore,
1231) -> TlsResult<()> {
1232    if chain.is_empty() {
1233        return Err(TlsError::CertificateChainError(
1234            "Empty certificate chain".to_string(),
1235        ));
1236    }
1237
1238    // Basic sanity check - the actual validation happens during TLS handshake
1239    // via rustls WebPKI implementation
1240    if root_store.is_empty() {
1241        tracing::warn!("Root certificate store is empty - chain validation may fail");
1242    }
1243
1244    // Log certificate chain info for debugging
1245    for (i, cert) in chain.iter().enumerate() {
1246        let fingerprint = certificate_fingerprint(cert);
1247        tracing::debug!(
1248            "Certificate chain[{}]: fingerprint={}",
1249            i,
1250            &fingerprint[..16]
1251        );
1252    }
1253
1254    Ok(())
1255}
1256
1257// ============================================================================
1258// Certificate Watcher for Hot Reloading
1259// ============================================================================
1260
1261/// Watches certificate files and triggers reload on changes
1262pub struct CertificateWatcher {
1263    /// Files being watched
1264    watched_files: Vec<PathBuf>,
1265    /// Last modification times
1266    last_modified: HashMap<PathBuf, SystemTime>,
1267    /// Callback for reload
1268    reload_callback: Box<dyn Fn() + Send + Sync>,
1269}
1270
1271impl CertificateWatcher {
1272    /// Create a new certificate watcher
1273    pub fn new<F>(files: Vec<PathBuf>, callback: F) -> Self
1274    where
1275        F: Fn() + Send + Sync + 'static,
1276    {
1277        let mut last_modified = HashMap::new();
1278        for file in &files {
1279            if let Ok(meta) = fs::metadata(file) {
1280                if let Ok(modified) = meta.modified() {
1281                    last_modified.insert(file.clone(), modified);
1282                }
1283            }
1284        }
1285
1286        Self {
1287            watched_files: files,
1288            last_modified,
1289            reload_callback: Box::new(callback),
1290        }
1291    }
1292
1293    /// Check for file changes and trigger reload if needed
1294    pub fn check_and_reload(&mut self) -> bool {
1295        let mut changed = false;
1296
1297        for file in &self.watched_files {
1298            if let Ok(meta) = fs::metadata(file) {
1299                if let Ok(modified) = meta.modified() {
1300                    let last = self.last_modified.get(file);
1301                    if last.is_none_or(|&l| modified > l) {
1302                        self.last_modified.insert(file.clone(), modified);
1303                        changed = true;
1304                    }
1305                }
1306            }
1307        }
1308
1309        if changed {
1310            (self.reload_callback)();
1311        }
1312
1313        changed
1314    }
1315
1316    /// Start watching in background
1317    pub fn spawn(mut self, interval: Duration) -> tokio::task::JoinHandle<()> {
1318        tokio::spawn(async move {
1319            let mut ticker = tokio::time::interval(interval);
1320            loop {
1321                ticker.tick().await;
1322                self.check_and_reload();
1323            }
1324        })
1325    }
1326}
1327
1328// ============================================================================
1329// Connection Identity (mTLS Integration)
1330// ============================================================================
1331
1332/// Identity extracted from TLS connection
1333#[derive(Debug, Clone, Serialize, Deserialize)]
1334pub struct TlsIdentity {
1335    /// Certificate common name (CN)
1336    pub common_name: Option<String>,
1337    /// Full certificate subject
1338    pub subject: Option<String>,
1339    /// Certificate fingerprint (SHA-256)
1340    pub fingerprint: String,
1341    /// Organization from certificate
1342    pub organization: Option<String>,
1343    /// Organizational unit from certificate
1344    pub organizational_unit: Option<String>,
1345    /// Certificate serial number
1346    pub serial_number: Option<String>,
1347    /// Certificate validity period
1348    pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1349    pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1350    /// Is the certificate still valid
1351    pub is_valid: bool,
1352}
1353
1354impl TlsIdentity {
1355    /// Extract identity from a certificate
1356    pub fn from_certificate(cert: &CertificateDer<'_>) -> Self {
1357        let fingerprint = certificate_fingerprint(cert);
1358        let common_name = extract_common_name(cert);
1359        let subject = extract_subject(cert);
1360
1361        // Parse additional fields using x509-parser
1362        let (organization, organizational_unit, serial_number, valid_from, valid_until, is_valid) =
1363            if let Ok((_, parsed)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
1364                let mut org = None;
1365                let mut ou = None;
1366
1367                for rdn in parsed.subject().iter_rdn() {
1368                    for attr in rdn.iter() {
1369                        if attr.attr_type()
1370                            == &x509_parser::oid_registry::OID_X509_ORGANIZATION_NAME
1371                        {
1372                            org = attr.as_str().ok().map(|s| s.to_string());
1373                        }
1374                        if attr.attr_type()
1375                            == &x509_parser::oid_registry::OID_X509_ORGANIZATIONAL_UNIT
1376                        {
1377                            ou = attr.as_str().ok().map(|s| s.to_string());
1378                        }
1379                    }
1380                }
1381
1382                let serial = Some(parsed.serial.to_str_radix(16));
1383
1384                let validity = parsed.validity();
1385                let now = chrono::Utc::now();
1386
1387                let from = chrono::DateTime::from_timestamp(validity.not_before.timestamp(), 0);
1388                let until = chrono::DateTime::from_timestamp(validity.not_after.timestamp(), 0);
1389
1390                let valid = from.is_some_and(|f| now >= f) && until.is_some_and(|u| now <= u);
1391
1392                (org, ou, serial, from, until, valid)
1393            } else {
1394                (None, None, None, None, None, false)
1395            };
1396
1397        Self {
1398            common_name,
1399            subject,
1400            fingerprint,
1401            organization,
1402            organizational_unit,
1403            serial_number,
1404            valid_from,
1405            valid_until,
1406            is_valid,
1407        }
1408    }
1409}
1410
1411// ============================================================================
1412// Security Best Practices
1413// ============================================================================
1414
1415/// Security audit of TLS configuration
1416#[derive(Debug)]
1417pub struct TlsSecurityAudit {
1418    pub warnings: Vec<String>,
1419    pub errors: Vec<String>,
1420    pub recommendations: Vec<String>,
1421}
1422
1423impl TlsSecurityAudit {
1424    /// Audit a TLS configuration for security issues
1425    pub fn audit(config: &TlsConfig) -> Self {
1426        let mut audit = Self {
1427            warnings: vec![],
1428            errors: vec![],
1429            recommendations: vec![],
1430        };
1431
1432        if !config.enabled {
1433            audit
1434                .errors
1435                .push("TLS is disabled - all traffic will be unencrypted".to_string());
1436        }
1437
1438        if config.insecure_skip_verify {
1439            audit.errors.push(
1440                "Certificate verification is disabled - vulnerable to MITM attacks".to_string(),
1441            );
1442        }
1443
1444        if config.min_version == TlsVersion::Tls12 {
1445            audit.warnings.push(
1446                "TLS 1.2 is allowed - consider requiring TLS 1.3 for better security".to_string(),
1447            );
1448        }
1449
1450        if config.mtls_mode == MtlsMode::Disabled && config.client_ca.is_some() {
1451            audit.warnings.push(
1452                "Client CA configured but mTLS is disabled - clients won't be verified".to_string(),
1453            );
1454        }
1455
1456        if config.mtls_mode == MtlsMode::Optional {
1457            audit.warnings.push(
1458                "mTLS is optional - some clients may connect without certificates".to_string(),
1459            );
1460        }
1461
1462        if config.session_cache_size == 0 {
1463            audit
1464                .recommendations
1465                .push("Consider enabling session cache for better performance".to_string());
1466        }
1467
1468        if config.cert_reload_interval == Duration::ZERO {
1469            audit.recommendations.push(
1470                "Consider enabling certificate hot-reloading for zero-downtime rotation"
1471                    .to_string(),
1472            );
1473        }
1474
1475        if config.pinned_certificates.is_empty() && !config.insecure_skip_verify {
1476            audit
1477                .recommendations
1478                .push("Consider certificate pinning for high-security deployments".to_string());
1479        }
1480
1481        audit
1482    }
1483}
1484
1485// ============================================================================
1486// Tests
1487// ============================================================================
1488
1489#[cfg(test)]
1490mod tests {
1491    use super::*;
1492    use tokio::io::{AsyncReadExt, AsyncWriteExt};
1493
1494    #[test]
1495    fn test_tls_config_default() {
1496        let config = TlsConfig::default();
1497        assert!(!config.enabled);
1498        assert_eq!(config.mtls_mode, MtlsMode::Disabled);
1499        assert_eq!(config.min_version, TlsVersion::Tls13);
1500    }
1501
1502    #[test]
1503    fn test_tls_config_builder() {
1504        let config = TlsConfigBuilder::new()
1505            .with_cert_file("/path/to/cert.pem")
1506            .with_key_file("/path/to/key.pem")
1507            .with_client_ca_file("/path/to/ca.pem")
1508            .require_client_cert(true)
1509            .with_min_version(TlsVersion::Tls12)
1510            .build();
1511
1512        assert!(config.enabled);
1513        assert_eq!(config.mtls_mode, MtlsMode::Required);
1514        assert_eq!(config.min_version, TlsVersion::Tls12);
1515    }
1516
1517    #[tokio::test]
1518    async fn test_tls_server_client_handshake() {
1519        // Install crypto provider (required by rustls 0.23+)
1520        // Use aws_lc_rs which is the default in rustls 0.23+
1521        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1522
1523        // Use SelfSigned source which generates at runtime
1524        let server_config = TlsConfig {
1525            enabled: true,
1526            certificate: Some(CertificateSource::SelfSigned {
1527                common_name: "localhost".to_string(),
1528            }),
1529            // Key is auto-generated with self-signed
1530            mtls_mode: MtlsMode::Disabled,
1531            ..Default::default()
1532        };
1533
1534        // Create client config that skips verification (for self-signed)
1535        let client_config = TlsConfig {
1536            enabled: true,
1537            insecure_skip_verify: true,
1538            ..Default::default()
1539        };
1540
1541        // Create acceptor and connector
1542        let acceptor = TlsAcceptor::new(&server_config).unwrap();
1543        let connector = TlsConnector::new(&client_config).unwrap();
1544
1545        // Start a TCP listener
1546        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1547        let addr = listener.local_addr().unwrap();
1548
1549        // Server task: accept TLS connection and echo data
1550        let server_task = tokio::spawn(async move {
1551            let (tcp_stream, _) = listener.accept().await.unwrap();
1552            let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1553                acceptor.accept_tcp(tcp_stream).await.unwrap();
1554
1555            // Read data
1556            let mut buf = [0u8; 32];
1557            let n = tls_stream.read(&mut buf).await.unwrap();
1558
1559            // Echo it back
1560            tls_stream.write_all(&buf[..n]).await.unwrap();
1561            tls_stream.flush().await.unwrap();
1562
1563            n
1564        });
1565
1566        // Client task: connect and send data
1567        let client_task = tokio::spawn(async move {
1568            let mut stream: TlsClientStream<tokio::net::TcpStream> =
1569                connector.connect_tcp(addr, "localhost").await.unwrap();
1570
1571            // Send test message
1572            let message = b"Hello, TLS!";
1573            stream.write_all(message).await.unwrap();
1574            stream.flush().await.unwrap();
1575
1576            // Read response
1577            let mut response = [0u8; 32];
1578            let n = stream.read(&mut response).await.unwrap();
1579
1580            (message.to_vec(), response[..n].to_vec())
1581        });
1582
1583        // Wait for both tasks
1584        let (server_result, client_result) = tokio::join!(server_task, client_task);
1585
1586        let server_bytes_read = server_result.unwrap();
1587        let (sent, received) = client_result.unwrap();
1588
1589        // Verify echo worked
1590        assert_eq!(server_bytes_read, sent.len());
1591        assert_eq!(sent, received);
1592    }
1593
1594    #[tokio::test]
1595    async fn test_mtls_server_client_handshake() {
1596        use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, KeyUsagePurpose};
1597
1598        // Install crypto provider (required by rustls 0.23+)
1599        // Use aws_lc_rs which is the default in rustls 0.23+
1600        let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
1601
1602        // Generate a shared CA certificate
1603        let mut ca_params = CertificateParams::default();
1604        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
1605        ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
1606        ca_params
1607            .distinguished_name
1608            .push(DnType::CommonName, "Rivven Test CA");
1609        let ca_key_pair = rcgen::KeyPair::generate().unwrap();
1610        let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
1611        let ca_cert_pem = ca_cert.pem();
1612
1613        // Generate server certificate signed by CA
1614        let mut server_params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
1615        server_params
1616            .distinguished_name
1617            .push(DnType::CommonName, "localhost");
1618        let server_key_pair = rcgen::KeyPair::generate().unwrap();
1619        let server_cert = server_params
1620            .signed_by(&server_key_pair, &ca_cert, &ca_key_pair)
1621            .unwrap();
1622        let server_cert_pem = server_cert.pem();
1623        let server_key_pem = server_key_pair.serialize_pem();
1624
1625        // Generate client certificate signed by CA
1626        let mut client_params =
1627            CertificateParams::new(vec!["client.rivven.local".to_string()]).unwrap();
1628        client_params
1629            .distinguished_name
1630            .push(DnType::CommonName, "client.rivven.local");
1631        let client_key_pair = rcgen::KeyPair::generate().unwrap();
1632        let client_cert = client_params
1633            .signed_by(&client_key_pair, &ca_cert, &ca_key_pair)
1634            .unwrap();
1635        let client_cert_pem = client_cert.pem();
1636        let client_key_pem = client_key_pair.serialize_pem();
1637
1638        // Server config with mTLS required
1639        let server_config = TlsConfig {
1640            enabled: true,
1641            certificate: Some(CertificateSource::Pem {
1642                content: server_cert_pem,
1643            }),
1644            private_key: Some(PrivateKeySource::Pem {
1645                content: server_key_pem,
1646            }),
1647            client_ca: Some(CertificateSource::Pem {
1648                content: ca_cert_pem.clone(),
1649            }),
1650            mtls_mode: MtlsMode::Required,
1651            insecure_skip_verify: false,
1652            ..Default::default()
1653        };
1654
1655        // Client config with client cert and CA trust
1656        let client_config = TlsConfig {
1657            enabled: true,
1658            certificate: Some(CertificateSource::Pem {
1659                content: client_cert_pem,
1660            }),
1661            private_key: Some(PrivateKeySource::Pem {
1662                content: client_key_pem,
1663            }),
1664            root_ca: Some(CertificateSource::Pem {
1665                content: ca_cert_pem,
1666            }),
1667            insecure_skip_verify: false,
1668            ..Default::default()
1669        };
1670
1671        // Create acceptor and connector
1672        let acceptor = TlsAcceptor::new(&server_config).unwrap();
1673        let connector = TlsConnector::new(&client_config).unwrap();
1674
1675        // Start a TCP listener
1676        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1677        let addr = listener.local_addr().unwrap();
1678
1679        // Server task
1680        let server_task = tokio::spawn(async move {
1681            let (tcp_stream, _) = listener.accept().await.unwrap();
1682            let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1683                acceptor.accept_tcp(tcp_stream).await.unwrap();
1684
1685            // Check if we can see peer certificates (mTLS)
1686            let has_peer_cert = tls_stream.peer_certificates().is_some();
1687
1688            // Read data
1689            let mut buf = [0u8; 32];
1690            let n = tls_stream.read(&mut buf).await.unwrap();
1691            tls_stream.write_all(&buf[..n]).await.unwrap();
1692            tls_stream.flush().await.unwrap();
1693
1694            (n, has_peer_cert)
1695        });
1696
1697        // Client task
1698        let client_task = tokio::spawn(async move {
1699            let mut stream: TlsClientStream<tokio::net::TcpStream> =
1700                connector.connect_tcp(addr, "localhost").await.unwrap();
1701
1702            // Send test message
1703            let message = b"mTLS Test!";
1704            stream.write_all(message).await.unwrap();
1705            stream.flush().await.unwrap();
1706
1707            // Read response
1708            let mut response = [0u8; 32];
1709            let n = stream.read(&mut response).await.unwrap();
1710
1711            (message.to_vec(), response[..n].to_vec())
1712        });
1713
1714        // Wait for both tasks
1715        let (server_result, client_result) = tokio::join!(server_task, client_task);
1716
1717        let (server_bytes_read, has_peer_cert) = server_result.unwrap();
1718        let (sent, received) = client_result.unwrap();
1719
1720        // Verify echo worked
1721        assert_eq!(server_bytes_read, sent.len());
1722        assert_eq!(sent, received);
1723
1724        // Verify mTLS - server saw client certificate
1725        assert!(
1726            has_peer_cert,
1727            "Server should have received client certificate in mTLS"
1728        );
1729    }
1730
1731    #[test]
1732    fn test_self_signed_generation() {
1733        let result = generate_self_signed("test.rivven.local");
1734        assert!(result.is_ok());
1735
1736        let (cert, _key) = result.unwrap();
1737        assert!(!cert.as_ref().is_empty());
1738
1739        // Verify we can extract identity
1740        let identity = TlsIdentity::from_certificate(&cert);
1741        assert_eq!(identity.common_name, Some("test.rivven.local".to_string()));
1742        assert!(identity.is_valid);
1743    }
1744
1745    #[test]
1746    fn test_certificate_fingerprint() {
1747        let (cert, _) = generate_self_signed("test.rivven.local").unwrap();
1748        let fingerprint = certificate_fingerprint(&cert);
1749
1750        // Should be 64 hex characters (SHA-256)
1751        assert_eq!(fingerprint.len(), 64);
1752        assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
1753    }
1754
1755    #[test]
1756    fn test_tls_security_audit_disabled() {
1757        let config = TlsConfig::disabled();
1758        let audit = TlsSecurityAudit::audit(&config);
1759
1760        assert!(!audit.errors.is_empty());
1761        assert!(audit.errors.iter().any(|e| e.contains("disabled")));
1762    }
1763
1764    #[test]
1765    fn test_tls_security_audit_insecure() {
1766        let config = TlsConfig {
1767            enabled: true,
1768            insecure_skip_verify: true,
1769            ..Default::default()
1770        };
1771        let audit = TlsSecurityAudit::audit(&config);
1772
1773        assert!(audit.errors.iter().any(|e| e.contains("MITM")));
1774    }
1775
1776    #[test]
1777    fn test_tls_security_audit_production_ready() {
1778        let (_cert, _key) = generate_self_signed("broker.rivven.local").unwrap();
1779
1780        let config = TlsConfig {
1781            enabled: true,
1782            certificate: Some(CertificateSource::SelfSigned {
1783                common_name: "broker.rivven.local".to_string(),
1784            }),
1785            mtls_mode: MtlsMode::Required,
1786            min_version: TlsVersion::Tls13,
1787            insecure_skip_verify: false,
1788            session_cache_size: 256,
1789            ..Default::default()
1790        };
1791
1792        let audit = TlsSecurityAudit::audit(&config);
1793
1794        // Should have no errors for a well-configured setup
1795        // (Note: mTLS Required without client_ca would fail at runtime, but audit catches config issues)
1796        assert!(audit.errors.is_empty() || audit.errors.iter().all(|e| !e.contains("disabled")));
1797    }
1798
1799    #[test]
1800    fn test_mtls_modes() {
1801        assert_eq!(MtlsMode::default(), MtlsMode::Disabled);
1802
1803        let modes = [MtlsMode::Disabled, MtlsMode::Optional, MtlsMode::Required];
1804        for mode in modes {
1805            let json = serde_json::to_string(&mode).unwrap();
1806            let parsed: MtlsMode = serde_json::from_str(&json).unwrap();
1807            assert_eq!(mode, parsed);
1808        }
1809    }
1810
1811    #[test]
1812    fn test_tls_identity_extraction() {
1813        let (cert, _) = generate_self_signed("service.rivven.internal").unwrap();
1814        let identity = TlsIdentity::from_certificate(&cert);
1815
1816        assert_eq!(
1817            identity.common_name,
1818            Some("service.rivven.internal".to_string())
1819        );
1820        assert!(identity.is_valid);
1821        assert!(identity.valid_from.is_some());
1822        assert!(identity.valid_until.is_some());
1823        assert!(!identity.fingerprint.is_empty());
1824    }
1825}