use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
pub async fn api_key_middleware(req: Request, next: Next) -> Result<Response, StatusCode> {
let expected = std::env::var("QUELCH_MCP_API_KEY").ok();
if let Some(expected) = expected {
let auth = req
.headers()
.get(axum::http::header::AUTHORIZATION)
.and_then(|h| h.to_str().ok());
let provided = auth.and_then(|h| h.strip_prefix("Bearer "));
if provided != Some(expected.as_str()) {
return Err(StatusCode::UNAUTHORIZED);
}
}
Ok(next.run(req).await)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::{Request, StatusCode};
use axum::{Router, body::Body, middleware, routing::get};
use std::sync::Mutex;
use tower::ServiceExt;
static ENV_LOCK: Mutex<()> = Mutex::new(());
fn app_with_auth() -> Router {
Router::new()
.route("/ping", get(|| async { "pong" }))
.layer(middleware::from_fn(api_key_middleware))
}
async fn response_status(router: Router, key_header: Option<&str>) -> StatusCode {
let mut builder = Request::builder().method("GET").uri("/ping");
if let Some(k) = key_header {
builder = builder.header("Authorization", k);
}
let req = builder.body(Body::empty()).unwrap();
router.oneshot(req).await.unwrap().status()
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn api_key_middleware_no_auth_required_when_env_unset() {
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("QUELCH_MCP_API_KEY").ok();
unsafe { std::env::remove_var("QUELCH_MCP_API_KEY") };
let status = response_status(app_with_auth(), None).await;
assert_eq!(status, StatusCode::OK, "no env var → accept all");
if let Some(v) = prev {
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", v) };
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn api_key_middleware_rejects_missing_header() {
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("QUELCH_MCP_API_KEY").ok();
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", "secret123") };
let status = response_status(app_with_auth(), None).await;
assert_eq!(status, StatusCode::UNAUTHORIZED, "missing header → 401");
unsafe { std::env::remove_var("QUELCH_MCP_API_KEY") };
if let Some(v) = prev {
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", v) };
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn api_key_middleware_rejects_wrong_value() {
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("QUELCH_MCP_API_KEY").ok();
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", "secret123") };
let status = response_status(app_with_auth(), Some("Bearer wrong-key")).await;
assert_eq!(status, StatusCode::UNAUTHORIZED, "wrong key → 401");
unsafe { std::env::remove_var("QUELCH_MCP_API_KEY") };
if let Some(v) = prev {
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", v) };
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn api_key_middleware_passes_correct_value() {
let _guard = ENV_LOCK.lock().unwrap();
let prev = std::env::var("QUELCH_MCP_API_KEY").ok();
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", "secret123") };
let status = response_status(app_with_auth(), Some("Bearer secret123")).await;
assert_eq!(status, StatusCode::OK, "correct key → 200");
unsafe { std::env::remove_var("QUELCH_MCP_API_KEY") };
if let Some(v) = prev {
unsafe { std::env::set_var("QUELCH_MCP_API_KEY", v) };
}
}
}