seqc/call_graph.rs
1//! Call graph analysis for detecting mutual recursion
2//!
3//! This module builds a call graph from a Seq program and detects
4//! strongly connected components (SCCs) to identify mutual recursion cycles.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let call_graph = CallGraph::build(&program);
10//! let cycles = call_graph.recursive_cycles();
11//! ```
12//!
13//! # Primary Use Cases
14//!
15//! 1. **Type checker divergence detection**: The type checker uses the call graph
16//! to identify mutually recursive tail calls, enabling correct type inference
17//! for patterns like even/odd that would otherwise require branch unification.
18//!
19//! 2. **Future optimizations**: The call graph infrastructure can support dead code
20//! detection, inlining decisions, and diagnostic tools.
21//!
22//! # Implementation Details
23//!
24//! - **Algorithm**: Tarjan's SCC algorithm, O(V + E) time complexity
25//! - **Builtins**: Calls to builtins/external words are excluded from the graph
26//! (they don't affect recursion detection since they always return)
27//! - **Quotations**: Calls within quotations are included in the analysis
28//! - **Match arms**: Calls within match arms are included in the analysis
29//!
30//! # Note on Tail Call Optimization
31//!
32//! The existing codegen already emits `musttail` for all tail calls to user-defined
33//! words (see `codegen/statements.rs`). This means mutual TCO works automatically
34//! without needing explicit call graph checks in codegen. The call graph is primarily
35//! used for type checking, not for enabling TCO.
36
37use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40/// A call graph representing which words call which other words.
41#[derive(Debug, Clone)]
42pub struct CallGraph {
43 /// Map from word name to the set of words it calls
44 edges: HashMap<String, HashSet<String>>,
45 /// All word names in the program
46 words: HashSet<String>,
47 /// Strongly connected components with more than one member (mutual recursion)
48 /// or single members that call themselves (direct recursion)
49 recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53 /// Build a call graph from a program.
54 ///
55 /// This extracts all word-to-word call relationships, including calls
56 /// within quotations, if branches, and match arms.
57 pub fn build(program: &Program) -> Self {
58 let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59 let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61 for word in &program.words {
62 let callees = extract_calls(&word.body, &words);
63 edges.insert(word.name.clone(), callees);
64 }
65
66 let mut graph = CallGraph {
67 edges,
68 words,
69 recursive_sccs: Vec::new(),
70 };
71
72 // Compute SCCs and identify recursive cycles
73 graph.recursive_sccs = graph.find_sccs();
74
75 graph
76 }
77
78 /// Check if a word is part of any recursive cycle (direct or mutual).
79 pub fn is_recursive(&self, word: &str) -> bool {
80 self.recursive_sccs.iter().any(|scc| scc.contains(word))
81 }
82
83 /// Check if two words are in the same recursive cycle (mutually recursive).
84 pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85 self.recursive_sccs
86 .iter()
87 .any(|scc| scc.contains(word1) && scc.contains(word2))
88 }
89
90 /// Get all recursive cycles (SCCs with recursion).
91 pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92 &self.recursive_sccs
93 }
94
95 /// Get the words that a given word calls.
96 pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97 self.edges.get(word)
98 }
99
100 /// Find strongly connected components using Tarjan's algorithm.
101 ///
102 /// Returns only SCCs that represent recursion:
103 /// - Multi-word SCCs (mutual recursion)
104 /// - Single-word SCCs where the word calls itself (direct recursion)
105 fn find_sccs(&self) -> Vec<HashSet<String>> {
106 let mut index_counter = 0;
107 let mut stack: Vec<String> = Vec::new();
108 let mut on_stack: HashSet<String> = HashSet::new();
109 let mut indices: HashMap<String, usize> = HashMap::new();
110 let mut lowlinks: HashMap<String, usize> = HashMap::new();
111 let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113 for word in &self.words {
114 if !indices.contains_key(word) {
115 self.tarjan_visit(
116 word,
117 &mut index_counter,
118 &mut stack,
119 &mut on_stack,
120 &mut indices,
121 &mut lowlinks,
122 &mut sccs,
123 );
124 }
125 }
126
127 // Filter to only recursive SCCs
128 sccs.into_iter()
129 .filter(|scc| {
130 if scc.len() > 1 {
131 // Multi-word SCC = mutual recursion
132 true
133 } else if scc.len() == 1 {
134 // Single-word SCC: check if it calls itself
135 let word = scc.iter().next().unwrap();
136 self.edges
137 .get(word)
138 .map(|callees| callees.contains(word))
139 .unwrap_or(false)
140 } else {
141 false
142 }
143 })
144 .collect()
145 }
146
147 /// Tarjan's algorithm recursive visit.
148 #[allow(clippy::too_many_arguments)]
149 fn tarjan_visit(
150 &self,
151 word: &str,
152 index_counter: &mut usize,
153 stack: &mut Vec<String>,
154 on_stack: &mut HashSet<String>,
155 indices: &mut HashMap<String, usize>,
156 lowlinks: &mut HashMap<String, usize>,
157 sccs: &mut Vec<HashSet<String>>,
158 ) {
159 let index = *index_counter;
160 *index_counter += 1;
161 indices.insert(word.to_string(), index);
162 lowlinks.insert(word.to_string(), index);
163 stack.push(word.to_string());
164 on_stack.insert(word.to_string());
165
166 // Visit all callees
167 if let Some(callees) = self.edges.get(word) {
168 for callee in callees {
169 if !self.words.contains(callee) {
170 // External word (builtin), skip
171 continue;
172 }
173 if !indices.contains_key(callee) {
174 // Not yet visited
175 self.tarjan_visit(
176 callee,
177 index_counter,
178 stack,
179 on_stack,
180 indices,
181 lowlinks,
182 sccs,
183 );
184 let callee_lowlink = *lowlinks.get(callee).unwrap();
185 let word_lowlink = lowlinks.get_mut(word).unwrap();
186 *word_lowlink = (*word_lowlink).min(callee_lowlink);
187 } else if on_stack.contains(callee) {
188 // Callee is on stack, part of current SCC
189 let callee_index = *indices.get(callee).unwrap();
190 let word_lowlink = lowlinks.get_mut(word).unwrap();
191 *word_lowlink = (*word_lowlink).min(callee_index);
192 }
193 }
194 }
195
196 // If word is a root node, pop the SCC
197 if lowlinks.get(word) == indices.get(word) {
198 let mut scc = HashSet::new();
199 loop {
200 let w = stack.pop().unwrap();
201 on_stack.remove(&w);
202 scc.insert(w.clone());
203 if w == word {
204 break;
205 }
206 }
207 sccs.push(scc);
208 }
209 }
210}
211
212/// Extract all word calls from a list of statements.
213///
214/// This recursively descends into quotations, if branches, and match arms.
215fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216 let mut calls = HashSet::new();
217
218 for stmt in statements {
219 extract_calls_from_statement(stmt, known_words, &mut calls);
220 }
221
222 calls
223}
224
225/// Extract word calls from a single statement.
226fn extract_calls_from_statement(
227 stmt: &Statement,
228 known_words: &HashSet<String>,
229 calls: &mut HashSet<String>,
230) {
231 match stmt {
232 Statement::WordCall { name, .. } => {
233 // Only track calls to user-defined words
234 if known_words.contains(name) {
235 calls.insert(name.clone());
236 }
237 }
238 Statement::If {
239 then_branch,
240 else_branch,
241 span: _,
242 } => {
243 for s in then_branch {
244 extract_calls_from_statement(s, known_words, calls);
245 }
246 if let Some(else_stmts) = else_branch {
247 for s in else_stmts {
248 extract_calls_from_statement(s, known_words, calls);
249 }
250 }
251 }
252 Statement::Quotation { body, .. } => {
253 for s in body {
254 extract_calls_from_statement(s, known_words, calls);
255 }
256 }
257 Statement::Match { arms, span: _ } => {
258 for arm in arms {
259 for s in &arm.body {
260 extract_calls_from_statement(s, known_words, calls);
261 }
262 }
263 }
264 // Literals don't contain calls
265 Statement::IntLiteral(_)
266 | Statement::FloatLiteral(_)
267 | Statement::BoolLiteral(_)
268 | Statement::StringLiteral(_)
269 | Statement::Symbol(_) => {}
270 }
271}
272
273#[cfg(test)]
274mod tests;