1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 union.left = isolate_table_selects(union.left, schema, _dialect);
56 union.right = isolate_table_selects(union.right, schema, _dialect);
57 Expression::Union(union)
58 }
59 Expression::Intersect(mut intersect) => {
60 intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61 intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62 Expression::Intersect(intersect)
63 }
64 Expression::Except(mut except) => {
65 except.left = isolate_table_selects(except.left, schema, _dialect);
66 except.right = isolate_table_selects(except.right, schema, _dialect);
67 Expression::Except(except)
68 }
69 other => other,
70 }
71}
72
73fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76 if let Some(ref mut with) = select.with {
78 for cte in &mut with.ctes {
79 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80 }
81 }
82
83 if let Some(ref mut from) = select.from {
85 for expr in &mut from.expressions {
86 if let Expression::Subquery(ref mut sq) = expr {
87 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88 }
89 }
90 }
91 for join in &mut select.joins {
92 if let Expression::Subquery(ref mut sq) = join.this {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96
97 let source_count = count_sources(&select);
99
100 if source_count <= 1 {
102 return select;
103 }
104
105 if let Some(ref mut from) = select.from {
107 from.expressions = from
108 .expressions
109 .drain(..)
110 .map(|expr| maybe_wrap_table(expr, schema))
111 .collect();
112 }
113
114 for join in &mut select.joins {
116 join.this = maybe_wrap_table(join.this.clone(), schema);
117 }
118
119 select
120}
121
122fn count_sources(select: &Select) -> usize {
126 let from_count = select
127 .from
128 .as_ref()
129 .map(|f| f.expressions.len())
130 .unwrap_or(0);
131 let join_count = select.joins.len();
132 from_count + join_count
133}
134
135fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145 match expression {
146 Expression::Table(ref table) => {
147 if let Some(s) = schema {
151 let table_name = full_table_name(table);
152 if s.column_names(&table_name).unwrap_or_default().is_empty() {
153 return expression;
154 }
155 }
156
157 let alias_name = match &table.alias {
161 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162 _ => return expression,
163 };
164
165 wrap_table_in_subquery(table.clone(), &alias_name)
166 }
167 _ => expression,
168 }
169}
170
171fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176 let inner_select = Select::new()
178 .column(Expression::Star(Star {
179 table: None,
180 except: None,
181 replace: None,
182 rename: None,
183 trailing_comments: Vec::new(),
184 span: None,
185 }))
186 .from(Expression::Table(table));
187
188 Expression::Subquery(Box::new(Subquery {
190 this: Expression::Select(Box::new(inner_select)),
191 alias: Some(Identifier::new(alias_name)),
192 column_aliases: Vec::new(),
193 order_by: None,
194 limit: None,
195 offset: None,
196 distribute_by: None,
197 sort_by: None,
198 cluster_by: None,
199 lateral: false,
200 modifiers_inside: false,
201 trailing_comments: Vec::new(),
202 }))
203}
204
205fn full_table_name(table: &TableRef) -> String {
210 let mut parts = Vec::new();
211 if let Some(ref catalog) = table.catalog {
212 parts.push(catalog.name.as_str());
213 }
214 if let Some(ref schema) = table.schema {
215 parts.push(schema.name.as_str());
216 }
217 parts.push(&table.name.name);
218 parts.join(".")
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::generator::Generator;
225 use crate::parser::Parser;
226 use crate::schema::MappingSchema;
227
228 fn parse(sql: &str) -> Expression {
230 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
231 }
232
233 fn gen(expr: &Expression) -> String {
235 Generator::new().generate(expr).unwrap()
236 }
237
238 #[test]
243 fn test_single_table_unchanged() {
244 let sql = "SELECT * FROM t AS t";
245 let expr = parse(sql);
246 let result = isolate_table_selects(expr, None, None);
247 let output = gen(&result);
248 assert!(
250 !output.contains("(SELECT"),
251 "Single table should not be wrapped: {output}"
252 );
253 }
254
255 #[test]
256 fn test_single_subquery_unchanged() {
257 let sql = "SELECT * FROM (SELECT 1) AS t";
258 let expr = parse(sql);
259 let result = isolate_table_selects(expr, None, None);
260 let output = gen(&result);
261 assert_eq!(
263 output.matches("(SELECT").count(),
264 1,
265 "Single subquery source should not gain extra wrapping: {output}"
266 );
267 }
268
269 #[test]
274 fn test_two_tables_joined() {
275 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
276 let expr = parse(sql);
277 let result = isolate_table_selects(expr, None, None);
278 let output = gen(&result);
279 assert!(
281 output.contains("(SELECT * FROM a AS a) AS a"),
282 "FROM table should be wrapped: {output}"
283 );
284 assert!(
285 output.contains("(SELECT * FROM b AS b) AS b"),
286 "JOIN table should be wrapped: {output}"
287 );
288 }
289
290 #[test]
291 fn test_table_with_join_subquery() {
292 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
295 let expr = parse(sql);
296 let result = isolate_table_selects(expr, None, None);
297 let output = gen(&result);
298 assert!(
300 output.contains("(SELECT * FROM a AS a) AS a"),
301 "Bare table should be wrapped: {output}"
302 );
303 assert_eq!(
306 output.matches("(SELECT * FROM b)").count(),
307 1,
308 "Already-subquery source should not be double-wrapped: {output}"
309 );
310 }
311
312 #[test]
313 fn test_no_alias_not_wrapped() {
314 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
317 let expr = parse(sql);
318 let result = isolate_table_selects(expr, None, None);
319 let output = gen(&result);
320 assert!(
322 !output.contains("(SELECT * FROM a"),
323 "Table without alias should not be wrapped: {output}"
324 );
325 }
326
327 #[test]
332 fn test_schema_known_table_wrapped() {
333 let mut schema = MappingSchema::new();
334 schema
335 .add_table(
336 "a",
337 &[(
338 "id".to_string(),
339 DataType::Int {
340 length: None,
341 integer_spelling: false,
342 },
343 )],
344 None,
345 )
346 .unwrap();
347 schema
348 .add_table(
349 "b",
350 &[(
351 "id".to_string(),
352 DataType::Int {
353 length: None,
354 integer_spelling: false,
355 },
356 )],
357 None,
358 )
359 .unwrap();
360
361 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
362 let expr = parse(sql);
363 let result = isolate_table_selects(expr, Some(&schema), None);
364 let output = gen(&result);
365 assert!(
366 output.contains("(SELECT * FROM a AS a) AS a"),
367 "Known table 'a' should be wrapped: {output}"
368 );
369 assert!(
370 output.contains("(SELECT * FROM b AS b) AS b"),
371 "Known table 'b' should be wrapped: {output}"
372 );
373 }
374
375 #[test]
376 fn test_schema_unknown_table_not_wrapped() {
377 let mut schema = MappingSchema::new();
378 schema
380 .add_table(
381 "a",
382 &[(
383 "id".to_string(),
384 DataType::Int {
385 length: None,
386 integer_spelling: false,
387 },
388 )],
389 None,
390 )
391 .unwrap();
392
393 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
394 let expr = parse(sql);
395 let result = isolate_table_selects(expr, Some(&schema), None);
396 let output = gen(&result);
397 assert!(
398 output.contains("(SELECT * FROM a AS a) AS a"),
399 "Known table 'a' should be wrapped: {output}"
400 );
401 assert!(
403 !output.contains("(SELECT * FROM b AS b) AS b"),
404 "Unknown table 'b' should NOT be wrapped: {output}"
405 );
406 }
407
408 #[test]
413 fn test_cte_inner_query_processed() {
414 let sql =
415 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
416 let expr = parse(sql);
417 let result = isolate_table_selects(expr, None, None);
418 let output = gen(&result);
419 assert!(
421 output.contains("(SELECT * FROM x AS x) AS x"),
422 "CTE inner table 'x' should be wrapped: {output}"
423 );
424 assert!(
425 output.contains("(SELECT * FROM y AS y) AS y"),
426 "CTE inner table 'y' should be wrapped: {output}"
427 );
428 }
429
430 #[test]
431 fn test_nested_subquery_processed() {
432 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
433 let expr = parse(sql);
434 let result = isolate_table_selects(expr, None, None);
435 let output = gen(&result);
436 assert!(
438 output.contains("(SELECT * FROM a AS a) AS a"),
439 "Nested inner table 'a' should be wrapped: {output}"
440 );
441 }
442
443 #[test]
448 fn test_union_both_sides_processed() {
449 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
450 let expr = parse(sql);
451 let result = isolate_table_selects(expr, None, None);
452 let output = gen(&result);
453 assert!(
455 output.contains("(SELECT * FROM a AS a) AS a"),
456 "UNION left side should be processed: {output}"
457 );
458 assert!(
460 !output.contains("(SELECT * FROM c AS c) AS c"),
461 "UNION right side (single source) should not be wrapped: {output}"
462 );
463 }
464
465 #[test]
470 fn test_cross_join() {
471 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
472 let expr = parse(sql);
473 let result = isolate_table_selects(expr, None, None);
474 let output = gen(&result);
475 assert!(
476 output.contains("(SELECT * FROM a AS a) AS a"),
477 "CROSS JOIN table 'a' should be wrapped: {output}"
478 );
479 assert!(
480 output.contains("(SELECT * FROM b AS b) AS b"),
481 "CROSS JOIN table 'b' should be wrapped: {output}"
482 );
483 }
484
485 #[test]
486 fn test_multiple_from_tables() {
487 let sql = "SELECT * FROM a AS a, b AS b";
489 let expr = parse(sql);
490 let result = isolate_table_selects(expr, None, None);
491 let output = gen(&result);
492 assert!(
493 output.contains("(SELECT * FROM a AS a) AS a"),
494 "Comma-join table 'a' should be wrapped: {output}"
495 );
496 assert!(
497 output.contains("(SELECT * FROM b AS b) AS b"),
498 "Comma-join table 'b' should be wrapped: {output}"
499 );
500 }
501
502 #[test]
503 fn test_three_way_join() {
504 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
505 let expr = parse(sql);
506 let result = isolate_table_selects(expr, None, None);
507 let output = gen(&result);
508 assert!(
509 output.contains("(SELECT * FROM a AS a) AS a"),
510 "Three-way join: 'a' should be wrapped: {output}"
511 );
512 assert!(
513 output.contains("(SELECT * FROM b AS b) AS b"),
514 "Three-way join: 'b' should be wrapped: {output}"
515 );
516 assert!(
517 output.contains("(SELECT * FROM c AS c) AS c"),
518 "Three-way join: 'c' should be wrapped: {output}"
519 );
520 }
521
522 #[test]
523 fn test_qualified_table_name_with_schema() {
524 let mut schema = MappingSchema::new();
525 schema
526 .add_table(
527 "mydb.a",
528 &[(
529 "id".to_string(),
530 DataType::Int {
531 length: None,
532 integer_spelling: false,
533 },
534 )],
535 None,
536 )
537 .unwrap();
538 schema
539 .add_table(
540 "mydb.b",
541 &[(
542 "id".to_string(),
543 DataType::Int {
544 length: None,
545 integer_spelling: false,
546 },
547 )],
548 None,
549 )
550 .unwrap();
551
552 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
553 let expr = parse(sql);
554 let result = isolate_table_selects(expr, Some(&schema), None);
555 let output = gen(&result);
556 assert!(
557 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
558 "Qualified table 'mydb.a' should be wrapped: {output}"
559 );
560 assert!(
561 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
562 "Qualified table 'mydb.b' should be wrapped: {output}"
563 );
564 }
565
566 #[test]
567 fn test_non_select_expression_unchanged() {
568 let sql = "INSERT INTO t VALUES (1)";
570 let expr = parse(sql);
571 let original = gen(&expr);
572 let result = isolate_table_selects(expr, None, None);
573 let output = gen(&result);
574 assert_eq!(original, output, "Non-SELECT should be unchanged");
575 }
576}