plsql-ir 0.1.0

Typed semantic intermediate representation for plsql-intelligence
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
//! Call-site edge extraction.
//!
//! Walks a lowered statement body and pulls out every
//! procedure / function invocation as a [`CallSite`]. The
//! dependency-graph layer resolves each `callee` to a concrete
//! node (via `plsql_symbols::resolve_reference`) and mints a
//! `Calls` edge; this module's job is purely *extraction* — find
//! the call sites and their shape.
//!
//! Calls appear in three places:
//!
//! 1. Statement-level procedure calls — a bare
//!    `Statement::Unrecognized` line whose text is
//!    `pkg.proc(args);` (the stmt recogniser leaves these
//!    unclassified because they're neither assignment nor
//!    control flow).
//! 2. Expression-embedded function calls — inside an
//!    `Assignment.rhs_text`, an `If` arm condition, a loop
//!    range, a `Return` value, etc.
//! 3. Nested calls — `nvl(compute(x), 0)` yields both `nvl`
//!    and `compute`.
//!
//! ## /oracle evidence
//!
//! * `DATABASE-REFERENCE.md` PL/SQL Language Reference — the
//!   call grammar (positional / named notation, package-
//!   qualified vs bare) drives what counts as a callee.
//! * `LOW-LEVEL-CATALOGS.md` Data Dictionary View Families —
//!   `ALL_DEPENDENCIES` with `DEPENDENCY_TYPE` is the
//!   server-side mirror the depgraph cross-checks `Calls`
//!   edges against.

use serde::{Deserialize, Serialize};

use crate::expr::{Expr, lower_expression};
use crate::stmt::Statement;

/// One extracted call site.
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallSite {
    /// Dotted callee path, case-folded for the lookup key.
    pub callee_parts: Vec<String>,
    /// Source-form callee path preserved for diagnostics.
    pub callee_display: String,
    /// Number of positional arguments at the call. Named-notation
    /// args still count toward arity here; the depgraph's overload
    /// resolver (SYM-009) handles named-vs-positional matching.
    pub arg_count: usize,
    /// Context the call appeared in — drives the edge's
    /// confidence + the report wording.
    pub context: CallContext,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum CallContext {
    /// Statement-level procedure call (`pkg.proc(args);`).
    Statement,
    /// Function call inside an assignment RHS.
    Assignment,
    /// Function call inside a control-flow condition / range.
    ControlFlow,
    /// Function call inside a RETURN expression.
    ReturnValue,
}

/// Extract every call site from a lowered statement body.
///
/// Backwards-compatible wrapper around
/// [`extract_call_sites_bounded`]: the recursion is depth-guarded
/// so a malformed unit whose re-lowered body fails to shrink can
/// never stack-overflow. Callers that need to surface the typed
/// [`plsql_core::UnknownReason::AnalysisRecursionLimit`] degradation
/// should call [`extract_call_sites_bounded`] directly.
#[must_use]
pub fn extract_call_sites(stmts: &[Statement]) -> Vec<CallSite> {
    extract_call_sites_bounded(stmts).0
}

/// Depth-bounded variant of [`extract_call_sites`]. Returns the
/// extracted call sites plus a [`RecursionOutcome`] recording
/// whether (and how often) a nested body was abandoned at the
/// recursion-depth cap rather than walked unbounded. The caller is
/// responsible for emitting an honest typed diagnostic when
/// `outcome.limit_hit` (R13 — never silently truncate).
#[must_use]
pub fn extract_call_sites_bounded(stmts: &[Statement]) -> (Vec<CallSite>, crate::RecursionOutcome) {
    let mut out: Vec<CallSite> = Vec::new();
    let mut outcome = crate::RecursionOutcome::default();
    walk_call_sites(stmts, 0, &mut out, &mut outcome);
    (out, outcome)
}

fn walk_call_sites(
    stmts: &[Statement],
    depth: usize,
    out: &mut Vec<CallSite>,
    outcome: &mut crate::RecursionOutcome,
) {
    // Recurse into a re-lowered body only while we have depth
    // budget left. At the cap we stop descending and record the
    // truncation so the caller can surface it honestly — we do
    // NOT silently drop it and we do NOT keep recursing (which
    // would stack-overflow on a non-shrinking malformed slice).
    macro_rules! recurse_body {
        ($text:expr) => {{
            if depth + 1 >= crate::MAX_RELOWER_DEPTH {
                outcome.note_truncated();
            } else {
                let lowered = crate::lower_statement_body($text);
                walk_call_sites(&lowered, depth + 1, out, outcome);
            }
        }};
    }
    for stmt in stmts {
        match stmt {
            Statement::Assignment { rhs_text, .. } => {
                collect_calls(&lower_expression(rhs_text), CallContext::Assignment, out);
            }
            Statement::Return {
                value_text: Some(v),
            } => {
                collect_calls(&lower_expression(v), CallContext::ReturnValue, out);
            }
            Statement::If {
                arms,
                else_body_text,
            } => {
                for arm in arms {
                    collect_calls(
                        &lower_expression(&arm.cond_text),
                        CallContext::ControlFlow,
                        out,
                    );
                    recurse_body!(&arm.body_text);
                }
                if let Some(eb) = else_body_text {
                    recurse_body!(eb);
                }
            }
            Statement::WhileLoop {
                cond_text,
                body_text,
            } => {
                collect_calls(&lower_expression(cond_text), CallContext::ControlFlow, out);
                recurse_body!(body_text);
            }
            Statement::ForLoop {
                range_text,
                body_text,
                ..
            } => {
                collect_calls(&lower_expression(range_text), CallContext::ControlFlow, out);
                recurse_body!(body_text);
            }
            Statement::BareLoop { body_text } => {
                recurse_body!(body_text);
            }
            Statement::NestedBlock { body_text } => {
                // Strip the BEGIN…END / DECLARE…END wrapper before
                // re-lowering, otherwise the stmt recogniser keeps
                // classifying the same text as a NestedBlock and
                // recursion never terminates.
                let inner = strip_block_wrapper(body_text);
                if inner != body_text.as_str() {
                    recurse_body!(inner);
                } else {
                    // No wrapper to strip — treat the text as a
                    // single expression candidate instead of
                    // recursing.
                    collect_calls(&lower_expression(body_text), CallContext::Statement, out);
                }
            }
            Statement::Unrecognized { raw_text, .. } => {
                // Statement-level procedure call: `pkg.proc(args);`.
                let e = lower_expression(raw_text);
                collect_calls(&e, CallContext::Statement, out);
            }
            _ => {}
        }
    }
}

/// Strip a leading `DECLARE`/`BEGIN` and a trailing `END[;]`
/// from a block body so the inner statements can be re-lowered
/// without re-triggering the NestedBlock classification.
///
/// Shared with the sibling re-lowering walks
/// [`crate::flow_intra`] (taint) and [`crate::dml_edges`]
/// (Reads/Writes edges) so all three descend into anonymous
/// `BEGIN … END` / `DECLARE … END` sub-blocks identically — the
/// returned slice is a sub-slice of `text`, so it is always on a
/// UTF-8 char boundary even for multi-byte content.
pub(crate) fn strip_block_wrapper(text: &str) -> &str {
    let trimmed = text.trim();
    let upper = trimmed.to_ascii_uppercase();
    let after_open = if let Some(rest) = upper.strip_prefix("DECLARE") {
        &trimmed[trimmed.len() - rest.len()..]
    } else if let Some(rest) = upper.strip_prefix("BEGIN") {
        &trimmed[trimmed.len() - rest.len()..]
    } else {
        return text;
    };
    let after_open = after_open.trim_start();
    // Drop a trailing `END;` / `END`.
    let upper_inner = after_open.to_ascii_uppercase();
    if let Some(pos) = upper_inner.rfind("END") {
        after_open[..pos].trim_end()
    } else {
        after_open
    }
}

fn collect_calls(expr: &Expr, ctx: CallContext, out: &mut Vec<CallSite>) {
    match expr {
        Expr::Call { callee, args } => {
            out.push(CallSite {
                callee_parts: callee.parts.clone(),
                callee_display: callee.display.clone(),
                arg_count: args.len(),
                context: ctx,
            });
            for a in args {
                collect_calls(a, ctx, out);
            }
        }
        Expr::Binary { lhs, rhs, .. } => {
            collect_calls(lhs, ctx, out);
            collect_calls(rhs, ctx, out);
        }
        Expr::Unary { operand, .. } => collect_calls(operand, ctx, out),
        _ => {}
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::lower_statement_body;

    #[test]
    fn assignment_rhs_call_extracted() {
        let stmts = lower_statement_body("v_total := compute_sum(a, b);");
        let calls = extract_call_sites(&stmts);
        assert_eq!(calls.len(), 1);
        assert_eq!(calls[0].callee_parts, vec!["COMPUTE_SUM"]);
        assert_eq!(calls[0].arg_count, 2);
        assert_eq!(calls[0].context, CallContext::Assignment);
    }

    #[test]
    fn nested_call_yields_both_callees() {
        let stmts = lower_statement_body("v := nvl(compute(x), 0);");
        let calls = extract_call_sites(&stmts);
        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
        assert!(names.contains(&"nvl"));
        assert!(names.contains(&"compute"));
    }

    #[test]
    fn return_value_call_context() {
        let stmts = lower_statement_body("RETURN compute_total(p_id);");
        let calls = extract_call_sites(&stmts);
        assert_eq!(calls.len(), 1);
        assert_eq!(calls[0].context, CallContext::ReturnValue);
    }

    #[test]
    fn statement_level_proc_call_extracted() {
        let stmts = lower_statement_body("billing_pkg.post_invoice(p_id, p_amount);");
        let calls = extract_call_sites(&stmts);
        assert_eq!(calls.len(), 1);
        assert_eq!(calls[0].callee_parts, vec!["BILLING_PKG", "POST_INVOICE"]);
        assert_eq!(calls[0].context, CallContext::Statement);
        assert_eq!(calls[0].arg_count, 2);
    }

    #[test]
    fn if_condition_and_body_calls_extracted() {
        let src = "IF is_valid(p_id) THEN log_event('ok'); END IF;";
        let stmts = lower_statement_body(src);
        let calls = extract_call_sites(&stmts);
        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
        assert!(names.contains(&"is_valid"));
        assert!(names.contains(&"log_event"));
    }

    #[test]
    fn for_loop_body_calls_recursed() {
        let src = "FOR i IN 1..10 LOOP process_row(i); END LOOP;";
        let stmts = lower_statement_body(src);
        let calls = extract_call_sites(&stmts);
        assert!(calls.iter().any(|c| c.callee_display == "process_row"));
    }

    #[test]
    fn no_calls_in_pure_arithmetic() {
        let stmts = lower_statement_body("v := a + b * 2;");
        let calls = extract_call_sites(&stmts);
        assert!(calls.is_empty());
    }

    #[test]
    fn binary_operands_searched_for_calls() {
        let stmts = lower_statement_body("v := f(x) + g(y);");
        let calls = extract_call_sites(&stmts);
        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
        assert!(names.contains(&"f"));
        assert!(names.contains(&"g"));
    }

    #[test]
    fn callsite_serde_round_trip() {
        let stmts = lower_statement_body("v := compute(a);");
        let calls = extract_call_sites(&stmts);
        let json = serde_json::to_string(&calls[0]).unwrap();
        let back: CallSite = serde_json::from_str(&json).unwrap();
        assert_eq!(back, calls[0]);
        assert!(json.contains("\"context\":\"assignment\""));
    }

    #[test]
    fn nested_block_calls_recursed() {
        let stmts = lower_statement_body("BEGIN inner_proc(1); END;");
        let calls = extract_call_sites(&stmts);
        assert!(calls.iter().any(|c| c.callee_display == "inner_proc"));
    }

    // oracle-aqum.1: the UNGUARDED expression-walk path. An
    // assignment whose RHS is a crafted flat binary chain
    // `a OR a OR … OR a` (here with calls so `collect_calls` has work
    // to do) used to lower into a recursion-depth tree as deep as the
    // operand count, and `collect_calls` (calls.rs:118) re-walked that
    // `Box<Expr>` chain to the same depth — overflowing the stack and
    // aborting `analyze`. With the lowering depth cap the produced tree
    // is bounded, so this walk terminates without a panic / SIGABRT.
    #[test]
    fn wide_assignment_rhs_chain_does_not_overflow_call_walk() {
        let n = 500_000usize;
        let mut rhs = String::with_capacity(n * 8);
        for i in 0..n {
            if i > 0 {
                rhs.push_str(" OR ");
            }
            rhs.push_str("f(x)");
        }
        let stmt = format!("v := {rhs};");
        let stmts = lower_statement_body(&stmt);
        // Must simply terminate (no stack overflow / abort). We do not
        // assert the call count — the deep tail is honestly truncated at
        // the depth cap — only that the walk is bounded and safe.
        let calls = extract_call_sites(&stmts);
        assert!(
            !calls.is_empty(),
            "the shallow prefix of the chain still yields call sites"
        );
    }

    // oracle-v4wa: the exact crash shape from the bundled public
    // fixture `corpus/synthetic/l1/pkg_error_handling.pkb`. A
    // `SELECT … FOR UPDATE;` body fragment leaves the bare token
    // `FOR UPDATE`; the text-scanner's `classify_loop` treats
    // `FOR …` as a FOR-loop, finds no `IN` and no `END LOOP`, and
    // falls back to a `BareLoop` whose `body_text` is *the same
    // string* `FOR UPDATE`. Re-lowering it yields the identical
    // non-shrinking `BareLoop` → before the depth guard this
    // recursed unbounded and aborted the whole `analyze`
    // (SIGABRT / "stack overflow"). It must now terminate and
    // report the truncation honestly (R13).
    #[test]
    fn non_shrinking_for_update_does_not_stack_overflow_and_reports_limit() {
        let stmts = vec![Statement::BareLoop {
            body_text: "FOR UPDATE".to_string(),
        }];
        let (calls, outcome) = extract_call_sites_bounded(&stmts);
        assert!(
            outcome.limit_hit,
            "the non-shrinking `FOR UPDATE` BareLoop must trip the \
             bounded depth cap, outcome={outcome:?}, calls={calls:?}"
        );
        assert!(outcome.truncated_bodies >= 1);
        // The back-compat wrapper must also simply terminate
        // (no panic / abort) rather than recurse unbounded.
        let _ = extract_call_sites(&stmts);
    }

    // oracle-hrzg.5: a parenthesised call operand `nvl((compute(x)), 0)`
    // must still record the inner `compute` call edge. Before the
    // `recognise_paren_group` recognizer, the `(compute(x))` argument
    // lowered to `Raw{UnrecognizedShape}` (the call recogniser bailed on
    // a bare `(...)` whose name part is empty), dropping the COMPUTE call
    // site that the un-parenthesised form records.
    #[test]
    fn parenthesised_call_operand_keeps_inner_call_edge() {
        let stmts = lower_statement_body("v := nvl((compute(x)), 0);");
        let calls = extract_call_sites(&stmts);
        let names: Vec<&str> = calls.iter().map(|c| c.callee_display.as_str()).collect();
        assert!(
            names.contains(&"nvl"),
            "outer nvl call must be recorded: {names:?}"
        );
        assert!(
            names.contains(&"compute"),
            "the parenthesised inner compute call must survive: {names:?}"
        );
    }
}