Skip to main content

altair_server/
builder.rs

1//! Typed builder for [`crate::Server`].
2
3use crate::error::{Error, Result};
4use crate::health::{self, HealthResponder};
5use crate::middleware::DefaultStack;
6use crate::server::Server;
7use axum::Router;
8use axum::handler::Handler;
9use axum::response::IntoResponse;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tower_http::cors::CorsLayer;
15
16const DEFAULT_BIND_ADDR: &str = "0.0.0.0:8080";
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
18const DEFAULT_HEALTH_PATH: &str = "/health";
19/// Default request body size limit: 2 MiB. Matches axum's built-in limit
20/// but applied via tower-http so it can be overridden up or down.
21const DEFAULT_BODY_LIMIT_BYTES: usize = 2 * 1024 * 1024;
22
23/// Typed builder for [`Server`].
24///
25/// Construct via [`Server::builder`](crate::Server::builder).
26///
27/// # Defaults
28///
29/// - bind address: `0.0.0.0:8080`
30/// - request timeout: 30s (applied via `tower_http::timeout::TimeoutLayer`)
31/// - request body limit: 2 MiB (via [`Self::request_body_limit`])
32/// - tracing, request-id, health endpoint at `/health`: enabled
33/// - CORS, compression, shutdown timeout: disabled / unset
34#[must_use]
35#[allow(clippy::struct_excessive_bools)] // each toggle is an independent middleware knob
36pub struct ServerBuilder {
37    bind_addr: String,
38    router: Router<()>,
39    tracing: bool,
40    request_id: bool,
41    timeout: Duration,
42    body_limit: usize,
43    cors: Option<CorsLayer>,
44    compression: bool,
45    health_enabled: bool,
46    health_path: String,
47    health_responder: HealthResponder,
48    shutdown_timeout: Option<Duration>,
49}
50
51impl Default for ServerBuilder {
52    fn default() -> Self {
53        Self {
54            bind_addr: DEFAULT_BIND_ADDR.to_string(),
55            router: Router::new(),
56            tracing: true,
57            request_id: true,
58            timeout: DEFAULT_TIMEOUT,
59            body_limit: DEFAULT_BODY_LIMIT_BYTES,
60            cors: None,
61            compression: false,
62            health_enabled: true,
63            health_path: DEFAULT_HEALTH_PATH.to_string(),
64            health_responder: health::default_responder(),
65            shutdown_timeout: None,
66        }
67    }
68}
69
70impl ServerBuilder {
71    /// Create a builder with defaults.
72    pub fn new() -> Self {
73        Self::default()
74    }
75
76    /// Set the bind address as a string (e.g. `"0.0.0.0:3000"`, `"[::]:8080"`).
77    pub fn bind_addr(mut self, addr: impl Into<String>) -> Self {
78        self.bind_addr = addr.into();
79        self
80    }
81
82    /// Set the bind address as a parsed `SocketAddr`.
83    pub fn bind_socket(mut self, addr: SocketAddr) -> Self {
84        self.bind_addr = addr.to_string();
85        self
86    }
87
88    /// Register a route, delegating to [`axum::Router::route`].
89    pub fn route<H, T>(mut self, path: &str, handler: H) -> Self
90    where
91        H: Handler<T, ()>,
92        T: 'static,
93    {
94        self.router = self.router.route(path, axum::routing::any(handler));
95        self
96    }
97
98    /// Merge another router (delegates to [`axum::Router::merge`]).
99    pub fn merge(mut self, other: Router) -> Self {
100        self.router = self.router.merge(other);
101        self
102    }
103
104    /// Mount a router at a nested path (delegates to [`axum::Router::nest`]).
105    pub fn nest(mut self, prefix: &str, router: Router) -> Self {
106        self.router = self.router.nest(prefix, router);
107        self
108    }
109
110    /// Set the per-request timeout. Default 30s.
111    ///
112    /// The timeout wraps the entire request, including all middleware
113    /// (tracing, CORS, custom layers) and the handler itself. Slow
114    /// middleware will count against this deadline.
115    pub fn request_timeout(mut self, d: Duration) -> Self {
116        self.timeout = d;
117        self
118    }
119
120    /// Cap the size of incoming request bodies (default 2 MiB).
121    ///
122    /// Requests with bodies larger than this receive an immediate
123    /// `413 Payload Too Large` response without buffering the full body.
124    /// Mitigates slow-drip and body-bomb attacks against public-facing
125    /// servers.
126    pub fn request_body_limit(mut self, bytes: usize) -> Self {
127        self.body_limit = bytes;
128        self
129    }
130
131    /// Bound the graceful shutdown drain (default: unbounded).
132    ///
133    /// After the shutdown future resolves, axum stops accepting new
134    /// connections and waits for in-flight requests to finish. Without
135    /// a bound, a stuck handler keeps the server alive forever. Set this
136    /// to enforce a deadline; in-flight requests still running after the
137    /// deadline will be dropped and `run_with_shutdown` returns
138    /// `Err(Error::ShutdownTimeout)`.
139    pub fn shutdown_timeout(mut self, d: Duration) -> Self {
140        self.shutdown_timeout = Some(d);
141        self
142    }
143
144    /// Disable the default tracing middleware.
145    pub fn disable_tracing(mut self) -> Self {
146        self.tracing = false;
147        self
148    }
149
150    /// Disable the default request-id middleware.
151    pub fn disable_request_id(mut self) -> Self {
152        self.request_id = false;
153        self
154    }
155
156    /// Enable CORS with permissive defaults (`CorsLayer::permissive()`).
157    pub fn enable_cors(mut self) -> Self {
158        self.cors = Some(CorsLayer::permissive());
159        self
160    }
161
162    /// Enable CORS with a custom [`CorsLayer`].
163    pub fn enable_cors_with(mut self, layer: CorsLayer) -> Self {
164        self.cors = Some(layer);
165        self
166    }
167
168    /// Enable response compression (gzip/br/zstd).
169    pub fn enable_compression(mut self) -> Self {
170        self.compression = true;
171        self
172    }
173
174    /// Customise the health endpoint path. Default `/health`.
175    pub fn health_path(mut self, path: &str) -> Self {
176        self.health_path = path.to_string();
177        self
178    }
179
180    /// Provide a custom responder for the health endpoint.
181    pub fn health_response<F, R>(mut self, responder: F) -> Self
182    where
183        F: Fn() -> R + Send + Sync + 'static,
184        R: IntoResponse + 'static,
185    {
186        self.health_responder = Arc::new(move || responder().into_response());
187        self
188    }
189
190    /// Disable the built-in health endpoint.
191    pub fn disable_health(mut self) -> Self {
192        self.health_enabled = false;
193        self
194    }
195
196    /// Bind the listener and build a [`Server`] ready to run.
197    pub async fn build(self) -> Result<Server> {
198        let addr: SocketAddr = self.bind_addr.parse().map_err(|e| {
199            Error::Configuration(format!("invalid bind address '{}': {e}", self.bind_addr))
200        })?;
201
202        let listener = TcpListener::bind(addr).await.map_err(|e| Error::Bind {
203            addr: self.bind_addr.clone(),
204            source: e,
205        })?;
206        let local_addr = listener.local_addr().map_err(Error::from)?;
207
208        // Register health first so it always wins on its configured path.
209        let router = health::install(
210            self.router,
211            self.health_enabled,
212            &self.health_path,
213            self.health_responder,
214        );
215
216        let stack = DefaultStack {
217            tracing: self.tracing,
218            request_id: self.request_id,
219            timeout: self.timeout,
220            body_limit: self.body_limit,
221            cors: self.cors,
222            compression: self.compression,
223        };
224
225        let router = stack.apply(router);
226
227        Ok(Server::from_parts(
228            router,
229            listener,
230            local_addr,
231            self.shutdown_timeout,
232        ))
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[tokio::test]
241    async fn build_with_defaults_binds_ephemeral_port() {
242        let server = ServerBuilder::new()
243            .bind_addr("127.0.0.1:0")
244            .build()
245            .await
246            .unwrap();
247        let addr = server.local_addr();
248        assert_eq!(addr.ip().to_string(), "127.0.0.1");
249        assert!(addr.port() > 0);
250    }
251
252    #[tokio::test]
253    async fn build_rejects_invalid_bind_address() {
254        let result = ServerBuilder::new()
255            .bind_addr("not a socket address")
256            .build()
257            .await;
258        assert!(matches!(result, Err(Error::Configuration(_))));
259    }
260
261    #[tokio::test]
262    async fn build_with_bind_socket_works() {
263        let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
264        let server = ServerBuilder::new()
265            .bind_socket(addr)
266            .build()
267            .await
268            .unwrap();
269        assert_eq!(server.local_addr().ip().to_string(), "127.0.0.1");
270    }
271
272    #[tokio::test]
273    async fn build_with_custom_timeout() {
274        let server = ServerBuilder::new()
275            .bind_addr("127.0.0.1:0")
276            .request_timeout(Duration::from_secs(5))
277            .build()
278            .await
279            .unwrap();
280        let _ = server.local_addr();
281    }
282}