use crate::ast::{Program, Statement};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct CallGraph {
edges: HashMap<String, HashSet<String>>,
words: HashSet<String>,
recursive_sccs: Vec<HashSet<String>>,
}
impl CallGraph {
pub fn build(program: &Program) -> Self {
let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
for word in &program.words {
let callees = extract_calls(&word.body, &words);
edges.insert(word.name.clone(), callees);
}
let mut graph = CallGraph {
edges,
words,
recursive_sccs: Vec::new(),
};
graph.recursive_sccs = graph.find_sccs();
graph
}
pub fn is_recursive(&self, word: &str) -> bool {
self.recursive_sccs.iter().any(|scc| scc.contains(word))
}
pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
self.recursive_sccs
.iter()
.any(|scc| scc.contains(word1) && scc.contains(word2))
}
pub fn recursive_cycles(&self) -> &[HashSet<String>] {
&self.recursive_sccs
}
pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
self.edges.get(word)
}
fn find_sccs(&self) -> Vec<HashSet<String>> {
let mut index_counter = 0;
let mut stack: Vec<String> = Vec::new();
let mut on_stack: HashSet<String> = HashSet::new();
let mut indices: HashMap<String, usize> = HashMap::new();
let mut lowlinks: HashMap<String, usize> = HashMap::new();
let mut sccs: Vec<HashSet<String>> = Vec::new();
for word in &self.words {
if !indices.contains_key(word) {
self.tarjan_visit(
word,
&mut index_counter,
&mut stack,
&mut on_stack,
&mut indices,
&mut lowlinks,
&mut sccs,
);
}
}
sccs.into_iter()
.filter(|scc| {
if scc.len() > 1 {
true
} else if scc.len() == 1 {
let word = scc.iter().next().unwrap();
self.edges
.get(word)
.map(|callees| callees.contains(word))
.unwrap_or(false)
} else {
false
}
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn tarjan_visit(
&self,
word: &str,
index_counter: &mut usize,
stack: &mut Vec<String>,
on_stack: &mut HashSet<String>,
indices: &mut HashMap<String, usize>,
lowlinks: &mut HashMap<String, usize>,
sccs: &mut Vec<HashSet<String>>,
) {
let index = *index_counter;
*index_counter += 1;
indices.insert(word.to_string(), index);
lowlinks.insert(word.to_string(), index);
stack.push(word.to_string());
on_stack.insert(word.to_string());
if let Some(callees) = self.edges.get(word) {
for callee in callees {
if !self.words.contains(callee) {
continue;
}
if !indices.contains_key(callee) {
self.tarjan_visit(
callee,
index_counter,
stack,
on_stack,
indices,
lowlinks,
sccs,
);
let callee_lowlink = *lowlinks.get(callee).unwrap();
let word_lowlink = lowlinks.get_mut(word).unwrap();
*word_lowlink = (*word_lowlink).min(callee_lowlink);
} else if on_stack.contains(callee) {
let callee_index = *indices.get(callee).unwrap();
let word_lowlink = lowlinks.get_mut(word).unwrap();
*word_lowlink = (*word_lowlink).min(callee_index);
}
}
}
if lowlinks.get(word) == indices.get(word) {
let mut scc = HashSet::new();
loop {
let w = stack.pop().unwrap();
on_stack.remove(&w);
scc.insert(w.clone());
if w == word {
break;
}
}
sccs.push(scc);
}
}
}
fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
let mut calls = HashSet::new();
for stmt in statements {
extract_calls_from_statement(stmt, known_words, &mut calls);
}
calls
}
fn extract_calls_from_statement(
stmt: &Statement,
known_words: &HashSet<String>,
calls: &mut HashSet<String>,
) {
match stmt {
Statement::WordCall { name, .. } => {
if known_words.contains(name) {
calls.insert(name.clone());
}
}
Statement::If {
then_branch,
else_branch,
span: _,
} => {
for s in then_branch {
extract_calls_from_statement(s, known_words, calls);
}
if let Some(else_stmts) = else_branch {
for s in else_stmts {
extract_calls_from_statement(s, known_words, calls);
}
}
}
Statement::Quotation { body, .. } => {
for s in body {
extract_calls_from_statement(s, known_words, calls);
}
}
Statement::Match { arms, span: _ } => {
for arm in arms {
for s in &arm.body {
extract_calls_from_statement(s, known_words, calls);
}
}
}
Statement::IntLiteral(_)
| Statement::FloatLiteral(_)
| Statement::BoolLiteral(_)
| Statement::StringLiteral(_)
| Statement::Symbol(_) => {}
}
}
#[cfg(test)]
mod tests;