#![cfg(feature = "http3")]
#![cfg_attr(docsrs, doc(cfg(feature = "http3")))]
use std::fs::File;
use std::future::Future;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Buf;
use bytes::Bytes;
use h3::quic::BidiStream;
use h3::server::RequestStream;
use http::Request;
use http_body::Body;
use quinn::crypto::rustls::QuicServerConfig;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls_pemfile::certs;
use rustls_pemfile::pkcs8_private_keys;
use crate::body::TakoBody;
use crate::router::Router;
#[cfg(feature = "signals")]
use crate::signals::Signal;
#[cfg(feature = "signals")]
use crate::signals::SignalArbiter;
#[cfg(feature = "signals")]
use crate::signals::ids;
use crate::types::BoxError;
const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);
pub async fn serve_h3(router: Router, addr: &str, certs: Option<&str>, key: Option<&str>) {
if let Err(e) = run(router, addr, certs, key, None::<std::future::Pending<()>>).await {
tracing::error!("HTTP/3 server error: {e}");
}
}
pub async fn serve_h3_with_shutdown(
router: Router,
addr: &str,
certs: Option<&str>,
key: Option<&str>,
signal: impl Future<Output = ()>,
) {
if let Err(e) = run(router, addr, certs, key, Some(signal)).await {
tracing::error!("HTTP/3 server error: {e}");
}
}
async fn run(
router: Router,
addr: &str,
certs: Option<&str>,
key: Option<&str>,
signal: Option<impl Future<Output = ()>>,
) -> Result<(), BoxError> {
#[cfg(feature = "tako-tracing")]
crate::tracing::init_tracing();
let _ = rustls::crypto::ring::default_provider().install_default();
let certs_vec = load_certs(certs.unwrap_or("cert.pem"))?;
let key = load_key(key.unwrap_or("key.pem"))?;
let mut tls_config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs_vec, key)?;
tls_config.max_early_data_size = u32::MAX;
tls_config.alpn_protocols = vec![b"h3".to_vec()];
let server_config =
quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?));
let socket_addr: SocketAddr = addr.parse()?;
let endpoint = quinn::Endpoint::server(server_config, socket_addr)?;
let router = Arc::new(router);
#[cfg(feature = "plugins")]
router.setup_plugins_once();
let addr_str = endpoint.local_addr()?.to_string();
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::SERVER_STARTED, 3)
.meta("addr", addr_str.clone())
.meta("transport", "quic")
.meta("protocol", "h3"),
)
.await;
}
tracing::info!("Tako HTTP/3 listening on {}", addr_str);
let mut join_set = tokio::task::JoinSet::new();
let signal = signal.map(|s| Box::pin(s));
let signal_fused = async {
if let Some(s) = signal {
s.await;
} else {
std::future::pending::<()>().await;
}
};
tokio::pin!(signal_fused);
loop {
tokio::select! {
maybe_conn = endpoint.accept() => {
let Some(new_conn) = maybe_conn else { break };
let router = router.clone();
join_set.spawn(async move {
match new_conn.await {
Ok(conn) => {
let remote_addr = conn.remote_address();
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::CONNECTION_OPENED, 2)
.meta("remote_addr", remote_addr.to_string())
.meta("protocol", "h3"),
)
.await;
}
if let Err(e) = handle_connection(conn, router, remote_addr).await {
tracing::error!("HTTP/3 connection error: {e}");
}
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
.meta("remote_addr", remote_addr.to_string())
.meta("protocol", "h3"),
)
.await;
}
}
Err(e) => {
tracing::error!("QUIC connection failed: {e}");
}
}
});
}
() = &mut signal_fused => {
tracing::info!("Shutdown signal received, draining HTTP/3 connections...");
break;
}
}
}
endpoint.close(0u32.into(), b"server shutting down");
let drain = tokio::time::timeout(DEFAULT_DRAIN_TIMEOUT, async {
while join_set.join_next().await.is_some() {}
});
if drain.await.is_err() {
tracing::warn!(
"Drain timeout ({:?}) exceeded, aborting {} remaining HTTP/3 connections",
DEFAULT_DRAIN_TIMEOUT,
join_set.len()
);
join_set.abort_all();
}
endpoint.wait_idle().await;
tracing::info!("HTTP/3 server shut down gracefully");
Ok(())
}
async fn handle_connection(
conn: quinn::Connection,
router: Arc<Router>,
remote_addr: SocketAddr,
) -> Result<(), BoxError> {
let mut h3_conn = h3::server::Connection::new(h3_quinn::Connection::new(conn)).await?;
loop {
match h3_conn.accept().await {
Ok(Some(resolver)) => {
let router = router.clone();
tokio::spawn(async move {
match resolver.resolve_request().await {
Ok((req, stream)) => {
if let Err(e) = handle_request(req, stream, router, remote_addr).await {
tracing::error!("HTTP/3 request error: {e}");
}
}
Err(e) => {
tracing::error!("HTTP/3 request resolve error: {e}");
}
}
});
}
Ok(None) => {
break;
}
Err(e) => {
tracing::error!("HTTP/3 accept error: {e}");
break;
}
}
}
Ok(())
}
async fn handle_request<S>(
req: Request<()>,
mut stream: RequestStream<S, Bytes>,
router: Arc<Router>,
remote_addr: SocketAddr,
) -> Result<(), BoxError>
where
S: BidiStream<Bytes>,
{
#[cfg(feature = "signals")]
let path = req.uri().path().to_string();
#[cfg(feature = "signals")]
let method = req.method().to_string();
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::REQUEST_STARTED, 3)
.meta("method", method.clone())
.meta("path", path.clone())
.meta("protocol", "h3"),
)
.await;
}
let mut body_bytes = Vec::new();
while let Some(mut chunk) = stream.recv_data().await? {
while chunk.has_remaining() {
let bytes = chunk.chunk();
body_bytes.extend_from_slice(bytes);
chunk.advance(bytes.len());
}
}
let (parts, _) = req.into_parts();
let body = TakoBody::from(Bytes::from(body_bytes));
let mut tako_req = Request::from_parts(parts, body);
tako_req.extensions_mut().insert(remote_addr);
let response = router.dispatch(tako_req).await;
#[cfg(feature = "signals")]
{
SignalArbiter::emit_app(
Signal::with_capacity(ids::REQUEST_COMPLETED, 4)
.meta("method", method)
.meta("path", path)
.meta("status", response.status().as_u16().to_string())
.meta("protocol", "h3"),
)
.await;
}
let (parts, body) = response.into_parts();
let resp = http::Response::from_parts(parts, ());
stream.send_response(resp).await?;
let mut body = std::pin::pin!(body);
while let Some(frame) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
match frame {
Ok(frame) => {
if let Some(data) = frame.data_ref().filter(|d| !d.is_empty()) {
stream.send_data(data.clone()).await?;
}
}
Err(e) => {
tracing::error!("HTTP/3 body frame error: {e}");
break;
}
}
}
stream.finish().await?;
Ok(())
}
pub fn load_certs(path: &str) -> anyhow::Result<Vec<CertificateDer<'static>>> {
let mut rd = BufReader::new(
File::open(path).map_err(|e| anyhow::anyhow!("failed to open cert file '{}': {}", path, e))?,
);
certs(&mut rd)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| anyhow::anyhow!("failed to parse certs from '{}': {}", path, e))
}
pub fn load_key(path: &str) -> anyhow::Result<PrivateKeyDer<'static>> {
let mut rd = BufReader::new(
File::open(path).map_err(|e| anyhow::anyhow!("failed to open key file '{}': {}", path, e))?,
);
pkcs8_private_keys(&mut rd)
.next()
.ok_or_else(|| anyhow::anyhow!("no private key found in '{}'", path))?
.map(|k| k.into())
.map_err(|e| anyhow::anyhow!("bad private key in '{}': {}", path, e))
}