1use crate::auth::{DistributedAuthenticator, token::Claims};
7use actix_web::{
8 body::EitherBody,
9 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10 Error, HttpMessage, HttpResponse,
11};
12use futures_util::future::LocalBoxFuture;
13use log::{debug, warn};
14use std::fs;
15use std::path::PathBuf;
16use std::rc::Rc;
17
18#[derive(Debug)]
20pub enum WebAuthError {
21 TokenMissing,
22 TokenInvalid,
23 TokenExpired,
24 IoError(std::io::Error),
25}
26
27impl std::fmt::Display for WebAuthError {
28 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29 match self {
30 WebAuthError::TokenMissing => write!(f, "Authorization token is missing"),
31 WebAuthError::TokenInvalid => write!(f, "Authorization token is invalid"),
32 WebAuthError::TokenExpired => write!(f, "Authorization token has expired"),
33 WebAuthError::IoError(e) => write!(f, "IO error: {}", e),
34 }
35 }
36}
37
38impl std::error::Error for WebAuthError {}
39
40pub struct JwtAuth {
42 authenticator: Rc<DistributedAuthenticator>,
43}
44
45impl JwtAuth {
46 pub fn new() -> Result<Self, WebAuthError> {
48 let jwt_secret = Self::get_jwt_secret()?;
49 let authenticator = DistributedAuthenticator::new(
50 jwt_secret,
51 "pilgrimage-web".to_string(),
52 );
53
54 Ok(Self {
55 authenticator: Rc::new(authenticator),
56 })
57 }
58
59 fn get_jwt_secret() -> Result<Vec<u8>, WebAuthError> {
61 let session_dir = Self::get_session_directory()?;
62 let secret_file = session_dir.join("jwt_secret");
63
64 if secret_file.exists() {
65 let secret_data = fs::read(&secret_file).map_err(WebAuthError::IoError)?;
66 if secret_data.len() == 32 {
67 debug!("Loaded existing JWT secret for web middleware");
68 Ok(secret_data)
69 } else {
70 warn!("Invalid JWT secret file found, using fallback");
71 let secret = DistributedAuthenticator::generate_jwt_secret();
73 fs::write(&secret_file, &secret).map_err(WebAuthError::IoError)?;
74 Ok(secret)
75 }
76 } else {
77 if !session_dir.exists() {
79 fs::create_dir_all(&session_dir).map_err(WebAuthError::IoError)?;
80 }
81 let secret = DistributedAuthenticator::generate_jwt_secret();
82 fs::write(&secret_file, &secret).map_err(WebAuthError::IoError)?;
83 Ok(secret)
84 }
85 }
86
87 fn get_session_directory() -> Result<PathBuf, WebAuthError> {
89 if let Some(home_dir) = dirs::home_dir() {
90 Ok(home_dir.join(".pilgrimage"))
91 } else {
92 Ok(PathBuf::from(".pilgrimage"))
94 }
95 }
96
97 fn extract_token(req: &ServiceRequest) -> Option<String> {
99 req.headers()
100 .get("Authorization")
101 .and_then(|h| h.to_str().ok())
102 .and_then(|h| {
103 if h.starts_with("Bearer ") {
104 Some(h[7..].to_string())
105 } else {
106 None
107 }
108 })
109 }
110}
111
112impl<S, B> Transform<S, ServiceRequest> for JwtAuth
113where
114 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
115 S::Future: 'static,
116 B: 'static,
117{
118 type Response = ServiceResponse<EitherBody<B>>;
119 type Error = Error;
120 type InitError = ();
121 type Transform = JwtAuthMiddleware<S>;
122 type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
123
124 fn new_transform(&self, service: S) -> Self::Future {
125 std::future::ready(Ok(JwtAuthMiddleware {
126 service: Rc::new(service),
127 authenticator: self.authenticator.clone(),
128 }))
129 }
130}
131
132pub struct JwtAuthMiddleware<S> {
134 service: Rc<S>,
135 authenticator: Rc<DistributedAuthenticator>,
136}
137
138impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
139where
140 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
141 S::Future: 'static,
142 B: 'static,
143{
144 type Response = ServiceResponse<EitherBody<B>>;
145 type Error = Error;
146 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
147
148 forward_ready!(service);
149
150 fn call(&self, req: ServiceRequest) -> Self::Future {
151 let service = self.service.clone();
152 let authenticator = self.authenticator.clone();
153
154 Box::pin(async move {
155 let token = match JwtAuth::extract_token(&req) {
157 Some(token) => token,
158 None => {
159 let response = HttpResponse::Unauthorized()
160 .json(serde_json::json!({
161 "error": "Authorization token required",
162 "message": "Please include 'Authorization: Bearer <token>' header"
163 }));
164 return Ok(req.into_response(response).map_into_right_body());
165 }
166 };
167
168 let validation_result = authenticator.validate_token(&token);
170 if validation_result.valid {
171 debug!("Token validated for user: {:?}", validation_result.client_id);
172
173 let claims = if let Some(client_id) = validation_result.client_id {
176 match authenticator.decode_client_token(&token) {
178 Ok(client_claims) => Claims {
179 sub: client_claims.username,
180 exp: client_claims.exp as usize,
181 roles: client_claims.roles,
182 },
183 Err(_) => {
184 Claims {
186 sub: client_id,
187 exp: (std::time::SystemTime::now()
188 .duration_since(std::time::UNIX_EPOCH)
189 .unwrap()
190 .as_secs() + 3600) as usize, roles: validation_result.permissions.clone(),
192 }
193 }
194 }
195 } else if let Some(node_id) = validation_result.node_id {
196 Claims {
198 sub: node_id,
199 exp: (std::time::SystemTime::now()
200 .duration_since(std::time::UNIX_EPOCH)
201 .unwrap()
202 .as_secs() + 3600) as usize,
203 roles: validation_result.permissions.clone(),
204 }
205 } else {
206 Claims {
208 sub: "unknown".to_string(),
209 exp: 0,
210 roles: validation_result.permissions.clone(),
211 }
212 };
213
214 req.extensions_mut().insert(claims);
216
217 let res = service.call(req).await?;
219 Ok(res.map_into_left_body())
220 } else {
221 warn!("Token validation failed: {:?}", validation_result.error_message);
222 let response = HttpResponse::Unauthorized()
223 .json(serde_json::json!({
224 "error": "Invalid or expired token",
225 "message": validation_result.error_message.unwrap_or_else(|| "Please login again to get a new token".to_string())
226 }));
227 Ok(req.into_response(response).map_into_right_body())
228 }
229 })
230 }
231}
232
233pub fn get_authenticated_user_claims(req: &actix_web::HttpRequest) -> Option<Claims> {
236 req.extensions().get::<Claims>().cloned()
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242 use actix_web::{test, web, App, HttpResponse};
243
244 async fn test_handler() -> Result<HttpResponse, Error> {
245 Ok(HttpResponse::Ok().json(serde_json::json!({"message": "Success"})))
246 }
247
248 #[actix_rt::test]
249 async fn test_jwt_middleware_no_token() {
250 let auth_middleware = JwtAuth::new().expect("Failed to create auth middleware");
251
252 let app = test::init_service(
253 App::new()
254 .wrap(auth_middleware)
255 .route("/test", web::get().to(test_handler))
256 )
257 .await;
258
259 let req = test::TestRequest::get().uri("/test").to_request();
260 let resp = test::call_service(&app, req).await;
261
262 assert_eq!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED);
263 }
264
265 #[actix_rt::test]
266 async fn test_jwt_middleware_invalid_token() {
267 let auth_middleware = JwtAuth::new().expect("Failed to create auth middleware");
268
269 let app = test::init_service(
270 App::new()
271 .wrap(auth_middleware)
272 .route("/test", web::get().to(test_handler))
273 )
274 .await;
275
276 let req = test::TestRequest::get()
277 .uri("/test")
278 .insert_header(("Authorization", "Bearer invalid-token"))
279 .to_request();
280 let resp = test::call_service(&app, req).await;
281
282 assert_eq!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED);
283 }
284}