Skip to main content

mockd/
server.rs

1//! HTTP server layer (Axum).
2//!
3//! [`Server`] wires a compiled [`Router`] to an Axum
4//! application. A single fallback handler receives every request, delegates
5//! matching to the router, applies any configured delay, renders template
6//! expressions in the response body and produces the HTTP response.
7//!
8//! Optional CORS support adds permissive cross-origin headers and handles
9//! `OPTIONS` preflight requests before route matching.
10//!
11//! For testing, [`build_app`] returns a plain `axum::Router` that can be bound
12//! to any address (including an ephemeral one).
13
14use std::collections::HashMap;
15use std::future::Future;
16use std::sync::Arc;
17
18use axum::body::{Body, Bytes};
19use axum::extract::State;
20use axum::http::{HeaderMap, HeaderValue, Method as AxumMethod, Response, StatusCode, Uri};
21use axum::response::IntoResponse;
22use axum::routing::any;
23use serde_json::Value;
24
25use crate::config::{Config, Method};
26use crate::router::{Match, Router, RouterError};
27use crate::template::{render, TemplateContext};
28
29/// A mockd server bound to a configuration.
30#[derive(Debug, Clone)]
31pub struct Server {
32    router: Router,
33    listen: String,
34    cors: bool,
35}
36
37/// Errors that can occur while building or running a [`Server`].
38#[derive(Debug, thiserror::Error)]
39pub enum ServerError {
40    /// The routes could not be compiled.
41    #[error("invalid routes: {0}")]
42    Router(#[from] RouterError),
43
44    /// The configured listen address could not be bound.
45    #[error("could not bind to {addr}: {source}")]
46    Bind {
47        addr: String,
48        #[source]
49        source: std::io::Error,
50    },
51
52    /// The server stopped with an error.
53    #[error("server error: {0}")]
54    Serve(#[source] std::io::Error),
55
56    /// A response status code was invalid.
57    #[error("invalid status code: {0}")]
58    InvalidStatus(u16),
59}
60
61impl Server {
62    /// Build a server from a parsed [`Config`].
63    ///
64    /// CORS is disabled by default; use [`Server::with_cors`] to enable it.
65    pub fn from_config(config: Config) -> Result<Self, ServerError> {
66        let router = Router::new(config.routes)?;
67        Ok(Server {
68            router,
69            listen: config.listen,
70            cors: false,
71        })
72    }
73
74    /// Enable or disable permissive CORS support.
75    pub fn with_cors(mut self, enabled: bool) -> Self {
76        self.cors = enabled;
77        self
78    }
79
80    /// Build the Axum application for this server.
81    ///
82    /// Exposed primarily for integration tests; [`Server::serve`] is the
83    /// normal entry point.
84    pub fn app(&self) -> axum::Router {
85        build_app(self.router.clone(), self.cors)
86    }
87
88    /// Number of compiled routes.
89    pub fn route_count(&self) -> usize {
90        self.router.len()
91    }
92
93    /// Bind the configured address and serve until interrupted.
94    pub async fn serve<F>(&self, shutdown: F) -> Result<(), ServerError>
95    where
96        F: Future<Output = ()> + Send + 'static,
97    {
98        let listen = normalize_listen(&self.listen);
99        let listener = tokio::net::TcpListener::bind(&listen)
100            .await
101            .map_err(|source| ServerError::Bind {
102                addr: listen.clone(),
103                source,
104            })?;
105        let addr = listener_local_addr(&listener);
106        tracing::info!("listening on {addr}");
107        let app = self.app();
108        axum::serve(listener, app)
109            .with_graceful_shutdown(shutdown)
110            .await
111            .map_err(ServerError::Serve)?;
112        Ok(())
113    }
114}
115
116/// Normalize a listen address so that the common shorthand `:PORT` binds to
117/// all interfaces, matching the convention used by Go/nginx and others.
118///
119/// - `:8080`       -> `0.0.0.0:8080`
120/// - `127.0.0.1:8080` -> unchanged
121/// - `[::1]:8080`  -> unchanged
122fn normalize_listen(addr: &str) -> String {
123    if let Some(rest) = addr.strip_prefix(':') {
124        format!("0.0.0.0:{rest}")
125    } else {
126        addr.to_string()
127    }
128}
129
130fn listener_local_addr(listener: &tokio::net::TcpListener) -> String {
131    listener
132        .local_addr()
133        .map(|a| a.to_string())
134        .unwrap_or_else(|_| "(unknown)".to_string())
135}
136
137/// Build an Axum application backed by the given router.
138///
139/// When `cors` is `true`, every response carries permissive CORS headers and
140/// `OPTIONS` preflight requests are answered with `204 No Content` without
141/// being forwarded to the router.
142pub fn build_app(router: Router, cors: bool) -> axum::Router {
143    let state = Arc::new(AppState { router, cors });
144    axum::Router::new().fallback(any(handler)).with_state(state)
145}
146
147/// Shared per-application state passed to the handler.
148#[derive(Clone)]
149struct AppState {
150    router: Router,
151    cors: bool,
152}
153
154/// The single request handler that backs every mockd route.
155async fn handler(
156    State(state): State<Arc<AppState>>,
157    method: AxumMethod,
158    uri: Uri,
159    headers: HeaderMap,
160    body: Bytes,
161) -> Response<Body> {
162    let method_str = method.as_str().to_string();
163    let path = uri.path().to_string();
164
165    // CORS preflight: must be handled before route matching because OPTIONS
166    // is not part of mockd's Method enum.
167    if state.cors
168        && method == AxumMethod::OPTIONS
169        && headers.contains_key("access-control-request-method")
170    {
171        tracing::info!(%method_str, %path, status = 204, "cors preflight");
172        return cors_preflight(&headers);
173    }
174
175    let Some(core_method) = Method::from_http_str(method.as_str()) else {
176        tracing::info!(%method_str, %path, status = 404, "unsupported method");
177        return not_found(state.cors);
178    };
179
180    let query = parse_query(uri.query().unwrap_or(""));
181    let header_map = collect_headers(&headers);
182    let request_body: Value = serde_json::from_slice(&body).unwrap_or(Value::Null);
183
184    let Some(Match {
185        path_params,
186        response,
187    }) = state
188        .router
189        .resolve(core_method, &path, &query, &header_map, &request_body)
190    else {
191        tracing::info!(%method_str, %path, status = 404, "no matching route");
192        return not_found(state.cors);
193    };
194
195    // Optional artificial delay (used to test timeouts).
196    if let Some(delay) = response.delay {
197        tracing::debug!(?delay, "applying artificial delay");
198        tokio::time::sleep(delay).await;
199    }
200
201    // Render template expressions in the body using the request context.
202    let rendered = response.body.map(|b| {
203        let ctx = TemplateContext {
204            path: path_params.clone(),
205            query: query.clone(),
206            headers: header_map.clone(),
207            body: request_body.clone(),
208        };
209        render(&b, &ctx)
210    });
211
212    let status = response.status;
213    let close_connection = response.close_connection;
214
215    let mut resp = build_response(status, &response.headers, rendered, close_connection)
216        .unwrap_or_else(|_| internal_error());
217
218    if state.cors {
219        add_cors_headers(resp.headers_mut());
220    }
221
222    tracing::info!(%method_str, %path, status, "handled");
223    resp
224}
225
226/// Parse a raw query string into a map.
227///
228/// Note: values are **not** percent-decoded in this MVP. Keys without a value
229/// map to an empty string.
230fn parse_query(query: &str) -> HashMap<String, String> {
231    let mut map = HashMap::new();
232    if query.is_empty() {
233        return map;
234    }
235    for pair in query.split('&') {
236        if pair.is_empty() {
237            continue;
238        }
239        match pair.split_once('=') {
240            Some((k, v)) => {
241                map.insert(k.to_string(), v.to_string());
242            }
243            None => {
244                map.insert(pair.to_string(), String::new());
245            }
246        }
247    }
248    map
249}
250
251/// Collect request headers into a map with lower-cased keys.
252fn collect_headers(headers: &HeaderMap) -> HashMap<String, String> {
253    let mut map = HashMap::new();
254    for (name, value) in headers.iter() {
255        let key = name.as_str().to_ascii_lowercase();
256        let val = value.to_str().unwrap_or("").to_string();
257        map.entry(key).or_insert(val);
258    }
259    map
260}
261
262/// Build an HTTP response from the mockd response definition.
263fn build_response(
264    status: u16,
265    headers: &HashMap<String, String>,
266    body: Option<Value>,
267    close_connection: bool,
268) -> Result<Response<Body>, ServerError> {
269    let status = StatusCode::from_u16(status).map_err(|_| ServerError::InvalidStatus(status))?;
270
271    let mut builder = Response::builder().status(status);
272
273    let has_content_type = headers
274        .keys()
275        .any(|k| k.eq_ignore_ascii_case("content-type"));
276
277    for (name, value) in headers {
278        builder = builder.header(name.as_str(), value.as_str());
279    }
280
281    if close_connection {
282        builder = builder.header("connection", "close");
283    }
284
285    let bytes = if let Some(body) = body {
286        if !has_content_type {
287            builder = builder.header("content-type", "application/json");
288        }
289        serde_json::to_vec(&body).unwrap_or_default()
290    } else {
291        Vec::new()
292    };
293
294    Ok(builder.body(Body::from(bytes)).unwrap())
295}
296
297/// Append the permissive CORS headers used for non-preflight responses.
298fn add_cors_headers(headers: &mut HeaderMap) {
299    headers.insert("access-control-allow-origin", HeaderValue::from_static("*"));
300    // The response varies by Origin, even though we always echo `*`, so that
301    // caches do not return a CORS-configured response to a non-CORS request.
302    headers.insert("vary", HeaderValue::from_static("origin"));
303}
304
305/// Build a `204 No Content` response for a CORS preflight.
306///
307/// The allowed request headers are echoed from `Access-Control-Request-Headers`
308/// if present, otherwise `*` is advertised.
309fn cors_preflight(req_headers: &HeaderMap) -> Response<Body> {
310    let allow_headers = req_headers
311        .get("access-control-request-headers")
312        .cloned()
313        .unwrap_or_else(|| HeaderValue::from_static("*"));
314
315    Response::builder()
316        .status(StatusCode::NO_CONTENT)
317        .header("access-control-allow-origin", "*")
318        .header(
319            "access-control-allow-methods",
320            "GET, POST, PUT, PATCH, DELETE, OPTIONS",
321        )
322        .header("access-control-allow-headers", allow_headers)
323        .header("access-control-max-age", "86400")
324        .header("vary", "origin")
325        .body(Body::empty())
326        .unwrap()
327}
328
329fn not_found(cors: bool) -> Response<Body> {
330    let mut resp = (
331        StatusCode::NOT_FOUND,
332        [(axum::http::header::CONTENT_TYPE, "application/json")],
333        r#"{"error":"no matching route"}"#,
334    )
335        .into_response();
336    if cors {
337        add_cors_headers(resp.headers_mut());
338    }
339    resp
340}
341
342fn internal_error() -> Response<Body> {
343    (
344        StatusCode::INTERNAL_SERVER_ERROR,
345        [(axum::http::header::CONTENT_TYPE, "application/json")],
346        r#"{"error":"internal mockd error"}"#,
347    )
348        .into_response()
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn parse_query_basic() {
357        let q = parse_query("role=admin&tenant=a&flag");
358        assert_eq!(q.get("role").unwrap(), "admin");
359        assert_eq!(q.get("tenant").unwrap(), "a");
360        assert_eq!(q.get("flag").unwrap(), "");
361    }
362
363    #[test]
364    fn parse_query_empty() {
365        assert!(parse_query("").is_empty());
366    }
367
368    #[test]
369    fn collect_headers_lowercases() {
370        let mut hm = HeaderMap::new();
371        hm.insert("X-Tenant-Id", "a".parse().unwrap());
372        let m = collect_headers(&hm);
373        assert_eq!(m.get("x-tenant-id").unwrap(), "a");
374    }
375
376    #[test]
377    fn build_response_sets_json_content_type_when_body_present() {
378        let resp = build_response(
379            200,
380            &HashMap::new(),
381            Some(serde_json::json!({"ok": true})),
382            false,
383        )
384        .unwrap();
385        assert_eq!(resp.status(), StatusCode::OK);
386        assert_eq!(
387            resp.headers()
388                .get("content-type")
389                .unwrap()
390                .to_str()
391                .unwrap(),
392            "application/json"
393        );
394    }
395
396    #[test]
397    fn build_response_keeps_explicit_content_type() {
398        let mut headers = HashMap::new();
399        headers.insert("Content-Type".to_string(), "text/plain".to_string());
400        let resp = build_response(200, &headers, Some(Value::String("hi".into())), false).unwrap();
401        assert_eq!(
402            resp.headers()
403                .get("content-type")
404                .unwrap()
405                .to_str()
406                .unwrap(),
407            "text/plain"
408        );
409    }
410
411    #[test]
412    fn build_response_close_connection_header() {
413        let resp = build_response(500, &HashMap::new(), None, true).unwrap();
414        assert_eq!(
415            resp.headers().get("connection").unwrap().to_str().unwrap(),
416            "close"
417        );
418    }
419
420    #[test]
421    fn build_response_rejects_invalid_status() {
422        // Status codes must be in the 100..=999 range; 6000 is rejected by
423        // the underlying `StatusCode::from_u16`.
424        let err = build_response(6000, &HashMap::new(), None, false).unwrap_err();
425        assert!(matches!(err, ServerError::InvalidStatus(6000)));
426    }
427
428    #[test]
429    fn normalize_listen_handles_shorthand() {
430        assert_eq!(normalize_listen(":8080"), "0.0.0.0:8080");
431        assert_eq!(normalize_listen("127.0.0.1:9000"), "127.0.0.1:9000");
432        assert_eq!(normalize_listen("[::1]:8080"), "[::1]:8080");
433    }
434
435    #[test]
436    fn cors_preflight_has_cors_headers() {
437        let resp = cors_preflight(&HeaderMap::new());
438        assert_eq!(resp.status(), StatusCode::NO_CONTENT);
439        assert_eq!(
440            resp.headers()
441                .get("access-control-allow-origin")
442                .unwrap()
443                .to_str()
444                .unwrap(),
445            "*"
446        );
447        assert!(resp
448            .headers()
449            .get("access-control-allow-methods")
450            .unwrap()
451            .to_str()
452            .unwrap()
453            .contains("GET"));
454        // No Access-Control-Request-Headers -> default to "*".
455        assert_eq!(
456            resp.headers()
457                .get("access-control-allow-headers")
458                .unwrap()
459                .to_str()
460                .unwrap(),
461            "*"
462        );
463    }
464
465    #[test]
466    fn cors_preflight_echoes_requested_headers() {
467        let mut req = HeaderMap::new();
468        req.insert(
469            "access-control-request-headers",
470            "X-Tenant-Id, Authorization".parse().unwrap(),
471        );
472        let resp = cors_preflight(&req);
473        assert_eq!(
474            resp.headers()
475                .get("access-control-allow-headers")
476                .unwrap()
477                .to_str()
478                .unwrap(),
479            "X-Tenant-Id, Authorization"
480        );
481    }
482}