corteq/
middleware.rs

1//! Middleware for tenant context extraction and validation
2
3use crate::{CorteqApp, TenantContext};
4use actix_web::{
5    dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6    error::{ErrorBadRequest, ErrorUnauthorized},
7    web, Error, FromRequest, HttpMessage,
8};
9use futures::future::{ready, LocalBoxFuture, Ready};
10use std::rc::Rc;
11
12/// Actix Web extractor for tenant context
13///
14/// This extractor ensures that every request has a valid tenant context.
15/// It will automatically reject requests without valid tenant information.
16///
17/// # Example
18///
19/// ```rust,no_run
20/// use actix_web::{web, HttpResponse};
21/// use corteq::TenantExtractor;
22///
23/// async fn my_handler(tenant: TenantExtractor) -> HttpResponse {
24///     HttpResponse::Ok().body(format!("Tenant: {}", tenant.0.tenant_id))
25/// }
26/// ```
27#[derive(Debug, Clone)]
28pub struct TenantExtractor(pub TenantContext);
29
30impl FromRequest for TenantExtractor {
31    type Error = Error;
32    type Future = Ready<Result<Self, Self::Error>>;
33
34    fn from_request(
35        req: &actix_web::HttpRequest,
36        _payload: &mut actix_web::dev::Payload,
37    ) -> Self::Future {
38        // Try to get tenant context from request extensions
39        match req.extensions().get::<TenantContext>().cloned() {
40            Some(ctx) => ready(Ok(TenantExtractor(ctx))),
41            None => ready(Err(ErrorBadRequest("Missing tenant context"))),
42        }
43    }
44}
45
46/// Middleware factory for tenant context injection
47///
48/// This middleware extracts the JWT token, validates it, and injects
49/// the tenant context into request extensions for downstream handlers.
50#[derive(Debug, Default)]
51pub struct TenantContextMiddleware;
52
53impl<S, B> Transform<S, ServiceRequest> for TenantContextMiddleware
54where
55    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
56    S::Future: 'static,
57    B: 'static,
58{
59    type Response = ServiceResponse<B>;
60    type Error = Error;
61    type InitError = ();
62    type Transform = TenantContextMiddlewareService<S>;
63    type Future = Ready<Result<Self::Transform, Self::InitError>>;
64
65    fn new_transform(&self, service: S) -> Self::Future {
66        ready(Ok(TenantContextMiddlewareService {
67            service: Rc::new(service),
68        }))
69    }
70}
71
72/// Service implementation for tenant context middleware
73pub struct TenantContextMiddlewareService<S> {
74    service: Rc<S>,
75}
76
77impl<S> std::fmt::Debug for TenantContextMiddlewareService<S> {
78    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79        f.debug_struct("TenantContextMiddlewareService")
80            .finish_non_exhaustive()
81    }
82}
83
84/// Extract JWT token from request
85///
86/// Checks in order of priority:
87/// 1. Cookie named "token" (httpOnly cookie)
88/// 2. Authorization header with Bearer scheme
89fn extract_jwt_token(req: &ServiceRequest) -> Option<String> {
90    // Priority 1: Check for httpOnly cookie named "token"
91    if let Some(cookie_header) = req.headers().get("cookie") {
92        if let Ok(cookie_str) = cookie_header.to_str() {
93            // Parse cookies manually (simple approach)
94            for cookie_pair in cookie_str.split(';') {
95                let parts: Vec<&str> = cookie_pair.trim().splitn(2, '=').collect();
96                if parts.len() == 2 && parts[0] == "token" {
97                    tracing::debug!("JWT token found in cookie");
98                    return Some(parts[1].to_string());
99                }
100            }
101        }
102    }
103
104    // Priority 2: Check Authorization header for Bearer token
105    if let Some(auth_header) = req.headers().get("authorization") {
106        if let Ok(auth_str) = auth_header.to_str() {
107            if let Some(token) = auth_str.strip_prefix("Bearer ") {
108                tracing::debug!("JWT token found in Authorization header");
109                return Some(token.to_string());
110            }
111        }
112    }
113
114    None
115}
116
117impl<S, B> Service<ServiceRequest> for TenantContextMiddlewareService<S>
118where
119    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
120    S::Future: 'static,
121    B: 'static,
122{
123    type Response = ServiceResponse<B>;
124    type Error = Error;
125    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
126
127    forward_ready!(service);
128
129    fn call(&self, req: ServiceRequest) -> Self::Future {
130        let svc = self.service.clone();
131
132        Box::pin(async move {
133            // Extract JWT token from Cookie (priority) or Authorization header
134            let token = extract_jwt_token(&req);
135
136            if token.is_none() {
137                tracing::warn!("No JWT token found in request");
138                return Err(ErrorUnauthorized("Authentication required"));
139            }
140
141            let token = token.unwrap();
142
143            // Get CorteqApp from app_data
144            let app_config = req.app_data::<web::Data<CorteqApp>>().ok_or_else(|| {
145                tracing::error!("CorteqApp not found in app_data");
146                ErrorBadRequest("Application configuration error")
147            })?;
148
149            // Decode JWT to extract claims
150            let claims = app_config.jwt_service.decode(&token).map_err(|e| {
151                tracing::warn!("JWT validation failed: {}", e);
152                ErrorUnauthorized("Invalid or expired token")
153            })?;
154
155            let tenant_id = claims.tenant_id;
156
157            // Try to get tenant context from cache first
158            let tenant_context = if let Some(ctx) = app_config.tenant_cache.get(&tenant_id).await {
159                tracing::debug!("Tenant {} found in cache", tenant_id);
160                ctx
161            } else {
162                // Cache miss - load from database
163                tracing::debug!("Cache miss for tenant {}, loading from database", tenant_id);
164
165                let tenant = app_config
166                    .tenant_repository
167                    .find_by_id(&tenant_id)
168                    .await
169                    .map_err(|e| {
170                        tracing::error!("Database error loading tenant: {}", e);
171                        ErrorBadRequest("Failed to load tenant")
172                    })?
173                    .ok_or_else(|| {
174                        tracing::warn!("Tenant {} not found or soft-deleted", tenant_id);
175                        ErrorUnauthorized("Tenant not found")
176                    })?;
177
178                // Create tenant context from loaded tenant
179                let ctx = TenantContext::from(tenant);
180
181                // Cache the context for future requests
182                app_config.tenant_cache.set(tenant_id, ctx.clone()).await;
183
184                ctx
185            };
186
187            // Inject tenant context into request extensions for downstream handlers
188            req.extensions_mut().insert(tenant_context);
189
190            // Continue to the next service
191            let res = svc.call(req).await?;
192            Ok(res)
193        })
194    }
195}
196
197// Integration tests are in tests/ directory