llm_registry_api/
auth.rs

1//! Authentication middleware
2//!
3//! This module provides JWT-based authentication middleware for protecting API routes.
4
5use axum::{
6    body::Body,
7    extract::{Request, State},
8    http::{header::AUTHORIZATION, StatusCode},
9    middleware::Next,
10    response::{IntoResponse, Response},
11};
12use std::sync::Arc;
13use tracing::{debug, warn};
14
15use crate::{
16    error::ErrorResponse,
17    jwt::{Claims, JwtManager, TokenError},
18};
19
20/// Extension for storing authenticated user claims in requests
21#[derive(Debug, Clone)]
22pub struct AuthUser {
23    /// JWT claims for the authenticated user
24    pub claims: Claims,
25}
26
27impl AuthUser {
28    /// Create new authenticated user
29    pub fn new(claims: Claims) -> Self {
30        Self { claims }
31    }
32
33    /// Get user ID
34    pub fn user_id(&self) -> &str {
35        &self.claims.sub
36    }
37
38    /// Get user email
39    pub fn email(&self) -> Option<&str> {
40        self.claims.email.as_deref()
41    }
42
43    /// Check if user has a role
44    pub fn has_role(&self, role: &str) -> bool {
45        self.claims.has_role(role)
46    }
47
48    /// Check if user has any of the roles
49    pub fn has_any_role(&self, roles: &[&str]) -> bool {
50        self.claims.has_any_role(roles)
51    }
52
53    /// Check if user has all of the roles
54    pub fn has_all_roles(&self, roles: &[&str]) -> bool {
55        self.claims.has_all_roles(roles)
56    }
57}
58
59/// Authentication state containing JWT manager
60#[derive(Clone)]
61pub struct AuthState {
62    jwt_manager: Arc<JwtManager>,
63}
64
65impl AuthState {
66    /// Create new auth state
67    pub fn new(jwt_manager: JwtManager) -> Self {
68        Self {
69            jwt_manager: Arc::new(jwt_manager),
70        }
71    }
72
73    /// Get JWT manager reference
74    pub fn jwt_manager(&self) -> &JwtManager {
75        &self.jwt_manager
76    }
77}
78
79/// Required authentication middleware
80///
81/// This middleware requires a valid JWT token in the Authorization header.
82/// If authentication fails, it returns a 401 Unauthorized response.
83///
84/// # Usage
85///
86/// ```rust,no_run
87/// use axum::{Router, routing::get, middleware};
88/// use llm_registry_api::auth::{require_auth, AuthState};
89/// use llm_registry_api::jwt::{JwtConfig, JwtManager};
90///
91/// # async fn example() {
92/// let jwt_manager = JwtManager::new(JwtConfig::default()).unwrap();
93/// let auth_state = AuthState::new(jwt_manager);
94///
95/// let app = Router::new()
96///     .route("/protected", get(|| async { "Protected content" }))
97///     .layer(middleware::from_fn_with_state(auth_state.clone(), require_auth));
98/// # }
99/// ```
100pub async fn require_auth(
101    State(auth_state): State<AuthState>,
102    mut request: Request,
103    next: Next,
104) -> Result<Response, AuthError> {
105    debug!("Authenticating request");
106
107    // Extract Authorization header
108    let auth_header = request
109        .headers()
110        .get(AUTHORIZATION)
111        .and_then(|h| h.to_str().ok())
112        .ok_or(AuthError::MissingToken)?;
113
114    // Extract token from header
115    let token = JwtManager::extract_token_from_header(auth_header)
116        .map_err(|_| AuthError::InvalidToken)?;
117
118    // Validate token
119    let claims = auth_state
120        .jwt_manager
121        .validate_token(token)
122        .map_err(|e| match e {
123            TokenError::Expired => AuthError::ExpiredToken,
124            TokenError::NotYetValid => AuthError::InvalidToken,
125            _ => AuthError::InvalidToken,
126        })?;
127
128    debug!("User authenticated: {}", claims.sub);
129
130    // Add user to request extensions
131    request.extensions_mut().insert(AuthUser::new(claims));
132
133    Ok(next.run(request).await)
134}
135
136/// Optional authentication middleware
137///
138/// This middleware attempts to authenticate the user but does not fail if
139/// authentication is unsuccessful. Use this for endpoints that have optional
140/// authentication (e.g., public content that can be personalized for logged-in users).
141///
142/// # Usage
143///
144/// ```rust,no_run
145/// use axum::{Router, routing::get, middleware};
146/// use llm_registry_api::auth::{optional_auth, AuthState};
147/// use llm_registry_api::jwt::{JwtConfig, JwtManager};
148///
149/// # async fn example() {
150/// let jwt_manager = JwtManager::new(JwtConfig::default()).unwrap();
151/// let auth_state = AuthState::new(jwt_manager);
152///
153/// let app = Router::new()
154///     .route("/public", get(|| async { "Public content" }))
155///     .layer(middleware::from_fn_with_state(auth_state.clone(), optional_auth));
156/// # }
157/// ```
158pub async fn optional_auth(
159    State(auth_state): State<AuthState>,
160    mut request: Request,
161    next: Next,
162) -> Response {
163    debug!("Attempting optional authentication");
164
165    // Try to extract and validate token
166    if let Some(auth_header) = request.headers().get(AUTHORIZATION) {
167        if let Ok(header_str) = auth_header.to_str() {
168            if let Ok(token) = JwtManager::extract_token_from_header(header_str) {
169                if let Ok(claims) = auth_state.jwt_manager.validate_token(token) {
170                    debug!("User optionally authenticated: {}", claims.sub);
171                    request.extensions_mut().insert(AuthUser::new(claims));
172                }
173            }
174        }
175    }
176
177    next.run(request).await
178}
179
180/// Role-based authentication middleware
181///
182/// This middleware requires authentication AND checks if the user has one of the
183/// specified roles.
184///
185/// # Usage
186///
187/// ```rust,no_run
188/// use axum::{Router, routing::get, middleware};
189/// use llm_registry_api::auth::{require_role, AuthState};
190/// use llm_registry_api::jwt::{JwtConfig, JwtManager};
191///
192/// # async fn example() {
193/// let jwt_manager = JwtManager::new(JwtConfig::default()).unwrap();
194/// let auth_state = AuthState::new(jwt_manager);
195///
196/// let roles = vec!["admin".to_string(), "moderator".to_string()];
197///
198/// let app = Router::new()
199///     .route("/admin", get(|| async { "Admin content" }))
200///     .layer(middleware::from_fn_with_state(
201///         (auth_state.clone(), roles),
202///         require_role,
203///     ));
204/// # }
205/// ```
206pub async fn require_role(
207    State((auth_state, allowed_roles)): State<(AuthState, Vec<String>)>,
208    mut request: Request,
209    next: Next,
210) -> Result<Response, AuthError> {
211    debug!("Authenticating request with role check");
212
213    // First authenticate
214    let auth_header = request
215        .headers()
216        .get(AUTHORIZATION)
217        .and_then(|h| h.to_str().ok())
218        .ok_or(AuthError::MissingToken)?;
219
220    let token = JwtManager::extract_token_from_header(auth_header)
221        .map_err(|_| AuthError::InvalidToken)?;
222
223    let claims = auth_state
224        .jwt_manager
225        .validate_token(token)
226        .map_err(|e| match e {
227            TokenError::Expired => AuthError::ExpiredToken,
228            TokenError::NotYetValid => AuthError::InvalidToken,
229            _ => AuthError::InvalidToken,
230        })?;
231
232    // Check roles
233    let role_refs: Vec<&str> = allowed_roles.iter().map(|s| s.as_str()).collect();
234    if !claims.has_any_role(&role_refs) {
235        warn!("User {} lacks required role", claims.sub);
236        return Err(AuthError::InsufficientPermissions);
237    }
238
239    debug!("User authenticated with role: {}", claims.sub);
240    request.extensions_mut().insert(AuthUser::new(claims));
241
242    Ok(next.run(request).await)
243}
244
245/// Extract authenticated user from request
246///
247/// This is a helper function to extract the AuthUser from request extensions.
248/// Use this in your handlers after authentication middleware.
249///
250/// # Example
251///
252/// ```rust,no_run
253/// use axum::{extract::Extension};
254/// use llm_registry_api::auth::AuthUser;
255///
256/// async fn my_handler(Extension(user): Extension<AuthUser>) -> String {
257///     format!("Hello, user {}", user.user_id())
258/// }
259/// ```
260pub fn extract_user(request: &Request<Body>) -> Result<&AuthUser, AuthError> {
261    request
262        .extensions()
263        .get::<AuthUser>()
264        .ok_or(AuthError::Unauthenticated)
265}
266
267/// Authentication errors
268#[derive(Debug)]
269pub enum AuthError {
270    /// Missing authentication token
271    MissingToken,
272
273    /// Invalid token format or signature
274    InvalidToken,
275
276    /// Token has expired
277    ExpiredToken,
278
279    /// User is not authenticated
280    Unauthenticated,
281
282    /// User lacks required permissions
283    InsufficientPermissions,
284}
285
286impl IntoResponse for AuthError {
287    fn into_response(self) -> Response {
288        let (status, message) = match self {
289            AuthError::MissingToken => (
290                StatusCode::UNAUTHORIZED,
291                "Missing authentication token",
292            ),
293            AuthError::InvalidToken => (
294                StatusCode::UNAUTHORIZED,
295                "Invalid authentication token",
296            ),
297            AuthError::ExpiredToken => (
298                StatusCode::UNAUTHORIZED,
299                "Authentication token has expired",
300            ),
301            AuthError::Unauthenticated => (
302                StatusCode::UNAUTHORIZED,
303                "Authentication required",
304            ),
305            AuthError::InsufficientPermissions => (
306                StatusCode::FORBIDDEN,
307                "Insufficient permissions",
308            ),
309        };
310
311        let error_response = ErrorResponse {
312            status: status.as_u16(),
313            error: message.to_string(),
314            code: None,
315            timestamp: chrono::Utc::now(),
316        };
317
318        (status, axum::Json(error_response)).into_response()
319    }
320}
321
322impl std::fmt::Display for AuthError {
323    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
324        match self {
325            AuthError::MissingToken => write!(f, "Missing authentication token"),
326            AuthError::InvalidToken => write!(f, "Invalid authentication token"),
327            AuthError::ExpiredToken => write!(f, "Authentication token has expired"),
328            AuthError::Unauthenticated => write!(f, "Authentication required"),
329            AuthError::InsufficientPermissions => write!(f, "Insufficient permissions"),
330        }
331    }
332}
333
334impl std::error::Error for AuthError {}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339    use crate::jwt::{JwtConfig, JwtManager};
340    use axum::{
341        body::Body,
342        extract::Extension,
343        http::{Request, StatusCode},
344        middleware,
345        routing::get,
346        Router,
347    };
348    use tower::ServiceExt;
349
350    fn create_test_jwt_manager() -> JwtManager {
351        let config = JwtConfig::new("test-secret-key")
352            .with_issuer("test")
353            .with_audience("test");
354        JwtManager::new(config).unwrap()
355    }
356
357    async fn protected_handler(Extension(user): axum::extract::Extension<AuthUser>) -> String {
358        format!("Hello, {}", user.user_id())
359    }
360
361    #[tokio::test]
362    async fn test_require_auth_with_valid_token() {
363        let jwt_manager = create_test_jwt_manager();
364        let token = jwt_manager.generate_token("user123").unwrap();
365        let auth_state = AuthState::new(jwt_manager);
366
367        let app = Router::new()
368            .route("/protected", get(protected_handler))
369            .layer(middleware::from_fn_with_state(
370                auth_state.clone(),
371                require_auth,
372            ));
373
374        let request = Request::builder()
375            .uri("/protected")
376            .header(AUTHORIZATION, format!("Bearer {}", token))
377            .body(Body::empty())
378            .unwrap();
379
380        let response = app.oneshot(request).await.unwrap();
381        assert_eq!(response.status(), StatusCode::OK);
382    }
383
384    #[tokio::test]
385    async fn test_require_auth_without_token() {
386        let jwt_manager = create_test_jwt_manager();
387        let auth_state = AuthState::new(jwt_manager);
388
389        let app = Router::new()
390            .route("/protected", get(protected_handler))
391            .layer(middleware::from_fn_with_state(
392                auth_state.clone(),
393                require_auth,
394            ));
395
396        let request = Request::builder()
397            .uri("/protected")
398            .body(Body::empty())
399            .unwrap();
400
401        let response = app.oneshot(request).await.unwrap();
402        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
403    }
404
405    #[tokio::test]
406    async fn test_require_auth_with_invalid_token() {
407        let jwt_manager = create_test_jwt_manager();
408        let auth_state = AuthState::new(jwt_manager);
409
410        let app = Router::new()
411            .route("/protected", get(protected_handler))
412            .layer(middleware::from_fn_with_state(
413                auth_state.clone(),
414                require_auth,
415            ));
416
417        let request = Request::builder()
418            .uri("/protected")
419            .header(AUTHORIZATION, "Bearer invalid.token.here")
420            .body(Body::empty())
421            .unwrap();
422
423        let response = app.oneshot(request).await.unwrap();
424        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
425    }
426
427    #[tokio::test]
428    async fn test_optional_auth_with_token() {
429        let jwt_manager = create_test_jwt_manager();
430        let token = jwt_manager.generate_token("user123").unwrap();
431        let auth_state = AuthState::new(jwt_manager);
432
433        async fn handler(user: Option<axum::extract::Extension<AuthUser>>) -> String {
434            match user {
435                Some(Extension(u)) => format!("Hello, {}", u.user_id()),
436                None => "Hello, guest".to_string(),
437            }
438        }
439
440        let app = Router::new()
441            .route("/public", get(handler))
442            .layer(middleware::from_fn_with_state(
443                auth_state.clone(),
444                optional_auth,
445            ));
446
447        let request = Request::builder()
448            .uri("/public")
449            .header(AUTHORIZATION, format!("Bearer {}", token))
450            .body(Body::empty())
451            .unwrap();
452
453        let response = app.oneshot(request).await.unwrap();
454        assert_eq!(response.status(), StatusCode::OK);
455    }
456
457    #[tokio::test]
458    async fn test_optional_auth_without_token() {
459        let jwt_manager = create_test_jwt_manager();
460        let auth_state = AuthState::new(jwt_manager);
461
462        async fn handler() -> &'static str {
463            "Public content"
464        }
465
466        let app = Router::new()
467            .route("/public", get(handler))
468            .layer(middleware::from_fn_with_state(
469                auth_state.clone(),
470                optional_auth,
471            ));
472
473        let request = Request::builder()
474            .uri("/public")
475            .body(Body::empty())
476            .unwrap();
477
478        let response = app.oneshot(request).await.unwrap();
479        assert_eq!(response.status(), StatusCode::OK);
480    }
481
482    #[test]
483    fn test_auth_user() {
484        let claims = crate::jwt::Claims::new("user123", "test", "test", 3600)
485            .with_email("user@example.com")
486            .with_role("admin");
487
488        let auth_user = AuthUser::new(claims);
489
490        assert_eq!(auth_user.user_id(), "user123");
491        assert_eq!(auth_user.email(), Some("user@example.com"));
492        assert!(auth_user.has_role("admin"));
493        assert!(!auth_user.has_role("moderator"));
494    }
495}