1use std::fs;
31use std::io::BufReader;
32use std::path::Path;
33use std::sync::Arc;
34use std::time::{Duration, SystemTime};
35
36use parking_lot::RwLock;
37use rcgen::{CertificateParams, DistinguishedName, DnType, Issuer, KeyPair, SanType};
38use rustls::RootCertStore;
39use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer};
40use tokio::sync::watch;
41use tracing::{debug, error, info, warn};
42use x509_parser::prelude::*;
43
44use crate::error::{NetError, NetResult};
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum CertificateFormat {
49 Pem,
51 Der,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum PrivateKeyType {
58 Rsa,
60 Ecdsa,
62 Ed25519,
64 Pkcs8,
66}
67
68#[derive(Debug, Clone)]
70pub struct CertificateInfo {
71 pub common_name: Option<String>,
73 pub subject_alt_names: Vec<String>,
75 pub issuer: Option<String>,
77 pub serial_number: String,
79 pub not_before: SystemTime,
81 pub not_after: SystemTime,
83 pub is_ca: bool,
85 pub key_usage: Vec<String>,
87 pub extended_key_usage: Vec<String>,
89 pub fingerprint_sha256: String,
91}
92
93impl CertificateInfo {
94 pub fn is_valid(&self) -> bool {
96 let now = SystemTime::now();
97 now >= self.not_before && now <= self.not_after
98 }
99
100 pub fn time_to_expiry(&self) -> Option<Duration> {
102 SystemTime::now()
103 .duration_since(self.not_after)
104 .ok()
105 .map(|_| Duration::ZERO)
106 .or_else(|| self.not_after.duration_since(SystemTime::now()).ok())
107 }
108
109 pub fn expires_within(&self, duration: Duration) -> bool {
111 self.time_to_expiry()
112 .is_some_and(|remaining| remaining <= duration)
113 }
114}
115
116#[derive(Debug, Clone)]
118pub struct CertificateLoader {
119 validate_on_load: bool,
121}
122
123impl Default for CertificateLoader {
124 fn default() -> Self {
125 Self::new()
126 }
127}
128
129impl CertificateLoader {
130 pub fn new() -> Self {
132 Self {
133 validate_on_load: true,
134 }
135 }
136
137 pub fn without_validation() -> Self {
139 Self {
140 validate_on_load: false,
141 }
142 }
143
144 pub fn load_pem_file<P: AsRef<Path>>(
154 &self,
155 path: P,
156 ) -> NetResult<Vec<CertificateDer<'static>>> {
157 let path = path.as_ref();
158 debug!(path = %path.display(), "Loading PEM certificates from file");
159
160 let file = fs::File::open(path)
161 .map_err(|e| NetError::InvalidCertificate(format!("Failed to open PEM file: {e}")))?;
162 let mut reader = BufReader::new(file);
163
164 let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
165 .filter_map(|result| result.ok())
166 .collect();
167
168 if certs.is_empty() {
169 return Err(NetError::InvalidCertificate(
170 "No certificates found in PEM file".to_string(),
171 ));
172 }
173
174 if self.validate_on_load {
175 for cert in &certs {
176 self.validate_certificate_der(cert)?;
177 }
178 }
179
180 info!(count = certs.len(), "Loaded certificates from PEM file");
181 Ok(certs)
182 }
183
184 pub fn load_der_file<P: AsRef<Path>>(&self, path: P) -> NetResult<CertificateDer<'static>> {
194 let path = path.as_ref();
195 debug!(path = %path.display(), "Loading DER certificate from file");
196
197 let der_data = fs::read(path)
198 .map_err(|e| NetError::InvalidCertificate(format!("Failed to read DER file: {e}")))?;
199
200 let cert = CertificateDer::from(der_data);
201
202 if self.validate_on_load {
203 self.validate_certificate_der(&cert)?;
204 }
205
206 info!("Loaded DER certificate from file");
207 Ok(cert)
208 }
209
210 pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<Vec<CertificateDer<'static>>> {
212 let mut reader = BufReader::new(pem_data);
213
214 let certs: Vec<_> = rustls_pemfile::certs(&mut reader)
215 .filter_map(|result| result.ok())
216 .collect();
217
218 if certs.is_empty() {
219 return Err(NetError::InvalidCertificate(
220 "No certificates found in PEM data".to_string(),
221 ));
222 }
223
224 if self.validate_on_load {
225 for cert in &certs {
226 self.validate_certificate_der(cert)?;
227 }
228 }
229
230 Ok(certs)
231 }
232
233 pub fn load_der_bytes(&self, der_data: &[u8]) -> NetResult<CertificateDer<'static>> {
235 let cert = CertificateDer::from(der_data.to_vec());
236
237 if self.validate_on_load {
238 self.validate_certificate_der(&cert)?;
239 }
240
241 Ok(cert)
242 }
243
244 fn validate_certificate_der(&self, cert: &CertificateDer<'_>) -> NetResult<()> {
246 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
247 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
248 })?;
249
250 let now = ASN1Time::now();
252 if parsed.validity().not_before > now {
253 return Err(NetError::InvalidCertificate(
254 "Certificate is not yet valid".to_string(),
255 ));
256 }
257 if parsed.validity().not_after < now {
258 return Err(NetError::InvalidCertificate(
259 "Certificate has expired".to_string(),
260 ));
261 }
262
263 Ok(())
264 }
265
266 pub fn get_certificate_info(&self, cert: &CertificateDer<'_>) -> NetResult<CertificateInfo> {
268 let (_, parsed) = X509Certificate::from_der(cert.as_ref()).map_err(|e| {
269 NetError::InvalidCertificate(format!("Failed to parse certificate: {e}"))
270 })?;
271
272 let common_name = parsed
273 .subject()
274 .iter_common_name()
275 .next()
276 .and_then(|cn| cn.as_str().ok())
277 .map(String::from);
278
279 let issuer = parsed
280 .issuer()
281 .iter_common_name()
282 .next()
283 .and_then(|cn| cn.as_str().ok())
284 .map(String::from);
285
286 let mut subject_alt_names = Vec::new();
287 if let Ok(Some(san)) = parsed.subject_alternative_name() {
288 for name in san.value.general_names.iter() {
289 match name {
290 GeneralName::DNSName(dns) => subject_alt_names.push(dns.to_string()),
291 GeneralName::IPAddress(ip) => {
292 if ip.len() == 4 {
293 subject_alt_names
294 .push(format!("{}.{}.{}.{}", ip[0], ip[1], ip[2], ip[3]));
295 } else if ip.len() == 16 {
296 let mut parts = Vec::with_capacity(8);
298 for i in 0..8 {
299 let val = u16::from_be_bytes([ip[i * 2], ip[i * 2 + 1]]);
300 parts.push(format!("{val:x}"));
301 }
302 subject_alt_names.push(parts.join(":"));
303 }
304 }
305 GeneralName::RFC822Name(email) => subject_alt_names.push(email.to_string()),
306 GeneralName::URI(uri) => subject_alt_names.push(uri.to_string()),
307 _ => {}
308 }
309 }
310 }
311
312 let serial_number = format!("{:x}", parsed.serial);
313
314 let not_before = asn1_time_to_system_time(&parsed.validity().not_before);
315 let not_after = asn1_time_to_system_time(&parsed.validity().not_after);
316
317 let is_ca = parsed.is_ca();
318
319 let mut key_usage = Vec::new();
320 if let Ok(Some(ku)) = parsed.key_usage() {
321 let flags = ku.value;
322 if flags.digital_signature() {
323 key_usage.push("digitalSignature".to_string());
324 }
325 if flags.non_repudiation() {
326 key_usage.push("nonRepudiation".to_string());
327 }
328 if flags.key_encipherment() {
329 key_usage.push("keyEncipherment".to_string());
330 }
331 if flags.data_encipherment() {
332 key_usage.push("dataEncipherment".to_string());
333 }
334 if flags.key_agreement() {
335 key_usage.push("keyAgreement".to_string());
336 }
337 if flags.key_cert_sign() {
338 key_usage.push("keyCertSign".to_string());
339 }
340 if flags.crl_sign() {
341 key_usage.push("cRLSign".to_string());
342 }
343 }
344
345 let mut extended_key_usage = Vec::new();
346 if let Ok(Some(eku)) = parsed.extended_key_usage() {
347 for oid in eku.value.other.iter() {
348 extended_key_usage.push(oid.to_string());
349 }
350 if eku.value.any {
351 extended_key_usage.push("anyExtendedKeyUsage".to_string());
352 }
353 if eku.value.server_auth {
354 extended_key_usage.push("serverAuth".to_string());
355 }
356 if eku.value.client_auth {
357 extended_key_usage.push("clientAuth".to_string());
358 }
359 if eku.value.code_signing {
360 extended_key_usage.push("codeSigning".to_string());
361 }
362 if eku.value.email_protection {
363 extended_key_usage.push("emailProtection".to_string());
364 }
365 if eku.value.time_stamping {
366 extended_key_usage.push("timeStamping".to_string());
367 }
368 if eku.value.ocsp_signing {
369 extended_key_usage.push("ocspSigning".to_string());
370 }
371 }
372
373 use std::fmt::Write;
375 let fingerprint_sha256 = cert
376 .as_ref()
377 .iter()
378 .take(32) .fold(String::new(), |mut s, b| {
380 let _ = write!(&mut s, "{b:02x}");
381 s
382 });
383
384 Ok(CertificateInfo {
385 common_name,
386 subject_alt_names,
387 issuer,
388 serial_number,
389 not_before,
390 not_after,
391 is_ca,
392 key_usage,
393 extended_key_usage,
394 fingerprint_sha256,
395 })
396 }
397}
398
399fn asn1_time_to_system_time(time: &ASN1Time) -> SystemTime {
401 let timestamp = time.timestamp();
403 if timestamp >= 0 {
404 SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp as u64)
405 } else {
406 SystemTime::UNIX_EPOCH
408 }
409}
410
411#[derive(Debug, Clone)]
413pub struct PrivateKeyLoader;
414
415impl Default for PrivateKeyLoader {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421impl PrivateKeyLoader {
422 pub fn new() -> Self {
424 Self
425 }
426
427 pub fn load_pem_file<P: AsRef<Path>>(&self, path: P) -> NetResult<PrivateKeyDer<'static>> {
431 let path = path.as_ref();
432 debug!(path = %path.display(), "Loading private key from PEM file");
433
434 let file = fs::File::open(path)
435 .map_err(|e| NetError::InvalidCertificate(format!("Failed to open key file: {e}")))?;
436 let mut reader = BufReader::new(file);
437
438 self.load_from_reader(&mut reader)
439 }
440
441 pub fn load_pem_bytes(&self, pem_data: &[u8]) -> NetResult<PrivateKeyDer<'static>> {
443 let mut reader = BufReader::new(pem_data);
444 self.load_from_reader(&mut reader)
445 }
446
447 pub fn load_der_file<P: AsRef<Path>>(
449 &self,
450 path: P,
451 key_type: PrivateKeyType,
452 ) -> NetResult<PrivateKeyDer<'static>> {
453 let path = path.as_ref();
454 debug!(path = %path.display(), "Loading private key from DER file");
455
456 let der_data = fs::read(path)
457 .map_err(|e| NetError::InvalidCertificate(format!("Failed to read key file: {e}")))?;
458
459 self.load_der_bytes(&der_data, key_type)
460 }
461
462 pub fn load_der_bytes(
464 &self,
465 der_data: &[u8],
466 key_type: PrivateKeyType,
467 ) -> NetResult<PrivateKeyDer<'static>> {
468 let key = match key_type {
469 PrivateKeyType::Rsa => PrivateKeyDer::Pkcs1(der_data.to_vec().into()),
470 PrivateKeyType::Ecdsa | PrivateKeyType::Ed25519 => {
471 PrivateKeyDer::Sec1(der_data.to_vec().into())
472 }
473 PrivateKeyType::Pkcs8 => {
474 PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(der_data.to_vec()))
475 }
476 };
477
478 Ok(key)
479 }
480
481 pub fn load_encrypted_pem_file<P: AsRef<Path>>(
487 &self,
488 path: P,
489 _password: &str,
490 ) -> NetResult<PrivateKeyDer<'static>> {
491 warn!(
495 "Encrypted key loading: attempting to load key, password decryption may require external tools"
496 );
497 self.load_pem_file(path)
498 }
499
500 fn load_from_reader<R: std::io::BufRead>(
502 &self,
503 reader: &mut R,
504 ) -> NetResult<PrivateKeyDer<'static>> {
505 let mut original_data: Vec<u8> = Vec::new();
507 reader
508 .read_to_end(&mut original_data)
509 .map_err(|e| NetError::InvalidCertificate(format!("Failed to read key data: {e}")))?;
510
511 let mut cursor = std::io::Cursor::new(&original_data);
512
513 if let Some(Ok(key)) = rustls_pemfile::pkcs8_private_keys(&mut cursor).next() {
515 info!("Loaded PKCS#8 private key");
516 return Ok(PrivateKeyDer::Pkcs8(key));
517 }
518
519 let mut cursor = std::io::Cursor::new(&original_data);
521 if let Some(Ok(key)) = rustls_pemfile::rsa_private_keys(&mut cursor).next() {
522 info!("Loaded RSA private key");
523 return Ok(PrivateKeyDer::Pkcs1(key));
524 }
525
526 let mut cursor = std::io::Cursor::new(&original_data);
528 if let Some(Ok(key)) = rustls_pemfile::ec_private_keys(&mut cursor).next() {
529 info!("Loaded EC private key");
530 return Ok(PrivateKeyDer::Sec1(key));
531 }
532
533 Err(NetError::InvalidCertificate(
534 "No valid private key found in PEM data (tried PKCS#8, RSA, EC formats)".to_string(),
535 ))
536 }
537}
538
539#[derive(Debug, Clone)]
541pub struct SelfSignedGenerator {
542 subject_alt_names: Vec<String>,
544 common_name: String,
546 organization: Option<String>,
548 validity_days: u32,
550 is_ca: bool,
552}
553
554impl SelfSignedGenerator {
555 pub fn new(common_name: impl Into<String>) -> Self {
561 Self {
562 common_name: common_name.into(),
563 subject_alt_names: vec!["localhost".to_string()],
564 organization: None,
565 validity_days: 365,
566 is_ca: false,
567 }
568 }
569
570 pub fn with_san(mut self, san: impl Into<String>) -> Self {
572 self.subject_alt_names.push(san.into());
573 self
574 }
575
576 pub fn with_sans<I, S>(mut self, sans: I) -> Self
578 where
579 I: IntoIterator<Item = S>,
580 S: Into<String>,
581 {
582 self.subject_alt_names
583 .extend(sans.into_iter().map(|s| s.into()));
584 self
585 }
586
587 pub fn with_organization(mut self, org: impl Into<String>) -> Self {
589 self.organization = Some(org.into());
590 self
591 }
592
593 pub fn with_validity_days(mut self, days: u32) -> Self {
595 self.validity_days = days;
596 self
597 }
598
599 pub fn as_ca(mut self) -> Self {
601 self.is_ca = true;
602 self
603 }
604
605 pub fn generate(&self) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
611 let mut params = CertificateParams::default();
612
613 let mut dn = DistinguishedName::new();
615 dn.push(DnType::CommonName, &self.common_name);
616 if let Some(ref org) = self.organization {
617 dn.push(DnType::OrganizationName, org);
618 }
619 params.distinguished_name = dn;
620
621 params.subject_alt_names = self
623 .subject_alt_names
624 .iter()
625 .map(|name| {
626 if let Ok(ip) = name.parse::<std::net::IpAddr>() {
628 SanType::IpAddress(ip)
629 } else {
630 SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
631 "localhost"
632 .to_string()
633 .try_into()
634 .expect("localhost is valid DNS name")
635 }))
636 }
637 })
638 .collect();
639
640 params.not_before = rcgen::date_time_ymd(
642 chrono::Utc::now().year(),
643 chrono::Utc::now().month() as u8,
644 chrono::Utc::now().day() as u8,
645 );
646
647 let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
648 params.not_after =
649 rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
650
651 if self.is_ca {
653 params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
654 }
655
656 let key_pair = KeyPair::generate().map_err(|e| {
658 NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
659 })?;
660
661 let cert = params.self_signed(&key_pair).map_err(|e| {
663 NetError::InvalidCertificate(format!("Failed to generate certificate: {e}"))
664 })?;
665
666 let cert_der = CertificateDer::from(cert.der().to_vec());
667 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
668
669 info!(
670 common_name = %self.common_name,
671 is_ca = self.is_ca,
672 validity_days = self.validity_days,
673 "Generated self-signed certificate"
674 );
675
676 Ok((cert_der, key_der))
677 }
678
679 pub fn generate_signed_by_keypair(
684 &self,
685 ca_key_pair: &KeyPair,
686 ca_common_name: &str,
687 ) -> NetResult<(CertificateDer<'static>, PrivateKeyDer<'static>)> {
688 let mut params = CertificateParams::default();
689
690 let mut dn = DistinguishedName::new();
692 dn.push(DnType::CommonName, &self.common_name);
693 if let Some(ref org) = self.organization {
694 dn.push(DnType::OrganizationName, org);
695 }
696 params.distinguished_name = dn;
697
698 params.subject_alt_names = self
700 .subject_alt_names
701 .iter()
702 .map(|name| {
703 if let Ok(ip) = name.parse::<std::net::IpAddr>() {
704 SanType::IpAddress(ip)
705 } else {
706 SanType::DnsName(name.clone().try_into().unwrap_or_else(|_| {
707 "localhost"
708 .to_string()
709 .try_into()
710 .expect("localhost is valid DNS name")
711 }))
712 }
713 })
714 .collect();
715
716 params.not_before = rcgen::date_time_ymd(
718 chrono::Utc::now().year(),
719 chrono::Utc::now().month() as u8,
720 chrono::Utc::now().day() as u8,
721 );
722
723 let future = chrono::Utc::now() + chrono::Duration::days(self.validity_days as i64);
724 params.not_after =
725 rcgen::date_time_ymd(future.year(), future.month() as u8, future.day() as u8);
726
727 let key_pair = KeyPair::generate().map_err(|e| {
729 NetError::InvalidCertificate(format!("Failed to generate key pair: {e}"))
730 })?;
731
732 let mut ca_params = CertificateParams::default();
734 ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
735
736 let mut issuer_dn = DistinguishedName::new();
738 issuer_dn.push(DnType::CommonName, ca_common_name);
739 ca_params.distinguished_name = issuer_dn;
740
741 let issuer = Issuer::from_params(&ca_params, ca_key_pair);
743
744 let signed_cert = params.signed_by(&key_pair, &issuer).map_err(|e| {
746 NetError::InvalidCertificate(format!("Failed to sign certificate: {e}"))
747 })?;
748
749 let cert_der = CertificateDer::from(signed_cert.der().to_vec());
750 let key_der = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key_pair.serialize_der()));
751
752 info!(
753 common_name = %self.common_name,
754 "Generated CA-signed certificate"
755 );
756
757 Ok((cert_der, key_der))
758 }
759}
760
761use chrono::Datelike;
762
763#[derive(Debug)]
765pub struct CertificateStore {
766 roots: Arc<RwLock<RootCertStore>>,
768 cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
770 cert_info: Arc<RwLock<Vec<CertificateInfo>>>,
772}
773
774impl Default for CertificateStore {
775 fn default() -> Self {
776 Self::new()
777 }
778}
779
780impl Clone for CertificateStore {
781 fn clone(&self) -> Self {
782 Self {
783 roots: Arc::new(RwLock::new((*self.roots.read()).clone())),
784 cert_chain: Arc::new(RwLock::new(self.cert_chain.read().clone())),
785 cert_info: Arc::new(RwLock::new(self.cert_info.read().clone())),
786 }
787 }
788}
789
790impl CertificateStore {
791 pub fn new() -> Self {
793 Self {
794 roots: Arc::new(RwLock::new(RootCertStore::empty())),
795 cert_chain: Arc::new(RwLock::new(Vec::new())),
796 cert_info: Arc::new(RwLock::new(Vec::new())),
797 }
798 }
799
800 pub fn add_system_roots(&mut self) -> NetResult<usize> {
802 let mut roots = self.roots.write();
803 let count_before = roots.len();
804
805 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
806
807 let added = roots.len() - count_before;
808 info!(count = added, "Added system root certificates");
809 Ok(added)
810 }
811
812 pub fn add_certificate(&mut self, cert: CertificateDer<'static>) -> NetResult<()> {
814 let loader = CertificateLoader::new();
815 let info = loader.get_certificate_info(&cert)?;
816
817 if !info.is_ca {
818 warn!(common_name = ?info.common_name, "Adding non-CA certificate to root store");
819 }
820
821 {
822 let mut roots = self.roots.write();
823 roots.add(cert.clone()).map_err(|e| {
824 NetError::InvalidCertificate(format!("Failed to add certificate: {e}"))
825 })?;
826 }
827
828 {
829 let mut chain = self.cert_chain.write();
830 chain.push(cert);
831 }
832
833 {
834 let mut infos = self.cert_info.write();
835 infos.push(info);
836 }
837
838 Ok(())
839 }
840
841 pub fn add_certificates_from_file<P: AsRef<Path>>(&mut self, path: P) -> NetResult<usize> {
843 let loader = CertificateLoader::new();
844 let certs = loader.load_pem_file(path)?;
845
846 let count = certs.len();
847 for cert in certs {
848 self.add_certificate(cert)?;
849 }
850
851 Ok(count)
852 }
853
854 pub fn get_root_store(&self) -> RootCertStore {
856 self.roots.read().clone()
857 }
858
859 pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
861 self.cert_chain.read().clone()
862 }
863
864 pub fn len(&self) -> usize {
866 self.roots.read().len()
867 }
868
869 pub fn is_empty(&self) -> bool {
871 self.roots.read().is_empty()
872 }
873
874 pub fn get_certificate_infos(&self) -> Vec<CertificateInfo> {
876 self.cert_info.read().clone()
877 }
878
879 pub fn check_expiring(&self, within: Duration) -> Vec<CertificateInfo> {
881 self.cert_info
882 .read()
883 .iter()
884 .filter(|info| info.expires_within(within))
885 .cloned()
886 .collect()
887 }
888}
889
890#[derive(Debug, Clone)]
892enum PrivateKeyData {
893 Pkcs8(Vec<u8>),
894 Pkcs1(Vec<u8>),
895 Sec1(Vec<u8>),
896}
897
898impl PrivateKeyData {
899 fn from_key(key: &PrivateKeyDer<'_>) -> Self {
901 match key {
902 PrivateKeyDer::Pkcs8(k) => Self::Pkcs8(k.secret_pkcs8_der().to_vec()),
903 PrivateKeyDer::Pkcs1(k) => Self::Pkcs1(k.secret_pkcs1_der().to_vec()),
904 PrivateKeyDer::Sec1(k) => Self::Sec1(k.secret_sec1_der().to_vec()),
905 _ => Self::Pkcs8(Vec::new()), }
907 }
908
909 fn to_key(&self) -> PrivateKeyDer<'static> {
911 match self {
912 Self::Pkcs8(data) => PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(data.clone())),
913 Self::Pkcs1(data) => PrivateKeyDer::Pkcs1(data.clone().into()),
914 Self::Sec1(data) => PrivateKeyDer::Sec1(data.clone().into()),
915 }
916 }
917}
918
919pub struct HotReloadableCertificates {
923 cert_chain: Arc<RwLock<Vec<CertificateDer<'static>>>>,
925 private_key_data: Arc<RwLock<Option<PrivateKeyData>>>,
927 update_tx: watch::Sender<u64>,
929 version: Arc<RwLock<u64>>,
931 cert_path: Arc<RwLock<Option<std::path::PathBuf>>>,
933 key_path: Arc<RwLock<Option<std::path::PathBuf>>>,
935}
936
937impl std::fmt::Debug for HotReloadableCertificates {
938 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
939 f.debug_struct("HotReloadableCertificates")
940 .field("version", &*self.version.read())
941 .field("cert_count", &self.cert_chain.read().len())
942 .field("has_key", &self.private_key_data.read().is_some())
943 .finish()
944 }
945}
946
947impl Default for HotReloadableCertificates {
948 fn default() -> Self {
949 Self::new()
950 }
951}
952
953impl HotReloadableCertificates {
954 pub fn new() -> Self {
956 let (update_tx, _) = watch::channel(0u64);
957 Self {
958 cert_chain: Arc::new(RwLock::new(Vec::new())),
959 private_key_data: Arc::new(RwLock::new(None)),
960 update_tx,
961 version: Arc::new(RwLock::new(0)),
962 cert_path: Arc::new(RwLock::new(None)),
963 key_path: Arc::new(RwLock::new(None)),
964 }
965 }
966
967 pub fn load_from_files<P: AsRef<Path>>(&self, cert_path: P, key_path: P) -> NetResult<()> {
969 let cert_path = cert_path.as_ref();
970 let key_path = key_path.as_ref();
971
972 let loader = CertificateLoader::new();
973 let key_loader = PrivateKeyLoader::new();
974
975 let certs = loader.load_pem_file(cert_path)?;
976 let key = key_loader.load_pem_file(key_path)?;
977
978 {
979 let mut chain = self.cert_chain.write();
980 *chain = certs;
981 }
982
983 {
984 let mut pk = self.private_key_data.write();
985 *pk = Some(PrivateKeyData::from_key(&key));
986 }
987
988 {
989 let mut cp = self.cert_path.write();
990 *cp = Some(cert_path.to_path_buf());
991 }
992
993 {
994 let mut kp = self.key_path.write();
995 *kp = Some(key_path.to_path_buf());
996 }
997
998 self.increment_version();
999
1000 info!(
1001 cert_path = %cert_path.display(),
1002 key_path = %key_path.display(),
1003 "Loaded certificates from files"
1004 );
1005
1006 Ok(())
1007 }
1008
1009 pub fn reload(&self) -> NetResult<()> {
1011 let cert_path = self.cert_path.read().clone();
1012 let key_path = self.key_path.read().clone();
1013
1014 match (cert_path, key_path) {
1015 (Some(cp), Some(kp)) => {
1016 self.load_from_files(&cp, &kp)?;
1017 info!("Reloaded certificates");
1018 Ok(())
1019 }
1020 _ => Err(NetError::InvalidCertificate(
1021 "No certificate paths configured for reload".to_string(),
1022 )),
1023 }
1024 }
1025
1026 pub fn set_certificates(
1028 &self,
1029 certs: Vec<CertificateDer<'static>>,
1030 key: PrivateKeyDer<'static>,
1031 ) {
1032 {
1033 let mut chain = self.cert_chain.write();
1034 *chain = certs;
1035 }
1036
1037 {
1038 let mut pk = self.private_key_data.write();
1039 *pk = Some(PrivateKeyData::from_key(&key));
1040 }
1041
1042 self.increment_version();
1043 }
1044
1045 pub fn get_cert_chain(&self) -> Vec<CertificateDer<'static>> {
1047 self.cert_chain.read().clone()
1048 }
1049
1050 pub fn get_private_key(&self) -> Option<PrivateKeyDer<'static>> {
1052 self.private_key_data.read().as_ref().map(|k| k.to_key())
1053 }
1054
1055 pub fn get_version(&self) -> u64 {
1057 *self.version.read()
1058 }
1059
1060 pub fn subscribe(&self) -> watch::Receiver<u64> {
1062 self.update_tx.subscribe()
1063 }
1064
1065 fn increment_version(&self) {
1067 let mut version = self.version.write();
1068 *version += 1;
1069 let _ = self.update_tx.send(*version);
1070 }
1071
1072 pub fn start_file_watcher(
1076 self: Arc<Self>,
1077 check_interval: Duration,
1078 ) -> NetResult<tokio::task::JoinHandle<()>> {
1079 let cert_path = self.cert_path.read().clone();
1080 let key_path = self.key_path.read().clone();
1081
1082 let (cert_path, key_path) = match (cert_path, key_path) {
1083 (Some(cp), Some(kp)) => (cp, kp),
1084 _ => {
1085 return Err(NetError::InvalidCertificate(
1086 "No certificate paths configured for file watching".to_string(),
1087 ));
1088 }
1089 };
1090
1091 let handle = tokio::spawn(async move {
1092 let mut last_cert_modified = get_file_modified(&cert_path);
1093 let mut last_key_modified = get_file_modified(&key_path);
1094
1095 loop {
1096 tokio::time::sleep(check_interval).await;
1097
1098 let cert_modified = get_file_modified(&cert_path);
1099 let key_modified = get_file_modified(&key_path);
1100
1101 let cert_changed = cert_modified != last_cert_modified;
1102 let key_changed = key_modified != last_key_modified;
1103
1104 if cert_changed || key_changed {
1105 info!(
1106 cert_changed = cert_changed,
1107 key_changed = key_changed,
1108 "Detected certificate file change, reloading"
1109 );
1110
1111 match self.reload() {
1112 Ok(()) => {
1113 last_cert_modified = cert_modified;
1114 last_key_modified = key_modified;
1115 }
1116 Err(e) => {
1117 error!(error = %e, "Failed to reload certificates");
1118 }
1119 }
1120 }
1121 }
1122 });
1123
1124 Ok(handle)
1125 }
1126}
1127
1128fn get_file_modified<P: AsRef<Path>>(path: P) -> Option<SystemTime> {
1130 fs::metadata(path.as_ref())
1131 .ok()
1132 .and_then(|m| m.modified().ok())
1133}
1134
1135#[cfg(test)]
1136mod tests {
1137 use super::*;
1138 use std::env::temp_dir;
1139
1140 #[test]
1141 fn test_self_signed_generator() {
1142 let generator = SelfSignedGenerator::new("test.example.com")
1143 .with_san("localhost")
1144 .with_san("127.0.0.1")
1145 .with_organization("Test Org")
1146 .with_validity_days(30);
1147
1148 let result = generator.generate();
1149 assert!(result.is_ok());
1150
1151 let (cert, key) = result.expect("Should generate certificate");
1152 assert!(!cert.as_ref().is_empty());
1153
1154 let loader = CertificateLoader::new();
1156 let info = loader
1157 .get_certificate_info(&cert)
1158 .expect("Should parse certificate");
1159
1160 assert_eq!(info.common_name.as_deref(), Some("test.example.com"));
1161 assert!(info.is_valid());
1162 }
1163
1164 #[test]
1165 fn test_ca_certificate_generation() {
1166 let ca_generator = SelfSignedGenerator::new("Test CA")
1167 .as_ca()
1168 .with_validity_days(365);
1169
1170 let (ca_cert, ca_key) = ca_generator.generate().expect("Should generate CA");
1171
1172 let loader = CertificateLoader::new();
1173 let ca_info = loader
1174 .get_certificate_info(&ca_cert)
1175 .expect("Should parse CA certificate");
1176
1177 assert!(ca_info.is_ca);
1178 assert_eq!(ca_info.common_name.as_deref(), Some("Test CA"));
1179 }
1180
1181 #[test]
1182 fn test_certificate_store() {
1183 let mut store = CertificateStore::new();
1184
1185 let generator = SelfSignedGenerator::new("test").as_ca();
1187 let (cert, _) = generator.generate().expect("Should generate certificate");
1188
1189 assert!(store.is_empty());
1190 store.add_certificate(cert).expect("Should add certificate");
1191 assert!(!store.is_empty());
1192 assert_eq!(store.len(), 1);
1193 }
1194
1195 #[test]
1196 fn test_certificate_store_system_roots() {
1197 let mut store = CertificateStore::new();
1198 let added = store.add_system_roots().expect("Should add system roots");
1199
1200 assert!(added > 0);
1202 assert!(!store.is_empty());
1203 }
1204
1205 #[test]
1206 fn test_certificate_info_validity() {
1207 let generator = SelfSignedGenerator::new("test").with_validity_days(30);
1208
1209 let (cert, _) = generator.generate().expect("Should generate certificate");
1210
1211 let loader = CertificateLoader::new();
1212 let info = loader.get_certificate_info(&cert).expect("Should get info");
1213
1214 assert!(info.is_valid());
1215 assert!(!info.expires_within(Duration::from_secs(0)));
1216
1217 assert!(info.expires_within(Duration::from_secs(31 * 24 * 60 * 60)));
1219 }
1220
1221 #[test]
1222 fn test_hot_reloadable_certificates() {
1223 let hot_certs = HotReloadableCertificates::new();
1224
1225 let generator = SelfSignedGenerator::new("test");
1227 let (cert, key) = generator.generate().expect("Should generate certificate");
1228
1229 assert_eq!(hot_certs.get_version(), 0);
1230
1231 hot_certs.set_certificates(vec![cert], key);
1232
1233 assert_eq!(hot_certs.get_version(), 1);
1234 assert!(!hot_certs.get_cert_chain().is_empty());
1235 assert!(hot_certs.get_private_key().is_some());
1236 }
1237
1238 #[test]
1239 fn test_pem_certificate_loading() {
1240 let generator = SelfSignedGenerator::new("test");
1242 let (cert, _) = generator.generate().expect("Should generate certificate");
1243
1244 let pem_content = format!(
1246 "-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----\n",
1247 base64_encode(cert.as_ref())
1248 );
1249
1250 let temp_path = temp_dir().join("test_cert.pem");
1251 fs::write(&temp_path, &pem_content).expect("Should write temp file");
1252
1253 let loader = CertificateLoader::new();
1254 let result = loader.load_pem_file(&temp_path);
1255
1256 let _ = fs::remove_file(&temp_path);
1258
1259 assert!(result.is_ok());
1260 }
1261
1262 #[test]
1263 fn test_der_certificate_loading() {
1264 let generator = SelfSignedGenerator::new("test");
1266 let (cert, _) = generator.generate().expect("Should generate certificate");
1267
1268 let temp_path = temp_dir().join("test_cert.der");
1269 fs::write(&temp_path, cert.as_ref()).expect("Should write temp file");
1270
1271 let loader = CertificateLoader::new();
1272 let result = loader.load_der_file(&temp_path);
1273
1274 let _ = fs::remove_file(&temp_path);
1276
1277 assert!(result.is_ok());
1278 }
1279
1280 fn base64_encode(data: &[u8]) -> String {
1282 const ALPHABET: &[u8; 64] =
1283 b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1284
1285 let mut result = String::new();
1286 let mut i = 0;
1287
1288 while i < data.len() {
1289 let b1 = data[i];
1290 let b2 = data.get(i + 1).copied().unwrap_or(0);
1291 let b3 = data.get(i + 2).copied().unwrap_or(0);
1292
1293 result.push(ALPHABET[(b1 >> 2) as usize] as char);
1294 result.push(ALPHABET[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
1295
1296 if i + 1 < data.len() {
1297 result.push(ALPHABET[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char);
1298 } else {
1299 result.push('=');
1300 }
1301
1302 if i + 2 < data.len() {
1303 result.push(ALPHABET[(b3 & 0x3f) as usize] as char);
1304 } else {
1305 result.push('=');
1306 }
1307
1308 i += 3;
1309 }
1310
1311 let mut formatted = String::new();
1313 for (idx, ch) in result.chars().enumerate() {
1314 if idx > 0 && idx % 64 == 0 {
1315 formatted.push('\n');
1316 }
1317 formatted.push(ch);
1318 }
1319
1320 formatted
1321 }
1322}