1use crate::error::{Error, Result};
4use crate::health::{self, HealthResponder};
5use crate::middleware::DefaultStack;
6use crate::server::Server;
7use axum::Router;
8use axum::handler::Handler;
9use axum::response::IntoResponse;
10use std::net::SocketAddr;
11use std::sync::Arc;
12use std::time::Duration;
13use tokio::net::TcpListener;
14use tower_http::cors::CorsLayer;
15
16const DEFAULT_BIND_ADDR: &str = "0.0.0.0:8080";
17const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
18const DEFAULT_HEALTH_PATH: &str = "/health";
19const DEFAULT_BODY_LIMIT_BYTES: usize = 2 * 1024 * 1024;
22
23#[must_use]
35#[allow(clippy::struct_excessive_bools)] pub struct ServerBuilder {
37 bind_addr: String,
38 router: Router<()>,
39 tracing: bool,
40 request_id: bool,
41 timeout: Duration,
42 body_limit: usize,
43 cors: Option<CorsLayer>,
44 compression: bool,
45 health_enabled: bool,
46 health_path: String,
47 health_responder: HealthResponder,
48 shutdown_timeout: Option<Duration>,
49}
50
51impl Default for ServerBuilder {
52 fn default() -> Self {
53 Self {
54 bind_addr: DEFAULT_BIND_ADDR.to_string(),
55 router: Router::new(),
56 tracing: true,
57 request_id: true,
58 timeout: DEFAULT_TIMEOUT,
59 body_limit: DEFAULT_BODY_LIMIT_BYTES,
60 cors: None,
61 compression: false,
62 health_enabled: true,
63 health_path: DEFAULT_HEALTH_PATH.to_string(),
64 health_responder: health::default_responder(),
65 shutdown_timeout: None,
66 }
67 }
68}
69
70impl ServerBuilder {
71 pub fn new() -> Self {
73 Self::default()
74 }
75
76 pub fn bind_addr(mut self, addr: impl Into<String>) -> Self {
78 self.bind_addr = addr.into();
79 self
80 }
81
82 pub fn bind_socket(mut self, addr: SocketAddr) -> Self {
84 self.bind_addr = addr.to_string();
85 self
86 }
87
88 pub fn route<H, T>(mut self, path: &str, handler: H) -> Self
90 where
91 H: Handler<T, ()>,
92 T: 'static,
93 {
94 self.router = self.router.route(path, axum::routing::any(handler));
95 self
96 }
97
98 pub fn merge(mut self, other: Router) -> Self {
100 self.router = self.router.merge(other);
101 self
102 }
103
104 pub fn nest(mut self, prefix: &str, router: Router) -> Self {
106 self.router = self.router.nest(prefix, router);
107 self
108 }
109
110 pub fn request_timeout(mut self, d: Duration) -> Self {
116 self.timeout = d;
117 self
118 }
119
120 pub fn request_body_limit(mut self, bytes: usize) -> Self {
127 self.body_limit = bytes;
128 self
129 }
130
131 pub fn shutdown_timeout(mut self, d: Duration) -> Self {
140 self.shutdown_timeout = Some(d);
141 self
142 }
143
144 pub fn disable_tracing(mut self) -> Self {
146 self.tracing = false;
147 self
148 }
149
150 pub fn disable_request_id(mut self) -> Self {
152 self.request_id = false;
153 self
154 }
155
156 pub fn enable_cors(mut self) -> Self {
158 self.cors = Some(CorsLayer::permissive());
159 self
160 }
161
162 pub fn enable_cors_with(mut self, layer: CorsLayer) -> Self {
164 self.cors = Some(layer);
165 self
166 }
167
168 pub fn enable_compression(mut self) -> Self {
170 self.compression = true;
171 self
172 }
173
174 pub fn health_path(mut self, path: &str) -> Self {
176 self.health_path = path.to_string();
177 self
178 }
179
180 pub fn health_response<F, R>(mut self, responder: F) -> Self
182 where
183 F: Fn() -> R + Send + Sync + 'static,
184 R: IntoResponse + 'static,
185 {
186 self.health_responder = Arc::new(move || responder().into_response());
187 self
188 }
189
190 pub fn disable_health(mut self) -> Self {
192 self.health_enabled = false;
193 self
194 }
195
196 pub async fn build(self) -> Result<Server> {
198 let addr: SocketAddr = self.bind_addr.parse().map_err(|e| {
199 Error::Configuration(format!("invalid bind address '{}': {e}", self.bind_addr))
200 })?;
201
202 let listener = TcpListener::bind(addr).await.map_err(|e| Error::Bind {
203 addr: self.bind_addr.clone(),
204 source: e,
205 })?;
206 let local_addr = listener.local_addr().map_err(Error::from)?;
207
208 let router = health::install(
210 self.router,
211 self.health_enabled,
212 &self.health_path,
213 self.health_responder,
214 );
215
216 let stack = DefaultStack {
217 tracing: self.tracing,
218 request_id: self.request_id,
219 timeout: self.timeout,
220 body_limit: self.body_limit,
221 cors: self.cors,
222 compression: self.compression,
223 };
224
225 let router = stack.apply(router);
226
227 Ok(Server::from_parts(
228 router,
229 listener,
230 local_addr,
231 self.shutdown_timeout,
232 ))
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[tokio::test]
241 async fn build_with_defaults_binds_ephemeral_port() {
242 let server = ServerBuilder::new()
243 .bind_addr("127.0.0.1:0")
244 .build()
245 .await
246 .unwrap();
247 let addr = server.local_addr();
248 assert_eq!(addr.ip().to_string(), "127.0.0.1");
249 assert!(addr.port() > 0);
250 }
251
252 #[tokio::test]
253 async fn build_rejects_invalid_bind_address() {
254 let result = ServerBuilder::new()
255 .bind_addr("not a socket address")
256 .build()
257 .await;
258 assert!(matches!(result, Err(Error::Configuration(_))));
259 }
260
261 #[tokio::test]
262 async fn build_with_bind_socket_works() {
263 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
264 let server = ServerBuilder::new()
265 .bind_socket(addr)
266 .build()
267 .await
268 .unwrap();
269 assert_eq!(server.local_addr().ip().to_string(), "127.0.0.1");
270 }
271
272 #[tokio::test]
273 async fn build_with_custom_timeout() {
274 let server = ServerBuilder::new()
275 .bind_addr("127.0.0.1:0")
276 .request_timeout(Duration::from_secs(5))
277 .build()
278 .await
279 .unwrap();
280 let _ = server.local_addr();
281 }
282}