use sqlparser::ast::{
BinaryOperator, Expr, JoinConstraint, JoinOperator, SetExpr, Statement, TableFactor,
TableWithJoins,
};
use sqlparser::dialect::PostgreSqlDialect;
use sqlparser::parser::Parser;
use std::collections::{HashMap, HashSet, VecDeque};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct JoinPath {
pub source_table: String,
pub initial_col: String,
pub steps: Vec<JoinStep>,
pub root_join_col: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct JoinStep {
pub table_name: String,
pub lookup_col: String,
pub carry_col: String,
}
#[derive(Debug, Clone)]
struct JoinEdge {
left_table: String,
left_col: String,
right_table: String,
right_col: String,
}
#[derive(Debug)]
struct JoinGraph {
tables: HashSet<String>,
edges: Vec<JoinEdge>,
aliases: HashMap<String, String>,
}
impl JoinGraph {
fn new() -> Self {
Self {
tables: HashSet::new(),
edges: Vec::new(),
aliases: HashMap::new(),
}
}
fn add_table(&mut self, name: &str, alias: Option<&str>) {
self.tables.insert(name.to_string());
if let Some(a) = alias {
self.aliases.insert(a.to_string(), name.to_string());
}
}
fn resolve(&self, name: &str) -> String {
self.aliases
.get(name)
.cloned()
.unwrap_or_else(|| name.to_string())
}
fn add_edge(&mut self, edge: JoinEdge) {
self.edges.push(edge);
}
fn edges_between(&self, a: &str, b: &str) -> Vec<&JoinEdge> {
self.edges
.iter()
.filter(|e| {
(e.left_table == a && e.right_table == b)
|| (e.left_table == b && e.right_table == a)
})
.collect()
}
fn neighbors(&self, table: &str) -> HashSet<String> {
let mut result = HashSet::new();
for edge in &self.edges {
if edge.left_table == table {
result.insert(edge.right_table.clone());
} else if edge.right_table == table {
result.insert(edge.left_table.clone());
}
}
result
}
}
pub fn extract_join_paths(select_sql: &str, root_table: &str) -> Result<Vec<JoinPath>, String> {
let dialect = PostgreSqlDialect {};
let stmts = Parser::new(&dialect)
.try_with_sql(select_sql)
.map_err(|e| format!("SQL init error: {e}"))?
.parse_statements()
.map_err(|e| format!("SQL parse error: {e}"))?;
let stmt = stmts
.into_iter()
.next()
.ok_or_else(|| "Empty SQL statement".to_string())?;
let query = match stmt {
Statement::Query(q) => q,
_ => return Err("Only SELECT queries are supported".to_string()),
};
let select = match *query.body {
SetExpr::Select(s) => s,
_ => return Err("UNION/set operations not yet supported for cascade paths".to_string()),
};
if select.from.is_empty() {
return Ok(vec![]);
}
let mut graph = JoinGraph::new();
for table_with_joins in &select.from {
build_graph_from_table_with_joins(table_with_joins, &mut graph)?;
}
if let Some(ref where_expr) = select.selection {
extract_implicit_joins(where_expr, &mut graph);
}
if !graph.tables.contains(root_table) {
return Ok(vec![]);
}
let non_root_tables: Vec<String> = graph
.tables
.iter()
.filter(|t| *t != root_table)
.cloned()
.collect();
let mut paths = Vec::new();
for leaf in &non_root_tables {
if let Some(path) = build_path(&graph, leaf, root_table) {
paths.push(path);
}
}
Ok(paths)
}
fn build_graph_from_table_with_joins(
twj: &TableWithJoins,
graph: &mut JoinGraph,
) -> Result<(), String> {
let (root_name, root_alias) = extract_table_info(&twj.relation)?;
graph.add_table(&root_name, root_alias.as_deref());
for join in &twj.joins {
let (right_name, right_alias) = extract_table_info(&join.relation)?;
graph.add_table(&right_name, right_alias.as_deref());
let constraint = match &join.join_operator {
JoinOperator::Inner(c)
| JoinOperator::LeftOuter(c)
| JoinOperator::RightOuter(c)
| JoinOperator::FullOuter(c) => c,
_ => continue, };
match constraint {
JoinConstraint::On(expr) => {
extract_equalities(expr, graph);
}
JoinConstraint::Using(cols) => {
for col in cols {
let col_name = col.value.clone();
graph.add_edge(JoinEdge {
left_table: root_name.clone(),
left_col: col_name.clone(),
right_table: right_name.clone(),
right_col: col_name,
});
}
}
JoinConstraint::Natural | JoinConstraint::None => {}
}
}
Ok(())
}
fn extract_table_info(factor: &TableFactor) -> Result<(String, Option<String>), String> {
match factor {
TableFactor::Table { name, alias, .. } => {
let table_name = name
.0
.last()
.map(|i| i.value.clone())
.ok_or("Invalid table name")?;
let alias_name = alias.as_ref().map(|a| a.name.value.clone());
Ok((table_name, alias_name))
}
_ => {
Err("Subqueries and derived tables not supported in cascade path analysis".to_string())
}
}
}
fn extract_equalities(expr: &Expr, graph: &mut JoinGraph) {
match expr {
Expr::BinaryOp { left, op, right } => match op {
BinaryOperator::Eq => {
if let (Some(left_ref), Some(right_ref)) =
(extract_col_ref(left, graph), extract_col_ref(right, graph))
&& left_ref.0 != right_ref.0
{
graph.add_edge(JoinEdge {
left_table: left_ref.0,
left_col: left_ref.1,
right_table: right_ref.0,
right_col: right_ref.1,
});
}
}
BinaryOperator::And => {
extract_equalities(left, graph);
extract_equalities(right, graph);
}
_ => {}
},
Expr::Nested(inner) => extract_equalities(inner, graph),
_ => {}
}
}
fn extract_col_ref(expr: &Expr, graph: &JoinGraph) -> Option<(String, String)> {
match expr {
Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
let table = graph.resolve(&parts[0].value);
let column = parts[1].value.clone();
Some((table, column))
}
_ => None,
}
}
fn extract_implicit_joins(expr: &Expr, graph: &mut JoinGraph) {
extract_equalities(expr, graph);
}
fn build_path(graph: &JoinGraph, leaf: &str, root: &str) -> Option<JoinPath> {
let mut queue = VecDeque::new();
let mut visited = HashSet::new();
let mut parent_map: HashMap<String, String> = HashMap::new();
queue.push_back(leaf.to_string());
visited.insert(leaf.to_string());
while let Some(current) = queue.pop_front() {
if current == root {
break;
}
for neighbor in graph.neighbors(¤t) {
if !visited.contains(&neighbor) {
visited.insert(neighbor.clone());
parent_map.insert(neighbor.clone(), current.clone());
queue.push_back(neighbor);
}
}
}
if !visited.contains(root) {
return None; }
let mut chain = vec![root.to_string()];
let mut current = root.to_string();
loop {
if current == leaf {
break;
}
let prev = parent_map.get(¤t)?;
chain.push(prev.clone());
current = prev.clone();
}
if chain.len() < 2 {
return None;
}
chain.reverse();
let first_edge = graph.edges_between(&chain[0], &chain[1]);
let first_edge = first_edge.first()?;
let initial_col = if first_edge.left_table == chain[0] {
first_edge.left_col.clone()
} else {
first_edge.right_col.clone()
};
let mut steps = Vec::new();
for i in 1..chain.len() - 1 {
let table = &chain[i];
let incoming_edge = graph.edges_between(&chain[i - 1], table);
let incoming = incoming_edge.first()?;
let lookup_col = if incoming.left_table == *table {
incoming.left_col.clone()
} else {
incoming.right_col.clone()
};
let outgoing_edge = graph.edges_between(table, &chain[i + 1]);
let outgoing = outgoing_edge.first()?;
let carry_col = if outgoing.left_table == *table {
outgoing.left_col.clone()
} else {
outgoing.right_col.clone()
};
steps.push(JoinStep {
table_name: table.clone(),
lookup_col,
carry_col,
});
}
let last_idx = chain.len() - 1;
let root_edge = graph.edges_between(&chain[last_idx - 1], &chain[last_idx]);
let root_edge = root_edge.first()?;
let root_join_col = if root_edge.left_table == chain[last_idx] {
root_edge.left_col.clone()
} else {
root_edge.right_col.clone()
};
Some(JoinPath {
source_table: chain[0].clone(),
initial_col,
steps,
root_join_col,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_single_join() {
let sql = "SELECT o.pk_order, g.name \
FROM tb_order o \
JOIN tb_group g ON g.fk_order = o.pk_order";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 1);
let path = &paths[0];
assert_eq!(path.source_table, "tb_group");
assert_eq!(path.initial_col, "fk_order");
assert!(
path.steps.is_empty(),
"single hop should have no intermediate steps"
);
assert_eq!(
path.root_join_col, "pk_order",
"root_join_col should be the root's PK (child→parent FK)"
);
}
#[test]
fn test_two_hop_chain() {
let sql = "SELECT o.pk_order, g.name, i.value \
FROM tb_order o \
JOIN tb_group g ON g.fk_order = o.pk_order \
JOIN tb_item i ON i.fk_group = g.pk_group";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 2);
let group_path = paths.iter().find(|p| p.source_table == "tb_group").unwrap();
assert_eq!(group_path.initial_col, "fk_order");
assert!(group_path.steps.is_empty());
assert_eq!(group_path.root_join_col, "pk_order");
let item_path = paths.iter().find(|p| p.source_table == "tb_item").unwrap();
assert_eq!(item_path.initial_col, "fk_group");
assert_eq!(item_path.steps.len(), 1);
let hop = &item_path.steps[0];
assert_eq!(hop.table_name, "tb_group");
assert_eq!(hop.lookup_col, "pk_group");
assert_eq!(hop.carry_col, "fk_order");
}
#[test]
fn test_three_hop_chain() {
let sql = "SELECT o.pk_order, g.name, i.value, d.detail \
FROM tb_order o \
JOIN tb_group g ON g.fk_order = o.pk_order \
JOIN tb_item i ON i.fk_group = g.pk_group \
JOIN tb_detail d ON d.fk_item = i.pk_item";
let paths = extract_join_paths(sql, "tb_order").unwrap();
let detail_path = paths
.iter()
.find(|p| p.source_table == "tb_detail")
.unwrap();
assert_eq!(detail_path.initial_col, "fk_item");
assert_eq!(detail_path.steps.len(), 2);
assert_eq!(detail_path.steps[0].table_name, "tb_item");
assert_eq!(detail_path.steps[0].lookup_col, "pk_item");
assert_eq!(detail_path.steps[0].carry_col, "fk_group");
assert_eq!(detail_path.steps[1].table_name, "tb_group");
assert_eq!(detail_path.steps[1].lookup_col, "pk_group");
assert_eq!(detail_path.steps[1].carry_col, "fk_order");
}
#[test]
fn test_aliases_resolved() {
let sql = "SELECT 1 FROM tb_order o JOIN tb_group g ON g.fk_order = o.pk_order";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].source_table, "tb_group");
assert_eq!(paths[0].initial_col, "fk_order");
}
#[test]
fn test_left_join() {
let sql = "SELECT 1 FROM tb_order o LEFT JOIN tb_group g ON g.fk_order = o.pk_order";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].source_table, "tb_group");
}
#[test]
fn test_implicit_join() {
let sql = "SELECT 1 FROM tb_order o, tb_group g WHERE g.fk_order = o.pk_order";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 1);
assert_eq!(paths[0].source_table, "tb_group");
assert_eq!(paths[0].initial_col, "fk_order");
}
#[test]
fn test_subquery_returns_error() {
let sql = "SELECT 1 FROM (SELECT * FROM tb_a) sub";
let result = extract_join_paths(sql, "tb_a");
assert!(result.is_err());
}
#[test]
fn test_root_not_found_returns_empty() {
let sql = "SELECT 1 FROM tb_order o JOIN tb_group g ON g.fk_order = o.pk_order";
let paths = extract_join_paths(sql, "tb_nonexistent").unwrap();
assert!(paths.is_empty());
}
#[test]
fn test_lookup_table_join_root_holds_fk() {
let sql = "SELECT o.pk_order, o.fk_currency, c.iso_code \
FROM tb_order o \
LEFT JOIN tb_currency c ON o.fk_currency = c.pk_currency";
let paths = extract_join_paths(sql, "tb_order").unwrap();
assert_eq!(paths.len(), 1);
let path = &paths[0];
assert_eq!(path.source_table, "tb_currency");
assert_eq!(path.initial_col, "pk_currency");
assert!(
path.steps.is_empty(),
"direct lookup join has no intermediate steps"
);
assert_eq!(
path.root_join_col, "fk_currency",
"root_join_col must be the FK on the root table, not its PK"
);
}
}