#![cfg(feature = "tls")]
#![cfg_attr(docsrs, doc(cfg(feature = "tls")))]
use std::convert::Infallible;
use std::fs::File;
use std::future::Future;
use std::io::BufReader;
use std::sync::Arc;
use std::time::Duration;
use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::server::conn::http2;
use hyper::service::service_fn;
#[cfg(feature = "http2")]
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls_pemfile::certs;
use rustls_pemfile::pkcs8_private_keys;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::rustls::ServerConfig;
use crate::body::TakoBody;
use crate::router::Router;
#[cfg(feature = "signals")]
use crate::signals::Signal;
#[cfg(feature = "signals")]
use crate::signals::SignalArbiter;
#[cfg(feature = "signals")]
use crate::signals::ids;
use crate::types::BoxError;
const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn serve_tls(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
) {
if let Err(e) = run(listener, router, certs, key, None::<std::future::Pending<()>>).await {
tracing::error!("TLS server error: {e}");
}
}
pub async fn serve_tls_with_shutdown(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
signal: impl Future<Output = ()>,
) {
if let Err(e) = run(listener, router, certs, key, Some(signal)).await {
tracing::error!("TLS server error: {e}");
}
}
pub async fn run(
listener: TcpListener,
router: Router,
certs: Option<&str>,
key: Option<&str>,
signal: Option<impl Future<Output = ()>>,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
crate::tracing::init_tracing();
let certs = load_certs(certs.unwrap_or("cert.pem"))?;
let key = load_key(key.unwrap_or("key.pem"))?;
let mut config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)?;
#[cfg(feature = "http2")]
{
config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
}
#[cfg(not(feature = "http2"))]
{
config.alpn_protocols = vec![b"http/1.1".to_vec()];
}
let acceptor = TlsAcceptor::from(Arc::new(config));
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
let addr_str = listener.local_addr()?.to_string();
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::SERVER_STARTED, 3)
.meta("addr", addr_str.clone())
.meta("transport", "tcp")
.meta("tls", "true"),
)
.await;
}
tracing::info!("Tako TLS listening on {}", addr_str);
let mut join_set = JoinSet::new();
let signal = signal.map(|s| Box::pin(s));
let signal_fused = async {
if let Some(s) = signal {
s.await;
} else {
std::future::pending::<()>().await;
}
};
tokio::pin!(signal_fused);
loop {
tokio::select! {
result = listener.accept() => {
let (stream, addr) = result?;
let _ = stream.set_nodelay(true);
let acceptor = acceptor.clone();
let router = router.clone();
join_set.spawn(async move {
let tls_stream = match acceptor.accept(stream).await {
Ok(s) => s,
Err(e) => {
tracing::error!("TLS error: {e}");
return;
}
};
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::CONNECTION_OPENED, 2)
.meta("remote_addr", addr.to_string())
.meta("tls", "true"),
)
.await;
}
#[cfg(feature = "http2")]
let proto = tls_stream.get_ref().1.alpn_protocol().map(|p| p.to_vec());
let io = TokioIo::new(tls_stream);
let svc = service_fn(move |mut req| {
let r = router.clone();
async move {
#[cfg(feature = "signals")]
let path = req.uri().path().to_string();
#[cfg(feature = "signals")]
let method = req.method().to_string();
req.extensions_mut().insert(addr);
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::REQUEST_STARTED, 2)
.meta("method", method.clone())
.meta("path", path.clone()),
)
.await;
}
let response = r.dispatch(req.map(TakoBody::incoming)).await;
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::REQUEST_COMPLETED, 3)
.meta("method", method)
.meta("path", path)
.meta("status", response.status().as_u16().to_string()),
)
.await;
}
Ok::<_, Infallible>(response)
}
});
#[cfg(feature = "http2")]
if proto.as_deref() == Some(b"h2") {
let h2 = http2::Builder::new(TokioExecutor::new());
if let Err(e) = h2.serve_connection(io, svc).await {
tracing::error!("HTTP/2 error: {e}");
}
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
.meta("remote_addr", addr.to_string())
.meta("tls", "true"),
)
.await;
}
return;
}
let mut h1 = http1::Builder::new();
h1.keep_alive(true);
if let Err(e) = h1.serve_connection(io, svc).with_upgrades().await {
tracing::error!("HTTP/1.1 error: {e}");
}
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
.meta("remote_addr", addr.to_string())
.meta("tls", "true"),
)
.await;
}
});
}
() = &mut signal_fused => {
tracing::info!("Shutdown signal received, draining TLS connections...");
break;
}
}
}
let drain = tokio::time::timeout(DEFAULT_DRAIN_TIMEOUT, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout ({:?}) exceeded, aborting {} remaining TLS connections",
DEFAULT_DRAIN_TIMEOUT,
join_set.len()
);
join_set.abort_all();
}
tracing::info!("TLS server shut down gracefully");
Ok(())
}
pub fn load_certs(path: &str) -> anyhow::Result<Vec<CertificateDer<'static>>> {
let mut rd = BufReader::new(
File::open(path).map_err(|e| anyhow::anyhow!("failed to open cert file '{}': {}", path, e))?,
);
certs(&mut rd)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("failed to parse certs from '{}': {}", path, e))
}
pub fn load_key(path: &str) -> anyhow::Result<PrivateKeyDer<'static>> {
let mut rd = BufReader::new(
File::open(path).map_err(|e| anyhow::anyhow!("failed to open key file '{}': {}", path, e))?,
);
pkcs8_private_keys(&mut rd)
.next()
.ok_or_else(|| anyhow::anyhow!("no private key found in '{}'", path))?
.map(|k| k.into())
.map_err(|e| anyhow::anyhow!("bad private key in '{}': {}", path, e))
}