Skip to main content

plsql_ir/
calls.rs

1//! Call-site edge extraction.
2//!
3//! Walks a lowered statement body and pulls out every
4//! procedure / function invocation as a [`CallSite`]. The
5//! dependency-graph layer resolves each `callee` to a concrete
6//! node (via `plsql_symbols::resolve_reference`) and mints a
7//! `Calls` edge; this module's job is purely *extraction* — find
8//! the call sites and their shape.
9//!
10//! Calls appear in three places:
11//!
12//! 1. Statement-level procedure calls — a bare
13//!    `Statement::Unrecognized` line whose text is
14//!    `pkg.proc(args);` (the stmt recogniser leaves these
15//!    unclassified because they're neither assignment nor
16//!    control flow).
17//! 2. Expression-embedded function calls — inside an
18//!    `Assignment.rhs_text`, an `If` arm condition, a loop
19//!    range, a `Return` value, etc.
20//! 3. Nested calls — `nvl(compute(x), 0)` yields both `nvl`
21//!    and `compute`.
22//!
23//! ## /oracle evidence
24//!
25//! * `DATABASE-REFERENCE.md` PL/SQL Language Reference — the
26//!   call grammar (positional / named notation, package-
27//!   qualified vs bare) drives what counts as a callee.
28//! * `LOW-LEVEL-CATALOGS.md` Data Dictionary View Families —
29//!   `ALL_DEPENDENCIES` with `DEPENDENCY_TYPE` is the
30//!   server-side mirror the depgraph cross-checks `Calls`
31//!   edges against.
32
33use serde::{Deserialize, Serialize};
34
35use crate::expr::{Expr, lower_expression};
36use crate::stmt::Statement;
37
38/// One extracted call site.
39#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
40pub struct CallSite {
41    /// Dotted callee path, case-folded for the lookup key.
42    pub callee_parts: Vec<String>,
43    /// Source-form callee path preserved for diagnostics.
44    pub callee_display: String,
45    /// Number of positional arguments at the call. Named-notation
46    /// args still count toward arity here; the depgraph's overload
47    /// resolver (SYM-009) handles named-vs-positional matching.
48    pub arg_count: usize,
49    /// Context the call appeared in — drives the edge's
50    /// confidence + the report wording.
51    pub context: CallContext,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallContext {
57    /// Statement-level procedure call (`pkg.proc(args);`).
58    Statement,
59    /// Function call inside an assignment RHS.
60    Assignment,
61    /// Function call inside a control-flow condition / range.
62    ControlFlow,
63    /// Function call inside a RETURN expression.
64    ReturnValue,
65}
66
67/// Extract every call site from a lowered statement body.
68///
69/// Backwards-compatible wrapper around
70/// [`extract_call_sites_bounded`]: the recursion is depth-guarded
71/// so a malformed unit whose re-lowered body fails to shrink can
72/// never stack-overflow. Callers that need to surface the typed
73/// [`plsql_core::UnknownReason::AnalysisRecursionLimit`] degradation
74/// should call [`extract_call_sites_bounded`] directly.
75#[must_use]
76pub fn extract_call_sites(stmts: &[Statement]) -> Vec<CallSite> {
77    extract_call_sites_bounded(stmts).0
78}
79
80/// Depth-bounded variant of [`extract_call_sites`]. Returns the
81/// extracted call sites plus a [`RecursionOutcome`] recording
82/// whether (and how often) a nested body was abandoned at the
83/// recursion-depth cap rather than walked unbounded. The caller is
84/// responsible for emitting an honest typed diagnostic when
85/// `outcome.limit_hit` (R13 — never silently truncate).
86#[must_use]
87pub fn extract_call_sites_bounded(stmts: &[Statement]) -> (Vec<CallSite>, crate::RecursionOutcome) {
88    let mut out: Vec<CallSite> = Vec::new();
89    let mut outcome = crate::RecursionOutcome::default();
90    walk_call_sites(stmts, 0, &mut out, &mut outcome);
91    (out, outcome)
92}
93
94fn walk_call_sites(
95    stmts: &[Statement],
96    depth: usize,
97    out: &mut Vec<CallSite>,
98    outcome: &mut crate::RecursionOutcome,
99) {
100    // Recurse into a re-lowered body only while we have depth
101    // budget left. At the cap we stop descending and record the
102    // truncation so the caller can surface it honestly — we do
103    // NOT silently drop it and we do NOT keep recursing (which
104    // would stack-overflow on a non-shrinking malformed slice).
105    macro_rules! recurse_body {
106        ($text:expr) => {{
107            if depth + 1 >= crate::MAX_RELOWER_DEPTH {
108                outcome.note_truncated();
109            } else {
110                let lowered = crate::lower_statement_body($text);
111                walk_call_sites(&lowered, depth + 1, out, outcome);
112            }
113        }};
114    }
115    for stmt in stmts {
116        match stmt {
117            Statement::Assignment { rhs_text, .. } => {
118                collect_calls(&lower_expression(rhs_text), CallContext::Assignment, out);
119            }
120            Statement::Return {
121                value_text: Some(v),
122            } => {
123                collect_calls(&lower_expression(v), CallContext::ReturnValue, out);
124            }
125            Statement::If {
126                arms,
127                else_body_text,
128            } => {
129                for arm in arms {
130                    collect_calls(
131                        &lower_expression(&arm.cond_text),
132                        CallContext::ControlFlow,
133                        out,
134                    );
135                    recurse_body!(&arm.body_text);
136                }
137                if let Some(eb) = else_body_text {
138                    recurse_body!(eb);
139                }
140            }
141            Statement::WhileLoop {
142                cond_text,
143                body_text,
144            } => {
145                collect_calls(&lower_expression(cond_text), CallContext::ControlFlow, out);
146                recurse_body!(body_text);
147            }
148            Statement::ForLoop {
149                range_text,
150                body_text,
151                ..
152            } => {
153                collect_calls(&lower_expression(range_text), CallContext::ControlFlow, out);
154                recurse_body!(body_text);
155            }
156            Statement::BareLoop { body_text } => {
157                recurse_body!(body_text);
158            }
159            Statement::NestedBlock { body_text } => {
160                // Strip the BEGIN…END / DECLARE…END wrapper before
161                // re-lowering, otherwise the stmt recogniser keeps
162                // classifying the same text as a NestedBlock and
163                // recursion never terminates.
164                let inner = strip_block_wrapper(body_text);
165                if inner != body_text.as_str() {
166                    recurse_body!(inner);
167                } else {
168                    // No wrapper to strip — treat the text as a
169                    // single expression candidate instead of
170                    // recursing.
171                    collect_calls(&lower_expression(body_text), CallContext::Statement, out);
172                }
173            }
174            Statement::Unrecognized { raw_text, .. } => {
175                // Statement-level procedure call: `pkg.proc(args);`.
176                let e = lower_expression(raw_text);
177                collect_calls(&e, CallContext::Statement, out);
178            }
179            _ => {}
180        }
181    }
182}
183
184/// Strip a leading `DECLARE`/`BEGIN` and a trailing `END[;]`
185/// from a block body so the inner statements can be re-lowered
186/// without re-triggering the NestedBlock classification.
187///
188/// Shared with the sibling re-lowering walks
189/// [`crate::flow_intra`] (taint) and [`crate::dml_edges`]
190/// (Reads/Writes edges) so all three descend into anonymous
191/// `BEGIN … END` / `DECLARE … END` sub-blocks identically — the
192/// returned slice is a sub-slice of `text`, so it is always on a
193/// UTF-8 char boundary even for multi-byte content.
194pub(crate) fn strip_block_wrapper(text: &str) -> &str {
195    let trimmed = text.trim();
196    let upper = trimmed.to_ascii_uppercase();
197    let after_open = if let Some(rest) = upper.strip_prefix("DECLARE") {
198        &trimmed[trimmed.len() - rest.len()..]
199    } else if let Some(rest) = upper.strip_prefix("BEGIN") {
200        &trimmed[trimmed.len() - rest.len()..]
201    } else {
202        return text;
203    };
204    let after_open = after_open.trim_start();
205    // Drop a trailing `END;` / `END`.
206    let upper_inner = after_open.to_ascii_uppercase();
207    if let Some(pos) = upper_inner.rfind("END") {
208        after_open[..pos].trim_end()
209    } else {
210        after_open
211    }
212}
213
214fn collect_calls(expr: &Expr, ctx: CallContext, out: &mut Vec<CallSite>) {
215    match expr {
216        Expr::Call { callee, args } => {
217            out.push(CallSite {
218                callee_parts: callee.parts.clone(),
219                callee_display: callee.display.clone(),
220                arg_count: args.len(),
221                context: ctx,
222            });
223            for a in args {
224                collect_calls(a, ctx, out);
225            }
226        }
227        Expr::Binary { lhs, rhs, .. } => {
228            collect_calls(lhs, ctx, out);
229            collect_calls(rhs, ctx, out);
230        }
231        Expr::Unary { operand, .. } => collect_calls(operand, ctx, out),
232        _ => {}
233    }
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239    use crate::lower_statement_body;
240
241    #[test]
242    fn assignment_rhs_call_extracted() {
243        let stmts = lower_statement_body("v_total := compute_sum(a, b);");
244        let calls = extract_call_sites(&stmts);
245        assert_eq!(calls.len(), 1);
246        assert_eq!(calls[0].callee_parts, vec!["COMPUTE_SUM"]);
247        assert_eq!(calls[0].arg_count, 2);
248        assert_eq!(calls[0].context, CallContext::Assignment);
249    }
250
251    #[test]
252    fn nested_call_yields_both_callees() {
253        let stmts = lower_statement_body("v := nvl(compute(x), 0);");
254        let calls = extract_call_sites(&stmts);
255        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
256        assert!(names.contains(&"nvl"));
257        assert!(names.contains(&"compute"));
258    }
259
260    #[test]
261    fn return_value_call_context() {
262        let stmts = lower_statement_body("RETURN compute_total(p_id);");
263        let calls = extract_call_sites(&stmts);
264        assert_eq!(calls.len(), 1);
265        assert_eq!(calls[0].context, CallContext::ReturnValue);
266    }
267
268    #[test]
269    fn statement_level_proc_call_extracted() {
270        let stmts = lower_statement_body("billing_pkg.post_invoice(p_id, p_amount);");
271        let calls = extract_call_sites(&stmts);
272        assert_eq!(calls.len(), 1);
273        assert_eq!(calls[0].callee_parts, vec!["BILLING_PKG", "POST_INVOICE"]);
274        assert_eq!(calls[0].context, CallContext::Statement);
275        assert_eq!(calls[0].arg_count, 2);
276    }
277
278    #[test]
279    fn if_condition_and_body_calls_extracted() {
280        let src = "IF is_valid(p_id) THEN log_event('ok'); END IF;";
281        let stmts = lower_statement_body(src);
282        let calls = extract_call_sites(&stmts);
283        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
284        assert!(names.contains(&"is_valid"));
285        assert!(names.contains(&"log_event"));
286    }
287
288    #[test]
289    fn for_loop_body_calls_recursed() {
290        let src = "FOR i IN 1..10 LOOP process_row(i); END LOOP;";
291        let stmts = lower_statement_body(src);
292        let calls = extract_call_sites(&stmts);
293        assert!(calls.iter().any(|c| c.callee_display == "process_row"));
294    }
295
296    #[test]
297    fn no_calls_in_pure_arithmetic() {
298        let stmts = lower_statement_body("v := a + b * 2;");
299        let calls = extract_call_sites(&stmts);
300        assert!(calls.is_empty());
301    }
302
303    #[test]
304    fn binary_operands_searched_for_calls() {
305        let stmts = lower_statement_body("v := f(x) + g(y);");
306        let calls = extract_call_sites(&stmts);
307        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
308        assert!(names.contains(&"f"));
309        assert!(names.contains(&"g"));
310    }
311
312    #[test]
313    fn callsite_serde_round_trip() {
314        let stmts = lower_statement_body("v := compute(a);");
315        let calls = extract_call_sites(&stmts);
316        let json = serde_json::to_string(&calls[0]).unwrap();
317        let back: CallSite = serde_json::from_str(&json).unwrap();
318        assert_eq!(back, calls[0]);
319        assert!(json.contains("\"context\":\"assignment\""));
320    }
321
322    #[test]
323    fn nested_block_calls_recursed() {
324        let stmts = lower_statement_body("BEGIN inner_proc(1); END;");
325        let calls = extract_call_sites(&stmts);
326        assert!(calls.iter().any(|c| c.callee_display == "inner_proc"));
327    }
328
329    // oracle-aqum.1: the UNGUARDED expression-walk path. An
330    // assignment whose RHS is a crafted flat binary chain
331    // `a OR a OR … OR a` (here with calls so `collect_calls` has work
332    // to do) used to lower into a recursion-depth tree as deep as the
333    // operand count, and `collect_calls` (calls.rs:118) re-walked that
334    // `Box<Expr>` chain to the same depth — overflowing the stack and
335    // aborting `analyze`. With the lowering depth cap the produced tree
336    // is bounded, so this walk terminates without a panic / SIGABRT.
337    #[test]
338    fn wide_assignment_rhs_chain_does_not_overflow_call_walk() {
339        let n = 500_000usize;
340        let mut rhs = String::with_capacity(n * 8);
341        for i in 0..n {
342            if i > 0 {
343                rhs.push_str(" OR ");
344            }
345            rhs.push_str("f(x)");
346        }
347        let stmt = format!("v := {rhs};");
348        let stmts = lower_statement_body(&stmt);
349        // Must simply terminate (no stack overflow / abort). We do not
350        // assert the call count — the deep tail is honestly truncated at
351        // the depth cap — only that the walk is bounded and safe.
352        let calls = extract_call_sites(&stmts);
353        assert!(
354            !calls.is_empty(),
355            "the shallow prefix of the chain still yields call sites"
356        );
357    }
358
359    // oracle-v4wa: the exact crash shape from the bundled public
360    // fixture `corpus/synthetic/l1/pkg_error_handling.pkb`. A
361    // `SELECT … FOR UPDATE;` body fragment leaves the bare token
362    // `FOR UPDATE`; the text-scanner's `classify_loop` treats
363    // `FOR …` as a FOR-loop, finds no `IN` and no `END LOOP`, and
364    // falls back to a `BareLoop` whose `body_text` is *the same
365    // string* `FOR UPDATE`. Re-lowering it yields the identical
366    // non-shrinking `BareLoop` → before the depth guard this
367    // recursed unbounded and aborted the whole `analyze`
368    // (SIGABRT / "stack overflow"). It must now terminate and
369    // report the truncation honestly (R13).
370    #[test]
371    fn non_shrinking_for_update_does_not_stack_overflow_and_reports_limit() {
372        let stmts = vec![Statement::BareLoop {
373            body_text: "FOR UPDATE".to_string(),
374        }];
375        let (calls, outcome) = extract_call_sites_bounded(&stmts);
376        assert!(
377            outcome.limit_hit,
378            "the non-shrinking `FOR UPDATE` BareLoop must trip the \
379             bounded depth cap, outcome={outcome:?}, calls={calls:?}"
380        );
381        assert!(outcome.truncated_bodies >= 1);
382        // The back-compat wrapper must also simply terminate
383        // (no panic / abort) rather than recurse unbounded.
384        let _ = extract_call_sites(&stmts);
385    }
386
387    // oracle-hrzg.5: a parenthesised call operand `nvl((compute(x)), 0)`
388    // must still record the inner `compute` call edge. Before the
389    // `recognise_paren_group` recognizer, the `(compute(x))` argument
390    // lowered to `Raw{UnrecognizedShape}` (the call recogniser bailed on
391    // a bare `(...)` whose name part is empty), dropping the COMPUTE call
392    // site that the un-parenthesised form records.
393    #[test]
394    fn parenthesised_call_operand_keeps_inner_call_edge() {
395        let stmts = lower_statement_body("v := nvl((compute(x)), 0);");
396        let calls = extract_call_sites(&stmts);
397        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
398        assert!(
399            names.contains(&"nvl"),
400            "outer nvl call must be recorded: {names:?}"
401        );
402        assert!(
403            names.contains(&"compute"),
404            "the parenthesised inner compute call must survive: {names:?}"
405        );
406    }
407}