use axum::{
Router,
http::{HeaderValue, Method},
routing::get,
};
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use crate::handlers::{query_handler, stats_handler, version_handler};
use crate::state::AppState;
pub fn build_router(state: AppState, cors_origins: Vec<HeaderValue>) -> Router {
let router = Router::new()
.route("/query", get(query_handler))
.route("/stats", get(stats_handler))
.route("/version", get(version_handler))
.with_state(state);
match build_cors_layer(cors_origins) {
Some(cors) => router.layer(cors),
None => router,
}
}
fn build_cors_layer(origins: Vec<HeaderValue>) -> Option<CorsLayer> {
if origins.is_empty() {
return None;
}
let base = CorsLayer::new().allow_methods([Method::GET]);
if origins.len() == 1 && origins[0].as_bytes() == b"*" {
Some(base.allow_origin(Any))
} else {
Some(base.allow_origin(AllowOrigin::list(origins)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use tower::ServiceExt;
fn make_state() -> AppState {
AppState::new(PathBuf::from("/tmp/does-not-need-to-exist-yet.db"))
}
fn version_request(origin: &str) -> axum::http::Request<axum::body::Body> {
axum::http::Request::builder()
.uri("/version")
.header("Origin", origin)
.body(axum::body::Body::empty())
.unwrap()
}
#[test]
fn build_router_produces_a_router_from_a_valid_state() {
let _router: Router = build_router(make_state(), vec![]);
}
#[tokio::test]
async fn no_cors_origins_does_not_add_acao_header() {
let resp = build_router(make_state(), vec![])
.oneshot(version_request("https://example.com"))
.await
.unwrap();
assert!(
resp.headers().get("access-control-allow-origin").is_none(),
"CORS disabled: ACAO header must be absent"
);
}
#[tokio::test]
async fn wildcard_cors_adds_acao_star_for_any_origin() {
let resp = build_router(make_state(), vec![HeaderValue::from_static("*")])
.oneshot(version_request("https://example.com"))
.await
.unwrap();
assert_eq!(
resp.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok()),
Some("*"),
);
}
#[tokio::test]
async fn specific_origin_reflects_matching_origin() {
let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
let resp = build_router(make_state(), vec![allowed])
.oneshot(version_request("https://app.example.com"))
.await
.unwrap();
assert_eq!(
resp.headers()
.get("access-control-allow-origin")
.and_then(|v| v.to_str().ok()),
Some("https://app.example.com"),
);
}
#[tokio::test]
async fn specific_origin_omits_acao_for_unmatched_origin() {
let allowed: HeaderValue = "https://app.example.com".parse().unwrap();
let resp = build_router(make_state(), vec![allowed])
.oneshot(version_request("https://evil.example.com"))
.await
.unwrap();
assert!(
resp.headers().get("access-control-allow-origin").is_none(),
"non-allowed origin must not receive ACAO header"
);
}
}