Skip to main content

heliosdb_proxy/multi_tenancy/
transformer.rs

1//! Tenant Query Transformer
2//!
3//! This module transforms SQL queries to apply tenant isolation, primarily
4//! for row-level security where queries need WHERE clause injection.
5
6use std::collections::{HashMap, HashSet};
7
8use super::config::{IsolationStrategy, TenantConfig, TenantId};
9
10/// Result of query transformation
11#[derive(Debug, Clone)]
12pub struct TransformResult {
13    /// Transformed query
14    pub query: String,
15
16    /// Whether transformation was applied
17    pub transformed: bool,
18
19    /// Tables that were filtered
20    pub filtered_tables: Vec<String>,
21
22    /// Warnings generated during transformation
23    pub warnings: Vec<String>,
24}
25
26impl TransformResult {
27    /// Create a passthrough result (no transformation)
28    pub fn passthrough(query: impl Into<String>) -> Self {
29        Self {
30            query: query.into(),
31            transformed: false,
32            filtered_tables: Vec::new(),
33            warnings: Vec::new(),
34        }
35    }
36
37    /// Create a transformed result
38    pub fn transformed(query: impl Into<String>, tables: Vec<String>) -> Self {
39        Self {
40            query: query.into(),
41            transformed: true,
42            filtered_tables: tables,
43            warnings: Vec::new(),
44        }
45    }
46
47    /// Add a warning
48    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
49        self.warnings.push(warning.into());
50        self
51    }
52}
53
54/// Query transformer for tenant isolation
55pub struct TenantQueryTransformer {
56    /// Tables that require tenant filtering
57    tenant_tables: HashMap<String, String>,
58
59    /// Tables to exclude from filtering
60    excluded_tables: HashSet<String>,
61
62    /// Whether to use parameterized queries
63    use_parameters: bool,
64
65    /// Custom filter template (default: "{column} = '{value}'")
66    filter_template: Option<String>,
67}
68
69impl Default for TenantQueryTransformer {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl TenantQueryTransformer {
76    /// Create a new query transformer
77    pub fn new() -> Self {
78        Self {
79            tenant_tables: HashMap::new(),
80            excluded_tables: HashSet::new(),
81            use_parameters: false,
82            filter_template: None,
83        }
84    }
85
86    /// Register a table with its tenant column
87    pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
88        self.tenant_tables
89            .insert(table.into().to_lowercase(), column.into());
90        self
91    }
92
93    /// Register multiple tables with the same column
94    pub fn register_tables(mut self, tables: &[&str], column: impl Into<String>) -> Self {
95        let col = column.into();
96        for table in tables {
97            self.tenant_tables
98                .insert(table.to_lowercase(), col.clone());
99        }
100        self
101    }
102
103    /// Exclude a table from filtering
104    pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105        self.excluded_tables.insert(table.into().to_lowercase());
106        self
107    }
108
109    /// Use parameterized queries
110    pub fn with_parameters(mut self) -> Self {
111        self.use_parameters = true;
112        self
113    }
114
115    /// Set custom filter template
116    pub fn with_filter_template(mut self, template: impl Into<String>) -> Self {
117        self.filter_template = Some(template.into());
118        self
119    }
120
121    /// Get the tenant column for a table
122    pub fn get_tenant_column(&self, table: &str) -> Option<&str> {
123        self.tenant_tables
124            .get(&table.to_lowercase())
125            .map(|s| s.as_str())
126    }
127
128    /// Check if a table requires filtering
129    pub fn requires_filtering(&self, table: &str) -> bool {
130        let lower = table.to_lowercase();
131        self.tenant_tables.contains_key(&lower) && !self.excluded_tables.contains(&lower)
132    }
133
134    /// Transform a query for a tenant
135    pub fn transform(
136        &self,
137        query: &str,
138        tenant: &TenantId,
139        config: &TenantConfig,
140    ) -> TransformResult {
141        // Only transform for row-level isolation
142        let tenant_column = match &config.isolation {
143            IsolationStrategy::Row { tenant_column, .. } => tenant_column,
144            _ => return TransformResult::passthrough(query),
145        };
146
147        // Parse and transform query
148        let upper = query.trim().to_uppercase();
149
150        if upper.starts_with("SELECT") {
151            self.transform_select(query, tenant, tenant_column)
152        } else if upper.starts_with("UPDATE") {
153            self.transform_update(query, tenant, tenant_column)
154        } else if upper.starts_with("DELETE") {
155            self.transform_delete(query, tenant, tenant_column)
156        } else if upper.starts_with("INSERT") {
157            self.transform_insert(query, tenant, tenant_column)
158        } else {
159            TransformResult::passthrough(query)
160        }
161    }
162
163    /// Transform a SELECT query
164    fn transform_select(
165        &self,
166        query: &str,
167        tenant: &TenantId,
168        tenant_column: &str,
169    ) -> TransformResult {
170        let tables = self.extract_tables(query);
171        let filtered_tables: Vec<String> = tables
172            .iter()
173            .filter(|t| self.requires_filtering(t))
174            .cloned()
175            .collect();
176
177        if filtered_tables.is_empty() {
178            return TransformResult::passthrough(query);
179        }
180
181        let filter = self.build_filter(tenant, tenant_column, &filtered_tables);
182        let transformed = self.inject_where_clause(query, &filter);
183
184        TransformResult::transformed(transformed, filtered_tables)
185    }
186
187    /// Transform an UPDATE query
188    fn transform_update(
189        &self,
190        query: &str,
191        tenant: &TenantId,
192        tenant_column: &str,
193    ) -> TransformResult {
194        let table = self.extract_update_table(query);
195
196        if let Some(table) = table {
197            if self.requires_filtering(&table) {
198                let filter = self.build_single_filter(tenant, tenant_column);
199                let transformed = self.inject_where_clause(query, &filter);
200                return TransformResult::transformed(transformed, vec![table]);
201            }
202        }
203
204        TransformResult::passthrough(query)
205    }
206
207    /// Transform a DELETE query
208    fn transform_delete(
209        &self,
210        query: &str,
211        tenant: &TenantId,
212        tenant_column: &str,
213    ) -> TransformResult {
214        let table = self.extract_delete_table(query);
215
216        if let Some(table) = table {
217            if self.requires_filtering(&table) {
218                let filter = self.build_single_filter(tenant, tenant_column);
219                let transformed = self.inject_where_clause(query, &filter);
220                return TransformResult::transformed(transformed, vec![table]);
221            }
222        }
223
224        TransformResult::passthrough(query)
225    }
226
227    /// Transform an INSERT query (add tenant_id column)
228    fn transform_insert(
229        &self,
230        query: &str,
231        tenant: &TenantId,
232        tenant_column: &str,
233    ) -> TransformResult {
234        let table = self.extract_insert_table(query);
235
236        if let Some(table) = table {
237            if self.requires_filtering(&table) {
238                // For INSERT, we need to add tenant_id to the values
239                let transformed = self.inject_tenant_value(query, tenant, tenant_column);
240                return TransformResult::transformed(transformed, vec![table])
241                    .with_warning("Tenant column injection may require schema awareness");
242            }
243        }
244
245        TransformResult::passthrough(query)
246    }
247
248    /// Build a filter clause for multiple tables
249    fn build_filter(
250        &self,
251        tenant: &TenantId,
252        default_column: &str,
253        tables: &[String],
254    ) -> String {
255        let filters: Vec<String> = tables
256            .iter()
257            .map(|table| {
258                let column = self
259                    .get_tenant_column(table)
260                    .unwrap_or(default_column);
261                if self.use_parameters {
262                    format!("{}.{} = $1", table, column)
263                } else {
264                    format!("{}.{} = '{}'", table, column, tenant.0)
265                }
266            })
267            .collect();
268
269        filters.join(" AND ")
270    }
271
272    /// Build a single filter clause
273    fn build_single_filter(&self, tenant: &TenantId, column: &str) -> String {
274        if self.use_parameters {
275            format!("{} = $1", column)
276        } else {
277            match &self.filter_template {
278                Some(template) => template
279                    .replace("{column}", column)
280                    .replace("{value}", &tenant.0),
281                None => format!("{} = '{}'", column, tenant.0),
282            }
283        }
284    }
285
286    /// Inject WHERE clause into query
287    fn inject_where_clause(&self, query: &str, filter: &str) -> String {
288        let upper = query.to_uppercase();
289
290        // Find existing WHERE clause
291        if let Some(where_pos) = upper.find(" WHERE ") {
292            // Add to existing WHERE
293            let (before, after) = query.split_at(where_pos + 7);
294            format!("{}{} AND {}", before, filter, after)
295        } else {
296            // Find position to insert WHERE
297            // Look for ORDER BY, GROUP BY, LIMIT, etc.
298            let insert_before = [" ORDER ", " GROUP ", " LIMIT ", " HAVING ", " UNION "]
299                .iter()
300                .filter_map(|kw| upper.find(kw))
301                .min();
302
303            match insert_before {
304                Some(pos) => {
305                    let (before, after) = query.split_at(pos);
306                    format!("{} WHERE {}{}", before, filter, after)
307                }
308                None => {
309                    // Append at end
310                    format!("{} WHERE {}", query.trim_end_matches(';'), filter)
311                }
312            }
313        }
314    }
315
316    /// Inject tenant value into INSERT statement
317    fn inject_tenant_value(&self, query: &str, tenant: &TenantId, column: &str) -> String {
318        // Simple implementation - a real one would parse SQL properly
319        let upper = query.to_uppercase();
320
321        if let Some(values_pos) = upper.find(" VALUES ") {
322            if let Some(paren_pos) = query[values_pos..].find('(') {
323                let insert_pos = values_pos + paren_pos + 1;
324
325                // Check if there's a column list
326                if let Some(cols_start) = upper.find('(') {
327                    if cols_start < values_pos {
328                        // There's a column list - add column to it
329                        let cols_end = upper[cols_start..].find(')').unwrap_or(0) + cols_start;
330                        let before_cols_end = &query[..cols_end];
331                        let after_cols_end = &query[cols_end..];
332
333                        // Insert column name
334                        let with_column =
335                            format!("{}, {}{}", before_cols_end, column, after_cols_end);
336
337                        // Now insert the value
338                        let upper_new = with_column.to_uppercase();
339                        if let Some(new_values_pos) = upper_new.find(" VALUES ") {
340                            if let Some(new_paren_pos) = with_column[new_values_pos..].find('(') {
341                                let new_insert_pos = new_values_pos + new_paren_pos + 1;
342                                let before = &with_column[..new_insert_pos];
343                                let after = &with_column[new_insert_pos..];
344                                return format!("{}'{}'", before, tenant.0)
345                                    + if !after.starts_with(')') { ", " } else { "" }
346                                    + after;
347                            }
348                        }
349                    }
350                }
351
352                // No column list or couldn't parse - just add to values
353                let before = &query[..insert_pos];
354                let after = &query[insert_pos..];
355                return format!("{}'{}'", before, tenant.0)
356                    + if !after.starts_with(')') { ", " } else { "" }
357                    + after;
358            }
359        }
360
361        query.to_string()
362    }
363
364    /// Extract table names from SELECT query
365    fn extract_tables(&self, query: &str) -> Vec<String> {
366        let upper = query.to_uppercase();
367        let mut tables = Vec::new();
368
369        // Find FROM clause
370        if let Some(from_pos) = upper.find(" FROM ") {
371            let after_from = &query[from_pos + 6..];
372
373            // Find end of table list
374            let end_markers = [" WHERE ", " JOIN ", " LEFT ", " RIGHT ", " INNER ", " OUTER ",
375                " GROUP ", " ORDER ", " LIMIT ", " HAVING "];
376            let end_pos = end_markers
377                .iter()
378                .filter_map(|m| after_from.to_uppercase().find(m))
379                .min()
380                .unwrap_or(after_from.len());
381
382            let table_section = &after_from[..end_pos];
383
384            // Parse table names (handle aliases)
385            for part in table_section.split(',') {
386                let trimmed = part.trim();
387                if let Some(table) = trimmed.split_whitespace().next() {
388                    let clean = table
389                        .trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
390                    if !clean.is_empty() {
391                        tables.push(clean.to_string());
392                    }
393                }
394            }
395        }
396
397        // Also look for JOINs
398        let words: Vec<&str> = query.split_whitespace().collect();
399        for (i, word) in words.iter().enumerate() {
400            if word.to_uppercase() == "JOIN" && i + 1 < words.len() {
401                let table = words[i + 1]
402                    .trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
403                if !table.is_empty() && !tables.contains(&table.to_string()) {
404                    tables.push(table.to_string());
405                }
406            }
407        }
408
409        tables
410    }
411
412    /// Extract table name from UPDATE query
413    fn extract_update_table(&self, query: &str) -> Option<String> {
414        let upper = query.to_uppercase();
415        if let Some(update_pos) = upper.find("UPDATE ") {
416            let after_update = &query[update_pos + 7..];
417            if let Some(set_pos) = after_update.to_uppercase().find(" SET ") {
418                let table_section = &after_update[..set_pos];
419                let table = table_section
420                    .trim()
421                    .split_whitespace()
422                    .next()?
423                    .trim_matches(|c| c == '"' || c == '`');
424                return Some(table.to_string());
425            }
426        }
427        None
428    }
429
430    /// Extract table name from DELETE query
431    fn extract_delete_table(&self, query: &str) -> Option<String> {
432        let upper = query.to_uppercase();
433        if let Some(from_pos) = upper.find(" FROM ") {
434            let after_from = &query[from_pos + 6..];
435            let end_pos = after_from.to_uppercase().find(" WHERE ")
436                .unwrap_or(after_from.len());
437            let table_section = &after_from[..end_pos];
438            let table = table_section
439                .trim()
440                .split_whitespace()
441                .next()?
442                .trim_matches(|c| c == '"' || c == '`');
443            return Some(table.to_string());
444        }
445        None
446    }
447
448    /// Extract table name from INSERT query
449    fn extract_insert_table(&self, query: &str) -> Option<String> {
450        let upper = query.to_uppercase();
451        if let Some(into_pos) = upper.find(" INTO ") {
452            let after_into = &query[into_pos + 6..];
453            let end_pos = after_into.find(|c: char| c == '(' || c.is_whitespace())
454                .unwrap_or(after_into.len());
455            let table = after_into[..end_pos]
456                .trim()
457                .trim_matches(|c| c == '"' || c == '`');
458            return Some(table.to_string());
459        }
460        None
461    }
462
463    /// Generate SET search_path command for schema isolation
464    pub fn set_schema_search_path(
465        &self,
466        _tenant: &TenantId,
467        config: &TenantConfig,
468    ) -> Option<String> {
469        if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
470            Some(format!("SET search_path TO {}", schema_name))
471        } else {
472            None
473        }
474    }
475
476    /// Generate USE database command for database isolation
477    pub fn use_database(&self, _tenant: &TenantId, config: &TenantConfig) -> Option<String> {
478        if let IsolationStrategy::Database { database_name } = &config.isolation {
479            Some(format!("USE {}", database_name))
480        } else {
481            None
482        }
483    }
484}
485
486/// Validate that a query doesn't try to bypass tenant isolation
487pub fn validate_query(query: &str, _tenant: &TenantId, config: &TenantConfig) -> QueryValidation {
488    let mut validation = QueryValidation {
489        valid: true,
490        violations: Vec::new(),
491    };
492
493    let upper = query.to_uppercase();
494
495    // Check for dangerous operations
496    if let IsolationStrategy::Row { tenant_column, .. } = &config.isolation {
497        // Check if query tries to modify tenant column
498        if upper.contains(&format!("{} =", tenant_column.to_uppercase())) {
499            let set_pattern = format!("SET {} =", tenant_column.to_uppercase());
500            if upper.contains(&set_pattern) {
501                validation.valid = false;
502                validation
503                    .violations
504                    .push(format!("Cannot modify tenant column: {}", tenant_column));
505            }
506        }
507
508        // Check for TRUNCATE (bypasses row-level security)
509        if upper.starts_with("TRUNCATE ") {
510            validation.valid = false;
511            validation
512                .violations
513                .push("TRUNCATE not allowed with row-level isolation".to_string());
514        }
515
516        // Check for DROP TABLE
517        if upper.contains("DROP TABLE") {
518            validation.valid = false;
519            validation
520                .violations
521                .push("DROP TABLE not allowed with row-level isolation".to_string());
522        }
523    }
524
525    // Check for cross-schema access in schema isolation
526    if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
527        // Look for schema.table patterns that don't match tenant's schema
528        let parts: Vec<&str> = upper.split_whitespace().collect();
529        for part in parts {
530            if part.contains('.') && !part.starts_with(&schema_name.to_uppercase()) {
531                let schema = part.split('.').next().unwrap_or("");
532                if !schema.eq_ignore_ascii_case("pg_catalog")
533                    && !schema.eq_ignore_ascii_case("information_schema")
534                {
535                    validation.valid = false;
536                    validation.violations.push(format!(
537                        "Cross-schema access not allowed: {}",
538                        part
539                    ));
540                }
541            }
542        }
543    }
544
545    validation
546}
547
548/// Result of query validation
549#[derive(Debug, Clone)]
550pub struct QueryValidation {
551    /// Whether query is valid
552    pub valid: bool,
553
554    /// List of violations
555    pub violations: Vec<String>,
556}
557
558#[cfg(test)]
559mod tests {
560    use super::*;
561
562    fn create_row_config(tenant_id: &str) -> TenantConfig {
563        TenantConfig::builder()
564            .id(tenant_id)
565            .name("Test")
566            .row_isolation("shared_db", "tenant_id")
567            .build()
568    }
569
570    #[test]
571    fn test_transform_select() {
572        let transformer = TenantQueryTransformer::new()
573            .register_table("users", "tenant_id")
574            .register_table("orders", "tenant_id");
575
576        let tenant = TenantId::new("acme");
577        let config = create_row_config("acme");
578
579        let result = transformer.transform(
580            "SELECT * FROM users WHERE active = true",
581            &tenant,
582            &config,
583        );
584
585        assert!(result.transformed);
586        assert!(result.query.contains("tenant_id = 'acme'"));
587        assert!(result.query.contains("AND active = true"));
588    }
589
590    #[test]
591    fn test_transform_select_no_where() {
592        let transformer = TenantQueryTransformer::new()
593            .register_table("users", "tenant_id");
594
595        let tenant = TenantId::new("acme");
596        let config = create_row_config("acme");
597
598        let result = transformer.transform(
599            "SELECT * FROM users ORDER BY id",
600            &tenant,
601            &config,
602        );
603
604        assert!(result.transformed);
605        assert!(result.query.contains("WHERE users.tenant_id = 'acme'"));
606        assert!(result.query.contains("ORDER BY id"));
607    }
608
609    #[test]
610    fn test_transform_update() {
611        let transformer = TenantQueryTransformer::new()
612            .register_table("users", "tenant_id");
613
614        let tenant = TenantId::new("acme");
615        let config = create_row_config("acme");
616
617        let result = transformer.transform(
618            "UPDATE users SET name = 'John' WHERE id = 1",
619            &tenant,
620            &config,
621        );
622
623        assert!(result.transformed);
624        assert!(result.query.contains("tenant_id = 'acme'"));
625    }
626
627    #[test]
628    fn test_transform_delete() {
629        let transformer = TenantQueryTransformer::new()
630            .register_table("users", "tenant_id");
631
632        let tenant = TenantId::new("acme");
633        let config = create_row_config("acme");
634
635        let result = transformer.transform(
636            "DELETE FROM users WHERE id = 1",
637            &tenant,
638            &config,
639        );
640
641        assert!(result.transformed);
642        assert!(result.query.contains("tenant_id = 'acme'"));
643    }
644
645    #[test]
646    fn test_no_transform_for_unregistered_table() {
647        let transformer = TenantQueryTransformer::new()
648            .register_table("users", "tenant_id");
649
650        let tenant = TenantId::new("acme");
651        let config = create_row_config("acme");
652
653        let result = transformer.transform(
654            "SELECT * FROM logs WHERE level = 'error'",
655            &tenant,
656            &config,
657        );
658
659        assert!(!result.transformed);
660    }
661
662    #[test]
663    fn test_no_transform_for_schema_isolation() {
664        let transformer = TenantQueryTransformer::new()
665            .register_table("users", "tenant_id");
666
667        let tenant = TenantId::new("acme");
668        let config = TenantConfig::builder()
669            .id("acme")
670            .name("Acme")
671            .schema_isolation("shared", "acme")
672            .build();
673
674        let result = transformer.transform(
675            "SELECT * FROM users",
676            &tenant,
677            &config,
678        );
679
680        assert!(!result.transformed);
681    }
682
683    #[test]
684    fn test_excluded_tables() {
685        let transformer = TenantQueryTransformer::new()
686            .register_table("users", "tenant_id")
687            .register_table("audit_log", "tenant_id")
688            .exclude_table("audit_log");
689
690        let tenant = TenantId::new("acme");
691        let config = create_row_config("acme");
692
693        let result = transformer.transform(
694            "SELECT * FROM audit_log",
695            &tenant,
696            &config,
697        );
698
699        assert!(!result.transformed);
700    }
701
702    #[test]
703    fn test_extract_tables() {
704        let transformer = TenantQueryTransformer::new();
705
706        let tables = transformer.extract_tables(
707            "SELECT * FROM users u, orders o WHERE u.id = o.user_id"
708        );
709        assert!(tables.contains(&"users".to_string()));
710        assert!(tables.contains(&"orders".to_string()));
711
712        let tables = transformer.extract_tables(
713            "SELECT * FROM users JOIN orders ON users.id = orders.user_id"
714        );
715        assert!(tables.contains(&"users".to_string()));
716        assert!(tables.contains(&"orders".to_string()));
717    }
718
719    #[test]
720    fn test_set_schema_search_path() {
721        let transformer = TenantQueryTransformer::new();
722        let tenant = TenantId::new("acme");
723
724        let config = TenantConfig::builder()
725            .id("acme")
726            .name("Acme")
727            .schema_isolation("shared", "acme_schema")
728            .build();
729
730        let path = transformer.set_schema_search_path(&tenant, &config);
731        assert_eq!(path, Some("SET search_path TO acme_schema".to_string()));
732    }
733
734    #[test]
735    fn test_query_validation() {
736        let tenant = TenantId::new("acme");
737        let config = create_row_config("acme");
738
739        // Valid query
740        let validation = validate_query("SELECT * FROM users", &tenant, &config);
741        assert!(validation.valid);
742
743        // Invalid - TRUNCATE
744        let validation = validate_query("TRUNCATE users", &tenant, &config);
745        assert!(!validation.valid);
746
747        // Invalid - DROP TABLE
748        let validation = validate_query("DROP TABLE users", &tenant, &config);
749        assert!(!validation.valid);
750    }
751
752    #[test]
753    fn test_schema_cross_access_validation() {
754        let tenant = TenantId::new("acme");
755        let config = TenantConfig::builder()
756            .id("acme")
757            .name("Acme")
758            .schema_isolation("shared", "acme")
759            .build();
760
761        // Valid - own schema
762        let validation = validate_query("SELECT * FROM acme.users", &tenant, &config);
763        assert!(validation.valid);
764
765        // Invalid - other tenant's schema
766        let validation = validate_query("SELECT * FROM other_tenant.users", &tenant, &config);
767        assert!(!validation.valid);
768
769        // Valid - system catalog
770        let validation = validate_query("SELECT * FROM pg_catalog.pg_tables", &tenant, &config);
771        assert!(validation.valid);
772    }
773}