use crate::dialects::DialectType;
use crate::expressions::*;
use crate::schema::Schema;
#[derive(Debug, Clone, thiserror::Error)]
pub enum IsolateTableSelectsError {
#[error("Tables require an alias: {0}")]
MissingAlias(String),
}
pub fn isolate_table_selects(
expression: Expression,
schema: Option<&dyn Schema>,
_dialect: Option<DialectType>,
) -> Expression {
match expression {
Expression::Select(select) => {
let transformed = isolate_select(*select, schema);
Expression::Select(Box::new(transformed))
}
Expression::Union(mut union) => {
let left = std::mem::replace(&mut union.left, Expression::Null(Null));
union.left = isolate_table_selects(left, schema, _dialect);
let right = std::mem::replace(&mut union.right, Expression::Null(Null));
union.right = isolate_table_selects(right, schema, _dialect);
Expression::Union(union)
}
Expression::Intersect(mut intersect) => {
let left = std::mem::replace(&mut intersect.left, Expression::Null(Null));
intersect.left = isolate_table_selects(left, schema, _dialect);
let right = std::mem::replace(&mut intersect.right, Expression::Null(Null));
intersect.right = isolate_table_selects(right, schema, _dialect);
Expression::Intersect(intersect)
}
Expression::Except(mut except) => {
let left = std::mem::replace(&mut except.left, Expression::Null(Null));
except.left = isolate_table_selects(left, schema, _dialect);
let right = std::mem::replace(&mut except.right, Expression::Null(Null));
except.right = isolate_table_selects(right, schema, _dialect);
Expression::Except(except)
}
other => other,
}
}
fn isolate_select(mut select: Select, schema: Option<&dyn Schema>) -> Select {
if let Some(ref mut with) = select.with {
for cte in &mut with.ctes {
cte.this = isolate_table_selects(cte.this.clone(), schema, None);
}
}
if let Some(ref mut from) = select.from {
for expr in &mut from.expressions {
if let Expression::Subquery(ref mut sq) = expr {
sq.this = isolate_table_selects(sq.this.clone(), schema, None);
}
}
}
for join in &mut select.joins {
if let Expression::Subquery(ref mut sq) = join.this {
sq.this = isolate_table_selects(sq.this.clone(), schema, None);
}
}
let source_count = count_sources(&select);
if source_count <= 1 {
return select;
}
if let Some(ref mut from) = select.from {
from.expressions = from
.expressions
.drain(..)
.map(|expr| maybe_wrap_table(expr, schema))
.collect();
}
for join in &mut select.joins {
join.this = maybe_wrap_table(join.this.clone(), schema);
}
select
}
fn count_sources(select: &Select) -> usize {
let from_count = select
.from
.as_ref()
.map(|f| f.expressions.len())
.unwrap_or(0);
let join_count = select.joins.len();
from_count + join_count
}
fn maybe_wrap_table(expression: Expression, schema: Option<&dyn Schema>) -> Expression {
match expression {
Expression::Table(ref table) => {
if let Some(s) = schema {
let table_name = full_table_name(table);
if s.column_names(&table_name).unwrap_or_default().is_empty() {
return expression;
}
}
let alias_name = match &table.alias {
Some(alias) if !alias.name.is_empty() => alias.name.clone(),
_ => return expression,
};
wrap_table_in_subquery(*table.clone(), &alias_name)
}
_ => expression,
}
}
fn wrap_table_in_subquery(table: TableRef, alias_name: &str) -> Expression {
let inner_select = Select::new()
.column(Expression::Star(Star {
table: None,
except: None,
replace: None,
rename: None,
trailing_comments: Vec::new(),
span: None,
}))
.from(Expression::Table(Box::new(table)));
Expression::Subquery(Box::new(Subquery {
this: Expression::Select(Box::new(inner_select)),
alias: Some(Identifier::new(alias_name)),
column_aliases: Vec::new(),
order_by: None,
limit: None,
offset: None,
distribute_by: None,
sort_by: None,
cluster_by: None,
lateral: false,
modifiers_inside: false,
trailing_comments: Vec::new(),
inferred_type: None,
}))
}
fn full_table_name(table: &TableRef) -> String {
let mut parts = Vec::new();
if let Some(ref catalog) = table.catalog {
parts.push(catalog.name.as_str());
}
if let Some(ref schema) = table.schema {
parts.push(schema.name.as_str());
}
parts.push(&table.name.name);
parts.join(".")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::Generator;
use crate::parser::Parser;
use crate::schema::MappingSchema;
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
#[test]
fn test_single_table_unchanged() {
let sql = "SELECT * FROM t AS t";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
!output.contains("(SELECT"),
"Single table should not be wrapped: {output}"
);
}
#[test]
fn test_single_subquery_unchanged() {
let sql = "SELECT * FROM (SELECT 1) AS t";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert_eq!(
output.matches("(SELECT").count(),
1,
"Single subquery source should not gain extra wrapping: {output}"
);
}
#[test]
fn test_two_tables_joined() {
let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"FROM table should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM b AS b) AS b"),
"JOIN table should be wrapped: {output}"
);
}
#[test]
fn test_table_with_join_subquery() {
let sql = "SELECT * FROM a AS a JOIN (SELECT * FROM b) AS b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Bare table should be wrapped: {output}"
);
assert_eq!(
output.matches("(SELECT * FROM b)").count(),
1,
"Already-subquery source should not be double-wrapped: {output}"
);
}
#[test]
fn test_no_alias_not_wrapped() {
let sql = "SELECT * FROM a JOIN b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
!output.contains("(SELECT * FROM a"),
"Table without alias should not be wrapped: {output}"
);
}
#[test]
fn test_schema_known_table_wrapped() {
let mut schema = MappingSchema::new();
schema
.add_table(
"a",
&[(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
)],
None,
)
.unwrap();
schema
.add_table(
"b",
&[(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
)],
None,
)
.unwrap();
let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, Some(&schema), None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Known table 'a' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM b AS b) AS b"),
"Known table 'b' should be wrapped: {output}"
);
}
#[test]
fn test_schema_unknown_table_not_wrapped() {
let mut schema = MappingSchema::new();
schema
.add_table(
"a",
&[(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
)],
None,
)
.unwrap();
let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, Some(&schema), None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Known table 'a' should be wrapped: {output}"
);
assert!(
!output.contains("(SELECT * FROM b AS b) AS b"),
"Unknown table 'b' should NOT be wrapped: {output}"
);
}
#[test]
fn test_cte_inner_query_processed() {
let sql =
"WITH cte AS (SELECT * FROM x AS x JOIN y AS y ON x.id = y.id) SELECT * FROM cte AS c";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM x AS x) AS x"),
"CTE inner table 'x' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM y AS y) AS y"),
"CTE inner table 'y' should be wrapped: {output}"
);
}
#[test]
fn test_nested_subquery_processed() {
let sql = "SELECT * FROM (SELECT * FROM a AS a JOIN b AS b ON a.id = b.id) AS sub";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Nested inner table 'a' should be wrapped: {output}"
);
}
#[test]
fn test_union_both_sides_processed() {
let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id UNION ALL SELECT * FROM c AS c";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"UNION left side should be processed: {output}"
);
assert!(
!output.contains("(SELECT * FROM c AS c) AS c"),
"UNION right side (single source) should not be wrapped: {output}"
);
}
#[test]
fn test_cross_join() {
let sql = "SELECT * FROM a AS a CROSS JOIN b AS b";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"CROSS JOIN table 'a' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM b AS b) AS b"),
"CROSS JOIN table 'b' should be wrapped: {output}"
);
}
#[test]
fn test_multiple_from_tables() {
let sql = "SELECT * FROM a AS a, b AS b";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Comma-join table 'a' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM b AS b) AS b"),
"Comma-join table 'b' should be wrapped: {output}"
);
}
#[test]
fn test_three_way_join() {
let sql = "SELECT * FROM a AS a JOIN b AS b ON a.id = b.id JOIN c AS c ON b.id = c.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM a AS a) AS a"),
"Three-way join: 'a' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM b AS b) AS b"),
"Three-way join: 'b' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM c AS c) AS c"),
"Three-way join: 'c' should be wrapped: {output}"
);
}
#[test]
fn test_qualified_table_name_with_schema() {
let mut schema = MappingSchema::new();
schema
.add_table(
"mydb.a",
&[(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
)],
None,
)
.unwrap();
schema
.add_table(
"mydb.b",
&[(
"id".to_string(),
DataType::Int {
length: None,
integer_spelling: false,
},
)],
None,
)
.unwrap();
let sql = "SELECT * FROM mydb.a AS a JOIN mydb.b AS b ON a.id = b.id";
let expr = parse(sql);
let result = isolate_table_selects(expr, Some(&schema), None);
let output = gen(&result);
assert!(
output.contains("(SELECT * FROM mydb.a AS a) AS a"),
"Qualified table 'mydb.a' should be wrapped: {output}"
);
assert!(
output.contains("(SELECT * FROM mydb.b AS b) AS b"),
"Qualified table 'mydb.b' should be wrapped: {output}"
);
}
#[test]
fn test_non_select_expression_unchanged() {
let sql = "INSERT INTO t VALUES (1)";
let expr = parse(sql);
let original = gen(&expr);
let result = isolate_table_selects(expr, None, None);
let output = gen(&result);
assert_eq!(original, output, "Non-SELECT should be unchanged");
}
}