1use crate::{CorteqApp, TenantContext};
4use actix_web::{
5 dev::{forward_ready, Service, ServiceRequest, ServiceResponse, Transform},
6 error::{ErrorBadRequest, ErrorUnauthorized},
7 web, Error, FromRequest, HttpMessage,
8};
9use futures::future::{ready, LocalBoxFuture, Ready};
10use std::rc::Rc;
11
12#[derive(Debug, Clone)]
28pub struct TenantExtractor(pub TenantContext);
29
30impl FromRequest for TenantExtractor {
31 type Error = Error;
32 type Future = Ready<Result<Self, Self::Error>>;
33
34 fn from_request(
35 req: &actix_web::HttpRequest,
36 _payload: &mut actix_web::dev::Payload,
37 ) -> Self::Future {
38 match req.extensions().get::<TenantContext>().cloned() {
40 Some(ctx) => ready(Ok(TenantExtractor(ctx))),
41 None => ready(Err(ErrorBadRequest("Missing tenant context"))),
42 }
43 }
44}
45
46#[derive(Debug, Default)]
51pub struct TenantContextMiddleware;
52
53impl<S, B> Transform<S, ServiceRequest> for TenantContextMiddleware
54where
55 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
56 S::Future: 'static,
57 B: 'static,
58{
59 type Response = ServiceResponse<B>;
60 type Error = Error;
61 type InitError = ();
62 type Transform = TenantContextMiddlewareService<S>;
63 type Future = Ready<Result<Self::Transform, Self::InitError>>;
64
65 fn new_transform(&self, service: S) -> Self::Future {
66 ready(Ok(TenantContextMiddlewareService {
67 service: Rc::new(service),
68 }))
69 }
70}
71
72pub struct TenantContextMiddlewareService<S> {
74 service: Rc<S>,
75}
76
77impl<S> std::fmt::Debug for TenantContextMiddlewareService<S> {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("TenantContextMiddlewareService")
80 .finish_non_exhaustive()
81 }
82}
83
84fn extract_jwt_token(req: &ServiceRequest) -> Option<String> {
90 if let Some(cookie_header) = req.headers().get("cookie") {
92 if let Ok(cookie_str) = cookie_header.to_str() {
93 for cookie_pair in cookie_str.split(';') {
95 let parts: Vec<&str> = cookie_pair.trim().splitn(2, '=').collect();
96 if parts.len() == 2 && parts[0] == "token" {
97 tracing::debug!("JWT token found in cookie");
98 return Some(parts[1].to_string());
99 }
100 }
101 }
102 }
103
104 if let Some(auth_header) = req.headers().get("authorization") {
106 if let Ok(auth_str) = auth_header.to_str() {
107 if let Some(token) = auth_str.strip_prefix("Bearer ") {
108 tracing::debug!("JWT token found in Authorization header");
109 return Some(token.to_string());
110 }
111 }
112 }
113
114 None
115}
116
117impl<S, B> Service<ServiceRequest> for TenantContextMiddlewareService<S>
118where
119 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
120 S::Future: 'static,
121 B: 'static,
122{
123 type Response = ServiceResponse<B>;
124 type Error = Error;
125 type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;
126
127 forward_ready!(service);
128
129 fn call(&self, req: ServiceRequest) -> Self::Future {
130 let svc = self.service.clone();
131
132 Box::pin(async move {
133 let token = extract_jwt_token(&req);
135
136 if token.is_none() {
137 tracing::warn!("No JWT token found in request");
138 return Err(ErrorUnauthorized("Authentication required"));
139 }
140
141 let token = token.unwrap();
142
143 let app_config = req.app_data::<web::Data<CorteqApp>>().ok_or_else(|| {
145 tracing::error!("CorteqApp not found in app_data");
146 ErrorBadRequest("Application configuration error")
147 })?;
148
149 let claims = app_config.jwt_service.decode(&token).map_err(|e| {
151 tracing::warn!("JWT validation failed: {}", e);
152 ErrorUnauthorized("Invalid or expired token")
153 })?;
154
155 let tenant_id = claims.tenant_id;
156
157 let tenant_context = if let Some(ctx) = app_config.tenant_cache.get(&tenant_id).await {
159 tracing::debug!("Tenant {} found in cache", tenant_id);
160 ctx
161 } else {
162 tracing::debug!("Cache miss for tenant {}, loading from database", tenant_id);
164
165 let tenant = app_config
166 .tenant_repository
167 .find_by_id(&tenant_id)
168 .await
169 .map_err(|e| {
170 tracing::error!("Database error loading tenant: {}", e);
171 ErrorBadRequest("Failed to load tenant")
172 })?
173 .ok_or_else(|| {
174 tracing::warn!("Tenant {} not found or soft-deleted", tenant_id);
175 ErrorUnauthorized("Tenant not found")
176 })?;
177
178 let ctx = TenantContext::from(tenant);
180
181 app_config.tenant_cache.set(tenant_id, ctx.clone()).await;
183
184 ctx
185 };
186
187 req.extensions_mut().insert(tenant_context);
189
190 let res = svc.call(req).await?;
192 Ok(res)
193 })
194 }
195}
196
197