pub mod auth;
pub mod error;
pub mod jetstream;
pub mod routes;
pub mod state;
pub mod ws;
use axum::extract::DefaultBodyLimit;
use axum::http::{header, Method};
use axum::{
routing::{delete, get, post},
Router,
};
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
pub use state::{AppState, CellRecord, CellState, FormationRecord, FormationStatus};
const FORMATIONS_POST_MAX_BYTES: usize = 64 * 1024;
pub fn router(state: AppState) -> Router {
Router::new()
.route(
"/v1/formations",
post(routes::formations::create_formation)
.layer(DefaultBodyLimit::max(FORMATIONS_POST_MAX_BYTES)),
)
.route("/v1/formations", get(routes::formations::list_formations))
.route("/v1/formations/:id", get(routes::formations::get_formation))
.route(
"/v1/formations/:id",
delete(routes::formations::delete_formation),
)
.route(
"/v1/formations/by-name/:name",
get(routes::formations::get_formation_by_name),
)
.route(
"/v1/formations/by-name/:name",
delete(routes::formations::delete_formation_by_name),
)
.route(
"/v1/formations/:id/status",
post(routes::formations::update_formation_status),
)
.route("/v1/cells", get(routes::cells::list_cells))
.route("/v1/cells/:id", get(routes::cells::get_cell))
.route("/v1/version", get(routes::meta::get_version))
.route("/v1/events", get(routes::events::list_events))
.route("/ws/events", get(ws::ws_events))
.layer(TraceLayer::new_for_http())
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::OPTIONS])
.allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]),
)
.with_state(state)
}
#[cfg(test)]
mod cors_tests {
use super::*;
use axum::body::Body;
use axum::http::{header, Method, Request, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn cors_preflight_for_post_does_not_allow_post() {
let state = AppState::new(None, "test-token");
let app = router(state);
let req = Request::builder()
.method(Method::OPTIONS)
.uri("/v1/formations")
.header(header::ORIGIN, "http://attacker.example")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "POST")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "authorization")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.expect("router response");
assert_eq!(resp.status(), StatusCode::OK);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_ascii_uppercase();
assert!(
!allow_methods.contains("POST"),
"POST must not appear in Access-Control-Allow-Methods (got {allow_methods:?})",
);
assert!(
allow_methods.contains("GET"),
"GET must appear in Access-Control-Allow-Methods (got {allow_methods:?})",
);
}
#[tokio::test]
async fn cors_preflight_for_get_is_allowed() {
let state = AppState::new(None, "test-token");
let app = router(state);
let req = Request::builder()
.method(Method::OPTIONS)
.uri("/v1/formations")
.header(header::ORIGIN, "http://localhost:9999")
.header(header::ACCESS_CONTROL_REQUEST_METHOD, "GET")
.header(header::ACCESS_CONTROL_REQUEST_HEADERS, "authorization")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.expect("router response");
assert_eq!(resp.status(), StatusCode::OK);
let allow_methods = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.and_then(|v| v.to_str().ok())
.unwrap_or_default()
.to_ascii_uppercase();
assert!(
allow_methods.contains("GET"),
"GET must be in Access-Control-Allow-Methods (got {allow_methods:?})",
);
}
}