prax_query/tenant/
middleware.rs

1//! Tenant middleware for automatic query filtering.
2
3use super::config::TenantConfig;
4use super::context::TenantContext;
5use super::strategy::ColumnType;
6use crate::error::{QueryError, QueryResult};
7use crate::middleware::{BoxFuture, Middleware, Next, QueryContext, QueryResponse, QueryType};
8use std::sync::{Arc, RwLock};
9
10/// Middleware that automatically applies tenant filtering to queries.
11pub struct TenantMiddleware {
12    config: TenantConfig,
13    current_tenant: Arc<RwLock<Option<TenantContext>>>,
14}
15
16impl TenantMiddleware {
17    /// Create a new tenant middleware with the given config.
18    pub fn new(config: TenantConfig) -> Self {
19        Self {
20            config,
21            current_tenant: Arc::new(RwLock::new(None)),
22        }
23    }
24
25    /// Set the current tenant context.
26    pub fn set_tenant(&self, ctx: TenantContext) {
27        *self.current_tenant.write().expect("lock poisoned") = Some(ctx);
28    }
29
30    /// Clear the current tenant context.
31    pub fn clear_tenant(&self) {
32        *self.current_tenant.write().expect("lock poisoned") = None;
33    }
34
35    /// Get the current tenant context.
36    pub fn current_tenant(&self) -> Option<TenantContext> {
37        self.current_tenant.read().expect("lock poisoned").clone()
38    }
39
40    /// Create a scoped tenant context (automatically clears on drop).
41    pub fn scoped(&self, ctx: TenantContext) -> TenantScope {
42        self.set_tenant(ctx);
43        TenantScope {
44            middleware: Arc::new(self.clone()),
45        }
46    }
47
48    /// Apply row-level filtering to a SQL query.
49    fn apply_row_level_filter(&self, sql: &str, tenant_id: &str) -> (String, Vec<String>) {
50        let config = match self.config.row_level_config() {
51            Some(c) => c,
52            None => return (sql.to_string(), vec![]),
53        };
54
55        let column = &config.column;
56        let tenant_value = match config.column_type {
57            ColumnType::String => format!("'{}'", tenant_id.replace('\'', "''")),
58            ColumnType::Uuid => format!("'{}'::uuid", tenant_id),
59            ColumnType::Integer | ColumnType::BigInt => tenant_id.to_string(),
60        };
61
62        // Parse and modify SQL
63        let modified_sql = self.inject_tenant_filter(sql, column, &tenant_value);
64        (modified_sql, vec![tenant_id.to_string()])
65    }
66
67    /// Inject tenant filter into SQL.
68    fn inject_tenant_filter(&self, sql: &str, column: &str, value: &str) -> String {
69        let sql_upper = sql.to_uppercase();
70        let filter = format!("{} = {}", column, value);
71
72        // Handle SELECT queries
73        if sql_upper.starts_with("SELECT") {
74            if let Some(where_pos) = sql_upper.find("WHERE") {
75                // Insert after WHERE
76                let (before, after) = sql.split_at(where_pos + 5);
77                return format!("{} {} AND {}", before.trim(), filter, after.trim());
78            } else if let Some(order_pos) = sql_upper.find("ORDER BY") {
79                // Insert before ORDER BY
80                let (before, after) = sql.split_at(order_pos);
81                return format!("{} WHERE {} {}", before.trim(), filter, after);
82            } else if let Some(limit_pos) = sql_upper.find("LIMIT") {
83                // Insert before LIMIT
84                let (before, after) = sql.split_at(limit_pos);
85                return format!("{} WHERE {} {}", before.trim(), filter, after);
86            } else {
87                // Append WHERE clause
88                return format!("{} WHERE {}", sql.trim(), filter);
89            }
90        }
91
92        // Handle UPDATE queries
93        if sql_upper.starts_with("UPDATE") {
94            if let Some(where_pos) = sql_upper.find("WHERE") {
95                let (before, after) = sql.split_at(where_pos + 5);
96                return format!("{} {} AND {}", before.trim(), filter, after.trim());
97            } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
98                let (before, after) = sql.split_at(returning_pos);
99                return format!("{} WHERE {} {}", before.trim(), filter, after);
100            } else {
101                return format!("{} WHERE {}", sql.trim(), filter);
102            }
103        }
104
105        // Handle DELETE queries
106        if sql_upper.starts_with("DELETE") {
107            if let Some(where_pos) = sql_upper.find("WHERE") {
108                let (before, after) = sql.split_at(where_pos + 5);
109                return format!("{} {} AND {}", before.trim(), filter, after.trim());
110            } else if let Some(returning_pos) = sql_upper.find("RETURNING") {
111                let (before, after) = sql.split_at(returning_pos);
112                return format!("{} WHERE {} {}", before.trim(), filter, after);
113            } else {
114                return format!("{} WHERE {}", sql.trim(), filter);
115            }
116        }
117
118        // Handle INSERT queries (add tenant_id column)
119        if sql_upper.starts_with("INSERT") && self.config.row_level_config().map_or(false, |c| c.auto_insert) {
120            // This is simplified - real implementation would parse the INSERT properly
121            // For now, we assume tenant_id is included in the data
122        }
123
124        sql.to_string()
125    }
126
127    /// Apply schema-based isolation.
128    fn apply_schema_isolation(&self, tenant_id: &str) -> Option<String> {
129        self.config
130            .schema_config()
131            .map(|c| c.search_path(tenant_id))
132    }
133}
134
135impl Clone for TenantMiddleware {
136    fn clone(&self) -> Self {
137        Self {
138            config: self.config.clone(),
139            current_tenant: Arc::clone(&self.current_tenant),
140        }
141    }
142}
143
144impl std::fmt::Debug for TenantMiddleware {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        f.debug_struct("TenantMiddleware")
147            .field("config", &self.config)
148            .field("has_tenant", &self.current_tenant().is_some())
149            .finish()
150    }
151}
152
153impl Middleware for TenantMiddleware {
154    fn handle<'a>(
155        &'a self,
156        mut ctx: QueryContext,
157        next: Next<'a>,
158    ) -> BoxFuture<'a, QueryResult<QueryResponse>> {
159        Box::pin(async move {
160            // Get tenant context
161            let tenant_ctx = match self.current_tenant() {
162                Some(ctx) => ctx,
163                None => {
164                    // No tenant context
165                    if self.config.require_tenant {
166                        if let Some(default) = &self.config.default_tenant {
167                            TenantContext::new(default.clone())
168                        } else {
169                            return Err(QueryError::internal(
170                                "Tenant context required but not provided",
171                            ));
172                        }
173                    } else {
174                        // No tenant filtering
175                        return next.run(ctx).await;
176                    }
177                }
178            };
179
180            // Check for bypass
181            if self.config.allow_bypass && tenant_ctx.should_bypass() {
182                if self.config.log_tenant_context {
183                    tracing::debug!(
184                        tenant_id = %tenant_ctx.id,
185                        bypass = true,
186                        "Tenant filter bypassed"
187                    );
188                }
189                return next.run(ctx).await;
190            }
191
192            // Apply row-level filtering if configured
193            if self.config.strategy.is_row_level() {
194                let query_type = ctx.query_type();
195
196                // Validate writes
197                if self.config.enforce_on_writes
198                    && matches!(
199                        query_type,
200                        QueryType::Insert | QueryType::Update | QueryType::Delete
201                    )
202                {
203                    // For writes, we need to ensure tenant_id is included
204                }
205
206                // Apply filter to query
207                let (modified_sql, _extra_params) =
208                    self.apply_row_level_filter(ctx.sql(), tenant_ctx.id.as_str());
209
210                // Update context with modified SQL
211                ctx = ctx.with_sql(modified_sql);
212            }
213
214            // Apply schema-based isolation if configured
215            if self.config.strategy.is_schema_based() {
216                if let Some(search_path) = self.apply_schema_isolation(tenant_ctx.id.as_str()) {
217                    // The search_path should be set on the connection
218                    // This is typically done by the connection manager
219                    ctx.metadata_mut().set_schema_override(Some(
220                        self.config
221                            .schema_config()
222                            .unwrap()
223                            .schema_name(tenant_ctx.id.as_str()),
224                    ));
225
226                    // Log the schema setting
227                    if self.config.log_tenant_context {
228                        tracing::debug!(
229                            tenant_id = %tenant_ctx.id,
230                            search_path = %search_path,
231                            "Setting schema for tenant"
232                        );
233                    }
234                }
235            }
236
237            // Log tenant context
238            if self.config.log_tenant_context {
239                tracing::debug!(
240                    tenant_id = %tenant_ctx.id,
241                    strategy = ?self.config.strategy,
242                    sql = %ctx.sql(),
243                    "Executing query with tenant context"
244                );
245            }
246
247            // Set tenant in metadata for downstream middleware
248            ctx.metadata_mut().tenant_id = Some(tenant_ctx.id.to_string());
249
250            // Continue with modified query
251            next.run(ctx).await
252        })
253    }
254
255    fn name(&self) -> &'static str {
256        "TenantMiddleware"
257    }
258}
259
260/// A scoped tenant context that clears on drop.
261pub struct TenantScope {
262    middleware: Arc<TenantMiddleware>,
263}
264
265impl Drop for TenantScope {
266    fn drop(&mut self) {
267        self.middleware.clear_tenant();
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn test_row_level_filter_select() {
277        let config = TenantConfig::row_level("tenant_id");
278        let middleware = TenantMiddleware::new(config);
279
280        let (sql, _) = middleware.apply_row_level_filter(
281            "SELECT * FROM users",
282            "tenant-123",
283        );
284        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
285
286        let (sql, _) = middleware.apply_row_level_filter(
287            "SELECT * FROM users WHERE active = true",
288            "tenant-123",
289        );
290        assert!(sql.contains("tenant_id = 'tenant-123' AND active = true"));
291    }
292
293    #[test]
294    fn test_row_level_filter_update() {
295        let config = TenantConfig::row_level("tenant_id");
296        let middleware = TenantMiddleware::new(config);
297
298        let (sql, _) = middleware.apply_row_level_filter(
299            "UPDATE users SET name = 'Bob'",
300            "tenant-123",
301        );
302        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
303
304        let (sql, _) = middleware.apply_row_level_filter(
305            "UPDATE users SET name = 'Bob' WHERE id = 1",
306            "tenant-123",
307        );
308        assert!(sql.contains("tenant_id = 'tenant-123' AND id = 1"));
309    }
310
311    #[test]
312    fn test_row_level_filter_delete() {
313        let config = TenantConfig::row_level("tenant_id");
314        let middleware = TenantMiddleware::new(config);
315
316        let (sql, _) = middleware.apply_row_level_filter(
317            "DELETE FROM users",
318            "tenant-123",
319        );
320        assert!(sql.contains("WHERE tenant_id = 'tenant-123'"));
321    }
322
323    #[test]
324    fn test_tenant_scope() {
325        let config = TenantConfig::row_level("tenant_id");
326        let middleware = TenantMiddleware::new(config);
327
328        {
329            let _scope = middleware.scoped(TenantContext::new("tenant-123"));
330            assert!(middleware.current_tenant().is_some());
331            assert_eq!(
332                middleware.current_tenant().unwrap().id.as_str(),
333                "tenant-123"
334            );
335        }
336
337        // Scope dropped, tenant cleared
338        assert!(middleware.current_tenant().is_none());
339    }
340}
341
342