1use crate::ast::{Program, Statement};
38use std::collections::{HashMap, HashSet};
39
40#[derive(Debug, Clone)]
42pub struct CallGraph {
43 edges: HashMap<String, HashSet<String>>,
45 words: HashSet<String>,
47 recursive_sccs: Vec<HashSet<String>>,
50}
51
52impl CallGraph {
53 pub fn build(program: &Program) -> Self {
58 let mut edges: HashMap<String, HashSet<String>> = HashMap::new();
59 let words: HashSet<String> = program.words.iter().map(|w| w.name.clone()).collect();
60
61 for word in &program.words {
62 let callees = extract_calls(&word.body, &words);
63 edges.insert(word.name.clone(), callees);
64 }
65
66 let mut graph = CallGraph {
67 edges,
68 words,
69 recursive_sccs: Vec::new(),
70 };
71
72 graph.recursive_sccs = graph.find_sccs();
74
75 graph
76 }
77
78 pub fn is_recursive(&self, word: &str) -> bool {
80 self.recursive_sccs.iter().any(|scc| scc.contains(word))
81 }
82
83 pub fn are_mutually_recursive(&self, word1: &str, word2: &str) -> bool {
85 self.recursive_sccs
86 .iter()
87 .any(|scc| scc.contains(word1) && scc.contains(word2))
88 }
89
90 pub fn get_cycle(&self, word: &str) -> Option<&HashSet<String>> {
92 self.recursive_sccs.iter().find(|scc| scc.contains(word))
93 }
94
95 pub fn recursive_cycles(&self) -> &[HashSet<String>] {
97 &self.recursive_sccs
98 }
99
100 pub fn callees(&self, word: &str) -> Option<&HashSet<String>> {
102 self.edges.get(word)
103 }
104
105 fn find_sccs(&self) -> Vec<HashSet<String>> {
111 let mut index_counter = 0;
112 let mut stack: Vec<String> = Vec::new();
113 let mut on_stack: HashSet<String> = HashSet::new();
114 let mut indices: HashMap<String, usize> = HashMap::new();
115 let mut lowlinks: HashMap<String, usize> = HashMap::new();
116 let mut sccs: Vec<HashSet<String>> = Vec::new();
117
118 for word in &self.words {
119 if !indices.contains_key(word) {
120 self.tarjan_visit(
121 word,
122 &mut index_counter,
123 &mut stack,
124 &mut on_stack,
125 &mut indices,
126 &mut lowlinks,
127 &mut sccs,
128 );
129 }
130 }
131
132 sccs.into_iter()
134 .filter(|scc| {
135 if scc.len() > 1 {
136 true
138 } else if scc.len() == 1 {
139 let word = scc.iter().next().unwrap();
141 self.edges
142 .get(word)
143 .map(|callees| callees.contains(word))
144 .unwrap_or(false)
145 } else {
146 false
147 }
148 })
149 .collect()
150 }
151
152 #[allow(clippy::too_many_arguments)]
154 fn tarjan_visit(
155 &self,
156 word: &str,
157 index_counter: &mut usize,
158 stack: &mut Vec<String>,
159 on_stack: &mut HashSet<String>,
160 indices: &mut HashMap<String, usize>,
161 lowlinks: &mut HashMap<String, usize>,
162 sccs: &mut Vec<HashSet<String>>,
163 ) {
164 let index = *index_counter;
165 *index_counter += 1;
166 indices.insert(word.to_string(), index);
167 lowlinks.insert(word.to_string(), index);
168 stack.push(word.to_string());
169 on_stack.insert(word.to_string());
170
171 if let Some(callees) = self.edges.get(word) {
173 for callee in callees {
174 if !self.words.contains(callee) {
175 continue;
177 }
178 if !indices.contains_key(callee) {
179 self.tarjan_visit(
181 callee,
182 index_counter,
183 stack,
184 on_stack,
185 indices,
186 lowlinks,
187 sccs,
188 );
189 let callee_lowlink = *lowlinks.get(callee).unwrap();
190 let word_lowlink = lowlinks.get_mut(word).unwrap();
191 *word_lowlink = (*word_lowlink).min(callee_lowlink);
192 } else if on_stack.contains(callee) {
193 let callee_index = *indices.get(callee).unwrap();
195 let word_lowlink = lowlinks.get_mut(word).unwrap();
196 *word_lowlink = (*word_lowlink).min(callee_index);
197 }
198 }
199 }
200
201 if lowlinks.get(word) == indices.get(word) {
203 let mut scc = HashSet::new();
204 loop {
205 let w = stack.pop().unwrap();
206 on_stack.remove(&w);
207 scc.insert(w.clone());
208 if w == word {
209 break;
210 }
211 }
212 sccs.push(scc);
213 }
214 }
215}
216
217fn extract_calls(statements: &[Statement], known_words: &HashSet<String>) -> HashSet<String> {
221 let mut calls = HashSet::new();
222
223 for stmt in statements {
224 extract_calls_from_statement(stmt, known_words, &mut calls);
225 }
226
227 calls
228}
229
230fn extract_calls_from_statement(
232 stmt: &Statement,
233 known_words: &HashSet<String>,
234 calls: &mut HashSet<String>,
235) {
236 match stmt {
237 Statement::WordCall { name, .. } => {
238 if known_words.contains(name) {
240 calls.insert(name.clone());
241 }
242 }
243 Statement::If {
244 then_branch,
245 else_branch,
246 } => {
247 for s in then_branch {
248 extract_calls_from_statement(s, known_words, calls);
249 }
250 if let Some(else_stmts) = else_branch {
251 for s in else_stmts {
252 extract_calls_from_statement(s, known_words, calls);
253 }
254 }
255 }
256 Statement::Quotation { body, .. } => {
257 for s in body {
258 extract_calls_from_statement(s, known_words, calls);
259 }
260 }
261 Statement::Match { arms } => {
262 for arm in arms {
263 for s in &arm.body {
264 extract_calls_from_statement(s, known_words, calls);
265 }
266 }
267 }
268 Statement::IntLiteral(_)
270 | Statement::FloatLiteral(_)
271 | Statement::BoolLiteral(_)
272 | Statement::StringLiteral(_)
273 | Statement::Symbol(_) => {}
274 }
275}
276
277#[derive(Debug, Clone)]
290#[allow(dead_code)] pub struct TailCallInfo {
292 pub recursive_words: HashSet<String>,
294}
295
296impl TailCallInfo {
297 pub fn from_call_graph(graph: &CallGraph) -> Self {
299 let mut recursive_words = HashSet::new();
300 for scc in graph.recursive_cycles() {
301 recursive_words.extend(scc.iter().cloned());
302 }
303 TailCallInfo { recursive_words }
304 }
305
306 #[allow(dead_code)] pub fn should_use_musttail(&self, caller: &str, callee: &str) -> bool {
316 self.recursive_words.contains(caller) && self.recursive_words.contains(callee)
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::ast::WordDef;
324
325 fn make_word(name: &str, calls: Vec<&str>) -> WordDef {
326 let body = calls
327 .into_iter()
328 .map(|c| Statement::WordCall {
329 name: c.to_string(),
330 span: None,
331 })
332 .collect();
333 WordDef {
334 name: name.to_string(),
335 effect: None,
336 body,
337 source: None,
338 allowed_lints: vec![],
339 }
340 }
341
342 #[test]
343 fn test_no_recursion() {
344 let program = Program {
345 includes: vec![],
346 unions: vec![],
347 words: vec![
348 make_word("foo", vec!["bar"]),
349 make_word("bar", vec![]),
350 make_word("baz", vec!["foo"]),
351 ],
352 };
353
354 let graph = CallGraph::build(&program);
355 assert!(!graph.is_recursive("foo"));
356 assert!(!graph.is_recursive("bar"));
357 assert!(!graph.is_recursive("baz"));
358 assert!(graph.recursive_cycles().is_empty());
359 }
360
361 #[test]
362 fn test_direct_recursion() {
363 let program = Program {
364 includes: vec![],
365 unions: vec![],
366 words: vec![
367 make_word("countdown", vec!["countdown"]),
368 make_word("helper", vec![]),
369 ],
370 };
371
372 let graph = CallGraph::build(&program);
373 assert!(graph.is_recursive("countdown"));
374 assert!(!graph.is_recursive("helper"));
375 assert_eq!(graph.recursive_cycles().len(), 1);
376 }
377
378 #[test]
379 fn test_mutual_recursion_pair() {
380 let program = Program {
381 includes: vec![],
382 unions: vec![],
383 words: vec![
384 make_word("ping", vec!["pong"]),
385 make_word("pong", vec!["ping"]),
386 ],
387 };
388
389 let graph = CallGraph::build(&program);
390 assert!(graph.is_recursive("ping"));
391 assert!(graph.is_recursive("pong"));
392 assert!(graph.are_mutually_recursive("ping", "pong"));
393 assert_eq!(graph.recursive_cycles().len(), 1);
394 assert_eq!(graph.recursive_cycles()[0].len(), 2);
395 }
396
397 #[test]
398 fn test_mutual_recursion_triple() {
399 let program = Program {
400 includes: vec![],
401 unions: vec![],
402 words: vec![
403 make_word("a", vec!["b"]),
404 make_word("b", vec!["c"]),
405 make_word("c", vec!["a"]),
406 ],
407 };
408
409 let graph = CallGraph::build(&program);
410 assert!(graph.is_recursive("a"));
411 assert!(graph.is_recursive("b"));
412 assert!(graph.is_recursive("c"));
413 assert!(graph.are_mutually_recursive("a", "b"));
414 assert!(graph.are_mutually_recursive("b", "c"));
415 assert!(graph.are_mutually_recursive("a", "c"));
416 assert_eq!(graph.recursive_cycles().len(), 1);
417 assert_eq!(graph.recursive_cycles()[0].len(), 3);
418 }
419
420 #[test]
421 fn test_multiple_independent_cycles() {
422 let program = Program {
423 includes: vec![],
424 unions: vec![],
425 words: vec![
426 make_word("ping", vec!["pong"]),
428 make_word("pong", vec!["ping"]),
429 make_word("even", vec!["odd"]),
431 make_word("odd", vec!["even"]),
432 make_word("main", vec!["ping", "even"]),
434 ],
435 };
436
437 let graph = CallGraph::build(&program);
438 assert!(graph.is_recursive("ping"));
439 assert!(graph.is_recursive("pong"));
440 assert!(graph.is_recursive("even"));
441 assert!(graph.is_recursive("odd"));
442 assert!(!graph.is_recursive("main"));
443
444 assert!(graph.are_mutually_recursive("ping", "pong"));
445 assert!(graph.are_mutually_recursive("even", "odd"));
446 assert!(!graph.are_mutually_recursive("ping", "even"));
447
448 assert_eq!(graph.recursive_cycles().len(), 2);
449 }
450
451 #[test]
452 fn test_calls_to_unknown_words() {
453 let program = Program {
455 includes: vec![],
456 unions: vec![],
457 words: vec![make_word("foo", vec!["dup", "drop", "unknown_builtin"])],
458 };
459
460 let graph = CallGraph::build(&program);
461 assert!(!graph.is_recursive("foo"));
462 assert!(graph.callees("foo").unwrap().is_empty());
464 }
465
466 #[test]
467 fn test_tail_call_info() {
468 let program = Program {
469 includes: vec![],
470 unions: vec![],
471 words: vec![
472 make_word("ping", vec!["pong"]),
473 make_word("pong", vec!["ping"]),
474 make_word("helper", vec![]),
475 ],
476 };
477
478 let graph = CallGraph::build(&program);
479 let info = TailCallInfo::from_call_graph(&graph);
480
481 assert!(info.should_use_musttail("ping", "pong"));
482 assert!(info.should_use_musttail("pong", "ping"));
483 assert!(!info.should_use_musttail("helper", "ping"));
484 assert!(!info.should_use_musttail("ping", "helper"));
485 }
486
487 #[test]
488 fn test_cycle_with_builtins_interspersed() {
489 let program = Program {
492 includes: vec![],
493 unions: vec![],
494 words: vec![
495 make_word("foo", vec!["dup", "drop", "bar"]),
496 make_word("bar", vec!["swap", "foo"]),
497 ],
498 };
499
500 let graph = CallGraph::build(&program);
501 assert!(graph.is_recursive("foo"));
503 assert!(graph.is_recursive("bar"));
504 assert!(graph.are_mutually_recursive("foo", "bar"));
505
506 let foo_callees = graph.callees("foo").unwrap();
508 assert!(foo_callees.contains("bar"));
509 assert!(!foo_callees.contains("dup"));
510 assert!(!foo_callees.contains("drop"));
511 }
512
513 #[test]
514 fn test_cycle_through_quotation() {
515 use crate::ast::Statement;
518
519 let program = Program {
520 includes: vec![],
521 unions: vec![],
522 words: vec![
523 WordDef {
524 name: "foo".to_string(),
525 effect: None,
526 body: vec![
527 Statement::Quotation {
528 id: 0,
529 body: vec![Statement::WordCall {
530 name: "bar".to_string(),
531 span: None,
532 }],
533 span: None,
534 },
535 Statement::WordCall {
536 name: "call".to_string(),
537 span: None,
538 },
539 ],
540 source: None,
541 allowed_lints: vec![],
542 },
543 make_word("bar", vec!["foo"]),
544 ],
545 };
546
547 let graph = CallGraph::build(&program);
548 assert!(graph.is_recursive("foo"));
550 assert!(graph.is_recursive("bar"));
551 assert!(graph.are_mutually_recursive("foo", "bar"));
552 }
553
554 #[test]
555 fn test_cycle_through_if_branch() {
556 use crate::ast::Statement;
558
559 let program = Program {
560 includes: vec![],
561 unions: vec![],
562 words: vec![
563 WordDef {
564 name: "even".to_string(),
565 effect: None,
566 body: vec![Statement::If {
567 then_branch: vec![],
568 else_branch: Some(vec![Statement::WordCall {
569 name: "odd".to_string(),
570 span: None,
571 }]),
572 }],
573 source: None,
574 allowed_lints: vec![],
575 },
576 WordDef {
577 name: "odd".to_string(),
578 effect: None,
579 body: vec![Statement::If {
580 then_branch: vec![],
581 else_branch: Some(vec![Statement::WordCall {
582 name: "even".to_string(),
583 span: None,
584 }]),
585 }],
586 source: None,
587 allowed_lints: vec![],
588 },
589 ],
590 };
591
592 let graph = CallGraph::build(&program);
593 assert!(graph.is_recursive("even"));
594 assert!(graph.is_recursive("odd"));
595 assert!(graph.are_mutually_recursive("even", "odd"));
596 }
597}