prax_query/tenant/
middleware.rs

1//! Tenant middleware for automatic query filtering.
2
3use super::config::TenantConfig;
4use super::context::TenantContext;
5use super::strategy::ColumnType;
6use crate::error::{QueryError, QueryResult};
7use crate::middleware::{BoxFuture, Middleware, Next, QueryContext, QueryResponse, QueryType};
8use std::sync::{Arc, RwLock};
9
10/// Middleware that automatically applies tenant filtering to queries.
11pub struct TenantMiddleware {
12    config: TenantConfig,
13    current_tenant: Arc<RwLock<Option<TenantContext>>>,
14}
15
16impl TenantMiddleware {
17    /// Create a new tenant middleware with the given config.
18    pub fn new(config: TenantConfig) -> Self {
19        Self {
20            config,
21            current_tenant: Arc::new(RwLock::new(None)),
22        }
23    }
24
25    /// Set the current tenant context.
26    pub fn set_tenant(&self, ctx: TenantContext) {
27        *self.current_tenant.write().expect("lock poisoned") = Some(ctx);
28    }
29
30    /// Clear the current tenant context.
31    pub fn clear_tenant(&self) {
32        *self.current_tenant.write().expect("lock poisoned") = None;
33    }
34
35    /// Get the current tenant context.
36    pub fn current_tenant(&self) -> Option<TenantContext> {
37        self.current_tenant.read().expect("lock poisoned").clone()
38    }
39
40    /// Create a scoped tenant context (automatically clears on drop).
41    pub fn scoped(&self, ctx: TenantContext) -> TenantScope {
42        self.set_tenant(ctx);
43        TenantScope {
44            middleware: Arc::new(self.clone()),
45        }
46    }
47
48    /// Apply row-level filtering to a SQL query.
49    fn apply_row_level_filter(&self, sql: &str, tenant_id: &str) -> (String, Vec<String>) {
50        let config = match self.config.row_level_config() {
51            Some(c) => c,
52            None => return (sql.to_string(), vec![]),
53        };
54
55        let column = &config.column;
56        let tenant_value = match config.column_type {
57            ColumnType::String => format!("'{}'", tenant_id.replace('\'', "''")),
58            ColumnType::Uuid => format!("'{}'::uuid", tenant_id),
59            ColumnType::Integer | ColumnType::BigInt => tenant_id.to_string(),
60        };
61
62        // Parse and modify SQL
63        let modified_sql = self.inject_tenant_filter(sql, column, &tenant_value);
64        (modified_sql, vec![tenant_id.to_string()])
65    }
66
67    /// Inject tenant filter into SQL.
68    fn inject_tenant_filter(&self, sql: &str, column: &str, value: &str) -> String {
69        let sql_upper = sql.to_uppercase();
70        let filter = format!("{} = {}", column, value);
71
72        // Handle SELECT queries
73        if sql_upper.starts_with("SELECT") {
74            if let Some(where_pos) = sql_upper.find("WHERE") {
75                // Insert after WHERE
76                let (before, after) = sql.split_at(where_pos + 5);
77                return format!("{} {} AND {}", before.trim(), filter, after.trim());
78            } else if let Some(order_pos) = sql_upper.find("ORDER BY") {
79                // Insert before ORDER BY
80                let (before, after) = sql.split_at(order_pos);
81                return format!("{} WHERE {} {}", before.trim(), filter, after);
82            } else if let Some(limit_pos) = sql_upper.find("LIMIT") {
83                // Insert before LIMIT
84                let (before, after) = sql.split_at(limit_pos);
85                return format!("{} WHERE {} {}", before.trim(), filter, after);
86            } else {
87                // Append WHERE clause
88                return format!("{} WHERE {}", sql.trim(), filter);
89            }
90        }
91
92        // Handle UPDATE queries
93        if sql_upper.starts_with("UPDATE") {
94            if let Some(where_pos) = sql_upper.find("WHERE") {
95                let (before, after) = sql.split_at(where_pos + 5);
96                return format!("{} {} AND {}", before.trim(), filter, after.trim());
97            } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
98                let (before, after) = sql.split_at(returning_pos);
99                return format!("{} WHERE {} {}", before.trim(), filter, after);
100            } else {
101                return format!("{} WHERE {}", sql.trim(), filter);
102            }
103        }
104
105        // Handle DELETE queries
106        if sql_upper.starts_with("DELETE") {
107            if let Some(where_pos) = sql_upper.find("WHERE") {
108                let (before, after) = sql.split_at(where_pos + 5);
109                return format!("{} {} AND {}", before.trim(), filter, after.trim());
110            } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
111                let (before, after) = sql.split_at(returning_pos);
112                return format!("{} WHERE {} {}", before.trim(), filter, after);
113            } else {
114                return format!("{} WHERE {}", sql.trim(), filter);
115            }
116        }
117
118        // Handle INSERT queries (add tenant_id column)
119        if sql_upper.starts_with("INSERT")
120            && self
121                .config
122                .row_level_config()
123                .is_some_and(|c| c.auto_insert)
124        {
125            // This is simplified - real implementation would parse the INSERT properly
126            // For now, we assume tenant_id is included in the data
127        }
128
129        sql.to_string()
130    }
131
132    /// Apply schema-based isolation.
133    fn apply_schema_isolation(&self, tenant_id: &str) -> Option<String> {
134        self.config
135            .schema_config()
136            .map(|c| c.search_path(tenant_id))
137    }
138}
139
140impl Clone for TenantMiddleware {
141    fn clone(&self) -> Self {
142        Self {
143            config: self.config.clone(),
144            current_tenant: Arc::clone(&self.current_tenant),
145        }
146    }
147}
148
149impl std::fmt::Debug for TenantMiddleware {
150    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
151        f.debug_struct("TenantMiddleware")
152            .field("config", &self.config)
153            .field("has_tenant", &self.current_tenant().is_some())
154            .finish()
155    }
156}
157
158impl Middleware for TenantMiddleware {
159    fn handle<'a>(
160        &'a self,
161        mut ctx: QueryContext,
162        next: Next<'a>,
163    ) -> BoxFuture<'a, QueryResult<QueryResponse>> {
164        Box::pin(async move {
165            // Get tenant context
166            let tenant_ctx = match self.current_tenant() {
167                Some(ctx) => ctx,
168                None => {
169                    // No tenant context
170                    if self.config.require_tenant {
171                        if let Some(default) = &self.config.default_tenant {
172                            TenantContext::new(default.clone())
173                        } else {
174                            return Err(QueryError::internal(
175                                "Tenant context required but not provided",
176                            ));
177                        }
178                    } else {
179                        // No tenant filtering
180                        return next.run(ctx).await;
181                    }
182                }
183            };
184
185            // Check for bypass
186            if self.config.allow_bypass && tenant_ctx.should_bypass() {
187                if self.config.log_tenant_context {
188                    tracing::debug!(
189                        tenant_id = %tenant_ctx.id,
190                        bypass = true,
191                        "Tenant filter bypassed"
192                    );
193                }
194                return next.run(ctx).await;
195            }
196
197            // Apply row-level filtering if configured
198            if self.config.strategy.is_row_level() {
199                let query_type = ctx.query_type();
200
201                // Validate writes
202                if self.config.enforce_on_writes
203                    && matches!(
204                        query_type,
205                        QueryType::Insert | QueryType::Update | QueryType::Delete
206                    )
207                {
208                    // For writes, we need to ensure tenant_id is included
209                }
210
211                // Apply filter to query
212                let (modified_sql, _extra_params) =
213                    self.apply_row_level_filter(ctx.sql(), tenant_ctx.id.as_str());
214
215                // Update context with modified SQL
216                ctx = ctx.with_sql(modified_sql);
217            }
218
219            // Apply schema-based isolation if configured
220            if self.config.strategy.is_schema_based() {
221                if let Some(search_path) = self.apply_schema_isolation(tenant_ctx.id.as_str()) {
222                    // The search_path should be set on the connection
223                    // This is typically done by the connection manager
224                    ctx.metadata_mut().set_schema_override(Some(
225                        self.config
226                            .schema_config()
227                            .unwrap()
228                            .schema_name(tenant_ctx.id.as_str()),
229                    ));
230
231                    // Log the schema setting
232                    if self.config.log_tenant_context {
233                        tracing::debug!(
234                            tenant_id = %tenant_ctx.id,
235                            search_path = %search_path,
236                            "Setting schema for tenant"
237                        );
238                    }
239                }
240            }
241
242            // Log tenant context
243            if self.config.log_tenant_context {
244                tracing::debug!(
245                    tenant_id = %tenant_ctx.id,
246                    strategy = ?self.config.strategy,
247                    sql = %ctx.sql(),
248                    "Executing query with tenant context"
249                );
250            }
251
252            // Set tenant in metadata for downstream middleware
253            ctx.metadata_mut().tenant_id = Some(tenant_ctx.id.to_string());
254
255            // Continue with modified query
256            next.run(ctx).await
257        })
258    }
259
260    fn name(&self) -> &'static str {
261        "TenantMiddleware"
262    }
263}
264
265/// A scoped tenant context that clears on drop.
266pub struct TenantScope {
267    middleware: Arc<TenantMiddleware>,
268}
269
270impl Drop for TenantScope {
271    fn drop(&mut self) {
272        self.middleware.clear_tenant();
273    }
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_row_level_filter_select() {
282        let config = TenantConfig::row_level("tenant_id");
283        let middleware = TenantMiddleware::new(config);
284
285        let (sql, _) = middleware.apply_row_level_filter("SELECT * FROM users", "tenant-123");
286        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
287
288        let (sql, _) = middleware
289            .apply_row_level_filter("SELECT * FROM users WHERE active = true", "tenant-123");
290        assert!(sql.contains("tenant_id = 'tenant-123' AND active = true"));
291    }
292
293    #[test]
294    fn test_row_level_filter_update() {
295        let config = TenantConfig::row_level("tenant_id");
296        let middleware = TenantMiddleware::new(config);
297
298        let (sql, _) =
299            middleware.apply_row_level_filter("UPDATE users SET name = 'Bob'", "tenant-123");
300        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
301
302        let (sql, _) = middleware
303            .apply_row_level_filter("UPDATE users SET name = 'Bob' WHERE id = 1", "tenant-123");
304        assert!(sql.contains("tenant_id = 'tenant-123' AND id = 1"));
305    }
306
307    #[test]
308    fn test_row_level_filter_delete() {
309        let config = TenantConfig::row_level("tenant_id");
310        let middleware = TenantMiddleware::new(config);
311
312        let (sql, _) = middleware.apply_row_level_filter("DELETE FROM users", "tenant-123");
313        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
314    }
315
316    #[test]
317    fn test_tenant_scope() {
318        let config = TenantConfig::row_level("tenant_id");
319        let middleware = TenantMiddleware::new(config);
320
321        {
322            let _scope = middleware.scoped(TenantContext::new("tenant-123"));
323            assert!(middleware.current_tenant().is_some());
324            assert_eq!(
325                middleware.current_tenant().unwrap().id.as_str(),
326                "tenant-123"
327            );
328        }
329
330        // Scope dropped, tenant cleared
331        assert!(middleware.current_tenant().is_none());
332    }
333}