logdive_api/router.rs
1//! Axum router construction.
2//!
3//! Extracted from `main.rs` so integration tests can build the same router
4//! the binary uses without duplicating route definitions. The router is
5//! pure data — no I/O happens here; `AppState` carries the configuration
6//! and all I/O is deferred into the handler layer.
7//!
8//! [`build_router`] accepts a pre-parsed list of allowed CORS origins.
9//! Parsing, validation, and the wildcard `*` check happen in `main.rs`,
10//! keeping this module free of clap and environment-variable concerns.
11
12use axum::{
13 Router,
14 http::{HeaderValue, Method},
15 routing::get,
16};
17use tower_http::cors::{AllowOrigin, Any, CorsLayer};
18
19use crate::handlers::{query_handler, stats_handler, version_handler};
20use crate::state::AppState;
21
22/// Build the application router with all endpoints wired up.
23///
24/// `cors_origins` controls the CORS policy applied to every route:
25///
26/// - `[]` (empty) — CORS disabled; no `Access-Control-Allow-Origin` header
27/// is ever added. This is the default and appropriate for local or
28/// server-side-only consumers.
29/// - `[HeaderValue::from_static("*")]` — wildcard; any origin is reflected.
30/// The single-element invariant is enforced by the caller in `main.rs`.
31/// - Any other non-empty list — exactly those origins are allowed; the
32/// matching origin is reflected back in the response header.
33///
34/// The returned router is ready for `axum::serve` in the binary or
35/// `tower::ServiceExt::oneshot` in tests.
36pub fn build_router(state: AppState, cors_origins: Vec<HeaderValue>) -> Router {
37 let router = Router::new()
38 .route("/query", get(query_handler))
39 .route("/stats", get(stats_handler))
40 .route("/version", get(version_handler))
41 .with_state(state);
42
43 match build_cors_layer(cors_origins) {
44 Some(cors) => router.layer(cors),
45 None => router,
46 }
47}
48
49/// Construct a [`CorsLayer`] from the parsed origin list, or return `None`
50/// when CORS should be disabled entirely (empty list).
51///
52/// Kept private and separate from [`build_router`] so the CORS policy
53/// logic is testable without constructing a full router.
54///
55/// Allowed methods are locked to `GET` only — the API is read-only and
56/// must never advertise write methods to cross-origin callers.
57fn build_cors_layer(origins: Vec<HeaderValue>) -> Option<CorsLayer> {
58 if origins.is_empty() {
59 return None;
60 }
61
62 let base = CorsLayer::new().allow_methods([Method::GET]);
63
64 // A single raw `*` byte sequence means "allow any origin". Comparing
65 // bytes rather than going through `HeaderValue::from_static` is
66 // deliberate: it matches regardless of how the caller constructed the
67 // HeaderValue and avoids a redundant allocation.
68 if origins.len() == 1 && origins[0].as_bytes() == b"*" {
69 Some(base.allow_origin(Any))
70 } else {
71 Some(base.allow_origin(AllowOrigin::list(origins)))
72 }
73}
74
75// ---------------------------------------------------------------------------
76// Tests
77// ---------------------------------------------------------------------------
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use std::path::PathBuf;
83 use tower::ServiceExt; // for `oneshot`
84
85 fn make_state() -> AppState {
86 AppState::new(PathBuf::from("/tmp/does-not-need-to-exist-yet.db"))
87 }
88
89 /// Construct a GET /version request with the given Origin header.
90 /// /version needs no DB, so make_state's phantom path is fine.
91 fn version_request(origin: &str) -> axum::http::Request<axum::body::Body> {
92 axum::http::Request::builder()
93 .uri("/version")
94 .header("Origin", origin)
95 .body(axum::body::Body::empty())
96 .unwrap()
97 }
98
99 #[test]
100 fn build_router_produces_a_router_from_a_valid_state() {
101 // Compile-time and type-plumbing smoke test. Real behaviour is
102 // validated by the integration test suite and the CORS tests below.
103 let _router: Router = build_router(make_state(), vec![]);
104 }
105
106 #[tokio::test]
107 async fn no_cors_origins_does_not_add_acao_header() {
108 let resp = build_router(make_state(), vec![])
109 .oneshot(version_request("https://example.com"))
110 .await
111 .unwrap();
112 assert!(
113 resp.headers().get("access-control-allow-origin").is_none(),
114 "CORS disabled: ACAO header must be absent"
115 );
116 }
117
118 #[tokio::test]
119 async fn wildcard_cors_adds_acao_star_for_any_origin() {
120 let resp = build_router(make_state(), vec![HeaderValue::from_static("*")])
121 .oneshot(version_request("https://example.com"))
122 .await
123 .unwrap();
124 assert_eq!(
125 resp.headers()
126 .get("access-control-allow-origin")
127 .and_then(|v| v.to_str().ok()),
128 Some("*"),
129 );
130 }
131
132 #[tokio::test]
133 async fn specific_origin_reflects_matching_origin() {
134 let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
135 let resp = build_router(make_state(), vec![allowed])
136 .oneshot(version_request("https://app.example.com"))
137 .await
138 .unwrap();
139 assert_eq!(
140 resp.headers()
141 .get("access-control-allow-origin")
142 .and_then(|v| v.to_str().ok()),
143 Some("https://app.example.com"),
144 );
145 }
146
147 #[tokio::test]
148 async fn specific_origin_omits_acao_for_unmatched_origin() {
149 let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
150 let resp = build_router(make_state(), vec![allowed])
151 .oneshot(version_request("https://evil.example.com"))
152 .await
153 .unwrap();
154 // CORS enforcement is the browser's job. The server omits ACAO for
155 // a non-matching origin but does not reject the request outright —
156 // that is correct per spec.
157 assert!(
158 resp.headers().get("access-control-allow-origin").is_none(),
159 "non-allowed origin must not receive ACAO header"
160 );
161 }
162}