Skip to main content

logicaffeine_compile/analysis/
callgraph.rs

1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::{Interner, Symbol};
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7/// Whole-program call graph for the LOGOS compilation pipeline.
8///
9/// Captures all direct and closure-embedded call edges between user-defined
10/// functions. Used by `ReadonlyParams` for transitive mutation detection and
11/// by liveness analysis for inter-procedural precision.
12pub struct CallGraph {
13    /// Direct call edges: fn_sym → set of directly called function symbols.
14    pub edges: HashMap<Symbol, HashSet<Symbol>>,
15    /// Set of native (extern) function symbols.
16    pub native_fns: HashSet<Symbol>,
17    /// Strongly connected components (Kosaraju's algorithm).
18    pub sccs: Vec<Vec<Symbol>>,
19}
20
21impl CallGraph {
22    /// Build the call graph from a program's top-level statements.
23    ///
24    /// Walks all `FunctionDef` bodies, collecting `Stmt::Call` and
25    /// `Expr::Call` targets, including calls inside closure bodies.
26    pub fn build(stmts: &[Stmt<'_>], _interner: &Interner) -> Self {
27        let mut edges: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
28        let mut native_fns: HashSet<Symbol> = HashSet::new();
29
30        for stmt in stmts {
31            if let Stmt::FunctionDef { name, body, is_native, .. } = stmt {
32                edges.entry(*name).or_default();
33                if *is_native {
34                    native_fns.insert(*name);
35                } else {
36                    let callees = edges.entry(*name).or_default();
37                    collect_calls_from_stmts(body, callees);
38                }
39            }
40        }
41
42        let sccs = compute_sccs(&edges);
43
44        Self { edges, native_fns, sccs }
45    }
46
47    /// Returns all functions reachable from `fn_sym` via the call graph.
48    ///
49    /// Does not include `fn_sym` itself unless it is part of a cycle.
50    pub fn reachable_from(&self, fn_sym: Symbol) -> HashSet<Symbol> {
51        let mut visited = HashSet::new();
52        let mut stack = Vec::new();
53
54        if let Some(callees) = self.edges.get(&fn_sym) {
55            for &c in callees {
56                if c != fn_sym {
57                    stack.push(c);
58                }
59            }
60        }
61
62        while let Some(f) = stack.pop() {
63            if visited.insert(f) {
64                if let Some(callees) = self.edges.get(&f) {
65                    for &c in callees {
66                        if !visited.contains(&c) {
67                            stack.push(c);
68                        }
69                    }
70                }
71            }
72        }
73
74        visited
75    }
76
77    /// Returns `true` if `fn_sym` participates in a recursive cycle
78    /// (direct self-call or mutual recursion via SCC membership).
79    pub fn is_recursive(&self, fn_sym: Symbol) -> bool {
80        // Direct self-edge
81        if self.edges.get(&fn_sym).map(|s| s.contains(&fn_sym)).unwrap_or(false) {
82            return true;
83        }
84        // Mutual recursion: fn_sym is in an SCC with more than one member
85        for scc in &self.sccs {
86            if scc.len() > 1 && scc.contains(&fn_sym) {
87                return true;
88            }
89        }
90        false
91    }
92}
93
94// =============================================================================
95// Call collection from AST
96// =============================================================================
97
98fn collect_calls_from_stmts(stmts: &[Stmt<'_>], calls: &mut HashSet<Symbol>) {
99    for stmt in stmts {
100        collect_calls_from_stmt(stmt, calls);
101    }
102}
103
104fn collect_calls_from_stmt(stmt: &Stmt<'_>, calls: &mut HashSet<Symbol>) {
105    match stmt {
106        Stmt::Call { function, args } => {
107            calls.insert(*function);
108            for arg in args {
109                collect_calls_from_expr(arg, calls);
110            }
111        }
112        Stmt::Let { value, .. } => collect_calls_from_expr(value, calls),
113        Stmt::Set { value, .. } => collect_calls_from_expr(value, calls),
114        Stmt::Return { value: Some(v) } => collect_calls_from_expr(v, calls),
115        Stmt::If { cond, then_block, else_block } => {
116            collect_calls_from_expr(cond, calls);
117            collect_calls_from_stmts(then_block, calls);
118            if let Some(else_b) = else_block {
119                collect_calls_from_stmts(else_b, calls);
120            }
121        }
122        Stmt::While { cond, body, .. } => {
123            collect_calls_from_expr(cond, calls);
124            collect_calls_from_stmts(body, calls);
125        }
126        Stmt::Repeat { iterable, body, .. } => {
127            collect_calls_from_expr(iterable, calls);
128            collect_calls_from_stmts(body, calls);
129        }
130        Stmt::Push { value, collection } => {
131            collect_calls_from_expr(value, calls);
132            collect_calls_from_expr(collection, calls);
133        }
134        Stmt::Pop { collection, .. } => collect_calls_from_expr(collection, calls),
135        Stmt::Add { value, collection } => {
136            collect_calls_from_expr(value, calls);
137            collect_calls_from_expr(collection, calls);
138        }
139        Stmt::Remove { value, collection } => {
140            collect_calls_from_expr(value, calls);
141            collect_calls_from_expr(collection, calls);
142        }
143        Stmt::SetIndex { collection, index, value } => {
144            collect_calls_from_expr(collection, calls);
145            collect_calls_from_expr(index, calls);
146            collect_calls_from_expr(value, calls);
147        }
148        Stmt::SetField { object, value, .. } => {
149            collect_calls_from_expr(object, calls);
150            collect_calls_from_expr(value, calls);
151        }
152        Stmt::Inspect { target, arms, .. } => {
153            collect_calls_from_expr(target, calls);
154            for arm in arms {
155                collect_calls_from_stmts(arm.body, calls);
156            }
157        }
158        Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
159            collect_calls_from_stmts(tasks, calls);
160        }
161        Stmt::Zone { body, .. } => collect_calls_from_stmts(body, calls),
162        _ => {}
163    }
164}
165
166fn collect_calls_from_expr(expr: &Expr<'_>, calls: &mut HashSet<Symbol>) {
167    match expr {
168        Expr::Call { function, args } => {
169            calls.insert(*function);
170            for arg in args {
171                collect_calls_from_expr(arg, calls);
172            }
173        }
174        Expr::Closure { body, .. } => match body {
175            ClosureBody::Expression(e) => collect_calls_from_expr(e, calls),
176            ClosureBody::Block(stmts) => collect_calls_from_stmts(stmts, calls),
177        },
178        Expr::BinaryOp { left, right, .. } => {
179            collect_calls_from_expr(left, calls);
180            collect_calls_from_expr(right, calls);
181        }
182        Expr::Index { collection, index } => {
183            collect_calls_from_expr(collection, calls);
184            collect_calls_from_expr(index, calls);
185        }
186        Expr::Slice { collection, start, end } => {
187            collect_calls_from_expr(collection, calls);
188            collect_calls_from_expr(start, calls);
189            collect_calls_from_expr(end, calls);
190        }
191        Expr::Length { collection } => collect_calls_from_expr(collection, calls),
192        Expr::Contains { collection, value } => {
193            collect_calls_from_expr(collection, calls);
194            collect_calls_from_expr(value, calls);
195        }
196        Expr::Union { left, right } | Expr::Intersection { left, right } => {
197            collect_calls_from_expr(left, calls);
198            collect_calls_from_expr(right, calls);
199        }
200        Expr::FieldAccess { object, .. } => collect_calls_from_expr(object, calls),
201        Expr::List(items) | Expr::Tuple(items) => {
202            for item in items {
203                collect_calls_from_expr(item, calls);
204            }
205        }
206        Expr::Range { start, end } => {
207            collect_calls_from_expr(start, calls);
208            collect_calls_from_expr(end, calls);
209        }
210        Expr::Copy { expr } | Expr::Give { value: expr } => {
211            collect_calls_from_expr(expr, calls);
212        }
213        Expr::OptionSome { value } => collect_calls_from_expr(value, calls),
214        Expr::WithCapacity { value, capacity } => {
215            collect_calls_from_expr(value, calls);
216            collect_calls_from_expr(capacity, calls);
217        }
218        Expr::CallExpr { callee, args } => {
219            collect_calls_from_expr(callee, calls);
220            for arg in args {
221                collect_calls_from_expr(arg, calls);
222            }
223        }
224        _ => {}
225    }
226}
227
228// =============================================================================
229// Kosaraju's SCC algorithm
230// =============================================================================
231
232fn compute_sccs(edges: &HashMap<Symbol, HashSet<Symbol>>) -> Vec<Vec<Symbol>> {
233    let nodes: Vec<Symbol> = edges.keys().copied().collect();
234
235    // Step 1: DFS on forward graph to compute finish order
236    let mut visited: HashSet<Symbol> = HashSet::new();
237    let mut finish_order: Vec<Symbol> = Vec::new();
238
239    for &v in &nodes {
240        if !visited.contains(&v) {
241            dfs_finish(v, edges, &mut visited, &mut finish_order);
242        }
243    }
244
245    // Step 2: Build reversed graph
246    let mut rev_edges: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
247    for (&src, callees) in edges {
248        for &dst in callees {
249            rev_edges.entry(dst).or_default().insert(src);
250        }
251    }
252
253    // Step 3: DFS on reversed graph in reverse finish order to collect SCCs
254    let mut visited2: HashSet<Symbol> = HashSet::new();
255    let mut sccs: Vec<Vec<Symbol>> = Vec::new();
256
257    for &v in finish_order.iter().rev() {
258        if !visited2.contains(&v) {
259            let mut scc = Vec::new();
260            dfs_collect(v, &rev_edges, &mut visited2, &mut scc);
261            sccs.push(scc);
262        }
263    }
264
265    sccs
266}
267
268fn dfs_finish(
269    v: Symbol,
270    edges: &HashMap<Symbol, HashSet<Symbol>>,
271    visited: &mut HashSet<Symbol>,
272    finish_order: &mut Vec<Symbol>,
273) {
274    if !visited.insert(v) {
275        return;
276    }
277    if let Some(callees) = edges.get(&v) {
278        for &callee in callees {
279            dfs_finish(callee, edges, visited, finish_order);
280        }
281    }
282    finish_order.push(v);
283}
284
285fn dfs_collect(
286    v: Symbol,
287    edges: &HashMap<Symbol, HashSet<Symbol>>,
288    visited: &mut HashSet<Symbol>,
289    scc: &mut Vec<Symbol>,
290) {
291    if !visited.insert(v) {
292        return;
293    }
294    scc.push(v);
295    if let Some(callees) = edges.get(&v) {
296        for &callee in callees {
297            dfs_collect(callee, edges, visited, scc);
298        }
299    }
300}