1use std::sync::Arc;
2
3use axum::extract::{Request, State};
4use axum::http::StatusCode;
5use axum::middleware::Next;
6use axum::response::Response;
7use subtle::ConstantTimeEq;
8
9use crate::state::ApiState;
10
11fn api_key_matches(provided: Option<&str>, expected: &str) -> bool {
12 provided.is_some_and(|key| bool::from(key.as_bytes().ct_eq(expected.as_bytes())))
13}
14
15pub async fn require_api_key(
16 State(state): State<Arc<ApiState>>,
17 req: Request,
18 next: Next,
19) -> Result<Response, StatusCode> {
20 let Some(ref expected) = state.api_key else {
21 return Ok(next.run(req).await);
22 };
23
24 let provided = req
25 .headers()
26 .get("authorization")
27 .and_then(|v| v.to_str().ok())
28 .and_then(|v| v.strip_prefix("Bearer "))
29 .or_else(|| req.headers().get("x-api-key").and_then(|v| v.to_str().ok()));
30
31 if api_key_matches(provided, expected) {
32 Ok(next.run(req).await)
33 } else {
34 Err(StatusCode::UNAUTHORIZED)
35 }
36}
37
38#[cfg(test)]
39mod tests {
40 use super::api_key_matches;
41
42 #[test]
43 fn matches_expected_key() {
44 assert!(api_key_matches(Some("secret-key"), "secret-key"));
45 }
46
47 #[test]
48 fn rejects_wrong_key() {
49 assert!(!api_key_matches(Some("wrong-key"), "secret-key"));
50 }
51
52 #[test]
53 fn rejects_missing_key() {
54 assert!(!api_key_matches(None, "secret-key"));
55 }
56}