use sqlparser::ast::{
CreateView, Cte, Expr, Query, Select, SelectItem, SetExpr, Spanned, Statement, TableFactor,
TableWithJoins,
};
use crate::analyzer::helpers::{infer_expr_type, line_col_to_offset};
use crate::types::{AstColumnInfo, AstContext, AstTableInfo, CteInfo, SubqueryInfo};
#[derive(Debug, Clone)]
pub struct LateralAliasInfo {
pub name: String,
pub definition_end: usize,
pub projection_start: usize,
pub projection_end: usize,
}
const MAX_EXTRACTION_DEPTH: usize = 50;
const MAX_LATERAL_ALIASES: usize = 1000;
pub(crate) fn extract_ast_context(statements: &[Statement]) -> AstContext {
let mut ctx = AstContext::default();
for stmt in statements {
extract_from_statement(stmt, &mut ctx, 0);
}
ctx
}
fn extract_from_statement(stmt: &Statement, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return; }
match stmt {
Statement::Query(query) => {
extract_from_query(query, ctx, depth);
}
Statement::Insert(insert) => {
if let Some(source) = &insert.source {
extract_from_query(source, ctx, depth);
}
}
Statement::CreateTable(ct) => {
if let Some(query) = &ct.query {
extract_from_query(query, ctx, depth);
}
}
Statement::CreateView(CreateView { query, .. }) => {
extract_from_query(query, ctx, depth);
}
_ => {}
}
}
fn extract_from_query(query: &Query, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return;
}
if let Some(with) = &query.with {
let is_recursive = with.recursive;
for cte in &with.cte_tables {
if let Some(info) = extract_cte_info(cte, is_recursive) {
ctx.cte_definitions.insert(info.name.clone(), info);
}
}
}
extract_from_set_expr(&query.body, ctx, depth + 1);
}
fn extract_from_set_expr(set_expr: &SetExpr, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return;
}
match set_expr {
SetExpr::Select(select) => {
extract_from_select(select, ctx, depth);
}
SetExpr::Query(query) => {
extract_from_query(query, ctx, depth);
}
SetExpr::SetOperation { left, right, .. } => {
extract_from_set_expr(left, ctx, depth + 1);
extract_from_set_expr(right, ctx, depth + 1);
}
SetExpr::Values(_) => {}
SetExpr::Insert(_) => {}
SetExpr::Update(_) => {}
SetExpr::Table(_) => {}
SetExpr::Delete(_) => {}
SetExpr::Merge(_) => {}
}
}
fn extract_from_select(select: &Select, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return;
}
for table_with_joins in &select.from {
extract_from_table_with_joins(table_with_joins, ctx, depth);
}
}
fn extract_from_table_with_joins(twj: &TableWithJoins, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return;
}
extract_from_table_factor(&twj.relation, ctx, depth);
for join in &twj.joins {
extract_from_table_factor(&join.relation, ctx, depth);
}
}
fn extract_from_table_factor(tf: &TableFactor, ctx: &mut AstContext, depth: usize) {
if depth > MAX_EXTRACTION_DEPTH {
return;
}
match tf {
TableFactor::Table { name, alias, .. } => {
let table_name = name.to_string();
let alias_name = alias.as_ref().map(|a| a.name.value.clone());
let key = alias_name.clone().unwrap_or_else(|| {
name.0
.last()
.map(|i| i.to_string())
.unwrap_or(table_name.clone())
});
ctx.table_aliases.insert(key, AstTableInfo);
}
TableFactor::Derived {
subquery, alias, ..
} => {
if let Some(alias) = alias {
let columns = extract_projected_columns_from_query(subquery);
ctx.subquery_aliases.insert(
alias.name.value.clone(),
SubqueryInfo {
projected_columns: columns,
},
);
}
extract_from_query(subquery, ctx, depth + 1);
}
TableFactor::NestedJoin {
table_with_joins, ..
} => {
extract_from_table_with_joins(table_with_joins, ctx, depth + 1);
}
TableFactor::TableFunction { .. } => {}
TableFactor::UNNEST {
alias: Some(alias), ..
} => {
ctx.table_aliases
.insert(alias.name.value.clone(), AstTableInfo);
}
_ => {}
}
}
fn extract_cte_info(cte: &Cte, is_recursive: bool) -> Option<CteInfo> {
let name = cte.alias.name.value.clone();
let declared_columns: Vec<String> = cte
.alias
.columns
.iter()
.map(|c| c.name.value.clone())
.collect();
let projected_columns = if is_recursive {
extract_base_case_columns(&cte.query)
} else {
extract_projected_columns_from_query(&cte.query)
};
Some(CteInfo {
name,
declared_columns,
projected_columns,
})
}
fn extract_base_case_columns(query: &Query) -> Vec<AstColumnInfo> {
match &*query.body {
SetExpr::SetOperation { left, .. } => {
if let SetExpr::Select(select) = &**left {
extract_select_columns(select)
} else {
vec![]
}
}
SetExpr::Select(select) => extract_select_columns(select),
_ => vec![],
}
}
fn extract_projected_columns_from_query(query: &Query) -> Vec<AstColumnInfo> {
match &*query.body {
SetExpr::Select(select) => extract_select_columns(select),
SetExpr::SetOperation { left, .. } => {
if let SetExpr::Select(select) = &**left {
extract_select_columns(select)
} else {
vec![]
}
}
_ => vec![],
}
}
fn extract_select_columns(select: &Select) -> Vec<AstColumnInfo> {
let mut columns = Vec::new();
for (idx, item) in select.projection.iter().enumerate() {
match item {
SelectItem::ExprWithAlias { alias, expr } => {
columns.push(AstColumnInfo {
name: alias.value.clone(),
data_type: infer_data_type(expr),
});
}
SelectItem::UnnamedExpr(expr) => {
columns.push(AstColumnInfo {
name: derive_column_name(expr, idx),
data_type: infer_data_type(expr),
});
}
SelectItem::Wildcard(_) => {
columns.push(AstColumnInfo {
name: "*".to_string(),
data_type: None,
});
}
SelectItem::QualifiedWildcard(name, _) => {
columns.push(AstColumnInfo {
name: format!("{}.*", name),
data_type: None,
});
}
}
}
columns
}
fn derive_column_name(expr: &Expr, index: usize) -> String {
match expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|i| i.value.clone())
.unwrap_or_else(|| format!("col_{}", index)),
Expr::Function(func) => func.name.to_string().to_lowercase(),
Expr::Cast { .. } => format!("col_{}", index),
Expr::Case { .. } => format!("case_{}", index),
Expr::Subquery(_) => format!("subquery_{}", index),
_ => format!("col_{}", index),
}
}
fn infer_data_type(expr: &Expr) -> Option<String> {
infer_expr_type(expr).map(|canonical| canonical.as_uppercase_str().to_string())
}
pub(crate) fn extract_lateral_aliases(
statements: &[Statement],
sql: &str,
) -> Vec<LateralAliasInfo> {
let mut aliases = Vec::with_capacity(64);
for stmt in statements {
if aliases.len() >= MAX_LATERAL_ALIASES {
break;
}
if let Statement::Query(query) = stmt {
if let Some(with) = &query.with {
for cte in &with.cte_tables {
if aliases.len() >= MAX_LATERAL_ALIASES {
break;
}
extract_lateral_aliases_from_set_expr(&cte.query.body, sql, &mut aliases, 0);
}
}
if aliases.len() < MAX_LATERAL_ALIASES {
extract_lateral_aliases_from_set_expr(&query.body, sql, &mut aliases, 0);
}
}
}
aliases
}
fn extract_lateral_aliases_from_set_expr(
set_expr: &SetExpr,
sql: &str,
aliases: &mut Vec<LateralAliasInfo>,
depth: usize,
) {
if depth > MAX_EXTRACTION_DEPTH || aliases.len() >= MAX_LATERAL_ALIASES {
return;
}
match set_expr {
SetExpr::Select(select) => {
extract_lateral_aliases_from_select(select, sql, aliases);
}
SetExpr::Query(query) => {
if let Some(with) = &query.with {
for cte in &with.cte_tables {
if aliases.len() >= MAX_LATERAL_ALIASES {
break;
}
extract_lateral_aliases_from_set_expr(&cte.query.body, sql, aliases, depth + 1);
}
}
if aliases.len() < MAX_LATERAL_ALIASES {
extract_lateral_aliases_from_set_expr(&query.body, sql, aliases, depth + 1);
}
}
SetExpr::SetOperation { left, right, .. } => {
extract_lateral_aliases_from_set_expr(left, sql, aliases, depth + 1);
if aliases.len() < MAX_LATERAL_ALIASES {
extract_lateral_aliases_from_set_expr(right, sql, aliases, depth + 1);
}
}
_ => {}
}
}
fn extract_lateral_aliases_from_select(
select: &Select,
sql: &str,
aliases: &mut Vec<LateralAliasInfo>,
) {
if aliases.len() >= MAX_LATERAL_ALIASES {
return;
}
let projection_span = compute_projection_span(select, sql);
let (projection_start, projection_end) = match projection_span {
Some((start, end)) => (start, end),
None => return, };
for item in &select.projection {
if aliases.len() >= MAX_LATERAL_ALIASES {
break;
}
if let SelectItem::ExprWithAlias { alias, .. } = item {
if let Some(end_offset) = line_col_to_offset(
sql,
alias.span.end.line as usize,
alias.span.end.column as usize,
) {
if end_offset <= sql.len() && sql.is_char_boundary(end_offset) {
aliases.push(LateralAliasInfo {
name: alias.value.clone(),
definition_end: end_offset,
projection_start,
projection_end,
});
}
}
}
}
}
fn compute_projection_span(select: &Select, sql: &str) -> Option<(usize, usize)> {
if select.projection.is_empty() {
return None;
}
let first_span = select
.projection
.iter()
.filter_map(select_item_span)
.next()
.or_else(|| {
let span = select.span();
if span.start.line > 0 && span.start.column > 0 {
Some((span.start.line, span.start.column))
} else {
None
}
})?;
let start = line_col_to_offset(sql, first_span.0 as usize, first_span.1 as usize)?;
let end = if let Some(from_item) = select.from.first() {
compute_from_clause_start(from_item, sql).unwrap_or_else(|| {
select
.projection
.last()
.and_then(|item| {
let span = select_item_end_span(item)?;
line_col_to_offset(sql, span.0 as usize, span.1 as usize)
})
.unwrap_or(sql.len())
})
} else {
sql.len()
};
if start <= sql.len() && end <= sql.len() && start <= end {
Some((start, end))
} else {
None
}
}
fn compute_from_clause_start(from_item: &TableWithJoins, sql: &str) -> Option<usize> {
let span = table_factor_span(&from_item.relation)?;
let table_start = line_col_to_offset(sql, span.0 as usize, span.1 as usize)?;
let search_start = find_char_boundary_before(sql, table_start.saturating_sub(50));
let search_area = &sql[search_start..table_start];
if let Some(pos) = rfind_ascii_case_insensitive(search_area, b"FROM") {
Some(search_start + pos)
} else {
Some(table_start)
}
}
fn find_char_boundary_before(s: &str, pos: usize) -> usize {
if pos >= s.len() {
return s.len();
}
(0..=pos)
.rev()
.find(|&i| s.is_char_boundary(i))
.unwrap_or(0)
}
fn rfind_ascii_case_insensitive(haystack: &str, needle: &[u8]) -> Option<usize> {
if needle.is_empty() || haystack.len() < needle.len() {
return None;
}
let haystack_bytes = haystack.as_bytes();
for start in (0..=(haystack_bytes.len() - needle.len())).rev() {
let mut matches = true;
for (i, &needle_byte) in needle.iter().enumerate() {
let hay_byte = haystack_bytes[start + i];
if !hay_byte.eq_ignore_ascii_case(&needle_byte) {
matches = false;
break;
}
}
if matches {
return Some(start);
}
}
None
}
fn table_factor_span(tf: &TableFactor) -> Option<(u64, u64)> {
match tf {
TableFactor::Table { name, .. } => name.0.first().map(|i| {
let span = i.span();
(span.start.line, span.start.column)
}),
TableFactor::Derived { subquery, .. } => {
let span = subquery.body.span();
if span.start.line > 0 {
Some((span.start.line, span.start.column))
} else {
None
}
}
_ => None,
}
}
fn select_item_span(item: &SelectItem) -> Option<(u64, u64)> {
match item {
SelectItem::ExprWithAlias { expr, .. } | SelectItem::UnnamedExpr(expr) => {
expr_start_span(expr)
}
SelectItem::Wildcard(opts) => {
if let Some(exclude) = &opts.opt_exclude {
match exclude {
sqlparser::ast::ExcludeSelectItem::Single(ident) => {
Some((ident.span.start.line, ident.span.start.column))
}
sqlparser::ast::ExcludeSelectItem::Multiple(idents) => idents
.first()
.map(|i| (i.span.start.line, i.span.start.column)),
}
} else {
None
}
}
SelectItem::QualifiedWildcard(name, _) => {
let span = name.span();
Some((span.start.line, span.start.column))
}
}
}
fn select_item_end_span(item: &SelectItem) -> Option<(u64, u64)> {
match item {
SelectItem::ExprWithAlias { alias, .. } => {
Some((alias.span.end.line, alias.span.end.column))
}
SelectItem::UnnamedExpr(expr) => expr_end_span(expr),
SelectItem::Wildcard(_) => None, SelectItem::QualifiedWildcard(name, _) => {
let span = name.span();
Some((span.end.line, span.end.column))
}
}
}
fn expr_start_span(expr: &Expr) -> Option<(u64, u64)> {
let span = expr.span();
if span.start.line > 0 && span.start.column > 0 {
Some((span.start.line, span.start.column))
} else {
None
}
}
fn expr_end_span(expr: &Expr) -> Option<(u64, u64)> {
let span = expr.span();
if span.end.line > 0 && span.end.column > 0 {
Some((span.end.line, span.end.column))
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use sqlparser::parser::Parser;
fn parse_sql(sql: &str) -> Vec<Statement> {
Parser::parse_sql(&sqlparser::dialect::GenericDialect {}, sql).unwrap()
}
#[test]
fn test_extract_cte() {
let sql = "WITH cte AS (SELECT id, name FROM users) SELECT * FROM cte";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
assert!(ctx.cte_definitions.contains_key("cte"));
let cte = &ctx.cte_definitions["cte"];
assert_eq!(cte.name, "cte");
assert_eq!(cte.projected_columns.len(), 2);
assert_eq!(cte.projected_columns[0].name, "id");
assert_eq!(cte.projected_columns[1].name, "name");
}
#[test]
fn test_extract_cte_with_declared_columns() {
let sql = "WITH cte(a, b) AS (SELECT id, name FROM users) SELECT * FROM cte";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
let cte = &ctx.cte_definitions["cte"];
assert_eq!(cte.declared_columns, vec!["a", "b"]);
}
#[test]
fn test_extract_table_alias() {
let sql = "SELECT * FROM users u JOIN orders o ON u.id = o.user_id";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
assert!(ctx.table_aliases.contains_key("u"));
assert!(ctx.table_aliases.contains_key("o"));
}
#[test]
fn test_extract_subquery_alias() {
let sql = "SELECT * FROM (SELECT a, b FROM t) AS sub WHERE sub.a = 1";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
assert!(ctx.subquery_aliases.contains_key("sub"));
let sub = &ctx.subquery_aliases["sub"];
assert_eq!(sub.projected_columns.len(), 2);
assert_eq!(sub.projected_columns[0].name, "a");
assert_eq!(sub.projected_columns[1].name, "b");
}
#[test]
fn test_extract_lateral_subquery() {
let sql = "SELECT * FROM users u, LATERAL (SELECT * FROM orders WHERE user_id = u.id) AS o";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
assert!(ctx.subquery_aliases.contains_key("o"));
}
#[test]
fn test_extract_column_with_alias() {
let sql =
"WITH cte AS (SELECT id AS user_id, name AS user_name FROM users) SELECT * FROM cte";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
let cte = &ctx.cte_definitions["cte"];
assert_eq!(cte.projected_columns[0].name, "user_id");
assert_eq!(cte.projected_columns[1].name, "user_name");
}
#[test]
fn test_extract_function_column_name() {
let sql = "WITH cte AS (SELECT COUNT(*), SUM(amount) FROM orders) SELECT * FROM cte";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
let cte = &ctx.cte_definitions["cte"];
assert!(cte.projected_columns[0]
.name
.to_lowercase()
.contains("count"));
}
#[test]
fn test_extract_wildcard() {
let sql = "WITH cte AS (SELECT * FROM users) SELECT * FROM cte";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
let cte = &ctx.cte_definitions["cte"];
assert_eq!(cte.projected_columns[0].name, "*");
}
#[test]
fn test_extract_recursive_cte() {
let sql = r#"
WITH RECURSIVE cte AS (
SELECT 1 AS n
UNION ALL
SELECT n + 1 FROM cte WHERE n < 10
)
SELECT * FROM cte
"#;
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
let cte = &ctx.cte_definitions["cte"];
assert_eq!(cte.projected_columns.len(), 1);
assert_eq!(cte.projected_columns[0].name, "n");
}
#[test]
fn test_has_enrichment() {
let sql = "SELECT * FROM users";
let stmts = parse_sql(sql);
let ctx = extract_ast_context(&stmts);
assert!(ctx.has_enrichment()); }
#[test]
fn test_empty_context() {
let ctx = AstContext::default();
assert!(!ctx.has_enrichment());
}
#[test]
fn test_extract_lateral_aliases_single() {
let sql = "SELECT price * qty AS total FROM orders";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 1);
assert_eq!(aliases[0].name, "total");
assert!(aliases[0].definition_end > 0);
assert!(aliases[0].definition_end <= sql.len());
}
#[test]
fn test_extract_lateral_aliases_with_leading_wildcard() {
let sql = "SELECT *, price * qty AS total, discount AS disc FROM orders";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
let names: Vec<_> = aliases.iter().map(|a| a.name.as_str()).collect();
assert_eq!(names, vec!["total", "disc"]);
}
#[test]
fn test_extract_lateral_aliases_multiple() {
let sql = "SELECT a AS x, b AS y, c AS z FROM t";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 3);
assert_eq!(aliases[0].name, "x");
assert_eq!(aliases[1].name, "y");
assert_eq!(aliases[2].name, "z");
assert!(aliases[0].definition_end < aliases[1].definition_end);
assert!(aliases[1].definition_end < aliases[2].definition_end);
}
#[test]
fn test_extract_lateral_aliases_with_expression() {
let sql = "SELECT price * qty AS total, total * 0.1 AS tax FROM orders";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 2);
assert_eq!(aliases[0].name, "total");
assert_eq!(aliases[1].name, "tax");
}
#[test]
fn test_extract_lateral_aliases_no_aliases() {
let sql = "SELECT price, qty FROM orders";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert!(aliases.is_empty());
}
#[test]
fn test_extract_lateral_aliases_mixed() {
let sql = "SELECT a, b AS alias_b, c FROM t";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 1);
assert_eq!(aliases[0].name, "alias_b");
}
#[test]
fn test_extract_lateral_aliases_quoted() {
let sql = r#"SELECT a AS "My Total", b AS "Tax Amount" FROM t"#;
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 2);
assert_eq!(aliases[0].name, "My Total");
assert_eq!(aliases[1].name, "Tax Amount");
}
#[test]
fn test_extract_lateral_aliases_subquery_in_from() {
let sql = "SELECT * FROM (SELECT a AS x, b AS y FROM t) sub";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 0);
}
#[test]
fn test_extract_lateral_aliases_outer_select_with_alias() {
let sql = "SELECT sub.x AS outer_x FROM (SELECT a AS x FROM t) sub";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 1);
assert_eq!(aliases[0].name, "outer_x");
}
#[test]
fn test_extract_lateral_aliases_with_unicode() {
let sql = "SELECT '日本語' AS label, value AS val FROM t";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 2);
assert_eq!(aliases[0].name, "label");
assert_eq!(aliases[1].name, "val");
}
#[test]
fn test_extract_lateral_aliases_cte_scope_isolation() {
let sql =
"WITH cte AS (SELECT a AS inner_alias FROM t) SELECT cte.a AS outer_alias FROM cte";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 2);
let inner = aliases.iter().find(|a| a.name == "inner_alias").unwrap();
let outer = aliases.iter().find(|a| a.name == "outer_alias").unwrap();
assert!(
inner.projection_start < outer.projection_start,
"CTE projection should start before outer SELECT projection"
);
assert!(
inner.projection_end < outer.projection_start
|| outer.projection_end < inner.projection_start
|| inner.projection_start != outer.projection_start,
"CTE and outer SELECT projections should have different spans"
);
}
#[test]
fn test_extract_lateral_aliases_projection_span_validity() {
let sql = "SELECT a AS x, b AS y FROM t";
let stmts = parse_sql(sql);
let aliases = extract_lateral_aliases(&stmts, sql);
assert_eq!(aliases.len(), 2);
for alias in &aliases {
assert!(
alias.definition_end <= alias.projection_end,
"Alias definition should be within projection span"
);
assert!(
alias.projection_start < alias.definition_end,
"Projection should start before alias definition ends"
);
}
}
}