1use serde::{Deserialize, Serialize};
34
35use crate::expr::{Expr, lower_expression};
36use crate::stmt::Statement;
37
38#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
40pub struct CallSite {
41 pub callee_parts: Vec<String>,
43 pub callee_display: String,
45 pub arg_count: usize,
49 pub context: CallContext,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallContext {
57 Statement,
59 Assignment,
61 ControlFlow,
63 ReturnValue,
65}
66
67#[must_use]
76pub fn extract_call_sites(stmts: &[Statement]) -> Vec<CallSite> {
77 extract_call_sites_bounded(stmts).0
78}
79
80#[must_use]
87pub fn extract_call_sites_bounded(stmts: &[Statement]) -> (Vec<CallSite>, crate::RecursionOutcome) {
88 let mut out: Vec<CallSite> = Vec::new();
89 let mut outcome = crate::RecursionOutcome::default();
90 walk_call_sites(stmts, 0, &mut out, &mut outcome);
91 (out, outcome)
92}
93
94fn walk_call_sites(
95 stmts: &[Statement],
96 depth: usize,
97 out: &mut Vec<CallSite>,
98 outcome: &mut crate::RecursionOutcome,
99) {
100 macro_rules! recurse_body {
106 ($text:expr) => {{
107 if depth + 1 >= crate::MAX_RELOWER_DEPTH {
108 outcome.note_truncated();
109 } else {
110 let lowered = crate::lower_statement_body($text);
111 walk_call_sites(&lowered, depth + 1, out, outcome);
112 }
113 }};
114 }
115 for stmt in stmts {
116 match stmt {
117 Statement::Assignment { rhs_text, .. } => {
118 collect_calls(&lower_expression(rhs_text), CallContext::Assignment, out);
119 }
120 Statement::Return {
121 value_text: Some(v),
122 } => {
123 collect_calls(&lower_expression(v), CallContext::ReturnValue, out);
124 }
125 Statement::If {
126 arms,
127 else_body_text,
128 } => {
129 for arm in arms {
130 collect_calls(
131 &lower_expression(&arm.cond_text),
132 CallContext::ControlFlow,
133 out,
134 );
135 recurse_body!(&arm.body_text);
136 }
137 if let Some(eb) = else_body_text {
138 recurse_body!(eb);
139 }
140 }
141 Statement::WhileLoop {
142 cond_text,
143 body_text,
144 } => {
145 collect_calls(&lower_expression(cond_text), CallContext::ControlFlow, out);
146 recurse_body!(body_text);
147 }
148 Statement::ForLoop {
149 range_text,
150 body_text,
151 ..
152 } => {
153 collect_calls(&lower_expression(range_text), CallContext::ControlFlow, out);
154 recurse_body!(body_text);
155 }
156 Statement::BareLoop { body_text } => {
157 recurse_body!(body_text);
158 }
159 Statement::NestedBlock { body_text } => {
160 let inner = strip_block_wrapper(body_text);
165 if inner != body_text.as_str() {
166 recurse_body!(inner);
167 } else {
168 collect_calls(&lower_expression(body_text), CallContext::Statement, out);
172 }
173 }
174 Statement::Unrecognized { raw_text, .. } => {
175 let e = lower_expression(raw_text);
177 collect_calls(&e, CallContext::Statement, out);
178 }
179 _ => {}
180 }
181 }
182}
183
184pub(crate) fn strip_block_wrapper(text: &str) -> &str {
195 let trimmed = text.trim();
196 let upper = trimmed.to_ascii_uppercase();
197 let after_open = if let Some(rest) = upper.strip_prefix("DECLARE") {
198 &trimmed[trimmed.len() - rest.len()..]
199 } else if let Some(rest) = upper.strip_prefix("BEGIN") {
200 &trimmed[trimmed.len() - rest.len()..]
201 } else {
202 return text;
203 };
204 let after_open = after_open.trim_start();
205 let upper_inner = after_open.to_ascii_uppercase();
207 if let Some(pos) = upper_inner.rfind("END") {
208 after_open[..pos].trim_end()
209 } else {
210 after_open
211 }
212}
213
214fn collect_calls(expr: &Expr, ctx: CallContext, out: &mut Vec<CallSite>) {
215 match expr {
216 Expr::Call { callee, args } => {
217 out.push(CallSite {
218 callee_parts: callee.parts.clone(),
219 callee_display: callee.display.clone(),
220 arg_count: args.len(),
221 context: ctx,
222 });
223 for a in args {
224 collect_calls(a, ctx, out);
225 }
226 }
227 Expr::Binary { lhs, rhs, .. } => {
228 collect_calls(lhs, ctx, out);
229 collect_calls(rhs, ctx, out);
230 }
231 Expr::Unary { operand, .. } => collect_calls(operand, ctx, out),
232 _ => {}
233 }
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239 use crate::lower_statement_body;
240
241 #[test]
242 fn assignment_rhs_call_extracted() {
243 let stmts = lower_statement_body("v_total := compute_sum(a, b);");
244 let calls = extract_call_sites(&stmts);
245 assert_eq!(calls.len(), 1);
246 assert_eq!(calls[0].callee_parts, vec!["COMPUTE_SUM"]);
247 assert_eq!(calls[0].arg_count, 2);
248 assert_eq!(calls[0].context, CallContext::Assignment);
249 }
250
251 #[test]
252 fn nested_call_yields_both_callees() {
253 let stmts = lower_statement_body("v := nvl(compute(x), 0);");
254 let calls = extract_call_sites(&stmts);
255 let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
256 assert!(names.contains(&"nvl"));
257 assert!(names.contains(&"compute"));
258 }
259
260 #[test]
261 fn return_value_call_context() {
262 let stmts = lower_statement_body("RETURN compute_total(p_id);");
263 let calls = extract_call_sites(&stmts);
264 assert_eq!(calls.len(), 1);
265 assert_eq!(calls[0].context, CallContext::ReturnValue);
266 }
267
268 #[test]
269 fn statement_level_proc_call_extracted() {
270 let stmts = lower_statement_body("billing_pkg.post_invoice(p_id, p_amount);");
271 let calls = extract_call_sites(&stmts);
272 assert_eq!(calls.len(), 1);
273 assert_eq!(calls[0].callee_parts, vec!["BILLING_PKG", "POST_INVOICE"]);
274 assert_eq!(calls[0].context, CallContext::Statement);
275 assert_eq!(calls[0].arg_count, 2);
276 }
277
278 #[test]
279 fn if_condition_and_body_calls_extracted() {
280 let src = "IF is_valid(p_id) THEN log_event('ok'); END IF;";
281 let stmts = lower_statement_body(src);
282 let calls = extract_call_sites(&stmts);
283 let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
284 assert!(names.contains(&"is_valid"));
285 assert!(names.contains(&"log_event"));
286 }
287
288 #[test]
289 fn for_loop_body_calls_recursed() {
290 let src = "FOR i IN 1..10 LOOP process_row(i); END LOOP;";
291 let stmts = lower_statement_body(src);
292 let calls = extract_call_sites(&stmts);
293 assert!(calls.iter().any(|c| c.callee_display == "process_row"));
294 }
295
296 #[test]
297 fn no_calls_in_pure_arithmetic() {
298 let stmts = lower_statement_body("v := a + b * 2;");
299 let calls = extract_call_sites(&stmts);
300 assert!(calls.is_empty());
301 }
302
303 #[test]
304 fn binary_operands_searched_for_calls() {
305 let stmts = lower_statement_body("v := f(x) + g(y);");
306 let calls = extract_call_sites(&stmts);
307 let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
308 assert!(names.contains(&"f"));
309 assert!(names.contains(&"g"));
310 }
311
312 #[test]
313 fn callsite_serde_round_trip() {
314 let stmts = lower_statement_body("v := compute(a);");
315 let calls = extract_call_sites(&stmts);
316 let json = serde_json::to_string(&calls[0]).unwrap();
317 let back: CallSite = serde_json::from_str(&json).unwrap();
318 assert_eq!(back, calls[0]);
319 assert!(json.contains("\"context\":\"assignment\""));
320 }
321
322 #[test]
323 fn nested_block_calls_recursed() {
324 let stmts = lower_statement_body("BEGIN inner_proc(1); END;");
325 let calls = extract_call_sites(&stmts);
326 assert!(calls.iter().any(|c| c.callee_display == "inner_proc"));
327 }
328
329 #[test]
338 fn wide_assignment_rhs_chain_does_not_overflow_call_walk() {
339 let n = 500_000usize;
340 let mut rhs = String::with_capacity(n * 8);
341 for i in 0..n {
342 if i > 0 {
343 rhs.push_str(" OR ");
344 }
345 rhs.push_str("f(x)");
346 }
347 let stmt = format!("v := {rhs};");
348 let stmts = lower_statement_body(&stmt);
349 let calls = extract_call_sites(&stmts);
353 assert!(
354 !calls.is_empty(),
355 "the shallow prefix of the chain still yields call sites"
356 );
357 }
358
359 #[test]
371 fn non_shrinking_for_update_does_not_stack_overflow_and_reports_limit() {
372 let stmts = vec![Statement::BareLoop {
373 body_text: "FOR UPDATE".to_string(),
374 }];
375 let (calls, outcome) = extract_call_sites_bounded(&stmts);
376 assert!(
377 outcome.limit_hit,
378 "the non-shrinking `FOR UPDATE` BareLoop must trip the \
379 bounded depth cap, outcome={outcome:?}, calls={calls:?}"
380 );
381 assert!(outcome.truncated_bodies >= 1);
382 let _ = extract_call_sites(&stmts);
385 }
386
387 #[test]
394 fn parenthesised_call_operand_keeps_inner_call_edge() {
395 let stmts = lower_statement_body("v := nvl((compute(x)), 0);");
396 let calls = extract_call_sites(&stmts);
397 let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
398 assert!(
399 names.contains(&"nvl"),
400 "outer nvl call must be recorded: {names:?}"
401 );
402 assert!(
403 names.contains(&"compute"),
404 "the parenthesised inner compute call must survive: {names:?}"
405 );
406 }
407}