1use crate::errors::{AuthError, Result};
9use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
10use ring::signature;
11use ring::signature::UnparsedPublicKey;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use x509_parser::{certificate::X509Certificate, parse_x509_certificate};
15
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub enum MutualTlsMethod {
19 PkiMutualTls,
21
22 SelfSignedTlsClientAuth,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct X509CertificateInfo {
29 pub thumbprint: String,
31
32 pub subject_dn: String,
34
35 pub issuer_dn: String,
37
38 pub serial_number: String,
40
41 pub not_before: chrono::DateTime<chrono::Utc>,
43 pub not_after: chrono::DateTime<chrono::Utc>,
44
45 pub san_dns: Vec<String>,
47 pub san_uri: Vec<String>,
48 pub san_email: Vec<String>,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct CertificateConfirmation {
54 #[serde(rename = "x5t#S256")]
56 pub x5t_s256: String,
57}
58
59#[derive(Debug, Clone)]
61pub struct MutualTlsClientConfig {
62 pub client_id: String,
64
65 pub auth_method: MutualTlsMethod,
67
68 pub ca_certificates: Vec<Vec<u8>>,
70
71 pub client_certificate: Option<Vec<u8>>,
73
74 pub expected_subject_dn: Option<String>,
76
77 pub certificate_bound_access_tokens: bool,
79}
80
81#[derive(Debug, Clone)]
83pub struct MutualTlsAuthResult {
84 pub client_id: String,
86
87 pub certificate_info: X509CertificateInfo,
89
90 pub is_valid: bool,
92
93 pub validation_errors: Vec<String>,
95}
96
97#[derive(Debug)]
99pub struct MutualTlsManager {
100 clients: tokio::sync::RwLock<HashMap<String, MutualTlsClientConfig>>,
102
103 ca_store: Vec<Vec<u8>>,
105}
106
107impl MutualTlsManager {
108 pub fn new() -> Self {
110 Self {
111 clients: tokio::sync::RwLock::new(HashMap::new()),
112 ca_store: Vec::new(),
113 }
114 }
115
116 pub fn add_ca_certificate(&mut self, ca_cert: Vec<u8>) -> Result<()> {
118 let (_, cert) = parse_x509_certificate(&ca_cert)
120 .map_err(|_| AuthError::auth_method("mtls", "Invalid CA certificate format"))?;
121
122 if !cert
124 .basic_constraints()
125 .map(|bc| bc.unwrap().value.ca)
126 .unwrap_or(false)
127 {
128 return Err(AuthError::auth_method(
129 "mtls",
130 "Certificate is not a CA certificate",
131 ));
132 }
133
134 self.ca_store.push(ca_cert);
135 Ok(())
136 }
137
138 pub async fn register_client(&self, config: MutualTlsClientConfig) -> Result<()> {
140 self.validate_client_config(&config)?;
141
142 let mut clients = self.clients.write().await;
143 clients.insert(config.client_id.clone(), config);
144
145 Ok(())
146 }
147
148 pub async fn authenticate_client(
150 &self,
151 client_id: &str,
152 client_certificate: &[u8],
153 ) -> Result<MutualTlsAuthResult> {
154 let clients = self.clients.read().await;
155 let client_config = clients
156 .get(client_id)
157 .ok_or_else(|| AuthError::auth_method("mtls", "Client not registered for mTLS"))?;
158
159 let (_, cert) = parse_x509_certificate(client_certificate)
161 .map_err(|_| AuthError::auth_method("mtls", "Invalid client certificate format"))?;
162
163 let cert_info = self.extract_certificate_info(&cert, client_certificate)?;
165
166 let (is_valid, validation_errors) = match client_config.auth_method {
168 MutualTlsMethod::PkiMutualTls => {
169 self.validate_pki_certificate(&cert, client_config).await
170 }
171 MutualTlsMethod::SelfSignedTlsClientAuth => {
172 self.validate_self_signed_certificate(&cert, client_config)
173 .await
174 }
175 };
176
177 Ok(MutualTlsAuthResult {
178 client_id: client_id.to_string(),
179 certificate_info: cert_info,
180 is_valid,
181 validation_errors,
182 })
183 }
184
185 pub fn create_certificate_confirmation(
187 &self,
188 client_certificate: &[u8],
189 ) -> Result<CertificateConfirmation> {
190 let thumbprint = self.calculate_certificate_thumbprint(client_certificate)?;
191
192 Ok(CertificateConfirmation {
193 x5t_s256: thumbprint,
194 })
195 }
196
197 pub fn validate_certificate_bound_token(
199 &self,
200 token_confirmation: &CertificateConfirmation,
201 client_certificate: &[u8],
202 ) -> Result<bool> {
203 let current_thumbprint = self.calculate_certificate_thumbprint(client_certificate)?;
204
205 Ok(token_confirmation.x5t_s256 == current_thumbprint)
206 }
207
208 pub async fn validate_client_certificate(
210 &self,
211 client_certificate: &[u8],
212 client_id: &str,
213 ) -> Result<()> {
214 let (_, cert) = parse_x509_certificate(client_certificate)
216 .map_err(|_| AuthError::auth_method("mtls", "Invalid client certificate format"))?;
217
218 let now = std::time::SystemTime::now()
220 .duration_since(std::time::UNIX_EPOCH)
221 .unwrap()
222 .as_secs() as i64;
223
224 if cert.validity.not_before.timestamp() > now {
225 return Err(AuthError::auth_method(
226 "mtls",
227 "Client certificate not yet valid",
228 ));
229 }
230
231 if cert.validity.not_after.timestamp() < now {
232 return Err(AuthError::auth_method(
233 "mtls",
234 "Client certificate has expired",
235 ));
236 }
237
238 if self.ca_store.is_empty() {
240 return Err(AuthError::auth_method(
241 "mtls",
242 "No trusted CA certificates configured",
243 ));
244 }
245
246 self.perform_full_chain_validation(client_certificate, client_id)
248 .await?;
249
250 let clients = self.clients.read().await;
252 if !clients.contains_key(client_id) {
253 return Err(AuthError::auth_method(
254 "mtls",
255 "Client not registered for mTLS",
256 ));
257 }
258
259 Ok(())
260 }
261
262 fn extract_certificate_info(
264 &self,
265 cert: &X509Certificate,
266 cert_der: &[u8],
267 ) -> Result<X509CertificateInfo> {
268 let thumbprint = self.calculate_certificate_thumbprint(cert_der)?;
270
271 let subject_dn = cert.subject().to_string();
273 let issuer_dn = cert.issuer().to_string();
274
275 let serial_number = hex::encode(cert.serial.to_bytes_be());
277
278 let not_before =
280 chrono::DateTime::from_timestamp(cert.validity().not_before.timestamp(), 0)
281 .unwrap_or_default();
282 let not_after = chrono::DateTime::from_timestamp(cert.validity().not_after.timestamp(), 0)
283 .unwrap_or_default();
284
285 let mut san_dns = Vec::new();
287 let mut san_uri = Vec::new();
288 let mut san_email = Vec::new();
289
290 if let Ok(Some(san_ext)) = cert.subject_alternative_name() {
292 for name in &san_ext.value.general_names {
293 match name {
294 x509_parser::extensions::GeneralName::DNSName(dns) => {
295 san_dns.push(dns.to_string());
296 }
297 x509_parser::extensions::GeneralName::URI(uri) => {
298 san_uri.push(uri.to_string());
299 }
300 x509_parser::extensions::GeneralName::RFC822Name(email) => {
301 san_email.push(email.to_string());
302 }
303 x509_parser::extensions::GeneralName::IPAddress(ip) => {
304 if ip.len() == 4 {
306 let ip_addr = format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]);
308 san_dns.push(ip_addr); } else if ip.len() == 16 {
310 let ip_addr = format!(
312 "{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}:{:02x}{:02x}",
313 ip[0],
314 ip[1],
315 ip[2],
316 ip[3],
317 ip[4],
318 ip[5],
319 ip[6],
320 ip[7],
321 ip[8],
322 ip[9],
323 ip[10],
324 ip[11],
325 ip[12],
326 ip[13],
327 ip[14],
328 ip[15]
329 );
330 san_dns.push(ip_addr);
331 }
332 }
333 _ => {
334 }
336 }
337 }
338 }
339
340 Ok(X509CertificateInfo {
341 thumbprint,
342 subject_dn,
343 issuer_dn,
344 serial_number,
345 not_before,
346 not_after,
347 san_dns,
348 san_uri,
349 san_email,
350 })
351 }
352
353 fn calculate_certificate_thumbprint(&self, cert_der: &[u8]) -> Result<String> {
355 use sha2::{Digest, Sha256};
356
357 let mut hasher = Sha256::new();
358 hasher.update(cert_der);
359 let hash = hasher.finalize();
360
361 Ok(URL_SAFE_NO_PAD.encode(hash))
362 }
363
364 async fn validate_pki_certificate(
366 &self,
367 cert: &X509Certificate<'_>,
368 client_config: &MutualTlsClientConfig,
369 ) -> (bool, Vec<String>) {
370 let mut errors = Vec::new();
371
372 let now = chrono::Utc::now().timestamp();
374 if cert.validity().not_before.timestamp() > now {
375 errors.push("Certificate is not yet valid".to_string());
376 }
377 if cert.validity().not_after.timestamp() < now {
378 errors.push("Certificate has expired".to_string());
379 }
380
381 if let Some(expected_subject) = &client_config.expected_subject_dn {
383 let actual_subject = cert.subject().to_string();
384 if !actual_subject.contains(expected_subject) {
385 errors.push(format!(
386 "Subject DN does not match expected pattern: {}",
387 expected_subject
388 ));
389 }
390 }
391
392 let mut ca_validated = false;
394 for ca_cert_der in &self.ca_store {
395 if let Ok((_, ca_cert)) = parse_x509_certificate(ca_cert_der) {
396 if cert.issuer() == ca_cert.subject() {
398 ca_validated = true;
399 break;
400 }
401 }
402 }
403
404 if !ca_validated && !self.ca_store.is_empty() {
405 errors.push("Certificate not signed by trusted CA".to_string());
406 }
407
408 if let Ok(Some(key_usage)) = cert.key_usage()
410 && !key_usage.value.digital_signature()
411 {
412 errors.push("Certificate does not allow digital signatures".to_string());
413 }
414
415 (errors.is_empty(), errors)
416 }
417
418 async fn validate_self_signed_certificate(
420 &self,
421 cert: &X509Certificate<'_>,
422 client_config: &MutualTlsClientConfig,
423 ) -> (bool, Vec<String>) {
424 let mut errors = Vec::new();
425
426 let now = chrono::Utc::now().timestamp();
428 if cert.validity().not_before.timestamp() > now {
429 errors.push("Certificate is not yet valid".to_string());
430 }
431 if cert.validity().not_after.timestamp() < now {
432 errors.push("Certificate has expired".to_string());
433 }
434
435 if let Some(registered_cert_der) = &client_config.client_certificate {
437 if let Ok((_, registered_cert)) = parse_x509_certificate(registered_cert_der) {
438 if cert.public_key().raw != registered_cert.public_key().raw {
440 errors.push("Certificate does not match registered certificate".to_string());
441 }
442 } else {
443 errors.push("Invalid registered certificate".to_string());
444 }
445 } else {
446 errors.push("No registered certificate for self-signed authentication".to_string());
447 }
448
449 if let Some(expected_subject) = &client_config.expected_subject_dn {
451 let actual_subject = cert.subject().to_string();
452 if !actual_subject.contains(expected_subject) {
453 errors.push(format!(
454 "Subject DN does not match expected pattern: {}",
455 expected_subject
456 ));
457 }
458 }
459
460 (errors.is_empty(), errors)
461 }
462
463 fn validate_client_config(&self, config: &MutualTlsClientConfig) -> Result<()> {
465 match config.auth_method {
466 MutualTlsMethod::PkiMutualTls => {
467 if config.ca_certificates.is_empty() && self.ca_store.is_empty() {
468 return Err(AuthError::auth_method(
469 "mtls",
470 "PKI authentication requires CA certificates",
471 ));
472 }
473 }
474 MutualTlsMethod::SelfSignedTlsClientAuth => {
475 if config.client_certificate.is_none() {
476 return Err(AuthError::auth_method(
477 "mtls",
478 "Self-signed authentication requires registered client certificate",
479 ));
480 }
481 }
482 }
483
484 Ok(())
485 }
486
487 async fn perform_full_chain_validation(&self, cert_der: &[u8], client_id: &str) -> Result<()> {
489 let (_, client_cert) = parse_x509_certificate(cert_der)
491 .map_err(|_| AuthError::auth_method("mtls", "Invalid client certificate format"))?;
492
493 let mut ca_validated = false;
495 let mut validation_errors = Vec::new();
496
497 for ca_der in &self.ca_store {
498 match self
499 .validate_certificate_against_ca(&client_cert, ca_der)
500 .await
501 {
502 Ok(()) => {
503 ca_validated = true;
504 break;
505 }
506 Err(e) => {
507 validation_errors.push(format!("CA validation failed: {}", e));
508 }
509 }
510 }
511
512 if !ca_validated {
513 return Err(AuthError::auth_method(
514 "mtls",
515 format!(
516 "Certificate chain validation failed. Errors: {}",
517 validation_errors.join("; ")
518 ),
519 ));
520 }
521
522 let clients = self.clients.read().await;
524 if !clients.contains_key(client_id) {
525 return Err(AuthError::auth_method(
526 "mtls",
527 "Client not registered for mTLS",
528 ));
529 }
530
531 if let Some(client_config) = clients.get(client_id)
533 && let Some(expected_cert) = &client_config.client_certificate
534 && expected_cert != cert_der {
535 return Err(AuthError::auth_method(
536 "mtls",
537 "Client certificate does not match registered certificate",
538 ));
539 }
540
541 Ok(())
542 }
543
544 async fn validate_certificate_against_ca<'a>(
546 &self,
547 client_cert: &'a X509Certificate<'a>,
548 ca_der: &[u8],
549 ) -> Result<()> {
550 let (_, ca_cert) = parse_x509_certificate(ca_der)
552 .map_err(|_| AuthError::auth_method("mtls", "Invalid CA certificate format"))?;
553
554 if client_cert.issuer() != ca_cert.subject() {
556 return Err(AuthError::auth_method(
557 "mtls",
558 "Certificate issuer does not match CA subject",
559 ));
560 }
561
562 self.verify_certificate_signature(client_cert, &ca_cert)
564 .await?;
565
566 let now = std::time::SystemTime::now()
568 .duration_since(std::time::UNIX_EPOCH)
569 .unwrap()
570 .as_secs();
571
572 let not_before = client_cert.validity().not_before.timestamp() as u64;
573 let not_after = client_cert.validity().not_after.timestamp() as u64;
574
575 if now < not_before {
576 return Err(AuthError::auth_method(
577 "mtls",
578 "Client certificate is not yet valid",
579 ));
580 }
581
582 if now > not_after {
583 return Err(AuthError::auth_method(
584 "mtls",
585 "Client certificate has expired",
586 ));
587 }
588
589 let ca_not_before = ca_cert.validity().not_before.timestamp() as u64;
591 let ca_not_after = ca_cert.validity().not_after.timestamp() as u64;
592
593 if now < ca_not_before || now > ca_not_after {
594 return Err(AuthError::auth_method(
595 "mtls",
596 "CA certificate is not valid at current time",
597 ));
598 }
599
600 Ok(())
601 }
602
603 async fn verify_certificate_signature<'a>(
605 &self,
606 client_cert: &'a X509Certificate<'a>,
607 ca_cert: &'a X509Certificate<'a>,
608 ) -> Result<()> {
609 let ca_public_key = ca_cert.public_key();
611 let ca_public_key_der = ca_public_key.raw;
612
613 let signature_algorithm = match client_cert
615 .signature_algorithm
616 .algorithm
617 .to_string()
618 .as_str()
619 {
620 "1.2.840.113549.1.1.11" => &signature::RSA_PKCS1_2048_8192_SHA256, "1.2.840.113549.1.1.12" => &signature::RSA_PKCS1_2048_8192_SHA384, "1.2.840.113549.1.1.13" => &signature::RSA_PKCS1_2048_8192_SHA512, _ => {
624 return Err(AuthError::auth_method(
625 "mtls",
626 "Unsupported signature algorithm for certificate validation",
627 ));
628 }
629 };
630
631 let public_key = UnparsedPublicKey::new(signature_algorithm, ca_public_key_der);
633
634 let tbs_certificate_der = &client_cert.tbs_certificate.as_ref();
636 let signature_value = &client_cert.signature_value.data;
637
638 public_key
640 .verify(tbs_certificate_der, signature_value)
641 .map_err(|_| {
642 AuthError::auth_method("mtls", "Certificate signature verification failed")
643 })?;
644
645 Ok(())
646 }
647}
648
649impl Default for MutualTlsManager {
650 fn default() -> Self {
651 Self::new()
652 }
653}
654
655#[cfg(test)]
656mod tests {
657 use super::*;
658
659 fn create_test_client_config() -> MutualTlsClientConfig {
660 MutualTlsClientConfig {
661 client_id: "test_client".to_string(),
662 auth_method: MutualTlsMethod::SelfSignedTlsClientAuth,
663 ca_certificates: Vec::new(),
664 client_certificate: Some(b"dummy_cert".to_vec()), expected_subject_dn: Some("CN=test_client".to_string()),
666 certificate_bound_access_tokens: true,
667 }
668 }
669
670 #[tokio::test]
671 async fn test_mtls_manager_creation() {
672 let manager = MutualTlsManager::new();
673 assert!(manager.ca_store.is_empty());
674 }
675
676 #[tokio::test]
677 async fn test_client_registration() {
678 let manager = MutualTlsManager::new();
679 let config = create_test_client_config();
680 manager.register_client(config).await.unwrap();
681 }
682
683 #[test]
684 fn test_certificate_confirmation() {
685 let manager = MutualTlsManager::new();
686
687 let cert_data = b"dummy_certificate_data";
689 let confirmation = manager.create_certificate_confirmation(cert_data).unwrap();
690
691 assert!(!confirmation.x5t_s256.is_empty());
692
693 let is_valid = manager
695 .validate_certificate_bound_token(&confirmation, cert_data)
696 .unwrap();
697 assert!(is_valid);
698
699 let different_cert = b"different_certificate_data";
701 let is_valid = manager
702 .validate_certificate_bound_token(&confirmation, different_cert)
703 .unwrap();
704 assert!(!is_valid);
705 }
706}
707
708