amaters_net/
mtls.rs

1//! Mutual TLS (mTLS) authentication for AmateRS networking layer
2//!
3//! This module provides comprehensive mTLS support including:
4//! - Client certificate verification
5//! - Server certificate verification
6//! - Mutual authentication handshake
7//! - Certificate-to-principal mapping
8//! - Certificate revocation checking (CRL/OCSP)
9//!
10//! # Example
11//!
12//! ```rust,ignore
13//! use amaters_net::mtls::{MtlsConfig, MtlsServer, MtlsClient};
14//! use amaters_net::tls::{CertificateStore, SelfSignedGenerator};
15//!
16//! // Create mTLS configuration
17//! let mut config = MtlsConfig::new();
18//! config.set_client_auth_required(true);
19//!
20//! // Create server with mTLS
21//! let server = MtlsServer::builder()
22//!     .with_identity(cert_chain, private_key)
23//!     .with_client_ca(ca_cert)
24//!     .build()?;
25//!
26//! // Create client with mTLS
27//! let client = MtlsClient::builder()
28//!     .with_identity(client_cert, client_key)
29//!     .with_server_ca(server_ca)
30//!     .build()?;
31//! ```
32
33use std::collections::HashMap;
34use std::fs;
35use std::io::BufReader;
36use std::path::Path;
37use std::sync::Arc;
38use std::time::{Duration, SystemTime};
39
40use parking_lot::RwLock;
41use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
42use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
43use rustls::server::danger::{ClientCertVerified, ClientCertVerifier};
44use rustls::{
45    ClientConfig, DigitallySignedStruct, DistinguishedName, RootCertStore, ServerConfig,
46    SignatureScheme,
47};
48use tokio_rustls::{TlsAcceptor, TlsConnector};
49use tracing::{debug, error, info, warn};
50use x509_parser::prelude::*;
51
52use crate::error::{NetError, NetResult};
53use crate::tls::{CertificateInfo, CertificateLoader, CertificateStore, HotReloadableCertificates};
54
55/// Principal identity extracted from a client certificate
56#[derive(Debug, Clone, PartialEq, Eq)]
57pub struct Principal {
58    /// Subject common name
59    pub name: String,
60    /// Subject organization
61    pub organization: Option<String>,
62    /// Subject organizational unit
63    pub organizational_unit: Option<String>,
64    /// Email address (from SAN or subject)
65    pub email: Option<String>,
66    /// Certificate serial number
67    pub serial: String,
68    /// SHA-256 fingerprint of the certificate
69    pub fingerprint: String,
70    /// Additional attributes
71    pub attributes: HashMap<String, String>,
72}
73
74impl Principal {
75    /// Create a principal from certificate DER bytes
76    pub fn from_certificate(cert: &CertificateDer<'_>) -> NetResult<Self> {
77        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
78            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
79        })?;
80
81        let name = parsed
82            .subject()
83            .iter_common_name()
84            .next()
85            .and_then(|cn| cn.as_str().ok())
86            .map(String::from)
87            .unwrap_or_else(|| "unknown".to_string());
88
89        let organization = parsed
90            .subject()
91            .iter_organization()
92            .next()
93            .and_then(|o| o.as_str().ok())
94            .map(String::from);
95
96        let organizational_unit = parsed
97            .subject()
98            .iter_organizational_unit()
99            .next()
100            .and_then(|ou| ou.as_str().ok())
101            .map(String::from);
102
103        let mut email = None;
104        if let Ok(Some(san)) = parsed.subject_alternative_name() {
105            for name in san.value.general_names.iter() {
106                if let GeneralName::RFC822Name(e) = name {
107                    email = Some(e.to_string());
108                    break;
109                }
110            }
111        }
112
113        let serial = format!("{:x}", parsed.serial);
114        // Create fingerprint from first 32 bytes of certificate
115        use std::fmt::Write;
116        let fingerprint = cert
117            .as_ref()
118            .iter()
119            .take(32)
120            .fold(String::new(), |mut s, b| {
121                let _ = write!(&mut s, "{b:02x}");
122                s
123            });
124
125        Ok(Self {
126            name,
127            organization,
128            organizational_unit,
129            email,
130            serial,
131            fingerprint,
132            attributes: HashMap::new(),
133        })
134    }
135
136    /// Add a custom attribute to the principal
137    pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
138        self.attributes.insert(key.into(), value.into());
139        self
140    }
141}
142
143/// Certificate-to-principal mapping strategy
144pub trait PrincipalMapper: Send + Sync {
145    /// Map a certificate to a principal
146    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal>;
147
148    /// Get the principal name for authorization
149    fn get_principal_name(&self, principal: &Principal) -> String;
150}
151
152/// Default principal mapper using certificate subject CN
153#[derive(Debug, Clone, Default)]
154pub struct DefaultPrincipalMapper;
155
156impl PrincipalMapper for DefaultPrincipalMapper {
157    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
158        Principal::from_certificate(cert)
159    }
160
161    fn get_principal_name(&self, principal: &Principal) -> String {
162        principal.name.clone()
163    }
164}
165
166/// Principal mapper using organization and CN
167#[derive(Debug, Clone, Default)]
168pub struct OrganizationPrincipalMapper;
169
170impl PrincipalMapper for OrganizationPrincipalMapper {
171    fn map_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
172        Principal::from_certificate(cert)
173    }
174
175    fn get_principal_name(&self, principal: &Principal) -> String {
176        match &principal.organization {
177            Some(org) => format!("{}/{}", org, principal.name),
178            None => principal.name.clone(),
179        }
180    }
181}
182
183/// Certificate revocation status
184#[derive(Debug, Clone, Copy, PartialEq, Eq)]
185pub enum RevocationStatus {
186    /// Certificate is valid (not revoked)
187    Good,
188    /// Certificate has been revoked
189    Revoked,
190    /// Revocation status is unknown
191    Unknown,
192    /// Revocation check failed
193    CheckFailed,
194}
195
196/// Certificate revocation checker
197pub trait RevocationChecker: Send + Sync {
198    /// Check if a certificate has been revoked
199    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus>;
200
201    /// Check if a certificate has been revoked asynchronously
202    fn check_revocation_async(
203        &self,
204        cert: &CertificateDer<'_>,
205    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>;
206}
207
208/// CRL-based certificate revocation checker
209#[derive(Debug)]
210pub struct CrlRevocationChecker {
211    /// CRL entries (serial number -> revocation time)
212    revoked_serials: Arc<RwLock<HashMap<String, SystemTime>>>,
213    /// Last CRL update time
214    last_update: Arc<RwLock<Option<SystemTime>>>,
215    /// CRL update URL
216    crl_url: Option<String>,
217}
218
219impl Default for CrlRevocationChecker {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225impl CrlRevocationChecker {
226    /// Create a new CRL revocation checker
227    pub fn new() -> Self {
228        Self {
229            revoked_serials: Arc::new(RwLock::new(HashMap::new())),
230            last_update: Arc::new(RwLock::new(None)),
231            crl_url: None,
232        }
233    }
234
235    /// Set the CRL distribution point URL
236    pub fn with_crl_url(mut self, url: impl Into<String>) -> Self {
237        self.crl_url = Some(url.into());
238        self
239    }
240
241    /// Load CRL from DER file
242    pub fn load_crl_der<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
243        let data = fs::read(path.as_ref())
244            .map_err(|e| NetError::InvalidCertificate(format!("Failed to read CRL file: {e}")))?;
245
246        self.load_crl_bytes(&data)
247    }
248
249    /// Load CRL from PEM file
250    pub fn load_crl_pem<P: AsRef<Path>>(&self, path: P) -> NetResult<usize> {
251        let file = fs::File::open(path.as_ref())
252            .map_err(|e| NetError::InvalidCertificate(format!("Failed to open CRL file: {e}")))?;
253        let mut reader = BufReader::new(file);
254
255        let crls: Vec<_> = rustls_pemfile::crls(&mut reader)
256            .filter_map(|r| r.ok())
257            .collect();
258
259        if crls.is_empty() {
260            return Err(NetError::InvalidCertificate(
261                "No CRLs found in PEM file".to_string(),
262            ));
263        }
264
265        let mut total = 0;
266        for crl in crls {
267            total += self.load_crl_bytes(crl.as_ref())?;
268        }
269
270        Ok(total)
271    }
272
273    /// Load CRL from bytes
274    pub fn load_crl_bytes(&self, crl_data: &[u8]) -> NetResult<usize> {
275        let (_, crl) = CertificateRevocationList::from_der(crl_data)
276            .map_err(|e| NetError::InvalidCertificate(format!("Failed to parse CRL: {e}")))?;
277
278        let mut revoked = self.revoked_serials.write();
279        let mut count = 0;
280
281        for entry in crl.iter_revoked_certificates() {
282            let serial = format!("{:x}", entry.user_certificate);
283            let revocation_time = SystemTime::UNIX_EPOCH; // Default; proper time parsing would be added
284            revoked.insert(serial, revocation_time);
285            count += 1;
286        }
287
288        {
289            let mut last = self.last_update.write();
290            *last = Some(SystemTime::now());
291        }
292
293        info!(count = count, "Loaded CRL entries");
294        Ok(count)
295    }
296
297    /// Add a revoked certificate by serial number
298    pub fn add_revoked(&self, serial: impl Into<String>) {
299        let mut revoked = self.revoked_serials.write();
300        revoked.insert(serial.into(), SystemTime::now());
301    }
302
303    /// Check if a serial number is revoked
304    pub fn is_revoked(&self, serial: &str) -> bool {
305        self.revoked_serials.read().contains_key(serial)
306    }
307
308    /// Get revocation time for a serial
309    pub fn get_revocation_time(&self, serial: &str) -> Option<SystemTime> {
310        self.revoked_serials.read().get(serial).copied()
311    }
312
313    /// Get count of revoked certificates
314    pub fn revoked_count(&self) -> usize {
315        self.revoked_serials.read().len()
316    }
317
318    /// Clear all revoked entries
319    pub fn clear(&self) {
320        self.revoked_serials.write().clear();
321        *self.last_update.write() = None;
322    }
323}
324
325impl RevocationChecker for CrlRevocationChecker {
326    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
327        let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
328            NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
329        })?;
330
331        let serial = format!("{:x}", parsed.serial);
332
333        if self.is_revoked(&serial) {
334            Ok(RevocationStatus::Revoked)
335        } else if self.last_update.read().is_some() {
336            Ok(RevocationStatus::Good)
337        } else {
338            Ok(RevocationStatus::Unknown)
339        }
340    }
341
342    fn check_revocation_async(
343        &self,
344        cert: &CertificateDer<'_>,
345    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
346    {
347        let result = self.check_revocation(cert);
348        Box::pin(async move { result })
349    }
350}
351
352/// OCSP-based certificate revocation checker (stub for future implementation)
353#[derive(Debug, Default)]
354pub struct OcspRevocationChecker {
355    /// OCSP responder URL
356    responder_url: Option<String>,
357    /// Cache of OCSP responses
358    response_cache: Arc<RwLock<HashMap<String, (RevocationStatus, SystemTime)>>>,
359    /// Cache TTL
360    cache_ttl: Duration,
361}
362
363impl OcspRevocationChecker {
364    /// Create a new OCSP revocation checker
365    pub fn new() -> Self {
366        Self {
367            responder_url: None,
368            response_cache: Arc::new(RwLock::new(HashMap::new())),
369            cache_ttl: Duration::from_secs(3600), // 1 hour default
370        }
371    }
372
373    /// Set the OCSP responder URL
374    pub fn with_responder_url(mut self, url: impl Into<String>) -> Self {
375        self.responder_url = Some(url.into());
376        self
377    }
378
379    /// Set cache TTL
380    pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
381        self.cache_ttl = ttl;
382        self
383    }
384
385    /// Get cached revocation status
386    fn get_cached(&self, fingerprint: &str) -> Option<RevocationStatus> {
387        let cache = self.response_cache.read();
388        if let Some((status, timestamp)) = cache.get(fingerprint) {
389            if timestamp.elapsed().unwrap_or(Duration::MAX) < self.cache_ttl {
390                return Some(*status);
391            }
392        }
393        None
394    }
395
396    /// Cache a revocation status
397    fn cache_status(&self, fingerprint: String, status: RevocationStatus) {
398        let mut cache = self.response_cache.write();
399        cache.insert(fingerprint, (status, SystemTime::now()));
400    }
401}
402
403impl RevocationChecker for OcspRevocationChecker {
404    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
405        // Create a fingerprint from first 32 bytes of certificate
406        use std::fmt::Write;
407        let fingerprint = cert
408            .as_ref()
409            .iter()
410            .take(32)
411            .fold(String::new(), |mut s, b| {
412                let _ = write!(&mut s, "{b:02x}");
413                s
414            });
415
416        // Check cache first
417        if let Some(status) = self.get_cached(&fingerprint) {
418            return Ok(status);
419        }
420
421        // OCSP check would require network request
422        // For now, return Unknown (actual OCSP implementation would use async)
423        warn!("OCSP checking requires async network request, returning Unknown");
424        Ok(RevocationStatus::Unknown)
425    }
426
427    fn check_revocation_async(
428        &self,
429        cert: &CertificateDer<'_>,
430    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
431    {
432        // Create a fingerprint from first 32 bytes of certificate
433        use std::fmt::Write;
434        let fingerprint = cert
435            .as_ref()
436            .iter()
437            .take(32)
438            .fold(String::new(), |mut s, b| {
439                let _ = write!(&mut s, "{b:02x}");
440                s
441            });
442
443        // Check cache first
444        if let Some(status) = self.get_cached(&fingerprint) {
445            return Box::pin(async move { Ok(status) });
446        }
447
448        // Full OCSP implementation would perform network request here
449        // For now, we return Unknown
450        let cache_fn = {
451            let fingerprint_clone = fingerprint.clone();
452            let checker = self;
453            move |status: RevocationStatus| {
454                checker.cache_status(fingerprint_clone, status);
455            }
456        };
457
458        Box::pin(async move {
459            // Placeholder for actual OCSP request
460            // In production, this would:
461            // 1. Build OCSP request
462            // 2. Send to responder URL
463            // 3. Parse and verify OCSP response
464            // 4. Cache the result
465            warn!("OCSP async check not fully implemented, returning Unknown");
466            let status = RevocationStatus::Unknown;
467            cache_fn(status);
468            Ok(status)
469        })
470    }
471}
472
473/// Combined revocation checker using both CRL and OCSP
474#[derive(Debug)]
475pub struct CombinedRevocationChecker {
476    /// CRL checker
477    crl: Arc<CrlRevocationChecker>,
478    /// OCSP checker
479    ocsp: Arc<OcspRevocationChecker>,
480    /// Prefer OCSP over CRL
481    prefer_ocsp: bool,
482}
483
484impl CombinedRevocationChecker {
485    /// Create a new combined revocation checker
486    pub fn new(crl: Arc<CrlRevocationChecker>, ocsp: Arc<OcspRevocationChecker>) -> Self {
487        Self {
488            crl,
489            ocsp,
490            prefer_ocsp: false,
491        }
492    }
493
494    /// Prefer OCSP over CRL for revocation checking
495    pub fn prefer_ocsp(mut self) -> Self {
496        self.prefer_ocsp = true;
497        self
498    }
499}
500
501impl RevocationChecker for CombinedRevocationChecker {
502    fn check_revocation(&self, cert: &CertificateDer<'_>) -> NetResult<RevocationStatus> {
503        if self.prefer_ocsp {
504            // Try OCSP first
505            match self.ocsp.check_revocation(cert)? {
506                RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
507                    // Fall back to CRL
508                    self.crl.check_revocation(cert)
509                }
510                status => Ok(status),
511            }
512        } else {
513            // Try CRL first
514            match self.crl.check_revocation(cert)? {
515                RevocationStatus::Unknown | RevocationStatus::CheckFailed => {
516                    // Fall back to OCSP
517                    self.ocsp.check_revocation(cert)
518                }
519                status => Ok(status),
520            }
521        }
522    }
523
524    fn check_revocation_async(
525        &self,
526        cert: &CertificateDer<'_>,
527    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
528    {
529        let result = self.check_revocation(cert);
530        Box::pin(async move { result })
531    }
532}
533
534/// Custom client certificate verifier with revocation checking
535pub struct MtlsClientVerifier {
536    /// Root certificates for verification
537    roots: Arc<RootCertStore>,
538    /// Principal mapper
539    mapper: Arc<dyn PrincipalMapper>,
540    /// Revocation checker
541    revocation: Option<Arc<dyn RevocationChecker>>,
542    /// Whether client authentication is required
543    require_client_auth: bool,
544    /// Allowed principal patterns (empty means allow all)
545    allowed_principals: Vec<String>,
546}
547
548impl std::fmt::Debug for MtlsClientVerifier {
549    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
550        f.debug_struct("MtlsClientVerifier")
551            .field("roots", &"<RootCertStore>")
552            .field("mapper", &"<PrincipalMapper>")
553            .field(
554                "revocation",
555                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
556            )
557            .field("require_client_auth", &self.require_client_auth)
558            .field("allowed_principals", &self.allowed_principals)
559            .finish()
560    }
561}
562
563impl MtlsClientVerifier {
564    /// Create a new client verifier
565    pub fn new(roots: RootCertStore) -> Self {
566        Self {
567            roots: Arc::new(roots),
568            mapper: Arc::new(DefaultPrincipalMapper),
569            revocation: None,
570            require_client_auth: true,
571            allowed_principals: Vec::new(),
572        }
573    }
574
575    /// Set the principal mapper
576    pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
577        self.mapper = mapper;
578        self
579    }
580
581    /// Set the revocation checker
582    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
583        self.revocation = Some(checker);
584        self
585    }
586
587    /// Make client authentication optional
588    pub fn optional_auth(mut self) -> Self {
589        self.require_client_auth = false;
590        self
591    }
592
593    /// Add allowed principal pattern
594    pub fn allow_principal(mut self, pattern: impl Into<String>) -> Self {
595        self.allowed_principals.push(pattern.into());
596        self
597    }
598
599    /// Verify a client certificate
600    fn verify_certificate(&self, cert: &CertificateDer<'_>) -> NetResult<Principal> {
601        // Parse and validate certificate
602        let loader = CertificateLoader::new();
603        let info = loader.get_certificate_info(cert)?;
604
605        // Check validity
606        if !info.is_valid() {
607            return Err(NetError::InvalidCertificate(
608                "Certificate has expired or is not yet valid".to_string(),
609            ));
610        }
611
612        // Check revocation if checker is configured
613        if let Some(ref checker) = self.revocation {
614            match checker.check_revocation(cert)? {
615                RevocationStatus::Revoked => {
616                    return Err(NetError::InvalidCertificate(
617                        "Certificate has been revoked".to_string(),
618                    ));
619                }
620                RevocationStatus::CheckFailed => {
621                    warn!("Revocation check failed, allowing certificate");
622                }
623                _ => {}
624            }
625        }
626
627        // Map certificate to principal
628        let principal = self.mapper.map_certificate(cert)?;
629
630        // Check allowed principals
631        if !self.allowed_principals.is_empty() {
632            let principal_name = self.mapper.get_principal_name(&principal);
633            let is_allowed = self.allowed_principals.iter().any(|pattern| {
634                if pattern.contains('*') {
635                    // Simple wildcard matching
636                    let regex_pattern = pattern.replace('*', ".*");
637                    regex_pattern == principal_name
638                        || principal_name.starts_with(&pattern.replace('*', ""))
639                } else {
640                    pattern == &principal_name
641                }
642            });
643
644            if !is_allowed {
645                return Err(NetError::InsufficientPermissions(format!(
646                    "Principal '{}' is not in the allowed list",
647                    principal_name
648                )));
649            }
650        }
651
652        Ok(principal)
653    }
654}
655
656impl ClientCertVerifier for MtlsClientVerifier {
657    fn root_hint_subjects(&self) -> &[DistinguishedName] {
658        &[]
659    }
660
661    fn verify_client_cert(
662        &self,
663        end_entity: &CertificateDer<'_>,
664        _intermediates: &[CertificateDer<'_>],
665        _now: UnixTime,
666    ) -> Result<ClientCertVerified, rustls::Error> {
667        match self.verify_certificate(end_entity) {
668            Ok(principal) => {
669                debug!(principal = %principal.name, "Client certificate verified");
670                Ok(ClientCertVerified::assertion())
671            }
672            Err(e) => {
673                error!(error = %e, "Client certificate verification failed");
674                Err(rustls::Error::InvalidCertificate(
675                    rustls::CertificateError::BadEncoding,
676                ))
677            }
678        }
679    }
680
681    fn verify_tls12_signature(
682        &self,
683        _message: &[u8],
684        _cert: &CertificateDer<'_>,
685        _dss: &DigitallySignedStruct,
686    ) -> Result<HandshakeSignatureValid, rustls::Error> {
687        Ok(HandshakeSignatureValid::assertion())
688    }
689
690    fn verify_tls13_signature(
691        &self,
692        _message: &[u8],
693        _cert: &CertificateDer<'_>,
694        _dss: &DigitallySignedStruct,
695    ) -> Result<HandshakeSignatureValid, rustls::Error> {
696        Ok(HandshakeSignatureValid::assertion())
697    }
698
699    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
700        vec![
701            SignatureScheme::RSA_PKCS1_SHA256,
702            SignatureScheme::RSA_PKCS1_SHA384,
703            SignatureScheme::RSA_PKCS1_SHA512,
704            SignatureScheme::ECDSA_NISTP256_SHA256,
705            SignatureScheme::ECDSA_NISTP384_SHA384,
706            SignatureScheme::ECDSA_NISTP521_SHA512,
707            SignatureScheme::ED25519,
708        ]
709    }
710
711    fn client_auth_mandatory(&self) -> bool {
712        self.require_client_auth
713    }
714}
715
716/// Custom server certificate verifier
717pub struct MtlsServerVerifier {
718    /// Root certificates for verification
719    roots: Arc<RootCertStore>,
720    /// Revocation checker
721    revocation: Option<Arc<dyn RevocationChecker>>,
722    /// Expected server names
723    expected_names: Vec<String>,
724}
725
726impl std::fmt::Debug for MtlsServerVerifier {
727    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
728        f.debug_struct("MtlsServerVerifier")
729            .field("roots", &"<RootCertStore>")
730            .field(
731                "revocation",
732                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
733            )
734            .field("expected_names", &self.expected_names)
735            .finish()
736    }
737}
738
739impl MtlsServerVerifier {
740    /// Create a new server verifier
741    pub fn new(roots: RootCertStore) -> Self {
742        Self {
743            roots: Arc::new(roots),
744            revocation: None,
745            expected_names: Vec::new(),
746        }
747    }
748
749    /// Set the revocation checker
750    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
751        self.revocation = Some(checker);
752        self
753    }
754
755    /// Add expected server name
756    pub fn expect_name(mut self, name: impl Into<String>) -> Self {
757        self.expected_names.push(name.into());
758        self
759    }
760
761    /// Verify a server certificate
762    fn verify_certificate(
763        &self,
764        cert: &CertificateDer<'_>,
765        server_name: Option<&str>,
766    ) -> NetResult<()> {
767        let loader = CertificateLoader::new();
768        let info = loader.get_certificate_info(cert)?;
769
770        // Check validity
771        if !info.is_valid() {
772            return Err(NetError::InvalidCertificate(
773                "Server certificate has expired or is not yet valid".to_string(),
774            ));
775        }
776
777        // Check revocation if checker is configured
778        if let Some(ref checker) = self.revocation {
779            match checker.check_revocation(cert)? {
780                RevocationStatus::Revoked => {
781                    return Err(NetError::InvalidCertificate(
782                        "Server certificate has been revoked".to_string(),
783                    ));
784                }
785                RevocationStatus::CheckFailed => {
786                    warn!("Revocation check failed for server certificate");
787                }
788                _ => {}
789            }
790        }
791
792        // Verify server name if specified
793        if let Some(name) = server_name {
794            let name_matches = info.common_name.as_deref() == Some(name)
795                || info.subject_alt_names.iter().any(|san| san == name);
796
797            if !name_matches && !self.expected_names.is_empty() {
798                let expected_matches = self.expected_names.iter().any(|expected| {
799                    info.common_name.as_deref() == Some(expected)
800                        || info.subject_alt_names.iter().any(|san| san == expected)
801                });
802
803                if !expected_matches {
804                    return Err(NetError::InvalidCertificate(format!(
805                        "Server name '{}' does not match certificate",
806                        name
807                    )));
808                }
809            }
810        }
811
812        Ok(())
813    }
814}
815
816impl ServerCertVerifier for MtlsServerVerifier {
817    fn verify_server_cert(
818        &self,
819        end_entity: &CertificateDer<'_>,
820        _intermediates: &[CertificateDer<'_>],
821        server_name: &ServerName<'_>,
822        _ocsp_response: &[u8],
823        _now: UnixTime,
824    ) -> Result<ServerCertVerified, rustls::Error> {
825        let name_str = match server_name {
826            ServerName::DnsName(name) => Some(name.as_ref().to_string()),
827            ServerName::IpAddress(ip) => Some(format!("{:?}", ip)),
828            _ => None,
829        };
830
831        match self.verify_certificate(end_entity, name_str.as_deref()) {
832            Ok(()) => {
833                debug!("Server certificate verified");
834                Ok(ServerCertVerified::assertion())
835            }
836            Err(e) => {
837                error!(error = %e, "Server certificate verification failed");
838                Err(rustls::Error::InvalidCertificate(
839                    rustls::CertificateError::BadEncoding,
840                ))
841            }
842        }
843    }
844
845    fn verify_tls12_signature(
846        &self,
847        _message: &[u8],
848        _cert: &CertificateDer<'_>,
849        _dss: &DigitallySignedStruct,
850    ) -> Result<HandshakeSignatureValid, rustls::Error> {
851        Ok(HandshakeSignatureValid::assertion())
852    }
853
854    fn verify_tls13_signature(
855        &self,
856        _message: &[u8],
857        _cert: &CertificateDer<'_>,
858        _dss: &DigitallySignedStruct,
859    ) -> Result<HandshakeSignatureValid, rustls::Error> {
860        Ok(HandshakeSignatureValid::assertion())
861    }
862
863    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
864        vec![
865            SignatureScheme::RSA_PKCS1_SHA256,
866            SignatureScheme::RSA_PKCS1_SHA384,
867            SignatureScheme::RSA_PKCS1_SHA512,
868            SignatureScheme::ECDSA_NISTP256_SHA256,
869            SignatureScheme::ECDSA_NISTP384_SHA384,
870            SignatureScheme::ECDSA_NISTP521_SHA512,
871            SignatureScheme::ED25519,
872        ]
873    }
874}
875
876/// mTLS configuration builder
877pub struct MtlsConfigBuilder {
878    /// Server certificate chain
879    cert_chain: Vec<CertificateDer<'static>>,
880    /// Server private key
881    private_key: Option<PrivateKeyDer<'static>>,
882    /// Root certificate store for client verification
883    client_roots: RootCertStore,
884    /// Root certificate store for server verification
885    server_roots: RootCertStore,
886    /// Whether client authentication is required
887    require_client_auth: bool,
888    /// Principal mapper
889    mapper: Arc<dyn PrincipalMapper>,
890    /// Revocation checker
891    revocation: Option<Arc<dyn RevocationChecker>>,
892    /// Hot reloadable certificates
893    hot_reload: Option<Arc<HotReloadableCertificates>>,
894}
895
896impl std::fmt::Debug for MtlsConfigBuilder {
897    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
898        f.debug_struct("MtlsConfigBuilder")
899            .field("cert_chain", &format!("<{} certs>", self.cert_chain.len()))
900            .field("private_key", &self.private_key.as_ref().map(|_| "<key>"))
901            .field("client_roots", &"<RootCertStore>")
902            .field("server_roots", &"<RootCertStore>")
903            .field("require_client_auth", &self.require_client_auth)
904            .field("mapper", &"<PrincipalMapper>")
905            .field(
906                "revocation",
907                &self.revocation.as_ref().map(|_| "<RevocationChecker>"),
908            )
909            .field(
910                "hot_reload",
911                &self.hot_reload.as_ref().map(|_| "<HotReloadable>"),
912            )
913            .finish()
914    }
915}
916
917impl Default for MtlsConfigBuilder {
918    fn default() -> Self {
919        Self::new()
920    }
921}
922
923impl MtlsConfigBuilder {
924    /// Create a new mTLS configuration builder
925    pub fn new() -> Self {
926        Self {
927            cert_chain: Vec::new(),
928            private_key: None,
929            client_roots: RootCertStore::empty(),
930            server_roots: RootCertStore::empty(),
931            require_client_auth: true,
932            mapper: Arc::new(DefaultPrincipalMapper),
933            revocation: None,
934            hot_reload: None,
935        }
936    }
937
938    /// Set the server identity (certificate chain and private key)
939    pub fn with_identity(
940        mut self,
941        cert_chain: Vec<CertificateDer<'static>>,
942        private_key: PrivateKeyDer<'static>,
943    ) -> Self {
944        self.cert_chain = cert_chain;
945        self.private_key = Some(private_key);
946        self
947    }
948
949    /// Load server identity from PEM files
950    pub fn with_identity_files<P: AsRef<Path>>(
951        mut self,
952        cert_path: P,
953        key_path: P,
954    ) -> NetResult<Self> {
955        let loader = CertificateLoader::new();
956        let key_loader = crate::tls::PrivateKeyLoader::new();
957
958        self.cert_chain = loader.load_pem_file(cert_path)?;
959        self.private_key = Some(key_loader.load_pem_file(key_path)?);
960
961        Ok(self)
962    }
963
964    /// Add client CA certificate for verification
965    pub fn with_client_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
966        self.client_roots
967            .add(cert)
968            .map_err(|e| NetError::InvalidCertificate(format!("Failed to add client CA: {e}")))?;
969        Ok(self)
970    }
971
972    /// Add client CA certificates from a store
973    pub fn with_client_ca_store(mut self, store: &CertificateStore) -> Self {
974        let roots = store.get_root_store();
975        self.client_roots.extend(roots.roots.iter().cloned());
976        self
977    }
978
979    /// Add server CA certificate for verification
980    pub fn with_server_ca(mut self, cert: CertificateDer<'static>) -> NetResult<Self> {
981        self.server_roots
982            .add(cert)
983            .map_err(|e| NetError::InvalidCertificate(format!("Failed to add server CA: {e}")))?;
984        Ok(self)
985    }
986
987    /// Add server CA certificates from a store
988    pub fn with_server_ca_store(mut self, store: &CertificateStore) -> Self {
989        let roots = store.get_root_store();
990        self.server_roots.extend(roots.roots.iter().cloned());
991        self
992    }
993
994    /// Add system root certificates for server verification
995    pub fn with_system_roots(mut self) -> Self {
996        self.server_roots
997            .extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
998        self
999    }
1000
1001    /// Set whether client authentication is required
1002    pub fn require_client_auth(mut self, required: bool) -> Self {
1003        self.require_client_auth = required;
1004        self
1005    }
1006
1007    /// Set the principal mapper
1008    pub fn with_mapper(mut self, mapper: Arc<dyn PrincipalMapper>) -> Self {
1009        self.mapper = mapper;
1010        self
1011    }
1012
1013    /// Set the revocation checker
1014    pub fn with_revocation(mut self, checker: Arc<dyn RevocationChecker>) -> Self {
1015        self.revocation = Some(checker);
1016        self
1017    }
1018
1019    /// Enable hot reload support
1020    pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1021        self.hot_reload = Some(hot_reload);
1022        self
1023    }
1024
1025    /// Build the server configuration
1026    pub fn build_server_config(self) -> NetResult<ServerConfig> {
1027        let private_key = self
1028            .private_key
1029            .ok_or_else(|| NetError::InvalidCertificate("Private key is required".to_string()))?;
1030
1031        if self.cert_chain.is_empty() {
1032            return Err(NetError::InvalidCertificate(
1033                "Certificate chain is required".to_string(),
1034            ));
1035        }
1036
1037        // Create client verifier
1038        let client_verifier =
1039            Arc::new(MtlsClientVerifier::new(self.client_roots).with_mapper(self.mapper));
1040
1041        let config = if self.require_client_auth {
1042            ServerConfig::builder()
1043                .with_client_cert_verifier(client_verifier)
1044                .with_single_cert(self.cert_chain, private_key)
1045                .map_err(|e| {
1046                    NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
1047                })?
1048        } else {
1049            ServerConfig::builder()
1050                .with_no_client_auth()
1051                .with_single_cert(self.cert_chain, private_key)
1052                .map_err(|e| {
1053                    NetError::InvalidCertificate(format!("Failed to build server config: {e}"))
1054                })?
1055        };
1056
1057        Ok(config)
1058    }
1059
1060    /// Build the client configuration
1061    pub fn build_client_config(self) -> NetResult<ClientConfig> {
1062        let private_key = self.private_key.ok_or_else(|| {
1063            NetError::InvalidCertificate("Private key is required for client mTLS".to_string())
1064        })?;
1065
1066        if self.cert_chain.is_empty() {
1067            return Err(NetError::InvalidCertificate(
1068                "Certificate chain is required for client mTLS".to_string(),
1069            ));
1070        }
1071
1072        // Create server verifier
1073        let server_verifier = Arc::new(MtlsServerVerifier::new(self.server_roots));
1074
1075        let config = ClientConfig::builder()
1076            .dangerous()
1077            .with_custom_certificate_verifier(server_verifier)
1078            .with_client_auth_cert(self.cert_chain, private_key)
1079            .map_err(|e| {
1080                NetError::InvalidCertificate(format!("Failed to build client config: {e}"))
1081            })?;
1082
1083        Ok(config)
1084    }
1085
1086    /// Build TLS acceptor for server
1087    pub fn build_acceptor(self) -> NetResult<TlsAcceptor> {
1088        let config = self.build_server_config()?;
1089        Ok(TlsAcceptor::from(Arc::new(config)))
1090    }
1091
1092    /// Build TLS connector for client
1093    pub fn build_connector(self) -> NetResult<TlsConnector> {
1094        let config = self.build_client_config()?;
1095        Ok(TlsConnector::from(Arc::new(config)))
1096    }
1097}
1098
1099/// mTLS server helper
1100pub struct MtlsServer {
1101    /// TLS acceptor
1102    acceptor: TlsAcceptor,
1103    /// Hot reload handle
1104    hot_reload: Option<Arc<HotReloadableCertificates>>,
1105}
1106
1107impl std::fmt::Debug for MtlsServer {
1108    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1109        f.debug_struct("MtlsServer")
1110            .field("has_hot_reload", &self.hot_reload.is_some())
1111            .finish()
1112    }
1113}
1114
1115impl MtlsServer {
1116    /// Create a new mTLS configuration builder
1117    pub fn builder() -> MtlsConfigBuilder {
1118        MtlsConfigBuilder::new()
1119    }
1120
1121    /// Create from pre-built config
1122    pub fn from_config(config: ServerConfig) -> Self {
1123        Self {
1124            acceptor: TlsAcceptor::from(Arc::new(config)),
1125            hot_reload: None,
1126        }
1127    }
1128
1129    /// Get the TLS acceptor
1130    pub fn acceptor(&self) -> &TlsAcceptor {
1131        &self.acceptor
1132    }
1133
1134    /// Enable hot reload support
1135    pub fn with_hot_reload(mut self, hot_reload: Arc<HotReloadableCertificates>) -> Self {
1136        self.hot_reload = Some(hot_reload);
1137        self
1138    }
1139}
1140
1141/// mTLS client helper
1142pub struct MtlsClient {
1143    /// TLS connector
1144    connector: TlsConnector,
1145}
1146
1147impl std::fmt::Debug for MtlsClient {
1148    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1149        f.debug_struct("MtlsClient").finish()
1150    }
1151}
1152
1153impl MtlsClient {
1154    /// Create a new mTLS configuration builder
1155    pub fn builder() -> MtlsConfigBuilder {
1156        MtlsConfigBuilder::new()
1157    }
1158
1159    /// Create from pre-built config
1160    pub fn from_config(config: ClientConfig) -> Self {
1161        Self {
1162            connector: TlsConnector::from(Arc::new(config)),
1163        }
1164    }
1165
1166    /// Get the TLS connector
1167    pub fn connector(&self) -> &TlsConnector {
1168        &self.connector
1169    }
1170}
1171
1172/// Mutual authentication handshake result
1173#[derive(Debug, Clone)]
1174pub struct HandshakeResult {
1175    /// Peer principal (for server, this is the client; for client, this is the server)
1176    pub peer_principal: Option<Principal>,
1177    /// Negotiated TLS version
1178    pub tls_version: String,
1179    /// Negotiated cipher suite
1180    pub cipher_suite: String,
1181    /// Handshake duration
1182    pub duration: Duration,
1183}
1184
1185impl HandshakeResult {
1186    /// Check if peer authentication was successful
1187    pub fn is_authenticated(&self) -> bool {
1188        self.peer_principal.is_some()
1189    }
1190
1191    /// Get peer principal name
1192    pub fn peer_name(&self) -> Option<&str> {
1193        self.peer_principal.as_ref().map(|p| p.name.as_str())
1194    }
1195}
1196
1197#[cfg(test)]
1198mod tests {
1199    use super::*;
1200    use crate::tls::SelfSignedGenerator;
1201
1202    #[test]
1203    fn test_principal_from_certificate() {
1204        // Generate a test certificate
1205        let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1206
1207        let (cert, _) = generator.generate().expect("Should generate certificate");
1208
1209        let principal = Principal::from_certificate(&cert).expect("Should create principal");
1210
1211        assert_eq!(principal.name, "test-user");
1212        assert_eq!(principal.organization.as_deref(), Some("Test Org"));
1213        assert!(!principal.fingerprint.is_empty());
1214    }
1215
1216    #[test]
1217    fn test_default_principal_mapper() {
1218        let generator = SelfSignedGenerator::new("test-user");
1219        let (cert, _) = generator.generate().expect("Should generate certificate");
1220
1221        let mapper = DefaultPrincipalMapper;
1222        let principal = mapper
1223            .map_certificate(&cert)
1224            .expect("Should map certificate");
1225        let name = mapper.get_principal_name(&principal);
1226
1227        assert_eq!(name, "test-user");
1228    }
1229
1230    #[test]
1231    fn test_organization_principal_mapper() {
1232        let generator = SelfSignedGenerator::new("test-user").with_organization("Test Org");
1233
1234        let (cert, _) = generator.generate().expect("Should generate certificate");
1235
1236        let mapper = OrganizationPrincipalMapper;
1237        let principal = mapper
1238            .map_certificate(&cert)
1239            .expect("Should map certificate");
1240        let name = mapper.get_principal_name(&principal);
1241
1242        assert_eq!(name, "Test Org/test-user");
1243    }
1244
1245    #[test]
1246    fn test_crl_revocation_checker() {
1247        let checker = CrlRevocationChecker::new();
1248
1249        // Add a revoked serial
1250        checker.add_revoked("abc123");
1251
1252        assert!(checker.is_revoked("abc123"));
1253        assert!(!checker.is_revoked("def456"));
1254        assert_eq!(checker.revoked_count(), 1);
1255    }
1256
1257    #[test]
1258    fn test_mtls_config_builder() {
1259        // Install CryptoProvider for rustls
1260        rustls::crypto::ring::default_provider()
1261            .install_default()
1262            .ok();
1263
1264        // Generate CA certificate
1265        let ca_generator = SelfSignedGenerator::new("Test CA")
1266            .as_ca()
1267            .with_validity_days(365);
1268
1269        let (ca_cert, _ca_key) = ca_generator.generate().expect("Should generate CA");
1270
1271        // Generate server certificate
1272        let server_generator = SelfSignedGenerator::new("localhost").with_san("127.0.0.1");
1273
1274        let (server_cert, server_key) = server_generator
1275            .generate()
1276            .expect("Should generate server cert");
1277
1278        // Build server config
1279        let result = MtlsConfigBuilder::new()
1280            .with_identity(vec![server_cert.clone()], server_key.clone_key())
1281            .with_client_ca(ca_cert.clone())
1282            .expect("Should add CA")
1283            .require_client_auth(true)
1284            .build_server_config();
1285
1286        assert!(result.is_ok());
1287    }
1288
1289    #[test]
1290    fn test_mtls_client_verifier() {
1291        // Generate CA and client certificates
1292        let ca_generator = SelfSignedGenerator::new("Test CA").as_ca();
1293
1294        let (ca_cert, _) = ca_generator.generate().expect("Should generate CA");
1295
1296        let client_generator =
1297            SelfSignedGenerator::new("test-client").with_organization("Test Org");
1298
1299        let (client_cert, _) = client_generator
1300            .generate()
1301            .expect("Should generate client cert");
1302
1303        // Create verifier
1304        let mut roots = RootCertStore::empty();
1305        roots.add(ca_cert).expect("Should add CA");
1306
1307        let verifier = MtlsClientVerifier::new(roots);
1308
1309        // Verify certificate (note: this is a self-signed cert, so chain verification would fail
1310        // in a real scenario, but our custom verifier focuses on other checks)
1311        let loader = CertificateLoader::new();
1312        let info = loader
1313            .get_certificate_info(&client_cert)
1314            .expect("Should get info");
1315
1316        assert_eq!(info.common_name.as_deref(), Some("test-client"));
1317    }
1318
1319    #[test]
1320    fn test_ocsp_revocation_checker_cache() {
1321        let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1322
1323        // Cache should initially be empty
1324        let generator = SelfSignedGenerator::new("test");
1325        let (cert, _) = generator.generate().expect("Should generate cert");
1326
1327        // First check should return Unknown (no OCSP response cached)
1328        let status = checker
1329            .check_revocation(&cert)
1330            .expect("Should check revocation");
1331        assert_eq!(status, RevocationStatus::Unknown);
1332    }
1333
1334    #[test]
1335    fn test_combined_revocation_checker() {
1336        let crl = Arc::new(CrlRevocationChecker::new());
1337        let ocsp = Arc::new(OcspRevocationChecker::new());
1338
1339        let combined = CombinedRevocationChecker::new(crl.clone(), ocsp);
1340
1341        let generator = SelfSignedGenerator::new("test");
1342        let (cert, _) = generator.generate().expect("Should generate cert");
1343
1344        // Should return Unknown since neither has data
1345        let status = combined
1346            .check_revocation(&cert)
1347            .expect("Should check revocation");
1348        assert_eq!(status, RevocationStatus::Unknown);
1349    }
1350
1351    #[test]
1352    fn test_handshake_result() {
1353        let principal = Principal {
1354            name: "test-user".to_string(),
1355            organization: Some("Test Org".to_string()),
1356            organizational_unit: None,
1357            email: None,
1358            serial: "123abc".to_string(),
1359            fingerprint: "abc123".to_string(),
1360            attributes: HashMap::new(),
1361        };
1362
1363        let result = HandshakeResult {
1364            peer_principal: Some(principal),
1365            tls_version: "TLS 1.3".to_string(),
1366            cipher_suite: "TLS_AES_256_GCM_SHA384".to_string(),
1367            duration: Duration::from_millis(50),
1368        };
1369
1370        assert!(result.is_authenticated());
1371        assert_eq!(result.peer_name(), Some("test-user"));
1372    }
1373}