use axum::{
extract::{Query, Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use crate::{api::ApiKeyQuery, AppState};
pub async fn auth_middleware(
State(state): State<AppState>,
Query(query): Query<ApiKeyQuery>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
if let Some(ref expected_key) = state.api_key {
if let Some(ref provided_key) = query.api_key {
if provided_key == expected_key {
return Ok(next.run(request).await);
}
}
return Err(StatusCode::UNAUTHORIZED);
}
Ok(next.run(request).await)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
Router,
};
use std::sync::{Arc, RwLock};
use tower::ServiceExt;
use crate::beacon::BeaconController;
async fn dummy_handler() -> &'static str {
"OK"
}
#[tokio::test]
async fn test_auth_middleware_without_api_key() {
let state = AppState {
beacon: Arc::new(RwLock::new(BeaconController::mock())),
api_key: None,
};
let app = Router::new()
.route("/test", get(dummy_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state);
let response = app
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_middleware_with_correct_api_key() {
let api_key = "test-key-123";
let state = AppState {
beacon: Arc::new(RwLock::new(BeaconController::mock())),
api_key: Some(api_key.to_string()),
};
let app = Router::new()
.route("/test", get(dummy_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri(format!("/test?apiKey={}", api_key))
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_auth_middleware_with_incorrect_api_key() {
let state = AppState {
beacon: Arc::new(RwLock::new(BeaconController::mock())),
api_key: Some("correct-key".to_string()),
};
let app = Router::new()
.route("/test", get(dummy_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state);
let response = app
.oneshot(
Request::builder()
.uri("/test?apiKey=wrong-key")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_auth_middleware_missing_api_key() {
let state = AppState {
beacon: Arc::new(RwLock::new(BeaconController::mock())),
api_key: Some("required-key".to_string()),
};
let app = Router::new()
.route("/test", get(dummy_handler))
.layer(middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.with_state(state);
let response = app
.oneshot(Request::builder().uri("/test").body(Body::empty()).unwrap())
.await
.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
}