use crate::error::{Error, Result};
use crate::health::{self, HealthResponder};
use crate::middleware::DefaultStack;
use crate::server::Server;
use axum::Router;
use axum::handler::Handler;
use axum::response::IntoResponse;
use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::net::TcpListener;
use tower_http::cors::CorsLayer;
const DEFAULT_BIND_ADDR: &str = "0.0.0.0:8080";
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_HEALTH_PATH: &str = "/health";
#[must_use]
#[allow(clippy::struct_excessive_bools)] pub struct ServerBuilder {
bind_addr: String,
router: Router<()>,
tracing: bool,
request_id: bool,
timeout: Duration,
cors: Option<CorsLayer>,
compression: bool,
health_enabled: bool,
health_path: String,
health_responder: HealthResponder,
}
impl Default for ServerBuilder {
fn default() -> Self {
Self {
bind_addr: DEFAULT_BIND_ADDR.to_string(),
router: Router::new(),
tracing: true,
request_id: true,
timeout: DEFAULT_TIMEOUT,
cors: None,
compression: false,
health_enabled: true,
health_path: DEFAULT_HEALTH_PATH.to_string(),
health_responder: health::default_responder(),
}
}
}
impl ServerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn bind_addr(mut self, addr: impl Into<String>) -> Self {
self.bind_addr = addr.into();
self
}
pub fn bind_socket(mut self, addr: SocketAddr) -> Self {
self.bind_addr = addr.to_string();
self
}
pub fn route<H, T>(mut self, path: &str, handler: H) -> Self
where
H: Handler<T, ()>,
T: 'static,
{
self.router = self.router.route(path, axum::routing::any(handler));
self
}
pub fn merge(mut self, other: Router) -> Self {
self.router = self.router.merge(other);
self
}
pub fn nest(mut self, prefix: &str, router: Router) -> Self {
self.router = self.router.nest(prefix, router);
self
}
pub fn request_timeout(mut self, d: Duration) -> Self {
self.timeout = d;
self
}
pub fn disable_tracing(mut self) -> Self {
self.tracing = false;
self
}
pub fn disable_request_id(mut self) -> Self {
self.request_id = false;
self
}
pub fn enable_cors(mut self) -> Self {
self.cors = Some(CorsLayer::permissive());
self
}
pub fn enable_cors_with(mut self, layer: CorsLayer) -> Self {
self.cors = Some(layer);
self
}
pub fn enable_compression(mut self) -> Self {
self.compression = true;
self
}
pub fn health_path(mut self, path: &str) -> Self {
self.health_path = path.to_string();
self
}
pub fn health_response<F, R>(mut self, responder: F) -> Self
where
F: Fn() -> R + Send + Sync + 'static,
R: IntoResponse + 'static,
{
self.health_responder = Arc::new(move || responder().into_response());
self
}
pub fn disable_health(mut self) -> Self {
self.health_enabled = false;
self
}
pub async fn build(self) -> Result<Server> {
let addr: SocketAddr = self.bind_addr.parse().map_err(|e| {
Error::Configuration(format!("invalid bind address '{}': {e}", self.bind_addr))
})?;
let listener = TcpListener::bind(addr).await.map_err(|e| Error::Bind {
addr: self.bind_addr.clone(),
source: e,
})?;
let local_addr = listener.local_addr().map_err(Error::from)?;
let router = health::install(
self.router,
self.health_enabled,
&self.health_path,
self.health_responder,
);
let stack = DefaultStack {
tracing: self.tracing,
request_id: self.request_id,
timeout: self.timeout,
cors: self.cors,
compression: self.compression,
};
let router = stack.apply(router);
Ok(Server::from_parts(router, listener, local_addr))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn build_with_defaults_binds_ephemeral_port() {
let server = ServerBuilder::new()
.bind_addr("127.0.0.1:0")
.build()
.await
.unwrap();
let addr = server.local_addr();
assert_eq!(addr.ip().to_string(), "127.0.0.1");
assert!(addr.port() > 0);
}
#[tokio::test]
async fn build_rejects_invalid_bind_address() {
let result = ServerBuilder::new()
.bind_addr("not a socket address")
.build()
.await;
assert!(matches!(result, Err(Error::Configuration(_))));
}
#[tokio::test]
async fn build_with_bind_socket_works() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let server = ServerBuilder::new()
.bind_socket(addr)
.build()
.await
.unwrap();
assert_eq!(server.local_addr().ip().to_string(), "127.0.0.1");
}
#[tokio::test]
async fn build_with_custom_timeout() {
let server = ServerBuilder::new()
.bind_addr("127.0.0.1:0")
.request_timeout(Duration::from_secs(5))
.build()
.await
.unwrap();
let _ = server.local_addr();
}
}