1use super::config::TenantConfig;
4use super::context::TenantContext;
5use super::strategy::ColumnType;
6use crate::error::{QueryError, QueryResult};
7use crate::middleware::{BoxFuture, Middleware, Next, QueryContext, QueryResponse, QueryType};
8use std::sync::{Arc, RwLock};
9
10pub struct TenantMiddleware {
12 config: TenantConfig,
13 current_tenant: Arc<RwLock<Option<TenantContext>>>,
14}
15
16impl TenantMiddleware {
17 pub fn new(config: TenantConfig) -> Self {
19 Self {
20 config,
21 current_tenant: Arc::new(RwLock::new(None)),
22 }
23 }
24
25 pub fn set_tenant(&self, ctx: TenantContext) {
27 *self.current_tenant.write().expect("lock poisoned") = Some(ctx);
28 }
29
30 pub fn clear_tenant(&self) {
32 *self.current_tenant.write().expect("lock poisoned") = None;
33 }
34
35 pub fn current_tenant(&self) -> Option<TenantContext> {
37 self.current_tenant.read().expect("lock poisoned").clone()
38 }
39
40 pub fn scoped(&self, ctx: TenantContext) -> TenantScope {
42 self.set_tenant(ctx);
43 TenantScope {
44 middleware: Arc::new(self.clone()),
45 }
46 }
47
48 fn apply_row_level_filter(&self, sql: &str, tenant_id: &str) -> (String, Vec<String>) {
50 let config = match self.config.row_level_config() {
51 Some(c) => c,
52 None => return (sql.to_string(), vec![]),
53 };
54
55 let column = &config.column;
56 let tenant_value = match config.column_type {
57 ColumnType::String => format!("'{}'", tenant_id.replace('\'', "''")),
58 ColumnType::Uuid => format!("'{}'::uuid", tenant_id),
59 ColumnType::Integer | ColumnType::BigInt => tenant_id.to_string(),
60 };
61
62 let modified_sql = self.inject_tenant_filter(sql, column, &tenant_value);
64 (modified_sql, vec![tenant_id.to_string()])
65 }
66
67 fn inject_tenant_filter(&self, sql: &str, column: &str, value: &str) -> String {
69 let sql_upper = sql.to_uppercase();
70 let filter = format!("{} = {}", column, value);
71
72 if sql_upper.starts_with("SELECT") {
74 if let Some(where_pos) = sql_upper.find("WHERE") {
75 let (before, after) = sql.split_at(where_pos + 5);
77 return format!("{} {} AND {}", before.trim(), filter, after.trim());
78 } else if let Some(order_pos) = sql_upper.find("ORDER BY") {
79 let (before, after) = sql.split_at(order_pos);
81 return format!("{} WHERE {} {}", before.trim(), filter, after);
82 } else if let Some(limit_pos) = sql_upper.find("LIMIT") {
83 let (before, after) = sql.split_at(limit_pos);
85 return format!("{} WHERE {} {}", before.trim(), filter, after);
86 } else {
87 return format!("{} WHERE {}", sql.trim(), filter);
89 }
90 }
91
92 if sql_upper.starts_with("UPDATE") {
94 if let Some(where_pos) = sql_upper.find("WHERE") {
95 let (before, after) = sql.split_at(where_pos + 5);
96 return format!("{} {} AND {}", before.trim(), filter, after.trim());
97 } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
98 let (before, after) = sql.split_at(returning_pos);
99 return format!("{} WHERE {} {}", before.trim(), filter, after);
100 } else {
101 return format!("{} WHERE {}", sql.trim(), filter);
102 }
103 }
104
105 if sql_upper.starts_with("DELETE") {
107 if let Some(where_pos) = sql_upper.find("WHERE") {
108 let (before, after) = sql.split_at(where_pos + 5);
109 return format!("{} {} AND {}", before.trim(), filter, after.trim());
110 } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
111 let (before, after) = sql.split_at(returning_pos);
112 return format!("{} WHERE {} {}", before.trim(), filter, after);
113 } else {
114 return format!("{} WHERE {}", sql.trim(), filter);
115 }
116 }
117
118 if sql_upper.starts_with("INSERT")
120 && self
121 .config
122 .row_level_config()
123 .is_some_and(|c| c.auto_insert)
124 {
125 }
128
129 sql.to_string()
130 }
131
132 fn apply_schema_isolation(&self, tenant_id: &str) -> Option<String> {
134 self.config
135 .schema_config()
136 .map(|c| c.search_path(tenant_id))
137 }
138}
139
140impl Clone for TenantMiddleware {
141 fn clone(&self) -> Self {
142 Self {
143 config: self.config.clone(),
144 current_tenant: Arc::clone(&self.current_tenant),
145 }
146 }
147}
148
149impl std::fmt::Debug for TenantMiddleware {
150 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151 f.debug_struct("TenantMiddleware")
152 .field("config", &self.config)
153 .field("has_tenant", &self.current_tenant().is_some())
154 .finish()
155 }
156}
157
158impl Middleware for TenantMiddleware {
159 fn handle<'a>(
160 &'a self,
161 mut ctx: QueryContext,
162 next: Next<'a>,
163 ) -> BoxFuture<'a, QueryResult<QueryResponse>> {
164 Box::pin(async move {
165 let tenant_ctx = match self.current_tenant() {
167 Some(ctx) => ctx,
168 None => {
169 if self.config.require_tenant {
171 if let Some(default) = &self.config.default_tenant {
172 TenantContext::new(default.clone())
173 } else {
174 return Err(QueryError::internal(
175 "Tenant context required but not provided",
176 ));
177 }
178 } else {
179 return next.run(ctx).await;
181 }
182 }
183 };
184
185 if self.config.allow_bypass && tenant_ctx.should_bypass() {
187 if self.config.log_tenant_context {
188 tracing::debug!(
189 tenant_id = %tenant_ctx.id,
190 bypass = true,
191 "Tenant filter bypassed"
192 );
193 }
194 return next.run(ctx).await;
195 }
196
197 if self.config.strategy.is_row_level() {
199 let query_type = ctx.query_type();
200
201 if self.config.enforce_on_writes
203 && matches!(
204 query_type,
205 QueryType::Insert | QueryType::Update | QueryType::Delete
206 )
207 {
208 }
210
211 let (modified_sql, _extra_params) =
213 self.apply_row_level_filter(ctx.sql(), tenant_ctx.id.as_str());
214
215 ctx = ctx.with_sql(modified_sql);
217 }
218
219 if self.config.strategy.is_schema_based() {
221 if let Some(search_path) = self.apply_schema_isolation(tenant_ctx.id.as_str()) {
222 ctx.metadata_mut().set_schema_override(Some(
225 self.config
226 .schema_config()
227 .unwrap()
228 .schema_name(tenant_ctx.id.as_str()),
229 ));
230
231 if self.config.log_tenant_context {
233 tracing::debug!(
234 tenant_id = %tenant_ctx.id,
235 search_path = %search_path,
236 "Setting schema for tenant"
237 );
238 }
239 }
240 }
241
242 if self.config.log_tenant_context {
244 tracing::debug!(
245 tenant_id = %tenant_ctx.id,
246 strategy = ?self.config.strategy,
247 sql = %ctx.sql(),
248 "Executing query with tenant context"
249 );
250 }
251
252 ctx.metadata_mut().tenant_id = Some(tenant_ctx.id.to_string());
254
255 next.run(ctx).await
257 })
258 }
259
260 fn name(&self) -> &'static str {
261 "TenantMiddleware"
262 }
263}
264
265pub struct TenantScope {
267 middleware: Arc<TenantMiddleware>,
268}
269
270impl Drop for TenantScope {
271 fn drop(&mut self) {
272 self.middleware.clear_tenant();
273 }
274}
275
276#[cfg(test)]
277mod tests {
278 use super::*;
279
280 #[test]
281 fn test_row_level_filter_select() {
282 let config = TenantConfig::row_level("tenant_id");
283 let middleware = TenantMiddleware::new(config);
284
285 let (sql, _) = middleware.apply_row_level_filter("SELECT * FROM users", "tenant-123");
286 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
287
288 let (sql, _) = middleware
289 .apply_row_level_filter("SELECT * FROM users WHERE active = true", "tenant-123");
290 assert!(sql.contains("tenant_id = 'tenant-123' AND active = true"));
291 }
292
293 #[test]
294 fn test_row_level_filter_update() {
295 let config = TenantConfig::row_level("tenant_id");
296 let middleware = TenantMiddleware::new(config);
297
298 let (sql, _) =
299 middleware.apply_row_level_filter("UPDATE users SET name = 'Bob'", "tenant-123");
300 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
301
302 let (sql, _) = middleware
303 .apply_row_level_filter("UPDATE users SET name = 'Bob' WHERE id = 1", "tenant-123");
304 assert!(sql.contains("tenant_id = 'tenant-123' AND id = 1"));
305 }
306
307 #[test]
308 fn test_row_level_filter_delete() {
309 let config = TenantConfig::row_level("tenant_id");
310 let middleware = TenantMiddleware::new(config);
311
312 let (sql, _) = middleware.apply_row_level_filter("DELETE FROM users", "tenant-123");
313 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
314 }
315
316 #[test]
317 fn test_tenant_scope() {
318 let config = TenantConfig::row_level("tenant_id");
319 let middleware = TenantMiddleware::new(config);
320
321 {
322 let _scope = middleware.scoped(TenantContext::new("tenant-123"));
323 assert!(middleware.current_tenant().is_some());
324 assert_eq!(
325 middleware.current_tenant().unwrap().id.as_str(),
326 "tenant-123"
327 );
328 }
329
330 assert!(middleware.current_tenant().is_none());
332 }
333}