Skip to main content

seqc/
call_graph.rs

1//! Call graph analysis for detecting mutual recursion
2//!
3//! This module builds a call graph from a Seq program and detects
4//! strongly connected components (SCCs) to identify mutual recursion cycles.
5//!
6//! # Usage
7//!
8//! ```ignore
9//! let call_graph = CallGraph::build(&program);
10//! let cycles = call_graph.recursive_cycles();
11//! ```
12//!
13//! # Primary Use Cases
14//!
15//! 1. **Type checker divergence detection**: The type checker uses the call graph
16//!    to identify mutually recursive tail calls, enabling correct type inference
17//!    for patterns like even/odd that would otherwise require branch unification.
18//!
19//! 2. **Future optimizations**: The call graph infrastructure can support dead code
20//!    detection, inlining decisions, and diagnostic tools.
21//!
22//! # Implementation Details
23//!
24//! - **Algorithm**: Tarjan's SCC algorithm, O(V + E) time complexity
25//! - **Builtins**: Calls to builtins/external words are excluded from the graph
26//!   (they don't affect recursion detection since they always return)
27//! - **Quotations**: Calls within quotations are included in the analysis
28//! - **Match arms**: Calls within match arms are included in the analysis
29//!
30//! # Note on Tail Call Optimization
31//!
32//! The existing codegen already emits `musttail` for all tail calls to user-defined
33//! words (see `codegen/statements.rs`). This means mutual TCO works automatically
34//! without needing explicit call graph checks in codegen. The call graph is primarily
35//! used for type checking, not for enabling TCO.
36
37use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40/// A call graph representing which words call which other words.
41#[derive(Debug, Clone)]
42pub struct CallGraph {
43    /// Map from word name to the set of words it calls
44    edges: HashMap<String, HashSet<String>>,
45    /// All word names in the program
46    words: HashSet<String>,
47    /// Strongly connected components with more than one member (mutual recursion)
48    /// or single members that call themselves (direct recursion)
49    recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53    /// Build a call graph from a program.
54    ///
55    /// This extracts all word-to-word call relationships, including calls
56    /// within quotations, if branches, and match arms.
57    pub fn build(program: &Program) -> Self {
58        let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59        let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61        for word in &program.words {
62            let callees = extract_calls(&word.body, &words);
63            edges.insert(word.name.clone(), callees);
64        }
65
66        let mut graph = CallGraph {
67            edges,
68            words,
69            recursive_sccs: Vec::new(),
70        };
71
72        // Compute SCCs and identify recursive cycles
73        graph.recursive_sccs = graph.find_sccs();
74
75        graph
76    }
77
78    /// Check if a word is part of any recursive cycle (direct or mutual).
79    pub fn is_recursive(&self, word: &str) -> bool {
80        self.recursive_sccs.iter().any(|scc| scc.contains(word))
81    }
82
83    /// Check if two words are in the same recursive cycle (mutually recursive).
84    pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85        self.recursive_sccs
86            .iter()
87            .any(|scc| scc.contains(word1) && scc.contains(word2))
88    }
89
90    /// Get all recursive cycles (SCCs with recursion).
91    pub fn recursive_cycles(&self) -> &[HashSet<String>] {
92        &self.recursive_sccs
93    }
94
95    /// Get the words that a given word calls.
96    pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
97        self.edges.get(word)
98    }
99
100    /// Find strongly connected components using Tarjan's algorithm.
101    ///
102    /// Returns only SCCs that represent recursion:
103    /// - Multi-word SCCs (mutual recursion)
104    /// - Single-word SCCs where the word calls itself (direct recursion)
105    fn find_sccs(&self) -> Vec<HashSet<String>> {
106        let mut index_counter = 0;
107        let mut stack: Vec<String> = Vec::new();
108        let mut on_stack: HashSet<String> = HashSet::new();
109        let mut indices: HashMap<String, usize> = HashMap::new();
110        let mut lowlinks: HashMap<String, usize> = HashMap::new();
111        let mut sccs: Vec<HashSet<String>> = Vec::new();
112
113        for word in &self.words {
114            if !indices.contains_key(word) {
115                self.tarjan_visit(
116                    word,
117                    &mut index_counter,
118                    &mut stack,
119                    &mut on_stack,
120                    &mut indices,
121                    &mut lowlinks,
122                    &mut sccs,
123                );
124            }
125        }
126
127        // Filter to only recursive SCCs
128        sccs.into_iter()
129            .filter(|scc| {
130                if scc.len() > 1 {
131                    // Multi-word SCC = mutual recursion
132                    true
133                } else if scc.len() == 1 {
134                    // Single-word SCC: check if it calls itself
135                    let word = scc.iter().next().unwrap();
136                    self.edges
137                        .get(word)
138                        .map(|callees| callees.contains(word))
139                        .unwrap_or(false)
140                } else {
141                    false
142                }
143            })
144            .collect()
145    }
146
147    /// Tarjan's algorithm recursive visit.
148    #[allow(clippy::too_many_arguments)]
149    fn tarjan_visit(
150        &self,
151        word: &str,
152        index_counter: &mut usize,
153        stack: &mut Vec<String>,
154        on_stack: &mut HashSet<String>,
155        indices: &mut HashMap<String, usize>,
156        lowlinks: &mut HashMap<String, usize>,
157        sccs: &mut Vec<HashSet<String>>,
158    ) {
159        let index = *index_counter;
160        *index_counter += 1;
161        indices.insert(word.to_string(), index);
162        lowlinks.insert(word.to_string(), index);
163        stack.push(word.to_string());
164        on_stack.insert(word.to_string());
165
166        // Visit all callees
167        if let Some(callees) = self.edges.get(word) {
168            for callee in callees {
169                if !self.words.contains(callee) {
170                    // External word (builtin), skip
171                    continue;
172                }
173                if !indices.contains_key(callee) {
174                    // Not yet visited
175                    self.tarjan_visit(
176                        callee,
177                        index_counter,
178                        stack,
179                        on_stack,
180                        indices,
181                        lowlinks,
182                        sccs,
183                    );
184                    let callee_lowlink = *lowlinks.get(callee).unwrap();
185                    let word_lowlink = lowlinks.get_mut(word).unwrap();
186                    *word_lowlink = (*word_lowlink).min(callee_lowlink);
187                } else if on_stack.contains(callee) {
188                    // Callee is on stack, part of current SCC
189                    let callee_index = *indices.get(callee).unwrap();
190                    let word_lowlink = lowlinks.get_mut(word).unwrap();
191                    *word_lowlink = (*word_lowlink).min(callee_index);
192                }
193            }
194        }
195
196        // If word is a root node, pop the SCC
197        if lowlinks.get(word) == indices.get(word) {
198            let mut scc = HashSet::new();
199            loop {
200                let w = stack.pop().unwrap();
201                on_stack.remove(&w);
202                scc.insert(w.clone());
203                if w == word {
204                    break;
205                }
206            }
207            sccs.push(scc);
208        }
209    }
210}
211
212/// Extract all word calls from a list of statements.
213///
214/// This recursively descends into quotations, if branches, and match arms.
215fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
216    let mut calls = HashSet::new();
217
218    for stmt in statements {
219        extract_calls_from_statement(stmt, known_words, &mut calls);
220    }
221
222    calls
223}
224
225/// Extract word calls from a single statement.
226fn extract_calls_from_statement(
227    stmt: &Statement,
228    known_words: &HashSet<String>,
229    calls: &mut HashSet<String>,
230) {
231    match stmt {
232        Statement::WordCall { name, .. } => {
233            // Only track calls to user-defined words
234            if known_words.contains(name) {
235                calls.insert(name.clone());
236            }
237        }
238        Statement::If {
239            then_branch,
240            else_branch,
241            span: _,
242        } => {
243            for s in then_branch {
244                extract_calls_from_statement(s, known_words, calls);
245            }
246            if let Some(else_stmts) = else_branch {
247                for s in else_stmts {
248                    extract_calls_from_statement(s, known_words, calls);
249                }
250            }
251        }
252        Statement::Quotation { body, .. } => {
253            for s in body {
254                extract_calls_from_statement(s, known_words, calls);
255            }
256        }
257        Statement::Match { arms, span: _ } => {
258            for arm in arms {
259                for s in &arm.body {
260                    extract_calls_from_statement(s, known_words, calls);
261                }
262            }
263        }
264        // Literals don't contain calls
265        Statement::IntLiteral(_)
266        | Statement::FloatLiteral(_)
267        | Statement::BoolLiteral(_)
268        | Statement::StringLiteral(_)
269        | Statement::Symbol(_) => {}
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276    use crate::ast::WordDef;
277
278    fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
279        let body = calls
280            .into_iter()
281            .map(|c| Statement::WordCall {
282                name: c.to_string(),
283                span: None,
284            })
285            .collect();
286        WordDef {
287            name: name.to_string(),
288            effect: None,
289            body,
290            source: None,
291            allowed_lints: vec![],
292        }
293    }
294
295    #[test]
296    fn test_no_recursion() {
297        let program = Program {
298            includes: vec![],
299            unions: vec![],
300            words: vec![
301                make_word("foo", vec!["bar"]),
302                make_word("bar", vec![]),
303                make_word("baz", vec!["foo"]),
304            ],
305        };
306
307        let graph = CallGraph::build(&program);
308        assert!(!graph.is_recursive("foo"));
309        assert!(!graph.is_recursive("bar"));
310        assert!(!graph.is_recursive("baz"));
311        assert!(graph.recursive_cycles().is_empty());
312    }
313
314    #[test]
315    fn test_direct_recursion() {
316        let program = Program {
317            includes: vec![],
318            unions: vec![],
319            words: vec![
320                make_word("countdown", vec!["countdown"]),
321                make_word("helper", vec![]),
322            ],
323        };
324
325        let graph = CallGraph::build(&program);
326        assert!(graph.is_recursive("countdown"));
327        assert!(!graph.is_recursive("helper"));
328        assert_eq!(graph.recursive_cycles().len(), 1);
329    }
330
331    #[test]
332    fn test_mutual_recursion_pair() {
333        let program = Program {
334            includes: vec![],
335            unions: vec![],
336            words: vec![
337                make_word("ping", vec!["pong"]),
338                make_word("pong", vec!["ping"]),
339            ],
340        };
341
342        let graph = CallGraph::build(&program);
343        assert!(graph.is_recursive("ping"));
344        assert!(graph.is_recursive("pong"));
345        assert!(graph.are_mutually_recursive("ping", "pong"));
346        assert_eq!(graph.recursive_cycles().len(), 1);
347        assert_eq!(graph.recursive_cycles()[0].len(), 2);
348    }
349
350    #[test]
351    fn test_mutual_recursion_triple() {
352        let program = Program {
353            includes: vec![],
354            unions: vec![],
355            words: vec![
356                make_word("a", vec!["b"]),
357                make_word("b", vec!["c"]),
358                make_word("c", vec!["a"]),
359            ],
360        };
361
362        let graph = CallGraph::build(&program);
363        assert!(graph.is_recursive("a"));
364        assert!(graph.is_recursive("b"));
365        assert!(graph.is_recursive("c"));
366        assert!(graph.are_mutually_recursive("a", "b"));
367        assert!(graph.are_mutually_recursive("b", "c"));
368        assert!(graph.are_mutually_recursive("a", "c"));
369        assert_eq!(graph.recursive_cycles().len(), 1);
370        assert_eq!(graph.recursive_cycles()[0].len(), 3);
371    }
372
373    #[test]
374    fn test_multiple_independent_cycles() {
375        let program = Program {
376            includes: vec![],
377            unions: vec![],
378            words: vec![
379                // Cycle 1: ping <-> pong
380                make_word("ping", vec!["pong"]),
381                make_word("pong", vec!["ping"]),
382                // Cycle 2: even <-> odd
383                make_word("even", vec!["odd"]),
384                make_word("odd", vec!["even"]),
385                // Non-recursive
386                make_word("main", vec!["ping", "even"]),
387            ],
388        };
389
390        let graph = CallGraph::build(&program);
391        assert!(graph.is_recursive("ping"));
392        assert!(graph.is_recursive("pong"));
393        assert!(graph.is_recursive("even"));
394        assert!(graph.is_recursive("odd"));
395        assert!(!graph.is_recursive("main"));
396
397        assert!(graph.are_mutually_recursive("ping", "pong"));
398        assert!(graph.are_mutually_recursive("even", "odd"));
399        assert!(!graph.are_mutually_recursive("ping", "even"));
400
401        assert_eq!(graph.recursive_cycles().len(), 2);
402    }
403
404    #[test]
405    fn test_calls_to_unknown_words() {
406        // Calls to builtins or external words should be ignored
407        let program = Program {
408            includes: vec![],
409            unions: vec![],
410            words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
411        };
412
413        let graph = CallGraph::build(&program);
414        assert!(!graph.is_recursive("foo"));
415        // Callees should only include known words
416        assert!(graph.callees("foo").unwrap().is_empty());
417    }
418
419    #[test]
420    fn test_cycle_with_builtins_interspersed() {
421        // Cycles should be detected even when builtins are called between user words
422        // e.g., : foo dup drop bar ;  : bar swap foo ;
423        let program = Program {
424            includes: vec![],
425            unions: vec![],
426            words: vec![
427                make_word("foo", vec!["dup", "drop", "bar"]),
428                make_word("bar", vec!["swap", "foo"]),
429            ],
430        };
431
432        let graph = CallGraph::build(&program);
433        // foo and bar should still form a cycle despite builtin calls
434        assert!(graph.is_recursive("foo"));
435        assert!(graph.is_recursive("bar"));
436        assert!(graph.are_mutually_recursive("foo", "bar"));
437
438        // Builtins should not appear in callees
439        let foo_callees = graph.callees("foo").unwrap();
440        assert!(foo_callees.contains("bar"));
441        assert!(!foo_callees.contains("dup"));
442        assert!(!foo_callees.contains("drop"));
443    }
444
445    #[test]
446    fn test_cycle_through_quotation() {
447        // Calls inside quotations should be detected
448        // e.g., : foo [ bar ] call ;  : bar foo ;
449        use crate::ast::Statement;
450
451        let program = Program {
452            includes: vec![],
453            unions: vec![],
454            words: vec![
455                WordDef {
456                    name: "foo".to_string(),
457                    effect: None,
458                    body: vec![
459                        Statement::Quotation {
460                            id: 0,
461                            body: vec![Statement::WordCall {
462                                name: "bar".to_string(),
463                                span: None,
464                            }],
465                            span: None,
466                        },
467                        Statement::WordCall {
468                            name: "call".to_string(),
469                            span: None,
470                        },
471                    ],
472                    source: None,
473                    allowed_lints: vec![],
474                },
475                make_word("bar", vec!["foo"]),
476            ],
477        };
478
479        let graph = CallGraph::build(&program);
480        // foo calls bar (inside quotation), bar calls foo
481        assert!(graph.is_recursive("foo"));
482        assert!(graph.is_recursive("bar"));
483        assert!(graph.are_mutually_recursive("foo", "bar"));
484    }
485
486    #[test]
487    fn test_cycle_through_if_branch() {
488        // Calls inside if branches should be detected
489        use crate::ast::Statement;
490
491        let program = Program {
492            includes: vec![],
493            unions: vec![],
494            words: vec![
495                WordDef {
496                    name: "even".to_string(),
497                    effect: None,
498                    body: vec![Statement::If {
499                        then_branch: vec![],
500                        else_branch: Some(vec![Statement::WordCall {
501                            name: "odd".to_string(),
502                            span: None,
503                        }]),
504                        span: None,
505                    }],
506                    source: None,
507                    allowed_lints: vec![],
508                },
509                WordDef {
510                    name: "odd".to_string(),
511                    effect: None,
512                    body: vec![Statement::If {
513                        then_branch: vec![],
514                        else_branch: Some(vec![Statement::WordCall {
515                            name: "even".to_string(),
516                            span: None,
517                        }]),
518                        span: None,
519                    }],
520                    source: None,
521                    allowed_lints: vec![],
522                },
523            ],
524        };
525
526        let graph = CallGraph::build(&program);
527        assert!(graph.is_recursive("even"));
528        assert!(graph.is_recursive("odd"));
529        assert!(graph.are_mutually_recursive("even", "odd"));
530    }
531}