use crate::error::{ProxyError, Result};
use dashmap::DashMap;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use std::io::BufReader;
use std::sync::{Arc, RwLock};
use tracing::{debug, trace, warn};
#[derive(Debug)]
pub struct SniCertResolver {
certs: DashMap<String, Arc<CertifiedKey>>,
default_cert: RwLock<Option<Arc<CertifiedKey>>>,
}
impl SniCertResolver {
#[must_use]
pub fn new() -> Self {
Self {
certs: DashMap::new(),
default_cert: RwLock::new(None),
}
}
pub fn load_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
let certified_key = create_certified_key(cert_pem, key_pem)?;
let domain_normalized = normalize_domain(domain);
debug!(domain = %domain_normalized, "Loaded TLS certificate");
self.certs
.insert(domain_normalized, Arc::new(certified_key));
Ok(())
}
pub fn set_default_cert(&self, cert_pem: &str, key_pem: &str) -> Result<()> {
let certified_key = create_certified_key(cert_pem, key_pem)?;
debug!("Set default TLS certificate");
let mut default = self.default_cert.write().expect("RwLock poisoned");
*default = Some(Arc::new(certified_key));
Ok(())
}
pub fn remove_cert(&self, domain: &str) {
let domain_normalized = normalize_domain(domain);
if self.certs.remove(&domain_normalized).is_some() {
debug!(domain = %domain_normalized, "Removed TLS certificate");
}
}
pub fn refresh_cert(&self, domain: &str, cert_pem: &str, key_pem: &str) -> Result<()> {
let certified_key = create_certified_key(cert_pem, key_pem)?;
let domain_normalized = normalize_domain(domain);
debug!(domain = %domain_normalized, "Refreshed TLS certificate");
self.certs
.insert(domain_normalized, Arc::new(certified_key));
Ok(())
}
#[must_use]
pub fn has_cert(&self, domain: &str) -> bool {
let domain_normalized = normalize_domain(domain);
self.certs.contains_key(&domain_normalized)
}
#[must_use]
pub fn cert_count(&self) -> usize {
self.certs.len()
}
#[must_use]
pub fn domains(&self) -> Vec<String> {
self.certs.iter().map(|r| r.key().clone()).collect()
}
#[must_use]
pub fn has_default_cert(&self) -> bool {
self.default_cert
.read()
.map(|guard| guard.is_some())
.unwrap_or(false)
}
fn resolve_cert(&self, server_name: Option<&str>) -> Option<Arc<CertifiedKey>> {
let server_name = server_name?;
let normalized = normalize_domain(server_name);
if let Some(cert) = self.certs.get(&normalized) {
trace!(domain = %normalized, "Exact certificate match");
return Some(Arc::clone(cert.value()));
}
if let Some(wildcard_domain) = get_wildcard_domain(&normalized) {
if let Some(cert) = self.certs.get(&wildcard_domain) {
trace!(
domain = %normalized,
wildcard = %wildcard_domain,
"Wildcard certificate match"
);
return Some(Arc::clone(cert.value()));
}
}
if let Ok(guard) = self.default_cert.read() {
if let Some(default) = guard.as_ref() {
trace!(domain = %normalized, "Using default certificate");
return Some(Arc::clone(default));
}
}
warn!(domain = %normalized, "No certificate found");
None
}
}
impl Default for SniCertResolver {
fn default() -> Self {
Self::new()
}
}
impl ResolvesServerCert for SniCertResolver {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
let server_name = client_hello.server_name();
self.resolve_cert(server_name)
}
}
fn create_certified_key(cert_pem: &str, key_pem: &str) -> Result<CertifiedKey> {
let certs = parse_certificates(cert_pem)?;
if certs.is_empty() {
return Err(ProxyError::Tls("No certificates found in PEM".to_string()));
}
let key = parse_private_key(key_pem)?;
let signing_key = rustls::crypto::ring::sign::any_supported_type(&key)
.map_err(|e| ProxyError::Tls(format!("Failed to create signing key: {e}")))?;
Ok(CertifiedKey::new(certs, signing_key))
}
fn parse_certificates(pem: &str) -> Result<Vec<CertificateDer<'static>>> {
let mut reader = BufReader::new(pem.as_bytes());
let certs: Vec<CertificateDer<'static>> = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| ProxyError::Tls(format!("Failed to parse certificate PEM: {e}")))?;
Ok(certs)
}
fn parse_private_key(pem: &str) -> Result<PrivateKeyDer<'static>> {
let mut reader = BufReader::new(pem.as_bytes());
loop {
match rustls_pemfile::read_one(&mut reader) {
Ok(Some(rustls_pemfile::Item::Pkcs1Key(key))) => {
return Ok(PrivateKeyDer::Pkcs1(key));
}
Ok(Some(rustls_pemfile::Item::Pkcs8Key(key))) => {
return Ok(PrivateKeyDer::Pkcs8(key));
}
Ok(Some(rustls_pemfile::Item::Sec1Key(key))) => {
return Ok(PrivateKeyDer::Sec1(key));
}
Ok(Some(_)) => {
}
Ok(None) => {
return Err(ProxyError::Tls("No private key found in PEM".to_string()));
}
Err(e) => {
return Err(ProxyError::Tls(format!(
"Failed to parse private key PEM: {e}"
)));
}
}
}
}
fn normalize_domain(domain: &str) -> String {
domain.trim().to_lowercase()
}
fn get_wildcard_domain(domain: &str) -> Option<String> {
let parts: Vec<&str> = domain.split('.').collect();
if parts.len() > 2 {
Some(format!("*.{}", parts[1..].join(".")))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_normalize_domain() {
assert_eq!(normalize_domain("Example.COM"), "example.com");
assert_eq!(normalize_domain(" foo.bar.com "), "foo.bar.com");
assert_eq!(normalize_domain("API.Example.ORG"), "api.example.org");
}
#[test]
fn test_get_wildcard_domain() {
assert_eq!(
get_wildcard_domain("foo.example.com"),
Some("*.example.com".to_string())
);
assert_eq!(
get_wildcard_domain("bar.foo.example.com"),
Some("*.foo.example.com".to_string())
);
assert_eq!(get_wildcard_domain("example.com"), None);
assert_eq!(get_wildcard_domain("localhost"), None);
}
#[test]
fn test_sni_resolver_new() {
let resolver = SniCertResolver::new();
assert_eq!(resolver.cert_count(), 0);
assert!(resolver.domains().is_empty());
}
#[test]
fn test_sni_resolver_default() {
let resolver = SniCertResolver::default();
assert_eq!(resolver.cert_count(), 0);
}
fn generate_test_cert() -> (String, String) {
use rcgen::{generate_simple_self_signed, CertifiedKey as RcgenCertifiedKey};
let subject_alt_names = vec!["localhost".to_string(), "example.com".to_string()];
let RcgenCertifiedKey { cert, key_pair } =
generate_simple_self_signed(subject_alt_names).unwrap();
(cert.pem(), key_pair.serialize_pem())
}
#[tokio::test]
async fn test_load_cert() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
let result = resolver.load_cert("example.com", &cert_pem, &key_pem);
assert!(result.is_ok());
assert!(resolver.has_cert("example.com"));
assert_eq!(resolver.cert_count(), 1);
}
#[tokio::test]
async fn test_load_cert_case_insensitive() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("Example.COM", &cert_pem, &key_pem)
.unwrap();
assert!(resolver.has_cert("example.com"));
assert!(resolver.has_cert("EXAMPLE.COM"));
}
#[tokio::test]
async fn test_remove_cert() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("example.com", &cert_pem, &key_pem)
.unwrap();
assert!(resolver.has_cert("example.com"));
resolver.remove_cert("example.com");
assert!(!resolver.has_cert("example.com"));
assert_eq!(resolver.cert_count(), 0);
}
#[tokio::test]
async fn test_refresh_cert() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("example.com", &cert_pem, &key_pem)
.unwrap();
let (new_cert_pem, new_key_pem) = generate_test_cert();
let result = resolver.refresh_cert("example.com", &new_cert_pem, &new_key_pem);
assert!(result.is_ok());
assert_eq!(resolver.cert_count(), 1);
}
#[tokio::test]
async fn test_set_default_cert() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
let result = resolver.set_default_cert(&cert_pem, &key_pem);
assert!(result.is_ok());
assert_eq!(resolver.cert_count(), 0);
}
#[tokio::test]
async fn test_has_default_cert() {
let resolver = SniCertResolver::new();
assert!(!resolver.has_default_cert());
let (cert_pem, key_pem) = generate_test_cert();
resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
assert!(resolver.has_default_cert());
}
#[tokio::test]
async fn test_domains() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("api.example.com", &cert_pem, &key_pem)
.unwrap();
resolver
.load_cert("web.example.com", &cert_pem, &key_pem)
.unwrap();
let domains = resolver.domains();
assert_eq!(domains.len(), 2);
assert!(domains.contains(&"api.example.com".to_string()));
assert!(domains.contains(&"web.example.com".to_string()));
}
#[tokio::test]
async fn test_resolve_exact_match() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("example.com", &cert_pem, &key_pem)
.unwrap();
let result = resolver.resolve_cert(Some("example.com"));
assert!(result.is_some());
}
#[tokio::test]
async fn test_resolve_wildcard_match() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("*.example.com", &cert_pem, &key_pem)
.unwrap();
let result = resolver.resolve_cert(Some("api.example.com"));
assert!(result.is_some());
let result = resolver.resolve_cert(Some("web.example.com"));
assert!(result.is_some());
let result = resolver.resolve_cert(Some("example.com"));
assert!(result.is_none());
}
#[tokio::test]
async fn test_resolve_default_fallback() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver.set_default_cert(&cert_pem, &key_pem).unwrap();
let result = resolver.resolve_cert(Some("unknown.com"));
assert!(result.is_some());
}
#[tokio::test]
async fn test_resolve_no_match() {
let resolver = SniCertResolver::new();
let (cert_pem, key_pem) = generate_test_cert();
resolver
.load_cert("example.com", &cert_pem, &key_pem)
.unwrap();
let result = resolver.resolve_cert(Some("other.com"));
assert!(result.is_none());
}
#[tokio::test]
async fn test_resolve_none_server_name() {
let resolver = SniCertResolver::new();
let result = resolver.resolve_cert(None);
assert!(result.is_none());
}
#[test]
fn test_invalid_cert_pem() {
let result = parse_certificates("not a valid PEM");
assert!(result.is_ok()); assert!(result.unwrap().is_empty());
}
#[test]
fn test_invalid_key_pem() {
let result = parse_private_key("not a valid PEM");
assert!(result.is_err());
}
#[test]
fn test_create_certified_key_empty_certs() {
let (_, key_pem) = generate_test_cert();
let result = create_certified_key("", &key_pem);
assert!(result.is_err());
}
}