use datafusion::error::DataFusionError;
use datafusion::sql::sqlparser;
use datafusion::sql::sqlparser::{
ast::{Expr, Query, Select, SelectItem, SetExpr, Statement, TableFactor, TableWithJoins},
dialect::GenericDialect,
parser::Parser,
};
use std::collections::HashSet;
pub fn extract_table_names_and_ctes(sql_query: &str) -> Result<(HashSet<String>, HashSet<String>), DataFusionError> {
let dialect = GenericDialect {};
let statements = Parser::parse_sql(&dialect, sql_query).map_err(|e| DataFusionError::Execution(format!("Failed to parse SQL: {}", e)))?;
let mut table_names = HashSet::new();
let mut cte_names = HashSet::new();
for statement in statements {
match statement {
Statement::Query(query) => {
extract_from_query(&query, &mut table_names, &mut cte_names);
}
Statement::Insert(insert) => {
if let Some(query) = &insert.source {
extract_from_query(query, &mut table_names, &mut cte_names);
}
}
Statement::Update { table, selection, .. } => {
extract_from_table_with_joins(&table, &mut table_names);
if let Some(expr) = selection {
extract_from_expr(&expr, &mut table_names);
}
}
Statement::Delete(delete) => {
for table in &delete.tables {
if let Some(table_name) = table.0.last() {
table_names.insert(table_name.value.to_lowercase());
}
}
if let sqlparser::ast::FromTable::WithFromKeyword(tables) = &delete.from {
for from_table in tables {
extract_from_table_with_joins(from_table, &mut table_names);
}
}
}
_ => {
}
}
}
Ok((table_names, cte_names))
}
fn extract_from_query(query: &Query, table_names: &mut HashSet<String>, cte_names: &mut HashSet<String>) {
if let Some(with) = &query.with {
for cte in &with.cte_tables {
cte_names.insert(cte.alias.name.value.to_lowercase());
extract_from_query(&cte.query, table_names, cte_names);
}
}
extract_from_set_expr(&query.body, table_names, cte_names);
}
fn extract_from_set_expr(set_expr: &SetExpr, table_names: &mut HashSet<String>, cte_names: &mut HashSet<String>) {
match set_expr {
SetExpr::Select(select) => {
extract_from_select(select, table_names, cte_names);
}
SetExpr::Query(query) => {
extract_from_query(query, table_names, cte_names);
}
SetExpr::SetOperation { left, right, .. } => {
extract_from_set_expr(left, table_names, cte_names);
extract_from_set_expr(right, table_names, cte_names);
}
SetExpr::Values(_) => {
}
SetExpr::Insert(_) => {
}
SetExpr::Update(_) => {
}
SetExpr::Table(table) => {
if let Some(table_name_str) = &table.table_name {
table_names.insert(table_name_str.to_lowercase());
}
}
}
}
fn extract_from_select(select: &Select, table_names: &mut HashSet<String>, _cte_names: &mut HashSet<String>) {
for table_with_joins in &select.from {
extract_from_table_with_joins(table_with_joins, table_names);
}
for item in &select.projection {
match item {
SelectItem::UnnamedExpr(expr) | SelectItem::ExprWithAlias { expr, .. } => {
extract_from_expr(expr, table_names);
}
_ => {}
}
}
if let Some(selection) = &select.selection {
extract_from_expr(selection, table_names);
}
match &select.group_by {
sqlparser::ast::GroupByExpr::All(_) => {}
sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
for expr in exprs {
extract_from_expr(expr, table_names);
}
}
}
if let Some(having) = &select.having {
extract_from_expr(having, table_names);
}
}
fn extract_from_table_with_joins(table_with_joins: &TableWithJoins, table_names: &mut HashSet<String>) {
extract_from_table_factor(&table_with_joins.relation, table_names);
for join in &table_with_joins.joins {
extract_from_table_factor(&join.relation, table_names);
match &join.join_operator {
sqlparser::ast::JoinOperator::Inner(constraint)
| sqlparser::ast::JoinOperator::LeftOuter(constraint)
| sqlparser::ast::JoinOperator::RightOuter(constraint)
| sqlparser::ast::JoinOperator::FullOuter(constraint) => {
if let sqlparser::ast::JoinConstraint::On(expr) = constraint {
extract_from_expr(expr, table_names);
}
}
_ => {}
}
}
}
fn extract_from_table_factor(table_factor: &TableFactor, table_names: &mut HashSet<String>) {
match table_factor {
TableFactor::Table { name, .. } => {
if let Some(table_name) = name.0.last() {
table_names.insert(table_name.value.to_lowercase());
}
}
TableFactor::Derived { subquery, .. } => {
let mut temp_cte_names = HashSet::new();
extract_from_query(subquery, table_names, &mut temp_cte_names);
}
TableFactor::TableFunction { .. } => {
}
TableFactor::UNNEST { .. } => {
}
TableFactor::NestedJoin { table_with_joins, .. } => {
extract_from_table_with_joins(table_with_joins, table_names);
}
TableFactor::Pivot { .. } | TableFactor::Unpivot { .. } => {
}
TableFactor::MatchRecognize { .. } => {
}
TableFactor::JsonTable { .. } => {
}
TableFactor::Function { .. } => {
}
TableFactor::OpenJsonTable { .. } => {
}
}
}
fn extract_from_expr(expr: &Expr, table_names: &mut HashSet<String>) {
match expr {
Expr::Subquery(query) => {
let mut temp_cte_names = HashSet::new();
extract_from_query(query, table_names, &mut temp_cte_names);
}
Expr::InSubquery { subquery, .. } => {
let mut temp_cte_names = HashSet::new();
extract_from_query(subquery, table_names, &mut temp_cte_names);
}
Expr::Exists { subquery, .. } => {
let mut temp_cte_names = HashSet::new();
extract_from_query(subquery, table_names, &mut temp_cte_names);
}
Expr::BinaryOp { left, right, .. } => {
extract_from_expr(left, table_names);
extract_from_expr(right, table_names);
}
Expr::UnaryOp { expr, .. } => {
extract_from_expr(expr, table_names);
}
Expr::Cast { expr, .. } => {
extract_from_expr(expr, table_names);
}
Expr::Nested(expr) => {
extract_from_expr(expr, table_names);
}
Expr::Function(func) => match &func.args {
sqlparser::ast::FunctionArguments::None => {}
sqlparser::ast::FunctionArguments::Subquery(_) => {}
sqlparser::ast::FunctionArguments::List(arg_list) => {
for arg in &arg_list.args {
match arg {
sqlparser::ast::FunctionArg::Named { arg, .. } => {
extract_from_function_arg_expr(arg, table_names);
}
sqlparser::ast::FunctionArg::Unnamed(arg) => {
extract_from_function_arg_expr(arg, table_names);
}
sqlparser::ast::FunctionArg::ExprNamed { arg, .. } => {
extract_from_function_arg_expr(arg, table_names);
}
}
}
}
},
Expr::Case {
operand,
conditions,
results,
else_result,
..
} => {
if let Some(op) = operand {
extract_from_expr(op, table_names);
}
for cond in conditions {
extract_from_expr(cond, table_names);
}
for result in results {
extract_from_expr(result, table_names);
}
if let Some(else_expr) = else_result {
extract_from_expr(else_expr, table_names);
}
}
Expr::InList { expr, list, .. } => {
extract_from_expr(expr, table_names);
for item in list {
extract_from_expr(item, table_names);
}
}
Expr::Between { expr, low, high, .. } => {
extract_from_expr(expr, table_names);
extract_from_expr(low, table_names);
extract_from_expr(high, table_names);
}
_ => {
}
}
}
fn extract_from_function_arg_expr(arg: &sqlparser::ast::FunctionArgExpr, table_names: &mut HashSet<String>) {
match arg {
sqlparser::ast::FunctionArgExpr::Expr(expr) => {
extract_from_expr(expr, table_names);
}
sqlparser::ast::FunctionArgExpr::QualifiedWildcard(_) => {}
sqlparser::ast::FunctionArgExpr::Wildcard => {}
}
}