1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 let left = std::mem::replace(&mut union.left, Expression::Null(Null));
56 union.left = isolate_table_selects(left, schema, _dialect);
57 let right = std::mem::replace(&mut union.right, Expression::Null(Null));
58 union.right = isolate_table_selects(right, schema, _dialect);
59 Expression::Union(union)
60 }
61 Expression::Intersect(mut intersect) => {
62 let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
63 intersect.left = isolate_table_selects(left, schema, _dialect);
64 let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
65 intersect.right = isolate_table_selects(right, schema, _dialect);
66 Expression::Intersect(intersect)
67 }
68 Expression::Except(mut except) => {
69 let left = std::mem::replace(&mut except.left, Expression::Null(Null));
70 except.left = isolate_table_selects(left, schema, _dialect);
71 let right = std::mem::replace(&mut except.right, Expression::Null(Null));
72 except.right = isolate_table_selects(right, schema, _dialect);
73 Expression::Except(except)
74 }
75 other => other,
76 }
77}
78
79fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
82 if let Some(ref mut with) = select.with {
84 for cte in &mut with.ctes {
85 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
86 }
87 }
88
89 if let Some(ref mut from) = select.from {
91 for expr in &mut from.expressions {
92 if let Expression::Subquery(ref mut sq) = expr {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96 }
97 for join in &mut select.joins {
98 if let Expression::Subquery(ref mut sq) = join.this {
99 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
100 }
101 }
102
103 let source_count = count_sources(&select);
105
106 if source_count <= 1 {
108 return select;
109 }
110
111 if let Some(ref mut from) = select.from {
113 from.expressions = from
114 .expressions
115 .drain(..)
116 .map(|expr| maybe_wrap_table(expr, schema))
117 .collect();
118 }
119
120 for join in &mut select.joins {
122 join.this = maybe_wrap_table(join.this.clone(), schema);
123 }
124
125 select
126}
127
128fn count_sources(select: &Select) -> usize {
132 let from_count = select
133 .from
134 .as_ref()
135 .map(|f| f.expressions.len())
136 .unwrap_or(0);
137 let join_count = select.joins.len();
138 from_count + join_count
139}
140
141fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
151 match expression {
152 Expression::Table(ref table) => {
153 if let Some(s) = schema {
157 let table_name = full_table_name(table);
158 if s.column_names(&table_name).unwrap_or_default().is_empty() {
159 return expression;
160 }
161 }
162
163 let alias_name = match &table.alias {
167 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
168 _ => return expression,
169 };
170
171 wrap_table_in_subquery(*table.clone(), &alias_name)
172 }
173 _ => expression,
174 }
175}
176
177fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
182 let inner_select = Select::new()
184 .column(Expression::Star(Star {
185 table: None,
186 except: None,
187 replace: None,
188 rename: None,
189 trailing_comments: Vec::new(),
190 span: None,
191 }))
192 .from(Expression::Table(Box::new(table)));
193
194 Expression::Subquery(Box::new(Subquery {
196 this: Expression::Select(Box::new(inner_select)),
197 alias: Some(Identifier::new(alias_name)),
198 column_aliases: Vec::new(),
199 alias_explicit_as: false,
200 alias_keyword: None,
201 order_by: None,
202 limit: None,
203 offset: None,
204 distribute_by: None,
205 sort_by: None,
206 cluster_by: None,
207 lateral: false,
208 modifiers_inside: false,
209 trailing_comments: Vec::new(),
210 inferred_type: None,
211 }))
212}
213
214fn full_table_name(table: &TableRef) -> String {
219 let mut parts = Vec::new();
220 if let Some(ref catalog) = table.catalog {
221 parts.push(catalog.name.as_str());
222 }
223 if let Some(ref schema) = table.schema {
224 parts.push(schema.name.as_str());
225 }
226 parts.push(&table.name.name);
227 parts.join(".")
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::generator::Generator;
234 use crate::parser::Parser;
235 use crate::schema::MappingSchema;
236
237 fn parse(sql: &str) -> Expression {
239 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
240 }
241
242 fn gen(expr: &Expression) -> String {
244 Generator::new().generate(expr).unwrap()
245 }
246
247 #[test]
252 fn test_single_table_unchanged() {
253 let sql = "SELECT * FROM t AS t";
254 let expr = parse(sql);
255 let result = isolate_table_selects(expr, None, None);
256 let output = gen(&result);
257 assert!(
259 !output.contains("(SELECT"),
260 "Single table should not be wrapped: {output}"
261 );
262 }
263
264 #[test]
265 fn test_single_subquery_unchanged() {
266 let sql = "SELECT * FROM (SELECT 1) AS t";
267 let expr = parse(sql);
268 let result = isolate_table_selects(expr, None, None);
269 let output = gen(&result);
270 assert_eq!(
272 output.matches("(SELECT").count(),
273 1,
274 "Single subquery source should not gain extra wrapping: {output}"
275 );
276 }
277
278 #[test]
283 fn test_two_tables_joined() {
284 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
285 let expr = parse(sql);
286 let result = isolate_table_selects(expr, None, None);
287 let output = gen(&result);
288 assert!(
290 output.contains("(SELECT * FROM a AS a) AS a"),
291 "FROM table should be wrapped: {output}"
292 );
293 assert!(
294 output.contains("(SELECT * FROM b AS b) AS b"),
295 "JOIN table should be wrapped: {output}"
296 );
297 }
298
299 #[test]
300 fn test_table_with_join_subquery() {
301 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
304 let expr = parse(sql);
305 let result = isolate_table_selects(expr, None, None);
306 let output = gen(&result);
307 assert!(
309 output.contains("(SELECT * FROM a AS a) AS a"),
310 "Bare table should be wrapped: {output}"
311 );
312 assert_eq!(
315 output.matches("(SELECT * FROM b)").count(),
316 1,
317 "Already-subquery source should not be double-wrapped: {output}"
318 );
319 }
320
321 #[test]
322 fn test_no_alias_not_wrapped() {
323 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
326 let expr = parse(sql);
327 let result = isolate_table_selects(expr, None, None);
328 let output = gen(&result);
329 assert!(
331 !output.contains("(SELECT * FROM a"),
332 "Table without alias should not be wrapped: {output}"
333 );
334 }
335
336 #[test]
341 fn test_schema_known_table_wrapped() {
342 let mut schema = MappingSchema::new();
343 schema
344 .add_table(
345 "a",
346 &[(
347 "id".to_string(),
348 DataType::Int {
349 length: None,
350 integer_spelling: false,
351 },
352 )],
353 None,
354 )
355 .unwrap();
356 schema
357 .add_table(
358 "b",
359 &[(
360 "id".to_string(),
361 DataType::Int {
362 length: None,
363 integer_spelling: false,
364 },
365 )],
366 None,
367 )
368 .unwrap();
369
370 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
371 let expr = parse(sql);
372 let result = isolate_table_selects(expr, Some(&schema), None);
373 let output = gen(&result);
374 assert!(
375 output.contains("(SELECT * FROM a AS a) AS a"),
376 "Known table 'a' should be wrapped: {output}"
377 );
378 assert!(
379 output.contains("(SELECT * FROM b AS b) AS b"),
380 "Known table 'b' should be wrapped: {output}"
381 );
382 }
383
384 #[test]
385 fn test_schema_unknown_table_not_wrapped() {
386 let mut schema = MappingSchema::new();
387 schema
389 .add_table(
390 "a",
391 &[(
392 "id".to_string(),
393 DataType::Int {
394 length: None,
395 integer_spelling: false,
396 },
397 )],
398 None,
399 )
400 .unwrap();
401
402 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
403 let expr = parse(sql);
404 let result = isolate_table_selects(expr, Some(&schema), None);
405 let output = gen(&result);
406 assert!(
407 output.contains("(SELECT * FROM a AS a) AS a"),
408 "Known table 'a' should be wrapped: {output}"
409 );
410 assert!(
412 !output.contains("(SELECT * FROM b AS b) AS b"),
413 "Unknown table 'b' should NOT be wrapped: {output}"
414 );
415 }
416
417 #[test]
422 fn test_cte_inner_query_processed() {
423 let sql =
424 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
425 let expr = parse(sql);
426 let result = isolate_table_selects(expr, None, None);
427 let output = gen(&result);
428 assert!(
430 output.contains("(SELECT * FROM x AS x) AS x"),
431 "CTE inner table 'x' should be wrapped: {output}"
432 );
433 assert!(
434 output.contains("(SELECT * FROM y AS y) AS y"),
435 "CTE inner table 'y' should be wrapped: {output}"
436 );
437 }
438
439 #[test]
440 fn test_nested_subquery_processed() {
441 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
442 let expr = parse(sql);
443 let result = isolate_table_selects(expr, None, None);
444 let output = gen(&result);
445 assert!(
447 output.contains("(SELECT * FROM a AS a) AS a"),
448 "Nested inner table 'a' should be wrapped: {output}"
449 );
450 }
451
452 #[test]
457 fn test_union_both_sides_processed() {
458 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
459 let expr = parse(sql);
460 let result = isolate_table_selects(expr, None, None);
461 let output = gen(&result);
462 assert!(
464 output.contains("(SELECT * FROM a AS a) AS a"),
465 "UNION left side should be processed: {output}"
466 );
467 assert!(
469 !output.contains("(SELECT * FROM c AS c) AS c"),
470 "UNION right side (single source) should not be wrapped: {output}"
471 );
472 }
473
474 #[test]
479 fn test_cross_join() {
480 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
481 let expr = parse(sql);
482 let result = isolate_table_selects(expr, None, None);
483 let output = gen(&result);
484 assert!(
485 output.contains("(SELECT * FROM a AS a) AS a"),
486 "CROSS JOIN table 'a' should be wrapped: {output}"
487 );
488 assert!(
489 output.contains("(SELECT * FROM b AS b) AS b"),
490 "CROSS JOIN table 'b' should be wrapped: {output}"
491 );
492 }
493
494 #[test]
495 fn test_multiple_from_tables() {
496 let sql = "SELECT * FROM a AS a, b AS b";
498 let expr = parse(sql);
499 let result = isolate_table_selects(expr, None, None);
500 let output = gen(&result);
501 assert!(
502 output.contains("(SELECT * FROM a AS a) AS a"),
503 "Comma-join table 'a' should be wrapped: {output}"
504 );
505 assert!(
506 output.contains("(SELECT * FROM b AS b) AS b"),
507 "Comma-join table 'b' should be wrapped: {output}"
508 );
509 }
510
511 #[test]
512 fn test_three_way_join() {
513 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
514 let expr = parse(sql);
515 let result = isolate_table_selects(expr, None, None);
516 let output = gen(&result);
517 assert!(
518 output.contains("(SELECT * FROM a AS a) AS a"),
519 "Three-way join: 'a' should be wrapped: {output}"
520 );
521 assert!(
522 output.contains("(SELECT * FROM b AS b) AS b"),
523 "Three-way join: 'b' should be wrapped: {output}"
524 );
525 assert!(
526 output.contains("(SELECT * FROM c AS c) AS c"),
527 "Three-way join: 'c' should be wrapped: {output}"
528 );
529 }
530
531 #[test]
532 fn test_qualified_table_name_with_schema() {
533 let mut schema = MappingSchema::new();
534 schema
535 .add_table(
536 "mydb.a",
537 &[(
538 "id".to_string(),
539 DataType::Int {
540 length: None,
541 integer_spelling: false,
542 },
543 )],
544 None,
545 )
546 .unwrap();
547 schema
548 .add_table(
549 "mydb.b",
550 &[(
551 "id".to_string(),
552 DataType::Int {
553 length: None,
554 integer_spelling: false,
555 },
556 )],
557 None,
558 )
559 .unwrap();
560
561 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
562 let expr = parse(sql);
563 let result = isolate_table_selects(expr, Some(&schema), None);
564 let output = gen(&result);
565 assert!(
566 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
567 "Qualified table 'mydb.a' should be wrapped: {output}"
568 );
569 assert!(
570 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
571 "Qualified table 'mydb.b' should be wrapped: {output}"
572 );
573 }
574
575 #[test]
576 fn test_non_select_expression_unchanged() {
577 let sql = "INSERT INTO t VALUES (1)";
579 let expr = parse(sql);
580 let original = gen(&expr);
581 let result = isolate_table_selects(expr, None, None);
582 let output = gen(&result);
583 assert_eq!(original, output, "Non-SELECT should be unchanged");
584 }
585}