1use crate::error::{ProxyError, Result};
23use dashmap::DashMap;
24use rustls::pki_types::{CertificateDer, PrivateKeyDer};
25use rustls::server::{ClientHello, ResolvesServerCert};
26use rustls::sign::CertifiedKey;
27use std::io::BufReader;
28use std::sync::{Arc, RwLock};
29use tracing::{debug, trace, warn};
30
31#[derive(Debug)]
43pub struct SniCertResolver {
44 certs: DashMap<String, Arc<CertifiedKey>>,
46 default_cert: RwLock<Option<Arc<CertifiedKey>>>,
48}
49
50impl SniCertResolver {
51 #[must_use]
61 pub fn new() -> Self {
62 Self {
63 certs: DashMap::new(),
64 default_cert: RwLock::new(None),
65 }
66 }
67
68 pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
92 let certified_key = create_certified_key(cert_pem, key_pem)?;
93 let domain_normalized = normalize_domain(domain);
94
95 debug!(domain = %domain_normalized, "Loaded TLS certificate");
96 self.certs
97 .insert(domain_normalized, Arc::new(certified_key));
98
99 Ok(())
100 }
101
102 pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
125 let certified_key = create_certified_key(cert_pem, key_pem)?;
126
127 debug!("Set default TLS certificate");
128 let mut default = self.default_cert.write().expect("RwLock poisoned");
129 *default = Some(Arc::new(certified_key));
130
131 Ok(())
132 }
133
134 pub fn remove_cert(&self, domain: &str) {
146 let domain_normalized = normalize_domain(domain);
147 if self.certs.remove(&domain_normalized).is_some() {
148 debug!(domain = %domain_normalized, "Removed TLS certificate");
149 }
150 }
151
152 pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
173 let certified_key = create_certified_key(cert_pem, key_pem)?;
174 let domain_normalized = normalize_domain(domain);
175
176 debug!(domain = %domain_normalized, "Refreshed TLS certificate");
177 self.certs
178 .insert(domain_normalized, Arc::new(certified_key));
179
180 Ok(())
181 }
182
183 #[must_use]
193 pub fn has_cert(&self, domain: &str) -> bool {
194 let domain_normalized = normalize_domain(domain);
195 self.certs.contains_key(&domain_normalized)
196 }
197
198 #[must_use]
200 pub fn cert_count(&self) -> usize {
201 self.certs.len()
202 }
203
204 #[must_use]
206 pub fn domains(&self) -> Vec<String> {
207 self.certs.iter().map(|r| r.key().clone()).collect()
208 }
209
210 #[must_use]
212 pub fn has_default_cert(&self) -> bool {
213 self.default_cert.read().is_ok_and(|guard| guard.is_some())
214 }
215
216 fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
218 let server_name = server_name?;
219 let normalized = normalize_domain(server_name);
220
221 if let Some(cert) = self.certs.get(&normalized) {
223 trace!(domain = %normalized, "Exact certificate match");
224 return Some(Arc::clone(cert.value()));
225 }
226
227 if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
229 if let Some(cert) = self.certs.get(&wildcard_domain) {
230 trace!(
231 domain = %normalized,
232 wildcard = %wildcard_domain,
233 "Wildcard certificate match"
234 );
235 return Some(Arc::clone(cert.value()));
236 }
237 }
238
239 if let Ok(guard) = self.default_cert.read() {
242 if let Some(default) = guard.as_ref() {
243 trace!(domain = %normalized, "Using default certificate");
244 return Some(Arc::clone(default));
245 }
246 }
247
248 warn!(domain = %normalized, "No certificate found");
249 None
250 }
251}
252
253impl Default for SniCertResolver {
254 fn default() -> Self {
255 Self::new()
256 }
257}
258
259impl ResolvesServerCert for SniCertResolver {
260 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
261 let server_name = client_hello.server_name();
262 self.resolve_cert(server_name)
263 }
264}
265
266fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
282 let certs = parse_certificates(cert_pem)?;
284 if certs.is_empty() {
285 return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
286 }
287
288 let key = parse_private_key(key_pem)?;
290
291 let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
293 .map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
294
295 Ok(CertifiedKey::new(certs, signing_key))
296}
297
298fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
300 let mut reader = BufReader::new(pem.as_bytes());
301 let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
302 .collect::<std::result::Result<Vec<_>, _>>()
303 .map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
304
305 Ok(certs)
306}
307
308fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
310 let mut reader = BufReader::new(pem.as_bytes());
311
312 loop {
314 match rustls_pemfile::read_one(&mut reader) {
315 Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
316 return Ok(PrivateKeyDer::Pkcs1(key));
317 }
318 Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
319 return Ok(PrivateKeyDer::Pkcs8(key));
320 }
321 Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
322 return Ok(PrivateKeyDer::Sec1(key));
323 }
324 Ok(Some(_)) => {
325 }
327 Ok(None) => {
328 return Err(ProxyError::Tls("No private key found in PEM".to_string()));
329 }
330 Err(e) => {
331 return Err(ProxyError::Tls(format!(
332 "Failed to parse private key PEM: {e}"
333 )));
334 }
335 }
336 }
337}
338
339fn normalize_domain(domain: &str) -> String {
341 domain.trim().to_lowercase()
342}
343
344fn get_wildcard_domain(domain: &str) -> Option<String> {
349 let parts: Vec<&str> = domain.split('.').collect();
350 if parts.len() > 2 {
351 Some(format!("*.{}", parts[1..].join(".")))
353 } else {
354 None
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_normalize_domain() {
365 assert_eq!(normalize_domain("Example.COM"), "example.com");
366 assert_eq!(normalize_domain(" foo.bar.com "), "foo.bar.com");
367 assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
368 }
369
370 #[test]
371 fn test_get_wildcard_domain() {
372 assert_eq!(
373 get_wildcard_domain("foo.example.com"),
374 Some("*.example.com".to_string())
375 );
376 assert_eq!(
377 get_wildcard_domain("bar.foo.example.com"),
378 Some("*.foo.example.com".to_string())
379 );
380 assert_eq!(get_wildcard_domain("example.com"), None);
381 assert_eq!(get_wildcard_domain("localhost"), None);
382 }
383
384 #[test]
385 fn test_sni_resolver_new() {
386 let resolver = SniCertResolver::new();
387 assert_eq!(resolver.cert_count(), 0);
388 assert!(resolver.domains().is_empty());
389 }
390
391 #[test]
392 fn test_sni_resolver_default() {
393 let resolver = SniCertResolver::default();
394 assert_eq!(resolver.cert_count(), 0);
395 }
396
397 fn generate_test_cert() -> (String, String) {
399 use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
400
401 let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
402 let RcgenCertifiedKey { cert, key_pair } =
403 generate_simple_self_signed(subject_alt_names).unwrap();
404
405 (cert.pem(), key_pair.serialize_pem())
406 }
407
408 #[tokio::test]
409 async fn test_load_cert() {
410 let resolver = SniCertResolver::new();
411 let (cert_pem, key_pem) = generate_test_cert();
412
413 let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
414 assert!(result.is_ok());
415 assert!(resolver.has_cert("example.com"));
416 assert_eq!(resolver.cert_count(), 1);
417 }
418
419 #[tokio::test]
420 async fn test_load_cert_case_insensitive() {
421 let resolver = SniCertResolver::new();
422 let (cert_pem, key_pem) = generate_test_cert();
423
424 resolver
425 .load_cert("Example.COM", &cert_pem, &key_pem)
426 .unwrap();
427 assert!(resolver.has_cert("example.com"));
428 assert!(resolver.has_cert("EXAMPLE.COM"));
429 }
430
431 #[tokio::test]
432 async fn test_remove_cert() {
433 let resolver = SniCertResolver::new();
434 let (cert_pem, key_pem) = generate_test_cert();
435
436 resolver
437 .load_cert("example.com", &cert_pem, &key_pem)
438 .unwrap();
439 assert!(resolver.has_cert("example.com"));
440
441 resolver.remove_cert("example.com");
442 assert!(!resolver.has_cert("example.com"));
443 assert_eq!(resolver.cert_count(), 0);
444 }
445
446 #[tokio::test]
447 async fn test_refresh_cert() {
448 let resolver = SniCertResolver::new();
449 let (cert_pem, key_pem) = generate_test_cert();
450
451 resolver
453 .load_cert("example.com", &cert_pem, &key_pem)
454 .unwrap();
455
456 let (new_cert_pem, new_key_pem) = generate_test_cert();
458 let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
459 assert!(result.is_ok());
460 assert_eq!(resolver.cert_count(), 1);
461 }
462
463 #[tokio::test]
464 async fn test_set_default_cert() {
465 let resolver = SniCertResolver::new();
466 let (cert_pem, key_pem) = generate_test_cert();
467
468 let result = resolver.set_default_cert(&cert_pem, &key_pem);
469 assert!(result.is_ok());
470
471 assert_eq!(resolver.cert_count(), 0);
473 }
474
475 #[tokio::test]
476 async fn test_has_default_cert() {
477 let resolver = SniCertResolver::new();
478 assert!(!resolver.has_default_cert());
479
480 let (cert_pem, key_pem) = generate_test_cert();
481 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
482
483 assert!(resolver.has_default_cert());
484 }
485
486 #[tokio::test]
487 async fn test_domains() {
488 let resolver = SniCertResolver::new();
489 let (cert_pem, key_pem) = generate_test_cert();
490
491 resolver
492 .load_cert("api.example.com", &cert_pem, &key_pem)
493 .unwrap();
494 resolver
495 .load_cert("web.example.com", &cert_pem, &key_pem)
496 .unwrap();
497
498 let domains = resolver.domains();
499 assert_eq!(domains.len(), 2);
500 assert!(domains.contains(&"api.example.com".to_string()));
501 assert!(domains.contains(&"web.example.com".to_string()));
502 }
503
504 #[tokio::test]
505 async fn test_resolve_exact_match() {
506 let resolver = SniCertResolver::new();
507 let (cert_pem, key_pem) = generate_test_cert();
508
509 resolver
510 .load_cert("example.com", &cert_pem, &key_pem)
511 .unwrap();
512
513 let result = resolver.resolve_cert(Some("example.com"));
514 assert!(result.is_some());
515 }
516
517 #[tokio::test]
518 async fn test_resolve_wildcard_match() {
519 let resolver = SniCertResolver::new();
520 let (cert_pem, key_pem) = generate_test_cert();
521
522 resolver
524 .load_cert("*.example.com", &cert_pem, &key_pem)
525 .unwrap();
526
527 let result = resolver.resolve_cert(Some("api.example.com"));
529 assert!(result.is_some());
530
531 let result = resolver.resolve_cert(Some("web.example.com"));
532 assert!(result.is_some());
533
534 let result = resolver.resolve_cert(Some("example.com"));
536 assert!(result.is_none());
537 }
538
539 #[tokio::test]
540 async fn test_resolve_default_fallback() {
541 let resolver = SniCertResolver::new();
542 let (cert_pem, key_pem) = generate_test_cert();
543
544 resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
545
546 let result = resolver.resolve_cert(Some("unknown.com"));
548 assert!(result.is_some());
549 }
550
551 #[tokio::test]
552 async fn test_resolve_no_match() {
553 let resolver = SniCertResolver::new();
554 let (cert_pem, key_pem) = generate_test_cert();
555
556 resolver
557 .load_cert("example.com", &cert_pem, &key_pem)
558 .unwrap();
559
560 let result = resolver.resolve_cert(Some("other.com"));
562 assert!(result.is_none());
563 }
564
565 #[tokio::test]
566 async fn test_resolve_none_server_name() {
567 let resolver = SniCertResolver::new();
568
569 let result = resolver.resolve_cert(None);
571 assert!(result.is_none());
572 }
573
574 #[test]
575 fn test_invalid_cert_pem() {
576 let result = parse_certificates("not a valid PEM");
577 assert!(result.is_ok()); assert!(result.unwrap().is_empty());
579 }
580
581 #[test]
582 fn test_invalid_key_pem() {
583 let result = parse_private_key("not a valid PEM");
584 assert!(result.is_err());
585 }
586
587 #[test]
588 fn test_create_certified_key_empty_certs() {
589 let (_, key_pem) = generate_test_cert();
590 let result = create_certified_key("", &key_pem);
591 assert!(result.is_err());
592 }
593}