1use std::collections::HashMap;
15use std::future::Future;
16use std::sync::Arc;
17
18use axum::body::{Body, Bytes};
19use axum::extract::State;
20use axum::http::{HeaderMap, HeaderValue, Method as AxumMethod, Response, StatusCode, Uri};
21use axum::response::IntoResponse;
22use axum::routing::any;
23use serde_json::Value;
24
25use crate::config::{Config, Method};
26use crate::router::{Match, Router, RouterError};
27use crate::template::{render, TemplateContext};
28
29#[derive(Debug, Clone)]
31pub struct Server {
32 router: Router,
33 listen: String,
34 cors: bool,
35}
36
37#[derive(Debug, thiserror::Error)]
39pub enum ServerError {
40 #[error("invalid routes: {0}")]
42 Router(#[from] RouterError),
43
44 #[error("could not bind to {addr}: {source}")]
46 Bind {
47 addr: String,
48 #[source]
49 source: std::io::Error,
50 },
51
52 #[error("server error: {0}")]
54 Serve(#[source] std::io::Error),
55
56 #[error("invalid status code: {0}")]
58 InvalidStatus(u16),
59}
60
61impl Server {
62 pub fn from_config(config: Config) -> Result<Self, ServerError> {
66 let router = Router::new(config.routes)?;
67 Ok(Server {
68 router,
69 listen: config.listen,
70 cors: false,
71 })
72 }
73
74 pub fn with_cors(mut self, enabled: bool) -> Self {
76 self.cors = enabled;
77 self
78 }
79
80 pub fn app(&self) -> axum::Router {
85 build_app(self.router.clone(), self.cors)
86 }
87
88 pub fn route_count(&self) -> usize {
90 self.router.len()
91 }
92
93 pub async fn serve<F>(&self, shutdown: F) -> Result<(), ServerError>
95 where
96 F: Future<Output = ()> + Send + 'static,
97 {
98 let listen = normalize_listen(&self.listen);
99 let listener = tokio::net::TcpListener::bind(&listen)
100 .await
101 .map_err(|source| ServerError::Bind {
102 addr: listen.clone(),
103 source,
104 })?;
105 let addr = listener_local_addr(&listener);
106 tracing::info!("listening on {addr}");
107 let app = self.app();
108 axum::serve(listener, app)
109 .with_graceful_shutdown(shutdown)
110 .await
111 .map_err(ServerError::Serve)?;
112 Ok(())
113 }
114}
115
116fn normalize_listen(addr: &str) -> String {
123 if let Some(rest) = addr.strip_prefix(':') {
124 format!("0.0.0.0:{rest}")
125 } else {
126 addr.to_string()
127 }
128}
129
130fn listener_local_addr(listener: &tokio::net::TcpListener) -> String {
131 listener
132 .local_addr()
133 .map(|a| a.to_string())
134 .unwrap_or_else(|_| "(unknown)".to_string())
135}
136
137pub fn build_app(router: Router, cors: bool) -> axum::Router {
143 let state = Arc::new(AppState { router, cors });
144 axum::Router::new().fallback(any(handler)).with_state(state)
145}
146
147#[derive(Clone)]
149struct AppState {
150 router: Router,
151 cors: bool,
152}
153
154async fn handler(
156 State(state): State<Arc<AppState>>,
157 method: AxumMethod,
158 uri: Uri,
159 headers: HeaderMap,
160 body: Bytes,
161) -> Response<Body> {
162 let method_str = method.as_str().to_string();
163 let path = uri.path().to_string();
164
165 if state.cors
168 && method == AxumMethod::OPTIONS
169 && headers.contains_key("access-control-request-method")
170 {
171 tracing::info!(%method_str, %path, status = 204, "cors preflight");
172 return cors_preflight(&headers);
173 }
174
175 let Some(core_method) = Method::from_http_str(method.as_str()) else {
176 tracing::info!(%method_str, %path, status = 404, "unsupported method");
177 return not_found(state.cors);
178 };
179
180 let query = parse_query(uri.query().unwrap_or(""));
181 let header_map = collect_headers(&headers);
182 let request_body: Value = serde_json::from_slice(&body).unwrap_or(Value::Null);
183
184 let Some(Match {
185 path_params,
186 response,
187 }) = state
188 .router
189 .resolve(core_method, &path, &query, &header_map, &request_body)
190 else {
191 tracing::info!(%method_str, %path, status = 404, "no matching route");
192 return not_found(state.cors);
193 };
194
195 if let Some(delay) = response.delay {
197 tracing::debug!(?delay, "applying artificial delay");
198 tokio::time::sleep(delay).await;
199 }
200
201 let rendered = response.body.map(|b| {
203 let ctx = TemplateContext {
204 path: path_params.clone(),
205 query: query.clone(),
206 headers: header_map.clone(),
207 body: request_body.clone(),
208 };
209 render(&b, &ctx)
210 });
211
212 let status = response.status;
213 let close_connection = response.close_connection;
214
215 let mut resp = build_response(status, &response.headers, rendered, close_connection)
216 .unwrap_or_else(|_| internal_error());
217
218 if state.cors {
219 add_cors_headers(resp.headers_mut());
220 }
221
222 tracing::info!(%method_str, %path, status, "handled");
223 resp
224}
225
226fn parse_query(query: &str) -> HashMap<String, String> {
231 let mut map = HashMap::new();
232 if query.is_empty() {
233 return map;
234 }
235 for pair in query.split('&') {
236 if pair.is_empty() {
237 continue;
238 }
239 match pair.split_once('=') {
240 Some((k, v)) => {
241 map.insert(k.to_string(), v.to_string());
242 }
243 None => {
244 map.insert(pair.to_string(), String::new());
245 }
246 }
247 }
248 map
249}
250
251fn collect_headers(headers: &HeaderMap) -> HashMap<String, String> {
253 let mut map = HashMap::new();
254 for (name, value) in headers.iter() {
255 let key = name.as_str().to_ascii_lowercase();
256 let val = value.to_str().unwrap_or("").to_string();
257 map.entry(key).or_insert(val);
258 }
259 map
260}
261
262fn build_response(
264 status: u16,
265 headers: &HashMap<String, String>,
266 body: Option<Value>,
267 close_connection: bool,
268) -> Result<Response<Body>, ServerError> {
269 let status = StatusCode::from_u16(status).map_err(|_| ServerError::InvalidStatus(status))?;
270
271 let mut builder = Response::builder().status(status);
272
273 let has_content_type = headers
274 .keys()
275 .any(|k| k.eq_ignore_ascii_case("content-type"));
276
277 for (name, value) in headers {
278 builder = builder.header(name.as_str(), value.as_str());
279 }
280
281 if close_connection {
282 builder = builder.header("connection", "close");
283 }
284
285 let bytes = if let Some(body) = body {
286 if !has_content_type {
287 builder = builder.header("content-type", "application/json");
288 }
289 serde_json::to_vec(&body).unwrap_or_default()
290 } else {
291 Vec::new()
292 };
293
294 Ok(builder.body(Body::from(bytes)).unwrap())
295}
296
297fn add_cors_headers(headers: &mut HeaderMap) {
299 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
300 headers.insert("vary", HeaderValue::from_static("origin"));
303}
304
305fn cors_preflight(req_headers: &HeaderMap) -> Response<Body> {
310 let allow_headers = req_headers
311 .get("access-control-request-headers")
312 .cloned()
313 .unwrap_or_else(|| HeaderValue::from_static("*"));
314
315 Response::builder()
316 .status(StatusCode::NO_CONTENT)
317 .header("access-control-allow-origin", "*")
318 .header(
319 "access-control-allow-methods",
320 "GET, POST, PUT, PATCH, DELETE, OPTIONS",
321 )
322 .header("access-control-allow-headers", allow_headers)
323 .header("access-control-max-age", "86400")
324 .header("vary", "origin")
325 .body(Body::empty())
326 .unwrap()
327}
328
329fn not_found(cors: bool) -> Response<Body> {
330 let mut resp = (
331 StatusCode::NOT_FOUND,
332 [(axum::http::header::CONTENT_TYPE, "application/json")],
333 r#"{"error":"no matching route"}"#,
334 )
335 .into_response();
336 if cors {
337 add_cors_headers(resp.headers_mut());
338 }
339 resp
340}
341
342fn internal_error() -> Response<Body> {
343 (
344 StatusCode::INTERNAL_SERVER_ERROR,
345 [(axum::http::header::CONTENT_TYPE, "application/json")],
346 r#"{"error":"internal mockd error"}"#,
347 )
348 .into_response()
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn parse_query_basic() {
357 let q = parse_query("role=admin&tenant=a&flag");
358 assert_eq!(q.get("role").unwrap(), "admin");
359 assert_eq!(q.get("tenant").unwrap(), "a");
360 assert_eq!(q.get("flag").unwrap(), "");
361 }
362
363 #[test]
364 fn parse_query_empty() {
365 assert!(parse_query("").is_empty());
366 }
367
368 #[test]
369 fn collect_headers_lowercases() {
370 let mut hm = HeaderMap::new();
371 hm.insert("X-Tenant-Id", "a".parse().unwrap());
372 let m = collect_headers(&hm);
373 assert_eq!(m.get("x-tenant-id").unwrap(), "a");
374 }
375
376 #[test]
377 fn build_response_sets_json_content_type_when_body_present() {
378 let resp = build_response(
379 200,
380 &HashMap::new(),
381 Some(serde_json::json!({"ok": true})),
382 false,
383 )
384 .unwrap();
385 assert_eq!(resp.status(), StatusCode::OK);
386 assert_eq!(
387 resp.headers()
388 .get("content-type")
389 .unwrap()
390 .to_str()
391 .unwrap(),
392 "application/json"
393 );
394 }
395
396 #[test]
397 fn build_response_keeps_explicit_content_type() {
398 let mut headers = HashMap::new();
399 headers.insert("Content-Type".to_string(), "text/plain".to_string());
400 let resp = build_response(200, &headers, Some(Value::String("hi".into())), false).unwrap();
401 assert_eq!(
402 resp.headers()
403 .get("content-type")
404 .unwrap()
405 .to_str()
406 .unwrap(),
407 "text/plain"
408 );
409 }
410
411 #[test]
412 fn build_response_close_connection_header() {
413 let resp = build_response(500, &HashMap::new(), None, true).unwrap();
414 assert_eq!(
415 resp.headers().get("connection").unwrap().to_str().unwrap(),
416 "close"
417 );
418 }
419
420 #[test]
421 fn build_response_rejects_invalid_status() {
422 let err = build_response(6000, &HashMap::new(), None, false).unwrap_err();
425 assert!(matches!(err, ServerError::InvalidStatus(6000)));
426 }
427
428 #[test]
429 fn normalize_listen_handles_shorthand() {
430 assert_eq!(normalize_listen(":8080"), "0.0.0.0:8080");
431 assert_eq!(normalize_listen("127.0.0.1:9000"), "127.0.0.1:9000");
432 assert_eq!(normalize_listen("[::1]:8080"), "[::1]:8080");
433 }
434
435 #[test]
436 fn cors_preflight_has_cors_headers() {
437 let resp = cors_preflight(&HeaderMap::new());
438 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
439 assert_eq!(
440 resp.headers()
441 .get("access-control-allow-origin")
442 .unwrap()
443 .to_str()
444 .unwrap(),
445 "*"
446 );
447 assert!(resp
448 .headers()
449 .get("access-control-allow-methods")
450 .unwrap()
451 .to_str()
452 .unwrap()
453 .contains("GET"));
454 assert_eq!(
456 resp.headers()
457 .get("access-control-allow-headers")
458 .unwrap()
459 .to_str()
460 .unwrap(),
461 "*"
462 );
463 }
464
465 #[test]
466 fn cors_preflight_echoes_requested_headers() {
467 let mut req = HeaderMap::new();
468 req.insert(
469 "access-control-request-headers",
470 "X-Tenant-Id, Authorization".parse().unwrap(),
471 );
472 let resp = cors_preflight(&req);
473 assert_eq!(
474 resp.headers()
475 .get("access-control-allow-headers")
476 .unwrap()
477 .to_str()
478 .unwrap(),
479 "X-Tenant-Id, Authorization"
480 );
481 }
482}