patronus/expr/
traversal.rs

1// Copyright 2024 Cornell University
2// released under BSD 3-Clause License
3// author: Kevin Laeufer <laeufer@cornell.edu>
4
5//! # Expression Traversals
6//!
7//! Contains functions to simplify non-recursive implementations of expression traversals.
8
9use crate::expr::{Context, Expr, ExprRef, ForEachChild};
10
11/// Visits expression nodes bottom up while propagating values
12#[inline]
13pub fn bottom_up<R>(
14    ctx: &Context,
15    expr: ExprRef,
16    f: impl FnMut(&Context, ExprRef, &[R]) -> R,
17) -> R {
18    bottom_up_multi_pat(
19        ctx,
20        expr,
21        |_ctx, expr, children| expr.for_each_child(|c| children.push(*c)),
22        f,
23    )
24}
25
26/// Visits expression nodes bottom up while propagating values.
27/// Can match patterns with multiple nodes that will turn into a single output value.
28#[inline]
29pub fn bottom_up_multi_pat<R>(
30    ctx: &Context,
31    expr: ExprRef,
32    mut get_children: impl FnMut(&Context, &Expr, &mut Vec<ExprRef>),
33    mut f: impl FnMut(&Context, ExprRef, &[R]) -> R,
34) -> R {
35    let mut todo = vec![(expr, false)];
36    let mut stack = Vec::with_capacity(4);
37    let mut child_vec = Vec::with_capacity(4);
38
39    while let Some((e, bottom_up)) = todo.pop() {
40        let expr = &ctx[e];
41
42        // Check if there are children that we need to compute first.
43        if !bottom_up {
44            // check if there are child expressions to evaluate
45            debug_assert!(child_vec.is_empty());
46            get_children(ctx, expr, &mut child_vec);
47            if !child_vec.is_empty() {
48                todo.push((e, true));
49                for c in child_vec.drain(..) {
50                    todo.push((c, false));
51                }
52                continue;
53            }
54        }
55
56        // Otherwise, all arguments are available on the stack for us to use.
57        let num_children = expr.num_children();
58        let values = &stack[stack.len() - num_children..];
59        let result = f(ctx, e, values);
60        stack.truncate(stack.len() - num_children);
61        stack.push(result);
62    }
63
64    debug_assert_eq!(stack.len(), 1);
65    stack.pop().unwrap()
66}
67
68#[derive(Debug, Copy, Clone, Eq, PartialEq)]
69pub enum TraversalCmd {
70    Stop,
71    Continue,
72}
73
74/// Visits expression from top to bottom. Halts exploration if a false is returned.
75#[inline]
76pub fn top_down(
77    ctx: &Context,
78    expr: ExprRef,
79    mut f: impl FnMut(&Context, ExprRef) -> TraversalCmd,
80) {
81    let mut todo = vec![expr];
82    while let Some(e) = todo.pop() {
83        let do_continue = f(ctx, e) == TraversalCmd::Continue;
84        if do_continue {
85            ctx[e].for_each_child(|&c| todo.push(c));
86        }
87    }
88}