use std::num::NonZeroUsize;
use std::sync::{Arc, Mutex};
use lru::LruCache;
use rustls::ServerConfig;
use rustls::crypto::CryptoProvider;
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use crate::CertificateAuthority;
const DEFAULT_CACHE_CAPACITY: usize = 1024;
pub struct MitmCertResolver {
ca: CertificateAuthority,
cache: Mutex<LruCache<String, Arc<CertifiedKey>>>,
}
impl std::fmt::Debug for MitmCertResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MitmCertResolver").finish_non_exhaustive()
}
}
impl MitmCertResolver {
pub fn new(ca: CertificateAuthority) -> Self {
Self::with_cache_capacity(ca, DEFAULT_CACHE_CAPACITY)
}
pub fn with_cache_capacity(ca: CertificateAuthority, capacity: usize) -> Self {
Self {
ca,
cache: Mutex::new(LruCache::new(NonZeroUsize::new(capacity).unwrap())),
}
}
pub fn into_server_config(self) -> ServerConfig {
ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(self))
}
}
impl ResolvesServerCert for MitmCertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let hostname = client_hello.server_name()?;
{
let mut cache = self.cache.lock().unwrap();
if let Some(key) = cache.get(hostname) {
return Some(key.clone());
}
}
let (cert_der, key_der) = self.ca.generate_cert(hostname).ok()?;
let provider = CryptoProvider::get_default()?;
let signing_key = provider.key_provider.load_private_key(key_der).ok()?;
let certified = Arc::new(CertifiedKey::new(vec![cert_der], signing_key));
let mut cache = self.cache.lock().unwrap();
if let Some(existing) = cache.get(hostname) {
return Some(existing.clone());
}
cache.put(hostname.to_string(), certified.clone());
Some(certified)
}
}