Skip to main content

logdive_api/
router.rs

1//! Axum router construction.
2//!
3//! Extracted from `main.rs` so integration tests can build the same router
4//! the binary uses without duplicating route definitions. The router is
5//! pure data — no I/O happens here; `AppState` carries the configuration
6//! and all I/O is deferred into the handler layer.
7//!
8//! [`build_router`] accepts a pre-parsed list of allowed CORS origins.
9//! Parsing, validation, and the wildcard `*` check happen in `main.rs`,
10//! keeping this module free of clap and environment-variable concerns.
11
12use axum::{
13    Router,
14    http::{HeaderValue, Method},
15    routing::get,
16};
17use tower_http::cors::{AllowOrigin, Any, CorsLayer};
18
19use crate::handlers::{query_handler, stats_handler, version_handler};
20use crate::state::AppState;
21
22/// Build the application router with all endpoints wired up.
23///
24/// `cors_origins` controls the CORS policy applied to every route:
25///
26/// - `[]` (empty) — CORS disabled; no `Access-Control-Allow-Origin` header
27///   is ever added. This is the default and appropriate for local or
28///   server-side-only consumers.
29/// - `[HeaderValue::from_static("*")]` — wildcard; any origin is reflected.
30///   The single-element invariant is enforced by the caller in `main.rs`.
31/// - Any other non-empty list — exactly those origins are allowed; the
32///   matching origin is reflected back in the response header.
33///
34/// The returned router is ready for `axum::serve` in the binary or
35/// `tower::ServiceExt::oneshot` in tests.
36pub fn build_router(state: AppState, cors_origins: Vec<HeaderValue>) -> Router {
37    let router = Router::new()
38        .route("/query", get(query_handler))
39        .route("/stats", get(stats_handler))
40        .route("/version", get(version_handler))
41        .with_state(state);
42
43    match build_cors_layer(cors_origins) {
44        Some(cors) => router.layer(cors),
45        None => router,
46    }
47}
48
49/// Construct a [`CorsLayer`] from the parsed origin list, or return `None`
50/// when CORS should be disabled entirely (empty list).
51///
52/// Kept private and separate from [`build_router`] so the CORS policy
53/// logic is testable without constructing a full router.
54///
55/// Allowed methods are locked to `GET` only — the API is read-only and
56/// must never advertise write methods to cross-origin callers.
57fn build_cors_layer(origins: Vec<HeaderValue>) -> Option<CorsLayer> {
58    if origins.is_empty() {
59        return None;
60    }
61
62    let base = CorsLayer::new().allow_methods([Method::GET]);
63
64    // A single raw `*` byte sequence means "allow any origin". Comparing
65    // bytes rather than going through `HeaderValue::from_static` is
66    // deliberate: it matches regardless of how the caller constructed the
67    // HeaderValue and avoids a redundant allocation.
68    if origins.len() == 1 && origins[0].as_bytes() == b"*" {
69        Some(base.allow_origin(Any))
70    } else {
71        Some(base.allow_origin(AllowOrigin::list(origins)))
72    }
73}
74
75// ---------------------------------------------------------------------------
76// Tests
77// ---------------------------------------------------------------------------
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use std::path::PathBuf;
83    use tower::ServiceExt; // for `oneshot`
84
85    fn make_state() -> AppState {
86        AppState::new(PathBuf::from("/tmp/does-not-need-to-exist-yet.db"))
87    }
88
89    /// Construct a GET /version request with the given Origin header.
90    /// /version needs no DB, so make_state's phantom path is fine.
91    fn version_request(origin: &str) -> axum::http::Request<axum::body::Body> {
92        axum::http::Request::builder()
93            .uri("/version")
94            .header("Origin", origin)
95            .body(axum::body::Body::empty())
96            .unwrap()
97    }
98
99    #[test]
100    fn build_router_produces_a_router_from_a_valid_state() {
101        // Compile-time and type-plumbing smoke test. Real behaviour is
102        // validated by the integration test suite and the CORS tests below.
103        let _router: Router = build_router(make_state(), vec![]);
104    }
105
106    #[tokio::test]
107    async fn no_cors_origins_does_not_add_acao_header() {
108        let resp = build_router(make_state(), vec![])
109            .oneshot(version_request("https://example.com"))
110            .await
111            .unwrap();
112        assert!(
113            resp.headers().get("access-control-allow-origin").is_none(),
114            "CORS disabled: ACAO header must be absent"
115        );
116    }
117
118    #[tokio::test]
119    async fn wildcard_cors_adds_acao_star_for_any_origin() {
120        let resp = build_router(make_state(), vec![HeaderValue::from_static("*")])
121            .oneshot(version_request("https://example.com"))
122            .await
123            .unwrap();
124        assert_eq!(
125            resp.headers()
126                .get("access-control-allow-origin")
127                .and_then(|v| v.to_str().ok()),
128            Some("*"),
129        );
130    }
131
132    #[tokio::test]
133    async fn specific_origin_reflects_matching_origin() {
134        let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
135        let resp = build_router(make_state(), vec![allowed])
136            .oneshot(version_request("https://app.example.com"))
137            .await
138            .unwrap();
139        assert_eq!(
140            resp.headers()
141                .get("access-control-allow-origin")
142                .and_then(|v| v.to_str().ok()),
143            Some("https://app.example.com"),
144        );
145    }
146
147    #[tokio::test]
148    async fn specific_origin_omits_acao_for_unmatched_origin() {
149        let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
150        let resp = build_router(make_state(), vec![allowed])
151            .oneshot(version_request("https://evil.example.com"))
152            .await
153            .unwrap();
154        // CORS enforcement is the browser's job. The server omits ACAO for
155        // a non-matching origin but does not reject the request outright —
156        // that is correct per spec.
157        assert!(
158            resp.headers().get("access-control-allow-origin").is_none(),
159            "non-allowed origin must not receive ACAO header"
160        );
161    }
162}