1use crate::error::Error;
2use moka::sync::Cache;
3use rand::{thread_rng, Rng};
4use rcgen::{
5 BasicConstraints, Certificate, CertificateParams, DistinguishedName, DnType,
6 ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, RcgenError, SanType,
7};
8use rustls::{
9 server::{ClientHello, ResolvesServerCert},
10 sign::CertifiedKey,
11};
12use std::sync::Arc;
13use time::{ext::NumericalDuration, OffsetDateTime};
14use tokio_rustls::rustls::{self, ServerConfig};
15
16const CERT_TTL_DAYS: u64 = 365;
17const CERT_CACHE_TTL_SECONDS: u64 = CERT_TTL_DAYS * 24 * 60 * 60 / 2;
18
19#[derive(Clone)]
25pub struct CertificateAuthority {
26 private_key: rustls::PrivateKey,
27 ca_cert: rustls::Certificate,
28 ca_cert_string: String,
29 cache: Cache<String, Arc<CertifiedKey>>,
30}
31
32impl CertificateAuthority {
33 pub fn gen_ca() -> Result<Certificate, RcgenError> {
34 let mut params = CertificateParams::default();
35 let mut distinguished_name = DistinguishedName::new();
36 distinguished_name.push(DnType::CommonName, "Good-MITM");
37 distinguished_name.push(DnType::OrganizationName, "Good-MITM");
38 distinguished_name.push(DnType::CountryName, "CN");
39 distinguished_name.push(DnType::LocalityName, "CN");
40 params.distinguished_name = distinguished_name;
41 params.key_usages = vec![
42 KeyUsagePurpose::DigitalSignature,
43 KeyUsagePurpose::KeyCertSign,
44 KeyUsagePurpose::CrlSign,
45 ];
46 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
47 Certificate::from_params(params)
48 }
49
50 pub fn new(
55 private_key: rustls::PrivateKey,
56 ca_cert: rustls::Certificate,
57 ca_cert_string: String,
58 cache_size: u64,
59 ) -> Result<CertificateAuthority, Error> {
60 let ca = CertificateAuthority {
61 private_key,
62 ca_cert,
63 ca_cert_string,
64 cache: Cache::builder()
65 .max_capacity(cache_size)
66 .time_to_live(std::time::Duration::from_secs(CERT_CACHE_TTL_SECONDS))
67 .build(),
68 };
69
70 ca.validate()?;
71 Ok(ca)
72 }
73
74 pub(crate) fn get_certified_key(&self, server_name: &str) -> Arc<CertifiedKey> {
75 if let Some(server_cfg) = self.cache.get(server_name) {
76 return server_cfg;
77 }
78
79 let certs = vec![self.gen_cert(server_name)];
80 let key = rustls::sign::any_supported_type(&self.private_key)
81 .expect("parse any supported private key");
82 let certified_key = Arc::new(CertifiedKey::new(certs, key));
83
84 self.cache
85 .insert(server_name.to_string(), certified_key.clone());
86
87 certified_key
88 }
89
90 fn gen_cert(&self, server_name: &str) -> rustls::Certificate {
91 let mut params = rcgen::CertificateParams::default();
92
93 params.serial_number = Some(thread_rng().gen::<u64>());
94 params.not_before = OffsetDateTime::now_utc().saturating_sub(1.days());
95 params.not_after = OffsetDateTime::now_utc().saturating_add((CERT_TTL_DAYS as i64).days());
96 params
97 .subject_alt_names
98 .push(SanType::DnsName(server_name.to_string()));
99 let mut distinguished_name = DistinguishedName::new();
100 distinguished_name.push(DnType::CommonName, server_name);
101 params.distinguished_name = distinguished_name;
102
103 params.key_usages = vec![KeyUsagePurpose::DigitalSignature];
104 params.extended_key_usages = vec![ExtendedKeyUsagePurpose::ServerAuth];
105
106 let key_pair = KeyPair::from_der(&self.private_key.0).expect("Failed to parse private key");
107 params.alg = key_pair
108 .compatible_algs()
109 .next()
110 .expect("Failed to find compatible algorithm");
111 params.key_pair = Some(key_pair);
112
113 let key_pair = KeyPair::from_der(&self.private_key.0).expect("Failed to parse private key");
114
115 let ca_cert_params = rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert.0, key_pair)
116 .expect("Failed to parse CA certificate");
117 let ca_cert = rcgen::Certificate::from_params(ca_cert_params)
118 .expect("Failed to generate CA certificate");
119
120 let cert = rcgen::Certificate::from_params(params).expect("Failed to generate certificate");
121
122 rustls::Certificate(
123 cert.serialize_der_with_signer(&ca_cert)
124 .expect("Failed to serialize certificate"),
125 )
126 }
127
128 fn validate(&self) -> Result<(), RcgenError> {
129 let key_pair = rcgen::KeyPair::from_der(&self.private_key.0)?;
130 rcgen::CertificateParams::from_ca_cert_der(&self.ca_cert.0, key_pair)?;
131 Ok(())
132 }
133
134 pub fn get_cert(&self) -> String {
135 self.ca_cert_string.clone()
136 }
137
138 pub fn gen_server_config(self: Arc<Self>) -> Arc<ServerConfig> {
139 let server_cfg = ServerConfig::builder()
140 .with_safe_defaults()
141 .with_no_client_auth()
142 .with_cert_resolver(self);
143 Arc::new(server_cfg)
144 }
145}
146
147impl ResolvesServerCert for CertificateAuthority {
148 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
149 client_hello
150 .server_name()
151 .map(|name| self.get_certified_key(name))
152 }
153}