pilgrimage/auth/
web_middleware.rs

1//! Web Authentication Middleware
2//!
3//! This module provides JWT-based authentication middleware for the web console,
4//! ensuring API endpoints are protected and accessible only to authenticated users.
5
6use crate::auth::{DistributedAuthenticator, token::Claims};
7use actix_web::{
8    body::EitherBody,
9    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
10    Error, HttpMessage, HttpResponse,
11};
12use futures_util::future::LocalBoxFuture;
13use log::{debug, warn};
14use std::fs;
15use std::path::PathBuf;
16use std::rc::Rc;
17
18/// JWT Authentication Error types for web middleware
19#[derive(Debug)]
20pub enum WebAuthError {
21    TokenMissing,
22    TokenInvalid,
23    TokenExpired,
24    IoError(std::io::Error),
25}
26
27impl std::fmt::Display for WebAuthError {
28    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
29        match self {
30            WebAuthError::TokenMissing => write!(f, "Authorization token is missing"),
31            WebAuthError::TokenInvalid => write!(f, "Authorization token is invalid"),
32            WebAuthError::TokenExpired => write!(f, "Authorization token has expired"),
33            WebAuthError::IoError(e) => write!(f, "IO error: {}", e),
34        }
35    }
36}
37
38impl std::error::Error for WebAuthError {}
39
40/// JWT authentication middleware factory
41pub struct JwtAuth {
42    authenticator: Rc<DistributedAuthenticator>,
43}
44
45impl JwtAuth {
46    /// Create a new JWT authentication middleware
47    pub fn new() -> Result<Self, WebAuthError> {
48        let jwt_secret = Self::get_jwt_secret()?;
49        let authenticator = DistributedAuthenticator::new(
50            jwt_secret,
51            "pilgrimage-web".to_string(),
52        );
53
54        Ok(Self {
55            authenticator: Rc::new(authenticator),
56        })
57    }
58
59    /// Get JWT secret (shared with CLI authentication)
60    fn get_jwt_secret() -> Result<Vec<u8>, WebAuthError> {
61        let session_dir = Self::get_session_directory()?;
62        let secret_file = session_dir.join("jwt_secret");
63
64        if secret_file.exists() {
65            let secret_data = fs::read(&secret_file).map_err(WebAuthError::IoError)?;
66            if secret_data.len() == 32 {
67                debug!("Loaded existing JWT secret for web middleware");
68                Ok(secret_data)
69            } else {
70                warn!("Invalid JWT secret file found, using fallback");
71                // If the secret is invalid, create a new one
72                let secret = DistributedAuthenticator::generate_jwt_secret();
73                fs::write(&secret_file, &secret).map_err(WebAuthError::IoError)?;
74                Ok(secret)
75            }
76        } else {
77            // If no secret exists, create one
78            if !session_dir.exists() {
79                fs::create_dir_all(&session_dir).map_err(WebAuthError::IoError)?;
80            }
81            let secret = DistributedAuthenticator::generate_jwt_secret();
82            fs::write(&secret_file, &secret).map_err(WebAuthError::IoError)?;
83            Ok(secret)
84        }
85    }
86
87    /// Get session directory path
88    fn get_session_directory() -> Result<PathBuf, WebAuthError> {
89        if let Some(home_dir) = dirs::home_dir() {
90            Ok(home_dir.join(".pilgrimage"))
91        } else {
92            // Fallback to current directory
93            Ok(PathBuf::from(".pilgrimage"))
94        }
95    }
96
97    /// Extract token from Authorization header
98    fn extract_token(req: &ServiceRequest) -> Option<String> {
99        req.headers()
100            .get("Authorization")
101            .and_then(|h| h.to_str().ok())
102            .and_then(|h| {
103                if h.starts_with("Bearer ") {
104                    Some(h[7..].to_string())
105                } else {
106                    None
107                }
108            })
109    }
110}
111
112impl<S, B> Transform<S, ServiceRequest> for JwtAuth
113where
114    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
115    S::Future: 'static,
116    B: 'static,
117{
118    type Response = ServiceResponse<EitherBody<B>>;
119    type Error = Error;
120    type InitError = ();
121    type Transform = JwtAuthMiddleware<S>;
122    type Future = std::future::Ready<Result<Self::Transform, Self::InitError>>;
123
124    fn new_transform(&self, service: S) -> Self::Future {
125        std::future::ready(Ok(JwtAuthMiddleware {
126            service: Rc::new(service),
127            authenticator: self.authenticator.clone(),
128        }))
129    }
130}
131
132/// JWT authentication middleware service
133pub struct JwtAuthMiddleware<S> {
134    service: Rc<S>,
135    authenticator: Rc<DistributedAuthenticator>,
136}
137
138impl<S, B> Service<ServiceRequest> for JwtAuthMiddleware<S>
139where
140    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
141    S::Future: 'static,
142    B: 'static,
143{
144    type Response = ServiceResponse<EitherBody<B>>;
145    type Error = Error;
146    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
147
148    forward_ready!(service);
149
150    fn call(&self, req: ServiceRequest) -> Self::Future {
151        let service = self.service.clone();
152        let authenticator = self.authenticator.clone();
153
154        Box::pin(async move {
155            // Extract token from Authorization header
156            let token = match JwtAuth::extract_token(&req) {
157                Some(token) => token,
158                None => {
159                    let response = HttpResponse::Unauthorized()
160                        .json(serde_json::json!({
161                            "error": "Authorization token required",
162                            "message": "Please include 'Authorization: Bearer <token>' header"
163                        }));
164                    return Ok(req.into_response(response).map_into_right_body());
165                }
166            };
167
168            // Validate token
169            let validation_result = authenticator.validate_token(&token);
170            if validation_result.valid {
171                debug!("Token validated for user: {:?}", validation_result.client_id);
172
173                // Create a proper claims structure from validation result
174                // For client tokens, we need to decode it properly to get the expiration time
175                let claims = if let Some(client_id) = validation_result.client_id {
176                    // Try to decode the token again to get the actual expiration time
177                    match authenticator.decode_client_token(&token) {
178                        Ok(client_claims) => Claims {
179                            sub: client_claims.username,
180                            exp: client_claims.exp as usize,
181                            roles: client_claims.roles,
182                        },
183                        Err(_) => {
184                            // Fallback if decoding fails
185                            Claims {
186                                sub: client_id,
187                                exp: (std::time::SystemTime::now()
188                                    .duration_since(std::time::UNIX_EPOCH)
189                                    .unwrap()
190                                    .as_secs() + 3600) as usize, // 1 hour from now
191                                roles: validation_result.permissions.clone(),
192                            }
193                        }
194                    }
195                } else if let Some(node_id) = validation_result.node_id {
196                    // For node tokens
197                    Claims {
198                        sub: node_id,
199                        exp: (std::time::SystemTime::now()
200                            .duration_since(std::time::UNIX_EPOCH)
201                            .unwrap()
202                            .as_secs() + 3600) as usize,
203                        roles: validation_result.permissions.clone(),
204                    }
205                } else {
206                    // Unknown token type
207                    Claims {
208                        sub: "unknown".to_string(),
209                        exp: 0,
210                        roles: validation_result.permissions.clone(),
211                    }
212                };
213
214                // Add user information to request extensions for use in handlers
215                req.extensions_mut().insert(claims);
216
217                // Continue to the actual service
218                let res = service.call(req).await?;
219                Ok(res.map_into_left_body())
220            } else {
221                warn!("Token validation failed: {:?}", validation_result.error_message);
222                let response = HttpResponse::Unauthorized()
223                    .json(serde_json::json!({
224                        "error": "Invalid or expired token",
225                        "message": validation_result.error_message.unwrap_or_else(|| "Please login again to get a new token".to_string())
226                    }));
227                Ok(req.into_response(response).map_into_right_body())
228            }
229        })
230    }
231}
232
233/// Helper function to extract authenticated user from request
234/// Use this function within handler functions where you have access to the HttpRequest
235pub fn get_authenticated_user_claims(req: &actix_web::HttpRequest) -> Option<Claims> {
236    req.extensions().get::<Claims>().cloned()
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242    use actix_web::{test, web, App, HttpResponse};
243
244    async fn test_handler() -> Result<HttpResponse, Error> {
245        Ok(HttpResponse::Ok().json(serde_json::json!({"message": "Success"})))
246    }
247
248    #[actix_rt::test]
249    async fn test_jwt_middleware_no_token() {
250        let auth_middleware = JwtAuth::new().expect("Failed to create auth middleware");
251
252        let app = test::init_service(
253            App::new()
254                .wrap(auth_middleware)
255                .route("/test", web::get().to(test_handler))
256        )
257        .await;
258
259        let req = test::TestRequest::get().uri("/test").to_request();
260        let resp = test::call_service(&app, req).await;
261
262        assert_eq!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED);
263    }
264
265    #[actix_rt::test]
266    async fn test_jwt_middleware_invalid_token() {
267        let auth_middleware = JwtAuth::new().expect("Failed to create auth middleware");
268
269        let app = test::init_service(
270            App::new()
271                .wrap(auth_middleware)
272                .route("/test", web::get().to(test_handler))
273        )
274        .await;
275
276        let req = test::TestRequest::get()
277            .uri("/test")
278            .insert_header(("Authorization", "Bearer invalid-token"))
279            .to_request();
280        let resp = test::call_service(&app, req).await;
281
282        assert_eq!(resp.status(), actix_web::http::StatusCode::UNAUTHORIZED);
283    }
284}