1use super::config::TenantConfig;
4use super::context::TenantContext;
5use super::strategy::ColumnType;
6use crate::error::{QueryError, QueryResult};
7use crate::middleware::{BoxFuture, Middleware, Next, QueryContext, QueryResponse, QueryType};
8use std::sync::{Arc, RwLock};
9
10pub struct TenantMiddleware {
12 config: TenantConfig,
13 current_tenant: Arc<RwLock<Option<TenantContext>>>,
14}
15
16impl TenantMiddleware {
17 pub fn new(config: TenantConfig) -> Self {
19 Self {
20 config,
21 current_tenant: Arc::new(RwLock::new(None)),
22 }
23 }
24
25 pub fn set_tenant(&self, ctx: TenantContext) {
27 *self.current_tenant.write().expect("lock poisoned") = Some(ctx);
28 }
29
30 pub fn clear_tenant(&self) {
32 *self.current_tenant.write().expect("lock poisoned") = None;
33 }
34
35 pub fn current_tenant(&self) -> Option<TenantContext> {
37 self.current_tenant.read().expect("lock poisoned").clone()
38 }
39
40 pub fn scoped(&self, ctx: TenantContext) -> TenantScope {
42 self.set_tenant(ctx);
43 TenantScope {
44 middleware: Arc::new(self.clone()),
45 }
46 }
47
48 fn apply_row_level_filter(&self, sql: &str, tenant_id: &str) -> (String, Vec<String>) {
50 let config = match self.config.row_level_config() {
51 Some(c) => c,
52 None => return (sql.to_string(), vec![]),
53 };
54
55 let column = &config.column;
56 let tenant_value = match config.column_type {
57 ColumnType::String => format!("'{}'", tenant_id.replace('\'', "''")),
58 ColumnType::Uuid => format!("'{}'::uuid", tenant_id),
59 ColumnType::Integer | ColumnType::BigInt => tenant_id.to_string(),
60 };
61
62 let modified_sql = self.inject_tenant_filter(sql, column, &tenant_value);
64 (modified_sql, vec![tenant_id.to_string()])
65 }
66
67 fn inject_tenant_filter(&self, sql: &str, column: &str, value: &str) -> String {
69 let sql_upper = sql.to_uppercase();
70 let filter = format!("{} = {}", column, value);
71
72 if sql_upper.starts_with("SELECT") {
74 if let Some(where_pos) = sql_upper.find("WHERE") {
75 let (before, after) = sql.split_at(where_pos + 5);
77 return format!("{} {} AND {}", before.trim(), filter, after.trim());
78 } else if let Some(order_pos) = sql_upper.find("ORDER BY") {
79 let (before, after) = sql.split_at(order_pos);
81 return format!("{} WHERE {} {}", before.trim(), filter, after);
82 } else if let Some(limit_pos) = sql_upper.find("LIMIT") {
83 let (before, after) = sql.split_at(limit_pos);
85 return format!("{} WHERE {} {}", before.trim(), filter, after);
86 } else {
87 return format!("{} WHERE {}", sql.trim(), filter);
89 }
90 }
91
92 if sql_upper.starts_with("UPDATE") {
94 if let Some(where_pos) = sql_upper.find("WHERE") {
95 let (before, after) = sql.split_at(where_pos + 5);
96 return format!("{} {} AND {}", before.trim(), filter, after.trim());
97 } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
98 let (before, after) = sql.split_at(returning_pos);
99 return format!("{} WHERE {} {}", before.trim(), filter, after);
100 } else {
101 return format!("{} WHERE {}", sql.trim(), filter);
102 }
103 }
104
105 if sql_upper.starts_with("DELETE") {
107 if let Some(where_pos) = sql_upper.find("WHERE") {
108 let (before, after) = sql.split_at(where_pos + 5);
109 return format!("{} {} AND {}", before.trim(), filter, after.trim());
110 } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
111 let (before, after) = sql.split_at(returning_pos);
112 return format!("{} WHERE {} {}", before.trim(), filter, after);
113 } else {
114 return format!("{} WHERE {}", sql.trim(), filter);
115 }
116 }
117
118 if sql_upper.starts_with("INSERT") && self.config.row_level_config().map_or(false, |c| c.auto_insert) {
120 }
123
124 sql.to_string()
125 }
126
127 fn apply_schema_isolation(&self, tenant_id: &str) -> Option<String> {
129 self.config
130 .schema_config()
131 .map(|c| c.search_path(tenant_id))
132 }
133}
134
135impl Clone for TenantMiddleware {
136 fn clone(&self) -> Self {
137 Self {
138 config: self.config.clone(),
139 current_tenant: Arc::clone(&self.current_tenant),
140 }
141 }
142}
143
144impl std::fmt::Debug for TenantMiddleware {
145 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146 f.debug_struct("TenantMiddleware")
147 .field("config", &self.config)
148 .field("has_tenant", &self.current_tenant().is_some())
149 .finish()
150 }
151}
152
153impl Middleware for TenantMiddleware {
154 fn handle<'a>(
155 &'a self,
156 mut ctx: QueryContext,
157 next: Next<'a>,
158 ) -> BoxFuture<'a, QueryResult<QueryResponse>> {
159 Box::pin(async move {
160 let tenant_ctx = match self.current_tenant() {
162 Some(ctx) => ctx,
163 None => {
164 if self.config.require_tenant {
166 if let Some(default) = &self.config.default_tenant {
167 TenantContext::new(default.clone())
168 } else {
169 return Err(QueryError::internal(
170 "Tenant context required but not provided",
171 ));
172 }
173 } else {
174 return next.run(ctx).await;
176 }
177 }
178 };
179
180 if self.config.allow_bypass && tenant_ctx.should_bypass() {
182 if self.config.log_tenant_context {
183 tracing::debug!(
184 tenant_id = %tenant_ctx.id,
185 bypass = true,
186 "Tenant filter bypassed"
187 );
188 }
189 return next.run(ctx).await;
190 }
191
192 if self.config.strategy.is_row_level() {
194 let query_type = ctx.query_type();
195
196 if self.config.enforce_on_writes
198 && matches!(
199 query_type,
200 QueryType::Insert | QueryType::Update | QueryType::Delete
201 )
202 {
203 }
205
206 let (modified_sql, _extra_params) =
208 self.apply_row_level_filter(ctx.sql(), tenant_ctx.id.as_str());
209
210 ctx = ctx.with_sql(modified_sql);
212 }
213
214 if self.config.strategy.is_schema_based() {
216 if let Some(search_path) = self.apply_schema_isolation(tenant_ctx.id.as_str()) {
217 ctx.metadata_mut().set_schema_override(Some(
220 self.config
221 .schema_config()
222 .unwrap()
223 .schema_name(tenant_ctx.id.as_str()),
224 ));
225
226 if self.config.log_tenant_context {
228 tracing::debug!(
229 tenant_id = %tenant_ctx.id,
230 search_path = %search_path,
231 "Setting schema for tenant"
232 );
233 }
234 }
235 }
236
237 if self.config.log_tenant_context {
239 tracing::debug!(
240 tenant_id = %tenant_ctx.id,
241 strategy = ?self.config.strategy,
242 sql = %ctx.sql(),
243 "Executing query with tenant context"
244 );
245 }
246
247 ctx.metadata_mut().tenant_id = Some(tenant_ctx.id.to_string());
249
250 next.run(ctx).await
252 })
253 }
254
255 fn name(&self) -> &'static str {
256 "TenantMiddleware"
257 }
258}
259
260pub struct TenantScope {
262 middleware: Arc<TenantMiddleware>,
263}
264
265impl Drop for TenantScope {
266 fn drop(&mut self) {
267 self.middleware.clear_tenant();
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn test_row_level_filter_select() {
277 let config = TenantConfig::row_level("tenant_id");
278 let middleware = TenantMiddleware::new(config);
279
280 let (sql, _) = middleware.apply_row_level_filter(
281 "SELECT * FROM users",
282 "tenant-123",
283 );
284 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
285
286 let (sql, _) = middleware.apply_row_level_filter(
287 "SELECT * FROM users WHERE active = true",
288 "tenant-123",
289 );
290 assert!(sql.contains("tenant_id = 'tenant-123' AND active = true"));
291 }
292
293 #[test]
294 fn test_row_level_filter_update() {
295 let config = TenantConfig::row_level("tenant_id");
296 let middleware = TenantMiddleware::new(config);
297
298 let (sql, _) = middleware.apply_row_level_filter(
299 "UPDATE users SET name = 'Bob'",
300 "tenant-123",
301 );
302 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
303
304 let (sql, _) = middleware.apply_row_level_filter(
305 "UPDATE users SET name = 'Bob' WHERE id = 1",
306 "tenant-123",
307 );
308 assert!(sql.contains("tenant_id = 'tenant-123' AND id = 1"));
309 }
310
311 #[test]
312 fn test_row_level_filter_delete() {
313 let config = TenantConfig::row_level("tenant_id");
314 let middleware = TenantMiddleware::new(config);
315
316 let (sql, _) = middleware.apply_row_level_filter(
317 "DELETE FROM users",
318 "tenant-123",
319 );
320 assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
321 }
322
323 #[test]
324 fn test_tenant_scope() {
325 let config = TenantConfig::row_level("tenant_id");
326 let middleware = TenantMiddleware::new(config);
327
328 {
329 let _scope = middleware.scoped(TenantContext::new("tenant-123"));
330 assert!(middleware.current_tenant().is_some());
331 assert_eq!(
332 middleware.current_tenant().unwrap().id.as_str(),
333 "tenant-123"
334 );
335 }
336
337 assert!(middleware.current_tenant().is_none());
339 }
340}
341
342