logicaffeine_compile/analysis/
callgraph.rs1use std::collections::{HashMap, HashSet};
2
3use logicaffeine_base::{Interner, Symbol};
4use logicaffeine_language::ast::{Expr, Stmt};
5use logicaffeine_language::ast::stmt::ClosureBody;
6
7pub struct CallGraph {
13 pub edges: HashMap<Symbol, HashSet<Symbol>>,
15 pub native_fns: HashSet<Symbol>,
17 pub sccs: Vec<Vec<Symbol>>,
19}
20
21impl CallGraph {
22 pub fn build(stmts: &[Stmt<'_>], _interner: &Interner) -> Self {
27 let mut edges: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
28 let mut native_fns: HashSet<Symbol> = HashSet::new();
29
30 for stmt in stmts {
31 if let Stmt::FunctionDef { name, body, is_native, .. } = stmt {
32 edges.entry(*name).or_default();
33 if *is_native {
34 native_fns.insert(*name);
35 } else {
36 let callees = edges.entry(*name).or_default();
37 collect_calls_from_stmts(body, callees);
38 }
39 }
40 }
41
42 let sccs = compute_sccs(&edges);
43
44 Self { edges, native_fns, sccs }
45 }
46
47 pub fn reachable_from(&self, fn_sym: Symbol) -> HashSet<Symbol> {
51 let mut visited = HashSet::new();
52 let mut stack = Vec::new();
53
54 if let Some(callees) = self.edges.get(&fn_sym) {
55 for &c in callees {
56 if c != fn_sym {
57 stack.push(c);
58 }
59 }
60 }
61
62 while let Some(f) = stack.pop() {
63 if visited.insert(f) {
64 if let Some(callees) = self.edges.get(&f) {
65 for &c in callees {
66 if !visited.contains(&c) {
67 stack.push(c);
68 }
69 }
70 }
71 }
72 }
73
74 visited
75 }
76
77 pub fn is_recursive(&self, fn_sym: Symbol) -> bool {
80 if self.edges.get(&fn_sym).map(|s| s.contains(&fn_sym)).unwrap_or(false) {
82 return true;
83 }
84 for scc in &self.sccs {
86 if scc.len() > 1 && scc.contains(&fn_sym) {
87 return true;
88 }
89 }
90 false
91 }
92}
93
94fn collect_calls_from_stmts(stmts: &[Stmt<'_>], calls: &mut HashSet<Symbol>) {
99 for stmt in stmts {
100 collect_calls_from_stmt(stmt, calls);
101 }
102}
103
104fn collect_calls_from_stmt(stmt: &Stmt<'_>, calls: &mut HashSet<Symbol>) {
105 match stmt {
106 Stmt::Call { function, args } => {
107 calls.insert(*function);
108 for arg in args {
109 collect_calls_from_expr(arg, calls);
110 }
111 }
112 Stmt::Let { value, .. } => collect_calls_from_expr(value, calls),
113 Stmt::Set { value, .. } => collect_calls_from_expr(value, calls),
114 Stmt::Return { value: Some(v) } => collect_calls_from_expr(v, calls),
115 Stmt::If { cond, then_block, else_block } => {
116 collect_calls_from_expr(cond, calls);
117 collect_calls_from_stmts(then_block, calls);
118 if let Some(else_b) = else_block {
119 collect_calls_from_stmts(else_b, calls);
120 }
121 }
122 Stmt::While { cond, body, .. } => {
123 collect_calls_from_expr(cond, calls);
124 collect_calls_from_stmts(body, calls);
125 }
126 Stmt::Repeat { iterable, body, .. } => {
127 collect_calls_from_expr(iterable, calls);
128 collect_calls_from_stmts(body, calls);
129 }
130 Stmt::Push { value, collection } => {
131 collect_calls_from_expr(value, calls);
132 collect_calls_from_expr(collection, calls);
133 }
134 Stmt::Pop { collection, .. } => collect_calls_from_expr(collection, calls),
135 Stmt::Add { value, collection } => {
136 collect_calls_from_expr(value, calls);
137 collect_calls_from_expr(collection, calls);
138 }
139 Stmt::Remove { value, collection } => {
140 collect_calls_from_expr(value, calls);
141 collect_calls_from_expr(collection, calls);
142 }
143 Stmt::SetIndex { collection, index, value } => {
144 collect_calls_from_expr(collection, calls);
145 collect_calls_from_expr(index, calls);
146 collect_calls_from_expr(value, calls);
147 }
148 Stmt::SetField { object, value, .. } => {
149 collect_calls_from_expr(object, calls);
150 collect_calls_from_expr(value, calls);
151 }
152 Stmt::Inspect { target, arms, .. } => {
153 collect_calls_from_expr(target, calls);
154 for arm in arms {
155 collect_calls_from_stmts(arm.body, calls);
156 }
157 }
158 Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
159 collect_calls_from_stmts(tasks, calls);
160 }
161 Stmt::Zone { body, .. } => collect_calls_from_stmts(body, calls),
162 _ => {}
163 }
164}
165
166fn collect_calls_from_expr(expr: &Expr<'_>, calls: &mut HashSet<Symbol>) {
167 match expr {
168 Expr::Call { function, args } => {
169 calls.insert(*function);
170 for arg in args {
171 collect_calls_from_expr(arg, calls);
172 }
173 }
174 Expr::Closure { body, .. } => match body {
175 ClosureBody::Expression(e) => collect_calls_from_expr(e, calls),
176 ClosureBody::Block(stmts) => collect_calls_from_stmts(stmts, calls),
177 },
178 Expr::BinaryOp { left, right, .. } => {
179 collect_calls_from_expr(left, calls);
180 collect_calls_from_expr(right, calls);
181 }
182 Expr::Index { collection, index } => {
183 collect_calls_from_expr(collection, calls);
184 collect_calls_from_expr(index, calls);
185 }
186 Expr::Slice { collection, start, end } => {
187 collect_calls_from_expr(collection, calls);
188 collect_calls_from_expr(start, calls);
189 collect_calls_from_expr(end, calls);
190 }
191 Expr::Length { collection } => collect_calls_from_expr(collection, calls),
192 Expr::Contains { collection, value } => {
193 collect_calls_from_expr(collection, calls);
194 collect_calls_from_expr(value, calls);
195 }
196 Expr::Union { left, right } | Expr::Intersection { left, right } => {
197 collect_calls_from_expr(left, calls);
198 collect_calls_from_expr(right, calls);
199 }
200 Expr::FieldAccess { object, .. } => collect_calls_from_expr(object, calls),
201 Expr::List(items) | Expr::Tuple(items) => {
202 for item in items {
203 collect_calls_from_expr(item, calls);
204 }
205 }
206 Expr::Range { start, end } => {
207 collect_calls_from_expr(start, calls);
208 collect_calls_from_expr(end, calls);
209 }
210 Expr::Copy { expr } | Expr::Give { value: expr } => {
211 collect_calls_from_expr(expr, calls);
212 }
213 Expr::OptionSome { value } => collect_calls_from_expr(value, calls),
214 Expr::WithCapacity { value, capacity } => {
215 collect_calls_from_expr(value, calls);
216 collect_calls_from_expr(capacity, calls);
217 }
218 Expr::CallExpr { callee, args } => {
219 collect_calls_from_expr(callee, calls);
220 for arg in args {
221 collect_calls_from_expr(arg, calls);
222 }
223 }
224 _ => {}
225 }
226}
227
228fn compute_sccs(edges: &HashMap<Symbol, HashSet<Symbol>>) -> Vec<Vec<Symbol>> {
233 let nodes: Vec<Symbol> = edges.keys().copied().collect();
234
235 let mut visited: HashSet<Symbol> = HashSet::new();
237 let mut finish_order: Vec<Symbol> = Vec::new();
238
239 for &v in &nodes {
240 if !visited.contains(&v) {
241 dfs_finish(v, edges, &mut visited, &mut finish_order);
242 }
243 }
244
245 let mut rev_edges: HashMap<Symbol, HashSet<Symbol>> = HashMap::new();
247 for (&src, callees) in edges {
248 for &dst in callees {
249 rev_edges.entry(dst).or_default().insert(src);
250 }
251 }
252
253 let mut visited2: HashSet<Symbol> = HashSet::new();
255 let mut sccs: Vec<Vec<Symbol>> = Vec::new();
256
257 for &v in finish_order.iter().rev() {
258 if !visited2.contains(&v) {
259 let mut scc = Vec::new();
260 dfs_collect(v, &rev_edges, &mut visited2, &mut scc);
261 sccs.push(scc);
262 }
263 }
264
265 sccs
266}
267
268fn dfs_finish(
269 v: Symbol,
270 edges: &HashMap<Symbol, HashSet<Symbol>>,
271 visited: &mut HashSet<Symbol>,
272 finish_order: &mut Vec<Symbol>,
273) {
274 if !visited.insert(v) {
275 return;
276 }
277 if let Some(callees) = edges.get(&v) {
278 for &callee in callees {
279 dfs_finish(callee, edges, visited, finish_order);
280 }
281 }
282 finish_order.push(v);
283}
284
285fn dfs_collect(
286 v: Symbol,
287 edges: &HashMap<Symbol, HashSet<Symbol>>,
288 visited: &mut HashSet<Symbol>,
289 scc: &mut Vec<Symbol>,
290) {
291 if !visited.insert(v) {
292 return;
293 }
294 scc.push(v);
295 if let Some(callees) = edges.get(&v) {
296 for &callee in callees {
297 dfs_collect(callee, edges, visited, scc);
298 }
299 }
300}