1#[cfg(test)]
7use std::collections::HashSet;
8use std::{
9 collections::HashMap,
10 fmt,
11 str::FromStr,
12 sync::{Arc, LazyLock, Mutex},
13};
14
15use rustls::{
16 pki_types::{CertificateDer, PrivateKeyDer, pem::PemObject},
17 server::{ClientHello, ResolvesServerCert},
18 sign::CertifiedKey,
19};
20
21use crate::crypto::any_supported_type;
22use sha2::{Digest, Sha256};
23use sozu_command::{
24 certificate::{
25 CertificateError, Fingerprint, get_cn_and_san_attributes, parse_pem, parse_x509,
26 split_certificate_chain,
27 },
28 logging::ansi_palette,
29 proto::command::{AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress},
30};
31
32use crate::metrics::names;
33use crate::router::pattern_trie::{Key, KeyValue, TrieNode};
34
35macro_rules! log_module_context {
44 () => {{
45 let (open, reset, _, _, _) = ansi_palette();
46 format!(
47 "{open}TLS-RESOLVER{reset}\t >>>",
48 open = open,
49 reset = reset
50 )
51 }};
52}
53
54static DEFAULT_CERTIFICATE: LazyLock<Option<Arc<CertifiedKey>>> = LazyLock::new(|| {
58 let add = AddCertificate {
59 certificate: CertificateAndKey {
60 certificate: include_str!("../assets/certificate.pem").to_string(),
61 certificate_chain: vec![include_str!("../assets/certificate_chain.pem").to_string()],
62 key: include_str!("../assets/key.pem").to_string(),
63 versions: vec![],
64 names: vec![],
65 },
66 address: SocketAddress::new_v4(0, 0, 0, 0, 8080), expired_at: None,
68 };
69 CertifiedKeyWrapper::try_from(&add).ok().map(|c| c.inner)
70});
71
72#[derive(thiserror::Error, Debug)]
73pub enum CertificateResolverError {
74 #[error("failed to get common name and subject alternate names from pem, {0}")]
75 InvalidCommonNameAndSubjectAlternateNames(CertificateError),
76 #[error("invalid private key: {0}")]
77 InvalidPrivateKey(String),
78 #[error("empty key")]
79 EmptyKeys,
80 #[error("error parsing x509 cert from bytes: {0}")]
81 ParseX509(CertificateError),
82 #[error("error parsing pem formated certificate from bytes: {0}")]
83 ParsePem(CertificateError),
84 #[error("error parsing overriding names in new certificate: {0}")]
85 ParseOverridingNames(CertificateError),
86}
87
88#[derive(Clone, Debug)]
92pub struct CertifiedKeyWrapper {
93 inner: Arc<CertifiedKey>,
94 names: Vec<String>,
96 expiration: i64,
97 fingerprint: Fingerprint,
98}
99
100impl TryFrom<&AddCertificate> for CertifiedKeyWrapper {
103 type Error = CertificateResolverError;
104
105 fn try_from(add: &AddCertificate) -> Result<Self, Self::Error> {
106 let cert = add.certificate.clone();
107
108 let pem =
109 parse_pem(cert.certificate.as_bytes()).map_err(CertificateResolverError::ParsePem)?;
110
111 let x509 = parse_x509(&pem.contents).map_err(CertificateResolverError::ParseX509)?;
112
113 let overriding_names = if add.certificate.names.is_empty() {
114 get_cn_and_san_attributes(&x509)
115 } else {
116 add.certificate.names.clone()
117 };
118
119 let expiration = add
120 .expired_at
121 .unwrap_or(x509.validity().not_after.timestamp());
122
123 let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
124
125 let leaf_der = pem.contents;
135 let mut chain = vec![CertificateDer::from(leaf_der.to_owned())];
136 let mut dropped_duplicates = 0usize;
137 for cert in &cert.certificate_chain {
138 for split_pem in split_certificate_chain(cert.to_owned()) {
139 let chain_link = parse_pem(split_pem.as_bytes())
140 .map_err(CertificateResolverError::ParsePem)?
141 .contents;
142
143 if chain_link == leaf_der {
144 dropped_duplicates += 1;
145 continue;
146 }
147 chain.push(CertificateDer::from(chain_link));
148 }
149 }
150 if dropped_duplicates > 0 {
151 debug!(
152 "{} dropped {} duplicate leaf certificate(s) from the supplied chain",
153 log_module_context!(),
154 dropped_duplicates
155 );
156 }
157
158 let private_key = PrivateKeyDer::from_pem_slice(cert.key.as_bytes())
167 .map_err(|_| CertificateResolverError::EmptyKeys)?;
168
169 match any_supported_type(&private_key) {
170 Ok(signing_key) => {
171 let stored_certificate = CertifiedKeyWrapper {
172 inner: Arc::new(CertifiedKey::new(chain, signing_key)),
173 names: overriding_names,
174 expiration,
175 fingerprint,
176 };
177 Ok(stored_certificate)
178 }
179 Err(sign_error) => Err(CertificateResolverError::InvalidPrivateKey(
180 sign_error.to_string(),
181 )),
182 }
183 }
184}
185
186#[derive(Default, Debug)]
193pub struct CertificateResolver {
194 pub domains: TrieNode<Fingerprint>,
196 certificates: HashMap<Fingerprint, CertifiedKeyWrapper>,
198 name_fingerprint_idx: HashMap<String, Vec<(Fingerprint, i64)>>,
202}
203
204impl CertificateResolver {
205 pub fn get_certificate(&self, fingerprint: &Fingerprint) -> Option<CertifiedKeyWrapper> {
207 self.certificates.get(fingerprint).map(ToOwned::to_owned)
208 }
209
210 fn publish_min_expiration_gauge(&self) {
224 let Some(min_expiration) = self.certificates.values().map(|c| c.expiration).min() else {
225 return;
233 };
234 let clamped = min_expiration.max(0) as usize;
235 gauge!(names::tls::CERT_MIN_EXPIRES_AT_SECONDS, clamped);
236 }
237
238 pub fn add_certificate(
241 &mut self,
242 add: &AddCertificate,
243 ) -> Result<Fingerprint, CertificateResolverError> {
244 let cert_to_add = CertifiedKeyWrapper::try_from(add)?;
245
246 trace!(
247 "{} adding certificate {:?}",
248 log_module_context!(),
249 cert_to_add
250 );
251
252 if self.certificates.contains_key(&cert_to_add.fingerprint) {
253 return Ok(cert_to_add.fingerprint);
254 }
255
256 for new_name in &cert_to_add.names {
257 let fingerprints_for_this_name = self
258 .name_fingerprint_idx
259 .entry(new_name.to_owned())
260 .or_default();
261
262 fingerprints_for_this_name
263 .push((cert_to_add.fingerprint.clone(), cert_to_add.expiration));
264
265 fingerprints_for_this_name.sort_by_key(|t| t.1);
267
268 let longest_lived_cert = match fingerprints_for_this_name.last() {
269 Some(cert) => cert,
270 None => {
271 error!(
272 "{} no fingerprint for this name, this should not happen",
273 log_module_context!()
274 );
275 continue;
276 }
277 };
278
279 self.domains.remove(&new_name.to_owned().into_bytes());
281 self.domains.insert(
282 new_name.to_owned().into_bytes(),
283 longest_lived_cert.0.to_owned(),
284 );
285 }
286
287 self.certificates
288 .insert(cert_to_add.fingerprint.to_owned(), cert_to_add.clone());
289 self.publish_min_expiration_gauge();
290
291 trace!("{} {:#?}", log_module_context!(), self);
292
293 Ok(cert_to_add.fingerprint)
294 }
295
296 pub fn remove_certificate(
299 &mut self,
300 fingerprint: &Fingerprint,
301 ) -> Result<(), CertificateResolverError> {
302 if let Some(certificate_to_remove) = self.get_certificate(fingerprint) {
303 for name in certificate_to_remove.names {
304 self.domains.domain_remove(&name.as_bytes().to_vec());
305
306 if let std::collections::hash_map::Entry::Occupied(mut entry) =
307 self.name_fingerprint_idx.entry(name.to_owned())
308 {
309 entry.get_mut().retain(|t| &t.0 != fingerprint);
311
312 if let Some(longest_lived_cert) = entry.get().last() {
314 self.domains
315 .insert(name.as_bytes().to_vec(), longest_lived_cert.0.to_owned());
316 }
317
318 if entry.get().is_empty() {
320 entry.remove();
321 }
322 }
323 }
324
325 self.certificates.remove(fingerprint);
326 self.publish_min_expiration_gauge();
327 }
328 trace!("{} {:#?}", log_module_context!(), self);
329
330 Ok(())
331 }
332
333 pub fn replace_certificate(
337 &mut self,
338 replace: &ReplaceCertificate,
339 ) -> Result<Fingerprint, CertificateResolverError> {
340 let add = AddCertificate {
341 address: replace.address.to_owned(),
342 certificate: replace.new_certificate.to_owned(),
343 expired_at: replace.new_expired_at.to_owned(),
344 };
345
346 let new_cert = CertifiedKeyWrapper::try_from(&add)?;
361 let new_fingerprint = new_cert.fingerprint.to_owned();
362
363 if let Ok(old_fingerprint) = Fingerprint::from_str(&replace.old_fingerprint) {
364 if old_fingerprint == new_fingerprint {
365 self.publish_min_expiration_gauge();
368 return Ok(new_fingerprint);
369 }
370 }
371
372 let new_fingerprint = self.add_certificate(&add)?;
373
374 match Fingerprint::from_str(&replace.old_fingerprint) {
375 Ok(old_fingerprint) => self.remove_certificate(&old_fingerprint)?,
376 Err(err) => {
377 warn!(
382 "{} new certificate added but could not remove old one: \
383 failed to parse old fingerprint, {}",
384 log_module_context!(),
385 err
386 );
387 }
388 }
389
390 Ok(new_fingerprint)
391 }
392
393 #[cfg(test)]
396 fn find_certificates_by_names(
397 &self,
398 names: &HashSet<String>,
399 ) -> Result<HashSet<Fingerprint>, CertificateResolverError> {
400 let mut fingerprints = HashSet::new();
401 for name in names {
402 if let Some(fprints) = self.name_fingerprint_idx.get(name) {
403 fprints.iter().for_each(|fingerprint| {
404 fingerprints.insert(fingerprint.to_owned().0);
405 });
406 }
407 }
408
409 Ok(fingerprints)
410 }
411
412 #[cfg(test)]
415 fn certificate_names(
416 &self,
417 fingerprint: &Fingerprint,
418 ) -> Result<HashSet<String>, CertificateResolverError> {
419 if let Some(cert) = self.certificates.get(fingerprint) {
420 return Ok(cert.names.iter().cloned().collect());
421 }
422 Ok(HashSet::new())
423 }
424
425 pub fn domain_lookup(
426 &self,
427 domain: &[u8],
428 accept_wildcard: bool,
429 ) -> Option<&KeyValue<Key, Fingerprint>> {
430 self.domains.domain_lookup(domain, accept_wildcard)
431 }
432
433 pub fn names_for_sni(&self, domain: &[u8]) -> Option<Vec<String>> {
445 let (_, fingerprint) = self.domain_lookup(domain, true)?;
446 self.certificates
447 .get(fingerprint)
448 .map(|cert| cert.names.clone())
449 }
450}
451
452#[derive(Default)]
456pub struct MutexCertificateResolver(pub Mutex<CertificateResolver>);
457
458impl ResolvesServerCert for MutexCertificateResolver {
459 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
460 let server_name = client_hello.server_name();
461 let sigschemes = client_hello.signature_schemes();
462
463 let Some(name) = server_name else {
464 error!(
465 "{} cannot look up certificate: no SNI from session",
466 log_module_context!()
467 );
468 return None;
469 };
470 trace!(
471 "{} trying to resolve name: {:?} for signature scheme: {:?}",
472 log_module_context!(),
473 name,
474 sigschemes
475 );
476 let resolver = match self.0.lock() {
486 Ok(guard) => guard,
487 Err(poisoned) => {
488 error!(
489 "{} cert resolver mutex poisoned, returning default cert: {:?}",
490 log_module_context!(),
491 poisoned
492 );
493 return DEFAULT_CERTIFICATE.clone();
494 }
495 };
496 if let Some((_, fingerprint)) = resolver.domains.domain_lookup(name.as_bytes(), true) {
497 trace!(
498 "{} looking for certificate for {:?} with fingerprint {:?}",
499 log_module_context!(),
500 name,
501 fingerprint
502 );
503
504 let cert = resolver
505 .certificates
506 .get(fingerprint)
507 .map(|cert| cert.inner.clone());
508
509 trace!(
510 "{} found for fingerprint {}: {}",
511 log_module_context!(),
512 fingerprint,
513 cert.is_some()
514 );
515 return cert;
516 }
517 drop(resolver);
518
519 debug!(
523 "{} default certificate is used for {}",
524 log_module_context!(),
525 name
526 );
527 incr!(names::tls::DEFAULT_CERT_USED);
528 DEFAULT_CERTIFICATE.clone()
529 }
530}
531
532impl MutexCertificateResolver {
533 pub fn names_for_sni(&self, domain: &[u8]) -> Option<Vec<String>> {
539 match self.0.lock() {
540 Ok(guard) => guard.names_for_sni(domain),
541 Err(poisoned) => {
542 error!(
543 "{} cert resolver mutex poisoned, treating as no SAN match: {:?}",
544 log_module_context!(),
545 poisoned
546 );
547 None
548 }
549 }
550 }
551}
552
553impl fmt::Debug for MutexCertificateResolver {
554 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
555 f.write_str("MutexWrappedCertificateResolver")
556 }
557}
558
559#[cfg(test)]
563mod tests {
564 use std::{
565 collections::HashSet,
566 error::Error,
567 time::{Duration, SystemTime},
568 };
569
570 use sozu_command::proto::command::{
572 AddCertificate, CertificateAndKey, ReplaceCertificate, SocketAddress,
573 };
574
575 use super::CertificateResolver;
576
577 #[test]
578 fn lifecycle() -> Result<(), Box<dyn Error + Send + Sync>> {
579 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
580 let mut resolver = CertificateResolver::default();
581 let certificate_and_key = CertificateAndKey {
582 certificate: String::from(include_str!("../assets/certificate.pem")),
583 key: String::from(include_str!("../assets/key.pem")),
584 ..Default::default()
585 };
586
587 let fingerprint = resolver
588 .add_certificate(&AddCertificate {
589 address,
590 certificate: certificate_and_key,
591 expired_at: None,
592 })
593 .expect("could not add certificate");
594
595 if resolver.get_certificate(&fingerprint).is_none() {
596 return Err("failed to retrieve certificate".into());
597 }
598
599 let names = resolver.certificate_names(&fingerprint)?;
601
602 if let Err(err) = resolver.remove_certificate(&fingerprint) {
603 return Err(format!("the certificate was not removed, {err}").into());
604 }
605
606 if resolver.get_certificate(&fingerprint).is_some() {
607 return Err("We have retrieved the certificate that should be deleted".into());
608 }
609
610 if !resolver.find_certificates_by_names(&names)?.is_empty() {
611 return Err(
612 "The certificate should be deleted but one of its names is in the index".into(),
613 );
614 }
615
616 Ok(())
617 }
618
619 #[test]
620 fn name_override() -> Result<(), Box<dyn Error + Send + Sync>> {
621 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
622 let mut resolver = CertificateResolver::default();
623 let certificate_and_key = CertificateAndKey {
624 certificate: String::from(include_str!("../assets/certificate.pem")),
625 key: String::from(include_str!("../assets/key.pem")),
626 names: vec!["localhost".into(), "lolcatho.st".into()],
627 ..Default::default()
628 };
629
630 let fingerprint = resolver.add_certificate(&AddCertificate {
631 address,
632 certificate: certificate_and_key,
633 expired_at: None,
634 })?;
635
636 if resolver.get_certificate(&fingerprint).is_none() {
637 return Err("failed to retrieve certificate".into());
638 }
639
640 let mut lolcat = HashSet::new();
641 lolcat.insert(String::from("lolcatho.st"));
642 if resolver.find_certificates_by_names(&lolcat)?.is_empty()
643 || resolver.get_certificate(&fingerprint).is_none()
644 {
645 return Err("failed to retrieve certificate with custom names".into());
646 }
647
648 if let Err(err) = resolver.remove_certificate(&fingerprint) {
649 return Err(format!("the certificate could not be removed, {err}").into());
650 }
651
652 let names = resolver.certificate_names(&fingerprint)?;
653 if !resolver.find_certificates_by_names(&names)?.is_empty()
654 && resolver.get_certificate(&fingerprint).is_some()
655 {
656 return Err("We have retrieved the certificate that should be deleted".into());
657 }
658
659 Ok(())
660 }
661
662 #[test]
663 fn keep_resolving_with_wildcard() -> Result<(), Box<dyn Error + Send + Sync>> {
664 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
665 let mut resolver = CertificateResolver::default();
666
667 let wildcard_example_org = CertificateAndKey {
670 certificate: String::from(include_str!("../assets/tests/certificate-3.pem")),
671 key: String::from(include_str!("../assets/tests/key.pem")),
672 ..Default::default()
673 };
674
675 let wildcard_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
676 address,
677 certificate: wildcard_example_org,
678 expired_at: Some(
679 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
680 + Duration::from_secs(365 * 24 * 3600))
681 .as_secs() as i64,
682 ),
683 })?;
684
685 if resolver
686 .get_certificate(&wildcard_example_org_fingerprint)
687 .is_none()
688 {
689 return Err("could not load the 2-year-valid certificate".into());
690 }
691
692 let www_example_org = CertificateAndKey {
696 certificate: String::from(include_str!("../assets/tests/certificate-2.pem")),
697 key: String::from(include_str!("../assets/tests/key.pem")),
698 ..Default::default()
699 };
700
701 let www_example_org_fingerprint = resolver.add_certificate(&AddCertificate {
702 address,
703 certificate: www_example_org,
704 expired_at: Some(
705 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
706 + Duration::from_secs(2 * 365 * 24 * 3600))
707 .as_secs() as i64,
708 ),
709 })?;
710
711 let www_example_org = resolver
712 .domain_lookup("www.example.org".as_bytes(), true)
713 .expect("there should be a www.example.org cert");
714 assert_eq!(www_example_org.1, www_example_org_fingerprint);
715
716 let test_example_org = resolver
717 .domain_lookup("test.example.org".as_bytes(), true)
718 .expect("there should be a test.example.org cert");
719 assert_eq!(test_example_org.1, wildcard_example_org_fingerprint);
720
721 let example_org = resolver
722 .domain_lookup("example.org".as_bytes(), true)
723 .expect("there should be a example.org cert");
724 assert_eq!(example_org.1, www_example_org_fingerprint);
725
726 resolver
729 .remove_certificate(&www_example_org_fingerprint)
730 .expect("should be able to remove the 2-year certificate");
731
732 let should_be_wildcard_fingerprint = resolver
733 .domain_lookup("www.example.org".as_bytes(), true)
734 .expect("there should be a www.example.org cert");
735 assert_eq!(
736 should_be_wildcard_fingerprint.1,
737 wildcard_example_org_fingerprint
738 );
739
740 assert!(
741 resolver
742 .domain_lookup("example.org".as_bytes(), true)
743 .is_none()
744 );
745
746 Ok(())
747 }
748
749 #[test]
750 fn resolve_the_longer_lived_cert() -> Result<(), Box<dyn Error + Send + Sync>> {
751 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
752 let mut resolver = CertificateResolver::default();
753
754 let certificate_and_key_2y = CertificateAndKey {
757 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
758 key: String::from(include_str!("../assets/tests/key-2y.pem")),
759 ..Default::default()
760 };
761
762 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
763 address,
764 certificate: certificate_and_key_2y,
765 expired_at: None,
766 })?;
767
768 if resolver.get_certificate(&fingerprint_2y).is_none() {
769 return Err("could not load the 2-year-valid certificate".into());
770 }
771
772 let certificate_and_key_1y = CertificateAndKey {
775 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
776 key: String::from(include_str!("../assets/tests/key-1y.pem")),
777 ..Default::default()
778 };
779
780 let fingerprint_1y = resolver.add_certificate(&AddCertificate {
781 address,
782 certificate: certificate_and_key_1y,
783 ..Default::default()
784 })?;
785
786 let localhost_cert = resolver
787 .domain_lookup("localhost".as_bytes(), true)
788 .expect("there should be a localhost cert");
789
790 assert_eq!(localhost_cert.1, fingerprint_2y);
791
792 resolver
796 .remove_certificate(&fingerprint_2y)
797 .expect("should be able to remove the 2-year certificate");
798
799 let localhost_cert = resolver
800 .domain_lookup("localhost".as_bytes(), true)
801 .expect("there should be a localhost cert");
802
803 assert_eq!(localhost_cert.1, fingerprint_1y);
804
805 Ok(())
806 }
807
808 #[test]
809 fn expiration_override() -> Result<(), Box<dyn Error + Send + Sync>> {
810 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
811 let mut resolver = CertificateResolver::default();
812
813 let certificate_and_key_1y = CertificateAndKey {
816 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
817 key: String::from(include_str!("../assets/tests/key-1y.pem")),
818 ..Default::default()
819 };
820
821 let fingerprint_1y_overriden = resolver.add_certificate(&AddCertificate {
822 address,
823 certificate: certificate_and_key_1y,
824 expired_at: Some(
825 (SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?
826 + Duration::from_secs(3 * 365 * 24 * 3600))
827 .as_secs() as i64,
828 ),
829 })?;
830
831 if resolver
832 .get_certificate(&fingerprint_1y_overriden)
833 .is_none()
834 {
835 return Err("failed to retrieve certificate".into());
836 }
837
838 let certificate_and_key_2y = CertificateAndKey {
841 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
842 key: String::from(include_str!("../assets/tests/key-2y.pem")),
843 ..Default::default()
844 };
845
846 let fingerprint_2y = resolver.add_certificate(&AddCertificate {
847 address,
848 certificate: certificate_and_key_2y,
849 expired_at: None,
850 })?;
851
852 let localhost_cert = resolver
853 .domain_lookup("localhost".as_bytes(), true)
854 .expect("there should be a localhost cert");
855
856 assert_eq!(localhost_cert.1, fingerprint_1y_overriden);
857
858 resolver
862 .remove_certificate(&fingerprint_1y_overriden)
863 .expect("should be able to remove the 1-year (3-year-overriden) certificate");
864
865 let localhost_cert = resolver
866 .domain_lookup("localhost".as_bytes(), true)
867 .expect("there should be a localhost cert");
868
869 assert_eq!(localhost_cert.1, fingerprint_2y);
870
871 Ok(())
872 }
873
874 #[test]
877 fn replace_certificate_add_before_remove() -> Result<(), Box<dyn Error + Send + Sync>> {
878 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
879 let mut resolver = CertificateResolver::default();
880
881 let cert_1y = CertificateAndKey {
883 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
884 key: String::from(include_str!("../assets/tests/key-1y.pem")),
885 ..Default::default()
886 };
887
888 let fingerprint_1y = resolver.add_certificate(&AddCertificate {
889 address,
890 certificate: cert_1y,
891 expired_at: None,
892 })?;
893
894 assert!(
896 resolver
897 .domain_lookup("localhost".as_bytes(), true)
898 .is_some(),
899 "initial certificate should be resolvable"
900 );
901
902 let cert_2y = CertificateAndKey {
904 certificate: String::from(include_str!("../assets/tests/certificate-2y.pem")),
905 key: String::from(include_str!("../assets/tests/key-2y.pem")),
906 ..Default::default()
907 };
908
909 let new_fingerprint = resolver.replace_certificate(&ReplaceCertificate {
910 address,
911 new_certificate: cert_2y,
912 old_fingerprint: fingerprint_1y.to_string(),
913 new_expired_at: None,
914 })?;
915
916 assert!(
918 resolver.get_certificate(&fingerprint_1y).is_none(),
919 "old certificate should have been removed"
920 );
921
922 assert!(
924 resolver.get_certificate(&new_fingerprint).is_some(),
925 "new certificate should be present"
926 );
927 let resolved = resolver
928 .domain_lookup("localhost".as_bytes(), true)
929 .expect("a certificate should resolve for localhost");
930 assert_eq!(
931 resolved.1, new_fingerprint,
932 "resolved certificate should be the replacement"
933 );
934
935 Ok(())
936 }
937
938 #[test]
947 fn replace_certificate_with_same_fingerprint_is_noop()
948 -> Result<(), Box<dyn Error + Send + Sync>> {
949 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
950 let mut resolver = CertificateResolver::default();
951
952 let cert = CertificateAndKey {
953 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
954 key: String::from(include_str!("../assets/tests/key-1y.pem")),
955 ..Default::default()
956 };
957
958 let initial_fingerprint = resolver.add_certificate(&AddCertificate {
959 address,
960 certificate: cert.clone(),
961 expired_at: None,
962 })?;
963
964 let returned_fingerprint = resolver.replace_certificate(&ReplaceCertificate {
966 address,
967 new_certificate: cert,
968 old_fingerprint: initial_fingerprint.to_string(),
969 new_expired_at: None,
970 })?;
971
972 assert_eq!(
973 returned_fingerprint, initial_fingerprint,
974 "idempotent replace should return the existing fingerprint"
975 );
976
977 assert!(
978 resolver.get_certificate(&initial_fingerprint).is_some(),
979 "idempotent replace must NOT delete the existing certificate"
980 );
981
982 let resolved = resolver
983 .domain_lookup("localhost".as_bytes(), true)
984 .expect("certificate should still resolve after idempotent replace");
985 assert_eq!(
986 resolved.1, initial_fingerprint,
987 "resolver should still hand back the original fingerprint"
988 );
989
990 Ok(())
991 }
992
993 #[test]
996 fn removal_cleans_up_empty_index_entries() -> Result<(), Box<dyn Error + Send + Sync>> {
997 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
998 let mut resolver = CertificateResolver::default();
999
1000 let cert = CertificateAndKey {
1001 certificate: String::from(include_str!("../assets/tests/certificate-1y.pem")),
1002 key: String::from(include_str!("../assets/tests/key-1y.pem")),
1003 ..Default::default()
1004 };
1005
1006 let fingerprint = resolver.add_certificate(&AddCertificate {
1007 address,
1008 certificate: cert,
1009 expired_at: None,
1010 })?;
1011
1012 let names = resolver.certificate_names(&fingerprint)?;
1014 assert!(
1015 !names.is_empty(),
1016 "certificate should have at least one name"
1017 );
1018
1019 for name in &names {
1021 assert!(
1022 resolver.name_fingerprint_idx.contains_key(name),
1023 "name_fingerprint_idx should contain '{name}' before removal"
1024 );
1025 }
1026
1027 resolver.remove_certificate(&fingerprint)?;
1028
1029 for name in &names {
1031 assert!(
1032 !resolver.name_fingerprint_idx.contains_key(name),
1033 "name_fingerprint_idx should not contain empty entry for '{name}' after removal"
1034 );
1035 }
1036
1037 Ok(())
1038 }
1039
1040 #[test]
1054 fn certificate_chain_dedup_drops_duplicate_leaf() -> Result<(), Box<dyn Error + Send + Sync>> {
1055 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
1056 let mut resolver = CertificateResolver::default();
1057
1058 let leaf_pem = String::from(include_str!("../assets/certificate.pem"));
1059
1060 let cert_with_duplicated_leaf = CertificateAndKey {
1061 certificate: leaf_pem.clone(),
1062 certificate_chain: vec![leaf_pem],
1063 key: String::from(include_str!("../assets/key.pem")),
1064 ..Default::default()
1065 };
1066
1067 let fingerprint = resolver.add_certificate(&AddCertificate {
1068 address,
1069 certificate: cert_with_duplicated_leaf,
1070 expired_at: None,
1071 })?;
1072
1073 let stored = resolver
1074 .get_certificate(&fingerprint)
1075 .ok_or("resolver lost the certificate after add")?;
1076
1077 assert_eq!(
1078 stored.inner.cert.len(),
1079 1,
1080 "expected dedup to drop the duplicate leaf, got chain of {} cert(s)",
1081 stored.inner.cert.len()
1082 );
1083
1084 Ok(())
1085 }
1086
1087 #[test]
1101 fn certificate_chain_handles_multi_pem_single_entry() -> Result<(), Box<dyn Error + Send + Sync>>
1102 {
1103 let address = SocketAddress::new_v4(127, 0, 0, 1, 8080);
1104 let mut resolver = CertificateResolver::default();
1105
1106 let leaf_pem = String::from(include_str!("../assets/certificate.pem"));
1107 let multi_pem_chain_entry = format!("{leaf_pem}\n{leaf_pem}");
1110
1111 let cert = CertificateAndKey {
1112 certificate: leaf_pem,
1113 certificate_chain: vec![multi_pem_chain_entry],
1114 key: String::from(include_str!("../assets/key.pem")),
1115 ..Default::default()
1116 };
1117
1118 let fingerprint = resolver.add_certificate(&AddCertificate {
1119 address,
1120 certificate: cert,
1121 expired_at: None,
1122 })?;
1123
1124 let stored = resolver
1125 .get_certificate(&fingerprint)
1126 .ok_or("resolver lost the certificate after add")?;
1127
1128 assert_eq!(
1133 stored.inner.cert.len(),
1134 1,
1135 "expected split + dedup to leave only the leaf, got chain of {} cert(s)",
1136 stored.inner.cert.len()
1137 );
1138
1139 Ok(())
1140 }
1141}