pub mod auth;
pub mod error;
pub mod jetstream;
pub mod routes;
pub mod state;
pub mod ws;
use axum::extract::DefaultBodyLimit;
use axum::http::{header, Method};
use axum::{
middleware::map_response,
routing::{delete, get, post},
Router,
};
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
use crate::error::{normalize_problem_response, problem_response, AppErrorKind};
pub use state::{AppState, CellRecord, CellState, FormationRecord, FormationStatus};
const FORMATIONS_POST_MAX_BYTES: usize = 64 * 1024;
pub fn router(state: AppState) -> Router {
Router::new()
.route(
"/v1/formations",
post(routes::formations::create_formation)
.layer(DefaultBodyLimit::max(FORMATIONS_POST_MAX_BYTES)),
)
.route("/v1/formations", get(routes::formations::list_formations))
.route("/v1/formations/:id", get(routes::formations::get_formation))
.route(
"/v1/formations/:id",
delete(routes::formations::delete_formation),
)
.route(
"/v1/formations/by-name/:name",
get(routes::formations::get_formation_by_name),
)
.route(
"/v1/formations/by-name/:name",
delete(routes::formations::delete_formation_by_name),
)
.route(
"/v1/formations/:id/status",
post(routes::formations::update_formation_status),
)
.route("/v1/cells", get(routes::cells::list_cells))
.route("/v1/cells/:id", get(routes::cells::get_cell))
.route("/v1/version", get(routes::meta::get_version))
.route("/v1/events", get(routes::events::list_events))
.route("/ws/events", get(ws::ws_events))
.fallback(not_found_handler)
.method_not_allowed_fallback(method_not_allowed_handler)
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::OPTIONS])
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]),
)
.layer(map_response(normalize_problem_response))
.with_state(state)
}
async fn not_found_handler(req: axum::extract::Request) -> axum::response::Response {
let path = req.uri().path().to_string();
let truncated: String = path.chars().take(200).collect();
problem_response(
AppErrorKind::NotFound,
format!("no route matched '{truncated}'"),
)
}
async fn method_not_allowed_handler(req: axum::extract::Request) -> axum::response::Response {
let method = req.method().as_str().to_string();
let path: String = req.uri().path().chars().take(200).collect();
problem_response(
AppErrorKind::MethodNotAllowed,
format!("method '{method}' not allowed for '{path}' — see Allow header"),
)
}
#[cfg(test)]
mod cors_tests {
use super::*;
use axum::body::Body;
use axum::http::{header, Method, Request, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn cors_preflight_for_post_does_not_allow_post() {
let state = AppState::new(None, "test-token");
let app = router(state);
let req = Request::builder()
.method(Method::OPTIONS)
.uri("/v1/formations")
.header(header::ORIGIN, "http://attacker.example")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "authorization")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.expect("router response");
assert_eq!(resp.status(), StatusCode::OK);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_ascii_uppercase();
assert!(
!allow_methods.contains("POST"),
"POST must not appear in Access-Control-Allow-Methods (got {allow_methods:?})",
);
assert!(
allow_methods.contains("GET"),
"GET must appear in Access-Control-Allow-Methods (got {allow_methods:?})",
);
}
#[tokio::test]
async fn cors_preflight_for_get_is_allowed() {
let state = AppState::new(None, "test-token");
let app = router(state);
let req = Request::builder()
.method(Method::OPTIONS)
.uri("/v1/formations")
.header(header::ORIGIN, "http://localhost:9999")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "authorization")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.expect("router response");
assert_eq!(resp.status(), StatusCode::OK);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_ascii_uppercase();
assert!(
allow_methods.contains("GET"),
"GET must be in Access-Control-Allow-Methods (got {allow_methods:?})",
);
}
}