use serde::{Deserialize, Serialize};
use crate::expr::{Expr, lower_expression};
use crate::stmt::Statement;
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallSite {
pub callee_parts: Vec<String>,
pub callee_display: String,
pub arg_count: usize,
pub context: CallContext,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CallContext {
Statement,
Assignment,
ControlFlow,
ReturnValue,
}
#[must_use]
pub fn extract_call_sites(stmts: &[Statement]) -> Vec<CallSite> {
extract_call_sites_bounded(stmts).0
}
#[must_use]
pub fn extract_call_sites_bounded(stmts: &[Statement]) -> (Vec<CallSite>, crate::RecursionOutcome) {
let mut out: Vec<CallSite> = Vec::new();
let mut outcome = crate::RecursionOutcome::default();
walk_call_sites(stmts, 0, &mut out, &mut outcome);
(out, outcome)
}
fn walk_call_sites(
stmts: &[Statement],
depth: usize,
out: &mut Vec<CallSite>,
outcome: &mut crate::RecursionOutcome,
) {
macro_rules! recurse_body {
($text:expr) => {{
if depth + 1 >= crate::MAX_RELOWER_DEPTH {
outcome.note_truncated();
} else {
let lowered = crate::lower_statement_body($text);
walk_call_sites(&lowered, depth + 1, out, outcome);
}
}};
}
for stmt in stmts {
match stmt {
Statement::Assignment { rhs_text, .. } => {
collect_calls(&lower_expression(rhs_text), CallContext::Assignment, out);
}
Statement::Return {
value_text: Some(v),
} => {
collect_calls(&lower_expression(v), CallContext::ReturnValue, out);
}
Statement::If {
arms,
else_body_text,
} => {
for arm in arms {
collect_calls(
&lower_expression(&arm.cond_text),
CallContext::ControlFlow,
out,
);
recurse_body!(&arm.body_text);
}
if let Some(eb) = else_body_text {
recurse_body!(eb);
}
}
Statement::WhileLoop {
cond_text,
body_text,
} => {
collect_calls(&lower_expression(cond_text), CallContext::ControlFlow, out);
recurse_body!(body_text);
}
Statement::ForLoop {
range_text,
body_text,
..
} => {
collect_calls(&lower_expression(range_text), CallContext::ControlFlow, out);
recurse_body!(body_text);
}
Statement::BareLoop { body_text } => {
recurse_body!(body_text);
}
Statement::NestedBlock { body_text } => {
let inner = strip_block_wrapper(body_text);
if inner != body_text.as_str() {
recurse_body!(inner);
} else {
collect_calls(&lower_expression(body_text), CallContext::Statement, out);
}
}
Statement::Unrecognized { raw_text, .. } => {
let e = lower_expression(raw_text);
collect_calls(&e, CallContext::Statement, out);
}
_ => {}
}
}
}
pub(crate) fn strip_block_wrapper(text: &str) -> &str {
let trimmed = text.trim();
let upper = trimmed.to_ascii_uppercase();
let after_open = if let Some(rest) = upper.strip_prefix("DECLARE") {
&trimmed[trimmed.len() - rest.len()..]
} else if let Some(rest) = upper.strip_prefix("BEGIN") {
&trimmed[trimmed.len() - rest.len()..]
} else {
return text;
};
let after_open = after_open.trim_start();
let upper_inner = after_open.to_ascii_uppercase();
if let Some(pos) = upper_inner.rfind("END") {
after_open[..pos].trim_end()
} else {
after_open
}
}
fn collect_calls(expr: &Expr, ctx: CallContext, out: &mut Vec<CallSite>) {
match expr {
Expr::Call { callee, args } => {
out.push(CallSite {
callee_parts: callee.parts.clone(),
callee_display: callee.display.clone(),
arg_count: args.len(),
context: ctx,
});
for a in args {
collect_calls(a, ctx, out);
}
}
Expr::Binary { lhs, rhs, .. } => {
collect_calls(lhs, ctx, out);
collect_calls(rhs, ctx, out);
}
Expr::Unary { operand, .. } => collect_calls(operand, ctx, out),
_ => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lower_statement_body;
#[test]
fn assignment_rhs_call_extracted() {
let stmts = lower_statement_body("v_total := compute_sum(a, b);");
let calls = extract_call_sites(&stmts);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].callee_parts, vec!["COMPUTE_SUM"]);
assert_eq!(calls[0].arg_count, 2);
assert_eq!(calls[0].context, CallContext::Assignment);
}
#[test]
fn nested_call_yields_both_callees() {
let stmts = lower_statement_body("v := nvl(compute(x), 0);");
let calls = extract_call_sites(&stmts);
let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
assert!(names.contains(&"nvl"));
assert!(names.contains(&"compute"));
}
#[test]
fn return_value_call_context() {
let stmts = lower_statement_body("RETURN compute_total(p_id);");
let calls = extract_call_sites(&stmts);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].context, CallContext::ReturnValue);
}
#[test]
fn statement_level_proc_call_extracted() {
let stmts = lower_statement_body("billing_pkg.post_invoice(p_id, p_amount);");
let calls = extract_call_sites(&stmts);
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].callee_parts, vec!["BILLING_PKG", "POST_INVOICE"]);
assert_eq!(calls[0].context, CallContext::Statement);
assert_eq!(calls[0].arg_count, 2);
}
#[test]
fn if_condition_and_body_calls_extracted() {
let src = "IF is_valid(p_id) THEN log_event('ok'); END IF;";
let stmts = lower_statement_body(src);
let calls = extract_call_sites(&stmts);
let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
assert!(names.contains(&"is_valid"));
assert!(names.contains(&"log_event"));
}
#[test]
fn for_loop_body_calls_recursed() {
let src = "FOR i IN 1..10 LOOP process_row(i); END LOOP;";
let stmts = lower_statement_body(src);
let calls = extract_call_sites(&stmts);
assert!(calls.iter().any(|c| c.callee_display == "process_row"));
}
#[test]
fn no_calls_in_pure_arithmetic() {
let stmts = lower_statement_body("v := a + b * 2;");
let calls = extract_call_sites(&stmts);
assert!(calls.is_empty());
}
#[test]
fn binary_operands_searched_for_calls() {
let stmts = lower_statement_body("v := f(x) + g(y);");
let calls = extract_call_sites(&stmts);
let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
assert!(names.contains(&"f"));
assert!(names.contains(&"g"));
}
#[test]
fn callsite_serde_round_trip() {
let stmts = lower_statement_body("v := compute(a);");
let calls = extract_call_sites(&stmts);
let json = serde_json::to_string(&calls[0]).unwrap();
let back: CallSite = serde_json::from_str(&json).unwrap();
assert_eq!(back, calls[0]);
assert!(json.contains("\"context\":\"assignment\""));
}
#[test]
fn nested_block_calls_recursed() {
let stmts = lower_statement_body("BEGIN inner_proc(1); END;");
let calls = extract_call_sites(&stmts);
assert!(calls.iter().any(|c| c.callee_display == "inner_proc"));
}
#[test]
fn wide_assignment_rhs_chain_does_not_overflow_call_walk() {
let n = 500_000usize;
let mut rhs = String::with_capacity(n * 8);
for i in 0..n {
if i > 0 {
rhs.push_str(" OR ");
}
rhs.push_str("f(x)");
}
let stmt = format!("v := {rhs};");
let stmts = lower_statement_body(&stmt);
let calls = extract_call_sites(&stmts);
assert!(
!calls.is_empty(),
"the shallow prefix of the chain still yields call sites"
);
}
#[test]
fn non_shrinking_for_update_does_not_stack_overflow_and_reports_limit() {
let stmts = vec![Statement::BareLoop {
body_text: "FOR UPDATE".to_string(),
}];
let (calls, outcome) = extract_call_sites_bounded(&stmts);
assert!(
outcome.limit_hit,
"the non-shrinking `FOR UPDATE` BareLoop must trip the \
bounded depth cap, outcome={outcome:?}, calls={calls:?}"
);
assert!(outcome.truncated_bodies >= 1);
let _ = extract_call_sites(&stmts);
}
#[test]
fn parenthesised_call_operand_keeps_inner_call_edge() {
let stmts = lower_statement_body("v := nvl((compute(x)), 0);");
let calls = extract_call_sites(&stmts);
let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
assert!(
names.contains(&"nvl"),
"outer nvl call must be recorded: {names:?}"
);
assert!(
names.contains(&"compute"),
"the parenthesised inner compute call must survive: {names:?}"
);
}
}