use std::sync::Arc;
use axum::extract::Extension;
use axum::Router;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer};
use tokio::net::TcpListener;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use x509_parser::prelude::*;
use koi_certmesh::http::ClientCn;
pub const DEFAULT_MTLS_PORT: u16 = 5642;
pub async fn start(
port: u16,
certmesh_core: Arc<koi_certmesh::CertmeshCore>,
cert_pem: &str,
key_pem: &str,
ca_cert_pem: &str,
cancel: CancellationToken,
) -> anyhow::Result<()> {
let tls_config = build_tls_config(cert_pem, key_pem, ca_cert_pem)?;
let tls_acceptor = TlsAcceptor::from(Arc::new(tls_config));
let app = Router::new().nest("/v1/certmesh", certmesh_core.inter_node_routes());
let listener = TcpListener::bind(("0.0.0.0", port)).await?;
tracing::info!(port, "mTLS adapter listening");
loop {
let (tcp, addr) = tokio::select! {
res = listener.accept() => match res {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "mTLS accept error");
continue;
}
},
_ = cancel.cancelled() => {
tracing::debug!("mTLS adapter stopped");
return Ok(());
}
};
let acceptor = tls_acceptor.clone();
let app = app.clone();
let cancel = cancel.clone();
tokio::spawn(async move {
let tls_stream = match acceptor.accept(tcp).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(%addr, error = %e, "mTLS handshake failed");
return;
}
};
let cn = tls_stream
.get_ref()
.1
.peer_certificates()
.and_then(|certs| certs.first())
.and_then(|cert| extract_cn(cert.as_ref()));
let cn = match cn {
Some(cn) => cn,
None => {
tracing::warn!(%addr, "no CN in client certificate");
return;
}
};
tracing::debug!(%addr, %cn, "mTLS authenticated");
let svc = app.layer(Extension(ClientCn(cn)));
let io = TokioIo::new(tls_stream);
let builder = Builder::new(TokioExecutor::new());
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
tokio::select! {
res = builder.serve_connection_with_upgrades(io, hyper_svc) => {
if let Err(e) = res {
tracing::debug!(%addr, error = %e, "mTLS connection error");
}
}
_ = cancel.cancelled() => {}
}
});
}
}
fn extract_cn(cert_der: &[u8]) -> Option<String> {
let (_, cert) = X509Certificate::from_der(cert_der).ok()?;
let cn = cert
.subject()
.iter_common_name()
.next()
.and_then(|cn| cn.as_str().ok())
.map(String::from);
cn
}
fn build_tls_config(
cert_pem: &str,
key_pem: &str,
ca_cert_pem: &str,
) -> anyhow::Result<rustls::ServerConfig> {
let certs: Vec<CertificateDer<'static>> =
CertificateDer::pem_slice_iter(cert_pem.as_bytes()).collect::<Result<Vec<_>, _>>()?;
if certs.is_empty() {
anyhow::bail!("no certificates found in server cert PEM");
}
let key: PrivateKeyDer<'static> = PrivateKeyDer::from_pem_slice(key_pem.as_bytes())?;
let mut root_store = rustls::RootCertStore::empty();
let ca_certs: Vec<CertificateDer<'static>> =
CertificateDer::pem_slice_iter(ca_cert_pem.as_bytes()).collect::<Result<Vec<_>, _>>()?;
for ca_cert in ca_certs {
root_store.add(ca_cert)?;
}
let client_verifier =
rustls::server::WebPkiClientVerifier::builder(Arc::new(root_store)).build()?;
let config = rustls::ServerConfig::builder()
.with_client_cert_verifier(client_verifier)
.with_single_cert(certs, key)?;
Ok(config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_cn_from_self_signed() {
let mut params = rcgen::CertificateParams::default();
params.distinguished_name = rcgen::DistinguishedName::new();
params
.distinguished_name
.push(rcgen::DnType::CommonName, "test-host");
let key_pair = rcgen::KeyPair::generate().expect("keygen");
let cert = params.self_signed(&key_pair).expect("self-sign");
let der = cert.der();
let cn = extract_cn(der.as_ref());
assert_eq!(cn, Some("test-host".to_string()));
}
#[test]
fn extract_cn_returns_none_for_garbage() {
let cn = extract_cn(b"not a certificate");
assert_eq!(cn, None);
}
}