use crate::mitm::ca_key_manager::CaKeyManager;
use anyhow::Result;
use base64::Engine;
use lru::LruCache;
use rcgen::{Certificate, CertificateParams, DnType, KeyPair, SanType};
use std::net::IpAddr;
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::{Duration, Instant};
use thiserror::Error;
use tokio::sync::Mutex;
use tracing::{debug, warn};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub enum HostIdentifier {
Domain(String),
Wildcard(String),
IpAddress(IpAddr),
Localhost,
}
impl HostIdentifier {
pub fn from_hostname(hostname: &str) -> Self {
if hostname == "localhost"
|| hostname == "127.0.0.1"
|| hostname == "::1"
|| hostname.starts_with("127.")
{
return Self::Localhost;
}
if let Ok(ip) = hostname.parse::<IpAddr>() {
if ip.is_loopback() {
return Self::Localhost;
}
return Self::IpAddress(ip);
}
if hostname.starts_with("*.") {
return Self::Wildcard(hostname.to_string());
}
Self::Domain(hostname.to_string())
}
}
#[derive(Clone)]
struct CachedCertificate {
cert: Arc<Certificate>,
created_at: Instant,
}
impl CachedCertificate {
fn new(cert: Arc<Certificate>) -> Self {
Self {
cert,
created_at: Instant::now(),
}
}
fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
}
#[derive(Debug, Error)]
pub enum MitmError {
#[error("Localhost bypass - MITM not allowed for localhost")]
LocalhostBypass,
#[error("Certificate generation failed: {0}")]
CertGenerationFailed(String),
#[error("Invalid hostname: {0}")]
InvalidHostname(String),
#[error("Cache error: {0}")]
CacheError(String),
}
pub struct CertificateAuthority {
ca_manager: Arc<CaKeyManager>,
cache: Arc<Mutex<LruCache<HostIdentifier, CachedCertificate>>>,
max_cache_size: usize,
cert_ttl: Duration,
}
impl CertificateAuthority {
pub fn new(ca_manager: Arc<CaKeyManager>, max_cache_size: usize) -> Self {
Self::with_ttl(ca_manager, max_cache_size, Duration::from_secs(86400))
}
pub fn with_ttl(
ca_manager: Arc<CaKeyManager>,
max_cache_size: usize,
cert_ttl: Duration,
) -> Self {
let cache_size =
NonZeroUsize::new(max_cache_size).unwrap_or(NonZeroUsize::new(1000).unwrap());
Self {
ca_manager,
cache: Arc::new(Mutex::new(LruCache::new(cache_size))),
max_cache_size,
cert_ttl,
}
}
pub async fn get_or_generate(
&self,
host: HostIdentifier,
) -> Result<Arc<Certificate>, MitmError> {
if matches!(host, HostIdentifier::Localhost) {
warn!("Attempted MITM on localhost - bypassing");
return Err(MitmError::LocalhostBypass);
}
{
let mut cache = self.cache.lock().await;
if let Some(cached) = cache.get(&host) {
if cached.is_expired(self.cert_ttl) {
debug!(host = ?host, "Certificate cache hit but expired, regenerating");
cache.pop(&host);
} else {
debug!(host = ?host, "Certificate cache hit (valid)");
return Ok(Arc::clone(&cached.cert));
}
}
}
debug!(host = ?host, "Generating new certificate");
let cert = self.generate_certificate(&host).await?;
let cert_arc = Arc::new(cert);
{
let mut cache = self.cache.lock().await;
cache.put(host.clone(), CachedCertificate::new(Arc::clone(&cert_arc)));
}
Ok(cert_arc)
}
async fn generate_certificate(&self, host: &HostIdentifier) -> Result<Certificate, MitmError> {
let mut params = CertificateParams::default();
match host {
HostIdentifier::Domain(domain) => {
params
.distinguished_name
.push(DnType::CommonName, domain.clone());
params.subject_alt_names = vec![SanType::DnsName(domain.clone())];
}
HostIdentifier::Wildcard(wildcard) => {
params
.distinguished_name
.push(DnType::CommonName, wildcard.clone());
params.subject_alt_names = vec![SanType::DnsName(wildcard.clone())];
}
HostIdentifier::IpAddress(ip) => {
params
.distinguished_name
.push(DnType::CommonName, ip.to_string());
params.subject_alt_names = vec![SanType::IpAddress(*ip)];
}
HostIdentifier::Localhost => {
return Err(MitmError::LocalhostBypass);
}
}
params.not_before = time::OffsetDateTime::now_utc() - time::Duration::days(1);
params.not_after = time::OffsetDateTime::now_utc() + time::Duration::days(90);
let serial_number = self.generate_serial_number();
params.serial_number = Some(serial_number.into());
let key_pair = KeyPair::generate(&rcgen::PKCS_ECDSA_P256_SHA256)
.map_err(|e| MitmError::CertGenerationFailed(e.to_string()))?;
let temp_cert = Certificate::from_params(params)
.map_err(|e| MitmError::CertGenerationFailed(e.to_string()))?;
let ca_cert = self.ca_manager.certificate();
let cert_der = temp_cert
.serialize_der_with_signer(&*ca_cert)
.map_err(|e| MitmError::CertGenerationFailed(e.to_string()))?;
let cert_pem = format!(
"-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
base64::engine::general_purpose::STANDARD.encode(&cert_der)
);
let cert = Certificate::from_params(
CertificateParams::from_ca_cert_pem(&cert_pem, key_pair)
.map_err(|e| MitmError::CertGenerationFailed(e.to_string()))?,
)
.map_err(|e| MitmError::CertGenerationFailed(e.to_string()))?;
Ok(cert)
}
fn generate_serial_number(&self) -> u64 {
use rand::Rng;
let mut rng = rand::thread_rng();
let random_part: u32 = rng.gen();
let timestamp_part = chrono::Utc::now().timestamp() as u32;
((timestamp_part as u64) << 32) | (random_part as u64)
}
pub async fn cache_stats(&self) -> (usize, usize) {
let cache = self.cache.lock().await;
(cache.len(), self.max_cache_size)
}
pub async fn clear_cache(&self) {
let mut cache = self.cache.lock().await;
cache.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_host_identifier_parsing() {
assert!(matches!(
HostIdentifier::from_hostname("localhost"),
HostIdentifier::Localhost
));
assert!(matches!(
HostIdentifier::from_hostname("127.0.0.1"),
HostIdentifier::Localhost
));
assert!(matches!(
HostIdentifier::from_hostname("::1"),
HostIdentifier::Localhost
));
assert!(matches!(
HostIdentifier::from_hostname("192.168.1.1"),
HostIdentifier::IpAddress(_)
));
assert!(matches!(
HostIdentifier::from_hostname("*.example.com"),
HostIdentifier::Wildcard(_)
));
assert!(matches!(
HostIdentifier::from_hostname("example.com"),
HostIdentifier::Domain(_)
));
}
}