Skip to main content

mockd/
server.rs

1//! HTTP server layer (Axum).
2//!
3//! [`Server`] wires a compiled [`Router`] to an Axum
4//! application. A single fallback handler receives every request, delegates
5//! matching to the router, applies any configured delay, renders template
6//! expressions in the response body and produces the HTTP response.
7//!
8//! Optional CORS support adds permissive cross-origin headers and handles
9//! `OPTIONS` preflight requests before route matching.
10//!
11//! For testing, [`build_app`] returns a plain `axum::Router` that can be bound
12//! to any address (including an ephemeral one).
13
14use std::collections::HashMap;
15use std::sync::Arc;
16
17use axum::body::{Body, Bytes};
18use axum::extract::State;
19use axum::http::{HeaderMap, HeaderValue, Method as AxumMethod, Response, StatusCode, Uri};
20use axum::response::IntoResponse;
21use axum::routing::any;
22use serde_json::Value;
23
24use crate::config::{Config, Method};
25use crate::router::{Match, Router, RouterError};
26use crate::template::{render, TemplateContext};
27
28/// A mockd server bound to a configuration.
29#[derive(Debug, Clone)]
30pub struct Server {
31    router: Router,
32    listen: String,
33    cors: bool,
34}
35
36/// Errors that can occur while building or running a [`Server`].
37#[derive(Debug, thiserror::Error)]
38pub enum ServerError {
39    /// The routes could not be compiled.
40    #[error("invalid routes: {0}")]
41    Router(#[from] RouterError),
42
43    /// The configured listen address could not be bound.
44    #[error("could not bind to {addr}: {source}")]
45    Bind {
46        addr: String,
47        #[source]
48        source: std::io::Error,
49    },
50
51    /// The server stopped with an error.
52    #[error("server error: {0}")]
53    Serve(#[source] std::io::Error),
54
55    /// A response status code was invalid.
56    #[error("invalid status code: {0}")]
57    InvalidStatus(u16),
58}
59
60impl Server {
61    /// Build a server from a parsed [`Config`].
62    ///
63    /// CORS is disabled by default; use [`Server::with_cors`] to enable it.
64    pub fn from_config(config: Config) -> Result<Self, ServerError> {
65        let router = Router::new(config.routes)?;
66        Ok(Server {
67            router,
68            listen: config.listen,
69            cors: false,
70        })
71    }
72
73    /// Enable or disable permissive CORS support.
74    pub fn with_cors(mut self, enabled: bool) -> Self {
75        self.cors = enabled;
76        self
77    }
78
79    /// Build the Axum application for this server.
80    ///
81    /// Exposed primarily for integration tests; [`Server::serve`] is the
82    /// normal entry point.
83    pub fn app(&self) -> axum::Router {
84        build_app(self.router.clone(), self.cors)
85    }
86
87    /// Number of compiled routes.
88    pub fn route_count(&self) -> usize {
89        self.router.len()
90    }
91
92    /// Bind the configured address and serve until interrupted.
93    pub async fn serve(&self) -> Result<(), ServerError> {
94        let listen = normalize_listen(&self.listen);
95        let listener = tokio::net::TcpListener::bind(&listen)
96            .await
97            .map_err(|source| ServerError::Bind {
98                addr: listen.clone(),
99                source,
100            })?;
101        let addr = listener_local_addr(&listener);
102        tracing::info!("listening on {addr}");
103        let app = self.app();
104        axum::serve(listener, app)
105            .await
106            .map_err(ServerError::Serve)?;
107        Ok(())
108    }
109}
110
111/// Normalize a listen address so that the common shorthand `:PORT` binds to
112/// all interfaces, matching the convention used by Go/nginx and others.
113///
114/// - `:8080`       -> `0.0.0.0:8080`
115/// - `127.0.0.1:8080` -> unchanged
116/// - `[::1]:8080`  -> unchanged
117fn normalize_listen(addr: &str) -> String {
118    if let Some(rest) = addr.strip_prefix(':') {
119        format!("0.0.0.0:{rest}")
120    } else {
121        addr.to_string()
122    }
123}
124
125fn listener_local_addr(listener: &tokio::net::TcpListener) -> String {
126    listener
127        .local_addr()
128        .map(|a| a.to_string())
129        .unwrap_or_else(|_| "(unknown)".to_string())
130}
131
132/// Build an Axum application backed by the given router.
133///
134/// When `cors` is `true`, every response carries permissive CORS headers and
135/// `OPTIONS` preflight requests are answered with `204 No Content` without
136/// being forwarded to the router.
137pub fn build_app(router: Router, cors: bool) -> axum::Router {
138    let state = Arc::new(AppState { router, cors });
139    axum::Router::new().fallback(any(handler)).with_state(state)
140}
141
142/// Shared per-application state passed to the handler.
143#[derive(Clone)]
144struct AppState {
145    router: Router,
146    cors: bool,
147}
148
149/// The single request handler that backs every mockd route.
150async fn handler(
151    State(state): State<Arc<AppState>>,
152    method: AxumMethod,
153    uri: Uri,
154    headers: HeaderMap,
155    body: Bytes,
156) -> Response<Body> {
157    let method_str = method.as_str().to_string();
158    let path = uri.path().to_string();
159
160    // CORS preflight: must be handled before route matching because OPTIONS
161    // is not part of mockd's Method enum.
162    if state.cors
163        && method == AxumMethod::OPTIONS
164        && headers.contains_key("access-control-request-method")
165    {
166        tracing::info!(%method_str, %path, status = 204, "cors preflight");
167        return cors_preflight(&headers);
168    }
169
170    let Some(core_method) = Method::from_http_str(method.as_str()) else {
171        tracing::info!(%method_str, %path, status = 404, "unsupported method");
172        return not_found(state.cors);
173    };
174
175    let query = parse_query(uri.query().unwrap_or(""));
176    let header_map = collect_headers(&headers);
177    let request_body: Value = serde_json::from_slice(&body).unwrap_or(Value::Null);
178
179    let Some(Match {
180        path_params,
181        response,
182    }) = state
183        .router
184        .resolve(core_method, &path, &query, &header_map, &request_body)
185    else {
186        tracing::info!(%method_str, %path, status = 404, "no matching route");
187        return not_found(state.cors);
188    };
189
190    // Optional artificial delay (used to test timeouts).
191    if let Some(delay) = response.delay {
192        tracing::debug!(?delay, "applying artificial delay");
193        tokio::time::sleep(delay).await;
194    }
195
196    // Render template expressions in the body using the request context.
197    let rendered = response.body.map(|b| {
198        let ctx = TemplateContext {
199            path: path_params.clone(),
200            query: query.clone(),
201            headers: header_map.clone(),
202            body: request_body.clone(),
203        };
204        render(&b, &ctx)
205    });
206
207    let status = response.status;
208    let close_connection = response.close_connection;
209
210    let mut resp = build_response(status, &response.headers, rendered, close_connection)
211        .unwrap_or_else(|_| internal_error());
212
213    if state.cors {
214        add_cors_headers(resp.headers_mut());
215    }
216
217    tracing::info!(%method_str, %path, status, "handled");
218    resp
219}
220
221/// Parse a raw query string into a map.
222///
223/// Note: values are **not** percent-decoded in this MVP. Keys without a value
224/// map to an empty string.
225fn parse_query(query: &str) -> HashMap<String, String> {
226    let mut map = HashMap::new();
227    if query.is_empty() {
228        return map;
229    }
230    for pair in query.split('&') {
231        if pair.is_empty() {
232            continue;
233        }
234        match pair.split_once('=') {
235            Some((k, v)) => {
236                map.insert(k.to_string(), v.to_string());
237            }
238            None => {
239                map.insert(pair.to_string(), String::new());
240            }
241        }
242    }
243    map
244}
245
246/// Collect request headers into a map with lower-cased keys.
247fn collect_headers(headers: &HeaderMap) -> HashMap<String, String> {
248    let mut map = HashMap::new();
249    for (name, value) in headers.iter() {
250        let key = name.as_str().to_ascii_lowercase();
251        let val = value.to_str().unwrap_or("").to_string();
252        map.entry(key).or_insert(val);
253    }
254    map
255}
256
257/// Build an HTTP response from the mockd response definition.
258fn build_response(
259    status: u16,
260    headers: &HashMap<String, String>,
261    body: Option<Value>,
262    close_connection: bool,
263) -> Result<Response<Body>, ServerError> {
264    let status = StatusCode::from_u16(status).map_err(|_| ServerError::InvalidStatus(status))?;
265
266    let mut builder = Response::builder().status(status);
267
268    let has_content_type = headers
269        .keys()
270        .any(|k| k.eq_ignore_ascii_case("content-type"));
271
272    for (name, value) in headers {
273        builder = builder.header(name.as_str(), value.as_str());
274    }
275
276    if close_connection {
277        builder = builder.header("connection", "close");
278    }
279
280    let bytes = if let Some(body) = body {
281        if !has_content_type {
282            builder = builder.header("content-type", "application/json");
283        }
284        serde_json::to_vec(&body).unwrap_or_default()
285    } else {
286        Vec::new()
287    };
288
289    Ok(builder.body(Body::from(bytes)).unwrap())
290}
291
292/// Append the permissive CORS headers used for non-preflight responses.
293fn add_cors_headers(headers: &mut HeaderMap) {
294    headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
295    // The response varies by Origin, even though we always echo `*`, so that
296    // caches do not return a CORS-configured response to a non-CORS request.
297    headers.insert("vary", HeaderValue::from_static("origin"));
298}
299
300/// Build a `204 No Content` response for a CORS preflight.
301///
302/// The allowed request headers are echoed from `Access-Control-Request-Headers`
303/// if present, otherwise `*` is advertised.
304fn cors_preflight(req_headers: &HeaderMap) -> Response<Body> {
305    let allow_headers = req_headers
306        .get("access-control-request-headers")
307        .cloned()
308        .unwrap_or_else(|| HeaderValue::from_static("*"));
309
310    Response::builder()
311        .status(StatusCode::NO_CONTENT)
312        .header("access-control-allow-origin", "*")
313        .header(
314            "access-control-allow-methods",
315            "GET, POST, PUT, PATCH, DELETE, OPTIONS",
316        )
317        .header("access-control-allow-headers", allow_headers)
318        .header("access-control-max-age", "86400")
319        .header("vary", "origin")
320        .body(Body::empty())
321        .unwrap()
322}
323
324fn not_found(cors: bool) -> Response<Body> {
325    let mut resp = (
326        StatusCode::NOT_FOUND,
327        [(axum::http::header::CONTENT_TYPE, "application/json")],
328        r#"{"error":"no matching route"}"#,
329    )
330        .into_response();
331    if cors {
332        add_cors_headers(resp.headers_mut());
333    }
334    resp
335}
336
337fn internal_error() -> Response<Body> {
338    (
339        StatusCode::INTERNAL_SERVER_ERROR,
340        [(axum::http::header::CONTENT_TYPE, "application/json")],
341        r#"{"error":"internal mockd error"}"#,
342    )
343        .into_response()
344}
345
346#[cfg(test)]
347mod tests {
348    use super::*;
349
350    #[test]
351    fn parse_query_basic() {
352        let q = parse_query("role=admin&tenant=a&flag");
353        assert_eq!(q.get("role").unwrap(), "admin");
354        assert_eq!(q.get("tenant").unwrap(), "a");
355        assert_eq!(q.get("flag").unwrap(), "");
356    }
357
358    #[test]
359    fn parse_query_empty() {
360        assert!(parse_query("").is_empty());
361    }
362
363    #[test]
364    fn collect_headers_lowercases() {
365        let mut hm = HeaderMap::new();
366        hm.insert("X-Tenant-Id", "a".parse().unwrap());
367        let m = collect_headers(&hm);
368        assert_eq!(m.get("x-tenant-id").unwrap(), "a");
369    }
370
371    #[test]
372    fn build_response_sets_json_content_type_when_body_present() {
373        let resp = build_response(
374            200,
375            &HashMap::new(),
376            Some(serde_json::json!({"ok": true})),
377            false,
378        )
379        .unwrap();
380        assert_eq!(resp.status(), StatusCode::OK);
381        assert_eq!(
382            resp.headers()
383                .get("content-type")
384                .unwrap()
385                .to_str()
386                .unwrap(),
387            "application/json"
388        );
389    }
390
391    #[test]
392    fn build_response_keeps_explicit_content_type() {
393        let mut headers = HashMap::new();
394        headers.insert("Content-Type".to_string(), "text/plain".to_string());
395        let resp = build_response(200, &headers, Some(Value::String("hi".into())), false).unwrap();
396        assert_eq!(
397            resp.headers()
398                .get("content-type")
399                .unwrap()
400                .to_str()
401                .unwrap(),
402            "text/plain"
403        );
404    }
405
406    #[test]
407    fn build_response_close_connection_header() {
408        let resp = build_response(500, &HashMap::new(), None, true).unwrap();
409        assert_eq!(
410            resp.headers().get("connection").unwrap().to_str().unwrap(),
411            "close"
412        );
413    }
414
415    #[test]
416    fn build_response_rejects_invalid_status() {
417        // Status codes must be in the 100..=999 range; 6000 is rejected by
418        // the underlying `StatusCode::from_u16`.
419        let err = build_response(6000, &HashMap::new(), None, false).unwrap_err();
420        assert!(matches!(err, ServerError::InvalidStatus(6000)));
421    }
422
423    #[test]
424    fn normalize_listen_handles_shorthand() {
425        assert_eq!(normalize_listen(":8080"), "0.0.0.0:8080");
426        assert_eq!(normalize_listen("127.0.0.1:9000"), "127.0.0.1:9000");
427        assert_eq!(normalize_listen("[::1]:8080"), "[::1]:8080");
428    }
429
430    #[test]
431    fn cors_preflight_has_cors_headers() {
432        let resp = cors_preflight(&HeaderMap::new());
433        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
434        assert_eq!(
435            resp.headers()
436                .get("access-control-allow-origin")
437                .unwrap()
438                .to_str()
439                .unwrap(),
440            "*"
441        );
442        assert!(resp
443            .headers()
444            .get("access-control-allow-methods")
445            .unwrap()
446            .to_str()
447            .unwrap()
448            .contains("GET"));
449        // No Access-Control-Request-Headers -> default to "*".
450        assert_eq!(
451            resp.headers()
452                .get("access-control-allow-headers")
453                .unwrap()
454                .to_str()
455                .unwrap(),
456            "*"
457        );
458    }
459
460    #[test]
461    fn cors_preflight_echoes_requested_headers() {
462        let mut req = HeaderMap::new();
463        req.insert(
464            "access-control-request-headers",
465            "X-Tenant-Id, Authorization".parse().unwrap(),
466        );
467        let resp = cors_preflight(&req);
468        assert_eq!(
469            resp.headers()
470                .get("access-control-allow-headers")
471                .unwrap()
472                .to_str()
473                .unwrap(),
474            "X-Tenant-Id, Authorization"
475        );
476    }
477}