use std::future::Future;
#[cfg(not(feature = "compio"))]
use std::path::PathBuf;
#[cfg(not(feature = "compio"))]
use std::pin::Pin;
#[cfg(any(not(feature = "compio"), feature = "tls"))]
use std::sync::Arc;
use std::time::Duration;
use tako_rs_core::router::Router;
#[cfg(not(feature = "compio"))]
use tokio::net::TcpListener;
use crate::ServerConfig;
pub struct ServerHandle {
shutdown: tokio_util::sync::CancellationToken,
done: tokio_util::sync::CancellationToken,
drain_timeout: Duration,
}
impl ServerHandle {
pub fn trigger(&self) {
self.shutdown.cancel();
}
pub async fn join(&self) {
self.done.cancelled().await;
}
pub async fn shutdown(self, _timeout: Duration) {
self.shutdown.cancel();
self.done.cancelled().await;
}
#[inline]
pub fn drain_timeout(&self) -> Duration {
self.drain_timeout
}
}
impl std::fmt::Debug for ServerHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ServerHandle")
.field("drain_timeout", &self.drain_timeout)
.finish_non_exhaustive()
}
}
pub async fn either<A, B>(a: A, b: B)
where
A: Future<Output = ()>,
B: Future<Output = ()>,
{
use futures_util::future::Either;
let a = std::pin::pin!(a);
let b = std::pin::pin!(b);
match futures_util::future::select(a, b).await {
Either::Left(_) | Either::Right(_) => {}
}
}
#[cfg(feature = "tls")]
#[derive(Clone)]
pub enum ClientAuth {
Optional(Arc<rustls::RootCertStore>),
Required(Arc<rustls::RootCertStore>),
}
#[cfg(feature = "tls")]
impl std::fmt::Debug for ClientAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ClientAuth::Optional(_) => f.debug_tuple("Optional").field(&"<root_store>").finish(),
ClientAuth::Required(_) => f.debug_tuple("Required").field(&"<root_store>").finish(),
}
}
}
#[derive(Clone)]
pub enum TlsCert {
PemPaths {
cert_path: String,
key_path: String,
#[cfg(feature = "tls")]
client_auth: Option<ClientAuth>,
},
#[cfg(feature = "tls")]
Der {
certs: Arc<Vec<rustls::pki_types::CertificateDer<'static>>>,
key: Arc<rustls::pki_types::PrivateKeyDer<'static>>,
client_auth: Option<ClientAuth>,
},
#[cfg(feature = "tls")]
Resolver {
resolver: Arc<dyn rustls::server::ResolvesServerCert>,
client_auth: Option<ClientAuth>,
},
}
impl std::fmt::Debug for TlsCert {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TlsCert::PemPaths {
cert_path,
key_path,
..
} => f
.debug_struct("PemPaths")
.field("cert_path", cert_path)
.field("key_path", key_path)
.finish_non_exhaustive(),
#[cfg(feature = "tls")]
TlsCert::Der { client_auth, .. } => f
.debug_struct("Der")
.field("client_auth", client_auth)
.finish_non_exhaustive(),
#[cfg(feature = "tls")]
TlsCert::Resolver { client_auth, .. } => f
.debug_struct("Resolver")
.field("client_auth", client_auth)
.finish_non_exhaustive(),
}
}
}
impl TlsCert {
pub fn pem_paths(cert: impl Into<String>, key: impl Into<String>) -> Self {
Self::PemPaths {
cert_path: cert.into(),
key_path: key.into(),
#[cfg(feature = "tls")]
client_auth: None,
}
}
#[cfg(feature = "tls")]
pub fn pem_paths_with_client_auth(
cert: impl Into<String>,
key: impl Into<String>,
client_auth: ClientAuth,
) -> Self {
Self::PemPaths {
cert_path: cert.into(),
key_path: key.into(),
client_auth: Some(client_auth),
}
}
#[cfg(feature = "tls")]
pub fn der(
certs: Vec<rustls::pki_types::CertificateDer<'static>>,
key: rustls::pki_types::PrivateKeyDer<'static>,
) -> Self {
Self::Der {
certs: Arc::new(certs),
key: Arc::new(key),
client_auth: None,
}
}
#[cfg(feature = "tls")]
pub fn resolver(resolver: Arc<dyn rustls::server::ResolvesServerCert>) -> Self {
Self::Resolver {
resolver,
client_auth: None,
}
}
#[cfg(feature = "tls")]
pub fn with_client_auth(mut self, auth: ClientAuth) -> Self {
match &mut self {
TlsCert::PemPaths { client_auth, .. }
| TlsCert::Der { client_auth, .. }
| TlsCert::Resolver { client_auth, .. } => *client_auth = Some(auth),
}
self
}
}
#[cfg(feature = "tls")]
pub struct ReloadableResolver {
current: arc_swap::ArcSwap<rustls::sign::CertifiedKey>,
}
#[cfg(feature = "tls")]
impl std::fmt::Debug for ReloadableResolver {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ReloadableResolver").finish_non_exhaustive()
}
}
#[cfg(feature = "tls")]
impl ReloadableResolver {
pub fn from_pem(cert_path: &str, key_path: &str) -> anyhow::Result<Self> {
let ck = build_certified_key(cert_path, key_path)?;
Ok(Self {
current: arc_swap::ArcSwap::from_pointee(ck),
})
}
pub fn reload_from_pem(&self, cert_path: &str, key_path: &str) -> anyhow::Result<()> {
let ck = build_certified_key(cert_path, key_path)?;
self.current.store(Arc::new(ck));
Ok(())
}
pub fn reload(&self, ck: rustls::sign::CertifiedKey) {
self.current.store(Arc::new(ck));
}
}
#[cfg(feature = "tls")]
impl rustls::server::ResolvesServerCert for ReloadableResolver {
fn resolve(
&self,
_client_hello: rustls::server::ClientHello<'_>,
) -> Option<Arc<rustls::sign::CertifiedKey>> {
Some(self.current.load_full())
}
}
#[cfg(feature = "tls")]
fn build_certified_key(
cert_path: &str,
key_path: &str,
) -> anyhow::Result<rustls::sign::CertifiedKey> {
let certs = tako_rs_core::tls::load_certs(cert_path)?;
let key = tako_rs_core::tls::load_key(key_path)?;
let we_installed = if rustls::crypto::CryptoProvider::get_default().is_none() {
rustls::crypto::aws_lc_rs::default_provider()
.install_default()
.is_ok()
} else {
false
};
if !we_installed {
static WARNED: std::sync::Once = std::sync::Once::new();
WARNED.call_once(|| {
tracing::warn!(
"tako-server: a rustls CryptoProvider was already installed before \
`build_certified_key` ran — Tako will use that provider for key \
loading instead of installing aws-lc-rs. If signing behavior is \
not what you expect (e.g. h3 installed `ring` first), pin the \
provider at process startup with `rustls::crypto::aws_lc_rs::\
default_provider().install_default()` BEFORE constructing the \
server."
);
});
}
let provider = rustls::crypto::CryptoProvider::get_default().ok_or_else(|| {
anyhow::anyhow!(
"no rustls CryptoProvider installed — enable rustls's `aws_lc_rs` or `ring` feature"
)
})?;
let signer = provider
.key_provider
.load_private_key(key)
.map_err(|e| anyhow::anyhow!("failed to load signing key from '{key_path}': {e}"))?;
Ok(rustls::sign::CertifiedKey::new(certs, signer))
}
#[cfg(feature = "tls")]
pub fn build_rustls_server_config(
cert: &TlsCert,
alpn: Vec<Vec<u8>>,
) -> anyhow::Result<Arc<rustls::ServerConfig>> {
use rustls::ServerConfig as RustlsServerConfig;
let builder = RustlsServerConfig::builder();
let client_auth = match cert {
TlsCert::PemPaths { client_auth, .. }
| TlsCert::Der { client_auth, .. }
| TlsCert::Resolver { client_auth, .. } => client_auth.clone(),
};
let builder_with_auth = match client_auth {
Some(ClientAuth::Optional(roots)) => {
let verifier = rustls::server::WebPkiClientVerifier::builder(roots)
.allow_unauthenticated()
.build()
.map_err(|e| anyhow::anyhow!("WebPkiClientVerifier build failed: {e}"))?;
builder.with_client_cert_verifier(verifier)
}
Some(ClientAuth::Required(roots)) => {
let verifier = rustls::server::WebPkiClientVerifier::builder(roots)
.build()
.map_err(|e| anyhow::anyhow!("WebPkiClientVerifier build failed: {e}"))?;
builder.with_client_cert_verifier(verifier)
}
None => builder.with_no_client_auth(),
};
let mut config = match cert {
TlsCert::PemPaths {
cert_path,
key_path,
..
} => {
let certs = tako_rs_core::tls::load_certs(cert_path)?;
let key = tako_rs_core::tls::load_key(key_path)?;
builder_with_auth
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("rustls config build failed: {e}"))?
}
TlsCert::Der { certs, key, .. } => {
let certs = certs.as_ref().clone();
let key = key.as_ref().clone_key();
builder_with_auth
.with_single_cert(certs, key)
.map_err(|e| anyhow::anyhow!("rustls config build failed: {e}"))?
}
TlsCert::Resolver { resolver, .. } => builder_with_auth.with_cert_resolver(resolver.clone()),
};
config.alpn_protocols = alpn;
if config.alpn_protocols.iter().any(|p| p.as_slice() == b"h3") {
config.max_early_data_size = 0;
}
Ok(Arc::new(config))
}
#[cfg(not(feature = "compio"))]
#[derive(Debug, Default, Clone)]
pub struct ServerBuilder {
config: ServerConfig,
tls: Option<TlsCert>,
}
#[cfg(not(feature = "compio"))]
impl ServerBuilder {
#[must_use]
pub fn config(mut self, config: ServerConfig) -> Self {
self.config = config;
self
}
#[must_use]
pub fn tls(mut self, cert: TlsCert) -> Self {
self.tls = Some(cert);
self
}
pub fn build(self) -> Server {
Server {
config: self.config,
tls: self.tls,
}
}
}
#[cfg(not(feature = "compio"))]
#[derive(Debug, Clone)]
pub struct Server {
config: ServerConfig,
#[cfg_attr(not(any(feature = "tls", feature = "http3")), allow(dead_code))]
tls: Option<TlsCert>,
}
#[cfg(not(feature = "compio"))]
impl Server {
#[must_use]
pub fn builder() -> ServerBuilder {
ServerBuilder::default()
}
#[inline]
pub fn config(&self) -> &ServerConfig {
&self.config
}
pub fn spawn_http(&self, listener: TcpListener, router: Router) -> ServerHandle {
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
crate::server::serve_with_shutdown_and_config(listener, router, shutdown_fut, config).await;
});
handle
}
#[cfg(feature = "http2")]
pub fn spawn_h2c(&self, listener: TcpListener, router: Router) -> ServerHandle {
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
crate::server_h2c::serve_h2c_with_shutdown_and_config(listener, router, shutdown_fut, config)
.await;
});
handle
}
#[cfg(feature = "tls")]
pub fn spawn_tls(&self, listener: TcpListener, router: Router) -> ServerHandle {
let tls = self
.tls
.clone()
.expect("Server::spawn_tls requires a TlsCert (use builder().tls(...))");
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
let alpn = tls_alpn_for_tcp();
spawn_done(handle.done.clone(), async move {
if let TlsCert::PemPaths {
cert_path,
key_path,
client_auth: None,
} = &tls
{
crate::server_tls::serve_tls_with_shutdown_and_config(
listener,
router,
Some(cert_path.as_str()),
Some(key_path.as_str()),
shutdown_fut,
config,
)
.await;
return;
}
let rustls_cfg = match build_rustls_server_config(&tls, alpn) {
Ok(c) => c,
Err(e) => {
tracing::error!("Server::spawn_tls: failed to build rustls config: {e}");
return;
}
};
crate::server_tls::serve_tls_with_rustls_config_and_shutdown(
listener,
router,
rustls_cfg,
shutdown_fut,
config,
)
.await;
});
handle
}
#[cfg(feature = "http3")]
pub fn spawn_h3(&self, addr: impl Into<String>, router: Router) -> ServerHandle {
let tls = self
.tls
.clone()
.expect("Server::spawn_h3 requires a TlsCert (use builder().tls(...))");
let addr = addr.into();
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
if let TlsCert::PemPaths {
cert_path,
key_path,
client_auth: None,
} = &tls
{
crate::server_h3::serve_h3_with_shutdown_and_config(
router,
&addr,
Some(cert_path.as_str()),
Some(key_path.as_str()),
shutdown_fut,
config,
)
.await;
return;
}
let rustls_cfg = match build_rustls_server_config(&tls, vec![b"h3".to_vec()]) {
Ok(c) => c,
Err(e) => {
tracing::error!("Server::spawn_h3: failed to build rustls config: {e}");
return;
}
};
crate::server_h3::serve_h3_with_rustls_config_and_shutdown(
router,
&addr,
rustls_cfg,
shutdown_fut,
config,
)
.await;
});
handle
}
#[cfg(unix)]
pub fn spawn_unix_http(&self, path: impl Into<PathBuf>, router: Router) -> ServerHandle {
let path = path.into();
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
crate::server_unix::serve_unix_http_with_shutdown_and_config(
path,
router,
shutdown_fut,
config,
)
.await;
});
handle
}
#[cfg(all(target_os = "linux", feature = "vsock"))]
pub fn spawn_vsock_http(&self, cid: u32, port: u32, router: Router) -> ServerHandle {
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
crate::server_vsock::serve_vsock_http_with_shutdown_and_config(
cid,
port,
router,
shutdown_fut,
config,
)
.await;
});
handle
}
pub fn spawn_proxy_protocol(&self, listener: TcpListener, router: Router) -> ServerHandle {
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done(handle.done.clone(), async move {
crate::proxy_protocol::serve_http_with_proxy_protocol_shutdown_and_config(
listener,
router,
shutdown_fut,
config,
)
.await;
});
handle
}
pub fn spawn_tcp_raw<F>(&self, addr: impl Into<String>, handler: F) -> ServerHandle
where
F: Fn(
tokio::net::TcpStream,
std::net::SocketAddr,
) -> Pin<Box<dyn Future<Output = std::io::Result<()>> + Send>>
+ Send
+ Sync
+ 'static,
{
let addr = addr.into();
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
spawn_done(handle.done.clone(), async move {
if let Err(e) = crate::server_tcp::serve_tcp_with_shutdown(&addr, handler, shutdown_fut).await
{
tracing::error!("raw TCP server error: {e}");
}
});
handle
}
pub fn spawn_udp_raw<F>(&self, addr: impl Into<String>, handler: F) -> ServerHandle
where
F: Fn(
Vec<u8>,
std::net::SocketAddr,
Arc<tokio::net::UdpSocket>,
) -> Pin<Box<dyn Future<Output = ()> + Send>>
+ Send
+ Sync
+ 'static,
{
let addr = addr.into();
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
spawn_done(handle.done.clone(), async move {
if let Err(e) = crate::server_udp::serve_udp_with_shutdown(&addr, handler, shutdown_fut).await
{
tracing::error!("raw UDP server error: {e}");
}
});
handle
}
}
#[cfg(feature = "compio")]
#[derive(Debug, Default, Clone)]
pub struct CompioServerBuilder {
config: ServerConfig,
#[cfg(feature = "compio-tls")]
tls: Option<TlsCert>,
}
#[cfg(feature = "compio")]
impl CompioServerBuilder {
#[must_use]
pub fn config(mut self, config: ServerConfig) -> Self {
self.config = config;
self
}
#[cfg(feature = "compio-tls")]
#[must_use]
pub fn tls(mut self, cert: TlsCert) -> Self {
self.tls = Some(cert);
self
}
pub fn build(self) -> CompioServer {
CompioServer {
config: self.config,
#[cfg(feature = "compio-tls")]
tls: self.tls,
}
}
}
#[cfg(feature = "compio")]
#[derive(Debug, Clone)]
pub struct CompioServer {
config: ServerConfig,
#[cfg(feature = "compio-tls")]
tls: Option<TlsCert>,
}
#[cfg(feature = "compio")]
impl CompioServer {
#[must_use]
pub fn builder() -> CompioServerBuilder {
CompioServerBuilder::default()
}
#[inline]
pub fn config(&self) -> &ServerConfig {
&self.config
}
pub fn spawn_http(&self, listener: compio::net::TcpListener, router: Router) -> ServerHandle {
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
spawn_done_compio(handle.done.clone(), async move {
crate::server_compio::serve_with_shutdown_and_config(listener, router, shutdown_fut, config)
.await;
});
handle
}
#[cfg(feature = "compio-tls")]
pub fn spawn_tls(&self, listener: compio::net::TcpListener, router: Router) -> ServerHandle {
let tls = self
.tls
.clone()
.expect("CompioServer::spawn_tls requires a TlsCert (use builder().tls(...))");
let (handle, shutdown_fut) = make_handle(self.config.drain_timeout);
let config = self.config.clone();
let alpn = tls_alpn_for_tcp();
spawn_done_compio(handle.done.clone(), async move {
if let TlsCert::PemPaths {
cert_path,
key_path,
client_auth: None,
} = &tls
{
crate::server_tls_compio::serve_tls_with_shutdown_and_config(
listener,
router,
Some(cert_path.as_str()),
Some(key_path.as_str()),
shutdown_fut,
config,
)
.await;
return;
}
let rustls_cfg = match build_rustls_server_config(&tls, alpn) {
Ok(c) => c,
Err(e) => {
tracing::error!("CompioServer::spawn_tls: failed to build rustls config: {e}");
return;
}
};
crate::server_tls_compio::serve_tls_with_rustls_config_and_shutdown(
listener,
router,
rustls_cfg,
shutdown_fut,
config,
)
.await;
});
handle
}
}
#[cfg(feature = "tls")]
#[inline]
fn tls_alpn_for_tcp() -> Vec<Vec<u8>> {
#[cfg(feature = "http2")]
{
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
}
#[cfg(not(feature = "http2"))]
{
vec![b"http/1.1".to_vec()]
}
}
fn make_handle(
drain_timeout: Duration,
) -> (ServerHandle, impl Future<Output = ()> + Send + 'static) {
let shutdown = tokio_util::sync::CancellationToken::new();
let done = tokio_util::sync::CancellationToken::new();
let shutdown_for_task = shutdown.clone();
let fut = async move {
shutdown_for_task.cancelled().await;
};
(
ServerHandle {
shutdown,
done,
drain_timeout,
},
fut,
)
}
#[cfg(not(feature = "compio"))]
fn spawn_done<F>(done: tokio_util::sync::CancellationToken, fut: F)
where
F: Future<Output = ()> + Send + 'static,
{
tokio::spawn(async move {
fut.await;
done.cancel();
});
}
#[cfg(feature = "compio")]
fn spawn_done_compio<F>(done: tokio_util::sync::CancellationToken, fut: F)
where
F: Future<Output = ()> + 'static,
{
compio::runtime::spawn(async move {
fut.await;
done.cancel();
})
.detach();
}