#![cfg(feature = "tls")]
#![cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::convert::Infallible;
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use compio::net::TcpListener;
use compio::tls::TlsAcceptor;
use cyper_core::HyperStream;
use futures_util::future::Either;
use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::server::conn::http2;
use hyper::service::service_fn;
use rustls::ServerConfig as RustlsServerConfig;
#[cfg(feature = "http2")]
use send_wrapper::SendWrapper;
use tako_rs_core::body::TakoBody;
use tako_rs_core::conn_info::ConnInfo;
use tako_rs_core::conn_info::TlsInfo;
use tako_rs_core::router::Router;
#[cfg(feature = "signals")]
use tako_rs_core::signals::transport as signal_tx;
use tako_rs_core::types::BoxError;
use tokio::sync::Notify;
use tokio_util::sync::CancellationToken;
use crate::ServerConfig;
pub async fn serve_tls(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
) {
if let Err(e) = run(
listener,
router,
certs,
key,
None::<std::future::Pending<()>>,
ServerConfig::default(),
)
.await
{
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_shutdown(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
signal: impl Future<Output = ()>,
) {
if let Err(e) = run(
listener,
router,
certs,
key,
Some(signal),
ServerConfig::default(),
)
.await
{
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_config(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
config: ServerConfig,
) {
if let Err(e) = run(
listener,
router,
certs,
key,
None::<std::future::Pending<()>>,
config,
)
.await
{
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_shutdown_and_config(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
signal: impl Future<Output = ()>,
config: ServerConfig,
) {
if let Err(e) = run(listener, router, certs, key, Some(signal), config).await {
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_rustls_config(
listener: TcpListener,
router: Router,
rustls_config: Arc<RustlsServerConfig>,
config: ServerConfig,
) {
if let Err(e) = run_with_config(
listener,
router,
rustls_config,
None::<std::future::Pending<()>>,
config,
)
.await
{
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_rustls_config_and_shutdown(
listener: TcpListener,
router: Router,
rustls_config: Arc<RustlsServerConfig>,
signal: impl Future<Output = ()>,
config: ServerConfig,
) {
if let Err(e) = run_with_config(listener, router, rustls_config, Some(signal), config).await {
tracing::error!("TLS server error: {e}");
}
}
pub async fn run(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
signal: Option<impl Future<Output = ()>>,
config: ServerConfig,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
tako_rs_core::tracing::init_tracing();
let certs = load_certs(certs.unwrap_or("cert.pem"))?;
let key = load_key(key.unwrap_or("key.pem"))?;
let mut tls_config = RustlsServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
#[cfg(feature = "http2")]
{
tls_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
}
#[cfg(not(feature = "http2"))]
{
tls_config.alpn_protocols = vec![b"http/1.1".to_vec()];
}
run_with_config(listener, router, Arc::new(tls_config), signal, config).await
}
pub async fn run_with_config(
listener: TcpListener,
router: Router,
tls_config: Arc<RustlsServerConfig>,
signal: Option<impl Future<Output = ()>>,
config: ServerConfig,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
tako_rs_core::tracing::init_tracing();
let acceptor = TlsAcceptor::from(tls_config);
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
let addr_str = listener.local_addr()?.to_string();
#[cfg(feature = "signals")]
signal_tx::emit_server_started(&addr_str, "tcp", true).await;
tracing::info!("Tako TLS listening on {}", addr_str);
let inflight = Arc::new(AtomicUsize::new(0));
let drain_notify = Arc::new(Notify::new());
let drain_timeout = config.drain_timeout;
let tls_handshake_timeout = config.tls_handshake_timeout;
let keep_alive = config.keep_alive;
#[cfg(feature = "http2")]
let h2_max_concurrent_streams = config.h2_max_concurrent_streams;
#[cfg(feature = "http2")]
let h2_max_header_list_size = config.h2_max_header_list_size;
#[cfg(feature = "http2")]
let h2_max_send_buf_size = config.h2_max_send_buf_size;
#[cfg(feature = "http2")]
let h2_max_pending_accept_reset_streams = config.h2_max_pending_accept_reset_streams;
#[cfg(feature = "http2")]
let h2_keep_alive_interval = config.h2_keep_alive_interval;
let max_conn_semaphore = config
.max_connections
.map(|n| Arc::new(tokio::sync::Semaphore::new(n)));
let mut accept_backoff = config.accept_backoff;
let cancel = CancellationToken::new();
let signal = signal.map(|s| Box::pin(s));
let mut signal_fused = std::pin::pin!(async {
if let Some(s) = signal {
s.await;
} else {
std::future::pending::<()>().await;
}
});
loop {
let accept = std::pin::pin!(listener.accept());
match futures_util::future::select(accept, signal_fused.as_mut()).await {
Either::Left((result, _)) => {
let (stream, addr) = match result {
Ok(v) => {
accept_backoff.reset();
v
}
Err(err) => {
tracing::warn!("compio TLS accept failed: {err}; backing off");
let d = accept_backoff.current_and_grow();
let sleep = std::pin::pin!(compio::time::sleep(d));
match futures_util::future::select(sleep, signal_fused.as_mut()).await {
Either::Left(((), _)) => continue,
Either::Right(_) => {
cancel.cancel();
break;
}
}
}
};
let permit = if let Some(sem) = max_conn_semaphore.as_ref() {
let acquire = std::pin::pin!(sem.clone().acquire_owned());
match futures_util::future::select(acquire, signal_fused.as_mut()).await {
Either::Left((Ok(p), _)) => Some(p),
Either::Left((Err(_), _)) => continue,
Either::Right(_) => {
cancel.cancel();
break;
}
}
} else {
None
};
let acceptor = acceptor.clone();
let router = router.clone();
let guard =
crate::server_compio::ConnectionGuard::new(inflight.clone(), drain_notify.clone());
let conn_cancel = cancel.clone();
compio::runtime::spawn(async move {
let _permit = permit;
let _guard = guard;
let handshake_deadline = std::pin::pin!(compio::time::sleep(tls_handshake_timeout));
let shutdown_wait = std::pin::pin!(conn_cancel.cancelled());
let deadline_or_shutdown = std::pin::pin!(futures_util::future::select(
handshake_deadline,
shutdown_wait
));
let accept_fut = std::pin::pin!(acceptor.accept(stream));
let tls_stream =
match futures_util::future::select(accept_fut, deadline_or_shutdown).await {
Either::Left((Ok(s), _)) => s,
Either::Left((Err(e), _)) => {
tracing::error!("TLS error: {e}");
return;
}
Either::Right((Either::Left(_), _)) => {
tracing::warn!("TLS handshake timeout after {tls_handshake_timeout:?} from {addr}");
return;
}
Either::Right((Either::Right(_), _)) => {
tracing::debug!("TLS handshake aborted by shutdown from {addr}");
return;
}
};
#[cfg(feature = "signals")]
signal_tx::emit_connection_opened(&addr.to_string(), true, None).await;
let alpn_proto = tls_stream
.negotiated_alpn()
.map(std::borrow::Cow::into_owned);
let is_h2 = matches!(alpn_proto.as_deref(), Some(b"h2"));
let conn_info = if is_h2 {
ConnInfo::h2_tls(
addr,
TlsInfo {
alpn: alpn_proto.clone(),
sni: None,
version: None,
},
)
} else {
ConnInfo::h1_tls(
addr,
TlsInfo {
alpn: alpn_proto.clone(),
sni: None,
version: None,
},
)
};
#[cfg(feature = "http2")]
let proto = alpn_proto;
let io = HyperStream::new(tls_stream);
let svc = service_fn(move |mut req| {
let r = router.clone();
let conn_info = conn_info.clone();
async move {
req.extensions_mut().insert(addr);
req.extensions_mut().insert(conn_info);
let response = r.dispatch(req.map(TakoBody::new)).await;
Ok::<_, Infallible>(response)
}
});
#[cfg(feature = "http2")]
if proto.as_deref() == Some(b"h2") {
let mut h2 = http2::Builder::new(CompioH2Executor);
h2.timer(CompioH2Timer)
.max_concurrent_streams(h2_max_concurrent_streams)
.max_header_list_size(h2_max_header_list_size)
.max_send_buf_size(h2_max_send_buf_size)
.max_pending_accept_reset_streams(h2_max_pending_accept_reset_streams);
if let Some(interval) = h2_keep_alive_interval {
h2.keep_alive_interval(Some(interval));
}
if let Err(e) = h2.serve_connection(io, ServiceSendWrapper::new(svc)).await {
tracing::error!("HTTP/2 error: {e}");
}
#[cfg(feature = "signals")]
signal_tx::emit_connection_closed(&addr.to_string(), true, None).await;
return;
}
let mut h1 = http1::Builder::new();
h1.keep_alive(keep_alive);
if let Err(e) = h1.serve_connection(io, svc).with_upgrades().await {
if e.is_incomplete_message() {
tracing::debug!("TLS HTTP/1.1 client disconnected mid-message: {e}");
} else {
tracing::error!("HTTP/1.1 error: {e}");
}
}
#[cfg(feature = "signals")]
signal_tx::emit_connection_closed(&addr.to_string(), true, None).await;
})
.detach();
}
Either::Right(_) => {
cancel.cancel();
tracing::info!("Shutdown signal received, draining TLS connections...");
break;
}
}
}
let drain_deadline = std::time::Instant::now() + drain_timeout;
while inflight.load(Ordering::SeqCst) > 0 {
let now = std::time::Instant::now();
if now >= drain_deadline {
tracing::warn!(
"Drain timeout ({:?}) exceeded, {} TLS connections still active",
drain_timeout,
inflight.load(Ordering::SeqCst)
);
break;
}
let remaining = drain_deadline - now;
let drain_wait = drain_notify.notified();
let sleep = compio::time::sleep(remaining);
let drain_wait = std::pin::pin!(drain_wait);
let sleep = std::pin::pin!(sleep);
if let Either::Right(_) = futures_util::future::select(drain_wait, sleep).await {
tracing::warn!(
"Drain timeout ({:?}) exceeded, {} TLS connections still active",
drain_timeout,
inflight.load(Ordering::SeqCst)
);
break;
}
}
tracing::info!("TLS server shut down gracefully");
Ok(())
}
pub use tako_rs_core::tls::load_certs;
pub use tako_rs_core::tls::load_key;
#[cfg(feature = "http2")]
struct ServiceSendWrapper<T>(SendWrapper<T>);
#[cfg(feature = "http2")]
impl<T> ServiceSendWrapper<T> {
fn new(inner: T) -> Self {
Self(SendWrapper::new(inner))
}
}
#[cfg(feature = "http2")]
impl<R, T> hyper::service::Service<R> for ServiceSendWrapper<T>
where
T: hyper::service::Service<R>,
{
type Response = T::Response;
type Error = T::Error;
type Future = SendWrapper<T::Future>;
fn call(&self, req: R) -> Self::Future {
SendWrapper::new(self.0.call(req))
}
}
#[cfg(feature = "http2")]
#[derive(Debug, Clone)]
struct CompioH2Executor;
#[cfg(feature = "http2")]
impl<F: std::future::Future<Output = ()> + Send + 'static> hyper::rt::Executor<F>
for CompioH2Executor
{
fn execute(&self, fut: F) {
compio::runtime::spawn(fut).detach();
}
}
#[cfg(feature = "http2")]
#[derive(Debug, Clone)]
struct CompioH2Timer;
#[cfg(feature = "http2")]
struct CompioSleep(SendWrapper<std::pin::Pin<Box<dyn std::future::Future<Output = ()>>>>);
#[cfg(feature = "http2")]
impl std::future::Future for CompioSleep {
type Output = ();
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
#[cfg(feature = "http2")]
impl Unpin for CompioSleep {}
#[cfg(feature = "http2")]
impl hyper::rt::Sleep for CompioSleep {}
#[cfg(feature = "http2")]
impl hyper::rt::Timer for CompioH2Timer {
fn sleep(&self, duration: std::time::Duration) -> std::pin::Pin<Box<dyn hyper::rt::Sleep>> {
Box::pin(CompioSleep(SendWrapper::new(Box::pin(
compio::time::sleep(duration),
))))
}
fn sleep_until(&self, deadline: std::time::Instant) -> std::pin::Pin<Box<dyn hyper::rt::Sleep>> {
Box::pin(CompioSleep(SendWrapper::new(Box::pin(
compio::time::sleep_until(deadline),
))))
}
}