tako-rs 1.1.2

Multi-transport Rust framework for modern network services.
Documentation
#![cfg(feature = "tls")]
#![cfg_attr(docsrs, doc(cfg(feature = "tls")))]

//! TLS-enabled HTTP server implementation for secure connections.
//!
//! This module provides TLS/SSL support for Tako web servers using rustls for encryption.
//! It handles secure connection establishment, certificate loading, and supports both
//! HTTP/1.1 and HTTP/2 protocols (when the http2 feature is enabled). The main entry
//! point is `serve_tls` which starts a secure server with the provided certificates.
//!
//! # Examples
//!
//! ```rust,no_run
//! # #[cfg(feature = "tls")]
//! use tako::{serve_tls, router::Router, Method, responder::Responder, types::Request};
//! # #[cfg(feature = "tls")]
//! use tokio::net::TcpListener;
//!
//! # #[cfg(feature = "tls")]
//! async fn hello(_: Request) -> impl Responder {
//!     "Hello, Secure World!".into_response()
//! }
//!
//! # #[cfg(feature = "tls")]
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! let listener = TcpListener::bind("127.0.0.1:8443").await?;
//! let mut router = Router::new();
//! router.route(Method::GET, "/", hello);
//! serve_tls(listener, router, Some("cert.pem"), Some("key.pem")).await;
//! # Ok(())
//! # }
//! ```

use std::convert::Infallible;
use std::fs::File;
use std::future::Future;
use std::io::BufReader;
use std::sync::Arc;
use std::time::Duration;

use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::server::conn::http2;
use hyper::service::service_fn;
#[cfg(feature = "http2")]
use hyper_util::rt::TokioExecutor;
use hyper_util::rt::TokioIo;
use rustls::pki_types::CertificateDer;
use rustls::pki_types::PrivateKeyDer;
use rustls_pemfile::certs;
use rustls_pemfile::pkcs8_private_keys;
use tokio::net::TcpListener;
use tokio::task::JoinSet;
use tokio_rustls::TlsAcceptor;
use tokio_rustls::rustls::ServerConfig;

use crate::body::TakoBody;
use crate::router::Router;
#[cfg(feature = "signals")]
use crate::signals::Signal;
#[cfg(feature = "signals")]
use crate::signals::SignalArbiter;
#[cfg(feature = "signals")]
use crate::signals::ids;
use crate::types::BoxError;

/// Default drain timeout for graceful shutdown (30 seconds).
const DEFAULT_DRAIN_TIMEOUT: Duration = Duration::from_secs(30);

/// Starts a TLS-enabled HTTP server with the given listener, router, and certificates.
pub async fn serve_tls(
  listener: TcpListener,
  router: Router,
  certs: Option<&str>,
  key: Option<&str>,
) {
  if let Err(e) = run(listener, router, certs, key, None::<std::future::Pending<()>>).await {
    tracing::error!("TLS server error: {e}");
  }
}

/// Starts a TLS-enabled HTTP server with graceful shutdown support.
pub async fn serve_tls_with_shutdown(
  listener: TcpListener,
  router: Router,
  certs: Option<&str>,
  key: Option<&str>,
  signal: impl Future<Output = ()>,
) {
  if let Err(e) = run(listener, router, certs, key, Some(signal)).await {
    tracing::error!("TLS server error: {e}");
  }
}

/// Runs the TLS server loop, handling secure connections and request dispatch.
pub async fn run(
  listener: TcpListener,
  router: Router,
  certs: Option<&str>,
  key: Option<&str>,
  signal: Option<impl Future<Output = ()>>,
) -> Result<(), BoxError> {
  #[cfg(feature = "tako-tracing")]
  crate::tracing::init_tracing();

  let certs = load_certs(certs.unwrap_or("cert.pem"))?;
  let key = load_key(key.unwrap_or("key.pem"))?;

  let mut config = ServerConfig::builder()
    .with_no_client_auth()
    .with_single_cert(certs, key)?;

  #[cfg(feature = "http2")]
  {
    config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
  }

  #[cfg(not(feature = "http2"))]
  {
    config.alpn_protocols = vec![b"http/1.1".to_vec()];
  }

  let acceptor = TlsAcceptor::from(Arc::new(config));
  let router = Arc::new(router);

  // Setup plugins
  #[cfg(feature = "plugins")]
  router.setup_plugins_once();

  let addr_str = listener.local_addr()?.to_string();

  #[cfg(feature = "signals")]
  {
    // Emit server.started (TLS)
    SignalArbiter::emit_app(
      Signal::with_capacity(ids::SERVER_STARTED, 3)
        .meta("addr", addr_str.clone())
        .meta("transport", "tcp")
        .meta("tls", "true"),
    )
    .await;
  }

  tracing::info!("Tako TLS listening on {}", addr_str);

  let mut join_set = JoinSet::new();
  let signal = signal.map(|s| Box::pin(s));
  let signal_fused = async {
    if let Some(s) = signal {
      s.await;
    } else {
      std::future::pending::<()>().await;
    }
  };
  tokio::pin!(signal_fused);

  loop {
    tokio::select! {
      result = listener.accept() => {
        let (stream, addr) = result?;
        let _ = stream.set_nodelay(true);
        let acceptor = acceptor.clone();
        let router = router.clone();

        join_set.spawn(async move {
          let tls_stream = match acceptor.accept(stream).await {
            Ok(s) => s,
            Err(e) => {
              tracing::error!("TLS error: {e}");
              return;
            }
          };

          #[cfg(feature = "signals")]
          {
            SignalArbiter::emit_app(
              Signal::with_capacity(ids::CONNECTION_OPENED, 2)
                .meta("remote_addr", addr.to_string())
                .meta("tls", "true"),
            )
            .await;
          }

          #[cfg(feature = "http2")]
          let proto = tls_stream.get_ref().1.alpn_protocol().map(|p| p.to_vec());

          let io = TokioIo::new(tls_stream);
          let svc = service_fn(move |mut req| {
            let r = router.clone();
            async move {
              #[cfg(feature = "signals")]
              let path = req.uri().path().to_string();
              #[cfg(feature = "signals")]
              let method = req.method().to_string();

              req.extensions_mut().insert(addr);

              #[cfg(feature = "signals")]
              {
                SignalArbiter::emit_app(
                  Signal::with_capacity(ids::REQUEST_STARTED, 2)
                    .meta("method", method.clone())
                    .meta("path", path.clone()),
                )
                .await;
              }

              let response = r.dispatch(req.map(TakoBody::incoming)).await;

              #[cfg(feature = "signals")]
              {
                SignalArbiter::emit_app(
                  Signal::with_capacity(ids::REQUEST_COMPLETED, 3)
                    .meta("method", method)
                    .meta("path", path)
                    .meta("status", response.status().as_u16().to_string()),
                )
                .await;
              }

              Ok::<_, Infallible>(response)
            }
          });

          #[cfg(feature = "http2")]
          if proto.as_deref() == Some(b"h2") {
            let h2 = http2::Builder::new(TokioExecutor::new());

            if let Err(e) = h2.serve_connection(io, svc).await {
              tracing::error!("HTTP/2 error: {e}");
            }

            #[cfg(feature = "signals")]
            {
              SignalArbiter::emit_app(
                Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
                  .meta("remote_addr", addr.to_string())
                  .meta("tls", "true"),
              )
              .await;
            }
            return;
          }

          let mut h1 = http1::Builder::new();
          h1.keep_alive(true);

          if let Err(e) = h1.serve_connection(io, svc).with_upgrades().await {
            tracing::error!("HTTP/1.1 error: {e}");
          }

          #[cfg(feature = "signals")]
          {
            SignalArbiter::emit_app(
              Signal::with_capacity(ids::CONNECTION_CLOSED, 2)
                .meta("remote_addr", addr.to_string())
                .meta("tls", "true"),
            )
            .await;
          }
        });
      }
      () = &mut signal_fused => {
        tracing::info!("Shutdown signal received, draining TLS connections...");
        break;
      }
    }
  }

  // Drain in-flight connections
  let drain = tokio::time::timeout(DEFAULT_DRAIN_TIMEOUT, async {
    while join_set.join_next().await.is_some() {}
  });

  if drain.await.is_err() {
    tracing::warn!(
      "Drain timeout ({:?}) exceeded, aborting {} remaining TLS connections",
      DEFAULT_DRAIN_TIMEOUT,
      join_set.len()
    );
    join_set.abort_all();
  }

  tracing::info!("TLS server shut down gracefully");
  Ok(())
}

/// Loads TLS certificates from a PEM-encoded file.
///
/// Reads and parses X.509 certificates from the specified file path. The file
/// should contain one or more PEM-encoded certificates.
///
/// # Arguments
///
/// * `path` - File system path to the certificate file
///
/// # Panics
///
/// Panics if the file cannot be opened, read, or if the certificates are
/// malformed or invalid.
///
/// # Examples
///
/// ```rust,no_run
/// # #[cfg(feature = "tls")]
/// use tako::server_tls::load_certs;
///
/// # #[cfg(feature = "tls")]
/// # fn example() {
/// let certs = load_certs("server.crt");
/// println!("Loaded {} certificates", certs.len());
/// # }
/// ```
pub fn load_certs(path: &str) -> anyhow::Result<Vec<CertificateDer<'static>>> {
  let mut rd = BufReader::new(
    File::open(path).map_err(|e| anyhow::anyhow!("failed to open cert file '{}': {}", path, e))?,
  );
  certs(&mut rd)
    .collect::<Result<Vec<_>, _>>()
    .map_err(|e| anyhow::anyhow!("failed to parse certs from '{}': {}", path, e))
}

/// Loads a private key from a PEM-encoded file.
///
/// Reads and parses a PKCS#8 private key from the specified file path. The file
/// should contain a single PEM-encoded private key.
///
/// # Arguments
///
/// * `path` - File system path to the private key file
///
/// # Panics
///
/// Panics if the file cannot be opened, read, if no private key is found,
/// or if the private key is malformed or invalid.
///
/// # Examples
///
/// ```rust,no_run
/// # #[cfg(feature = "tls")]
/// use tako::server_tls::load_key;
///
/// # #[cfg(feature = "tls")]
/// # fn example() {
/// let key = load_key("server.key");
/// println!("Loaded private key successfully");
/// # }
/// ```
pub fn load_key(path: &str) -> anyhow::Result<PrivateKeyDer<'static>> {
  let mut rd = BufReader::new(
    File::open(path).map_err(|e| anyhow::anyhow!("failed to open key file '{}': {}", path, e))?,
  );
  pkcs8_private_keys(&mut rd)
    .next()
    .ok_or_else(|| anyhow::anyhow!("no private key found in '{}'", path))?
    .map(|k| k.into())
    .map_err(|e| anyhow::anyhow!("bad private key in '{}': {}", path, e))
}