Skip to main content

heliosdb_proxy/multi_tenancy/
transformer.rs

1//! Tenant Query Transformer
2//!
3//! This module transforms SQL queries to apply tenant isolation, primarily
4//! for row-level security where queries need WHERE clause injection.
5
6use std::collections::{HashMap, HashSet};
7
8use super::config::{IsolationStrategy, TenantConfig, TenantId};
9
10/// Result of query transformation
11#[derive(Debug, Clone)]
12pub struct TransformResult {
13    /// Transformed query
14    pub query: String,
15
16    /// Whether transformation was applied
17    pub transformed: bool,
18
19    /// Tables that were filtered
20    pub filtered_tables: Vec<String>,
21
22    /// Warnings generated during transformation
23    pub warnings: Vec<String>,
24}
25
26impl TransformResult {
27    /// Create a passthrough result (no transformation)
28    pub fn passthrough(query: impl Into<String>) -> Self {
29        Self {
30            query: query.into(),
31            transformed: false,
32            filtered_tables: Vec::new(),
33            warnings: Vec::new(),
34        }
35    }
36
37    /// Create a transformed result
38    pub fn transformed(query: impl Into<String>, tables: Vec<String>) -> Self {
39        Self {
40            query: query.into(),
41            transformed: true,
42            filtered_tables: tables,
43            warnings: Vec::new(),
44        }
45    }
46
47    /// Add a warning
48    pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
49        self.warnings.push(warning.into());
50        self
51    }
52}
53
54/// Query transformer for tenant isolation
55pub struct TenantQueryTransformer {
56    /// Tables that require tenant filtering
57    tenant_tables: HashMap<String, String>,
58
59    /// Tables to exclude from filtering
60    excluded_tables: HashSet<String>,
61
62    /// Whether to use parameterized queries
63    use_parameters: bool,
64
65    /// Custom filter template (default: "{column} = '{value}'")
66    filter_template: Option<String>,
67}
68
69impl Default for TenantQueryTransformer {
70    fn default() -> Self {
71        Self::new()
72    }
73}
74
75impl TenantQueryTransformer {
76    /// Create a new query transformer
77    pub fn new() -> Self {
78        Self {
79            tenant_tables: HashMap::new(),
80            excluded_tables: HashSet::new(),
81            use_parameters: false,
82            filter_template: None,
83        }
84    }
85
86    /// Register a table with its tenant column
87    pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
88        self.tenant_tables
89            .insert(table.into().to_lowercase(), column.into());
90        self
91    }
92
93    /// Register multiple tables with the same column
94    pub fn register_tables(mut self, tables: &[&str], column: impl Into<String>) -> Self {
95        let col = column.into();
96        for table in tables {
97            self.tenant_tables.insert(table.to_lowercase(), col.clone());
98        }
99        self
100    }
101
102    /// Exclude a table from filtering
103    pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
104        self.excluded_tables.insert(table.into().to_lowercase());
105        self
106    }
107
108    /// Use parameterized queries
109    pub fn with_parameters(mut self) -> Self {
110        self.use_parameters = true;
111        self
112    }
113
114    /// Set custom filter template
115    pub fn with_filter_template(mut self, template: impl Into<String>) -> Self {
116        self.filter_template = Some(template.into());
117        self
118    }
119
120    /// Get the tenant column for a table
121    pub fn get_tenant_column(&self, table: &str) -> Option<&str> {
122        self.tenant_tables
123            .get(&table.to_lowercase())
124            .map(|s| s.as_str())
125    }
126
127    /// Check if a table requires filtering
128    pub fn requires_filtering(&self, table: &str) -> bool {
129        let lower = table.to_lowercase();
130        self.tenant_tables.contains_key(&lower) && !self.excluded_tables.contains(&lower)
131    }
132
133    /// Transform a query for a tenant
134    pub fn transform(
135        &self,
136        query: &str,
137        tenant: &TenantId,
138        config: &TenantConfig,
139    ) -> TransformResult {
140        // Only transform for row-level isolation
141        let tenant_column = match &config.isolation {
142            IsolationStrategy::Row { tenant_column, .. } => tenant_column,
143            _ => return TransformResult::passthrough(query),
144        };
145
146        // Parse and transform query
147        let upper = query.trim().to_uppercase();
148
149        if upper.starts_with("SELECT") {
150            self.transform_select(query, tenant, tenant_column)
151        } else if upper.starts_with("UPDATE") {
152            self.transform_update(query, tenant, tenant_column)
153        } else if upper.starts_with("DELETE") {
154            self.transform_delete(query, tenant, tenant_column)
155        } else if upper.starts_with("INSERT") {
156            self.transform_insert(query, tenant, tenant_column)
157        } else {
158            TransformResult::passthrough(query)
159        }
160    }
161
162    /// Transform a SELECT query
163    fn transform_select(
164        &self,
165        query: &str,
166        tenant: &TenantId,
167        tenant_column: &str,
168    ) -> TransformResult {
169        let tables = self.extract_tables(query);
170        let filtered_tables: Vec<String> = tables
171            .iter()
172            .filter(|t| self.requires_filtering(t))
173            .cloned()
174            .collect();
175
176        if filtered_tables.is_empty() {
177            return TransformResult::passthrough(query);
178        }
179
180        let filter = self.build_filter(tenant, tenant_column, &filtered_tables);
181        let transformed = self.inject_where_clause(query, &filter);
182
183        TransformResult::transformed(transformed, filtered_tables)
184    }
185
186    /// Transform an UPDATE query
187    fn transform_update(
188        &self,
189        query: &str,
190        tenant: &TenantId,
191        tenant_column: &str,
192    ) -> TransformResult {
193        let table = self.extract_update_table(query);
194
195        if let Some(table) = table {
196            if self.requires_filtering(&table) {
197                let filter = self.build_single_filter(tenant, tenant_column);
198                let transformed = self.inject_where_clause(query, &filter);
199                return TransformResult::transformed(transformed, vec![table]);
200            }
201        }
202
203        TransformResult::passthrough(query)
204    }
205
206    /// Transform a DELETE query
207    fn transform_delete(
208        &self,
209        query: &str,
210        tenant: &TenantId,
211        tenant_column: &str,
212    ) -> TransformResult {
213        let table = self.extract_delete_table(query);
214
215        if let Some(table) = table {
216            if self.requires_filtering(&table) {
217                let filter = self.build_single_filter(tenant, tenant_column);
218                let transformed = self.inject_where_clause(query, &filter);
219                return TransformResult::transformed(transformed, vec![table]);
220            }
221        }
222
223        TransformResult::passthrough(query)
224    }
225
226    /// Transform an INSERT query (add tenant_id column)
227    fn transform_insert(
228        &self,
229        query: &str,
230        tenant: &TenantId,
231        tenant_column: &str,
232    ) -> TransformResult {
233        let table = self.extract_insert_table(query);
234
235        if let Some(table) = table {
236            if self.requires_filtering(&table) {
237                // For INSERT, we need to add tenant_id to the values
238                let transformed = self.inject_tenant_value(query, tenant, tenant_column);
239                return TransformResult::transformed(transformed, vec![table])
240                    .with_warning("Tenant column injection may require schema awareness");
241            }
242        }
243
244        TransformResult::passthrough(query)
245    }
246
247    /// Build a filter clause for multiple tables
248    fn build_filter(&self, tenant: &TenantId, default_column: &str, tables: &[String]) -> String {
249        let filters: Vec<String> = tables
250            .iter()
251            .map(|table| {
252                let column = self.get_tenant_column(table).unwrap_or(default_column);
253                if self.use_parameters {
254                    format!("{}.{} = $1", table, column)
255                } else {
256                    format!("{}.{} = '{}'", table, column, tenant.0)
257                }
258            })
259            .collect();
260
261        filters.join(" AND ")
262    }
263
264    /// Build a single filter clause
265    fn build_single_filter(&self, tenant: &TenantId, column: &str) -> String {
266        if self.use_parameters {
267            format!("{} = $1", column)
268        } else {
269            match &self.filter_template {
270                Some(template) => template
271                    .replace("{column}", column)
272                    .replace("{value}", &tenant.0),
273                None => format!("{} = '{}'", column, tenant.0),
274            }
275        }
276    }
277
278    /// Inject WHERE clause into query
279    fn inject_where_clause(&self, query: &str, filter: &str) -> String {
280        let upper = query.to_uppercase();
281
282        // Find existing WHERE clause
283        if let Some(where_pos) = upper.find(" WHERE ") {
284            // Add to existing WHERE
285            let (before, after) = query.split_at(where_pos + 7);
286            format!("{}{} AND {}", before, filter, after)
287        } else {
288            // Find position to insert WHERE
289            // Look for ORDER BY, GROUP BY, LIMIT, etc.
290            let insert_before = [" ORDER ", " GROUP ", " LIMIT ", " HAVING ", " UNION "]
291                .iter()
292                .filter_map(|kw| upper.find(kw))
293                .min();
294
295            match insert_before {
296                Some(pos) => {
297                    let (before, after) = query.split_at(pos);
298                    format!("{} WHERE {}{}", before, filter, after)
299                }
300                None => {
301                    // Append at end
302                    format!("{} WHERE {}", query.trim_end_matches(';'), filter)
303                }
304            }
305        }
306    }
307
308    /// Inject tenant value into INSERT statement
309    fn inject_tenant_value(&self, query: &str, tenant: &TenantId, column: &str) -> String {
310        // Simple implementation - a real one would parse SQL properly
311        let upper = query.to_uppercase();
312
313        if let Some(values_pos) = upper.find(" VALUES ") {
314            if let Some(paren_pos) = query[values_pos..].find('(') {
315                let insert_pos = values_pos + paren_pos + 1;
316
317                // Check if there's a column list
318                if let Some(cols_start) = upper.find('(') {
319                    if cols_start < values_pos {
320                        // There's a column list - add column to it
321                        let cols_end = upper[cols_start..].find(')').unwrap_or(0) + cols_start;
322                        let before_cols_end = &query[..cols_end];
323                        let after_cols_end = &query[cols_end..];
324
325                        // Insert column name
326                        let with_column =
327                            format!("{}, {}{}", before_cols_end, column, after_cols_end);
328
329                        // Now insert the value
330                        let upper_new = with_column.to_uppercase();
331                        if let Some(new_values_pos) = upper_new.find(" VALUES ") {
332                            if let Some(new_paren_pos) = with_column[new_values_pos..].find('(') {
333                                let new_insert_pos = new_values_pos + new_paren_pos + 1;
334                                let before = &with_column[..new_insert_pos];
335                                let after = &with_column[new_insert_pos..];
336                                return format!("{}'{}'", before, tenant.0)
337                                    + if !after.starts_with(')') { ", " } else { "" }
338                                    + after;
339                            }
340                        }
341                    }
342                }
343
344                // No column list or couldn't parse - just add to values
345                let before = &query[..insert_pos];
346                let after = &query[insert_pos..];
347                return format!("{}'{}'", before, tenant.0)
348                    + if !after.starts_with(')') { ", " } else { "" }
349                    + after;
350            }
351        }
352
353        query.to_string()
354    }
355
356    /// Extract table names from SELECT query
357    fn extract_tables(&self, query: &str) -> Vec<String> {
358        let upper = query.to_uppercase();
359        let mut tables = Vec::new();
360
361        // Find FROM clause
362        if let Some(from_pos) = upper.find(" FROM ") {
363            let after_from = &query[from_pos + 6..];
364
365            // Find end of table list
366            let end_markers = [
367                " WHERE ", " JOIN ", " LEFT ", " RIGHT ", " INNER ", " OUTER ", " GROUP ",
368                " ORDER ", " LIMIT ", " HAVING ",
369            ];
370            let end_pos = end_markers
371                .iter()
372                .filter_map(|m| after_from.to_uppercase().find(m))
373                .min()
374                .unwrap_or(after_from.len());
375
376            let table_section = &after_from[..end_pos];
377
378            // Parse table names (handle aliases)
379            for part in table_section.split(',') {
380                let trimmed = part.trim();
381                if let Some(table) = trimmed.split_whitespace().next() {
382                    let clean =
383                        table.trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
384                    if !clean.is_empty() {
385                        tables.push(clean.to_string());
386                    }
387                }
388            }
389        }
390
391        // Also look for JOINs
392        let words: Vec<&str> = query.split_whitespace().collect();
393        for (i, word) in words.iter().enumerate() {
394            if word.to_uppercase() == "JOIN" && i + 1 < words.len() {
395                let table =
396                    words[i + 1].trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
397                if !table.is_empty() && !tables.contains(&table.to_string()) {
398                    tables.push(table.to_string());
399                }
400            }
401        }
402
403        tables
404    }
405
406    /// Extract table name from UPDATE query
407    fn extract_update_table(&self, query: &str) -> Option<String> {
408        let upper = query.to_uppercase();
409        if let Some(update_pos) = upper.find("UPDATE ") {
410            let after_update = &query[update_pos + 7..];
411            if let Some(set_pos) = after_update.to_uppercase().find(" SET ") {
412                let table_section = &after_update[..set_pos];
413                let table = table_section
414                    .split_whitespace()
415                    .next()?
416                    .trim_matches(|c| c == '"' || c == '`');
417                return Some(table.to_string());
418            }
419        }
420        None
421    }
422
423    /// Extract table name from DELETE query
424    fn extract_delete_table(&self, query: &str) -> Option<String> {
425        let upper = query.to_uppercase();
426        if let Some(from_pos) = upper.find(" FROM ") {
427            let after_from = &query[from_pos + 6..];
428            let end_pos = after_from
429                .to_uppercase()
430                .find(" WHERE ")
431                .unwrap_or(after_from.len());
432            let table_section = &after_from[..end_pos];
433            let table = table_section
434                .split_whitespace()
435                .next()?
436                .trim_matches(|c| c == '"' || c == '`');
437            return Some(table.to_string());
438        }
439        None
440    }
441
442    /// Extract table name from INSERT query
443    fn extract_insert_table(&self, query: &str) -> Option<String> {
444        let upper = query.to_uppercase();
445        if let Some(into_pos) = upper.find(" INTO ") {
446            let after_into = &query[into_pos + 6..];
447            let end_pos = after_into
448                .find(|c: char| c == '(' || c.is_whitespace())
449                .unwrap_or(after_into.len());
450            let table = after_into[..end_pos]
451                .trim()
452                .trim_matches(|c| c == '"' || c == '`');
453            return Some(table.to_string());
454        }
455        None
456    }
457
458    /// Generate SET search_path command for schema isolation
459    pub fn set_schema_search_path(
460        &self,
461        _tenant: &TenantId,
462        config: &TenantConfig,
463    ) -> Option<String> {
464        if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
465            Some(format!("SET search_path TO {}", schema_name))
466        } else {
467            None
468        }
469    }
470
471    /// Generate USE database command for database isolation
472    pub fn use_database(&self, _tenant: &TenantId, config: &TenantConfig) -> Option<String> {
473        if let IsolationStrategy::Database { database_name } = &config.isolation {
474            Some(format!("USE {}", database_name))
475        } else {
476            None
477        }
478    }
479}
480
481/// Validate that a query doesn't try to bypass tenant isolation
482pub fn validate_query(query: &str, _tenant: &TenantId, config: &TenantConfig) -> QueryValidation {
483    let mut validation = QueryValidation {
484        valid: true,
485        violations: Vec::new(),
486    };
487
488    let upper = query.to_uppercase();
489
490    // Check for dangerous operations
491    if let IsolationStrategy::Row { tenant_column, .. } = &config.isolation {
492        // Check if query tries to modify tenant column
493        if upper.contains(&format!("{} =", tenant_column.to_uppercase())) {
494            let set_pattern = format!("SET {} =", tenant_column.to_uppercase());
495            if upper.contains(&set_pattern) {
496                validation.valid = false;
497                validation
498                    .violations
499                    .push(format!("Cannot modify tenant column: {}", tenant_column));
500            }
501        }
502
503        // Check for TRUNCATE (bypasses row-level security)
504        if upper.starts_with("TRUNCATE ") {
505            validation.valid = false;
506            validation
507                .violations
508                .push("TRUNCATE not allowed with row-level isolation".to_string());
509        }
510
511        // Check for DROP TABLE
512        if upper.contains("DROP TABLE") {
513            validation.valid = false;
514            validation
515                .violations
516                .push("DROP TABLE not allowed with row-level isolation".to_string());
517        }
518    }
519
520    // Check for cross-schema access in schema isolation
521    if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
522        // Look for schema.table patterns that don't match tenant's schema
523        let parts: Vec<&str> = upper.split_whitespace().collect();
524        for part in parts {
525            if part.contains('.') && !part.starts_with(&schema_name.to_uppercase()) {
526                let schema = part.split('.').next().unwrap_or("");
527                if !schema.eq_ignore_ascii_case("pg_catalog")
528                    && !schema.eq_ignore_ascii_case("information_schema")
529                {
530                    validation.valid = false;
531                    validation
532                        .violations
533                        .push(format!("Cross-schema access not allowed: {}", part));
534                }
535            }
536        }
537    }
538
539    validation
540}
541
542/// Result of query validation
543#[derive(Debug, Clone)]
544pub struct QueryValidation {
545    /// Whether query is valid
546    pub valid: bool,
547
548    /// List of violations
549    pub violations: Vec<String>,
550}
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555
556    fn create_row_config(tenant_id: &str) -> TenantConfig {
557        TenantConfig::builder()
558            .id(tenant_id)
559            .name("Test")
560            .row_isolation("shared_db", "tenant_id")
561            .build()
562    }
563
564    #[test]
565    fn test_transform_select() {
566        let transformer = TenantQueryTransformer::new()
567            .register_table("users", "tenant_id")
568            .register_table("orders", "tenant_id");
569
570        let tenant = TenantId::new("acme");
571        let config = create_row_config("acme");
572
573        let result =
574            transformer.transform("SELECT * FROM users WHERE active = true", &tenant, &config);
575
576        assert!(result.transformed);
577        assert!(result.query.contains("tenant_id = 'acme'"));
578        assert!(result.query.contains("AND active = true"));
579    }
580
581    #[test]
582    fn test_transform_select_no_where() {
583        let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
584
585        let tenant = TenantId::new("acme");
586        let config = create_row_config("acme");
587
588        let result = transformer.transform("SELECT * FROM users ORDER BY id", &tenant, &config);
589
590        assert!(result.transformed);
591        assert!(result.query.contains("WHERE users.tenant_id = 'acme'"));
592        assert!(result.query.contains("ORDER BY id"));
593    }
594
595    #[test]
596    fn test_transform_update() {
597        let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
598
599        let tenant = TenantId::new("acme");
600        let config = create_row_config("acme");
601
602        let result = transformer.transform(
603            "UPDATE users SET name = 'John' WHERE id = 1",
604            &tenant,
605            &config,
606        );
607
608        assert!(result.transformed);
609        assert!(result.query.contains("tenant_id = 'acme'"));
610    }
611
612    #[test]
613    fn test_transform_delete() {
614        let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
615
616        let tenant = TenantId::new("acme");
617        let config = create_row_config("acme");
618
619        let result = transformer.transform("DELETE FROM users WHERE id = 1", &tenant, &config);
620
621        assert!(result.transformed);
622        assert!(result.query.contains("tenant_id = 'acme'"));
623    }
624
625    #[test]
626    fn test_no_transform_for_unregistered_table() {
627        let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
628
629        let tenant = TenantId::new("acme");
630        let config = create_row_config("acme");
631
632        let result =
633            transformer.transform("SELECT * FROM logs WHERE level = 'error'", &tenant, &config);
634
635        assert!(!result.transformed);
636    }
637
638    #[test]
639    fn test_no_transform_for_schema_isolation() {
640        let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
641
642        let tenant = TenantId::new("acme");
643        let config = TenantConfig::builder()
644            .id("acme")
645            .name("Acme")
646            .schema_isolation("shared", "acme")
647            .build();
648
649        let result = transformer.transform("SELECT * FROM users", &tenant, &config);
650
651        assert!(!result.transformed);
652    }
653
654    #[test]
655    fn test_excluded_tables() {
656        let transformer = TenantQueryTransformer::new()
657            .register_table("users", "tenant_id")
658            .register_table("audit_log", "tenant_id")
659            .exclude_table("audit_log");
660
661        let tenant = TenantId::new("acme");
662        let config = create_row_config("acme");
663
664        let result = transformer.transform("SELECT * FROM audit_log", &tenant, &config);
665
666        assert!(!result.transformed);
667    }
668
669    #[test]
670    fn test_extract_tables() {
671        let transformer = TenantQueryTransformer::new();
672
673        let tables =
674            transformer.extract_tables("SELECT * FROM users u, orders o WHERE u.id = o.user_id");
675        assert!(tables.contains(&"users".to_string()));
676        assert!(tables.contains(&"orders".to_string()));
677
678        let tables = transformer
679            .extract_tables("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
680        assert!(tables.contains(&"users".to_string()));
681        assert!(tables.contains(&"orders".to_string()));
682    }
683
684    #[test]
685    fn test_set_schema_search_path() {
686        let transformer = TenantQueryTransformer::new();
687        let tenant = TenantId::new("acme");
688
689        let config = TenantConfig::builder()
690            .id("acme")
691            .name("Acme")
692            .schema_isolation("shared", "acme_schema")
693            .build();
694
695        let path = transformer.set_schema_search_path(&tenant, &config);
696        assert_eq!(path, Some("SET search_path TO acme_schema".to_string()));
697    }
698
699    #[test]
700    fn test_query_validation() {
701        let tenant = TenantId::new("acme");
702        let config = create_row_config("acme");
703
704        // Valid query
705        let validation = validate_query("SELECT * FROM users", &tenant, &config);
706        assert!(validation.valid);
707
708        // Invalid - TRUNCATE
709        let validation = validate_query("TRUNCATE users", &tenant, &config);
710        assert!(!validation.valid);
711
712        // Invalid - DROP TABLE
713        let validation = validate_query("DROP TABLE users", &tenant, &config);
714        assert!(!validation.valid);
715    }
716
717    #[test]
718    fn test_schema_cross_access_validation() {
719        let tenant = TenantId::new("acme");
720        let config = TenantConfig::builder()
721            .id("acme")
722            .name("Acme")
723            .schema_isolation("shared", "acme")
724            .build();
725
726        // Valid - own schema
727        let validation = validate_query("SELECT * FROM acme.users", &tenant, &config);
728        assert!(validation.valid);
729
730        // Invalid - other tenant's schema
731        let validation = validate_query("SELECT * FROM other_tenant.users", &tenant, &config);
732        assert!(!validation.valid);
733
734        // Valid - system catalog
735        let validation = validate_query("SELECT * FROM pg_catalog.pg_tables", &tenant, &config);
736        assert!(validation.valid);
737    }
738}