Skip to main content

rivven_core/
tls.rs

1//! Production-grade TLS/mTLS infrastructure for Rivven.
2//!
3//! This module provides comprehensive encryption for all communication paths:
4//! - Client → Broker (optional TLS, mTLS for high-security)
5//! - Connect → Broker (mTLS required for service-to-service)
6//! - Broker ↔ Broker (mTLS for cluster communication)
7//! - Admin → Broker (TLS/mTLS for management APIs)
8//!
9//! # Security Model
10//!
11//! Rivven follows a zero-trust security model for inter-service communication:
12//! - All internal communication uses mTLS by default
13//! - Certificate-based identity for services
14//! - Strong cipher suites only (TLS 1.3 preferred)
15//! - Certificate rotation support
16//! - Optional certificate pinning for high-security deployments
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! use rivven_core::tls::{TlsConfigBuilder, TlsAcceptor, TlsConnector};
22//!
23//! // Server-side mTLS
24//! let server_config = TlsConfigBuilder::new()
25//!     .with_cert_file("server.crt")?
26//!     .with_key_file("server.key")?
27//!     .with_client_ca_file("ca.crt")?  // Enable mTLS
28//!     .require_client_cert(true)
29//!     .build()?;
30//!
31//! let acceptor = TlsAcceptor::new(server_config)?;
32//!
33//! // Client-side mTLS
34//! let client_config = TlsConfigBuilder::new()
35//!     .with_cert_file("client.crt")?
36//!     .with_key_file("client.key")?
37//!     .with_root_ca_file("ca.crt")?
38//!     .build()?;
39//!
40//! let connector = TlsConnector::new(client_config)?;
41//! ```
42
43use std::collections::HashMap;
44use std::fmt;
45use std::fs;
46use std::io::{self, BufReader, Cursor};
47use std::net::SocketAddr;
48use std::path::PathBuf;
49use std::sync::Arc;
50use std::time::{Duration, SystemTime};
51
52use parking_lot::RwLock;
53use serde::{Deserialize, Serialize};
54use thiserror::Error;
55use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
56use tokio::net::TcpStream;
57
58// Re-export rustls types that users might need
59pub use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
60pub use rustls::{ClientConfig, ServerConfig};
61
62/// TLS-related errors
63#[derive(Debug, Error)]
64pub enum TlsError {
65    /// Certificate file not found or unreadable
66    #[error("Failed to read certificate file '{path}': {source}")]
67    CertificateReadError {
68        path: PathBuf,
69        #[source]
70        source: io::Error,
71    },
72
73    /// Private key file not found or unreadable
74    #[error("Failed to read private key file '{path}': {source}")]
75    KeyReadError {
76        path: PathBuf,
77        #[source]
78        source: io::Error,
79    },
80
81    /// Invalid certificate format
82    #[error("Invalid certificate format: {0}")]
83    InvalidCertificate(String),
84
85    /// Invalid private key format
86    #[error("Invalid private key format: {0}")]
87    InvalidPrivateKey(String),
88
89    /// Certificate chain validation failed
90    #[error("Certificate chain validation failed: {0}")]
91    CertificateChainError(String),
92
93    /// TLS handshake failed
94    #[error("TLS handshake failed: {0}")]
95    HandshakeError(String),
96
97    /// Connection error
98    #[error("Connection error: {0}")]
99    ConnectionError(String),
100
101    /// Configuration error
102    #[error("TLS configuration error: {0}")]
103    ConfigError(String),
104
105    /// Certificate expired
106    #[error("Certificate expired: {0}")]
107    CertificateExpired(String),
108
109    /// Certificate not yet valid
110    #[error("Certificate not yet valid: {0}")]
111    CertificateNotYetValid(String),
112
113    /// Certificate revoked
114    #[error("Certificate revoked: {0}")]
115    CertificateRevoked(String),
116
117    /// Hostname verification failed
118    #[error("Hostname verification failed: expected '{expected}', got '{actual}'")]
119    HostnameVerificationFailed { expected: String, actual: String },
120
121    /// mTLS required but client certificate not provided
122    #[error("Client certificate required for mTLS but not provided")]
123    ClientCertificateRequired,
124
125    /// Self-signed certificate generation failed
126    #[error("Failed to generate self-signed certificate: {0}")]
127    SelfSignedGenerationError(String),
128
129    /// ALPN negotiation failed
130    #[error("ALPN negotiation failed: no common protocol")]
131    AlpnNegotiationFailed,
132
133    /// Internal rustls error
134    #[error("TLS internal error: {0}")]
135    RustlsError(String),
136}
137
138impl From<rustls::Error> for TlsError {
139    fn from(err: rustls::Error) -> Self {
140        TlsError::RustlsError(err.to_string())
141    }
142}
143
144/// Result type for TLS operations
145pub type TlsResult<T> = std::result::Result<T, TlsError>;
146
147// ============================================================================
148// TLS Configuration
149// ============================================================================
150
151/// TLS protocol version
152#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
153#[serde(rename_all = "snake_case")]
154pub enum TlsVersion {
155    /// TLS 1.2 (minimum for compatibility)
156    Tls12,
157    /// TLS 1.3 (preferred, default)
158    #[default]
159    Tls13,
160}
161
162/// mTLS mode configuration
163#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
164#[serde(rename_all = "snake_case")]
165pub enum MtlsMode {
166    /// TLS without client certificate verification
167    #[default]
168    Disabled,
169    /// Request client certificate but don't require it
170    Optional,
171    /// Require valid client certificate (recommended for service-to-service)
172    Required,
173}
174
175/// Certificate source for flexible certificate loading
176#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(tag = "type", rename_all = "snake_case")]
178pub enum CertificateSource {
179    /// Load from PEM file
180    File { path: PathBuf },
181    /// Load from PEM string
182    Pem { content: String },
183    /// Load from DER bytes (base64 encoded in config)
184    Der { content: String },
185    /// Generate self-signed (development only)
186    SelfSigned { common_name: String },
187}
188
189/// Private key source for flexible key loading
190#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(tag = "type", rename_all = "snake_case")]
192pub enum PrivateKeySource {
193    /// Load from PEM file
194    File { path: PathBuf },
195    /// Load from PEM string
196    Pem { content: String },
197    /// Load from DER bytes (base64 encoded in config)
198    Der { content: String },
199}
200
201/// Complete TLS configuration for a component
202#[derive(Debug, Clone, Serialize, Deserialize)]
203pub struct TlsConfig {
204    /// Whether TLS is enabled
205    #[serde(default = "default_true")]
206    pub enabled: bool,
207
208    /// Server certificate and chain
209    pub certificate: Option<CertificateSource>,
210
211    /// Server private key
212    pub private_key: Option<PrivateKeySource>,
213
214    /// Root CA certificates for verification
215    pub root_ca: Option<CertificateSource>,
216
217    /// Client CA certificates for mTLS verification
218    pub client_ca: Option<CertificateSource>,
219
220    /// mTLS mode
221    #[serde(default)]
222    pub mtls_mode: MtlsMode,
223
224    /// Minimum TLS version
225    #[serde(default)]
226    pub min_version: TlsVersion,
227
228    /// ALPN protocols (e.g., ["h2", "http/1.1"])
229    #[serde(default)]
230    pub alpn_protocols: Vec<String>,
231
232    /// Enable OCSP stapling
233    #[serde(default)]
234    pub ocsp_stapling: bool,
235
236    /// Certificate pinning (SHA-256 fingerprints)
237    #[serde(default)]
238    pub pinned_certificates: Vec<String>,
239
240    /// Skip certificate verification (DANGEROUS - testing only)
241    #[serde(default)]
242    pub insecure_skip_verify: bool,
243
244    /// Server name for SNI (client-side)
245    pub server_name: Option<String>,
246
247    /// Session cache size (0 to disable)
248    #[serde(default = "default_session_cache_size")]
249    pub session_cache_size: usize,
250
251    /// Session ticket lifetime
252    #[serde(default = "default_session_ticket_lifetime")]
253    #[serde(with = "humantime_serde")]
254    pub session_ticket_lifetime: Duration,
255
256    /// Certificate reload interval (0 to disable)
257    #[serde(default)]
258    #[serde(with = "humantime_serde")]
259    pub cert_reload_interval: Duration,
260}
261
262fn default_true() -> bool {
263    true
264}
265
266fn default_session_cache_size() -> usize {
267    256
268}
269
270fn default_session_ticket_lifetime() -> Duration {
271    Duration::from_secs(86400) // 24 hours
272}
273
274impl Default for TlsConfig {
275    fn default() -> Self {
276        Self {
277            enabled: false, // Opt-in to TLS
278            certificate: None,
279            private_key: None,
280            root_ca: None,
281            client_ca: None,
282            mtls_mode: MtlsMode::Disabled,
283            min_version: TlsVersion::Tls13,
284            alpn_protocols: vec![],
285            ocsp_stapling: false,
286            pinned_certificates: vec![],
287            insecure_skip_verify: false,
288            server_name: None,
289            session_cache_size: default_session_cache_size(),
290            session_ticket_lifetime: default_session_ticket_lifetime(),
291            cert_reload_interval: Duration::ZERO,
292        }
293    }
294}
295
296impl TlsConfig {
297    /// Create a new disabled TLS configuration
298    pub fn disabled() -> Self {
299        Self::default()
300    }
301
302    /// Create TLS configuration for development with self-signed certificates
303    pub fn self_signed(common_name: &str) -> Self {
304        Self {
305            enabled: true,
306            certificate: Some(CertificateSource::SelfSigned {
307                common_name: common_name.to_string(),
308            }),
309            private_key: None,          // Generated with self-signed cert
310            insecure_skip_verify: true, // Required for self-signed
311            ..Default::default()
312        }
313    }
314
315    /// Create TLS configuration from PEM files
316    pub fn from_pem_files<P: Into<PathBuf>>(cert_path: P, key_path: P) -> Self {
317        Self {
318            enabled: true,
319            certificate: Some(CertificateSource::File {
320                path: cert_path.into(),
321            }),
322            private_key: Some(PrivateKeySource::File {
323                path: key_path.into(),
324            }),
325            ..Default::default()
326        }
327    }
328
329    /// Create mTLS configuration from PEM files
330    pub fn mtls_from_pem_files<P1, P2, P3>(cert_path: P1, key_path: P2, ca_path: P3) -> Self
331    where
332        P1: Into<PathBuf>,
333        P2: Into<PathBuf>,
334        P3: Into<PathBuf> + Clone,
335    {
336        let ca: PathBuf = ca_path.clone().into();
337        Self {
338            enabled: true,
339            certificate: Some(CertificateSource::File {
340                path: cert_path.into(),
341            }),
342            private_key: Some(PrivateKeySource::File {
343                path: key_path.into(),
344            }),
345            client_ca: Some(CertificateSource::File { path: ca.clone() }),
346            root_ca: Some(CertificateSource::File { path: ca }),
347            mtls_mode: MtlsMode::Required,
348            ..Default::default()
349        }
350    }
351}
352
353// ============================================================================
354// Builder Pattern for Complex Configurations
355// ============================================================================
356
357/// Builder for TLS configuration
358pub struct TlsConfigBuilder {
359    config: TlsConfig,
360}
361
362impl TlsConfigBuilder {
363    /// Create a new TLS configuration builder
364    pub fn new() -> Self {
365        Self {
366            config: TlsConfig {
367                enabled: true,
368                ..Default::default()
369            },
370        }
371    }
372
373    /// Set the server certificate from a file
374    pub fn with_cert_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
375        self.config.certificate = Some(CertificateSource::File { path: path.into() });
376        self
377    }
378
379    /// Set the server certificate from PEM content
380    pub fn with_cert_pem(mut self, pem: String) -> Self {
381        self.config.certificate = Some(CertificateSource::Pem { content: pem });
382        self
383    }
384
385    /// Set the private key from a file
386    pub fn with_key_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
387        self.config.private_key = Some(PrivateKeySource::File { path: path.into() });
388        self
389    }
390
391    /// Set the private key from PEM content
392    pub fn with_key_pem(mut self, pem: String) -> Self {
393        self.config.private_key = Some(PrivateKeySource::Pem { content: pem });
394        self
395    }
396
397    /// Set the root CA for server verification (client-side)
398    pub fn with_root_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
399        self.config.root_ca = Some(CertificateSource::File { path: path.into() });
400        self
401    }
402
403    /// Set the client CA for mTLS (server-side)
404    pub fn with_client_ca_file<P: Into<PathBuf>>(mut self, path: P) -> Self {
405        self.config.client_ca = Some(CertificateSource::File { path: path.into() });
406        self
407    }
408
409    /// Require client certificate (mTLS)
410    pub fn require_client_cert(mut self, required: bool) -> Self {
411        self.config.mtls_mode = if required {
412            MtlsMode::Required
413        } else {
414            MtlsMode::Disabled
415        };
416        self
417    }
418
419    /// Set mTLS mode
420    pub fn with_mtls_mode(mut self, mode: MtlsMode) -> Self {
421        self.config.mtls_mode = mode;
422        self
423    }
424
425    /// Set minimum TLS version
426    pub fn with_min_version(mut self, version: TlsVersion) -> Self {
427        self.config.min_version = version;
428        self
429    }
430
431    /// Set ALPN protocols
432    pub fn with_alpn_protocols(mut self, protocols: Vec<String>) -> Self {
433        self.config.alpn_protocols = protocols;
434        self
435    }
436
437    /// Set server name for SNI
438    pub fn with_server_name(mut self, name: String) -> Self {
439        self.config.server_name = Some(name);
440        self
441    }
442
443    /// Skip certificate verification (DANGEROUS - testing only)
444    pub fn insecure_skip_verify(mut self) -> Self {
445        self.config.insecure_skip_verify = true;
446        self
447    }
448
449    /// Add pinned certificate fingerprint
450    pub fn with_pinned_certificate(mut self, fingerprint: String) -> Self {
451        self.config.pinned_certificates.push(fingerprint);
452        self
453    }
454
455    /// Use self-signed certificate
456    pub fn with_self_signed(mut self, common_name: &str) -> Self {
457        self.config.certificate = Some(CertificateSource::SelfSigned {
458            common_name: common_name.to_string(),
459        });
460        self.config.insecure_skip_verify = true;
461        self
462    }
463
464    /// Enable certificate reloading
465    pub fn with_cert_reload_interval(mut self, interval: Duration) -> Self {
466        self.config.cert_reload_interval = interval;
467        self
468    }
469
470    /// Build the TLS configuration
471    pub fn build(self) -> TlsConfig {
472        self.config
473    }
474}
475
476impl Default for TlsConfigBuilder {
477    fn default() -> Self {
478        Self::new()
479    }
480}
481
482// ============================================================================
483// Certificate Loading Utilities
484// ============================================================================
485
486/// Load certificates from a source
487pub fn load_certificates(source: &CertificateSource) -> TlsResult<Vec<CertificateDer<'static>>> {
488    match source {
489        CertificateSource::File { path } => {
490            let data = fs::read(path).map_err(|e| TlsError::CertificateReadError {
491                path: path.clone(),
492                source: e,
493            })?;
494            parse_pem_certificates(&data)
495        }
496        CertificateSource::Pem { content } => parse_pem_certificates(content.as_bytes()),
497        CertificateSource::Der { content } => {
498            let der =
499                base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
500                    .map_err(|e| TlsError::InvalidCertificate(format!("Invalid base64: {}", e)))?;
501            Ok(vec![CertificateDer::from(der)])
502        }
503        CertificateSource::SelfSigned { common_name } => {
504            let (cert, _key) = generate_self_signed(common_name)?;
505            Ok(vec![cert])
506        }
507    }
508}
509
510/// Parse PEM-encoded certificates
511fn parse_pem_certificates(data: &[u8]) -> TlsResult<Vec<CertificateDer<'static>>> {
512    let mut reader = BufReader::new(Cursor::new(data));
513    let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
514        .collect::<Result<Vec<_>, _>>()
515        .map_err(|e| TlsError::InvalidCertificate(format!("Failed to parse PEM: {}", e)))?;
516
517    if certs.is_empty() {
518        return Err(TlsError::InvalidCertificate(
519            "No certificates found in PEM data".to_string(),
520        ));
521    }
522
523    Ok(certs)
524}
525
526/// Load private key from a source
527pub fn load_private_key(source: &PrivateKeySource) -> TlsResult<PrivateKeyDer<'static>> {
528    match source {
529        PrivateKeySource::File { path } => {
530            let data = fs::read(path).map_err(|e| TlsError::KeyReadError {
531                path: path.clone(),
532                source: e,
533            })?;
534            parse_pem_private_key(&data)
535        }
536        PrivateKeySource::Pem { content } => parse_pem_private_key(content.as_bytes()),
537        PrivateKeySource::Der { content } => {
538            let der = base64::Engine::decode(&base64::engine::general_purpose::STANDARD, content)
539                .map_err(|e| TlsError::InvalidPrivateKey(format!("Invalid base64: {}", e)))?;
540            Ok(PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der)))
541        }
542    }
543}
544
545/// Parse PEM-encoded private key
546fn parse_pem_private_key(data: &[u8]) -> TlsResult<PrivateKeyDer<'static>> {
547    let mut reader = BufReader::new(Cursor::new(data));
548
549    rustls_pemfile::private_key(&mut reader)
550        .map_err(|e| TlsError::InvalidPrivateKey(format!("Failed to parse PEM: {}", e)))?
551        .ok_or_else(|| TlsError::InvalidPrivateKey("No private key found in PEM data".to_string()))
552}
553
554/// Generate self-signed certificate for development/testing
555pub fn generate_self_signed(
556    common_name: &str,
557) -> TlsResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
558    let subject_alt_names = vec![
559        common_name.to_string(),
560        "localhost".to_string(),
561        "127.0.0.1".to_string(),
562    ];
563
564    let mut cert_params = rcgen::CertificateParams::new(subject_alt_names)
565        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
566
567    // Set the distinguished name with proper common name
568    cert_params.distinguished_name = rcgen::DistinguishedName::new();
569    cert_params.distinguished_name.push(
570        rcgen::DnType::CommonName,
571        rcgen::DnValue::Utf8String(common_name.to_string()),
572    );
573    cert_params.distinguished_name.push(
574        rcgen::DnType::OrganizationName,
575        rcgen::DnValue::Utf8String("Rivven".to_string()),
576    );
577
578    let key_pair = rcgen::KeyPair::generate()
579        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
580
581    let cert = cert_params
582        .self_signed(&key_pair)
583        .map_err(|e| TlsError::SelfSignedGenerationError(e.to_string()))?;
584
585    let cert_der = CertificateDer::from(cert.der().to_vec());
586    let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
587
588    Ok((cert_der, key_der))
589}
590
591// ============================================================================
592// Server-Side TLS (Acceptor)
593// ============================================================================
594
595/// TLS acceptor for server-side connections
596pub struct TlsAcceptor {
597    config: Arc<ServerConfig>,
598    inner: tokio_rustls::TlsAcceptor,
599    /// Configuration for hot-reloading
600    tls_config: TlsConfig,
601    /// Reloadable config handle
602    reloadable_config: Option<Arc<RwLock<Arc<ServerConfig>>>>,
603}
604
605impl TlsAcceptor {
606    /// Create a new TLS acceptor from configuration
607    pub fn new(config: &TlsConfig) -> TlsResult<Self> {
608        let server_config = build_server_config(config)?;
609        let server_config = Arc::new(server_config);
610
611        Ok(Self {
612            inner: tokio_rustls::TlsAcceptor::from(server_config.clone()),
613            config: server_config.clone(),
614            tls_config: config.clone(),
615            reloadable_config: if config.cert_reload_interval > Duration::ZERO {
616                Some(Arc::new(RwLock::new(server_config)))
617            } else {
618                None
619            },
620        })
621    }
622
623    /// Accept a TLS connection
624    pub async fn accept<IO>(&self, stream: IO) -> TlsResult<TlsServerStream<IO>>
625    where
626        IO: AsyncRead + AsyncWrite + Unpin,
627    {
628        let tls_stream = self
629            .inner
630            .accept(stream)
631            .await
632            .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
633
634        Ok(TlsServerStream { inner: tls_stream })
635    }
636
637    /// Accept a TCP connection with TLS
638    pub async fn accept_tcp(&self, stream: TcpStream) -> TlsResult<TlsServerStream<TcpStream>> {
639        self.accept(stream).await
640    }
641
642    /// Reload certificates (for hot-reloading)
643    pub fn reload(&self) -> TlsResult<()> {
644        if let Some(ref reloadable) = self.reloadable_config {
645            let new_config = build_server_config(&self.tls_config)?;
646            *reloadable.write() = Arc::new(new_config);
647        }
648        Ok(())
649    }
650
651    /// Get the underlying rustls ServerConfig
652    pub fn config(&self) -> &Arc<ServerConfig> {
653        &self.config
654    }
655}
656
657impl fmt::Debug for TlsAcceptor {
658    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659        f.debug_struct("TlsAcceptor")
660            .field("mtls_mode", &self.tls_config.mtls_mode)
661            .field("min_version", &self.tls_config.min_version)
662            .finish()
663    }
664}
665
666/// Build rustls ServerConfig from TlsConfig
667fn build_server_config(config: &TlsConfig) -> TlsResult<ServerConfig> {
668    // Handle self-signed certificates specially to ensure cert and key match
669    let (certs, key) =
670        if let Some(CertificateSource::SelfSigned { common_name }) = &config.certificate {
671            // Generate both cert and key together to ensure they match
672            let (cert, key) = generate_self_signed(common_name)?;
673            (vec![cert], key)
674        } else {
675            // Load certificates from explicit sources
676            let certs = if let Some(ref cert_source) = config.certificate {
677                load_certificates(cert_source)?
678            } else {
679                return Err(TlsError::ConfigError(
680                    "Server certificate required".to_string(),
681                ));
682            };
683
684            // Load private key
685            let key = if let Some(ref key_source) = config.private_key {
686                load_private_key(key_source)?
687            } else {
688                return Err(TlsError::ConfigError("Private key required".to_string()));
689            };
690
691            (certs, key)
692        };
693
694    // Build TLS versions
695    let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
696        TlsVersion::Tls13 => vec![&rustls::version::TLS13],
697        TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
698    };
699
700    // Configure client certificate verification
701    let client_cert_verifier = match config.mtls_mode {
702        MtlsMode::Disabled => None,
703        MtlsMode::Optional | MtlsMode::Required => {
704            if let Some(ref ca_source) = config.client_ca {
705                let ca_certs = load_certificates(ca_source)?;
706                let mut root_store = rustls::RootCertStore::empty();
707                for cert in ca_certs {
708                    root_store.add(cert).map_err(|e| {
709                        TlsError::CertificateChainError(format!("Failed to add CA cert: {}", e))
710                    })?;
711                }
712
713                let verifier = if config.mtls_mode == MtlsMode::Required {
714                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
715                        .build()
716                        .map_err(|e| {
717                            TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
718                        })?
719                } else {
720                    rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store))
721                        .allow_unauthenticated()
722                        .build()
723                        .map_err(|e| {
724                            TlsError::ConfigError(format!("Failed to build client verifier: {}", e))
725                        })?
726                };
727
728                Some(verifier)
729            } else if config.mtls_mode == MtlsMode::Required {
730                return Err(TlsError::ConfigError(
731                    "mTLS required but no client CA configured".to_string(),
732                ));
733            } else {
734                None
735            }
736        }
737    };
738
739    // Build server config
740    let mut server_config = if let Some(verifier) = client_cert_verifier {
741        ServerConfig::builder_with_protocol_versions(&versions)
742            .with_client_cert_verifier(verifier)
743            .with_single_cert(certs, key)
744            .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
745    } else {
746        ServerConfig::builder_with_protocol_versions(&versions)
747            .with_no_client_auth()
748            .with_single_cert(certs, key)
749            .map_err(|e| TlsError::ConfigError(format!("Invalid cert/key: {}", e)))?
750    };
751
752    // Configure ALPN
753    if !config.alpn_protocols.is_empty() {
754        server_config.alpn_protocols = config
755            .alpn_protocols
756            .iter()
757            .map(|p| p.as_bytes().to_vec())
758            .collect();
759    }
760
761    // Configure session cache
762    if config.session_cache_size > 0 {
763        server_config.session_storage =
764            rustls::server::ServerSessionMemoryCache::new(config.session_cache_size);
765    }
766
767    Ok(server_config)
768}
769
770// ============================================================================
771// Client-Side TLS (Connector)
772// ============================================================================
773
774/// TLS connector for client-side connections
775pub struct TlsConnector {
776    config: Arc<ClientConfig>,
777    inner: tokio_rustls::TlsConnector,
778    /// Default server name for SNI
779    server_name: Option<String>,
780}
781
782impl TlsConnector {
783    /// Create a new TLS connector from configuration
784    pub fn new(config: &TlsConfig) -> TlsResult<Self> {
785        let client_config = build_client_config(config)?;
786        let client_config = Arc::new(client_config);
787
788        Ok(Self {
789            inner: tokio_rustls::TlsConnector::from(client_config.clone()),
790            config: client_config,
791            server_name: config.server_name.clone(),
792        })
793    }
794
795    /// Connect to a server with TLS
796    pub async fn connect<IO>(&self, stream: IO, server_name: &str) -> TlsResult<TlsClientStream<IO>>
797    where
798        IO: AsyncRead + AsyncWrite + Unpin,
799    {
800        let name: rustls::pki_types::ServerName<'static> = server_name
801            .to_string()
802            .try_into()
803            .map_err(|_| TlsError::ConfigError(format!("Invalid server name: {}", server_name)))?;
804
805        let tls_stream = self
806            .inner
807            .connect(name, stream)
808            .await
809            .map_err(|e| TlsError::HandshakeError(e.to_string()))?;
810
811        Ok(TlsClientStream { inner: tls_stream })
812    }
813
814    /// Connect to a TCP address with TLS
815    pub async fn connect_tcp(
816        &self,
817        addr: SocketAddr,
818        server_name: &str,
819    ) -> TlsResult<TlsClientStream<TcpStream>> {
820        let stream = TcpStream::connect(addr)
821            .await
822            .map_err(|e| TlsError::ConnectionError(e.to_string()))?;
823
824        self.connect(stream, server_name).await
825    }
826
827    /// Connect using the configured server name
828    pub async fn connect_with_default_name<IO>(&self, stream: IO) -> TlsResult<TlsClientStream<IO>>
829    where
830        IO: AsyncRead + AsyncWrite + Unpin,
831    {
832        let name = self.server_name.as_ref().ok_or_else(|| {
833            TlsError::ConfigError("No server name configured for SNI".to_string())
834        })?;
835        self.connect(stream, name).await
836    }
837
838    /// Get the underlying rustls ClientConfig
839    pub fn config(&self) -> &Arc<ClientConfig> {
840        &self.config
841    }
842}
843
844impl fmt::Debug for TlsConnector {
845    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
846        f.debug_struct("TlsConnector")
847            .field("server_name", &self.server_name)
848            .finish()
849    }
850}
851
852/// Build rustls ClientConfig from TlsConfig
853fn build_client_config(config: &TlsConfig) -> TlsResult<ClientConfig> {
854    // Build TLS versions
855    let versions: Vec<&'static rustls::SupportedProtocolVersion> = match config.min_version {
856        TlsVersion::Tls13 => vec![&rustls::version::TLS13],
857        TlsVersion::Tls12 => vec![&rustls::version::TLS12, &rustls::version::TLS13],
858    };
859
860    // Build root certificate store
861    let root_store = if config.insecure_skip_verify {
862        // DANGEROUS: Trust all certificates (development only)
863        rustls::RootCertStore::empty()
864    } else if let Some(ref ca_source) = config.root_ca {
865        let ca_certs = load_certificates(ca_source)?;
866        let mut store = rustls::RootCertStore::empty();
867        for cert in ca_certs {
868            store.add(cert).map_err(|e| {
869                TlsError::CertificateChainError(format!("Failed to add root CA: {}", e))
870            })?;
871        }
872        store
873    } else {
874        // Use system root certificates
875        let mut store = rustls::RootCertStore::empty();
876        let native_certs = rustls_native_certs::load_native_certs();
877        for cert in native_certs.certs {
878            let _ = store.add(cert);
879        }
880        store
881    };
882
883    // Build client config with or without client certificate
884    let mut client_config = if let (Some(ref cert_source), Some(ref key_source)) =
885        (&config.certificate, &config.private_key)
886    {
887        // mTLS: provide client certificate
888        let certs = load_certificates(cert_source)?;
889        let key = load_private_key(key_source)?;
890
891        ClientConfig::builder_with_protocol_versions(&versions)
892            .with_root_certificates(root_store)
893            .with_client_auth_cert(certs, key)
894            .map_err(|e| TlsError::ConfigError(format!("Invalid client cert/key: {}", e)))?
895    } else if config.insecure_skip_verify {
896        // DANGEROUS: Skip server verification
897        ClientConfig::builder_with_protocol_versions(&versions)
898            .dangerous()
899            .with_custom_certificate_verifier(Arc::new(NoCertificateVerification))
900            .with_no_client_auth()
901    } else {
902        // Standard TLS without client certificate
903        ClientConfig::builder_with_protocol_versions(&versions)
904            .with_root_certificates(root_store)
905            .with_no_client_auth()
906    };
907
908    // Configure ALPN
909    if !config.alpn_protocols.is_empty() {
910        client_config.alpn_protocols = config
911            .alpn_protocols
912            .iter()
913            .map(|p| p.as_bytes().to_vec())
914            .collect();
915    }
916
917    Ok(client_config)
918}
919
920/// DANGEROUS: Certificate verifier that accepts any certificate
921/// Only for development/testing
922#[derive(Debug)]
923struct NoCertificateVerification;
924
925impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
926    fn verify_server_cert(
927        &self,
928        _end_entity: &CertificateDer<'_>,
929        _intermediates: &[CertificateDer<'_>],
930        _server_name: &rustls::pki_types::ServerName<'_>,
931        _ocsp_response: &[u8],
932        _now: rustls::pki_types::UnixTime,
933    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
934        Ok(rustls::client::danger::ServerCertVerified::assertion())
935    }
936
937    fn verify_tls12_signature(
938        &self,
939        _message: &[u8],
940        _cert: &CertificateDer<'_>,
941        _dss: &rustls::DigitallySignedStruct,
942    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
943        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
944    }
945
946    fn verify_tls13_signature(
947        &self,
948        _message: &[u8],
949        _cert: &CertificateDer<'_>,
950        _dss: &rustls::DigitallySignedStruct,
951    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
952        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
953    }
954
955    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
956        vec![
957            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
958            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
959            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
960            rustls::SignatureScheme::RSA_PSS_SHA256,
961            rustls::SignatureScheme::RSA_PSS_SHA384,
962            rustls::SignatureScheme::RSA_PSS_SHA512,
963            rustls::SignatureScheme::RSA_PKCS1_SHA256,
964            rustls::SignatureScheme::RSA_PKCS1_SHA384,
965            rustls::SignatureScheme::RSA_PKCS1_SHA512,
966            rustls::SignatureScheme::ED25519,
967        ]
968    }
969}
970
971// ============================================================================
972// TLS Streams
973// ============================================================================
974
975/// Server-side TLS stream
976pub struct TlsServerStream<IO> {
977    inner: tokio_rustls::server::TlsStream<IO>,
978}
979
980impl<IO> TlsServerStream<IO>
981where
982    IO: AsyncRead + AsyncWrite + Unpin,
983{
984    /// Get client certificate if presented
985    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
986        self.inner.get_ref().1.peer_certificates()
987    }
988
989    /// Get ALPN protocol if negotiated
990    pub fn alpn_protocol(&self) -> Option<&[u8]> {
991        self.inner.get_ref().1.alpn_protocol()
992    }
993
994    /// Get negotiated protocol version
995    pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
996        self.inner.get_ref().1.protocol_version()
997    }
998
999    /// Get negotiated cipher suite
1000    pub fn negotiated_cipher_suite(&self) -> Option<rustls::SupportedCipherSuite> {
1001        self.inner.get_ref().1.negotiated_cipher_suite()
1002    }
1003
1004    /// Get cipher suite name as string
1005    pub fn cipher_suite_name(&self) -> Option<String> {
1006        self.negotiated_cipher_suite()
1007            .map(|cs| format!("{:?}", cs.suite()))
1008    }
1009
1010    /// Check if the connection uses TLS 1.3
1011    pub fn is_tls_13(&self) -> bool {
1012        self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1013    }
1014
1015    /// Extract the client certificate common name (CN)
1016    pub fn peer_common_name(&self) -> Option<String> {
1017        self.peer_certificates().and_then(|certs| {
1018            if certs.is_empty() {
1019                return None;
1020            }
1021            extract_common_name(&certs[0])
1022        })
1023    }
1024
1025    /// Extract the client certificate subject
1026    pub fn peer_subject(&self) -> Option<String> {
1027        self.peer_certificates().and_then(|certs| {
1028            if certs.is_empty() {
1029                return None;
1030            }
1031            extract_subject(&certs[0])
1032        })
1033    }
1034
1035    /// Get reference to the inner stream
1036    pub fn get_ref(&self) -> &IO {
1037        self.inner.get_ref().0
1038    }
1039
1040    /// Get mutable reference to the inner stream
1041    pub fn get_mut(&mut self) -> &mut IO {
1042        self.inner.get_mut().0
1043    }
1044
1045    /// Unwrap and get the inner stream
1046    pub fn into_inner(self) -> IO {
1047        self.inner.into_inner().0
1048    }
1049}
1050
1051impl<IO> tokio::io::AsyncRead for TlsServerStream<IO>
1052where
1053    IO: AsyncRead + AsyncWrite + Unpin,
1054{
1055    fn poll_read(
1056        mut self: std::pin::Pin<&mut Self>,
1057        cx: &mut std::task::Context<'_>,
1058        buf: &mut ReadBuf<'_>,
1059    ) -> std::task::Poll<io::Result<()>> {
1060        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1061    }
1062}
1063
1064impl<IO> tokio::io::AsyncWrite for TlsServerStream<IO>
1065where
1066    IO: AsyncRead + AsyncWrite + Unpin,
1067{
1068    fn poll_write(
1069        mut self: std::pin::Pin<&mut Self>,
1070        cx: &mut std::task::Context<'_>,
1071        buf: &[u8],
1072    ) -> std::task::Poll<io::Result<usize>> {
1073        std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1074    }
1075
1076    fn poll_flush(
1077        mut self: std::pin::Pin<&mut Self>,
1078        cx: &mut std::task::Context<'_>,
1079    ) -> std::task::Poll<io::Result<()>> {
1080        std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1081    }
1082
1083    fn poll_shutdown(
1084        mut self: std::pin::Pin<&mut Self>,
1085        cx: &mut std::task::Context<'_>,
1086    ) -> std::task::Poll<io::Result<()>> {
1087        std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1088    }
1089}
1090
1091/// Client-side TLS stream
1092pub struct TlsClientStream<IO> {
1093    inner: tokio_rustls::client::TlsStream<IO>,
1094}
1095
1096impl<IO> TlsClientStream<IO>
1097where
1098    IO: AsyncRead + AsyncWrite + Unpin,
1099{
1100    /// Get server certificate if provided
1101    pub fn peer_certificates(&self) -> Option<&[CertificateDer<'_>]> {
1102        self.inner.get_ref().1.peer_certificates()
1103    }
1104
1105    /// Get ALPN protocol if negotiated
1106    pub fn alpn_protocol(&self) -> Option<&[u8]> {
1107        self.inner.get_ref().1.alpn_protocol()
1108    }
1109
1110    /// Get negotiated protocol version
1111    pub fn protocol_version(&self) -> Option<rustls::ProtocolVersion> {
1112        self.inner.get_ref().1.protocol_version()
1113    }
1114
1115    /// Check if the connection uses TLS 1.3
1116    pub fn is_tls_13(&self) -> bool {
1117        self.protocol_version() == Some(rustls::ProtocolVersion::TLSv1_3)
1118    }
1119
1120    /// Get reference to the inner stream
1121    pub fn get_ref(&self) -> &IO {
1122        self.inner.get_ref().0
1123    }
1124
1125    /// Get mutable reference to the inner stream
1126    pub fn get_mut(&mut self) -> &mut IO {
1127        self.inner.get_mut().0
1128    }
1129
1130    /// Unwrap and get the inner stream
1131    pub fn into_inner(self) -> IO {
1132        self.inner.into_inner().0
1133    }
1134}
1135
1136impl<IO> tokio::io::AsyncRead for TlsClientStream<IO>
1137where
1138    IO: AsyncRead + AsyncWrite + Unpin,
1139{
1140    fn poll_read(
1141        mut self: std::pin::Pin<&mut Self>,
1142        cx: &mut std::task::Context<'_>,
1143        buf: &mut ReadBuf<'_>,
1144    ) -> std::task::Poll<io::Result<()>> {
1145        std::pin::Pin::new(&mut self.inner).poll_read(cx, buf)
1146    }
1147}
1148
1149impl<IO> tokio::io::AsyncWrite for TlsClientStream<IO>
1150where
1151    IO: AsyncRead + AsyncWrite + Unpin,
1152{
1153    fn poll_write(
1154        mut self: std::pin::Pin<&mut Self>,
1155        cx: &mut std::task::Context<'_>,
1156        buf: &[u8],
1157    ) -> std::task::Poll<io::Result<usize>> {
1158        std::pin::Pin::new(&mut self.inner).poll_write(cx, buf)
1159    }
1160
1161    fn poll_flush(
1162        mut self: std::pin::Pin<&mut Self>,
1163        cx: &mut std::task::Context<'_>,
1164    ) -> std::task::Poll<io::Result<()>> {
1165        std::pin::Pin::new(&mut self.inner).poll_flush(cx)
1166    }
1167
1168    fn poll_shutdown(
1169        mut self: std::pin::Pin<&mut Self>,
1170        cx: &mut std::task::Context<'_>,
1171    ) -> std::task::Poll<io::Result<()>> {
1172        std::pin::Pin::new(&mut self.inner).poll_shutdown(cx)
1173    }
1174}
1175
1176// ============================================================================
1177// Certificate Utilities
1178// ============================================================================
1179
1180/// Extract common name (CN) from certificate
1181fn extract_common_name(cert: &CertificateDer<'_>) -> Option<String> {
1182    // Parse the certificate using x509-parser
1183    let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1184
1185    for rdn in cert.subject().iter_rdn() {
1186        for attr in rdn.iter() {
1187            if attr.attr_type() == &x509_parser::oid_registry::OID_X509_COMMON_NAME {
1188                return attr.as_str().ok().map(|s| s.to_string());
1189            }
1190        }
1191    }
1192
1193    None
1194}
1195
1196/// Extract full subject from certificate
1197fn extract_subject(cert: &CertificateDer<'_>) -> Option<String> {
1198    let (_, cert) = x509_parser::parse_x509_certificate(cert.as_ref()).ok()?;
1199    Some(cert.subject().to_string())
1200}
1201
1202/// Calculate SHA-256 fingerprint of a certificate
1203pub fn certificate_fingerprint(cert: &CertificateDer<'_>) -> String {
1204    use sha2::{Digest, Sha256};
1205    let hash = Sha256::digest(cert.as_ref());
1206    hex::encode(hash)
1207}
1208
1209/// Verify certificate chain
1210///
1211/// Note: This is a basic sanity check. The actual TLS handshake performs
1212/// full chain validation using WebPKI through rustls.
1213pub fn verify_certificate_chain(
1214    chain: &[CertificateDer<'_>],
1215    root_store: &rustls::RootCertStore,
1216) -> TlsResult<()> {
1217    if chain.is_empty() {
1218        return Err(TlsError::CertificateChainError(
1219            "Empty certificate chain".to_string(),
1220        ));
1221    }
1222
1223    // Basic sanity check - the actual validation happens during TLS handshake
1224    // via rustls WebPKI implementation
1225    if root_store.is_empty() {
1226        tracing::warn!("Root certificate store is empty - chain validation may fail");
1227    }
1228
1229    // Log certificate chain info for debugging
1230    for (i, cert) in chain.iter().enumerate() {
1231        let fingerprint = certificate_fingerprint(cert);
1232        tracing::debug!(
1233            "Certificate chain[{}]: fingerprint={}",
1234            i,
1235            &fingerprint[..16]
1236        );
1237    }
1238
1239    Ok(())
1240}
1241
1242// ============================================================================
1243// Certificate Watcher for Hot Reloading
1244// ============================================================================
1245
1246/// Watches certificate files and triggers reload on changes
1247pub struct CertificateWatcher {
1248    /// Files being watched
1249    watched_files: Vec<PathBuf>,
1250    /// Last modification times
1251    last_modified: HashMap<PathBuf, SystemTime>,
1252    /// Callback for reload
1253    reload_callback: Box<dyn Fn() + Send + Sync>,
1254}
1255
1256impl CertificateWatcher {
1257    /// Create a new certificate watcher
1258    pub fn new<F>(files: Vec<PathBuf>, callback: F) -> Self
1259    where
1260        F: Fn() + Send + Sync + 'static,
1261    {
1262        let mut last_modified = HashMap::new();
1263        for file in &files {
1264            if let Ok(meta) = fs::metadata(file) {
1265                if let Ok(modified) = meta.modified() {
1266                    last_modified.insert(file.clone(), modified);
1267                }
1268            }
1269        }
1270
1271        Self {
1272            watched_files: files,
1273            last_modified,
1274            reload_callback: Box::new(callback),
1275        }
1276    }
1277
1278    /// Check for file changes and trigger reload if needed
1279    pub fn check_and_reload(&mut self) -> bool {
1280        let mut changed = false;
1281
1282        for file in &self.watched_files {
1283            if let Ok(meta) = fs::metadata(file) {
1284                if let Ok(modified) = meta.modified() {
1285                    let last = self.last_modified.get(file);
1286                    if last.is_none_or(|&l| modified > l) {
1287                        self.last_modified.insert(file.clone(), modified);
1288                        changed = true;
1289                    }
1290                }
1291            }
1292        }
1293
1294        if changed {
1295            (self.reload_callback)();
1296        }
1297
1298        changed
1299    }
1300
1301    /// Start watching in background
1302    pub fn spawn(mut self, interval: Duration) -> tokio::task::JoinHandle<()> {
1303        tokio::spawn(async move {
1304            let mut ticker = tokio::time::interval(interval);
1305            loop {
1306                ticker.tick().await;
1307                self.check_and_reload();
1308            }
1309        })
1310    }
1311}
1312
1313// ============================================================================
1314// Connection Identity (mTLS Integration)
1315// ============================================================================
1316
1317/// Identity extracted from TLS connection
1318#[derive(Debug, Clone, Serialize, Deserialize)]
1319pub struct TlsIdentity {
1320    /// Certificate common name (CN)
1321    pub common_name: Option<String>,
1322    /// Full certificate subject
1323    pub subject: Option<String>,
1324    /// Certificate fingerprint (SHA-256)
1325    pub fingerprint: String,
1326    /// Organization from certificate
1327    pub organization: Option<String>,
1328    /// Organizational unit from certificate
1329    pub organizational_unit: Option<String>,
1330    /// Certificate serial number
1331    pub serial_number: Option<String>,
1332    /// Certificate validity period
1333    pub valid_from: Option<chrono::DateTime<chrono::Utc>>,
1334    pub valid_until: Option<chrono::DateTime<chrono::Utc>>,
1335    /// Is the certificate still valid
1336    pub is_valid: bool,
1337}
1338
1339impl TlsIdentity {
1340    /// Extract identity from a certificate
1341    pub fn from_certificate(cert: &CertificateDer<'_>) -> Self {
1342        let fingerprint = certificate_fingerprint(cert);
1343        let common_name = extract_common_name(cert);
1344        let subject = extract_subject(cert);
1345
1346        // Parse additional fields using x509-parser
1347        let (organization, organizational_unit, serial_number, valid_from, valid_until, is_valid) =
1348            if let Ok((_, parsed)) = x509_parser::parse_x509_certificate(cert.as_ref()) {
1349                let mut org = None;
1350                let mut ou = None;
1351
1352                for rdn in parsed.subject().iter_rdn() {
1353                    for attr in rdn.iter() {
1354                        if attr.attr_type()
1355                            == &x509_parser::oid_registry::OID_X509_ORGANIZATION_NAME
1356                        {
1357                            org = attr.as_str().ok().map(|s| s.to_string());
1358                        }
1359                        if attr.attr_type()
1360                            == &x509_parser::oid_registry::OID_X509_ORGANIZATIONAL_UNIT
1361                        {
1362                            ou = attr.as_str().ok().map(|s| s.to_string());
1363                        }
1364                    }
1365                }
1366
1367                let serial = Some(parsed.serial.to_str_radix(16));
1368
1369                let validity = parsed.validity();
1370                let now = chrono::Utc::now();
1371
1372                let from = chrono::DateTime::from_timestamp(validity.not_before.timestamp(), 0);
1373                let until = chrono::DateTime::from_timestamp(validity.not_after.timestamp(), 0);
1374
1375                let valid = from.is_some_and(|f| now >= f) && until.is_some_and(|u| now <= u);
1376
1377                (org, ou, serial, from, until, valid)
1378            } else {
1379                (None, None, None, None, None, false)
1380            };
1381
1382        Self {
1383            common_name,
1384            subject,
1385            fingerprint,
1386            organization,
1387            organizational_unit,
1388            serial_number,
1389            valid_from,
1390            valid_until,
1391            is_valid,
1392        }
1393    }
1394}
1395
1396// ============================================================================
1397// Security Best Practices
1398// ============================================================================
1399
1400/// Security audit of TLS configuration
1401#[derive(Debug)]
1402pub struct TlsSecurityAudit {
1403    pub warnings: Vec<String>,
1404    pub errors: Vec<String>,
1405    pub recommendations: Vec<String>,
1406}
1407
1408impl TlsSecurityAudit {
1409    /// Audit a TLS configuration for security issues
1410    pub fn audit(config: &TlsConfig) -> Self {
1411        let mut audit = Self {
1412            warnings: vec![],
1413            errors: vec![],
1414            recommendations: vec![],
1415        };
1416
1417        if !config.enabled {
1418            audit
1419                .errors
1420                .push("TLS is disabled - all traffic will be unencrypted".to_string());
1421        }
1422
1423        if config.insecure_skip_verify {
1424            audit.errors.push(
1425                "Certificate verification is disabled - vulnerable to MITM attacks".to_string(),
1426            );
1427        }
1428
1429        if config.min_version == TlsVersion::Tls12 {
1430            audit.warnings.push(
1431                "TLS 1.2 is allowed - consider requiring TLS 1.3 for better security".to_string(),
1432            );
1433        }
1434
1435        if config.mtls_mode == MtlsMode::Disabled && config.client_ca.is_some() {
1436            audit.warnings.push(
1437                "Client CA configured but mTLS is disabled - clients won't be verified".to_string(),
1438            );
1439        }
1440
1441        if config.mtls_mode == MtlsMode::Optional {
1442            audit.warnings.push(
1443                "mTLS is optional - some clients may connect without certificates".to_string(),
1444            );
1445        }
1446
1447        if config.session_cache_size == 0 {
1448            audit
1449                .recommendations
1450                .push("Consider enabling session cache for better performance".to_string());
1451        }
1452
1453        if config.cert_reload_interval == Duration::ZERO {
1454            audit.recommendations.push(
1455                "Consider enabling certificate hot-reloading for zero-downtime rotation"
1456                    .to_string(),
1457            );
1458        }
1459
1460        if config.pinned_certificates.is_empty() && !config.insecure_skip_verify {
1461            audit
1462                .recommendations
1463                .push("Consider certificate pinning for high-security deployments".to_string());
1464        }
1465
1466        audit
1467    }
1468}
1469
1470// ============================================================================
1471// Tests
1472// ============================================================================
1473
1474#[cfg(test)]
1475mod tests {
1476    use super::*;
1477    use tokio::io::{AsyncReadExt, AsyncWriteExt};
1478
1479    #[test]
1480    fn test_tls_config_default() {
1481        let config = TlsConfig::default();
1482        assert!(!config.enabled);
1483        assert_eq!(config.mtls_mode, MtlsMode::Disabled);
1484        assert_eq!(config.min_version, TlsVersion::Tls13);
1485    }
1486
1487    #[test]
1488    fn test_tls_config_builder() {
1489        let config = TlsConfigBuilder::new()
1490            .with_cert_file("/path/to/cert.pem")
1491            .with_key_file("/path/to/key.pem")
1492            .with_client_ca_file("/path/to/ca.pem")
1493            .require_client_cert(true)
1494            .with_min_version(TlsVersion::Tls12)
1495            .build();
1496
1497        assert!(config.enabled);
1498        assert_eq!(config.mtls_mode, MtlsMode::Required);
1499        assert_eq!(config.min_version, TlsVersion::Tls12);
1500    }
1501
1502    #[tokio::test]
1503    async fn test_tls_server_client_handshake() {
1504        // Install crypto provider (required by rustls 0.23+)
1505        let _ = rustls::crypto::ring::default_provider().install_default();
1506
1507        // Use SelfSigned source which generates at runtime
1508        let server_config = TlsConfig {
1509            enabled: true,
1510            certificate: Some(CertificateSource::SelfSigned {
1511                common_name: "localhost".to_string(),
1512            }),
1513            // Key is auto-generated with self-signed
1514            mtls_mode: MtlsMode::Disabled,
1515            ..Default::default()
1516        };
1517
1518        // Create client config that skips verification (for self-signed)
1519        let client_config = TlsConfig {
1520            enabled: true,
1521            insecure_skip_verify: true,
1522            ..Default::default()
1523        };
1524
1525        // Create acceptor and connector
1526        let acceptor = TlsAcceptor::new(&server_config).unwrap();
1527        let connector = TlsConnector::new(&client_config).unwrap();
1528
1529        // Start a TCP listener
1530        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1531        let addr = listener.local_addr().unwrap();
1532
1533        // Server task: accept TLS connection and echo data
1534        let server_task = tokio::spawn(async move {
1535            let (tcp_stream, _) = listener.accept().await.unwrap();
1536            let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1537                acceptor.accept_tcp(tcp_stream).await.unwrap();
1538
1539            // Read data
1540            let mut buf = [0u8; 32];
1541            let n = tls_stream.read(&mut buf).await.unwrap();
1542
1543            // Echo it back
1544            tls_stream.write_all(&buf[..n]).await.unwrap();
1545            tls_stream.flush().await.unwrap();
1546
1547            n
1548        });
1549
1550        // Client task: connect and send data
1551        let client_task = tokio::spawn(async move {
1552            let mut stream: TlsClientStream<tokio::net::TcpStream> =
1553                connector.connect_tcp(addr, "localhost").await.unwrap();
1554
1555            // Send test message
1556            let message = b"Hello, TLS!";
1557            stream.write_all(message).await.unwrap();
1558            stream.flush().await.unwrap();
1559
1560            // Read response
1561            let mut response = [0u8; 32];
1562            let n = stream.read(&mut response).await.unwrap();
1563
1564            (message.to_vec(), response[..n].to_vec())
1565        });
1566
1567        // Wait for both tasks
1568        let (server_result, client_result) = tokio::join!(server_task, client_task);
1569
1570        let server_bytes_read = server_result.unwrap();
1571        let (sent, received) = client_result.unwrap();
1572
1573        // Verify echo worked
1574        assert_eq!(server_bytes_read, sent.len());
1575        assert_eq!(sent, received);
1576    }
1577
1578    #[tokio::test]
1579    async fn test_mtls_server_client_handshake() {
1580        use rcgen::{BasicConstraints, CertificateParams, DnType, IsCa, KeyUsagePurpose};
1581
1582        // Install crypto provider (required by rustls 0.23+)
1583        let _ = rustls::crypto::ring::default_provider().install_default();
1584
1585        // Generate a shared CA certificate
1586        let mut ca_params = CertificateParams::default();
1587        ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
1588        ca_params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
1589        ca_params
1590            .distinguished_name
1591            .push(DnType::CommonName, "Rivven Test CA");
1592        let ca_key_pair = rcgen::KeyPair::generate().unwrap();
1593        let ca_cert = ca_params.self_signed(&ca_key_pair).unwrap();
1594        let ca_cert_pem = ca_cert.pem();
1595
1596        // Generate server certificate signed by CA
1597        let mut server_params = CertificateParams::new(vec!["localhost".to_string()]).unwrap();
1598        server_params
1599            .distinguished_name
1600            .push(DnType::CommonName, "localhost");
1601        let server_key_pair = rcgen::KeyPair::generate().unwrap();
1602        let server_cert = server_params
1603            .signed_by(&server_key_pair, &ca_cert, &ca_key_pair)
1604            .unwrap();
1605        let server_cert_pem = server_cert.pem();
1606        let server_key_pem = server_key_pair.serialize_pem();
1607
1608        // Generate client certificate signed by CA
1609        let mut client_params =
1610            CertificateParams::new(vec!["client.rivven.local".to_string()]).unwrap();
1611        client_params
1612            .distinguished_name
1613            .push(DnType::CommonName, "client.rivven.local");
1614        let client_key_pair = rcgen::KeyPair::generate().unwrap();
1615        let client_cert = client_params
1616            .signed_by(&client_key_pair, &ca_cert, &ca_key_pair)
1617            .unwrap();
1618        let client_cert_pem = client_cert.pem();
1619        let client_key_pem = client_key_pair.serialize_pem();
1620
1621        // Server config with mTLS required
1622        let server_config = TlsConfig {
1623            enabled: true,
1624            certificate: Some(CertificateSource::Pem {
1625                content: server_cert_pem,
1626            }),
1627            private_key: Some(PrivateKeySource::Pem {
1628                content: server_key_pem,
1629            }),
1630            client_ca: Some(CertificateSource::Pem {
1631                content: ca_cert_pem.clone(),
1632            }),
1633            mtls_mode: MtlsMode::Required,
1634            insecure_skip_verify: false,
1635            ..Default::default()
1636        };
1637
1638        // Client config with client cert and CA trust
1639        let client_config = TlsConfig {
1640            enabled: true,
1641            certificate: Some(CertificateSource::Pem {
1642                content: client_cert_pem,
1643            }),
1644            private_key: Some(PrivateKeySource::Pem {
1645                content: client_key_pem,
1646            }),
1647            root_ca: Some(CertificateSource::Pem {
1648                content: ca_cert_pem,
1649            }),
1650            insecure_skip_verify: false,
1651            ..Default::default()
1652        };
1653
1654        // Create acceptor and connector
1655        let acceptor = TlsAcceptor::new(&server_config).unwrap();
1656        let connector = TlsConnector::new(&client_config).unwrap();
1657
1658        // Start a TCP listener
1659        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
1660        let addr = listener.local_addr().unwrap();
1661
1662        // Server task
1663        let server_task = tokio::spawn(async move {
1664            let (tcp_stream, _) = listener.accept().await.unwrap();
1665            let mut tls_stream: TlsServerStream<tokio::net::TcpStream> =
1666                acceptor.accept_tcp(tcp_stream).await.unwrap();
1667
1668            // Check if we can see peer certificates (mTLS)
1669            let has_peer_cert = tls_stream.peer_certificates().is_some();
1670
1671            // Read data
1672            let mut buf = [0u8; 32];
1673            let n = tls_stream.read(&mut buf).await.unwrap();
1674            tls_stream.write_all(&buf[..n]).await.unwrap();
1675            tls_stream.flush().await.unwrap();
1676
1677            (n, has_peer_cert)
1678        });
1679
1680        // Client task
1681        let client_task = tokio::spawn(async move {
1682            let mut stream: TlsClientStream<tokio::net::TcpStream> =
1683                connector.connect_tcp(addr, "localhost").await.unwrap();
1684
1685            // Send test message
1686            let message = b"mTLS Test!";
1687            stream.write_all(message).await.unwrap();
1688            stream.flush().await.unwrap();
1689
1690            // Read response
1691            let mut response = [0u8; 32];
1692            let n = stream.read(&mut response).await.unwrap();
1693
1694            (message.to_vec(), response[..n].to_vec())
1695        });
1696
1697        // Wait for both tasks
1698        let (server_result, client_result) = tokio::join!(server_task, client_task);
1699
1700        let (server_bytes_read, has_peer_cert) = server_result.unwrap();
1701        let (sent, received) = client_result.unwrap();
1702
1703        // Verify echo worked
1704        assert_eq!(server_bytes_read, sent.len());
1705        assert_eq!(sent, received);
1706
1707        // Verify mTLS - server saw client certificate
1708        assert!(
1709            has_peer_cert,
1710            "Server should have received client certificate in mTLS"
1711        );
1712    }
1713
1714    #[test]
1715    fn test_self_signed_generation() {
1716        let result = generate_self_signed("test.rivven.local");
1717        assert!(result.is_ok());
1718
1719        let (cert, _key) = result.unwrap();
1720        assert!(!cert.as_ref().is_empty());
1721
1722        // Verify we can extract identity
1723        let identity = TlsIdentity::from_certificate(&cert);
1724        assert_eq!(identity.common_name, Some("test.rivven.local".to_string()));
1725        assert!(identity.is_valid);
1726    }
1727
1728    #[test]
1729    fn test_certificate_fingerprint() {
1730        let (cert, _) = generate_self_signed("test.rivven.local").unwrap();
1731        let fingerprint = certificate_fingerprint(&cert);
1732
1733        // Should be 64 hex characters (SHA-256)
1734        assert_eq!(fingerprint.len(), 64);
1735        assert!(fingerprint.chars().all(|c| c.is_ascii_hexdigit()));
1736    }
1737
1738    #[test]
1739    fn test_tls_security_audit_disabled() {
1740        let config = TlsConfig::disabled();
1741        let audit = TlsSecurityAudit::audit(&config);
1742
1743        assert!(!audit.errors.is_empty());
1744        assert!(audit.errors.iter().any(|e| e.contains("disabled")));
1745    }
1746
1747    #[test]
1748    fn test_tls_security_audit_insecure() {
1749        let config = TlsConfig {
1750            enabled: true,
1751            insecure_skip_verify: true,
1752            ..Default::default()
1753        };
1754        let audit = TlsSecurityAudit::audit(&config);
1755
1756        assert!(audit.errors.iter().any(|e| e.contains("MITM")));
1757    }
1758
1759    #[test]
1760    fn test_tls_security_audit_production_ready() {
1761        let (_cert, _key) = generate_self_signed("broker.rivven.local").unwrap();
1762
1763        let config = TlsConfig {
1764            enabled: true,
1765            certificate: Some(CertificateSource::SelfSigned {
1766                common_name: "broker.rivven.local".to_string(),
1767            }),
1768            mtls_mode: MtlsMode::Required,
1769            min_version: TlsVersion::Tls13,
1770            insecure_skip_verify: false,
1771            session_cache_size: 256,
1772            ..Default::default()
1773        };
1774
1775        let audit = TlsSecurityAudit::audit(&config);
1776
1777        // Should have no errors for a well-configured setup
1778        // (Note: mTLS Required without client_ca would fail at runtime, but audit catches config issues)
1779        assert!(audit.errors.is_empty() || audit.errors.iter().all(|e| !e.contains("disabled")));
1780    }
1781
1782    #[test]
1783    fn test_mtls_modes() {
1784        assert_eq!(MtlsMode::default(), MtlsMode::Disabled);
1785
1786        let modes = [MtlsMode::Disabled, MtlsMode::Optional, MtlsMode::Required];
1787        for mode in modes {
1788            let json = serde_json::to_string(&mode).unwrap();
1789            let parsed: MtlsMode = serde_json::from_str(&json).unwrap();
1790            assert_eq!(mode, parsed);
1791        }
1792    }
1793
1794    #[test]
1795    fn test_tls_identity_extraction() {
1796        let (cert, _) = generate_self_signed("service.rivven.internal").unwrap();
1797        let identity = TlsIdentity::from_certificate(&cert);
1798
1799        assert_eq!(
1800            identity.common_name,
1801            Some("service.rivven.internal".to_string())
1802        );
1803        assert!(identity.is_valid);
1804        assert!(identity.valid_from.is_some());
1805        assert!(identity.valid_until.is_some());
1806        assert!(!identity.fingerprint.is_empty());
1807    }
1808}