Skip to main content

axon/
auth_middleware.rs

1//! Auth Middleware — role-based authentication gate for AxonServer.
2//!
3//! Replaces the simple single-token `check_auth` with ApiKeyManager-backed
4//! validation that enforces role-based access control on all protected endpoints.
5//!
6//! Endpoint access levels:
7//!   - Public:   no auth required (health, version, rate-limit)
8//!   - ReadOnly: any valid key (metrics, list daemons, logs, keys list, session reads)
9//!   - Write:    Operator or Admin (deploy, estimate, events, supervisor control, session writes)
10//!   - Admin:    Admin only (key management: create, revoke, rotate)
11//!
12//! When ApiKeyManager is disabled (no auth_token configured), all requests pass.
13//! When enabled, Bearer token is validated against the key registry and role checked.
14
15use crate::api_keys::{ApiKeyManager, KeyRole, ValidationResult};
16use axum::http::{HeaderMap, StatusCode};
17use serde::{Deserialize, Serialize};
18
19// ── Access levels ───────────────────────────────────────────────────────
20
21/// Required access level for an endpoint.
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
23#[serde(rename_all = "lowercase")]
24pub enum AccessLevel {
25    /// No authentication required.
26    Public,
27    /// Any valid key (Admin, Operator, or ReadOnly).
28    ReadOnly,
29    /// Operator or Admin — write operations.
30    Write,
31    /// Admin only — key management.
32    Admin,
33}
34
35impl AccessLevel {
36    pub fn as_str(&self) -> &'static str {
37        match self {
38            AccessLevel::Public => "public",
39            AccessLevel::ReadOnly => "readonly",
40            AccessLevel::Write => "write",
41            AccessLevel::Admin => "admin",
42        }
43    }
44}
45
46// ── Auth result ─────────────────────────────────────────────────────────
47
48/// Result of an authentication check.
49#[derive(Debug, Clone)]
50pub struct AuthResult {
51    /// Whether the request is authorized.
52    pub authorized: bool,
53    /// The key name that was used (if any).
54    pub key_name: Option<String>,
55    /// The role of the key (if any).
56    pub role: Option<KeyRole>,
57    /// Per-key rate limit override (if any).
58    pub rate_limit: Option<u32>,
59}
60
61impl AuthResult {
62    fn allowed(v: &ValidationResult) -> Self {
63        AuthResult {
64            authorized: true,
65            key_name: v.key_name.clone(),
66            role: v.role,
67            rate_limit: v.rate_limit,
68        }
69    }
70
71    fn public() -> Self {
72        AuthResult {
73            authorized: true,
74            key_name: None,
75            role: None,
76            rate_limit: None,
77        }
78    }
79}
80
81// ── Token extraction ────────────────────────────────────────────────────
82
83/// Extract Bearer token from Authorization header.
84fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
85    headers
86        .get("authorization")
87        .and_then(|v| v.to_str().ok())
88        .and_then(|v| v.strip_prefix("Bearer "))
89}
90
91// ── Auth gate ───────────────────────────────────────────────────────────
92
93/// Check authentication and authorization for a request.
94///
95/// - If `api_keys` is disabled, all requests pass (backwards compat).
96/// - For `Public` endpoints, always passes.
97/// - For `ReadOnly` endpoints, any valid key is sufficient.
98/// - For `Write` endpoints, key must have `can_write()` (Operator or Admin).
99/// - For `Admin` endpoints, key must have `can_manage_keys()` (Admin only).
100///
101/// Returns `Ok(AuthResult)` if authorized, `Err(StatusCode)` if not.
102pub fn check(
103    api_keys: &mut ApiKeyManager,
104    headers: &HeaderMap,
105    level: AccessLevel,
106) -> Result<AuthResult, StatusCode> {
107    // Public endpoints always pass
108    if level == AccessLevel::Public {
109        return Ok(AuthResult::public());
110    }
111
112    // If key management is disabled, all requests pass (no auth configured)
113    if !api_keys.is_enabled() {
114        return Ok(AuthResult::public());
115    }
116
117    // Extract Bearer token
118    let token = match extract_bearer(headers) {
119        Some(t) => t,
120        None => return Err(StatusCode::UNAUTHORIZED),
121    };
122
123    // Validate token
124    let validation = api_keys.validate(token);
125    if !validation.valid {
126        return Err(StatusCode::FORBIDDEN);
127    }
128
129    // Check role permissions
130    let role = validation.role.unwrap_or(KeyRole::ReadOnly);
131    match level {
132        AccessLevel::Public => Ok(AuthResult::allowed(&validation)),
133        AccessLevel::ReadOnly => {
134            // Any valid key can read
135            Ok(AuthResult::allowed(&validation))
136        }
137        AccessLevel::Write => {
138            if role.can_write() {
139                Ok(AuthResult::allowed(&validation))
140            } else {
141                Err(StatusCode::FORBIDDEN)
142            }
143        }
144        AccessLevel::Admin => {
145            if role.can_manage_keys() {
146                Ok(AuthResult::allowed(&validation))
147            } else {
148                Err(StatusCode::FORBIDDEN)
149            }
150        }
151    }
152}
153
154/// Check authentication without recording usage (for peek/status endpoints).
155pub fn peek(
156    api_keys: &ApiKeyManager,
157    headers: &HeaderMap,
158    level: AccessLevel,
159) -> Result<AuthResult, StatusCode> {
160    if level == AccessLevel::Public {
161        return Ok(AuthResult::public());
162    }
163
164    if !api_keys.is_enabled() {
165        return Ok(AuthResult::public());
166    }
167
168    let token = match extract_bearer(headers) {
169        Some(t) => t,
170        None => return Err(StatusCode::UNAUTHORIZED),
171    };
172
173    let validation = api_keys.peek(token);
174    if !validation.valid {
175        return Err(StatusCode::FORBIDDEN);
176    }
177
178    let role = validation.role.unwrap_or(KeyRole::ReadOnly);
179    match level {
180        AccessLevel::Public | AccessLevel::ReadOnly => Ok(AuthResult::allowed(&validation)),
181        AccessLevel::Write => {
182            if role.can_write() {
183                Ok(AuthResult::allowed(&validation))
184            } else {
185                Err(StatusCode::FORBIDDEN)
186            }
187        }
188        AccessLevel::Admin => {
189            if role.can_manage_keys() {
190                Ok(AuthResult::allowed(&validation))
191            } else {
192                Err(StatusCode::FORBIDDEN)
193            }
194        }
195    }
196}
197
198/// Classify an endpoint path + method into an AccessLevel.
199pub fn classify_endpoint(method: &str, path: &str) -> AccessLevel {
200    // Public endpoints — no auth required
201    if path.starts_with("/v1/health") || path == "/v1/version" || path == "/v1/rate-limit" {
202        return AccessLevel::Public;
203    }
204
205    // Admin endpoints — key management writes
206    if path.starts_with("/v1/keys") && method != "GET" {
207        return AccessLevel::Admin;
208    }
209
210    // Write endpoints
211    match method {
212        "POST" | "PUT" | "DELETE" | "PATCH" => AccessLevel::Write,
213        _ => AccessLevel::ReadOnly,
214    }
215}
216
217// ── Tests ────────────────────────────────────────────────────────────────
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::api_keys::ApiKeyManager;
223
224    fn make_headers(token: Option<&str>) -> HeaderMap {
225        let mut h = HeaderMap::new();
226        if let Some(t) = token {
227            h.insert("authorization", format!("Bearer {t}").parse().unwrap());
228        }
229        h
230    }
231
232    #[test]
233    fn public_always_passes() {
234        let mut mgr = ApiKeyManager::new(Some("master"));
235        let h = make_headers(None);
236        let result = check(&mut mgr, &h, AccessLevel::Public).unwrap();
237        assert!(result.authorized);
238        assert!(result.key_name.is_none());
239    }
240
241    #[test]
242    fn disabled_manager_allows_all() {
243        let mut mgr = ApiKeyManager::new(None);
244        let h = make_headers(None);
245
246        assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
247        assert!(check(&mut mgr, &h, AccessLevel::Write).is_ok());
248        assert!(check(&mut mgr, &h, AccessLevel::Admin).is_ok());
249    }
250
251    #[test]
252    fn missing_token_returns_unauthorized() {
253        let mut mgr = ApiKeyManager::new(Some("master"));
254        let h = make_headers(None);
255
256        assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::UNAUTHORIZED);
257        assert_eq!(check(&mut mgr, &h, AccessLevel::Write).unwrap_err(), StatusCode::UNAUTHORIZED);
258    }
259
260    #[test]
261    fn invalid_token_returns_forbidden() {
262        let mut mgr = ApiKeyManager::new(Some("master"));
263        let h = make_headers(Some("wrong_token"));
264
265        assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::FORBIDDEN);
266    }
267
268    #[test]
269    fn admin_key_has_full_access() {
270        let mut mgr = ApiKeyManager::new(Some("admin_tok"));
271        let h = make_headers(Some("admin_tok"));
272
273        let r = check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
274        assert!(r.authorized);
275        assert_eq!(r.role, Some(KeyRole::Admin));
276
277        let r = check(&mut mgr, &h, AccessLevel::Write).unwrap();
278        assert!(r.authorized);
279
280        let r = check(&mut mgr, &h, AccessLevel::Admin).unwrap();
281        assert!(r.authorized);
282    }
283
284    #[test]
285    fn operator_can_write_not_admin() {
286        let mut mgr = ApiKeyManager::new(Some("master"));
287        mgr.create_key("op", "op_tok", KeyRole::Operator, None);
288        let h = make_headers(Some("op_tok"));
289
290        assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
291        assert!(check(&mut mgr, &h, AccessLevel::Write).is_ok());
292        assert_eq!(check(&mut mgr, &h, AccessLevel::Admin).unwrap_err(), StatusCode::FORBIDDEN);
293    }
294
295    #[test]
296    fn readonly_can_only_read() {
297        let mut mgr = ApiKeyManager::new(Some("master"));
298        mgr.create_key("viewer", "view_tok", KeyRole::ReadOnly, None);
299        let h = make_headers(Some("view_tok"));
300
301        assert!(check(&mut mgr, &h, AccessLevel::ReadOnly).is_ok());
302        assert_eq!(check(&mut mgr, &h, AccessLevel::Write).unwrap_err(), StatusCode::FORBIDDEN);
303        assert_eq!(check(&mut mgr, &h, AccessLevel::Admin).unwrap_err(), StatusCode::FORBIDDEN);
304    }
305
306    #[test]
307    fn check_records_usage() {
308        let mut mgr = ApiKeyManager::new(Some("master"));
309        mgr.create_key("svc", "svc_tok", KeyRole::Operator, None);
310        let h = make_headers(Some("svc_tok"));
311
312        check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
313        check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap();
314        check(&mut mgr, &h, AccessLevel::Write).unwrap();
315
316        let list = mgr.list();
317        let key = list.iter().find(|k| k.name == "svc").unwrap();
318        assert_eq!(key.request_count, 3);
319        assert!(key.last_used.is_some());
320    }
321
322    #[test]
323    fn peek_does_not_record_usage() {
324        let mut mgr = ApiKeyManager::new(Some("master"));
325        mgr.create_key("peeker", "peek_tok", KeyRole::ReadOnly, None);
326        let h = make_headers(Some("peek_tok"));
327
328        peek(&mgr, &h, AccessLevel::ReadOnly).unwrap();
329        peek(&mgr, &h, AccessLevel::ReadOnly).unwrap();
330
331        let list = mgr.list();
332        let key = list.iter().find(|k| k.name == "peeker").unwrap();
333        assert_eq!(key.request_count, 0);
334        assert!(key.last_used.is_none());
335    }
336
337    #[test]
338    fn auth_result_carries_rate_limit() {
339        let mut mgr = ApiKeyManager::new(Some("master"));
340        mgr.create_key("limited", "lim_tok", KeyRole::Operator, Some(50));
341        let h = make_headers(Some("lim_tok"));
342
343        let r = check(&mut mgr, &h, AccessLevel::Write).unwrap();
344        assert_eq!(r.rate_limit, Some(50));
345        assert_eq!(r.key_name, Some("limited".to_string()));
346    }
347
348    #[test]
349    fn classify_public_endpoints() {
350        assert_eq!(classify_endpoint("GET", "/v1/health"), AccessLevel::Public);
351        assert_eq!(classify_endpoint("GET", "/v1/health/live"), AccessLevel::Public);
352        assert_eq!(classify_endpoint("GET", "/v1/health/ready"), AccessLevel::Public);
353        assert_eq!(classify_endpoint("GET", "/v1/version"), AccessLevel::Public);
354        assert_eq!(classify_endpoint("GET", "/v1/rate-limit"), AccessLevel::Public);
355    }
356
357    #[test]
358    fn classify_readonly_endpoints() {
359        assert_eq!(classify_endpoint("GET", "/v1/metrics"), AccessLevel::ReadOnly);
360        assert_eq!(classify_endpoint("GET", "/v1/daemons"), AccessLevel::ReadOnly);
361        assert_eq!(classify_endpoint("GET", "/v1/logs"), AccessLevel::ReadOnly);
362        assert_eq!(classify_endpoint("GET", "/v1/keys"), AccessLevel::ReadOnly);
363        assert_eq!(classify_endpoint("GET", "/v1/session"), AccessLevel::ReadOnly);
364    }
365
366    #[test]
367    fn classify_write_endpoints() {
368        assert_eq!(classify_endpoint("POST", "/v1/deploy"), AccessLevel::Write);
369        assert_eq!(classify_endpoint("POST", "/v1/estimate"), AccessLevel::Write);
370        assert_eq!(classify_endpoint("POST", "/v1/events"), AccessLevel::Write);
371        assert_eq!(classify_endpoint("DELETE", "/v1/daemons/x"), AccessLevel::Write);
372    }
373
374    #[test]
375    fn classify_admin_endpoints() {
376        assert_eq!(classify_endpoint("POST", "/v1/keys"), AccessLevel::Admin);
377        assert_eq!(classify_endpoint("POST", "/v1/keys/revoke"), AccessLevel::Admin);
378        assert_eq!(classify_endpoint("POST", "/v1/keys/rotate"), AccessLevel::Admin);
379    }
380
381    #[test]
382    fn revoked_key_denied() {
383        let mut mgr = ApiKeyManager::new(Some("master"));
384        mgr.create_key("temp", "temp_tok", KeyRole::Operator, None);
385        mgr.revoke("temp_tok");
386
387        let h = make_headers(Some("temp_tok"));
388        assert_eq!(check(&mut mgr, &h, AccessLevel::ReadOnly).unwrap_err(), StatusCode::FORBIDDEN);
389    }
390}