1use crate::auth::jwt::AuthService;
2use crate::types::Claims;
3use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
4use std::sync::Arc;
5
6pub async fn auth_middleware(auth_service: Arc<AuthService>, req: Request, next: Next) -> Response {
11 if let Some(auth_header) = req.headers().get("authorization") {
13 if let Ok(auth_str) = auth_header.to_str() {
14 if let Some(token) = auth_str.strip_prefix("Bearer ") {
15 match auth_service.verify_token(token) {
16 Ok(claims) => {
17 let mut req = req;
18 req.extensions_mut().insert(claims);
19 return next.run(req).await;
20 }
21 Err(e) => {
22 tracing::debug!("Token verification failed: {}", e);
23 }
24 }
25 }
26 }
27 }
28
29 Response::builder()
31 .status(StatusCode::UNAUTHORIZED)
32 .header("Content-Type", "application/json")
33 .body(r#"{"error":"Unauthorized"}"#.into())
34 .unwrap()
35}
36
37use axum::extract::FromRequestParts;
39use axum::http::request::Parts;
40
41pub struct AuthUser(pub Claims);
50
51impl<S> FromRequestParts<S> for AuthUser
52where
53 S: Send + Sync,
54{
55 type Rejection = (StatusCode, axum::Json<serde_json::Value>);
56
57 async fn from_request_parts(
58 parts: &mut Parts,
59 _state: &S,
60 ) -> std::result::Result<Self, Self::Rejection> {
61 parts
62 .extensions
63 .get::<Claims>()
64 .cloned()
65 .map(AuthUser)
66 .ok_or_else(|| {
67 (
68 StatusCode::UNAUTHORIZED,
69 axum::Json(serde_json::json!({"error": "Unauthorized"})),
70 )
71 })
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::*;
78 use axum::{
79 body::Body,
80 http::{Request, StatusCode},
81 routing::get,
82 Router,
83 };
84 use tower::ServiceExt;
85
86 fn create_test_auth_service() -> Arc<AuthService> {
87 Arc::new(AuthService::new(
88 "test-secret-key-that-is-at-least-32-chars".to_string(),
89 900,
90 604800,
91 ))
92 }
93
94 async fn protected_handler() -> &'static str {
95 "protected content"
96 }
97
98 fn create_test_app(auth_service: Arc<AuthService>) -> Router {
99 Router::new()
100 .route("/protected", get(protected_handler))
101 .layer(axum::middleware::from_fn(move |req, next| {
102 let auth = auth_service.clone();
103 async move { auth_middleware(auth, req, next).await }
104 }))
105 }
106
107 #[tokio::test]
108 async fn test_middleware_no_auth_header() {
109 let auth_service = create_test_auth_service();
110 let app = create_test_app(auth_service);
111
112 let response = app
113 .oneshot(
114 Request::builder()
115 .uri("/protected")
116 .body(Body::empty())
117 .unwrap(),
118 )
119 .await
120 .unwrap();
121
122 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
123 }
124
125 #[tokio::test]
126 async fn test_middleware_invalid_token() {
127 let auth_service = create_test_auth_service();
128 let app = create_test_app(auth_service);
129
130 let response = app
131 .oneshot(
132 Request::builder()
133 .uri("/protected")
134 .header("Authorization", "Bearer invalid.token.here")
135 .body(Body::empty())
136 .unwrap(),
137 )
138 .await
139 .unwrap();
140
141 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
142 }
143
144 #[tokio::test]
145 async fn test_middleware_valid_token() {
146 let auth_service = create_test_auth_service();
147 let tokens = auth_service
148 .generate_tokens("user-123", "test@example.com")
149 .expect("should generate tokens");
150
151 let app = create_test_app(auth_service);
152
153 let response = app
154 .oneshot(
155 Request::builder()
156 .uri("/protected")
157 .header("Authorization", format!("Bearer {}", tokens.access_token))
158 .body(Body::empty())
159 .unwrap(),
160 )
161 .await
162 .unwrap();
163
164 assert_eq!(response.status(), StatusCode::OK);
165 }
166
167 #[tokio::test]
168 async fn test_middleware_malformed_auth_header() {
169 let auth_service = create_test_auth_service();
170 let app = create_test_app(auth_service);
171
172 let response = app
174 .oneshot(
175 Request::builder()
176 .uri("/protected")
177 .header("Authorization", "some-token-without-bearer")
178 .body(Body::empty())
179 .unwrap(),
180 )
181 .await
182 .unwrap();
183
184 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
185 }
186}