use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use axum::Router;
use tokio::net::{TcpListener, TcpStream};
use tokio_util::sync::CancellationToken;
use koi_certmesh::serve::{
serve_mtls, serve_plain, AdaptiveServerConfig, TLS_HANDSHAKE_FIRST_BYTE,
};
use koi_certmesh::CertmeshCore;
use koi_common::posture::Posture;
const PEEK_TIMEOUT: Duration = Duration::from_secs(10);
pub async fn serve_adaptive(
core: Arc<CertmeshCore>,
router: Router,
addr: SocketAddr,
cancel: CancellationToken,
) -> std::io::Result<()> {
let listener = TcpListener::bind(addr).await?;
let mut posture_rx = core.watch_posture();
let mut posture = *posture_rx.borrow_and_update();
let mut tls_config = build_tls_config(&core, posture).await;
tracing::info!(%addr, ?posture, "same-port dial: listening");
loop {
tokio::select! {
_ = cancel.cancelled() => return Ok(()),
changed = posture_rx.changed() => {
if changed.is_err() {
return Ok(()); }
posture = *posture_rx.borrow_and_update();
tls_config = build_tls_config(&core, posture).await;
tracing::info!(
?posture,
"same-port dial: posture changed — new connections use the updated protocol"
);
}
accepted = listener.accept() => {
let (tcp, peer) = match accepted {
Ok(v) => v,
Err(e) => {
tracing::warn!(error = %e, "same-port dial: accept error");
continue;
}
};
let router = router.clone();
let cancel_conn = cancel.clone();
let secure = posture.signed;
let cfg = tls_config.clone();
tokio::spawn(async move {
dispatch_connection(tcp, peer, secure, cfg, router, cancel_conn).await;
});
}
}
}
}
async fn build_tls_config(core: &CertmeshCore, posture: Posture) -> Option<AdaptiveServerConfig> {
if !posture.signed {
return None;
}
let id = core.local_identity().await?;
match AdaptiveServerConfig::from_identity(&id.cert_pem, &id.key_pem, &id.ca_cert_pem) {
Ok(cfg) => Some(cfg),
Err(e) => {
tracing::error!(
error = %e,
"same-port dial: secure posture but could not build mTLS config — \
secure connections will be refused"
);
None
}
}
}
async fn dispatch_connection(
tcp: TcpStream,
peer: SocketAddr,
secure: bool,
cfg: Option<AdaptiveServerConfig>,
router: Router,
cancel: CancellationToken,
) {
let mut first = [0u8; 1];
let n = match tokio::time::timeout(PEEK_TIMEOUT, tcp.peek(&mut first)).await {
Ok(Ok(n)) => n,
Ok(Err(e)) => {
tracing::debug!(%peer, error = %e, "same-port dial: peek failed");
return;
}
Err(_) => {
tracing::debug!(%peer, "same-port dial: peek timed out");
return;
}
};
if n == 0 {
return; }
let is_tls = first[0] == TLS_HANDSHAKE_FIRST_BYTE;
match (secure, is_tls) {
(true, true) => match cfg {
Some(cfg) => serve_mtls(tcp, cfg, router, cancel).await,
None => tracing::warn!(
%peer,
"same-port dial: secure posture but no mTLS config available — dropping TLS connection"
),
},
(true, false) => {
tracing::warn!(
%peer,
"same-port dial: refused a plaintext connection to a secure node (mTLS required)"
);
}
(false, false) => serve_plain(tcp, router, cancel).await,
(false, true) => {
tracing::warn!(
%peer,
"same-port dial: refused a TLS connection to an Open node (no identity to terminate TLS)"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::extract::Extension;
use axum::routing::{get, post};
use koi_certmesh::http::ClientCn;
use koi_certmesh::{ca, roster::Roster, CertmeshCore, CertmeshPaths};
fn isolated_paths(tag: &str) -> CertmeshPaths {
let dir = std::env::temp_dir().join(format!("koi-emb-serve-{tag}-{}", std::process::id()));
let _ = std::fs::remove_dir_all(&dir);
std::fs::create_dir_all(&dir).unwrap();
CertmeshPaths::with_data_dir(dir)
}
fn open_core(tag: &str) -> Arc<CertmeshCore> {
Arc::new(CertmeshCore::uninitialized_with_paths(isolated_paths(tag)))
}
async fn secure_core(tag: &str) -> Arc<CertmeshCore> {
std::env::set_var("KOI_NO_CREDENTIAL_STORE", "1");
let paths = isolated_paths(tag);
let ca = ca::create_ca("test-pass", &[3u8; 32], &paths).unwrap().0;
let roster = Roster::new(false, false, None);
let core = CertmeshCore::new_with_paths(ca, roster, None, paths);
core.self_enroll().await.expect("self-enroll");
assert!(core.posture().signed, "core should be secure");
Arc::new(core)
}
fn plain_router() -> Router {
Router::new().route("/ping", get(|| async { "pong" }))
}
fn cn_router() -> Router {
Router::new().route(
"/echo",
post(|Extension(ClientCn(cn)): Extension<ClientCn>| async move { cn }),
)
}
fn combined_router() -> Router {
Router::new()
.route("/ping", get(|| async { "pong" }))
.route(
"/echo",
post(|Extension(ClientCn(cn)): Extension<ClientCn>| async move { cn }),
)
}
async fn bind_addr() -> SocketAddr {
let l = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
l.local_addr().unwrap()
}
#[tokio::test]
async fn open_node_serves_plaintext() {
let core = open_core("open-plain");
let addr = bind_addr().await;
let cancel = CancellationToken::new();
let server = tokio::spawn(serve_adaptive(core, plain_router(), addr, cancel.clone()));
tokio::time::sleep(Duration::from_millis(50)).await;
let (status, body) = koi_certmesh::mtls::get(&addr.ip().to_string(), addr.port(), "/ping")
.await
.expect("plain GET to Open node");
assert_eq!(status, 200);
assert_eq!(body, "pong");
cancel.cancel();
let _ = server.await;
}
#[tokio::test]
async fn secure_node_serves_mtls() {
let core = secure_core("secure-mtls").await;
let id = core.local_identity().await.expect("identity");
let addr = bind_addr().await;
let cancel = CancellationToken::new();
let server = tokio::spawn(serve_adaptive(
Arc::clone(&core),
cn_router(),
addr,
cancel.clone(),
));
tokio::time::sleep(Duration::from_millis(50)).await;
let (status, body) = koi_certmesh::mtls::post_json(
&addr.ip().to_string(),
addr.port(),
"/echo",
"{}",
&id.cert_pem,
&id.key_pem,
&id.ca_cert_pem,
)
.await
.expect("mTLS POST to secure node");
assert_eq!(status, 200);
assert_eq!(body, id.hostname, "the server authenticated our leaf CN");
cancel.cancel();
let _ = server.await;
}
#[tokio::test]
async fn secure_node_refuses_plaintext() {
let core = secure_core("secure-refuse-plain").await;
let addr = bind_addr().await;
let cancel = CancellationToken::new();
let server = tokio::spawn(serve_adaptive(core, plain_router(), addr, cancel.clone()));
tokio::time::sleep(Duration::from_millis(50)).await;
let result = koi_certmesh::mtls::get(&addr.ip().to_string(), addr.port(), "/ping").await;
assert!(
result.is_err() || result.as_ref().unwrap().0 != 200,
"secure node must refuse plaintext; got {result:?}"
);
cancel.cancel();
let _ = server.await;
}
#[tokio::test]
async fn open_node_refuses_tls() {
let core = open_core("open-refuse-tls");
let client = secure_core("open-refuse-tls-client").await;
let id = client.local_identity().await.unwrap();
let addr = bind_addr().await;
let cancel = CancellationToken::new();
let server = tokio::spawn(serve_adaptive(core, cn_router(), addr, cancel.clone()));
tokio::time::sleep(Duration::from_millis(50)).await;
let result = koi_certmesh::mtls::post_json(
&addr.ip().to_string(),
addr.port(),
"/echo",
"{}",
&id.cert_pem,
&id.key_pem,
&id.ca_cert_pem,
)
.await;
assert!(result.is_err(), "Open node must refuse TLS; got {result:?}");
cancel.cancel();
let _ = server.await;
}
#[tokio::test]
async fn live_flip_open_to_secure_without_restart() {
std::env::set_var("KOI_NO_CREDENTIAL_STORE", "1");
let paths = isolated_paths("flip");
let core = Arc::new(CertmeshCore::uninitialized_with_paths(paths));
let addr = bind_addr().await;
let cancel = CancellationToken::new();
let server = tokio::spawn(serve_adaptive(
Arc::clone(&core),
combined_router(),
addr,
cancel.clone(),
));
tokio::time::sleep(Duration::from_millis(50)).await;
let (status, body) = koi_certmesh::mtls::get(&addr.ip().to_string(), addr.port(), "/ping")
.await
.expect("plain works while Open");
assert_eq!(status, 200);
assert_eq!(body, "pong");
let req = koi_certmesh::protocol::CreateCaRequest {
passphrase: "test-pass-strong".to_string(),
entropy_hex: koi_common::encoding::hex_encode(&[8u8; 32]),
operator: None,
enrollment_open: false,
requires_approval: false,
auto_unlock: false,
totp_secret_hex: None,
};
core.create(req).await.expect("create CA");
tokio::time::sleep(Duration::from_millis(250)).await;
let plain = koi_certmesh::mtls::get(&addr.ip().to_string(), addr.port(), "/ping").await;
assert!(
plain.is_err() || plain.as_ref().unwrap().0 != 200,
"plaintext must be refused after the flip; got {plain:?}"
);
let id = core.local_identity().await.expect("identity after create");
let (status, body) = koi_certmesh::mtls::post_json(
&addr.ip().to_string(),
addr.port(),
"/echo",
"{}",
&id.cert_pem,
&id.key_pem,
&id.ca_cert_pem,
)
.await
.expect("mTLS works after the flip");
assert_eq!(status, 200);
assert_eq!(body, id.hostname);
cancel.cancel();
let _ = server.await;
}
}