1use std::collections::{HashMap, HashSet};
7
8use super::config::{IsolationStrategy, TenantConfig, TenantId};
9
10#[derive(Debug, Clone)]
12pub struct TransformResult {
13 pub query: String,
15
16 pub transformed: bool,
18
19 pub filtered_tables: Vec<String>,
21
22 pub warnings: Vec<String>,
24}
25
26impl TransformResult {
27 pub fn passthrough(query: impl Into<String>) -> Self {
29 Self {
30 query: query.into(),
31 transformed: false,
32 filtered_tables: Vec::new(),
33 warnings: Vec::new(),
34 }
35 }
36
37 pub fn transformed(query: impl Into<String>, tables: Vec<String>) -> Self {
39 Self {
40 query: query.into(),
41 transformed: true,
42 filtered_tables: tables,
43 warnings: Vec::new(),
44 }
45 }
46
47 pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
49 self.warnings.push(warning.into());
50 self
51 }
52}
53
54pub struct TenantQueryTransformer {
56 tenant_tables: HashMap<String, String>,
58
59 excluded_tables: HashSet<String>,
61
62 use_parameters: bool,
64
65 filter_template: Option<String>,
67}
68
69impl Default for TenantQueryTransformer {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl TenantQueryTransformer {
76 pub fn new() -> Self {
78 Self {
79 tenant_tables: HashMap::new(),
80 excluded_tables: HashSet::new(),
81 use_parameters: false,
82 filter_template: None,
83 }
84 }
85
86 pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
88 self.tenant_tables
89 .insert(table.into().to_lowercase(), column.into());
90 self
91 }
92
93 pub fn register_tables(mut self, tables: &[&str], column: impl Into<String>) -> Self {
95 let col = column.into();
96 for table in tables {
97 self.tenant_tables
98 .insert(table.to_lowercase(), col.clone());
99 }
100 self
101 }
102
103 pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
105 self.excluded_tables.insert(table.into().to_lowercase());
106 self
107 }
108
109 pub fn with_parameters(mut self) -> Self {
111 self.use_parameters = true;
112 self
113 }
114
115 pub fn with_filter_template(mut self, template: impl Into<String>) -> Self {
117 self.filter_template = Some(template.into());
118 self
119 }
120
121 pub fn get_tenant_column(&self, table: &str) -> Option<&str> {
123 self.tenant_tables
124 .get(&table.to_lowercase())
125 .map(|s| s.as_str())
126 }
127
128 pub fn requires_filtering(&self, table: &str) -> bool {
130 let lower = table.to_lowercase();
131 self.tenant_tables.contains_key(&lower) && !self.excluded_tables.contains(&lower)
132 }
133
134 pub fn transform(
136 &self,
137 query: &str,
138 tenant: &TenantId,
139 config: &TenantConfig,
140 ) -> TransformResult {
141 let tenant_column = match &config.isolation {
143 IsolationStrategy::Row { tenant_column, .. } => tenant_column,
144 _ => return TransformResult::passthrough(query),
145 };
146
147 let upper = query.trim().to_uppercase();
149
150 if upper.starts_with("SELECT") {
151 self.transform_select(query, tenant, tenant_column)
152 } else if upper.starts_with("UPDATE") {
153 self.transform_update(query, tenant, tenant_column)
154 } else if upper.starts_with("DELETE") {
155 self.transform_delete(query, tenant, tenant_column)
156 } else if upper.starts_with("INSERT") {
157 self.transform_insert(query, tenant, tenant_column)
158 } else {
159 TransformResult::passthrough(query)
160 }
161 }
162
163 fn transform_select(
165 &self,
166 query: &str,
167 tenant: &TenantId,
168 tenant_column: &str,
169 ) -> TransformResult {
170 let tables = self.extract_tables(query);
171 let filtered_tables: Vec<String> = tables
172 .iter()
173 .filter(|t| self.requires_filtering(t))
174 .cloned()
175 .collect();
176
177 if filtered_tables.is_empty() {
178 return TransformResult::passthrough(query);
179 }
180
181 let filter = self.build_filter(tenant, tenant_column, &filtered_tables);
182 let transformed = self.inject_where_clause(query, &filter);
183
184 TransformResult::transformed(transformed, filtered_tables)
185 }
186
187 fn transform_update(
189 &self,
190 query: &str,
191 tenant: &TenantId,
192 tenant_column: &str,
193 ) -> TransformResult {
194 let table = self.extract_update_table(query);
195
196 if let Some(table) = table {
197 if self.requires_filtering(&table) {
198 let filter = self.build_single_filter(tenant, tenant_column);
199 let transformed = self.inject_where_clause(query, &filter);
200 return TransformResult::transformed(transformed, vec![table]);
201 }
202 }
203
204 TransformResult::passthrough(query)
205 }
206
207 fn transform_delete(
209 &self,
210 query: &str,
211 tenant: &TenantId,
212 tenant_column: &str,
213 ) -> TransformResult {
214 let table = self.extract_delete_table(query);
215
216 if let Some(table) = table {
217 if self.requires_filtering(&table) {
218 let filter = self.build_single_filter(tenant, tenant_column);
219 let transformed = self.inject_where_clause(query, &filter);
220 return TransformResult::transformed(transformed, vec![table]);
221 }
222 }
223
224 TransformResult::passthrough(query)
225 }
226
227 fn transform_insert(
229 &self,
230 query: &str,
231 tenant: &TenantId,
232 tenant_column: &str,
233 ) -> TransformResult {
234 let table = self.extract_insert_table(query);
235
236 if let Some(table) = table {
237 if self.requires_filtering(&table) {
238 let transformed = self.inject_tenant_value(query, tenant, tenant_column);
240 return TransformResult::transformed(transformed, vec![table])
241 .with_warning("Tenant column injection may require schema awareness");
242 }
243 }
244
245 TransformResult::passthrough(query)
246 }
247
248 fn build_filter(
250 &self,
251 tenant: &TenantId,
252 default_column: &str,
253 tables: &[String],
254 ) -> String {
255 let filters: Vec<String> = tables
256 .iter()
257 .map(|table| {
258 let column = self
259 .get_tenant_column(table)
260 .unwrap_or(default_column);
261 if self.use_parameters {
262 format!("{}.{} = $1", table, column)
263 } else {
264 format!("{}.{} = '{}'", table, column, tenant.0)
265 }
266 })
267 .collect();
268
269 filters.join(" AND ")
270 }
271
272 fn build_single_filter(&self, tenant: &TenantId, column: &str) -> String {
274 if self.use_parameters {
275 format!("{} = $1", column)
276 } else {
277 match &self.filter_template {
278 Some(template) => template
279 .replace("{column}", column)
280 .replace("{value}", &tenant.0),
281 None => format!("{} = '{}'", column, tenant.0),
282 }
283 }
284 }
285
286 fn inject_where_clause(&self, query: &str, filter: &str) -> String {
288 let upper = query.to_uppercase();
289
290 if let Some(where_pos) = upper.find(" WHERE ") {
292 let (before, after) = query.split_at(where_pos + 7);
294 format!("{}{} AND {}", before, filter, after)
295 } else {
296 let insert_before = [" ORDER ", " GROUP ", " LIMIT ", " HAVING ", " UNION "]
299 .iter()
300 .filter_map(|kw| upper.find(kw))
301 .min();
302
303 match insert_before {
304 Some(pos) => {
305 let (before, after) = query.split_at(pos);
306 format!("{} WHERE {}{}", before, filter, after)
307 }
308 None => {
309 format!("{} WHERE {}", query.trim_end_matches(';'), filter)
311 }
312 }
313 }
314 }
315
316 fn inject_tenant_value(&self, query: &str, tenant: &TenantId, column: &str) -> String {
318 let upper = query.to_uppercase();
320
321 if let Some(values_pos) = upper.find(" VALUES ") {
322 if let Some(paren_pos) = query[values_pos..].find('(') {
323 let insert_pos = values_pos + paren_pos + 1;
324
325 if let Some(cols_start) = upper.find('(') {
327 if cols_start < values_pos {
328 let cols_end = upper[cols_start..].find(')').unwrap_or(0) + cols_start;
330 let before_cols_end = &query[..cols_end];
331 let after_cols_end = &query[cols_end..];
332
333 let with_column =
335 format!("{}, {}{}", before_cols_end, column, after_cols_end);
336
337 let upper_new = with_column.to_uppercase();
339 if let Some(new_values_pos) = upper_new.find(" VALUES ") {
340 if let Some(new_paren_pos) = with_column[new_values_pos..].find('(') {
341 let new_insert_pos = new_values_pos + new_paren_pos + 1;
342 let before = &with_column[..new_insert_pos];
343 let after = &with_column[new_insert_pos..];
344 return format!("{}'{}'", before, tenant.0)
345 + if !after.starts_with(')') { ", " } else { "" }
346 + after;
347 }
348 }
349 }
350 }
351
352 let before = &query[..insert_pos];
354 let after = &query[insert_pos..];
355 return format!("{}'{}'", before, tenant.0)
356 + if !after.starts_with(')') { ", " } else { "" }
357 + after;
358 }
359 }
360
361 query.to_string()
362 }
363
364 fn extract_tables(&self, query: &str) -> Vec<String> {
366 let upper = query.to_uppercase();
367 let mut tables = Vec::new();
368
369 if let Some(from_pos) = upper.find(" FROM ") {
371 let after_from = &query[from_pos + 6..];
372
373 let end_markers = [" WHERE ", " JOIN ", " LEFT ", " RIGHT ", " INNER ", " OUTER ",
375 " GROUP ", " ORDER ", " LIMIT ", " HAVING "];
376 let end_pos = end_markers
377 .iter()
378 .filter_map(|m| after_from.to_uppercase().find(m))
379 .min()
380 .unwrap_or(after_from.len());
381
382 let table_section = &after_from[..end_pos];
383
384 for part in table_section.split(',') {
386 let trimmed = part.trim();
387 if let Some(table) = trimmed.split_whitespace().next() {
388 let clean = table
389 .trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
390 if !clean.is_empty() {
391 tables.push(clean.to_string());
392 }
393 }
394 }
395 }
396
397 let words: Vec<&str> = query.split_whitespace().collect();
399 for (i, word) in words.iter().enumerate() {
400 if word.to_uppercase() == "JOIN" && i + 1 < words.len() {
401 let table = words[i + 1]
402 .trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
403 if !table.is_empty() && !tables.contains(&table.to_string()) {
404 tables.push(table.to_string());
405 }
406 }
407 }
408
409 tables
410 }
411
412 fn extract_update_table(&self, query: &str) -> Option<String> {
414 let upper = query.to_uppercase();
415 if let Some(update_pos) = upper.find("UPDATE ") {
416 let after_update = &query[update_pos + 7..];
417 if let Some(set_pos) = after_update.to_uppercase().find(" SET ") {
418 let table_section = &after_update[..set_pos];
419 let table = table_section
420 .trim()
421 .split_whitespace()
422 .next()?
423 .trim_matches(|c| c == '"' || c == '`');
424 return Some(table.to_string());
425 }
426 }
427 None
428 }
429
430 fn extract_delete_table(&self, query: &str) -> Option<String> {
432 let upper = query.to_uppercase();
433 if let Some(from_pos) = upper.find(" FROM ") {
434 let after_from = &query[from_pos + 6..];
435 let end_pos = after_from.to_uppercase().find(" WHERE ")
436 .unwrap_or(after_from.len());
437 let table_section = &after_from[..end_pos];
438 let table = table_section
439 .trim()
440 .split_whitespace()
441 .next()?
442 .trim_matches(|c| c == '"' || c == '`');
443 return Some(table.to_string());
444 }
445 None
446 }
447
448 fn extract_insert_table(&self, query: &str) -> Option<String> {
450 let upper = query.to_uppercase();
451 if let Some(into_pos) = upper.find(" INTO ") {
452 let after_into = &query[into_pos + 6..];
453 let end_pos = after_into.find(|c: char| c == '(' || c.is_whitespace())
454 .unwrap_or(after_into.len());
455 let table = after_into[..end_pos]
456 .trim()
457 .trim_matches(|c| c == '"' || c == '`');
458 return Some(table.to_string());
459 }
460 None
461 }
462
463 pub fn set_schema_search_path(
465 &self,
466 _tenant: &TenantId,
467 config: &TenantConfig,
468 ) -> Option<String> {
469 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
470 Some(format!("SET search_path TO {}", schema_name))
471 } else {
472 None
473 }
474 }
475
476 pub fn use_database(&self, _tenant: &TenantId, config: &TenantConfig) -> Option<String> {
478 if let IsolationStrategy::Database { database_name } = &config.isolation {
479 Some(format!("USE {}", database_name))
480 } else {
481 None
482 }
483 }
484}
485
486pub fn validate_query(query: &str, _tenant: &TenantId, config: &TenantConfig) -> QueryValidation {
488 let mut validation = QueryValidation {
489 valid: true,
490 violations: Vec::new(),
491 };
492
493 let upper = query.to_uppercase();
494
495 if let IsolationStrategy::Row { tenant_column, .. } = &config.isolation {
497 if upper.contains(&format!("{} =", tenant_column.to_uppercase())) {
499 let set_pattern = format!("SET {} =", tenant_column.to_uppercase());
500 if upper.contains(&set_pattern) {
501 validation.valid = false;
502 validation
503 .violations
504 .push(format!("Cannot modify tenant column: {}", tenant_column));
505 }
506 }
507
508 if upper.starts_with("TRUNCATE ") {
510 validation.valid = false;
511 validation
512 .violations
513 .push("TRUNCATE not allowed with row-level isolation".to_string());
514 }
515
516 if upper.contains("DROP TABLE") {
518 validation.valid = false;
519 validation
520 .violations
521 .push("DROP TABLE not allowed with row-level isolation".to_string());
522 }
523 }
524
525 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
527 let parts: Vec<&str> = upper.split_whitespace().collect();
529 for part in parts {
530 if part.contains('.') && !part.starts_with(&schema_name.to_uppercase()) {
531 let schema = part.split('.').next().unwrap_or("");
532 if !schema.eq_ignore_ascii_case("pg_catalog")
533 && !schema.eq_ignore_ascii_case("information_schema")
534 {
535 validation.valid = false;
536 validation.violations.push(format!(
537 "Cross-schema access not allowed: {}",
538 part
539 ));
540 }
541 }
542 }
543 }
544
545 validation
546}
547
548#[derive(Debug, Clone)]
550pub struct QueryValidation {
551 pub valid: bool,
553
554 pub violations: Vec<String>,
556}
557
558#[cfg(test)]
559mod tests {
560 use super::*;
561
562 fn create_row_config(tenant_id: &str) -> TenantConfig {
563 TenantConfig::builder()
564 .id(tenant_id)
565 .name("Test")
566 .row_isolation("shared_db", "tenant_id")
567 .build()
568 }
569
570 #[test]
571 fn test_transform_select() {
572 let transformer = TenantQueryTransformer::new()
573 .register_table("users", "tenant_id")
574 .register_table("orders", "tenant_id");
575
576 let tenant = TenantId::new("acme");
577 let config = create_row_config("acme");
578
579 let result = transformer.transform(
580 "SELECT * FROM users WHERE active = true",
581 &tenant,
582 &config,
583 );
584
585 assert!(result.transformed);
586 assert!(result.query.contains("tenant_id = 'acme'"));
587 assert!(result.query.contains("AND active = true"));
588 }
589
590 #[test]
591 fn test_transform_select_no_where() {
592 let transformer = TenantQueryTransformer::new()
593 .register_table("users", "tenant_id");
594
595 let tenant = TenantId::new("acme");
596 let config = create_row_config("acme");
597
598 let result = transformer.transform(
599 "SELECT * FROM users ORDER BY id",
600 &tenant,
601 &config,
602 );
603
604 assert!(result.transformed);
605 assert!(result.query.contains("WHERE users.tenant_id = 'acme'"));
606 assert!(result.query.contains("ORDER BY id"));
607 }
608
609 #[test]
610 fn test_transform_update() {
611 let transformer = TenantQueryTransformer::new()
612 .register_table("users", "tenant_id");
613
614 let tenant = TenantId::new("acme");
615 let config = create_row_config("acme");
616
617 let result = transformer.transform(
618 "UPDATE users SET name = 'John' WHERE id = 1",
619 &tenant,
620 &config,
621 );
622
623 assert!(result.transformed);
624 assert!(result.query.contains("tenant_id = 'acme'"));
625 }
626
627 #[test]
628 fn test_transform_delete() {
629 let transformer = TenantQueryTransformer::new()
630 .register_table("users", "tenant_id");
631
632 let tenant = TenantId::new("acme");
633 let config = create_row_config("acme");
634
635 let result = transformer.transform(
636 "DELETE FROM users WHERE id = 1",
637 &tenant,
638 &config,
639 );
640
641 assert!(result.transformed);
642 assert!(result.query.contains("tenant_id = 'acme'"));
643 }
644
645 #[test]
646 fn test_no_transform_for_unregistered_table() {
647 let transformer = TenantQueryTransformer::new()
648 .register_table("users", "tenant_id");
649
650 let tenant = TenantId::new("acme");
651 let config = create_row_config("acme");
652
653 let result = transformer.transform(
654 "SELECT * FROM logs WHERE level = 'error'",
655 &tenant,
656 &config,
657 );
658
659 assert!(!result.transformed);
660 }
661
662 #[test]
663 fn test_no_transform_for_schema_isolation() {
664 let transformer = TenantQueryTransformer::new()
665 .register_table("users", "tenant_id");
666
667 let tenant = TenantId::new("acme");
668 let config = TenantConfig::builder()
669 .id("acme")
670 .name("Acme")
671 .schema_isolation("shared", "acme")
672 .build();
673
674 let result = transformer.transform(
675 "SELECT * FROM users",
676 &tenant,
677 &config,
678 );
679
680 assert!(!result.transformed);
681 }
682
683 #[test]
684 fn test_excluded_tables() {
685 let transformer = TenantQueryTransformer::new()
686 .register_table("users", "tenant_id")
687 .register_table("audit_log", "tenant_id")
688 .exclude_table("audit_log");
689
690 let tenant = TenantId::new("acme");
691 let config = create_row_config("acme");
692
693 let result = transformer.transform(
694 "SELECT * FROM audit_log",
695 &tenant,
696 &config,
697 );
698
699 assert!(!result.transformed);
700 }
701
702 #[test]
703 fn test_extract_tables() {
704 let transformer = TenantQueryTransformer::new();
705
706 let tables = transformer.extract_tables(
707 "SELECT * FROM users u, orders o WHERE u.id = o.user_id"
708 );
709 assert!(tables.contains(&"users".to_string()));
710 assert!(tables.contains(&"orders".to_string()));
711
712 let tables = transformer.extract_tables(
713 "SELECT * FROM users JOIN orders ON users.id = orders.user_id"
714 );
715 assert!(tables.contains(&"users".to_string()));
716 assert!(tables.contains(&"orders".to_string()));
717 }
718
719 #[test]
720 fn test_set_schema_search_path() {
721 let transformer = TenantQueryTransformer::new();
722 let tenant = TenantId::new("acme");
723
724 let config = TenantConfig::builder()
725 .id("acme")
726 .name("Acme")
727 .schema_isolation("shared", "acme_schema")
728 .build();
729
730 let path = transformer.set_schema_search_path(&tenant, &config);
731 assert_eq!(path, Some("SET search_path TO acme_schema".to_string()));
732 }
733
734 #[test]
735 fn test_query_validation() {
736 let tenant = TenantId::new("acme");
737 let config = create_row_config("acme");
738
739 let validation = validate_query("SELECT * FROM users", &tenant, &config);
741 assert!(validation.valid);
742
743 let validation = validate_query("TRUNCATE users", &tenant, &config);
745 assert!(!validation.valid);
746
747 let validation = validate_query("DROP TABLE users", &tenant, &config);
749 assert!(!validation.valid);
750 }
751
752 #[test]
753 fn test_schema_cross_access_validation() {
754 let tenant = TenantId::new("acme");
755 let config = TenantConfig::builder()
756 .id("acme")
757 .name("Acme")
758 .schema_isolation("shared", "acme")
759 .build();
760
761 let validation = validate_query("SELECT * FROM acme.users", &tenant, &config);
763 assert!(validation.valid);
764
765 let validation = validate_query("SELECT * FROM other_tenant.users", &tenant, &config);
767 assert!(!validation.valid);
768
769 let validation = validate_query("SELECT * FROM pg_catalog.pg_tables", &tenant, &config);
771 assert!(validation.valid);
772 }
773}