1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 union.left = isolate_table_selects(union.left, schema, _dialect);
56 union.right = isolate_table_selects(union.right, schema, _dialect);
57 Expression::Union(union)
58 }
59 Expression::Intersect(mut intersect) => {
60 intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61 intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62 Expression::Intersect(intersect)
63 }
64 Expression::Except(mut except) => {
65 except.left = isolate_table_selects(except.left, schema, _dialect);
66 except.right = isolate_table_selects(except.right, schema, _dialect);
67 Expression::Except(except)
68 }
69 other => other,
70 }
71}
72
73fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76 if let Some(ref mut with) = select.with {
78 for cte in &mut with.ctes {
79 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80 }
81 }
82
83 if let Some(ref mut from) = select.from {
85 for expr in &mut from.expressions {
86 if let Expression::Subquery(ref mut sq) = expr {
87 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88 }
89 }
90 }
91 for join in &mut select.joins {
92 if let Expression::Subquery(ref mut sq) = join.this {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96
97 let source_count = count_sources(&select);
99
100 if source_count <= 1 {
102 return select;
103 }
104
105 if let Some(ref mut from) = select.from {
107 from.expressions = from
108 .expressions
109 .drain(..)
110 .map(|expr| maybe_wrap_table(expr, schema))
111 .collect();
112 }
113
114 for join in &mut select.joins {
116 join.this = maybe_wrap_table(join.this.clone(), schema);
117 }
118
119 select
120}
121
122fn count_sources(select: &Select) -> usize {
126 let from_count = select
127 .from
128 .as_ref()
129 .map(|f| f.expressions.len())
130 .unwrap_or(0);
131 let join_count = select.joins.len();
132 from_count + join_count
133}
134
135fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145 match expression {
146 Expression::Table(ref table) => {
147 if let Some(s) = schema {
151 let table_name = full_table_name(table);
152 if s.column_names(&table_name).unwrap_or_default().is_empty() {
153 return expression;
154 }
155 }
156
157 let alias_name = match &table.alias {
161 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162 _ => return expression,
163 };
164
165 wrap_table_in_subquery(table.clone(), &alias_name)
166 }
167 _ => expression,
168 }
169}
170
171fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176 let inner_select = Select::new()
178 .column(Expression::Star(Star {
179 table: None,
180 except: None,
181 replace: None,
182 rename: None,
183 trailing_comments: Vec::new(),
184 }))
185 .from(Expression::Table(table));
186
187 Expression::Subquery(Box::new(Subquery {
189 this: Expression::Select(Box::new(inner_select)),
190 alias: Some(Identifier::new(alias_name)),
191 column_aliases: Vec::new(),
192 order_by: None,
193 limit: None,
194 offset: None,
195 distribute_by: None,
196 sort_by: None,
197 cluster_by: None,
198 lateral: false,
199 modifiers_inside: false,
200 trailing_comments: Vec::new(),
201 }))
202}
203
204fn full_table_name(table: &TableRef) -> String {
209 let mut parts = Vec::new();
210 if let Some(ref catalog) = table.catalog {
211 parts.push(catalog.name.as_str());
212 }
213 if let Some(ref schema) = table.schema {
214 parts.push(schema.name.as_str());
215 }
216 parts.push(&table.name.name);
217 parts.join(".")
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::generator::Generator;
224 use crate::parser::Parser;
225 use crate::schema::MappingSchema;
226
227 fn parse(sql: &str) -> Expression {
229 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
230 }
231
232 fn gen(expr: &Expression) -> String {
234 Generator::new().generate(expr).unwrap()
235 }
236
237 #[test]
242 fn test_single_table_unchanged() {
243 let sql = "SELECT * FROM t AS t";
244 let expr = parse(sql);
245 let result = isolate_table_selects(expr, None, None);
246 let output = gen(&result);
247 assert!(
249 !output.contains("(SELECT"),
250 "Single table should not be wrapped: {output}"
251 );
252 }
253
254 #[test]
255 fn test_single_subquery_unchanged() {
256 let sql = "SELECT * FROM (SELECT 1) AS t";
257 let expr = parse(sql);
258 let result = isolate_table_selects(expr, None, None);
259 let output = gen(&result);
260 assert_eq!(
262 output.matches("(SELECT").count(),
263 1,
264 "Single subquery source should not gain extra wrapping: {output}"
265 );
266 }
267
268 #[test]
273 fn test_two_tables_joined() {
274 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
275 let expr = parse(sql);
276 let result = isolate_table_selects(expr, None, None);
277 let output = gen(&result);
278 assert!(
280 output.contains("(SELECT * FROM a AS a) AS a"),
281 "FROM table should be wrapped: {output}"
282 );
283 assert!(
284 output.contains("(SELECT * FROM b AS b) AS b"),
285 "JOIN table should be wrapped: {output}"
286 );
287 }
288
289 #[test]
290 fn test_table_with_join_subquery() {
291 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
294 let expr = parse(sql);
295 let result = isolate_table_selects(expr, None, None);
296 let output = gen(&result);
297 assert!(
299 output.contains("(SELECT * FROM a AS a) AS a"),
300 "Bare table should be wrapped: {output}"
301 );
302 assert_eq!(
305 output.matches("(SELECT * FROM b)").count(),
306 1,
307 "Already-subquery source should not be double-wrapped: {output}"
308 );
309 }
310
311 #[test]
312 fn test_no_alias_not_wrapped() {
313 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
316 let expr = parse(sql);
317 let result = isolate_table_selects(expr, None, None);
318 let output = gen(&result);
319 assert!(
321 !output.contains("(SELECT * FROM a"),
322 "Table without alias should not be wrapped: {output}"
323 );
324 }
325
326 #[test]
331 fn test_schema_known_table_wrapped() {
332 let mut schema = MappingSchema::new();
333 schema
334 .add_table(
335 "a",
336 &[(
337 "id".to_string(),
338 DataType::Int {
339 length: None,
340 integer_spelling: false,
341 },
342 )],
343 None,
344 )
345 .unwrap();
346 schema
347 .add_table(
348 "b",
349 &[(
350 "id".to_string(),
351 DataType::Int {
352 length: None,
353 integer_spelling: false,
354 },
355 )],
356 None,
357 )
358 .unwrap();
359
360 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
361 let expr = parse(sql);
362 let result = isolate_table_selects(expr, Some(&schema), None);
363 let output = gen(&result);
364 assert!(
365 output.contains("(SELECT * FROM a AS a) AS a"),
366 "Known table 'a' should be wrapped: {output}"
367 );
368 assert!(
369 output.contains("(SELECT * FROM b AS b) AS b"),
370 "Known table 'b' should be wrapped: {output}"
371 );
372 }
373
374 #[test]
375 fn test_schema_unknown_table_not_wrapped() {
376 let mut schema = MappingSchema::new();
377 schema
379 .add_table(
380 "a",
381 &[(
382 "id".to_string(),
383 DataType::Int {
384 length: None,
385 integer_spelling: false,
386 },
387 )],
388 None,
389 )
390 .unwrap();
391
392 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
393 let expr = parse(sql);
394 let result = isolate_table_selects(expr, Some(&schema), None);
395 let output = gen(&result);
396 assert!(
397 output.contains("(SELECT * FROM a AS a) AS a"),
398 "Known table 'a' should be wrapped: {output}"
399 );
400 assert!(
402 !output.contains("(SELECT * FROM b AS b) AS b"),
403 "Unknown table 'b' should NOT be wrapped: {output}"
404 );
405 }
406
407 #[test]
412 fn test_cte_inner_query_processed() {
413 let sql =
414 "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
415 let expr = parse(sql);
416 let result = isolate_table_selects(expr, None, None);
417 let output = gen(&result);
418 assert!(
420 output.contains("(SELECT * FROM x AS x) AS x"),
421 "CTE inner table 'x' should be wrapped: {output}"
422 );
423 assert!(
424 output.contains("(SELECT * FROM y AS y) AS y"),
425 "CTE inner table 'y' should be wrapped: {output}"
426 );
427 }
428
429 #[test]
430 fn test_nested_subquery_processed() {
431 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
432 let expr = parse(sql);
433 let result = isolate_table_selects(expr, None, None);
434 let output = gen(&result);
435 assert!(
437 output.contains("(SELECT * FROM a AS a) AS a"),
438 "Nested inner table 'a' should be wrapped: {output}"
439 );
440 }
441
442 #[test]
447 fn test_union_both_sides_processed() {
448 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
449 let expr = parse(sql);
450 let result = isolate_table_selects(expr, None, None);
451 let output = gen(&result);
452 assert!(
454 output.contains("(SELECT * FROM a AS a) AS a"),
455 "UNION left side should be processed: {output}"
456 );
457 assert!(
459 !output.contains("(SELECT * FROM c AS c) AS c"),
460 "UNION right side (single source) should not be wrapped: {output}"
461 );
462 }
463
464 #[test]
469 fn test_cross_join() {
470 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
471 let expr = parse(sql);
472 let result = isolate_table_selects(expr, None, None);
473 let output = gen(&result);
474 assert!(
475 output.contains("(SELECT * FROM a AS a) AS a"),
476 "CROSS JOIN table 'a' should be wrapped: {output}"
477 );
478 assert!(
479 output.contains("(SELECT * FROM b AS b) AS b"),
480 "CROSS JOIN table 'b' should be wrapped: {output}"
481 );
482 }
483
484 #[test]
485 fn test_multiple_from_tables() {
486 let sql = "SELECT * FROM a AS a, b AS b";
488 let expr = parse(sql);
489 let result = isolate_table_selects(expr, None, None);
490 let output = gen(&result);
491 assert!(
492 output.contains("(SELECT * FROM a AS a) AS a"),
493 "Comma-join table 'a' should be wrapped: {output}"
494 );
495 assert!(
496 output.contains("(SELECT * FROM b AS b) AS b"),
497 "Comma-join table 'b' should be wrapped: {output}"
498 );
499 }
500
501 #[test]
502 fn test_three_way_join() {
503 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
504 let expr = parse(sql);
505 let result = isolate_table_selects(expr, None, None);
506 let output = gen(&result);
507 assert!(
508 output.contains("(SELECT * FROM a AS a) AS a"),
509 "Three-way join: 'a' should be wrapped: {output}"
510 );
511 assert!(
512 output.contains("(SELECT * FROM b AS b) AS b"),
513 "Three-way join: 'b' should be wrapped: {output}"
514 );
515 assert!(
516 output.contains("(SELECT * FROM c AS c) AS c"),
517 "Three-way join: 'c' should be wrapped: {output}"
518 );
519 }
520
521 #[test]
522 fn test_qualified_table_name_with_schema() {
523 let mut schema = MappingSchema::new();
524 schema
525 .add_table(
526 "mydb.a",
527 &[(
528 "id".to_string(),
529 DataType::Int {
530 length: None,
531 integer_spelling: false,
532 },
533 )],
534 None,
535 )
536 .unwrap();
537 schema
538 .add_table(
539 "mydb.b",
540 &[(
541 "id".to_string(),
542 DataType::Int {
543 length: None,
544 integer_spelling: false,
545 },
546 )],
547 None,
548 )
549 .unwrap();
550
551 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
552 let expr = parse(sql);
553 let result = isolate_table_selects(expr, Some(&schema), None);
554 let output = gen(&result);
555 assert!(
556 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
557 "Qualified table 'mydb.a' should be wrapped: {output}"
558 );
559 assert!(
560 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
561 "Qualified table 'mydb.b' should be wrapped: {output}"
562 );
563 }
564
565 #[test]
566 fn test_non_select_expression_unchanged() {
567 let sql = "INSERT INTO t VALUES (1)";
569 let expr = parse(sql);
570 let original = gen(&expr);
571 let result = isolate_table_selects(expr, None, None);
572 let output = gen(&result);
573 assert_eq!(original, output, "Non-SELECT should be unchanged");
574 }
575}