1use std::num::NonZeroUsize;
2use std::sync::{Arc, Mutex};
3
4use lru::LruCache;
5use rustls::ServerConfig;
6use rustls::crypto::CryptoProvider;
7use rustls::server::{ClientHello, ResolvesServerCert};
8use rustls::sign::CertifiedKey;
9
10use crate::CertificateAuthority;
11
12const DEFAULT_CACHE_CAPACITY: usize = 1024;
13
14pub struct MitmCertResolver {
21 ca: CertificateAuthority,
22 cache: Mutex<LruCache<String, Arc<CertifiedKey>>>,
23}
24
25impl std::fmt::Debug for MitmCertResolver {
26 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
27 f.debug_struct("MitmCertResolver").finish_non_exhaustive()
28 }
29}
30
31impl MitmCertResolver {
32 pub fn new(ca: CertificateAuthority) -> Self {
34 Self::with_cache_capacity(ca, DEFAULT_CACHE_CAPACITY)
35 }
36
37 pub fn with_cache_capacity(ca: CertificateAuthority, capacity: usize) -> Self {
39 Self {
40 ca,
41 cache: Mutex::new(LruCache::new(NonZeroUsize::new(capacity).unwrap())),
42 }
43 }
44
45 pub fn into_server_config(self) -> ServerConfig {
51 ServerConfig::builder()
52 .with_no_client_auth()
53 .with_cert_resolver(Arc::new(self))
54 }
55}
56
57impl ResolvesServerCert for MitmCertResolver {
58 fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
59 let hostname = client_hello.server_name()?;
60
61 {
62 let mut cache = self.cache.lock().unwrap();
63 if let Some(key) = cache.get(hostname) {
64 return Some(key.clone());
65 }
66 }
67
68 let (cert_der, key_der) = self.ca.generate_cert(hostname).ok()?;
69 let provider = CryptoProvider::get_default()?;
70 let signing_key = provider.key_provider.load_private_key(key_der).ok()?;
71 let certified = Arc::new(CertifiedKey::new(vec![cert_der], signing_key));
72
73 let mut cache = self.cache.lock().unwrap();
74 if let Some(existing) = cache.get(hostname) {
75 return Some(existing.clone());
76 }
77 cache.put(hostname.to_string(), certified.clone());
78 Some(certified)
79 }
80}