use axum::extract::Extension;
use axum::Router;
use hyper_util::rt::{TokioExecutor, TokioIo};
use hyper_util::server::conn::auto::Builder;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::TlsAcceptor;
use tokio_util::sync::CancellationToken;
use crate::error::CertmeshError;
use crate::http::ClientCn;
use crate::mtls;
pub const TLS_HANDSHAKE_FIRST_BYTE: u8 = 0x16;
#[derive(Clone)]
pub struct AdaptiveServerConfig(Arc<rustls::ServerConfig>);
impl AdaptiveServerConfig {
pub fn from_identity(
cert_pem: &str,
key_pem: &str,
ca_cert_pem: &str,
) -> Result<Self, CertmeshError> {
Ok(Self(Arc::new(mtls::build_server_config(
cert_pem,
key_pem,
ca_cert_pem,
)?)))
}
}
impl std::fmt::Debug for AdaptiveServerConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("AdaptiveServerConfig(<rustls::ServerConfig>)")
}
}
pub async fn serve_plain(tcp: TcpStream, router: Router, cancel: CancellationToken) {
let io = TokioIo::new(tcp);
let svc = hyper_util::service::TowerToHyperService::new(router);
let builder = Builder::new(TokioExecutor::new());
tokio::select! {
res = builder.serve_connection_with_upgrades(io, svc) => {
if let Err(e) = res {
tracing::debug!(error = %e, "plain connection error");
}
}
_ = cancel.cancelled() => {}
}
}
pub async fn serve_mtls(
tcp: TcpStream,
config: AdaptiveServerConfig,
router: Router,
cancel: CancellationToken,
) {
let acceptor = TlsAcceptor::from(config.0);
let tls = match acceptor.accept(tcp).await {
Ok(s) => s,
Err(e) => {
tracing::debug!(error = %e, "mTLS handshake failed");
return;
}
};
let cn = tls
.get_ref()
.1
.peer_certificates()
.and_then(|certs| certs.first())
.and_then(|cert| mtls::extract_cn(cert.as_ref()));
let cn = match cn {
Some(cn) => cn,
None => {
tracing::warn!("no CN in client certificate; dropping connection");
return;
}
};
let svc = router.layer(Extension(ClientCn(cn)));
let io = TokioIo::new(tls);
let hyper_svc = hyper_util::service::TowerToHyperService::new(svc);
let builder = Builder::new(TokioExecutor::new());
tokio::select! {
res = builder.serve_connection_with_upgrades(io, hyper_svc) => {
if let Err(e) = res {
tracing::debug!(error = %e, "mTLS connection error");
}
}
_ = cancel.cancelled() => {}
}
}