use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use rcgen::{Certificate, CertificateParams, DnType, KeyPair};
use rustls::pki_types::PrivateKeyDer;
use rustls::ServerConfig;
pub(crate) struct CertSigner {
ca_cert: Certificate,
ca_key: KeyPair,
cache: Mutex<HashMap<String, Arc<ServerConfig>>>,
}
impl CertSigner {
pub(crate) fn new(ca_cert_pem: &str, ca_key_pem: &str) -> std::io::Result<Self> {
let ca_key = KeyPair::from_pem(ca_key_pem).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("CA key: {e}"))
})?;
let ca_params = CertificateParams::from_ca_cert_pem(ca_cert_pem).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("CA cert: {e}"))
})?;
let ca_cert = ca_params.self_signed(&ca_key).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidData, format!("CA rebuild: {e}"))
})?;
Ok(Self { ca_cert, ca_key, cache: Mutex::new(HashMap::new()) })
}
pub(crate) fn server_config_for(&self, sni: &str) -> std::io::Result<Arc<ServerConfig>> {
if let Some(cfg) = self.cache.lock().unwrap().get(sni) {
return Ok(Arc::clone(cfg));
}
let leaf_key = KeyPair::generate().map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("leaf keygen: {e}"))
})?;
let mut params = CertificateParams::new(vec![sni.to_string()]).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("leaf params: {e}"))
})?;
params.distinguished_name.push(DnType::CommonName, sni);
let leaf = params.signed_by(&leaf_key, &self.ca_cert, &self.ca_key).map_err(|e| {
std::io::Error::new(std::io::ErrorKind::Other, format!("leaf sign: {e}"))
})?;
let chain = vec![leaf.der().clone()];
let key_der = PrivateKeyDer::Pkcs8(leaf_key.serialize_der().into());
let mut cfg = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(chain, key_der)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("server cfg: {e}")))?;
cfg.alpn_protocols = vec![b"http/1.1".to_vec()];
let cfg = Arc::new(cfg);
self.cache.lock().unwrap().insert(sni.to_string(), Arc::clone(&cfg));
Ok(cfg)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_ca() -> (String, String) {
let m = crate::transparent_proxy::resolve_ca(None, None, true).unwrap().unwrap();
(m.cert_pem, m.key_pem)
}
#[test]
fn mints_distinct_configs_per_sni() {
let (cert, key) = test_ca();
let signer = CertSigner::new(&cert, &key).unwrap();
let a = signer.server_config_for("api.openai.com").unwrap();
let b = signer.server_config_for("example.com").unwrap();
assert!(!Arc::ptr_eq(&a, &b));
}
#[test]
fn caches_by_sni() {
let (cert, key) = test_ca();
let signer = CertSigner::new(&cert, &key).unwrap();
let a = signer.server_config_for("example.com").unwrap();
let b = signer.server_config_for("example.com").unwrap();
assert!(Arc::ptr_eq(&a, &b));
}
}