1use crate::auth::jwt::AuthService;
2use crate::types::Claims;
3use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
4use std::sync::Arc;
5
6pub async fn auth_middleware(auth_service: Arc<AuthService>, req: Request, next: Next) -> Response {
7 if let Some(auth_header) = req.headers().get("authorization") {
9 if let Ok(auth_str) = auth_header.to_str() {
10 if let Some(token) = auth_str.strip_prefix("Bearer ") {
11 match auth_service.verify_token(token) {
12 Ok(claims) => {
13 let mut req = req;
14 req.extensions_mut().insert(claims);
15 return next.run(req).await;
16 }
17 Err(_) => {
18 }
20 }
21 }
22 }
23 }
24
25 Response::builder()
27 .status(StatusCode::UNAUTHORIZED)
28 .body("Unauthorized".into())
29 .unwrap()
30}
31
32use axum::extract::FromRequestParts;
34use axum::http::request::Parts;
35
36pub struct AuthUser(pub Claims);
37
38impl<S> FromRequestParts<S> for AuthUser
39where
40 S: Send + Sync,
41{
42 type Rejection = StatusCode;
43
44 async fn from_request_parts(
45 parts: &mut Parts,
46 _state: &S,
47 ) -> std::result::Result<Self, Self::Rejection> {
48 parts
49 .extensions
50 .get::<Claims>()
51 .cloned()
52 .map(AuthUser)
53 .ok_or(StatusCode::UNAUTHORIZED)
54 }
55}