use std::path::PathBuf;
use std::sync::{Arc, OnceLock, RwLock};
use notify::{RecursiveMode, Watcher};
use rustls::crypto::CryptoProvider;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use rustls::server::{ClientHello, ResolvesServerCert};
use rustls::sign::CertifiedKey;
use rustls::ServerConfig;
use tokio_util::sync::CancellationToken;
use crate::config::ProxyEntry;
use crate::ProxyError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CertSource {
Certmesh,
SelfSigned,
}
impl CertSource {
pub fn as_str(self) -> &'static str {
match self {
CertSource::Certmesh => "certmesh",
CertSource::SelfSigned => "self-signed",
}
}
}
#[derive(Debug)]
pub struct CertResolver {
current: RwLock<Arc<CertifiedKey>>,
}
impl CertResolver {
fn new(initial: Arc<CertifiedKey>) -> Self {
Self {
current: RwLock::new(initial),
}
}
fn swap(&self, next: Arc<CertifiedKey>) {
if let Ok(mut guard) = self.current.write() {
*guard = next;
}
}
}
impl ResolvesServerCert for CertResolver {
fn resolve(&self, _client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
self.current.read().ok().map(|guard| Arc::clone(&guard))
}
}
pub struct TlsSetup {
pub config: Arc<ServerConfig>,
pub cert_source: CertSource,
pub resolver: Arc<CertResolver>,
}
pub fn build_tls(entry: &ProxyEntry) -> Result<TlsSetup, ProxyError> {
let (certified, cert_source) = resolve_initial(entry)?;
let resolver = Arc::new(CertResolver::new(certified));
let config = ServerConfig::builder_with_provider(provider())
.with_safe_default_protocol_versions()
.map_err(|e| ProxyError::Io(format!("tls config: {e}")))?
.with_no_client_auth()
.with_cert_resolver(resolver.clone() as Arc<dyn ResolvesServerCert>);
Ok(TlsSetup {
config: Arc::new(config),
cert_source,
resolver,
})
}
pub fn spawn_cert_watcher(
entry: ProxyEntry,
resolver: Arc<CertResolver>,
cancel: CancellationToken,
) -> Option<notify::RecommendedWatcher> {
let certs_dir = koi_common::paths::koi_certs_dir();
if let Err(e) = std::fs::create_dir_all(&certs_dir) {
tracing::warn!(error = %e, "Proxy cert watcher: cannot create certs dir; hot-reload disabled");
return None;
}
let (tx, mut rx) = tokio::sync::mpsc::channel::<()>(8);
let mut watcher =
match notify::recommended_watcher(move |res: notify::Result<notify::Event>| {
if res.is_ok() {
let _ = tx.try_send(());
}
}) {
Ok(w) => w,
Err(e) => {
tracing::warn!(error = %e, "Proxy cert watcher: init failed; hot-reload disabled");
return None;
}
};
if let Err(e) = watcher.watch(&certs_dir, RecursiveMode::Recursive) {
tracing::warn!(error = %e, dir = %certs_dir.display(),
"Proxy cert watcher: watch failed; hot-reload disabled");
return None;
}
tokio::spawn(async move {
loop {
tokio::select! {
_ = cancel.cancelled() => break,
msg = rx.recv() => {
if msg.is_none() {
break;
}
while rx.try_recv().is_ok() {}
reload_cert(&entry, &resolver).await;
}
}
}
});
Some(watcher)
}
async fn reload_cert(entry: &ProxyEntry, resolver: &Arc<CertResolver>) {
let entry = entry.clone();
let built =
tokio::task::spawn_blocking(move || find_file_cert(&entry).map(|ck| (ck, entry.name)))
.await
.ok()
.flatten();
if let Some((certified, name)) = built {
resolver.swap(certified);
tracing::info!(name = %name, "Proxy TLS cert reloaded");
}
}
fn resolve_initial(entry: &ProxyEntry) -> Result<(Arc<CertifiedKey>, CertSource), ProxyError> {
if let Some(certified) = find_file_cert(entry) {
return Ok((certified, CertSource::Certmesh));
}
let (cert_pem, key_pem) = generate_self_signed(entry)?;
let certified = build_certified_key(&cert_pem, &key_pem)?;
Ok((certified, CertSource::SelfSigned))
}
fn cert_candidate_dirs(entry: &ProxyEntry) -> Vec<PathBuf> {
let certs = koi_common::paths::koi_certs_dir();
let mut dirs = vec![certs.join(&entry.name)];
if let Ok(host) = hostname::get() {
let host = host.to_string_lossy().to_string();
if !host.is_empty() && host != entry.name {
dirs.push(certs.join(host));
}
}
dirs
}
fn find_file_cert(entry: &ProxyEntry) -> Option<Arc<CertifiedKey>> {
for dir in cert_candidate_dirs(entry) {
let cert = dir.join("fullchain.pem");
let key = dir.join("key.pem");
if !(cert.is_file() && key.is_file()) {
continue;
}
let (Ok(cert_pem), Ok(key_pem)) = (std::fs::read(&cert), std::fs::read(&key)) else {
continue;
};
match build_certified_key(&cert_pem, &key_pem) {
Ok(certified) => return Some(certified),
Err(e) => tracing::warn!(
name = %entry.name, dir = %dir.display(), error = %e,
"Proxy cert files present but unusable; trying next source"
),
}
}
None
}
fn generate_self_signed(entry: &ProxyEntry) -> Result<(Vec<u8>, Vec<u8>), ProxyError> {
let mut sans = vec!["localhost".to_string()];
if !entry.name.is_empty() && entry.name != "localhost" {
sans.push(entry.name.clone());
}
if let Ok(host) = hostname::get() {
let host = host.to_string_lossy().to_string();
if !host.is_empty() && !sans.contains(&host) {
sans.push(host);
}
}
let generated = rcgen::generate_simple_self_signed(sans)
.map_err(|e| ProxyError::Io(format!("self-signed cert generation failed: {e}")))?;
Ok((
generated.cert.pem().into_bytes(),
generated.key_pair.serialize_pem().into_bytes(),
))
}
fn build_certified_key(cert_pem: &[u8], key_pem: &[u8]) -> Result<Arc<CertifiedKey>, ProxyError> {
let certs: Vec<CertificateDer<'static>> = CertificateDer::pem_slice_iter(cert_pem)
.collect::<Result<_, _>>()
.map_err(|e| ProxyError::Io(format!("cert parse: {e}")))?;
if certs.is_empty() {
return Err(ProxyError::Io("no certificates in PEM".to_string()));
}
let key: PrivateKeyDer<'static> = PrivateKeyDer::from_pem_slice(key_pem)
.map_err(|e| ProxyError::Io(format!("key parse: {e}")))?;
let signing_key = provider()
.key_provider
.load_private_key(key)
.map_err(|e| ProxyError::Io(format!("load private key: {e}")))?;
Ok(Arc::new(CertifiedKey::new(certs, signing_key)))
}
fn provider() -> Arc<CryptoProvider> {
static PROVIDER: OnceLock<Arc<CryptoProvider>> = OnceLock::new();
PROVIDER
.get_or_init(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()))
.clone()
}