Skip to main content

reasonkit_web/portal/
middleware.rs

1//! # Portal Middleware
2//!
3//! JWT authentication middleware for protecting routes.
4//!
5//! ## Usage
6//!
7//! ```ignore
8//! use reasonkit_web::portal::middleware::RequireAuth;
9//!
10//! let protected_routes = Router::new()
11//!     .route("/protected", get(handler))
12//!     .layer(RequireAuth::new(auth_service));
13//! ```
14
15use axum::{
16    extract::Request,
17    http::{header, StatusCode},
18    middleware::Next,
19    response::Response,
20    Json,
21};
22use serde::Serialize;
23
24use crate::portal::auth::{AuthService, Claims};
25
26/// Authentication error response
27#[derive(Debug, Serialize)]
28pub struct AuthErrorResponse {
29    pub error: String,
30    pub code: String,
31}
32
33/// Extract JWT token from Authorization header
34fn extract_token(req: &Request) -> Option<&str> {
35    req.headers()
36        .get(header::AUTHORIZATION)
37        .and_then(|value| value.to_str().ok())
38        .and_then(|value| value.strip_prefix("Bearer "))
39}
40
41/// JWT authentication middleware
42///
43/// Validates the JWT token and injects Claims into request extensions.
44pub async fn require_auth(
45    req: Request,
46    next: Next,
47) -> Result<Response, (StatusCode, Json<AuthErrorResponse>)> {
48    // Extract token from header
49    let token = extract_token(&req).ok_or_else(|| {
50        (
51            StatusCode::UNAUTHORIZED,
52            Json(AuthErrorResponse {
53                error: "Missing or invalid Authorization header".to_string(),
54                code: "MISSING_TOKEN".to_string(),
55            }),
56        )
57    })?;
58
59    // Create auth service with default config
60    let auth_service = AuthService::new(Default::default());
61
62    // Validate token
63    let claims = auth_service.validate_token(token).map_err(|e| {
64        (
65            StatusCode::UNAUTHORIZED,
66            Json(AuthErrorResponse {
67                error: e.to_string(),
68                code: "INVALID_TOKEN".to_string(),
69            }),
70        )
71    })?;
72
73    // Check if it's an access token (not refresh)
74    if claims.token_type != crate::portal::auth::TokenType::Access {
75        return Err((
76            StatusCode::UNAUTHORIZED,
77            Json(AuthErrorResponse {
78                error: "Invalid token type".to_string(),
79                code: "WRONG_TOKEN_TYPE".to_string(),
80            }),
81        ));
82    }
83
84    // Insert claims into request extensions for handlers to access
85    let mut req = req;
86    req.extensions_mut().insert(claims);
87
88    Ok(next.run(req).await)
89}
90
91/// Optional authentication middleware
92///
93/// Validates JWT if present but allows unauthenticated requests.
94pub async fn optional_auth(req: Request, next: Next) -> Response {
95    // Try to extract and validate token
96    if let Some(token) = extract_token(&req) {
97        let auth_service = AuthService::new(Default::default());
98        if let Ok(claims) = auth_service.validate_token(token) {
99            if claims.token_type == crate::portal::auth::TokenType::Access {
100                let mut req = req;
101                req.extensions_mut().insert(claims);
102                return next.run(req).await;
103            }
104        }
105    }
106
107    // Continue without authentication
108    next.run(req).await
109}
110
111/// Scope-checking middleware
112///
113/// Requires the authenticated user to have a specific scope.
114#[allow(clippy::type_complexity)]
115pub fn require_scope(
116    required_scope: &'static str,
117) -> impl Fn(
118    Request,
119    Next,
120) -> std::pin::Pin<
121    Box<
122        dyn std::future::Future<Output = Result<Response, (StatusCode, Json<AuthErrorResponse>)>>
123            + Send,
124    >,
125> + Clone {
126    move |req: Request, next: Next| {
127        Box::pin(async move {
128            // Get claims from extensions (set by require_auth)
129            let claims = req.extensions().get::<Claims>().ok_or_else(|| {
130                (
131                    StatusCode::UNAUTHORIZED,
132                    Json(AuthErrorResponse {
133                        error: "Not authenticated".to_string(),
134                        code: "NOT_AUTHENTICATED".to_string(),
135                    }),
136                )
137            })?;
138
139            // Check if user has required scope
140            if !claims
141                .scopes
142                .iter()
143                .any(|s| s == required_scope || s == "admin")
144            {
145                return Err((
146                    StatusCode::FORBIDDEN,
147                    Json(AuthErrorResponse {
148                        error: format!("Missing required scope: {}", required_scope),
149                        code: "INSUFFICIENT_SCOPE".to_string(),
150                    }),
151                ));
152            }
153
154            Ok(next.run(req).await)
155        })
156    }
157}
158
159/// Extract authenticated user claims from request
160pub fn get_claims(req: &Request) -> Option<&Claims> {
161    req.extensions().get::<Claims>()
162}
163
164/// Axum extractor for authenticated claims
165#[derive(Debug, Clone)]
166pub struct AuthClaims(pub Claims);
167
168#[axum::async_trait]
169impl<S> axum::extract::FromRequestParts<S> for AuthClaims
170where
171    S: Send + Sync,
172{
173    type Rejection = (StatusCode, Json<AuthErrorResponse>);
174
175    async fn from_request_parts(
176        parts: &mut axum::http::request::Parts,
177        _state: &S,
178    ) -> Result<Self, Self::Rejection> {
179        parts
180            .extensions
181            .get::<Claims>()
182            .cloned()
183            .map(AuthClaims)
184            .ok_or_else(|| {
185                (
186                    StatusCode::UNAUTHORIZED,
187                    Json(AuthErrorResponse {
188                        error: "Not authenticated".to_string(),
189                        code: "NOT_AUTHENTICATED".to_string(),
190                    }),
191                )
192            })
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use axum::body::Body;
200    use axum::http::Request;
201
202    #[test]
203    fn test_extract_token_valid() {
204        let req = Request::builder()
205            .header("Authorization", "Bearer test_token_123")
206            .body(Body::empty())
207            .unwrap();
208
209        assert_eq!(extract_token(&req), Some("test_token_123"));
210    }
211
212    #[test]
213    fn test_extract_token_missing() {
214        let req = Request::builder().body(Body::empty()).unwrap();
215
216        assert_eq!(extract_token(&req), None);
217    }
218
219    #[test]
220    fn test_extract_token_invalid_format() {
221        let req = Request::builder()
222            .header("Authorization", "Basic user:pass")
223            .body(Body::empty())
224            .unwrap();
225
226        assert_eq!(extract_token(&req), None);
227    }
228}