use crate::control::security::catalog::procedure_types::ProcedureRoutability;
pub fn extract_routability(body_sql: &str) -> ProcedureRoutability {
let block = match crate::control::planner::procedural::parse_block(body_sql) {
Ok(b) => b,
Err(_) => return ProcedureRoutability::MultiCollection,
};
let mut collections = std::collections::HashSet::new();
collect_dml_targets(&block.statements, &mut collections);
match collections.len() {
0 => ProcedureRoutability::MultiCollection, 1 => {
if let Some(name) = collections.into_iter().next() {
ProcedureRoutability::SingleCollection(name)
} else {
ProcedureRoutability::MultiCollection
}
}
_ => ProcedureRoutability::MultiCollection,
}
}
fn collect_dml_targets(
stmts: &[crate::control::planner::procedural::ast::Statement],
collections: &mut std::collections::HashSet<String>,
) {
use crate::control::planner::procedural::ast::Statement;
for stmt in stmts {
match stmt {
Statement::Sql { sql } => {
if let Some(name) = extract_dml_target_collection(sql) {
collections.insert(name);
}
}
Statement::If {
then_block,
elsif_branches,
else_block,
..
} => {
collect_dml_targets(then_block, collections);
for branch in elsif_branches {
collect_dml_targets(&branch.body, collections);
}
if let Some(else_stmts) = else_block {
collect_dml_targets(else_stmts, collections);
}
}
Statement::Loop { body }
| Statement::While { body, .. }
| Statement::For { body, .. } => {
collect_dml_targets(body, collections);
}
_ => {}
}
}
}
fn extract_dml_target_collection(sql: &str) -> Option<String> {
let trimmed = sql.trim();
let upper = trimmed.to_uppercase();
let tokens: Vec<&str> = trimmed.split_whitespace().collect();
if upper.starts_with("INSERT INTO") && tokens.len() >= 3 {
Some(tokens[2].to_lowercase().trim_matches('(').to_string())
} else if upper.starts_with("UPDATE") && tokens.len() >= 2 {
Some(tokens[1].to_lowercase())
} else if upper.starts_with("DELETE FROM") && tokens.len() >= 3 {
Some(tokens[2].to_lowercase())
} else {
None
}
}