1use std::collections::HashMap;
15use std::sync::Arc;
16
17use axum::body::{Body, Bytes};
18use axum::extract::State;
19use axum::http::{HeaderMap, HeaderValue, Method as AxumMethod, Response, StatusCode, Uri};
20use axum::response::IntoResponse;
21use axum::routing::any;
22use serde_json::Value;
23
24use crate::config::{Config, Method};
25use crate::router::{Match, Router, RouterError};
26use crate::template::{render, TemplateContext};
27
28#[derive(Debug, Clone)]
30pub struct Server {
31 router: Router,
32 listen: String,
33 cors: bool,
34}
35
36#[derive(Debug, thiserror::Error)]
38pub enum ServerError {
39 #[error("invalid routes: {0}")]
41 Router(#[from] RouterError),
42
43 #[error("could not bind to {addr}: {source}")]
45 Bind {
46 addr: String,
47 #[source]
48 source: std::io::Error,
49 },
50
51 #[error("server error: {0}")]
53 Serve(#[source] std::io::Error),
54
55 #[error("invalid status code: {0}")]
57 InvalidStatus(u16),
58}
59
60impl Server {
61 pub fn from_config(config: Config) -> Result<Self, ServerError> {
65 let router = Router::new(config.routes)?;
66 Ok(Server {
67 router,
68 listen: config.listen,
69 cors: false,
70 })
71 }
72
73 pub fn with_cors(mut self, enabled: bool) -> Self {
75 self.cors = enabled;
76 self
77 }
78
79 pub fn app(&self) -> axum::Router {
84 build_app(self.router.clone(), self.cors)
85 }
86
87 pub fn route_count(&self) -> usize {
89 self.router.len()
90 }
91
92 pub async fn serve(&self) -> Result<(), ServerError> {
94 let listen = normalize_listen(&self.listen);
95 let listener = tokio::net::TcpListener::bind(&listen)
96 .await
97 .map_err(|source| ServerError::Bind {
98 addr: listen.clone(),
99 source,
100 })?;
101 let addr = listener_local_addr(&listener);
102 tracing::info!("listening on {addr}");
103 let app = self.app();
104 axum::serve(listener, app)
105 .await
106 .map_err(ServerError::Serve)?;
107 Ok(())
108 }
109}
110
111fn normalize_listen(addr: &str) -> String {
118 if let Some(rest) = addr.strip_prefix(':') {
119 format!("0.0.0.0:{rest}")
120 } else {
121 addr.to_string()
122 }
123}
124
125fn listener_local_addr(listener: &tokio::net::TcpListener) -> String {
126 listener
127 .local_addr()
128 .map(|a| a.to_string())
129 .unwrap_or_else(|_| "(unknown)".to_string())
130}
131
132pub fn build_app(router: Router, cors: bool) -> axum::Router {
138 let state = Arc::new(AppState { router, cors });
139 axum::Router::new().fallback(any(handler)).with_state(state)
140}
141
142#[derive(Clone)]
144struct AppState {
145 router: Router,
146 cors: bool,
147}
148
149async fn handler(
151 State(state): State<Arc<AppState>>,
152 method: AxumMethod,
153 uri: Uri,
154 headers: HeaderMap,
155 body: Bytes,
156) -> Response<Body> {
157 let method_str = method.as_str().to_string();
158 let path = uri.path().to_string();
159
160 if state.cors
163 && method == AxumMethod::OPTIONS
164 && headers.contains_key("access-control-request-method")
165 {
166 tracing::info!(%method_str, %path, status = 204, "cors preflight");
167 return cors_preflight(&headers);
168 }
169
170 let Some(core_method) = Method::from_http_str(method.as_str()) else {
171 tracing::info!(%method_str, %path, status = 404, "unsupported method");
172 return not_found(state.cors);
173 };
174
175 let query = parse_query(uri.query().unwrap_or(""));
176 let header_map = collect_headers(&headers);
177 let request_body: Value = serde_json::from_slice(&body).unwrap_or(Value::Null);
178
179 let Some(Match {
180 path_params,
181 response,
182 }) = state
183 .router
184 .resolve(core_method, &path, &query, &header_map, &request_body)
185 else {
186 tracing::info!(%method_str, %path, status = 404, "no matching route");
187 return not_found(state.cors);
188 };
189
190 if let Some(delay) = response.delay {
192 tracing::debug!(?delay, "applying artificial delay");
193 tokio::time::sleep(delay).await;
194 }
195
196 let rendered = response.body.map(|b| {
198 let ctx = TemplateContext {
199 path: path_params.clone(),
200 query: query.clone(),
201 headers: header_map.clone(),
202 body: request_body.clone(),
203 };
204 render(&b, &ctx)
205 });
206
207 let status = response.status;
208 let close_connection = response.close_connection;
209
210 let mut resp = build_response(status, &response.headers, rendered, close_connection)
211 .unwrap_or_else(|_| internal_error());
212
213 if state.cors {
214 add_cors_headers(resp.headers_mut());
215 }
216
217 tracing::info!(%method_str, %path, status, "handled");
218 resp
219}
220
221fn parse_query(query: &str) -> HashMap<String, String> {
226 let mut map = HashMap::new();
227 if query.is_empty() {
228 return map;
229 }
230 for pair in query.split('&') {
231 if pair.is_empty() {
232 continue;
233 }
234 match pair.split_once('=') {
235 Some((k, v)) => {
236 map.insert(k.to_string(), v.to_string());
237 }
238 None => {
239 map.insert(pair.to_string(), String::new());
240 }
241 }
242 }
243 map
244}
245
246fn collect_headers(headers: &HeaderMap) -> HashMap<String, String> {
248 let mut map = HashMap::new();
249 for (name, value) in headers.iter() {
250 let key = name.as_str().to_ascii_lowercase();
251 let val = value.to_str().unwrap_or("").to_string();
252 map.entry(key).or_insert(val);
253 }
254 map
255}
256
257fn build_response(
259 status: u16,
260 headers: &HashMap<String, String>,
261 body: Option<Value>,
262 close_connection: bool,
263) -> Result<Response<Body>, ServerError> {
264 let status = StatusCode::from_u16(status).map_err(|_| ServerError::InvalidStatus(status))?;
265
266 let mut builder = Response::builder().status(status);
267
268 let has_content_type = headers
269 .keys()
270 .any(|k| k.eq_ignore_ascii_case("content-type"));
271
272 for (name, value) in headers {
273 builder = builder.header(name.as_str(), value.as_str());
274 }
275
276 if close_connection {
277 builder = builder.header("connection", "close");
278 }
279
280 let bytes = if let Some(body) = body {
281 if !has_content_type {
282 builder = builder.header("content-type", "application/json");
283 }
284 serde_json::to_vec(&body).unwrap_or_default()
285 } else {
286 Vec::new()
287 };
288
289 Ok(builder.body(Body::from(bytes)).unwrap())
290}
291
292fn add_cors_headers(headers: &mut HeaderMap) {
294 headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
295 headers.insert("vary", HeaderValue::from_static("origin"));
298}
299
300fn cors_preflight(req_headers: &HeaderMap) -> Response<Body> {
305 let allow_headers = req_headers
306 .get("access-control-request-headers")
307 .cloned()
308 .unwrap_or_else(|| HeaderValue::from_static("*"));
309
310 Response::builder()
311 .status(StatusCode::NO_CONTENT)
312 .header("access-control-allow-origin", "*")
313 .header(
314 "access-control-allow-methods",
315 "GET, POST, PUT, PATCH, DELETE, OPTIONS",
316 )
317 .header("access-control-allow-headers", allow_headers)
318 .header("access-control-max-age", "86400")
319 .header("vary", "origin")
320 .body(Body::empty())
321 .unwrap()
322}
323
324fn not_found(cors: bool) -> Response<Body> {
325 let mut resp = (
326 StatusCode::NOT_FOUND,
327 [(axum::http::header::CONTENT_TYPE, "application/json")],
328 r#"{"error":"no matching route"}"#,
329 )
330 .into_response();
331 if cors {
332 add_cors_headers(resp.headers_mut());
333 }
334 resp
335}
336
337fn internal_error() -> Response<Body> {
338 (
339 StatusCode::INTERNAL_SERVER_ERROR,
340 [(axum::http::header::CONTENT_TYPE, "application/json")],
341 r#"{"error":"internal mockd error"}"#,
342 )
343 .into_response()
344}
345
346#[cfg(test)]
347mod tests {
348 use super::*;
349
350 #[test]
351 fn parse_query_basic() {
352 let q = parse_query("role=admin&tenant=a&flag");
353 assert_eq!(q.get("role").unwrap(), "admin");
354 assert_eq!(q.get("tenant").unwrap(), "a");
355 assert_eq!(q.get("flag").unwrap(), "");
356 }
357
358 #[test]
359 fn parse_query_empty() {
360 assert!(parse_query("").is_empty());
361 }
362
363 #[test]
364 fn collect_headers_lowercases() {
365 let mut hm = HeaderMap::new();
366 hm.insert("X-Tenant-Id", "a".parse().unwrap());
367 let m = collect_headers(&hm);
368 assert_eq!(m.get("x-tenant-id").unwrap(), "a");
369 }
370
371 #[test]
372 fn build_response_sets_json_content_type_when_body_present() {
373 let resp = build_response(
374 200,
375 &HashMap::new(),
376 Some(serde_json::json!({"ok": true})),
377 false,
378 )
379 .unwrap();
380 assert_eq!(resp.status(), StatusCode::OK);
381 assert_eq!(
382 resp.headers()
383 .get("content-type")
384 .unwrap()
385 .to_str()
386 .unwrap(),
387 "application/json"
388 );
389 }
390
391 #[test]
392 fn build_response_keeps_explicit_content_type() {
393 let mut headers = HashMap::new();
394 headers.insert("Content-Type".to_string(), "text/plain".to_string());
395 let resp = build_response(200, &headers, Some(Value::String("hi".into())), false).unwrap();
396 assert_eq!(
397 resp.headers()
398 .get("content-type")
399 .unwrap()
400 .to_str()
401 .unwrap(),
402 "text/plain"
403 );
404 }
405
406 #[test]
407 fn build_response_close_connection_header() {
408 let resp = build_response(500, &HashMap::new(), None, true).unwrap();
409 assert_eq!(
410 resp.headers().get("connection").unwrap().to_str().unwrap(),
411 "close"
412 );
413 }
414
415 #[test]
416 fn build_response_rejects_invalid_status() {
417 let err = build_response(6000, &HashMap::new(), None, false).unwrap_err();
420 assert!(matches!(err, ServerError::InvalidStatus(6000)));
421 }
422
423 #[test]
424 fn normalize_listen_handles_shorthand() {
425 assert_eq!(normalize_listen(":8080"), "0.0.0.0:8080");
426 assert_eq!(normalize_listen("127.0.0.1:9000"), "127.0.0.1:9000");
427 assert_eq!(normalize_listen("[::1]:8080"), "[::1]:8080");
428 }
429
430 #[test]
431 fn cors_preflight_has_cors_headers() {
432 let resp = cors_preflight(&HeaderMap::new());
433 assert_eq!(resp.status(), StatusCode::NO_CONTENT);
434 assert_eq!(
435 resp.headers()
436 .get("access-control-allow-origin")
437 .unwrap()
438 .to_str()
439 .unwrap(),
440 "*"
441 );
442 assert!(resp
443 .headers()
444 .get("access-control-allow-methods")
445 .unwrap()
446 .to_str()
447 .unwrap()
448 .contains("GET"));
449 assert_eq!(
451 resp.headers()
452 .get("access-control-allow-headers")
453 .unwrap()
454 .to_str()
455 .unwrap(),
456 "*"
457 );
458 }
459
460 #[test]
461 fn cors_preflight_echoes_requested_headers() {
462 let mut req = HeaderMap::new();
463 req.insert(
464 "access-control-request-headers",
465 "X-Tenant-Id, Authorization".parse().unwrap(),
466 );
467 let resp = cors_preflight(&req);
468 assert_eq!(
469 resp.headers()
470 .get("access-control-allow-headers")
471 .unwrap()
472 .to_str()
473 .unwrap(),
474 "X-Tenant-Id, Authorization"
475 );
476 }
477}