use crate::error::{Error, Result};
use moka::future::Cache;
use rand::RngExt;
use rcgen::{
BasicConstraints, CertificateParams, DistinguishedName, DnType, IsCa, Issuer, KeyPair,
KeyUsagePurpose, SanType,
};
use rustls_pki_types::pem::{PemObject, SectionKind};
use std::net::IpAddr;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use time::{Duration, OffsetDateTime};
use tokio::fs;
use tokio::io::AsyncWriteExt;
use tokio_rustls::rustls::pki_types::{CertificateDer, PrivateKeyDer};
const TTL_SECS: i64 = 365 * 24 * 60 * 60;
const CACHE_TTL: u64 = (TTL_SECS / 2) as u64;
const NOT_BEFORE_OFFSET: i64 = 60;
pub struct CertificateAuthority {
issuer: Issuer<'static, KeyPair>,
ca_cert_der: CertificateDer<'static>,
#[allow(dead_code)]
ca_key_der: PrivateKeyDer<'static>,
storage_path: PathBuf,
}
impl CertificateAuthority {
pub async fn new(storage_path: impl AsRef<Path>) -> Result<Self> {
let storage_path = storage_path.as_ref().to_path_buf();
if !storage_path.exists() {
fs::create_dir_all(&storage_path).await?;
}
let ca_cert_path = storage_path.join("ca_cert.pem");
let ca_key_path = storage_path.join("ca_key.pem");
let (issuer, ca_cert_der, ca_key_der) = if ca_cert_path.exists() && ca_key_path.exists() {
Self::load_ca(&ca_cert_path, &ca_key_path).await?
} else {
Self::generate_ca(&ca_cert_path, &ca_key_path).await?
};
Ok(Self {
issuer,
ca_cert_der,
ca_key_der,
storage_path,
})
}
async fn load_ca(
cert_path: &Path,
key_path: &Path,
) -> Result<(
Issuer<'static, KeyPair>,
CertificateDer<'static>,
PrivateKeyDer<'static>,
)> {
let cert_pem = fs::read_to_string(cert_path).await?;
let key_pem = fs::read_to_string(key_path).await?;
let key_pair = KeyPair::from_pem(&key_pem)
.map_err(|e| Error::certificate_error(format!("Failed to parse CA key: {}", e)))?;
let issuer = Issuer::from_ca_cert_pem(&cert_pem, key_pair).map_err(|e| {
Error::certificate_error(format!("Failed to create issuer from CA cert: {}", e))
})?;
let mut found: Option<Vec<u8>> = None;
for item in <(SectionKind, Vec<u8>) as PemObject>::pem_slice_iter(cert_pem.as_bytes()) {
match item {
Ok((kind, contents)) => {
if kind == SectionKind::Certificate {
found = Some(contents);
break;
}
}
Err(e) => {
return Err(Error::certificate_error(format!(
"Failed to parse PEM: {}",
e
)));
}
}
}
let cert_der_vec =
found.ok_or_else(|| Error::certificate_error("No certificate found in PEM"))?;
let cert_der = CertificateDer::from(cert_der_vec);
let key_der = PrivateKeyDer::try_from(issuer.key().serialize_der())
.map_err(|_| Error::certificate_error("Failed to serialize CA key"))?;
Ok((issuer, cert_der, key_der))
}
async fn generate_ca(
cert_path: &Path,
key_path: &Path,
) -> Result<(
Issuer<'static, KeyPair>,
CertificateDer<'static>,
PrivateKeyDer<'static>,
)> {
let mut params = CertificateParams::default();
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, "Slinger MITM Proxy CA");
dn.push(DnType::OrganizationName, "Emo-Crab");
dn.push(DnType::CountryName, "CN");
dn.push(DnType::LocalityName, "Internet");
dn.push(DnType::StateOrProvinceName, "World");
params.distinguished_name = dn;
params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
params.key_usages = vec![KeyUsagePurpose::KeyCertSign, KeyUsagePurpose::CrlSign];
let now = OffsetDateTime::now_utc();
params.not_before = now;
params.not_after = now + Duration::days(3650);
let key_pair = KeyPair::generate()
.map_err(|e| Error::certificate_error(format!("Failed to generate key pair: {}", e)))?;
let cert = params
.self_signed(&key_pair)
.map_err(|e| Error::certificate_error(format!("Failed to generate CA: {}", e)))?;
let cert_pem = cert.pem();
let key_pem = key_pair.serialize_pem();
let mut cert_file = fs::File::create(cert_path).await?;
cert_file.write_all(cert_pem.as_bytes()).await?;
let mut key_file = fs::File::create(key_path).await?;
key_file.write_all(key_pem.as_bytes()).await?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::try_from(key_pair.serialize_der())
.map_err(|_| Error::certificate_error("Failed to serialize CA key DER"))?;
let issuer = Issuer::from_ca_cert_pem(&cert_pem, key_pair)
.map_err(|e| Error::certificate_error(format!("Failed to create issuer: {}", e)))?;
Ok((issuer, cert_der, key_der))
}
pub fn generate_server_cert(
&self,
domain: &str,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let mut params = CertificateParams::default();
params.serial_number = Some(rand::rng().random::<u64>().into());
let mut dn = DistinguishedName::new();
dn.push(DnType::CommonName, domain);
params.distinguished_name = dn;
params.subject_alt_names = if let Ok(ip) = domain.parse::<IpAddr>() {
let mut sans = Vec::new();
sans.push(SanType::IpAddress(ip));
if let Ok(dns_name) = domain.try_into() {
sans.push(SanType::DnsName(dns_name));
}
sans
} else {
vec![SanType::DnsName(domain.try_into().map_err(|_| {
Error::certificate_error(format!("Invalid domain name: {}", domain))
})?)]
};
let now = OffsetDateTime::now_utc();
params.not_before = now - Duration::seconds(NOT_BEFORE_OFFSET);
params.not_after = now + Duration::seconds(TTL_SECS);
let key_pair = KeyPair::generate()
.map_err(|e| Error::certificate_error(format!("Failed to generate key pair: {}", e)))?;
let cert = params
.signed_by(&key_pair, &self.issuer)
.map_err(|e| Error::certificate_error(format!("Failed to sign server cert: {}", e)))?;
let cert_der = CertificateDer::from(cert.der().to_vec());
let key_der = PrivateKeyDer::try_from(key_pair.serialize_der())
.map_err(|_| Error::certificate_error("Failed to serialize server key"))?;
Ok((vec![cert_der, self.ca_cert_der.clone()], key_der))
}
pub async fn ca_cert_pem(&self) -> Result<String> {
let ca_cert_path = self.storage_path.join("ca_cert.pem");
tokio::fs::read_to_string(&ca_cert_path)
.await
.map_err(|e| Error::certificate_error(format!("Failed to read CA cert: {}", e)))
}
pub fn ca_cert_path(&self) -> PathBuf {
self.storage_path.join("ca_cert.pem")
}
}
pub struct CertificateManager {
ca: CertificateAuthority,
cert_cache: Cache<String, Arc<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>>,
}
impl CertificateManager {
pub async fn new(storage_path: impl AsRef<Path>) -> Result<Self> {
let ca = CertificateAuthority::new(storage_path).await?;
let cert_cache = Cache::builder()
.max_capacity(1000)
.time_to_live(std::time::Duration::from_secs(CACHE_TTL))
.build();
Ok(Self { ca, cert_cache })
}
pub async fn get_server_cert(
&self,
domain: &str,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
if domain.parse::<std::net::IpAddr>().is_ok() {
let (cert_chain, key) = self.ca.generate_server_cert(domain)?;
let cached_cert = (cert_chain.clone(), key.clone_key());
self
.cert_cache
.insert(domain.to_string(), Arc::new(cached_cert))
.await;
return Ok((cert_chain, key));
}
if let Some(cached) = self.cert_cache.get(domain).await {
let (cert_chain, key) = cached.as_ref();
return Ok((cert_chain.clone(), key.clone_key()));
}
let (cert_chain, key) = self.ca.generate_server_cert(domain)?;
let cached_cert = (cert_chain.clone(), key.clone_key());
self
.cert_cache
.insert(domain.to_string(), Arc::new(cached_cert))
.await;
Ok((cert_chain, key))
}
pub async fn ca_cert_pem(&self) -> Result<String> {
self.ca.ca_cert_pem().await
}
pub fn ca_cert_path(&self) -> PathBuf {
self.ca.ca_cert_path()
}
}