1use std::collections::{HashMap, HashSet};
7
8use super::config::{IsolationStrategy, TenantConfig, TenantId};
9
10#[derive(Debug, Clone)]
12pub struct TransformResult {
13 pub query: String,
15
16 pub transformed: bool,
18
19 pub filtered_tables: Vec<String>,
21
22 pub warnings: Vec<String>,
24}
25
26impl TransformResult {
27 pub fn passthrough(query: impl Into<String>) -> Self {
29 Self {
30 query: query.into(),
31 transformed: false,
32 filtered_tables: Vec::new(),
33 warnings: Vec::new(),
34 }
35 }
36
37 pub fn transformed(query: impl Into<String>, tables: Vec<String>) -> Self {
39 Self {
40 query: query.into(),
41 transformed: true,
42 filtered_tables: tables,
43 warnings: Vec::new(),
44 }
45 }
46
47 pub fn with_warning(mut self, warning: impl Into<String>) -> Self {
49 self.warnings.push(warning.into());
50 self
51 }
52}
53
54pub struct TenantQueryTransformer {
56 tenant_tables: HashMap<String, String>,
58
59 excluded_tables: HashSet<String>,
61
62 use_parameters: bool,
64
65 filter_template: Option<String>,
67}
68
69impl Default for TenantQueryTransformer {
70 fn default() -> Self {
71 Self::new()
72 }
73}
74
75impl TenantQueryTransformer {
76 pub fn new() -> Self {
78 Self {
79 tenant_tables: HashMap::new(),
80 excluded_tables: HashSet::new(),
81 use_parameters: false,
82 filter_template: None,
83 }
84 }
85
86 pub fn register_table(mut self, table: impl Into<String>, column: impl Into<String>) -> Self {
88 self.tenant_tables
89 .insert(table.into().to_lowercase(), column.into());
90 self
91 }
92
93 pub fn register_tables(mut self, tables: &[&str], column: impl Into<String>) -> Self {
95 let col = column.into();
96 for table in tables {
97 self.tenant_tables.insert(table.to_lowercase(), col.clone());
98 }
99 self
100 }
101
102 pub fn exclude_table(mut self, table: impl Into<String>) -> Self {
104 self.excluded_tables.insert(table.into().to_lowercase());
105 self
106 }
107
108 pub fn with_parameters(mut self) -> Self {
110 self.use_parameters = true;
111 self
112 }
113
114 pub fn with_filter_template(mut self, template: impl Into<String>) -> Self {
116 self.filter_template = Some(template.into());
117 self
118 }
119
120 pub fn get_tenant_column(&self, table: &str) -> Option<&str> {
122 self.tenant_tables
123 .get(&table.to_lowercase())
124 .map(|s| s.as_str())
125 }
126
127 pub fn requires_filtering(&self, table: &str) -> bool {
129 let lower = table.to_lowercase();
130 self.tenant_tables.contains_key(&lower) && !self.excluded_tables.contains(&lower)
131 }
132
133 pub fn transform(
135 &self,
136 query: &str,
137 tenant: &TenantId,
138 config: &TenantConfig,
139 ) -> TransformResult {
140 let tenant_column = match &config.isolation {
142 IsolationStrategy::Row { tenant_column, .. } => tenant_column,
143 _ => return TransformResult::passthrough(query),
144 };
145
146 let upper = query.trim().to_uppercase();
148
149 if upper.starts_with("SELECT") {
150 self.transform_select(query, tenant, tenant_column)
151 } else if upper.starts_with("UPDATE") {
152 self.transform_update(query, tenant, tenant_column)
153 } else if upper.starts_with("DELETE") {
154 self.transform_delete(query, tenant, tenant_column)
155 } else if upper.starts_with("INSERT") {
156 self.transform_insert(query, tenant, tenant_column)
157 } else {
158 TransformResult::passthrough(query)
159 }
160 }
161
162 fn transform_select(
164 &self,
165 query: &str,
166 tenant: &TenantId,
167 tenant_column: &str,
168 ) -> TransformResult {
169 let tables = self.extract_tables(query);
170 let filtered_tables: Vec<String> = tables
171 .iter()
172 .filter(|t| self.requires_filtering(t))
173 .cloned()
174 .collect();
175
176 if filtered_tables.is_empty() {
177 return TransformResult::passthrough(query);
178 }
179
180 let filter = self.build_filter(tenant, tenant_column, &filtered_tables);
181 let transformed = self.inject_where_clause(query, &filter);
182
183 TransformResult::transformed(transformed, filtered_tables)
184 }
185
186 fn transform_update(
188 &self,
189 query: &str,
190 tenant: &TenantId,
191 tenant_column: &str,
192 ) -> TransformResult {
193 let table = self.extract_update_table(query);
194
195 if let Some(table) = table {
196 if self.requires_filtering(&table) {
197 let filter = self.build_single_filter(tenant, tenant_column);
198 let transformed = self.inject_where_clause(query, &filter);
199 return TransformResult::transformed(transformed, vec![table]);
200 }
201 }
202
203 TransformResult::passthrough(query)
204 }
205
206 fn transform_delete(
208 &self,
209 query: &str,
210 tenant: &TenantId,
211 tenant_column: &str,
212 ) -> TransformResult {
213 let table = self.extract_delete_table(query);
214
215 if let Some(table) = table {
216 if self.requires_filtering(&table) {
217 let filter = self.build_single_filter(tenant, tenant_column);
218 let transformed = self.inject_where_clause(query, &filter);
219 return TransformResult::transformed(transformed, vec![table]);
220 }
221 }
222
223 TransformResult::passthrough(query)
224 }
225
226 fn transform_insert(
228 &self,
229 query: &str,
230 tenant: &TenantId,
231 tenant_column: &str,
232 ) -> TransformResult {
233 let table = self.extract_insert_table(query);
234
235 if let Some(table) = table {
236 if self.requires_filtering(&table) {
237 let transformed = self.inject_tenant_value(query, tenant, tenant_column);
239 return TransformResult::transformed(transformed, vec![table])
240 .with_warning("Tenant column injection may require schema awareness");
241 }
242 }
243
244 TransformResult::passthrough(query)
245 }
246
247 fn build_filter(&self, tenant: &TenantId, default_column: &str, tables: &[String]) -> String {
249 let filters: Vec<String> = tables
250 .iter()
251 .map(|table| {
252 let column = self.get_tenant_column(table).unwrap_or(default_column);
253 if self.use_parameters {
254 format!("{}.{} = $1", table, column)
255 } else {
256 format!("{}.{} = '{}'", table, column, tenant.0)
257 }
258 })
259 .collect();
260
261 filters.join(" AND ")
262 }
263
264 fn build_single_filter(&self, tenant: &TenantId, column: &str) -> String {
266 if self.use_parameters {
267 format!("{} = $1", column)
268 } else {
269 match &self.filter_template {
270 Some(template) => template
271 .replace("{column}", column)
272 .replace("{value}", &tenant.0),
273 None => format!("{} = '{}'", column, tenant.0),
274 }
275 }
276 }
277
278 fn inject_where_clause(&self, query: &str, filter: &str) -> String {
280 let upper = query.to_uppercase();
281
282 if let Some(where_pos) = upper.find(" WHERE ") {
284 let (before, after) = query.split_at(where_pos + 7);
286 format!("{}{} AND {}", before, filter, after)
287 } else {
288 let insert_before = [" ORDER ", " GROUP ", " LIMIT ", " HAVING ", " UNION "]
291 .iter()
292 .filter_map(|kw| upper.find(kw))
293 .min();
294
295 match insert_before {
296 Some(pos) => {
297 let (before, after) = query.split_at(pos);
298 format!("{} WHERE {}{}", before, filter, after)
299 }
300 None => {
301 format!("{} WHERE {}", query.trim_end_matches(';'), filter)
303 }
304 }
305 }
306 }
307
308 fn inject_tenant_value(&self, query: &str, tenant: &TenantId, column: &str) -> String {
310 let upper = query.to_uppercase();
312
313 if let Some(values_pos) = upper.find(" VALUES ") {
314 if let Some(paren_pos) = query[values_pos..].find('(') {
315 let insert_pos = values_pos + paren_pos + 1;
316
317 if let Some(cols_start) = upper.find('(') {
319 if cols_start < values_pos {
320 let cols_end = upper[cols_start..].find(')').unwrap_or(0) + cols_start;
322 let before_cols_end = &query[..cols_end];
323 let after_cols_end = &query[cols_end..];
324
325 let with_column =
327 format!("{}, {}{}", before_cols_end, column, after_cols_end);
328
329 let upper_new = with_column.to_uppercase();
331 if let Some(new_values_pos) = upper_new.find(" VALUES ") {
332 if let Some(new_paren_pos) = with_column[new_values_pos..].find('(') {
333 let new_insert_pos = new_values_pos + new_paren_pos + 1;
334 let before = &with_column[..new_insert_pos];
335 let after = &with_column[new_insert_pos..];
336 return format!("{}'{}'", before, tenant.0)
337 + if !after.starts_with(')') { ", " } else { "" }
338 + after;
339 }
340 }
341 }
342 }
343
344 let before = &query[..insert_pos];
346 let after = &query[insert_pos..];
347 return format!("{}'{}'", before, tenant.0)
348 + if !after.starts_with(')') { ", " } else { "" }
349 + after;
350 }
351 }
352
353 query.to_string()
354 }
355
356 fn extract_tables(&self, query: &str) -> Vec<String> {
358 let upper = query.to_uppercase();
359 let mut tables = Vec::new();
360
361 if let Some(from_pos) = upper.find(" FROM ") {
363 let after_from = &query[from_pos + 6..];
364
365 let end_markers = [
367 " WHERE ", " JOIN ", " LEFT ", " RIGHT ", " INNER ", " OUTER ", " GROUP ",
368 " ORDER ", " LIMIT ", " HAVING ",
369 ];
370 let end_pos = end_markers
371 .iter()
372 .filter_map(|m| after_from.to_uppercase().find(m))
373 .min()
374 .unwrap_or(after_from.len());
375
376 let table_section = &after_from[..end_pos];
377
378 for part in table_section.split(',') {
380 let trimmed = part.trim();
381 if let Some(table) = trimmed.split_whitespace().next() {
382 let clean =
383 table.trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
384 if !clean.is_empty() {
385 tables.push(clean.to_string());
386 }
387 }
388 }
389 }
390
391 let words: Vec<&str> = query.split_whitespace().collect();
393 for (i, word) in words.iter().enumerate() {
394 if word.to_uppercase() == "JOIN" && i + 1 < words.len() {
395 let table =
396 words[i + 1].trim_matches(|c| c == '"' || c == '`' || c == '[' || c == ']');
397 if !table.is_empty() && !tables.contains(&table.to_string()) {
398 tables.push(table.to_string());
399 }
400 }
401 }
402
403 tables
404 }
405
406 fn extract_update_table(&self, query: &str) -> Option<String> {
408 let upper = query.to_uppercase();
409 if let Some(update_pos) = upper.find("UPDATE ") {
410 let after_update = &query[update_pos + 7..];
411 if let Some(set_pos) = after_update.to_uppercase().find(" SET ") {
412 let table_section = &after_update[..set_pos];
413 let table = table_section
414 .split_whitespace()
415 .next()?
416 .trim_matches(|c| c == '"' || c == '`');
417 return Some(table.to_string());
418 }
419 }
420 None
421 }
422
423 fn extract_delete_table(&self, query: &str) -> Option<String> {
425 let upper = query.to_uppercase();
426 if let Some(from_pos) = upper.find(" FROM ") {
427 let after_from = &query[from_pos + 6..];
428 let end_pos = after_from
429 .to_uppercase()
430 .find(" WHERE ")
431 .unwrap_or(after_from.len());
432 let table_section = &after_from[..end_pos];
433 let table = table_section
434 .split_whitespace()
435 .next()?
436 .trim_matches(|c| c == '"' || c == '`');
437 return Some(table.to_string());
438 }
439 None
440 }
441
442 fn extract_insert_table(&self, query: &str) -> Option<String> {
444 let upper = query.to_uppercase();
445 if let Some(into_pos) = upper.find(" INTO ") {
446 let after_into = &query[into_pos + 6..];
447 let end_pos = after_into
448 .find(|c: char| c == '(' || c.is_whitespace())
449 .unwrap_or(after_into.len());
450 let table = after_into[..end_pos]
451 .trim()
452 .trim_matches(|c| c == '"' || c == '`');
453 return Some(table.to_string());
454 }
455 None
456 }
457
458 pub fn set_schema_search_path(
460 &self,
461 _tenant: &TenantId,
462 config: &TenantConfig,
463 ) -> Option<String> {
464 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
465 Some(format!("SET search_path TO {}", schema_name))
466 } else {
467 None
468 }
469 }
470
471 pub fn use_database(&self, _tenant: &TenantId, config: &TenantConfig) -> Option<String> {
473 if let IsolationStrategy::Database { database_name } = &config.isolation {
474 Some(format!("USE {}", database_name))
475 } else {
476 None
477 }
478 }
479}
480
481pub fn validate_query(query: &str, _tenant: &TenantId, config: &TenantConfig) -> QueryValidation {
483 let mut validation = QueryValidation {
484 valid: true,
485 violations: Vec::new(),
486 };
487
488 let upper = query.to_uppercase();
489
490 if let IsolationStrategy::Row { tenant_column, .. } = &config.isolation {
492 if upper.contains(&format!("{} =", tenant_column.to_uppercase())) {
494 let set_pattern = format!("SET {} =", tenant_column.to_uppercase());
495 if upper.contains(&set_pattern) {
496 validation.valid = false;
497 validation
498 .violations
499 .push(format!("Cannot modify tenant column: {}", tenant_column));
500 }
501 }
502
503 if upper.starts_with("TRUNCATE ") {
505 validation.valid = false;
506 validation
507 .violations
508 .push("TRUNCATE not allowed with row-level isolation".to_string());
509 }
510
511 if upper.contains("DROP TABLE") {
513 validation.valid = false;
514 validation
515 .violations
516 .push("DROP TABLE not allowed with row-level isolation".to_string());
517 }
518 }
519
520 if let IsolationStrategy::Schema { schema_name, .. } = &config.isolation {
522 let parts: Vec<&str> = upper.split_whitespace().collect();
524 for part in parts {
525 if part.contains('.') && !part.starts_with(&schema_name.to_uppercase()) {
526 let schema = part.split('.').next().unwrap_or("");
527 if !schema.eq_ignore_ascii_case("pg_catalog")
528 && !schema.eq_ignore_ascii_case("information_schema")
529 {
530 validation.valid = false;
531 validation
532 .violations
533 .push(format!("Cross-schema access not allowed: {}", part));
534 }
535 }
536 }
537 }
538
539 validation
540}
541
542#[derive(Debug, Clone)]
544pub struct QueryValidation {
545 pub valid: bool,
547
548 pub violations: Vec<String>,
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 fn create_row_config(tenant_id: &str) -> TenantConfig {
557 TenantConfig::builder()
558 .id(tenant_id)
559 .name("Test")
560 .row_isolation("shared_db", "tenant_id")
561 .build()
562 }
563
564 #[test]
565 fn test_transform_select() {
566 let transformer = TenantQueryTransformer::new()
567 .register_table("users", "tenant_id")
568 .register_table("orders", "tenant_id");
569
570 let tenant = TenantId::new("acme");
571 let config = create_row_config("acme");
572
573 let result =
574 transformer.transform("SELECT * FROM users WHERE active = true", &tenant, &config);
575
576 assert!(result.transformed);
577 assert!(result.query.contains("tenant_id = 'acme'"));
578 assert!(result.query.contains("AND active = true"));
579 }
580
581 #[test]
582 fn test_transform_select_no_where() {
583 let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
584
585 let tenant = TenantId::new("acme");
586 let config = create_row_config("acme");
587
588 let result = transformer.transform("SELECT * FROM users ORDER BY id", &tenant, &config);
589
590 assert!(result.transformed);
591 assert!(result.query.contains("WHERE users.tenant_id = 'acme'"));
592 assert!(result.query.contains("ORDER BY id"));
593 }
594
595 #[test]
596 fn test_transform_update() {
597 let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
598
599 let tenant = TenantId::new("acme");
600 let config = create_row_config("acme");
601
602 let result = transformer.transform(
603 "UPDATE users SET name = 'John' WHERE id = 1",
604 &tenant,
605 &config,
606 );
607
608 assert!(result.transformed);
609 assert!(result.query.contains("tenant_id = 'acme'"));
610 }
611
612 #[test]
613 fn test_transform_delete() {
614 let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
615
616 let tenant = TenantId::new("acme");
617 let config = create_row_config("acme");
618
619 let result = transformer.transform("DELETE FROM users WHERE id = 1", &tenant, &config);
620
621 assert!(result.transformed);
622 assert!(result.query.contains("tenant_id = 'acme'"));
623 }
624
625 #[test]
626 fn test_no_transform_for_unregistered_table() {
627 let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
628
629 let tenant = TenantId::new("acme");
630 let config = create_row_config("acme");
631
632 let result =
633 transformer.transform("SELECT * FROM logs WHERE level = 'error'", &tenant, &config);
634
635 assert!(!result.transformed);
636 }
637
638 #[test]
639 fn test_no_transform_for_schema_isolation() {
640 let transformer = TenantQueryTransformer::new().register_table("users", "tenant_id");
641
642 let tenant = TenantId::new("acme");
643 let config = TenantConfig::builder()
644 .id("acme")
645 .name("Acme")
646 .schema_isolation("shared", "acme")
647 .build();
648
649 let result = transformer.transform("SELECT * FROM users", &tenant, &config);
650
651 assert!(!result.transformed);
652 }
653
654 #[test]
655 fn test_excluded_tables() {
656 let transformer = TenantQueryTransformer::new()
657 .register_table("users", "tenant_id")
658 .register_table("audit_log", "tenant_id")
659 .exclude_table("audit_log");
660
661 let tenant = TenantId::new("acme");
662 let config = create_row_config("acme");
663
664 let result = transformer.transform("SELECT * FROM audit_log", &tenant, &config);
665
666 assert!(!result.transformed);
667 }
668
669 #[test]
670 fn test_extract_tables() {
671 let transformer = TenantQueryTransformer::new();
672
673 let tables =
674 transformer.extract_tables("SELECT * FROM users u, orders o WHERE u.id = o.user_id");
675 assert!(tables.contains(&"users".to_string()));
676 assert!(tables.contains(&"orders".to_string()));
677
678 let tables = transformer
679 .extract_tables("SELECT * FROM users JOIN orders ON users.id = orders.user_id");
680 assert!(tables.contains(&"users".to_string()));
681 assert!(tables.contains(&"orders".to_string()));
682 }
683
684 #[test]
685 fn test_set_schema_search_path() {
686 let transformer = TenantQueryTransformer::new();
687 let tenant = TenantId::new("acme");
688
689 let config = TenantConfig::builder()
690 .id("acme")
691 .name("Acme")
692 .schema_isolation("shared", "acme_schema")
693 .build();
694
695 let path = transformer.set_schema_search_path(&tenant, &config);
696 assert_eq!(path, Some("SET search_path TO acme_schema".to_string()));
697 }
698
699 #[test]
700 fn test_query_validation() {
701 let tenant = TenantId::new("acme");
702 let config = create_row_config("acme");
703
704 let validation = validate_query("SELECT * FROM users", &tenant, &config);
706 assert!(validation.valid);
707
708 let validation = validate_query("TRUNCATE users", &tenant, &config);
710 assert!(!validation.valid);
711
712 let validation = validate_query("DROP TABLE users", &tenant, &config);
714 assert!(!validation.valid);
715 }
716
717 #[test]
718 fn test_schema_cross_access_validation() {
719 let tenant = TenantId::new("acme");
720 let config = TenantConfig::builder()
721 .id("acme")
722 .name("Acme")
723 .schema_isolation("shared", "acme")
724 .build();
725
726 let validation = validate_query("SELECT * FROM acme.users", &tenant, &config);
728 assert!(validation.valid);
729
730 let validation = validate_query("SELECT * FROM other_tenant.users", &tenant, &config);
732 assert!(!validation.valid);
733
734 let validation = validate_query("SELECT * FROM pg_catalog.pg_tables", &tenant, &config);
736 assert!(validation.valid);
737 }
738}