1use crate::api_keys::{ApiKeyManager, KeyRole, ValidationResult};
16use axum::http::{HeaderMap, StatusCode};
17use serde::{Deserialize, Serialize};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum AccessLevel {
25 Public,
27 ReadOnly,
29 Write,
31 Admin,
33}
34
35impl AccessLevel {
36 pub fn as_str(&self) -> &'static str {
37 match self {
38 AccessLevel::Public => "public",
39 AccessLevel::ReadOnly => "readonly",
40 AccessLevel::Write => "write",
41 AccessLevel::Admin => "admin",
42 }
43 }
44}
45
46#[derive(Debug, Clone)]
50pub struct AuthResult {
51 pub authorized: bool,
53 pub key_name: Option<String>,
55 pub role: Option<KeyRole>,
57 pub rate_limit: Option<u32>,
59}
60
61impl AuthResult {
62 fn allowed(v: &ValidationResult) -> Self {
63 AuthResult {
64 authorized: true,
65 key_name: v.key_name.clone(),
66 role: v.role,
67 rate_limit: v.rate_limit,
68 }
69 }
70
71 fn public() -> Self {
72 AuthResult {
73 authorized: true,
74 key_name: None,
75 role: None,
76 rate_limit: None,
77 }
78 }
79}
80
81fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
85 headers
86 .get("authorization")
87 .and_then(|v| v.to_str().ok())
88 .and_then(|v| v.strip_prefix("Bearer "))
89}
90
91pub fn check(
103 api_keys: &mut ApiKeyManager,
104 headers: &HeaderMap,
105 level: AccessLevel,
106) -> Result<AuthResult, StatusCode> {
107 if level == AccessLevel::Public {
109 return Ok(AuthResult::public());
110 }
111
112 if !api_keys.is_enabled() {
114 return Ok(AuthResult::public());
115 }
116
117 let token = match extract_bearer(headers) {
119 Some(t) => t,
120 None => return Err(StatusCode::UNAUTHORIZED),
121 };
122
123 let validation = api_keys.validate(token);
125 if !validation.valid {
126 return Err(StatusCode::FORBIDDEN);
127 }
128
129 let role = validation.role.unwrap_or(KeyRole::ReadOnly);
131 match level {
132 AccessLevel::Public => Ok(AuthResult::allowed(&validation)),
133 AccessLevel::ReadOnly => {
134 Ok(AuthResult::allowed(&validation))
136 }
137 AccessLevel::Write => {
138 if role.can_write() {
139 Ok(AuthResult::allowed(&validation))
140 } else {
141 Err(StatusCode::FORBIDDEN)
142 }
143 }
144 AccessLevel::Admin => {
145 if role.can_manage_keys() {
146 Ok(AuthResult::allowed(&validation))
147 } else {
148 Err(StatusCode::FORBIDDEN)
149 }
150 }
151 }
152}
153
154pub fn peek(
156 api_keys: &ApiKeyManager,
157 headers: &HeaderMap,
158 level: AccessLevel,
159) -> Result<AuthResult, StatusCode> {
160 if level == AccessLevel::Public {
161 return Ok(AuthResult::public());
162 }
163
164 if !api_keys.is_enabled() {
165 return Ok(AuthResult::public());
166 }
167
168 let token = match extract_bearer(headers) {
169 Some(t) => t,
170 None => return Err(StatusCode::UNAUTHORIZED),
171 };
172
173 let validation = api_keys.peek(token);
174 if !validation.valid {
175 return Err(StatusCode::FORBIDDEN);
176 }
177
178 let role = validation.role.unwrap_or(KeyRole::ReadOnly);
179 match level {
180 AccessLevel::Public | AccessLevel::ReadOnly => Ok(AuthResult::allowed(&validation)),
181 AccessLevel::Write => {
182 if role.can_write() {
183 Ok(AuthResult::allowed(&validation))
184 } else {
185 Err(StatusCode::FORBIDDEN)
186 }
187 }
188 AccessLevel::Admin => {
189 if role.can_manage_keys() {
190 Ok(AuthResult::allowed(&validation))
191 } else {
192 Err(StatusCode::FORBIDDEN)
193 }
194 }
195 }
196}
197
198pub fn classify_endpoint(method: &str, path: &str) -> AccessLevel {
200 if path.starts_with("/v1/health") || path == "/v1/version" || path == "/v1/rate-limit" {
202 return AccessLevel::Public;
203 }
204
205 if path.starts_with("/v1/keys") && method != "GET" {
207 return AccessLevel::Admin;
208 }
209
210 match method {
212 "POST" | "PUT" | "DELETE" | "PATCH" => AccessLevel::Write,
213 _ => AccessLevel::ReadOnly,
214 }
215}
216
217#[cfg(test)]
220mod tests {
221 use super::*;
222 use crate::api_keys::ApiKeyManager;
223
224 fn make_headers(token: Option<&str>) -> HeaderMap {
225 let mut h = HeaderMap::new();
226 if let Some(t) = token {
227 h.insert("authorization", format!("Bearer {t}").parse().unwrap());
228 }
229 h
230 }
231
232 #[test]
233 fn public_always_passes() {
234 let mut mgr = ApiKeyManager::new(Some("master"));
235 let h = make_headers(None);
236 let result = check(&mut mgr, &h, AccessLevel::Public).unwrap();
237 assert!(result.authorized);
238 assert!(result.key_name.is_none());
239 }
240
241 #[test]
242 fn disabled_manager_allows_all() {
243 let mut mgr = ApiKeyManager::new(None);
244 let h = make_headers(None);
245
246 assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
247 assert!(check(&mut mgr, &h, AccessLevel::Write).is_ok());
248 assert!(check(&mut mgr, &h, AccessLevel::Admin).is_ok());
249 }
250
251 #[test]
252 fn missing_token_returns_unauthorized() {
253 let mut mgr = ApiKeyManager::new(Some("master"));
254 let h = make_headers(None);
255
256 assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::UNAUTHORIZED);
257 assert_eq!(check(&mut mgr, &h, AccessLevel::Write).unwrap_err(), StatusCode::UNAUTHORIZED);
258 }
259
260 #[test]
261 fn invalid_token_returns_forbidden() {
262 let mut mgr = ApiKeyManager::new(Some("master"));
263 let h = make_headers(Some("wrong_token"));
264
265 assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::FORBIDDEN);
266 }
267
268 #[test]
269 fn admin_key_has_full_access() {
270 let mut mgr = ApiKeyManager::new(Some("admin_tok"));
271 let h = make_headers(Some("admin_tok"));
272
273 let r = check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
274 assert!(r.authorized);
275 assert_eq!(r.role, Some(KeyRole::Admin));
276
277 let r = check(&mut mgr, &h, AccessLevel::Write).unwrap();
278 assert!(r.authorized);
279
280 let r = check(&mut mgr, &h, AccessLevel::Admin).unwrap();
281 assert!(r.authorized);
282 }
283
284 #[test]
285 fn operator_can_write_not_admin() {
286 let mut mgr = ApiKeyManager::new(Some("master"));
287 mgr.create_key("op", "op_tok", KeyRole::Operator, None);
288 let h = make_headers(Some("op_tok"));
289
290 assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
291 assert!(check(&mut mgr, &h, AccessLevel::Write).is_ok());
292 assert_eq!(check(&mut mgr, &h, AccessLevel::Admin).unwrap_err(), StatusCode::FORBIDDEN);
293 }
294
295 #[test]
296 fn readonly_can_only_read() {
297 let mut mgr = ApiKeyManager::new(Some("master"));
298 mgr.create_key("viewer", "view_tok", KeyRole::ReadOnly, None);
299 let h = make_headers(Some("view_tok"));
300
301 assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
302 assert_eq!(check(&mut mgr, &h, AccessLevel::Write).unwrap_err(), StatusCode::FORBIDDEN);
303 assert_eq!(check(&mut mgr, &h, AccessLevel::Admin).unwrap_err(), StatusCode::FORBIDDEN);
304 }
305
306 #[test]
307 fn check_records_usage() {
308 let mut mgr = ApiKeyManager::new(Some("master"));
309 mgr.create_key("svc", "svc_tok", KeyRole::Operator, None);
310 let h = make_headers(Some("svc_tok"));
311
312 check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
313 check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
314 check(&mut mgr, &h, AccessLevel::Write).unwrap();
315
316 let list = mgr.list();
317 let key = list.iter().find(|k| k.name == "svc").unwrap();
318 assert_eq!(key.request_count, 3);
319 assert!(key.last_used.is_some());
320 }
321
322 #[test]
323 fn peek_does_not_record_usage() {
324 let mut mgr = ApiKeyManager::new(Some("master"));
325 mgr.create_key("peeker", "peek_tok", KeyRole::ReadOnly, None);
326 let h = make_headers(Some("peek_tok"));
327
328 peek(&mgr, &h, AccessLevel::ReadOnly).unwrap();
329 peek(&mgr, &h, AccessLevel::ReadOnly).unwrap();
330
331 let list = mgr.list();
332 let key = list.iter().find(|k| k.name == "peeker").unwrap();
333 assert_eq!(key.request_count, 0);
334 assert!(key.last_used.is_none());
335 }
336
337 #[test]
338 fn auth_result_carries_rate_limit() {
339 let mut mgr = ApiKeyManager::new(Some("master"));
340 mgr.create_key("limited", "lim_tok", KeyRole::Operator, Some(50));
341 let h = make_headers(Some("lim_tok"));
342
343 let r = check(&mut mgr, &h, AccessLevel::Write).unwrap();
344 assert_eq!(r.rate_limit, Some(50));
345 assert_eq!(r.key_name, Some("limited".to_string()));
346 }
347
348 #[test]
349 fn classify_public_endpoints() {
350 assert_eq!(classify_endpoint("GET", "/v1/health"), AccessLevel::Public);
351 assert_eq!(classify_endpoint("GET", "/v1/health/live"), AccessLevel::Public);
352 assert_eq!(classify_endpoint("GET", "/v1/health/ready"), AccessLevel::Public);
353 assert_eq!(classify_endpoint("GET", "/v1/version"), AccessLevel::Public);
354 assert_eq!(classify_endpoint("GET", "/v1/rate-limit"), AccessLevel::Public);
355 }
356
357 #[test]
358 fn classify_readonly_endpoints() {
359 assert_eq!(classify_endpoint("GET", "/v1/metrics"), AccessLevel::ReadOnly);
360 assert_eq!(classify_endpoint("GET", "/v1/daemons"), AccessLevel::ReadOnly);
361 assert_eq!(classify_endpoint("GET", "/v1/logs"), AccessLevel::ReadOnly);
362 assert_eq!(classify_endpoint("GET", "/v1/keys"), AccessLevel::ReadOnly);
363 assert_eq!(classify_endpoint("GET", "/v1/session"), AccessLevel::ReadOnly);
364 }
365
366 #[test]
367 fn classify_write_endpoints() {
368 assert_eq!(classify_endpoint("POST", "/v1/deploy"), AccessLevel::Write);
369 assert_eq!(classify_endpoint("POST", "/v1/estimate"), AccessLevel::Write);
370 assert_eq!(classify_endpoint("POST", "/v1/events"), AccessLevel::Write);
371 assert_eq!(classify_endpoint("DELETE", "/v1/daemons/x"), AccessLevel::Write);
372 }
373
374 #[test]
375 fn classify_admin_endpoints() {
376 assert_eq!(classify_endpoint("POST", "/v1/keys"), AccessLevel::Admin);
377 assert_eq!(classify_endpoint("POST", "/v1/keys/revoke"), AccessLevel::Admin);
378 assert_eq!(classify_endpoint("POST", "/v1/keys/rotate"), AccessLevel::Admin);
379 }
380
381 #[test]
382 fn revoked_key_denied() {
383 let mut mgr = ApiKeyManager::new(Some("master"));
384 mgr.create_key("temp", "temp_tok", KeyRole::Operator, None);
385 mgr.revoke("temp_tok");
386
387 let h = make_headers(Some("temp_tok"));
388 assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::FORBIDDEN);
389 }
390}