1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 union.left = isolate_table_selects(union.left, schema, _dialect);
56 union.right = isolate_table_selects(union.right, schema, _dialect);
57 Expression::Union(union)
58 }
59 Expression::Intersect(mut intersect) => {
60 intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61 intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62 Expression::Intersect(intersect)
63 }
64 Expression::Except(mut except) => {
65 except.left = isolate_table_selects(except.left, schema, _dialect);
66 except.right = isolate_table_selects(except.right, schema, _dialect);
67 Expression::Except(except)
68 }
69 other => other,
70 }
71}
72
73fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76 if let Some(ref mut with) = select.with {
78 for cte in &mut with.ctes {
79 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80 }
81 }
82
83 if let Some(ref mut from) = select.from {
85 for expr in &mut from.expressions {
86 if let Expression::Subquery(ref mut sq) = expr {
87 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88 }
89 }
90 }
91 for join in &mut select.joins {
92 if let Expression::Subquery(ref mut sq) = join.this {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96
97 let source_count = count_sources(&select);
99
100 if source_count <= 1 {
102 return select;
103 }
104
105 if let Some(ref mut from) = select.from {
107 from.expressions = from
108 .expressions
109 .drain(..)
110 .map(|expr| maybe_wrap_table(expr, schema))
111 .collect();
112 }
113
114 for join in &mut select.joins {
116 join.this = maybe_wrap_table(join.this.clone(), schema);
117 }
118
119 select
120}
121
122fn count_sources(select: &Select) -> usize {
126 let from_count = select
127 .from
128 .as_ref()
129 .map(|f| f.expressions.len())
130 .unwrap_or(0);
131 let join_count = select.joins.len();
132 from_count + join_count
133}
134
135fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145 match expression {
146 Expression::Table(ref table) => {
147 if let Some(s) = schema {
151 let table_name = full_table_name(table);
152 if s.column_names(&table_name).unwrap_or_default().is_empty() {
153 return expression;
154 }
155 }
156
157 let alias_name = match &table.alias {
161 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162 _ => return expression,
163 };
164
165 wrap_table_in_subquery(*table.clone(), &alias_name)
166
167 }
168 _ => expression,
169 }
170}
171
172fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
177 let inner_select = Select::new()
179 .column(Expression::Star(Star {
180 table: None,
181 except: None,
182 replace: None,
183 rename: None,
184 trailing_comments: Vec::new(),
185 span: None,
186 }))
187 .from(Expression::Table(Box::new(table)));
188
189 Expression::Subquery(Box::new(Subquery {
191 this: Expression::Select(Box::new(inner_select)),
192 alias: Some(Identifier::new(alias_name)),
193 column_aliases: Vec::new(),
194 order_by: None,
195 limit: None,
196 offset: None,
197 distribute_by: None,
198 sort_by: None,
199 cluster_by: None,
200 lateral: false,
201 modifiers_inside: false,
202 trailing_comments: Vec::new(),
203 inferred_type: None,
204 }))
205}
206
207fn full_table_name(table: &TableRef) -> String {
212 let mut parts = Vec::new();
213 if let Some(ref catalog) = table.catalog {
214 parts.push(catalog.name.as_str());
215 }
216 if let Some(ref schema) = table.schema {
217 parts.push(schema.name.as_str());
218 }
219 parts.push(&table.name.name);
220 parts.join(".")
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226 use crate::generator::Generator;
227 use crate::parser::Parser;
228 use crate::schema::MappingSchema;
229
230 fn parse(sql: &str) -> Expression {
232 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
233 }
234
235 fn gen(expr: &Expression) -> String {
237 Generator::new().generate(expr).unwrap()
238 }
239
240 #[test]
245 fn test_single_table_unchanged() {
246 let sql = "SELECT * FROM t AS t";
247 let expr = parse(sql);
248 let result = isolate_table_selects(expr, None, None);
249 let output = gen(&result);
250 assert!(
252 !output.contains("(SELECT"),
253 "Single table should not be wrapped: {output}"
254 );
255 }
256
257 #[test]
258 fn test_single_subquery_unchanged() {
259 let sql = "SELECT * FROM (SELECT 1) AS t";
260 let expr = parse(sql);
261 let result = isolate_table_selects(expr, None, None);
262 let output = gen(&result);
263 assert_eq!(
265 output.matches("(SELECT").count(),
266 1,
267 "Single subquery source should not gain extra wrapping: {output}"
268 );
269 }
270
271 #[test]
276 fn test_two_tables_joined() {
277 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
278 let expr = parse(sql);
279 let result = isolate_table_selects(expr, None, None);
280 let output = gen(&result);
281 assert!(
283 output.contains("(SELECT * FROM a AS a) AS a"),
284 "FROM table should be wrapped: {output}"
285 );
286 assert!(
287 output.contains("(SELECT * FROM b AS b) AS b"),
288 "JOIN table should be wrapped: {output}"
289 );
290 }
291
292 #[test]
293 fn test_table_with_join_subquery() {
294 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
297 let expr = parse(sql);
298 let result = isolate_table_selects(expr, None, None);
299 let output = gen(&result);
300 assert!(
302 output.contains("(SELECT * FROM a AS a) AS a"),
303 "Bare table should be wrapped: {output}"
304 );
305 assert_eq!(
308 output.matches("(SELECT * FROM b)").count(),
309 1,
310 "Already-subquery source should not be double-wrapped: {output}"
311 );
312 }
313
314 #[test]
315 fn test_no_alias_not_wrapped() {
316 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
319 let expr = parse(sql);
320 let result = isolate_table_selects(expr, None, None);
321 let output = gen(&result);
322 assert!(
324 !output.contains("(SELECT * FROM a"),
325 "Table without alias should not be wrapped: {output}"
326 );
327 }
328
329 #[test]
334 fn test_schema_known_table_wrapped() {
335 let mut schema = MappingSchema::new();
336 schema
337 .add_table(
338 "a",
339 &[(
340 "id".to_string(),
341 DataType::Int {
342 length: None,
343 integer_spelling: false,
344 },
345 )],
346 None,
347 )
348 .unwrap();
349 schema
350 .add_table(
351 "b",
352 &[(
353 "id".to_string(),
354 DataType::Int {
355 length: None,
356 integer_spelling: false,
357 },
358 )],
359 None,
360 )
361 .unwrap();
362
363 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
364 let expr = parse(sql);
365 let result = isolate_table_selects(expr, Some(&schema), None);
366 let output = gen(&result);
367 assert!(
368 output.contains("(SELECT * FROM a AS a) AS a"),
369 "Known table 'a' should be wrapped: {output}"
370 );
371 assert!(
372 output.contains("(SELECT * FROM b AS b) AS b"),
373 "Known table 'b' should be wrapped: {output}"
374 );
375 }
376
377 #[test]
378 fn test_schema_unknown_table_not_wrapped() {
379 let mut schema = MappingSchema::new();
380 schema
382 .add_table(
383 "a",
384 &[(
385 "id".to_string(),
386 DataType::Int {
387 length: None,
388 integer_spelling: false,
389 },
390 )],
391 None,
392 )
393 .unwrap();
394
395 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
396 let expr = parse(sql);
397 let result = isolate_table_selects(expr, Some(&schema), None);
398 let output = gen(&result);
399 assert!(
400 output.contains("(SELECT * FROM a AS a) AS a"),
401 "Known table 'a' should be wrapped: {output}"
402 );
403 assert!(
405 !output.contains("(SELECT * FROM b AS b) AS b"),
406 "Unknown table 'b' should NOT be wrapped: {output}"
407 );
408 }
409
410 #[test]
415 fn test_cte_inner_query_processed() {
416 let sql =
417 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
418 let expr = parse(sql);
419 let result = isolate_table_selects(expr, None, None);
420 let output = gen(&result);
421 assert!(
423 output.contains("(SELECT * FROM x AS x) AS x"),
424 "CTE inner table 'x' should be wrapped: {output}"
425 );
426 assert!(
427 output.contains("(SELECT * FROM y AS y) AS y"),
428 "CTE inner table 'y' should be wrapped: {output}"
429 );
430 }
431
432 #[test]
433 fn test_nested_subquery_processed() {
434 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
435 let expr = parse(sql);
436 let result = isolate_table_selects(expr, None, None);
437 let output = gen(&result);
438 assert!(
440 output.contains("(SELECT * FROM a AS a) AS a"),
441 "Nested inner table 'a' should be wrapped: {output}"
442 );
443 }
444
445 #[test]
450 fn test_union_both_sides_processed() {
451 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
452 let expr = parse(sql);
453 let result = isolate_table_selects(expr, None, None);
454 let output = gen(&result);
455 assert!(
457 output.contains("(SELECT * FROM a AS a) AS a"),
458 "UNION left side should be processed: {output}"
459 );
460 assert!(
462 !output.contains("(SELECT * FROM c AS c) AS c"),
463 "UNION right side (single source) should not be wrapped: {output}"
464 );
465 }
466
467 #[test]
472 fn test_cross_join() {
473 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
474 let expr = parse(sql);
475 let result = isolate_table_selects(expr, None, None);
476 let output = gen(&result);
477 assert!(
478 output.contains("(SELECT * FROM a AS a) AS a"),
479 "CROSS JOIN table 'a' should be wrapped: {output}"
480 );
481 assert!(
482 output.contains("(SELECT * FROM b AS b) AS b"),
483 "CROSS JOIN table 'b' should be wrapped: {output}"
484 );
485 }
486
487 #[test]
488 fn test_multiple_from_tables() {
489 let sql = "SELECT * FROM a AS a, b AS b";
491 let expr = parse(sql);
492 let result = isolate_table_selects(expr, None, None);
493 let output = gen(&result);
494 assert!(
495 output.contains("(SELECT * FROM a AS a) AS a"),
496 "Comma-join table 'a' should be wrapped: {output}"
497 );
498 assert!(
499 output.contains("(SELECT * FROM b AS b) AS b"),
500 "Comma-join table 'b' should be wrapped: {output}"
501 );
502 }
503
504 #[test]
505 fn test_three_way_join() {
506 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
507 let expr = parse(sql);
508 let result = isolate_table_selects(expr, None, None);
509 let output = gen(&result);
510 assert!(
511 output.contains("(SELECT * FROM a AS a) AS a"),
512 "Three-way join: 'a' should be wrapped: {output}"
513 );
514 assert!(
515 output.contains("(SELECT * FROM b AS b) AS b"),
516 "Three-way join: 'b' should be wrapped: {output}"
517 );
518 assert!(
519 output.contains("(SELECT * FROM c AS c) AS c"),
520 "Three-way join: 'c' should be wrapped: {output}"
521 );
522 }
523
524 #[test]
525 fn test_qualified_table_name_with_schema() {
526 let mut schema = MappingSchema::new();
527 schema
528 .add_table(
529 "mydb.a",
530 &[(
531 "id".to_string(),
532 DataType::Int {
533 length: None,
534 integer_spelling: false,
535 },
536 )],
537 None,
538 )
539 .unwrap();
540 schema
541 .add_table(
542 "mydb.b",
543 &[(
544 "id".to_string(),
545 DataType::Int {
546 length: None,
547 integer_spelling: false,
548 },
549 )],
550 None,
551 )
552 .unwrap();
553
554 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
555 let expr = parse(sql);
556 let result = isolate_table_selects(expr, Some(&schema), None);
557 let output = gen(&result);
558 assert!(
559 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
560 "Qualified table 'mydb.a' should be wrapped: {output}"
561 );
562 assert!(
563 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
564 "Qualified table 'mydb.b' should be wrapped: {output}"
565 );
566 }
567
568 #[test]
569 fn test_non_select_expression_unchanged() {
570 let sql = "INSERT INTO t VALUES (1)";
572 let expr = parse(sql);
573 let original = gen(&expr);
574 let result = isolate_table_selects(expr, None, None);
575 let output = gen(&result);
576 assert_eq!(original, output, "Non-SELECT should be unchanged");
577 }
578}