Skip to main content

ares/auth/
middleware.rs

1use crate::auth::jwt::AuthService;
2use crate::types::Claims;
3use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
4use std::sync::Arc;
5
6/// Axum middleware that validates JWT tokens from the Authorization header.
7///
8/// Expects tokens in the format: `Authorization: Bearer <token>`
9/// On success, injects `Claims` into request extensions for downstream handlers.
10pub async fn auth_middleware(auth_service: Arc<AuthService>, req: Request, next: Next) -> Response {
11    // Extract Authorization header
12    if let Some(auth_header) = req.headers().get("authorization") {
13        if let Ok(auth_str) = auth_header.to_str() {
14            if let Some(token) = auth_str.strip_prefix("Bearer ") {
15                match auth_service.verify_token(token) {
16                    Ok(claims) => {
17                        let mut req = req;
18                        req.extensions_mut().insert(claims);
19                        return next.run(req).await;
20                    }
21                    Err(e) => {
22                        tracing::debug!("Token verification failed: {}", e);
23                    }
24                }
25            }
26        }
27    }
28
29    // No valid token provided - return JSON error for consistency
30    Response::builder()
31        .status(StatusCode::UNAUTHORIZED)
32        .header("Content-Type", "application/json")
33        .body(r#"{"error":"Unauthorized"}"#.into())
34        .unwrap()
35}
36
37// Extractor for claims
38use axum::extract::FromRequestParts;
39use axum::http::request::Parts;
40
41/// Extractor for authenticated user claims.
42///
43/// Use in handler signatures to require authentication:
44/// ```ignore
45/// async fn handler(AuthUser(claims): AuthUser) -> impl IntoResponse {
46///     format!("Hello, {}", claims.sub)
47/// }
48/// ```
49pub struct AuthUser(pub Claims);
50
51impl<S> FromRequestParts<S> for AuthUser
52where
53    S: Send + Sync,
54{
55    type Rejection = (StatusCode, axum::Json<serde_json::Value>);
56
57    async fn from_request_parts(
58        parts: &mut Parts,
59        _state: &S,
60    ) -> std::result::Result<Self, Self::Rejection> {
61        parts
62            .extensions
63            .get::<Claims>()
64            .cloned()
65            .map(AuthUser)
66            .ok_or_else(|| {
67                (
68                    StatusCode::UNAUTHORIZED,
69                    axum::Json(serde_json::json!({"error": "Unauthorized"})),
70                )
71            })
72    }
73}
74
75#[cfg(test)]
76mod tests {
77    use super::*;
78    use axum::{
79        body::Body,
80        http::{Request, StatusCode},
81        routing::get,
82        Router,
83    };
84    use tower::ServiceExt;
85
86    fn create_test_auth_service() -> Arc<AuthService> {
87        Arc::new(AuthService::new(
88            "test-secret-key-that-is-at-least-32-chars".to_string(),
89            900,
90            604800,
91        ))
92    }
93
94    async fn protected_handler() -> &'static str {
95        "protected content"
96    }
97
98    fn create_test_app(auth_service: Arc<AuthService>) -> Router {
99        Router::new()
100            .route("/protected", get(protected_handler))
101            .layer(axum::middleware::from_fn(move |req, next| {
102                let auth = auth_service.clone();
103                async move { auth_middleware(auth, req, next).await }
104            }))
105    }
106
107    #[tokio::test]
108    async fn test_middleware_no_auth_header() {
109        let auth_service = create_test_auth_service();
110        let app = create_test_app(auth_service);
111
112        let response = app
113            .oneshot(
114                Request::builder()
115                    .uri("/protected")
116                    .body(Body::empty())
117                    .unwrap(),
118            )
119            .await
120            .unwrap();
121
122        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
123    }
124
125    #[tokio::test]
126    async fn test_middleware_invalid_token() {
127        let auth_service = create_test_auth_service();
128        let app = create_test_app(auth_service);
129
130        let response = app
131            .oneshot(
132                Request::builder()
133                    .uri("/protected")
134                    .header("Authorization", "Bearer invalid.token.here")
135                    .body(Body::empty())
136                    .unwrap(),
137            )
138            .await
139            .unwrap();
140
141        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
142    }
143
144    #[tokio::test]
145    async fn test_middleware_valid_token() {
146        let auth_service = create_test_auth_service();
147        let tokens = auth_service
148            .generate_tokens("user-123", "test@example.com")
149            .expect("should generate tokens");
150
151        let app = create_test_app(auth_service);
152
153        let response = app
154            .oneshot(
155                Request::builder()
156                    .uri("/protected")
157                    .header("Authorization", format!("Bearer {}", tokens.access_token))
158                    .body(Body::empty())
159                    .unwrap(),
160            )
161            .await
162            .unwrap();
163
164        assert_eq!(response.status(), StatusCode::OK);
165    }
166
167    #[tokio::test]
168    async fn test_middleware_malformed_auth_header() {
169        let auth_service = create_test_auth_service();
170        let app = create_test_app(auth_service);
171
172        // Missing "Bearer " prefix
173        let response = app
174            .oneshot(
175                Request::builder()
176                    .uri("/protected")
177                    .header("Authorization", "some-token-without-bearer")
178                    .body(Body::empty())
179                    .unwrap(),
180            )
181            .await
182            .unwrap();
183
184        assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
185    }
186}