1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 union.left = isolate_table_selects(union.left, schema, _dialect);
56 union.right = isolate_table_selects(union.right, schema, _dialect);
57 Expression::Union(union)
58 }
59 Expression::Intersect(mut intersect) => {
60 intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61 intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62 Expression::Intersect(intersect)
63 }
64 Expression::Except(mut except) => {
65 except.left = isolate_table_selects(except.left, schema, _dialect);
66 except.right = isolate_table_selects(except.right, schema, _dialect);
67 Expression::Except(except)
68 }
69 other => other,
70 }
71}
72
73fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76 if let Some(ref mut with) = select.with {
78 for cte in &mut with.ctes {
79 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80 }
81 }
82
83 if let Some(ref mut from) = select.from {
85 for expr in &mut from.expressions {
86 if let Expression::Subquery(ref mut sq) = expr {
87 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88 }
89 }
90 }
91 for join in &mut select.joins {
92 if let Expression::Subquery(ref mut sq) = join.this {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96
97 let source_count = count_sources(&select);
99
100 if source_count <= 1 {
102 return select;
103 }
104
105 if let Some(ref mut from) = select.from {
107 from.expressions = from
108 .expressions
109 .drain(..)
110 .map(|expr| maybe_wrap_table(expr, schema))
111 .collect();
112 }
113
114 for join in &mut select.joins {
116 join.this = maybe_wrap_table(join.this.clone(), schema);
117 }
118
119 select
120}
121
122fn count_sources(select: &Select) -> usize {
126 let from_count = select
127 .from
128 .as_ref()
129 .map(|f| f.expressions.len())
130 .unwrap_or(0);
131 let join_count = select.joins.len();
132 from_count + join_count
133}
134
135fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145 match expression {
146 Expression::Table(ref table) => {
147 if let Some(s) = schema {
151 let table_name = full_table_name(table);
152 if s.column_names(&table_name).unwrap_or_default().is_empty() {
153 return expression;
154 }
155 }
156
157 let alias_name = match &table.alias {
161 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162 _ => return expression,
163 };
164
165 wrap_table_in_subquery(table.clone(), &alias_name)
166 }
167 _ => expression,
168 }
169}
170
171fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176 let inner_select = Select::new()
178 .column(Expression::Star(Star {
179 table: None,
180 except: None,
181 replace: None,
182 rename: None,
183 trailing_comments: Vec::new(),
184 span: None,
185 }))
186 .from(Expression::Table(table));
187
188 Expression::Subquery(Box::new(Subquery {
190 this: Expression::Select(Box::new(inner_select)),
191 alias: Some(Identifier::new(alias_name)),
192 column_aliases: Vec::new(),
193 order_by: None,
194 limit: None,
195 offset: None,
196 distribute_by: None,
197 sort_by: None,
198 cluster_by: None,
199 lateral: false,
200 modifiers_inside: false,
201 trailing_comments: Vec::new(),
202 inferred_type: None,
203 }))
204}
205
206fn full_table_name(table: &TableRef) -> String {
211 let mut parts = Vec::new();
212 if let Some(ref catalog) = table.catalog {
213 parts.push(catalog.name.as_str());
214 }
215 if let Some(ref schema) = table.schema {
216 parts.push(schema.name.as_str());
217 }
218 parts.push(&table.name.name);
219 parts.join(".")
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use crate::generator::Generator;
226 use crate::parser::Parser;
227 use crate::schema::MappingSchema;
228
229 fn parse(sql: &str) -> Expression {
231 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
232 }
233
234 fn gen(expr: &Expression) -> String {
236 Generator::new().generate(expr).unwrap()
237 }
238
239 #[test]
244 fn test_single_table_unchanged() {
245 let sql = "SELECT * FROM t AS t";
246 let expr = parse(sql);
247 let result = isolate_table_selects(expr, None, None);
248 let output = gen(&result);
249 assert!(
251 !output.contains("(SELECT"),
252 "Single table should not be wrapped: {output}"
253 );
254 }
255
256 #[test]
257 fn test_single_subquery_unchanged() {
258 let sql = "SELECT * FROM (SELECT 1) AS t";
259 let expr = parse(sql);
260 let result = isolate_table_selects(expr, None, None);
261 let output = gen(&result);
262 assert_eq!(
264 output.matches("(SELECT").count(),
265 1,
266 "Single subquery source should not gain extra wrapping: {output}"
267 );
268 }
269
270 #[test]
275 fn test_two_tables_joined() {
276 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
277 let expr = parse(sql);
278 let result = isolate_table_selects(expr, None, None);
279 let output = gen(&result);
280 assert!(
282 output.contains("(SELECT * FROM a AS a) AS a"),
283 "FROM table should be wrapped: {output}"
284 );
285 assert!(
286 output.contains("(SELECT * FROM b AS b) AS b"),
287 "JOIN table should be wrapped: {output}"
288 );
289 }
290
291 #[test]
292 fn test_table_with_join_subquery() {
293 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
296 let expr = parse(sql);
297 let result = isolate_table_selects(expr, None, None);
298 let output = gen(&result);
299 assert!(
301 output.contains("(SELECT * FROM a AS a) AS a"),
302 "Bare table should be wrapped: {output}"
303 );
304 assert_eq!(
307 output.matches("(SELECT * FROM b)").count(),
308 1,
309 "Already-subquery source should not be double-wrapped: {output}"
310 );
311 }
312
313 #[test]
314 fn test_no_alias_not_wrapped() {
315 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
318 let expr = parse(sql);
319 let result = isolate_table_selects(expr, None, None);
320 let output = gen(&result);
321 assert!(
323 !output.contains("(SELECT * FROM a"),
324 "Table without alias should not be wrapped: {output}"
325 );
326 }
327
328 #[test]
333 fn test_schema_known_table_wrapped() {
334 let mut schema = MappingSchema::new();
335 schema
336 .add_table(
337 "a",
338 &[(
339 "id".to_string(),
340 DataType::Int {
341 length: None,
342 integer_spelling: false,
343 },
344 )],
345 None,
346 )
347 .unwrap();
348 schema
349 .add_table(
350 "b",
351 &[(
352 "id".to_string(),
353 DataType::Int {
354 length: None,
355 integer_spelling: false,
356 },
357 )],
358 None,
359 )
360 .unwrap();
361
362 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
363 let expr = parse(sql);
364 let result = isolate_table_selects(expr, Some(&schema), None);
365 let output = gen(&result);
366 assert!(
367 output.contains("(SELECT * FROM a AS a) AS a"),
368 "Known table 'a' should be wrapped: {output}"
369 );
370 assert!(
371 output.contains("(SELECT * FROM b AS b) AS b"),
372 "Known table 'b' should be wrapped: {output}"
373 );
374 }
375
376 #[test]
377 fn test_schema_unknown_table_not_wrapped() {
378 let mut schema = MappingSchema::new();
379 schema
381 .add_table(
382 "a",
383 &[(
384 "id".to_string(),
385 DataType::Int {
386 length: None,
387 integer_spelling: false,
388 },
389 )],
390 None,
391 )
392 .unwrap();
393
394 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
395 let expr = parse(sql);
396 let result = isolate_table_selects(expr, Some(&schema), None);
397 let output = gen(&result);
398 assert!(
399 output.contains("(SELECT * FROM a AS a) AS a"),
400 "Known table 'a' should be wrapped: {output}"
401 );
402 assert!(
404 !output.contains("(SELECT * FROM b AS b) AS b"),
405 "Unknown table 'b' should NOT be wrapped: {output}"
406 );
407 }
408
409 #[test]
414 fn test_cte_inner_query_processed() {
415 let sql =
416 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
417 let expr = parse(sql);
418 let result = isolate_table_selects(expr, None, None);
419 let output = gen(&result);
420 assert!(
422 output.contains("(SELECT * FROM x AS x) AS x"),
423 "CTE inner table 'x' should be wrapped: {output}"
424 );
425 assert!(
426 output.contains("(SELECT * FROM y AS y) AS y"),
427 "CTE inner table 'y' should be wrapped: {output}"
428 );
429 }
430
431 #[test]
432 fn test_nested_subquery_processed() {
433 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
434 let expr = parse(sql);
435 let result = isolate_table_selects(expr, None, None);
436 let output = gen(&result);
437 assert!(
439 output.contains("(SELECT * FROM a AS a) AS a"),
440 "Nested inner table 'a' should be wrapped: {output}"
441 );
442 }
443
444 #[test]
449 fn test_union_both_sides_processed() {
450 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
451 let expr = parse(sql);
452 let result = isolate_table_selects(expr, None, None);
453 let output = gen(&result);
454 assert!(
456 output.contains("(SELECT * FROM a AS a) AS a"),
457 "UNION left side should be processed: {output}"
458 );
459 assert!(
461 !output.contains("(SELECT * FROM c AS c) AS c"),
462 "UNION right side (single source) should not be wrapped: {output}"
463 );
464 }
465
466 #[test]
471 fn test_cross_join() {
472 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
473 let expr = parse(sql);
474 let result = isolate_table_selects(expr, None, None);
475 let output = gen(&result);
476 assert!(
477 output.contains("(SELECT * FROM a AS a) AS a"),
478 "CROSS JOIN table 'a' should be wrapped: {output}"
479 );
480 assert!(
481 output.contains("(SELECT * FROM b AS b) AS b"),
482 "CROSS JOIN table 'b' should be wrapped: {output}"
483 );
484 }
485
486 #[test]
487 fn test_multiple_from_tables() {
488 let sql = "SELECT * FROM a AS a, b AS b";
490 let expr = parse(sql);
491 let result = isolate_table_selects(expr, None, None);
492 let output = gen(&result);
493 assert!(
494 output.contains("(SELECT * FROM a AS a) AS a"),
495 "Comma-join table 'a' should be wrapped: {output}"
496 );
497 assert!(
498 output.contains("(SELECT * FROM b AS b) AS b"),
499 "Comma-join table 'b' should be wrapped: {output}"
500 );
501 }
502
503 #[test]
504 fn test_three_way_join() {
505 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
506 let expr = parse(sql);
507 let result = isolate_table_selects(expr, None, None);
508 let output = gen(&result);
509 assert!(
510 output.contains("(SELECT * FROM a AS a) AS a"),
511 "Three-way join: 'a' should be wrapped: {output}"
512 );
513 assert!(
514 output.contains("(SELECT * FROM b AS b) AS b"),
515 "Three-way join: 'b' should be wrapped: {output}"
516 );
517 assert!(
518 output.contains("(SELECT * FROM c AS c) AS c"),
519 "Three-way join: 'c' should be wrapped: {output}"
520 );
521 }
522
523 #[test]
524 fn test_qualified_table_name_with_schema() {
525 let mut schema = MappingSchema::new();
526 schema
527 .add_table(
528 "mydb.a",
529 &[(
530 "id".to_string(),
531 DataType::Int {
532 length: None,
533 integer_spelling: false,
534 },
535 )],
536 None,
537 )
538 .unwrap();
539 schema
540 .add_table(
541 "mydb.b",
542 &[(
543 "id".to_string(),
544 DataType::Int {
545 length: None,
546 integer_spelling: false,
547 },
548 )],
549 None,
550 )
551 .unwrap();
552
553 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
554 let expr = parse(sql);
555 let result = isolate_table_selects(expr, Some(&schema), None);
556 let output = gen(&result);
557 assert!(
558 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
559 "Qualified table 'mydb.a' should be wrapped: {output}"
560 );
561 assert!(
562 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
563 "Qualified table 'mydb.b' should be wrapped: {output}"
564 );
565 }
566
567 #[test]
568 fn test_non_select_expression_unchanged() {
569 let sql = "INSERT INTO t VALUES (1)";
571 let expr = parse(sql);
572 let original = gen(&expr);
573 let result = isolate_table_selects(expr, None, None);
574 let output = gen(&result);
575 assert_eq!(original, output, "Non-SELECT should be unchanged");
576 }
577}