tako-rs 1.1.2

Multi-transport Rust framework for modern network services.
Documentation
#![cfg(feature = "http3")]
#![cfg_attr(docsrs, doc(cfg(feature = "http3")))]

//! HTTP/3 server implementation using QUIC transport.
//!
//! This module provides HTTP/3 support for Tako web servers using the h3 crate
//! with Quinn as the QUIC transport. HTTP/3 offers improved performance over
//! HTTP/1.1 and HTTP/2 through features like reduced latency, better multiplexing,
//! and built-in encryption via QUIC.
//!
//! # Examples
//!
//! ```rust,no_run
//! # #[cfg(feature = "http3")]
//! use tako::{serve_h3, router::Router, Method, responder::Responder, types::Request};
//!
//! # #[cfg(feature = "http3")]
//! async fn hello(_: Request) -> impl Responder {
//!     "Hello, HTTP/3 World!".into_response()
//! }
//!
//! # #[cfg(feature = "http3")]
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let mut router = Router::new();
//! router.route(Method::GET, "/", hello);
//! serve_h3(router, "[::]:4433", Some("cert.pem"), Some("key.pem")).await;
//! # Ok(())
//! # }
//! ```

use std::fs::File;
use std::future::Future;
use std::io::BufReader;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;

use bytes::Buf;
use bytes::Bytes;
use h3::quic::BidiStream;
use h3::server::RequestStream;
use http::Request;
use http_body::Body;
use quinn::crypto::rustls::QuicServerConfig;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls_pemfile::certs;
use rustls_pemfile::pkcs8_private_keys;

use crate::body::TakoBody;
use crate::router::Router;
#[cfg(feature = "signals")]
use crate::signals::Signal;
#[cfg(feature = "signals")]
use crate::signals::SignalArbiter;
#[cfg(feature = "signals")]
use crate::signals::ids;
use crate::types::BoxError;

/// Starts an HTTP/3 server with the given router and certificates.
///
/// This function creates a QUIC endpoint and listens for incoming HTTP/3 connections.
/// Unlike TCP-based servers, HTTP/3 uses UDP and QUIC for transport.
///
/// # Arguments
///
/// * `router` - The Tako router containing route definitions
/// * `addr` - The socket address to bind to (e.g., "[::]:4433")
/// * `certs` - Optional path to the TLS certificate file (defaults to "cert.pem")
/// * `key` - Optional path to the TLS private key file (defaults to "key.pem")
/// Default drain timeout for graceful shutdown (30 seconds).
const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);

pub async fn serve_h3(router: Router, addr: &str, certs: Option<&str>, key: Option<&str>) {
  if let Err(e) = run(router, addr, certs, key, None::<std::future::Pending<()>>).await {
    tracing::error!("HTTP/3 server error: {e}");
  }
}

/// Starts an HTTP/3 server with graceful shutdown support.
pub async fn serve_h3_with_shutdown(
  router: Router,
  addr: &str,
  certs: Option<&str>,
  key: Option<&str>,
  signal: impl Future<Output = ()>,
) {
  if let Err(e) = run(router, addr, certs, key, Some(signal)).await {
    tracing::error!("HTTP/3 server error: {e}");
  }
}

/// Runs the HTTP/3 server loop.
async fn run(
  router: Router,
  addr: &str,
  certs: Option<&str>,
  key: Option<&str>,
  signal: Option<impl Future<Output = ()>>,
) -> Result<(), BoxError> {
  #[cfg(feature = "tako-tracing")]
  crate::tracing::init_tracing();

  // Install default crypto provider for rustls (required for QUIC/TLS)
  let _ = rustls::crypto::ring::default_provider().install_default();

  let certs_vec = load_certs(certs.unwrap_or("cert.pem"))?;
  let key = load_key(key.unwrap_or("key.pem"))?;

  let mut tls_config = rustls::ServerConfig::builder()
    .with_no_client_auth()
    .with_single_cert(certs_vec, key)?;

  tls_config.max_early_data_size = u32::MAX;
  tls_config.alpn_protocols = vec![b"h3".to_vec()];

  let server_config =
    quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?));

  let socket_addr: SocketAddr = addr.parse()?;
  let endpoint = quinn::Endpoint::server(server_config, socket_addr)?;

  let router = Arc::new(router);

  #[cfg(feature = "plugins")]
  router.setup_plugins_once();

  let addr_str = endpoint.local_addr()?.to_string();

  #[cfg(feature = "signals")]
  {
    SignalArbiter::emit_app(
      Signal::with_capacity(ids::SERVER_STARTED, 3)
        .meta("addr", addr_str.clone())
        .meta("transport", "quic")
        .meta("protocol", "h3"),
    )
    .await;
  }

  tracing::info!("Tako HTTP/3 listening on {}", addr_str);

  let mut join_set = tokio::task::JoinSet::new();

  let signal = signal.map(|s| Box::pin(s));
  let signal_fused = async {
    if let Some(s) = signal {
      s.await;
    } else {
      std::future::pending::<()>().await;
    }
  };
  tokio::pin!(signal_fused);

  loop {
    tokio::select! {
      maybe_conn = endpoint.accept() => {
        let Some(new_conn) = maybe_conn else { break };
        let router = router.clone();

        join_set.spawn(async move {
          match new_conn.await {
            Ok(conn) => {
              let remote_addr = conn.remote_address();

              #[cfg(feature = "signals")]
              {
                SignalArbiter::emit_app(
                  Signal::with_capacity(ids::CONNECTION_OPENED, 2)
                    .meta("remote_addr", remote_addr.to_string())
                    .meta("protocol", "h3"),
                )
                .await;
              }

              if let Err(e) = handle_connection(conn, router, remote_addr).await {
                tracing::error!("HTTP/3 connection error: {e}");
              }

              #[cfg(feature = "signals")]
              {
                SignalArbiter::emit_app(
                  Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
                    .meta("remote_addr", remote_addr.to_string())
                    .meta("protocol", "h3"),
                )
                .await;
              }
            }
            Err(e) => {
              tracing::error!("QUIC connection failed: {e}");
            }
          }
        });
      }
      () = &mut signal_fused => {
        tracing::info!("Shutdown signal received, draining HTTP/3 connections...");
        break;
      }
    }
  }

  // Close the endpoint to stop accepting new connections
  endpoint.close(0u32.into(), b"server shutting down");

  // Drain in-flight connections
  let drain = tokio::time::timeout(DEFAULT_DRAIN_TIMEOUT, async {
    while join_set.join_next().await.is_some() {}
  });

  if drain.await.is_err() {
    tracing::warn!(
      "Drain timeout ({:?}) exceeded, aborting {} remaining HTTP/3 connections",
      DEFAULT_DRAIN_TIMEOUT,
      join_set.len()
    );
    join_set.abort_all();
  }

  endpoint.wait_idle().await;
  tracing::info!("HTTP/3 server shut down gracefully");
  Ok(())
}

/// Handles a single HTTP/3 connection.
async fn handle_connection(
  conn: quinn::Connection,
  router: Arc<Router>,
  remote_addr: SocketAddr,
) -> Result<(), BoxError> {
  let mut h3_conn = h3::server::Connection::new(h3_quinn::Connection::new(conn)).await?;

  loop {
    match h3_conn.accept().await {
      Ok(Some(resolver)) => {
        let router = router.clone();
        tokio::spawn(async move {
          match resolver.resolve_request().await {
            Ok((req, stream)) => {
              if let Err(e) = handle_request(req, stream, router, remote_addr).await {
                tracing::error!("HTTP/3 request error: {e}");
              }
            }
            Err(e) => {
              tracing::error!("HTTP/3 request resolve error: {e}");
            }
          }
        });
      }
      Ok(None) => {
        break;
      }
      Err(e) => {
        tracing::error!("HTTP/3 accept error: {e}");
        break;
      }
    }
  }

  Ok(())
}

/// Handles a single HTTP/3 request.
async fn handle_request<S>(
  req: Request<()>,
  mut stream: RequestStream<S, Bytes>,
  router: Arc<Router>,
  remote_addr: SocketAddr,
) -> Result<(), BoxError>
where
  S: BidiStream<Bytes>,
{
  #[cfg(feature = "signals")]
  let path = req.uri().path().to_string();
  #[cfg(feature = "signals")]
  let method = req.method().to_string();

  #[cfg(feature = "signals")]
  {
    SignalArbiter::emit_app(
      Signal::with_capacity(ids::REQUEST_STARTED, 3)
        .meta("method", method.clone())
        .meta("path", path.clone())
        .meta("protocol", "h3"),
    )
    .await;
  }

  // Collect request body
  let mut body_bytes = Vec::new();
  while let Some(mut chunk) = stream.recv_data().await? {
    while chunk.has_remaining() {
      let bytes = chunk.chunk();
      body_bytes.extend_from_slice(bytes);
      chunk.advance(bytes.len());
    }
  }

  // Build request with body
  let (parts, _) = req.into_parts();
  let body = TakoBody::from(Bytes::from(body_bytes));
  let mut tako_req = Request::from_parts(parts, body);
  tako_req.extensions_mut().insert(remote_addr);

  // Dispatch through router
  let response = router.dispatch(tako_req).await;

  #[cfg(feature = "signals")]
  {
    SignalArbiter::emit_app(
      Signal::with_capacity(ids::REQUEST_COMPLETED, 4)
        .meta("method", method)
        .meta("path", path)
        .meta("status", response.status().as_u16().to_string())
        .meta("protocol", "h3"),
    )
    .await;
  }

  // Send response
  let (parts, body) = response.into_parts();
  let resp = http::Response::from_parts(parts, ());

  stream.send_response(resp).await?;

  // Stream body data frame by frame (supports SSE)
  let mut body = std::pin::pin!(body);
  while let Some(frame) = std::future::poll_fn(|cx| body.as_mut().poll_frame(cx)).await {
    match frame {
      Ok(frame) => {
        if let Some(data) = frame.data_ref().filter(|d| !d.is_empty()) {
          stream.send_data(data.clone()).await?;
        }
      }
      Err(e) => {
        tracing::error!("HTTP/3 body frame error: {e}");
        break;
      }
    }
  }

  stream.finish().await?;

  Ok(())
}

/// Loads TLS certificates from a PEM-encoded file.
pub fn load_certs(path: &str) -> anyhow::Result<Vec<CertificateDer<'static>>> {
  let mut rd = BufReader::new(
    File::open(path).map_err(|e| anyhow::anyhow!("failed to open cert file '{}': {}", path, e))?,
  );
  certs(&mut rd)
    .collect::<Result<Vec<_>, _>>()
    .map_err(|e| anyhow::anyhow!("failed to parse certs from '{}': {}", path, e))
}

/// Loads a private key from a PEM-encoded file.
pub fn load_key(path: &str) -> anyhow::Result<PrivateKeyDer<'static>> {
  let mut rd = BufReader::new(
    File::open(path).map_err(|e| anyhow::anyhow!("failed to open key file '{}': {}", path, e))?,
  );
  pkcs8_private_keys(&mut rd)
    .next()
    .ok_or_else(|| anyhow::anyhow!("no private key found in '{}'", path))?
    .map(|k| k.into())
    .map_err(|e| anyhow::anyhow!("bad private key in '{}': {}", path, e))
}