1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 let left = std::mem::replace(&mut union.left, Expression::Null(Null));
56 union.left = isolate_table_selects(left, schema, _dialect);
57 let right = std::mem::replace(&mut union.right, Expression::Null(Null));
58 union.right = isolate_table_selects(right, schema, _dialect);
59 Expression::Union(union)
60 }
61 Expression::Intersect(mut intersect) => {
62 let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
63 intersect.left = isolate_table_selects(left, schema, _dialect);
64 let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
65 intersect.right = isolate_table_selects(right, schema, _dialect);
66 Expression::Intersect(intersect)
67 }
68 Expression::Except(mut except) => {
69 let left = std::mem::replace(&mut except.left, Expression::Null(Null));
70 except.left = isolate_table_selects(left, schema, _dialect);
71 let right = std::mem::replace(&mut except.right, Expression::Null(Null));
72 except.right = isolate_table_selects(right, schema, _dialect);
73 Expression::Except(except)
74 }
75 other => other,
76 }
77}
78
79fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
82 if let Some(ref mut with) = select.with {
84 for cte in &mut with.ctes {
85 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
86 }
87 }
88
89 if let Some(ref mut from) = select.from {
91 for expr in &mut from.expressions {
92 if let Expression::Subquery(ref mut sq) = expr {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96 }
97 for join in &mut select.joins {
98 if let Expression::Subquery(ref mut sq) = join.this {
99 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
100 }
101 }
102
103 let source_count = count_sources(&select);
105
106 if source_count <= 1 {
108 return select;
109 }
110
111 if let Some(ref mut from) = select.from {
113 from.expressions = from
114 .expressions
115 .drain(..)
116 .map(|expr| maybe_wrap_table(expr, schema))
117 .collect();
118 }
119
120 for join in &mut select.joins {
122 join.this = maybe_wrap_table(join.this.clone(), schema);
123 }
124
125 select
126}
127
128fn count_sources(select: &Select) -> usize {
132 let from_count = select
133 .from
134 .as_ref()
135 .map(|f| f.expressions.len())
136 .unwrap_or(0);
137 let join_count = select.joins.len();
138 from_count + join_count
139}
140
141fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
151 match expression {
152 Expression::Table(ref table) => {
153 if let Some(s) = schema {
157 let table_name = full_table_name(table);
158 if s.column_names(&table_name).unwrap_or_default().is_empty() {
159 return expression;
160 }
161 }
162
163 let alias_name = match &table.alias {
167 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
168 _ => return expression,
169 };
170
171 wrap_table_in_subquery(*table.clone(), &alias_name)
172 }
173 _ => expression,
174 }
175}
176
177fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
182 let inner_select = Select::new()
184 .column(Expression::Star(Star {
185 table: None,
186 except: None,
187 replace: None,
188 rename: None,
189 trailing_comments: Vec::new(),
190 span: None,
191 }))
192 .from(Expression::Table(Box::new(table)));
193
194 Expression::Subquery(Box::new(Subquery {
196 this: Expression::Select(Box::new(inner_select)),
197 alias: Some(Identifier::new(alias_name)),
198 column_aliases: Vec::new(),
199 order_by: None,
200 limit: None,
201 offset: None,
202 distribute_by: None,
203 sort_by: None,
204 cluster_by: None,
205 lateral: false,
206 modifiers_inside: false,
207 trailing_comments: Vec::new(),
208 inferred_type: None,
209 }))
210}
211
212fn full_table_name(table: &TableRef) -> String {
217 let mut parts = Vec::new();
218 if let Some(ref catalog) = table.catalog {
219 parts.push(catalog.name.as_str());
220 }
221 if let Some(ref schema) = table.schema {
222 parts.push(schema.name.as_str());
223 }
224 parts.push(&table.name.name);
225 parts.join(".")
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::generator::Generator;
232 use crate::parser::Parser;
233 use crate::schema::MappingSchema;
234
235 fn parse(sql: &str) -> Expression {
237 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
238 }
239
240 fn gen(expr: &Expression) -> String {
242 Generator::new().generate(expr).unwrap()
243 }
244
245 #[test]
250 fn test_single_table_unchanged() {
251 let sql = "SELECT * FROM t AS t";
252 let expr = parse(sql);
253 let result = isolate_table_selects(expr, None, None);
254 let output = gen(&result);
255 assert!(
257 !output.contains("(SELECT"),
258 "Single table should not be wrapped: {output}"
259 );
260 }
261
262 #[test]
263 fn test_single_subquery_unchanged() {
264 let sql = "SELECT * FROM (SELECT 1) AS t";
265 let expr = parse(sql);
266 let result = isolate_table_selects(expr, None, None);
267 let output = gen(&result);
268 assert_eq!(
270 output.matches("(SELECT").count(),
271 1,
272 "Single subquery source should not gain extra wrapping: {output}"
273 );
274 }
275
276 #[test]
281 fn test_two_tables_joined() {
282 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
283 let expr = parse(sql);
284 let result = isolate_table_selects(expr, None, None);
285 let output = gen(&result);
286 assert!(
288 output.contains("(SELECT * FROM a AS a) AS a"),
289 "FROM table should be wrapped: {output}"
290 );
291 assert!(
292 output.contains("(SELECT * FROM b AS b) AS b"),
293 "JOIN table should be wrapped: {output}"
294 );
295 }
296
297 #[test]
298 fn test_table_with_join_subquery() {
299 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
302 let expr = parse(sql);
303 let result = isolate_table_selects(expr, None, None);
304 let output = gen(&result);
305 assert!(
307 output.contains("(SELECT * FROM a AS a) AS a"),
308 "Bare table should be wrapped: {output}"
309 );
310 assert_eq!(
313 output.matches("(SELECT * FROM b)").count(),
314 1,
315 "Already-subquery source should not be double-wrapped: {output}"
316 );
317 }
318
319 #[test]
320 fn test_no_alias_not_wrapped() {
321 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
324 let expr = parse(sql);
325 let result = isolate_table_selects(expr, None, None);
326 let output = gen(&result);
327 assert!(
329 !output.contains("(SELECT * FROM a"),
330 "Table without alias should not be wrapped: {output}"
331 );
332 }
333
334 #[test]
339 fn test_schema_known_table_wrapped() {
340 let mut schema = MappingSchema::new();
341 schema
342 .add_table(
343 "a",
344 &[(
345 "id".to_string(),
346 DataType::Int {
347 length: None,
348 integer_spelling: false,
349 },
350 )],
351 None,
352 )
353 .unwrap();
354 schema
355 .add_table(
356 "b",
357 &[(
358 "id".to_string(),
359 DataType::Int {
360 length: None,
361 integer_spelling: false,
362 },
363 )],
364 None,
365 )
366 .unwrap();
367
368 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
369 let expr = parse(sql);
370 let result = isolate_table_selects(expr, Some(&schema), None);
371 let output = gen(&result);
372 assert!(
373 output.contains("(SELECT * FROM a AS a) AS a"),
374 "Known table 'a' should be wrapped: {output}"
375 );
376 assert!(
377 output.contains("(SELECT * FROM b AS b) AS b"),
378 "Known table 'b' should be wrapped: {output}"
379 );
380 }
381
382 #[test]
383 fn test_schema_unknown_table_not_wrapped() {
384 let mut schema = MappingSchema::new();
385 schema
387 .add_table(
388 "a",
389 &[(
390 "id".to_string(),
391 DataType::Int {
392 length: None,
393 integer_spelling: false,
394 },
395 )],
396 None,
397 )
398 .unwrap();
399
400 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
401 let expr = parse(sql);
402 let result = isolate_table_selects(expr, Some(&schema), None);
403 let output = gen(&result);
404 assert!(
405 output.contains("(SELECT * FROM a AS a) AS a"),
406 "Known table 'a' should be wrapped: {output}"
407 );
408 assert!(
410 !output.contains("(SELECT * FROM b AS b) AS b"),
411 "Unknown table 'b' should NOT be wrapped: {output}"
412 );
413 }
414
415 #[test]
420 fn test_cte_inner_query_processed() {
421 let sql =
422 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
423 let expr = parse(sql);
424 let result = isolate_table_selects(expr, None, None);
425 let output = gen(&result);
426 assert!(
428 output.contains("(SELECT * FROM x AS x) AS x"),
429 "CTE inner table 'x' should be wrapped: {output}"
430 );
431 assert!(
432 output.contains("(SELECT * FROM y AS y) AS y"),
433 "CTE inner table 'y' should be wrapped: {output}"
434 );
435 }
436
437 #[test]
438 fn test_nested_subquery_processed() {
439 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
440 let expr = parse(sql);
441 let result = isolate_table_selects(expr, None, None);
442 let output = gen(&result);
443 assert!(
445 output.contains("(SELECT * FROM a AS a) AS a"),
446 "Nested inner table 'a' should be wrapped: {output}"
447 );
448 }
449
450 #[test]
455 fn test_union_both_sides_processed() {
456 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
457 let expr = parse(sql);
458 let result = isolate_table_selects(expr, None, None);
459 let output = gen(&result);
460 assert!(
462 output.contains("(SELECT * FROM a AS a) AS a"),
463 "UNION left side should be processed: {output}"
464 );
465 assert!(
467 !output.contains("(SELECT * FROM c AS c) AS c"),
468 "UNION right side (single source) should not be wrapped: {output}"
469 );
470 }
471
472 #[test]
477 fn test_cross_join() {
478 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
479 let expr = parse(sql);
480 let result = isolate_table_selects(expr, None, None);
481 let output = gen(&result);
482 assert!(
483 output.contains("(SELECT * FROM a AS a) AS a"),
484 "CROSS JOIN table 'a' should be wrapped: {output}"
485 );
486 assert!(
487 output.contains("(SELECT * FROM b AS b) AS b"),
488 "CROSS JOIN table 'b' should be wrapped: {output}"
489 );
490 }
491
492 #[test]
493 fn test_multiple_from_tables() {
494 let sql = "SELECT * FROM a AS a, b AS b";
496 let expr = parse(sql);
497 let result = isolate_table_selects(expr, None, None);
498 let output = gen(&result);
499 assert!(
500 output.contains("(SELECT * FROM a AS a) AS a"),
501 "Comma-join table 'a' should be wrapped: {output}"
502 );
503 assert!(
504 output.contains("(SELECT * FROM b AS b) AS b"),
505 "Comma-join table 'b' should be wrapped: {output}"
506 );
507 }
508
509 #[test]
510 fn test_three_way_join() {
511 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
512 let expr = parse(sql);
513 let result = isolate_table_selects(expr, None, None);
514 let output = gen(&result);
515 assert!(
516 output.contains("(SELECT * FROM a AS a) AS a"),
517 "Three-way join: 'a' should be wrapped: {output}"
518 );
519 assert!(
520 output.contains("(SELECT * FROM b AS b) AS b"),
521 "Three-way join: 'b' should be wrapped: {output}"
522 );
523 assert!(
524 output.contains("(SELECT * FROM c AS c) AS c"),
525 "Three-way join: 'c' should be wrapped: {output}"
526 );
527 }
528
529 #[test]
530 fn test_qualified_table_name_with_schema() {
531 let mut schema = MappingSchema::new();
532 schema
533 .add_table(
534 "mydb.a",
535 &[(
536 "id".to_string(),
537 DataType::Int {
538 length: None,
539 integer_spelling: false,
540 },
541 )],
542 None,
543 )
544 .unwrap();
545 schema
546 .add_table(
547 "mydb.b",
548 &[(
549 "id".to_string(),
550 DataType::Int {
551 length: None,
552 integer_spelling: false,
553 },
554 )],
555 None,
556 )
557 .unwrap();
558
559 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
560 let expr = parse(sql);
561 let result = isolate_table_selects(expr, Some(&schema), None);
562 let output = gen(&result);
563 assert!(
564 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
565 "Qualified table 'mydb.a' should be wrapped: {output}"
566 );
567 assert!(
568 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
569 "Qualified table 'mydb.b' should be wrapped: {output}"
570 );
571 }
572
573 #[test]
574 fn test_non_select_expression_unchanged() {
575 let sql = "INSERT INTO t VALUES (1)";
577 let expr = parse(sql);
578 let original = gen(&expr);
579 let result = isolate_table_selects(expr, None, None);
580 let output = gen(&result);
581 assert_eq!(original, output, "Non-SELECT should be unchanged");
582 }
583}