1use crate::dialects::DialectType;
12use crate::expressions::*;
13use crate::schema::Schema;
14
15#[derive(Debug, Clone, thiserror::Error)]
17pub enum IsolateTableSelectsError {
18 #[error("Tables require an alias: {0}")]
19 MissingAlias(String),
20}
21
22pub fn isolate_table_selects(
45 expression: Expression,
46 schema: Option<&dyn Schema>,
47 _dialect: Option<DialectType>,
48) -> Expression {
49 match expression {
50 Expression::Select(select) => {
51 let transformed = isolate_select(*select, schema);
52 Expression::Select(Box::new(transformed))
53 }
54 Expression::Union(mut union) => {
55 union.left = isolate_table_selects(union.left, schema, _dialect);
56 union.right = isolate_table_selects(union.right, schema, _dialect);
57 Expression::Union(union)
58 }
59 Expression::Intersect(mut intersect) => {
60 intersect.left = isolate_table_selects(intersect.left, schema, _dialect);
61 intersect.right = isolate_table_selects(intersect.right, schema, _dialect);
62 Expression::Intersect(intersect)
63 }
64 Expression::Except(mut except) => {
65 except.left = isolate_table_selects(except.left, schema, _dialect);
66 except.right = isolate_table_selects(except.right, schema, _dialect);
67 Expression::Except(except)
68 }
69 other => other,
70 }
71}
72
73fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
76 if let Some(ref mut with) = select.with {
78 for cte in &mut with.ctes {
79 cte.this = isolate_table_selects(cte.this.clone(), schema, None);
80 }
81 }
82
83 if let Some(ref mut from) = select.from {
85 for expr in &mut from.expressions {
86 if let Expression::Subquery(ref mut sq) = expr {
87 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
88 }
89 }
90 }
91 for join in &mut select.joins {
92 if let Expression::Subquery(ref mut sq) = join.this {
93 sq.this = isolate_table_selects(sq.this.clone(), schema, None);
94 }
95 }
96
97 let source_count = count_sources(&select);
99
100 if source_count <= 1 {
102 return select;
103 }
104
105 if let Some(ref mut from) = select.from {
107 from.expressions = from
108 .expressions
109 .drain(..)
110 .map(|expr| maybe_wrap_table(expr, schema))
111 .collect();
112 }
113
114 for join in &mut select.joins {
116 join.this = maybe_wrap_table(join.this.clone(), schema);
117 }
118
119 select
120}
121
122fn count_sources(select: &Select) -> usize {
126 let from_count = select
127 .from
128 .as_ref()
129 .map(|f| f.expressions.len())
130 .unwrap_or(0);
131 let join_count = select.joins.len();
132 from_count + join_count
133}
134
135fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
145 match expression {
146 Expression::Table(ref table) => {
147 if let Some(s) = schema {
151 let table_name = full_table_name(table);
152 if s.column_names(&table_name).unwrap_or_default().is_empty() {
153 return expression;
154 }
155 }
156
157 let alias_name = match &table.alias {
161 Some(alias) if !alias.name.is_empty() => alias.name.clone(),
162 _ => return expression,
163 };
164
165 wrap_table_in_subquery(table.clone(), &alias_name)
166 }
167 _ => expression,
168 }
169}
170
171fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
176 let inner_select = Select::new()
178 .column(Expression::Star(Star {
179 table: None,
180 except: None,
181 replace: None,
182 rename: None,
183 trailing_comments: Vec::new(),
184 }))
185 .from(Expression::Table(table));
186
187 Expression::Subquery(Box::new(Subquery {
189 this: Expression::Select(Box::new(inner_select)),
190 alias: Some(Identifier::new(alias_name)),
191 column_aliases: Vec::new(),
192 order_by: None,
193 limit: None,
194 offset: None,
195 distribute_by: None,
196 sort_by: None,
197 cluster_by: None,
198 lateral: false,
199 modifiers_inside: false,
200 trailing_comments: Vec::new(),
201 }))
202}
203
204fn full_table_name(table: &TableRef) -> String {
209 let mut parts = Vec::new();
210 if let Some(ref catalog) = table.catalog {
211 parts.push(catalog.name.as_str());
212 }
213 if let Some(ref schema) = table.schema {
214 parts.push(schema.name.as_str());
215 }
216 parts.push(&table.name.name);
217 parts.join(".")
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::generator::Generator;
224 use crate::parser::Parser;
225 use crate::schema::MappingSchema;
226
227 fn parse(sql: &str) -> Expression {
229 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
230 }
231
232 fn gen(expr: &Expression) -> String {
234 Generator::new().generate(expr).unwrap()
235 }
236
237 #[test]
242 fn test_single_table_unchanged() {
243 let sql = "SELECT * FROM t AS t";
244 let expr = parse(sql);
245 let result = isolate_table_selects(expr, None, None);
246 let output = gen(&result);
247 assert!(
249 !output.contains("(SELECT"),
250 "Single table should not be wrapped: {output}"
251 );
252 }
253
254 #[test]
255 fn test_single_subquery_unchanged() {
256 let sql = "SELECT * FROM (SELECT 1) AS t";
257 let expr = parse(sql);
258 let result = isolate_table_selects(expr, None, None);
259 let output = gen(&result);
260 assert_eq!(
262 output.matches("(SELECT").count(),
263 1,
264 "Single subquery source should not gain extra wrapping: {output}"
265 );
266 }
267
268 #[test]
273 fn test_two_tables_joined() {
274 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
275 let expr = parse(sql);
276 let result = isolate_table_selects(expr, None, None);
277 let output = gen(&result);
278 assert!(
280 output.contains("(SELECT * FROM a AS a) AS a"),
281 "FROM table should be wrapped: {output}"
282 );
283 assert!(
284 output.contains("(SELECT * FROM b AS b) AS b"),
285 "JOIN table should be wrapped: {output}"
286 );
287 }
288
289 #[test]
290 fn test_table_with_join_subquery() {
291 let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
294 let expr = parse(sql);
295 let result = isolate_table_selects(expr, None, None);
296 let output = gen(&result);
297 assert!(
299 output.contains("(SELECT * FROM a AS a) AS a"),
300 "Bare table should be wrapped: {output}"
301 );
302 assert_eq!(
305 output.matches("(SELECT * FROM b)").count(),
306 1,
307 "Already-subquery source should not be double-wrapped: {output}"
308 );
309 }
310
311 #[test]
312 fn test_no_alias_not_wrapped() {
313 let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
316 let expr = parse(sql);
317 let result = isolate_table_selects(expr, None, None);
318 let output = gen(&result);
319 assert!(
321 !output.contains("(SELECT * FROM a"),
322 "Table without alias should not be wrapped: {output}"
323 );
324 }
325
326 #[test]
331 fn test_schema_known_table_wrapped() {
332 let mut schema = MappingSchema::new();
333 schema
334 .add_table(
335 "a",
336 &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
337 None,
338 )
339 .unwrap();
340 schema
341 .add_table(
342 "b",
343 &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
344 None,
345 )
346 .unwrap();
347
348 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
349 let expr = parse(sql);
350 let result = isolate_table_selects(expr, Some(&schema), None);
351 let output = gen(&result);
352 assert!(
353 output.contains("(SELECT * FROM a AS a) AS a"),
354 "Known table 'a' should be wrapped: {output}"
355 );
356 assert!(
357 output.contains("(SELECT * FROM b AS b) AS b"),
358 "Known table 'b' should be wrapped: {output}"
359 );
360 }
361
362 #[test]
363 fn test_schema_unknown_table_not_wrapped() {
364 let mut schema = MappingSchema::new();
365 schema
367 .add_table(
368 "a",
369 &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
370 None,
371 )
372 .unwrap();
373
374 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
375 let expr = parse(sql);
376 let result = isolate_table_selects(expr, Some(&schema), None);
377 let output = gen(&result);
378 assert!(
379 output.contains("(SELECT * FROM a AS a) AS a"),
380 "Known table 'a' should be wrapped: {output}"
381 );
382 assert!(
384 !output.contains("(SELECT * FROM b AS b) AS b"),
385 "Unknown table 'b' should NOT be wrapped: {output}"
386 );
387 }
388
389 #[test]
394 fn test_cte_inner_query_processed() {
395 let sql = "WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
396 let expr = parse(sql);
397 let result = isolate_table_selects(expr, None, None);
398 let output = gen(&result);
399 assert!(
401 output.contains("(SELECT * FROM x AS x) AS x"),
402 "CTE inner table 'x' should be wrapped: {output}"
403 );
404 assert!(
405 output.contains("(SELECT * FROM y AS y) AS y"),
406 "CTE inner table 'y' should be wrapped: {output}"
407 );
408 }
409
410 #[test]
411 fn test_nested_subquery_processed() {
412 let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
413 let expr = parse(sql);
414 let result = isolate_table_selects(expr, None, None);
415 let output = gen(&result);
416 assert!(
418 output.contains("(SELECT * FROM a AS a) AS a"),
419 "Nested inner table 'a' should be wrapped: {output}"
420 );
421 }
422
423 #[test]
428 fn test_union_both_sides_processed() {
429 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
430 let expr = parse(sql);
431 let result = isolate_table_selects(expr, None, None);
432 let output = gen(&result);
433 assert!(
435 output.contains("(SELECT * FROM a AS a) AS a"),
436 "UNION left side should be processed: {output}"
437 );
438 assert!(
440 !output.contains("(SELECT * FROM c AS c) AS c"),
441 "UNION right side (single source) should not be wrapped: {output}"
442 );
443 }
444
445 #[test]
450 fn test_cross_join() {
451 let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
452 let expr = parse(sql);
453 let result = isolate_table_selects(expr, None, None);
454 let output = gen(&result);
455 assert!(
456 output.contains("(SELECT * FROM a AS a) AS a"),
457 "CROSS JOIN table 'a' should be wrapped: {output}"
458 );
459 assert!(
460 output.contains("(SELECT * FROM b AS b) AS b"),
461 "CROSS JOIN table 'b' should be wrapped: {output}"
462 );
463 }
464
465 #[test]
466 fn test_multiple_from_tables() {
467 let sql = "SELECT * FROM a AS a, b AS b";
469 let expr = parse(sql);
470 let result = isolate_table_selects(expr, None, None);
471 let output = gen(&result);
472 assert!(
473 output.contains("(SELECT * FROM a AS a) AS a"),
474 "Comma-join table 'a' should be wrapped: {output}"
475 );
476 assert!(
477 output.contains("(SELECT * FROM b AS b) AS b"),
478 "Comma-join table 'b' should be wrapped: {output}"
479 );
480 }
481
482 #[test]
483 fn test_three_way_join() {
484 let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
485 let expr = parse(sql);
486 let result = isolate_table_selects(expr, None, None);
487 let output = gen(&result);
488 assert!(
489 output.contains("(SELECT * FROM a AS a) AS a"),
490 "Three-way join: 'a' should be wrapped: {output}"
491 );
492 assert!(
493 output.contains("(SELECT * FROM b AS b) AS b"),
494 "Three-way join: 'b' should be wrapped: {output}"
495 );
496 assert!(
497 output.contains("(SELECT * FROM c AS c) AS c"),
498 "Three-way join: 'c' should be wrapped: {output}"
499 );
500 }
501
502 #[test]
503 fn test_qualified_table_name_with_schema() {
504 let mut schema = MappingSchema::new();
505 schema
506 .add_table(
507 "mydb.a",
508 &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
509 None,
510 )
511 .unwrap();
512 schema
513 .add_table(
514 "mydb.b",
515 &[("id".to_string(), DataType::Int { length: None, integer_spelling: false })],
516 None,
517 )
518 .unwrap();
519
520 let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
521 let expr = parse(sql);
522 let result = isolate_table_selects(expr, Some(&schema), None);
523 let output = gen(&result);
524 assert!(
525 output.contains("(SELECT * FROM mydb.a AS a) AS a"),
526 "Qualified table 'mydb.a' should be wrapped: {output}"
527 );
528 assert!(
529 output.contains("(SELECT * FROM mydb.b AS b) AS b"),
530 "Qualified table 'mydb.b' should be wrapped: {output}"
531 );
532 }
533
534 #[test]
535 fn test_non_select_expression_unchanged() {
536 let sql = "INSERT INTO t VALUES (1)";
538 let expr = parse(sql);
539 let original = gen(&expr);
540 let result = isolate_table_selects(expr, None, None);
541 let output = gen(&result);
542 assert_eq!(original, output, "Non-SELECT should be unchanged");
543 }
544}