1use axum::{
8 body::Body,
9 extract::Request,
10 http::{header, StatusCode},
11 middleware::Next,
12 response::{IntoResponse, Response},
13 Json,
14};
15use std::sync::Arc;
16
17fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
24 if a.len() != b.len() {
25 return false;
31 }
32
33 let mut acc: u8 = 0;
34 for (x, y) in a.iter().zip(b.iter()) {
35 acc |= x ^ y;
36 }
37 acc == 0
38}
39
40fn any_key_matches(keys: &[String], token: &str) -> bool {
45 let token_bytes = token.as_bytes();
46 let mut matched = false;
47 for key in keys {
48 if constant_time_eq(key.as_bytes(), token_bytes) {
49 matched = true;
50 }
51 }
53 matched
54}
55
56#[derive(Debug, Clone)]
58pub struct AuthState {
59 pub api_keys: Arc<Vec<String>>,
61}
62
63impl AuthState {
64 pub fn new(api_keys: Vec<String>) -> Self {
66 Self {
67 api_keys: Arc::new(api_keys),
68 }
69 }
70
71 pub fn auth_enabled(&self) -> bool {
73 !self.api_keys.is_empty()
74 }
75}
76
77fn is_public_path(path: &str) -> bool {
79 path == "/health" || path == "/ready" || path == "/metrics"
80}
81
82fn extract_bearer_token(header_value: &str) -> Option<&str> {
84 let trimmed = header_value.trim();
85 if trimmed.len() > 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
86 let token = trimmed[7..].trim();
87 if token.is_empty() {
88 None
89 } else {
90 Some(token)
91 }
92 } else {
93 None
94 }
95}
96
97pub async fn auth_middleware(
101 axum::extract::State(state): axum::extract::State<AuthState>,
102 request: Request<Body>,
103 next: Next,
104) -> Response {
105 if !state.auth_enabled() {
107 return next.run(request).await;
108 }
109
110 if is_public_path(request.uri().path()) {
112 return next.run(request).await;
113 }
114
115 let auth_header = request
117 .headers()
118 .get(header::AUTHORIZATION)
119 .and_then(|v| v.to_str().ok());
120
121 match auth_header {
122 Some(value) => match extract_bearer_token(value) {
123 Some(token) if any_key_matches(&state.api_keys, token) => next.run(request).await,
124 Some(_) => unauthorized_response("invalid API key"),
125 None => {
126 unauthorized_response("invalid Authorization header format, expected: Bearer <key>")
127 }
128 },
129 None => unauthorized_response("missing Authorization header"),
130 }
131}
132
133fn unauthorized_response(message: &str) -> Response {
135 (
136 StatusCode::UNAUTHORIZED,
137 Json(serde_json::json!({
138 "error": "Unauthorized",
139 "message": message
140 })),
141 )
142 .into_response()
143}
144
145#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_auth_state_disabled_when_empty() {
155 let state = AuthState::new(vec![]);
156 assert!(!state.auth_enabled());
157 }
158
159 #[test]
160 fn test_auth_state_enabled_with_keys() {
161 let state = AuthState::new(vec!["key1".to_string()]);
162 assert!(state.auth_enabled());
163 }
164
165 #[test]
166 fn test_is_public_path_health() {
167 assert!(is_public_path("/health"));
168 }
169
170 #[test]
171 fn test_is_public_path_ready() {
172 assert!(is_public_path("/ready"));
173 }
174
175 #[test]
176 fn test_is_public_path_metrics() {
177 assert!(is_public_path("/metrics"));
178 }
179
180 #[test]
181 fn test_is_public_path_other() {
182 assert!(!is_public_path("/collections"));
183 assert!(!is_public_path("/query"));
184 assert!(!is_public_path("/health/extra"));
185 }
186
187 #[test]
188 fn test_extract_bearer_token_valid() {
189 assert_eq!(extract_bearer_token("Bearer my-key"), Some("my-key"));
190 assert_eq!(extract_bearer_token("bearer my-key"), Some("my-key"));
191 assert_eq!(extract_bearer_token("BEARER my-key"), Some("my-key"));
192 assert_eq!(extract_bearer_token(" Bearer my-key "), Some("my-key"));
193 }
194
195 #[test]
196 fn test_extract_bearer_token_invalid() {
197 assert_eq!(extract_bearer_token("Basic abc123"), None);
198 assert_eq!(extract_bearer_token("my-key"), None);
199 assert_eq!(extract_bearer_token("Bearer"), None);
200 assert_eq!(extract_bearer_token(""), None);
201 }
202
203 #[test]
204 fn test_extract_bearer_token_whitespace_only() {
205 assert_eq!(extract_bearer_token("Bearer "), None);
206 }
207
208 #[test]
213 fn test_constant_time_eq_identical() {
214 assert!(constant_time_eq(b"secret-key-42", b"secret-key-42"));
215 }
216
217 #[test]
218 fn test_constant_time_eq_different_content() {
219 assert!(!constant_time_eq(b"secret-key-42", b"secret-key-43"));
220 }
221
222 #[test]
223 fn test_constant_time_eq_different_length() {
224 assert!(!constant_time_eq(b"short", b"longer-key"));
225 }
226
227 #[test]
228 fn test_constant_time_eq_empty() {
229 assert!(constant_time_eq(b"", b""));
230 }
231
232 #[test]
233 fn test_any_key_matches_found() {
234 let keys = vec!["key-a".to_string(), "key-b".to_string()];
235 assert!(any_key_matches(&keys, "key-b"));
236 }
237
238 #[test]
239 fn test_any_key_matches_not_found() {
240 let keys = vec!["key-a".to_string(), "key-b".to_string()];
241 assert!(!any_key_matches(&keys, "key-c"));
242 }
243
244 #[test]
245 fn test_any_key_matches_empty_keys() {
246 let keys: Vec<String> = vec![];
247 assert!(!any_key_matches(&keys, "anything"));
248 }
249}