use crate::http_server::{HttpServerConfig, HttpServerError, Result};
use axum::{Router, http::StatusCode, response::IntoResponse, routing::get};
use std::future::IntoFuture;
use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use tokio::net::TcpListener;
#[cfg(not(feature = "shutdown"))]
use tokio::signal;
use tokio::sync::watch;
use tower::limit::ConcurrencyLimitLayer;
use tower_http::timeout::TimeoutLayer;
use tower_http::trace::TraceLayer;
pub struct HttpServer {
config: HttpServerConfig,
ready: Arc<AtomicBool>,
}
impl HttpServer {
#[must_use]
pub fn new(config: HttpServerConfig) -> Self {
let ready = Arc::new(AtomicBool::new(true));
#[cfg(feature = "health")]
{
let r = Arc::clone(&ready);
crate::health::HealthRegistry::register("http_server", move || {
if r.load(Ordering::Relaxed) {
crate::health::HealthStatus::Healthy
} else {
crate::health::HealthStatus::Unhealthy
}
});
}
Self { config, ready }
}
#[must_use]
pub fn bind(address: impl Into<String>) -> Self {
Self::new(HttpServerConfig::new(address))
}
pub fn set_ready(&self, ready: bool) {
self.ready.store(ready, Ordering::SeqCst);
}
#[must_use]
pub fn is_ready(&self) -> bool {
self.ready.load(Ordering::SeqCst)
}
#[must_use]
pub fn ready_flag(&self) -> Arc<AtomicBool> {
Arc::clone(&self.ready)
}
pub async fn serve(self, app: Router) -> Result<()> {
#[cfg(feature = "shutdown")]
{
let token = crate::shutdown::install_signal_handler();
self.serve_with_shutdown(app, token.cancelled_owned()).await
}
#[cfg(not(feature = "shutdown"))]
{
self.serve_with_shutdown(app, shutdown_signal()).await
}
}
pub async fn serve_with_shutdown<F>(self, app: Router, shutdown: F) -> Result<()>
where
F: std::future::Future<Output = ()> + Send + 'static,
{
self.config.validate().map_err(HttpServerError::TlsConfig)?;
let shutdown_timeout = self.config.shutdown_timeout();
let app = self.build_router(app);
let addr: SocketAddr =
self.config
.bind_address
.parse()
.map_err(|e| HttpServerError::Bind {
address: self.config.bind_address.clone(),
source: std::io::Error::new(std::io::ErrorKind::InvalidInput, e),
})?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| HttpServerError::Bind {
address: self.config.bind_address.clone(),
source: e,
})?;
#[cfg(feature = "logger")]
tracing::info!(address = %addr, "HTTP server listening");
let (drain_started_tx, drain_started_rx) = tokio::sync::oneshot::channel();
let drain_started_tx = std::sync::Mutex::new(Some(drain_started_tx));
let ready_for_signal = Arc::clone(&self.ready);
let signal = async move {
shutdown.await;
ready_for_signal.store(false, Ordering::SeqCst);
if let Some(tx) = drain_started_tx.lock().ok().and_then(|mut g| g.take()) {
let _ = tx.send(());
}
};
let serve = axum::serve(listener, app)
.with_graceful_shutdown(signal)
.into_future();
tokio::pin!(serve);
tokio::select! {
result = &mut serve => result.map_err(HttpServerError::Io)?,
() = async {
let _ = drain_started_rx.await;
tokio::time::sleep(shutdown_timeout).await;
} => {
#[cfg(feature = "logger")]
tracing::warn!(
timeout_ms = u64::try_from(shutdown_timeout.as_millis()).unwrap_or(u64::MAX),
"HTTP server graceful drain timed out -- forcing exit"
);
}
}
#[cfg(feature = "logger")]
tracing::info!("HTTP server shut down gracefully");
Ok(())
}
pub async fn serve_with_handle(self, app: Router) -> Result<(ShutdownHandle, ServerFuture)> {
self.config.validate().map_err(HttpServerError::TlsConfig)?;
let (tx, rx) = watch::channel(());
let handle = ShutdownHandle { sender: tx };
let shutdown_timeout = self.config.shutdown_timeout();
let app = self.build_router(app);
let addr: SocketAddr =
self.config
.bind_address
.parse()
.map_err(|e| HttpServerError::Bind {
address: self.config.bind_address.clone(),
source: std::io::Error::new(std::io::ErrorKind::InvalidInput, e),
})?;
let listener = TcpListener::bind(addr)
.await
.map_err(|e| HttpServerError::Bind {
address: self.config.bind_address.clone(),
source: e,
})?;
#[cfg(feature = "logger")]
tracing::info!(address = %addr, "HTTP server listening");
let (drain_started_tx, drain_started_rx) = tokio::sync::oneshot::channel();
let drain_started_tx = std::sync::Mutex::new(Some(drain_started_tx));
let ready_for_signal = Arc::clone(&self.ready);
let signal = async move {
let _ = rx.clone().changed().await;
ready_for_signal.store(false, Ordering::SeqCst);
if let Some(tx) = drain_started_tx.lock().ok().and_then(|mut g| g.take()) {
let _ = tx.send(());
}
};
let future = ServerFuture {
inner: Box::pin(async move {
let serve = axum::serve(listener, app)
.with_graceful_shutdown(signal)
.into_future();
tokio::pin!(serve);
tokio::select! {
result = &mut serve => result.map_err(HttpServerError::Io)?,
() = async {
let _ = drain_started_rx.await;
tokio::time::sleep(shutdown_timeout).await;
} => {
#[cfg(feature = "logger")]
tracing::warn!(
timeout_ms = u64::try_from(shutdown_timeout.as_millis()).unwrap_or(u64::MAX),
"HTTP server graceful drain timed out -- forcing exit"
);
}
}
Ok(())
}),
};
Ok((handle, future))
}
fn build_router(&self, app: Router) -> Router {
let mut router = app;
if self.config.enable_health_endpoints {
let ready = Arc::clone(&self.ready);
let r1 = Arc::clone(&ready);
let r2 = Arc::clone(&ready);
router = router
.route("/health/live", get(health_live))
.route("/health/ready", get(move || health_ready(Arc::clone(&r1))))
.route("/healthz", get(health_live))
.route("/readyz", get(move || health_ready(Arc::clone(&r2))));
}
#[cfg(all(feature = "health", feature = "serde_json"))]
if self.config.enable_health_endpoints {
router = router.route("/health/detailed", get(health_detailed));
}
#[cfg(feature = "config")]
if self.config.enable_config_endpoint {
router = router.route("/config", get(config_dump));
}
router
.layer(TraceLayer::new_for_http())
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
self.config.request_timeout(),
))
.layer(ConcurrencyLimitLayer::new(self.config.max_connections))
}
}
#[derive(Clone)]
pub struct ShutdownHandle {
sender: watch::Sender<()>,
}
impl ShutdownHandle {
pub fn shutdown(self) {
let _ = self.sender.send(());
}
}
pub struct ServerFuture {
inner: std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>,
}
impl std::future::Future for ServerFuture {
type Output = Result<()>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
self.inner.as_mut().poll(cx)
}
}
async fn health_live() -> impl IntoResponse {
(StatusCode::OK, "OK")
}
async fn health_ready(ready: Arc<AtomicBool>) -> impl IntoResponse {
let locally_ready = ready.load(Ordering::SeqCst);
#[cfg(feature = "health")]
let registry_ready = crate::health::HealthRegistry::is_ready();
#[cfg(not(feature = "health"))]
let registry_ready = true;
if locally_ready && registry_ready {
(StatusCode::OK, "OK")
} else {
(StatusCode::SERVICE_UNAVAILABLE, "NOT READY")
}
}
#[cfg(all(feature = "health", feature = "serde_json"))]
async fn health_detailed() -> impl IntoResponse {
let json = crate::health::HealthRegistry::to_json();
axum::Json(json)
}
#[cfg(feature = "config")]
async fn config_dump() -> impl IntoResponse {
let effective = crate::config::registry::dump_effective();
let defaults = crate::config::registry::dump_defaults();
let body = serde_json::json!({
"effective": effective,
"defaults": defaults,
"sections": crate::config::registry::sections()
.iter()
.map(|s| serde_json::json!({
"key": s.key,
"type": s.type_name,
}))
.collect::<Vec<_>>(),
});
(
StatusCode::OK,
[("content-type", "application/json")],
serde_json::to_string_pretty(&body).unwrap_or_default(),
)
}
#[cfg(not(feature = "shutdown"))]
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
() = ctrl_c => {},
() = terminate => {},
}
#[cfg(feature = "logger")]
tracing::info!("Shutdown signal received, starting graceful shutdown");
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::Request;
use tower::ServiceExt;
#[tokio::test]
async fn test_health_live() {
let config = HttpServerConfig::default();
let server = HttpServer::new(config);
let app = server.build_router(Router::new());
let response = app
.oneshot(
Request::builder()
.uri("/health/live")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_health_ready_when_ready() {
#[cfg(feature = "health")]
crate::health::HealthRegistry::reset();
let config = HttpServerConfig::default();
let server = HttpServer::new(config);
server.set_ready(true);
let app = server.build_router(Router::new());
let response = app
.oneshot(
Request::builder()
.uri("/health/ready")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_health_ready_when_not_ready() {
let config = HttpServerConfig::default();
let server = HttpServer::new(config);
server.set_ready(false);
let app = server.build_router(Router::new());
let response = app
.oneshot(
Request::builder()
.uri("/health/ready")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE);
}
#[tokio::test]
async fn test_server_with_handle() {
let config = HttpServerConfig::new("127.0.0.1:18080");
let server = HttpServer::new(config);
let app = Router::new().route("/", get(|| async { "Hello" }));
let (handle, future) = server.serve_with_handle(app).await.unwrap();
handle.shutdown();
future.await.unwrap();
}
#[tokio::test]
async fn k8s_standard_health_paths_are_mounted() {
let config = HttpServerConfig::default();
let server = HttpServer::new(config);
let app = server.build_router(Router::new());
for path in &["/healthz", "/readyz"] {
let response = app
.clone()
.oneshot(Request::builder().uri(*path).body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK, "path={path}");
}
}
#[tokio::test]
async fn shutdown_signal_flips_ready_before_drain() {
let config = HttpServerConfig::new("127.0.0.1:18081");
let server = HttpServer::new(config);
let ready = server.ready_flag();
assert!(ready.load(Ordering::SeqCst), "ready starts true");
let app = Router::new().route(
"/slow",
get(|| async {
tokio::time::sleep(std::time::Duration::from_millis(200)).await;
"done"
}),
);
let (handle, future) = server.serve_with_handle(app).await.unwrap();
let server_task = tokio::spawn(future);
handle.shutdown();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
assert!(
!ready.load(Ordering::SeqCst),
"ready must flip to false post-shutdown",
);
let _ = server_task.await;
}
#[test]
fn test_ready_flag() {
let config = HttpServerConfig::default();
let server = HttpServer::new(config);
assert!(server.is_ready());
server.set_ready(false);
assert!(!server.is_ready());
server.set_ready(true);
assert!(server.is_ready());
}
}