1use axum::{
8 body::Body,
9 extract::Request,
10 http::{header, StatusCode},
11 middleware::Next,
12 response::{IntoResponse, Response},
13 Json,
14};
15use std::sync::Arc;
16
17fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
24 if a.len() != b.len() {
25 return false;
31 }
32
33 let mut acc: u8 = 0;
34 for (x, y) in a.iter().zip(b.iter()) {
35 acc |= x ^ y;
36 }
37 acc == 0
38}
39
40fn any_key_matches(keys: &[String], token: &str) -> bool {
45 let token_bytes = token.as_bytes();
46 let mut matched = false;
47 for key in keys {
48 if constant_time_eq(key.as_bytes(), token_bytes) {
49 matched = true;
50 }
51 }
53 matched
54}
55
56#[derive(Debug, Clone)]
58pub struct AuthState {
59 pub api_keys: Arc<Vec<String>>,
61}
62
63impl AuthState {
64 pub fn new(api_keys: Vec<String>) -> Self {
66 Self {
67 api_keys: Arc::new(api_keys),
68 }
69 }
70
71 pub fn auth_enabled(&self) -> bool {
73 !self.api_keys.is_empty()
74 }
75}
76
77fn is_public_path(path: &str) -> bool {
95 matches!(path, "/health" | "/ready" | "/v1/health" | "/v1/ready")
96}
97
98fn extract_bearer_token(header_value: &str) -> Option<&str> {
100 let trimmed = header_value.trim();
101 if trimmed.len() > 7 && trimmed[..7].eq_ignore_ascii_case("bearer ") {
102 let token = trimmed[7..].trim();
103 if token.is_empty() {
104 None
105 } else {
106 Some(token)
107 }
108 } else {
109 None
110 }
111}
112
113pub async fn auth_middleware(
117 axum::extract::State(state): axum::extract::State<AuthState>,
118 request: Request<Body>,
119 next: Next,
120) -> Response {
121 if !state.auth_enabled() {
123 return next.run(request).await;
124 }
125
126 if is_public_path(request.uri().path()) {
128 return next.run(request).await;
129 }
130
131 let auth_header = request
133 .headers()
134 .get(header::AUTHORIZATION)
135 .and_then(|v| v.to_str().ok());
136
137 match auth_header {
138 Some(value) => match extract_bearer_token(value) {
139 Some(token) if any_key_matches(&state.api_keys, token) => next.run(request).await,
140 Some(_) => unauthorized_response("invalid API key"),
141 None => {
142 unauthorized_response("invalid Authorization header format, expected: Bearer <key>")
143 }
144 },
145 None => unauthorized_response("missing Authorization header"),
146 }
147}
148
149fn unauthorized_response(message: &str) -> Response {
151 (
152 StatusCode::UNAUTHORIZED,
153 Json(serde_json::json!({
154 "error": "Unauthorized",
155 "message": message
156 })),
157 )
158 .into_response()
159}
160
161#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn test_auth_state_disabled_when_empty() {
171 let state = AuthState::new(vec![]);
172 assert!(!state.auth_enabled());
173 }
174
175 #[test]
176 fn test_auth_state_enabled_with_keys() {
177 let state = AuthState::new(vec!["key1".to_string()]);
178 assert!(state.auth_enabled());
179 }
180
181 #[test]
182 fn test_is_public_path_health() {
183 assert!(is_public_path("/health"));
184 }
185
186 #[test]
187 fn test_is_public_path_ready() {
188 assert!(is_public_path("/ready"));
189 }
190
191 #[test]
192 fn test_is_public_path_metrics_is_protected() {
193 assert!(!is_public_path("/metrics"));
196 assert!(!is_public_path("/v1/metrics"));
197 }
198
199 #[test]
200 fn test_is_public_path_versioned_health() {
201 assert!(is_public_path("/v1/health"));
202 }
203
204 #[test]
205 fn test_is_public_path_versioned_ready() {
206 assert!(is_public_path("/v1/ready"));
207 }
208
209 #[test]
210 fn test_is_public_path_other() {
211 assert!(!is_public_path("/collections"));
212 assert!(!is_public_path("/query"));
213 assert!(!is_public_path("/health/extra"));
214 assert!(!is_public_path("/v1/collections"));
215 }
216
217 #[test]
218 fn test_extract_bearer_token_valid() {
219 assert_eq!(extract_bearer_token("Bearer my-key"), Some("my-key"));
220 assert_eq!(extract_bearer_token("bearer my-key"), Some("my-key"));
221 assert_eq!(extract_bearer_token("BEARER my-key"), Some("my-key"));
222 assert_eq!(extract_bearer_token(" Bearer my-key "), Some("my-key"));
223 }
224
225 #[test]
226 fn test_extract_bearer_token_invalid() {
227 assert_eq!(extract_bearer_token("Basic abc123"), None);
228 assert_eq!(extract_bearer_token("my-key"), None);
229 assert_eq!(extract_bearer_token("Bearer"), None);
230 assert_eq!(extract_bearer_token(""), None);
231 }
232
233 #[test]
234 fn test_extract_bearer_token_whitespace_only() {
235 assert_eq!(extract_bearer_token("Bearer "), None);
236 }
237
238 #[test]
243 fn test_constant_time_eq_identical() {
244 assert!(constant_time_eq(b"secret-key-42", b"secret-key-42"));
245 }
246
247 #[test]
248 fn test_constant_time_eq_different_content() {
249 assert!(!constant_time_eq(b"secret-key-42", b"secret-key-43"));
250 }
251
252 #[test]
253 fn test_constant_time_eq_different_length() {
254 assert!(!constant_time_eq(b"short", b"longer-key"));
255 }
256
257 #[test]
258 fn test_constant_time_eq_empty() {
259 assert!(constant_time_eq(b"", b""));
260 }
261
262 #[test]
263 fn test_any_key_matches_found() {
264 let keys = vec!["key-a".to_string(), "key-b".to_string()];
265 assert!(any_key_matches(&keys, "key-b"));
266 }
267
268 #[test]
269 fn test_any_key_matches_not_found() {
270 let keys = vec!["key-a".to_string(), "key-b".to_string()];
271 assert!(!any_key_matches(&keys, "key-c"));
272 }
273
274 #[test]
275 fn test_any_key_matches_empty_keys() {
276 let keys: Vec<String> = vec![];
277 assert!(!any_key_matches(&keys, "anything"));
278 }
279}