use axum::{
Json,
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use serde_json::json;
pub async fn require_api_key(
State(expected_api_key): State<Option<String>>,
request: Request,
next: Next,
) -> Response {
let Some(expected_key) = expected_api_key else {
return next.run(request).await;
};
let api_key_header = request
.headers()
.get("x-api-key")
.and_then(|value| value.to_str().ok());
match api_key_header {
Some(provided_key)
if constant_time_eq(provided_key.as_bytes(), expected_key.as_bytes()) =>
{
next.run(request).await
}
Some(_) => {
unauthorized_response("Invalid API key")
}
None => {
unauthorized_response("Missing X-Api-Key header")
}
}
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut result: u8 = 0;
for (x, y) in a.iter().zip(b.iter()) {
result |= x ^ y;
}
result == 0
}
fn unauthorized_response(message: &str) -> Response {
let body = Json(json!({
"error": {
"code": "unauthorized",
"message": message
}
}));
(StatusCode::UNAUTHORIZED, body).into_response()
}
#[allow(clippy::unwrap_used, clippy::expect_used)]
#[cfg(test)]
mod tests {
use super::*;
use axum::{
Router,
body::Body,
http::{Request, StatusCode},
middleware,
routing::get,
};
use tower::ServiceExt;
async fn test_handler() -> impl IntoResponse {
(StatusCode::OK, "Success")
}
#[tokio::test]
async fn test_no_api_key_configured() {
let app =
Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(
None::<String>,
require_api_key,
));
let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_valid_api_key() {
let api_key = Some("test-secret-key".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder()
.uri("/test")
.header("X-Api-Key", "test-secret-key")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_invalid_api_key() {
let api_key = Some("correct-key".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder()
.uri("/test")
.header("X-Api-Key", "wrong-key")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("Invalid API key"));
}
#[tokio::test]
async fn test_missing_api_key() {
let api_key = Some("required-key".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let body_str = String::from_utf8(body.to_vec()).unwrap();
assert!(body_str.contains("Missing X-Api-Key header"));
}
#[tokio::test]
async fn test_api_key_case_sensitive() {
let api_key = Some("CaseSensitiveKey".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder()
.uri("/test")
.header("X-Api-Key", "casesensitivekey") .body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_header_name_case_insensitive() {
let api_key = Some("test-key".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder()
.uri("/test")
.header("x-api-key", "test-key")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn test_whitespace_in_api_key() {
let api_key = Some("key-with-space ".to_string());
let app = Router::new()
.route("/test", get(test_handler))
.layer(middleware::from_fn_with_state(api_key, require_api_key));
let request = Request::builder()
.uri("/test")
.header("X-Api-Key", "key-with-space")
.body(Body::empty())
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}
}