ares/auth/
middleware.rs

1use crate::auth::jwt::AuthService;
2use crate::types::Claims;
3use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
4use std::sync::Arc;
5
6pub async fn auth_middleware(auth_service: Arc<AuthService>, req: Request, next: Next) -> Response {
7    // Extract Authorization header
8    if let Some(auth_header) = req.headers().get("authorization") {
9        if let Ok(auth_str) = auth_header.to_str() {
10            if let Some(token) = auth_str.strip_prefix("Bearer ") {
11                match auth_service.verify_token(token) {
12                    Ok(claims) => {
13                        let mut req = req;
14                        req.extensions_mut().insert(claims);
15                        return next.run(req).await;
16                    }
17                    Err(_) => {
18                        // Invalid token
19                    }
20                }
21            }
22        }
23    }
24
25    // No valid token provided
26    Response::builder()
27        .status(StatusCode::UNAUTHORIZED)
28        .body("Unauthorized".into())
29        .unwrap()
30}
31
32// Extractor for claims
33use axum::extract::FromRequestParts;
34use axum::http::request::Parts;
35
36pub struct AuthUser(pub Claims);
37
38impl<S> FromRequestParts<S> for AuthUser
39where
40    S: Send + Sync,
41{
42    type Rejection = StatusCode;
43
44    async fn from_request_parts(
45        parts: &mut Parts,
46        _state: &S,
47    ) -> std::result::Result<Self, Self::Rejection> {
48        parts
49            .extensions
50            .get::<Claims>()
51            .cloned()
52            .map(AuthUser)
53            .ok_or(StatusCode::UNAUTHORIZED)
54    }
55}