use anyhow::{Context, Result};
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, KeyPair,
KeyUsagePurpose, SanType,
};
use rustls::pki_types::CertificateDer;
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::audit::{AuditEventType, AuditLogger};
pub struct CertManager {
ca_cert_pem: String,
ca_key_pem: String,
cache_dir: PathBuf,
mem_cache: Arc<RwLock<HashMap<String, (String, String)>>>,
audit_logger: Option<Arc<AuditLogger>>,
}
impl CertManager {
pub fn new(data_dir: &Path) -> Result<Self> {
Self::new_with_logger(data_dir, None)
}
pub fn new_with_logger(data_dir: &Path, audit_logger: Option<Arc<AuditLogger>>) -> Result<Self> {
let ca_dir = data_dir.join("certs").join("ca");
let cache_dir = data_dir.join("certs").join("cache");
fs::create_dir_all(&ca_dir)?;
fs::create_dir_all(&cache_dir)?;
let cert_path = ca_dir.join("cert.pem");
let key_path = ca_dir.join("key.pem");
let (ca_cert_pem, ca_key_pem) = if cert_path.exists() && key_path.exists() {
let ca_cert_contents = fs::read_to_string(&cert_path)
.with_context(|| "Failed to read CA cert")?;
let ca_key_contents = fs::read_to_string(&key_path)
.with_context(|| "Failed to read CA key")?;
(ca_cert_contents, ca_key_contents)
} else {
let (ca_cert_pem, ca_key_pem) = generate_ca_cert()?;
fs::write(&cert_path, &ca_cert_pem)?;
fs::write(&key_path, &ca_key_pem)?;
let export_path = data_dir.join("certs/ca.crt");
fs::write(&export_path, &ca_cert_pem)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&key_path, fs::Permissions::from_mode(0o600))?;
}
tracing::info!("Generated new CA certificate at {}", cert_path.display());
if let Some(logger) = &audit_logger {
let _ = logger.log(
AuditEventType::CertGenerated,
format!("Generated CA certificate at {}", cert_path.display()),
true,
);
}
(ca_cert_pem, ca_key_pem)
};
let _ca_key_pair = KeyPair::from_pem(&ca_key_pem)
.with_context(|| "Failed to parse CA key pair")?;
Ok(Self {
ca_cert_pem,
ca_key_pem,
cache_dir,
mem_cache: Arc::new(RwLock::new(HashMap::new())),
audit_logger,
})
}
#[allow(dead_code)]
pub fn ca_cert_pem(&self) -> &str {
&self.ca_cert_pem
}
pub async fn get_or_create_cert(&self, domain: &str) -> Result<(String, String)> {
{
let cache = self.mem_cache.read().await;
if let Some(cached) = cache.get(domain) {
return Ok(cached.clone());
}
}
let cert_file = self.cache_dir.join(format!("{}.crt", domain));
let key_file = self.cache_dir.join(format!("{}.key", domain));
if cert_file.exists() && key_file.exists() {
let cert_pem = fs::read_to_string(&cert_file)?;
let key_pem = fs::read_to_string(&key_file)?;
let mut cache = self.mem_cache.write().await;
cache.insert(domain.to_string(), (cert_pem.clone(), key_pem.clone()));
return Ok((cert_pem, key_pem));
}
let (cert_pem, key_pem) = self.sign_server_cert(domain)?;
fs::write(&cert_file, &cert_pem)?;
fs::write(&key_file, &key_pem)?;
{
let mut cache = self.mem_cache.write().await;
cache.insert(domain.to_string(), (cert_pem.clone(), key_pem.clone()));
}
tracing::debug!("Generated server certificate for {}", domain);
if let Some(logger) = &self.audit_logger {
let _ = logger.log(
AuditEventType::CertGenerated,
format!("Generated server certificate for domain: {}", domain),
true,
);
}
Ok((cert_pem, key_pem))
}
fn sign_server_cert(&self, domain: &str) -> Result<(String, String)> {
let ca_key_pair = KeyPair::from_pem(&self.ca_key_pem)
.with_context(|| "Failed to parse CA key pair")?;
let ca_cert_der = pem_to_der(&self.ca_cert_pem)?;
let ca_params = CertificateParams::from_ca_cert_der(&ca_cert_der)
.with_context(|| "Failed to parse CA cert params")?;
let ca_cert = ca_params
.self_signed(&ca_key_pair)
.with_context(|| "Failed to reconstruct CA certificate")?;
let mut params = CertificateParams::new(vec![domain.to_string()])
.with_context(|| format!("Failed to create cert params for {}", domain))?;
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, domain);
dn.push(DnType::OrganizationName, "FakeKey Proxy");
params.distinguished_name = dn;
params.subject_alt_names = vec![SanType::DnsName(domain.try_into()?)];
let server_key_pair = KeyPair::generate()
.with_context(|| "Failed to generate server key pair")?;
let server_cert = params
.signed_by(&server_key_pair, &ca_cert, &ca_key_pair)
.with_context(|| format!("Failed to sign server cert for {}", domain))?;
let cert_pem = server_cert.pem();
let key_pem = server_key_pair.serialize_pem();
Ok((cert_pem, key_pem))
}
pub async fn make_server_config(&self, domain: &str) -> Result<Arc<rustls::ServerConfig>> {
let (cert_pem, key_pem) = self.get_or_create_cert(domain).await?;
let certs = rustls_pemfile::certs(&mut cert_pem.as_bytes())
.collect::<Result<Vec<_>, _>>()
.with_context(|| "Failed to parse server cert")?;
let key = rustls_pemfile::private_key(&mut key_pem.as_bytes())
.with_context(|| "Failed to parse server key")?
.ok_or_else(|| anyhow::anyhow!("No private key found"))?;
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.with_context(|| "Failed to build server TLS config")?;
Ok(Arc::new(config))
}
}
fn generate_ca_cert() -> Result<(String, String)> {
let mut params = CertificateParams::new(Vec::<String>::new())
.with_context(|| "Failed to create CA cert params")?;
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "FakeKey Root CA");
dn.push(DnType::OrganizationName, "FakeKey");
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
];
params.not_before = rcgen::date_time_ymd(2024, 1, 1);
params.not_after = rcgen::date_time_ymd(2034, 1, 1);
let key_pair = KeyPair::generate()
.with_context(|| "Failed to generate CA key pair")?;
let cert = params.self_signed(&key_pair)
.with_context(|| "Failed to self-sign CA certificate")?;
Ok((cert.pem(), key_pair.serialize_pem()))
}
fn pem_to_der(pem_str: &str) -> Result<CertificateDer<'static>> {
let certs: Vec<_> = rustls_pemfile::certs(&mut pem_str.as_bytes())
.collect::<Result<Vec<_>, _>>()
.with_context(|| "Failed to parse PEM")?;
certs
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No certificate found in PEM"))
}