use std::sync::Arc;
use crate::ir::Node;
#[must_use]
pub fn map_children<F>(node: Node, f: &mut F) -> Node
where
F: FnMut(Node) -> Node,
{
match node {
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond,
then: then.into_iter().map(&mut *f).collect(),
otherwise: otherwise.into_iter().map(&mut *f).collect(),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var,
from,
to,
body: body.into_iter().map(&mut *f).collect(),
},
Node::Block(body) => Node::Block(body.into_iter().map(&mut *f).collect()),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: Arc::new(body_vec.into_iter().map(f).collect()),
}
}
other => other,
}
}
#[must_use]
pub fn map_body<F>(node: Node, f: &mut F) -> Node
where
F: FnMut(Vec<Node>) -> Vec<Node>,
{
match node {
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond,
then: f(then),
otherwise: f(otherwise),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var,
from,
to,
body: f(body),
},
Node::Block(body) => Node::Block(f(body)),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: Arc::new(f(body_vec)),
}
}
other => other,
}
}
#[must_use]
pub fn any_descendant<P>(node: &Node, pred: &mut P) -> bool
where
P: FnMut(&Node) -> bool,
{
let mut stack: Vec<&Node> = vec![node];
while let Some(current) = stack.pop() {
if pred(current) {
return true;
}
match current {
Node::If {
then, otherwise, ..
} => {
for child in otherwise.iter().rev().chain(then.iter().rev()) {
stack.push(child);
}
}
Node::Loop { body, .. } => {
for child in body.iter().rev() {
stack.push(child);
}
}
Node::Block(body) => {
for child in body.iter().rev() {
stack.push(child);
}
}
Node::Region { body, .. } => {
for child in body.iter().rev() {
stack.push(child);
}
}
_ => {}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::model::expr::Ident;
use crate::ir::Expr;
fn store_at(idx: u32, value: u32) -> Node {
Node::store("buf", Expr::u32(idx), Expr::u32(value))
}
#[test]
fn map_children_recurses_into_if_branches() {
let input =
Node::if_then_else(Expr::bool(true), vec![store_at(0, 1)], vec![store_at(0, 2)]);
let mut count = 0;
let mapped = map_children(input, &mut |n| {
count += 1;
match n {
Node::Store { .. } => Node::Block(Vec::new()),
other => other,
}
});
assert_eq!(count, 2, "callback must fire once per branch's store");
match mapped {
Node::If {
then, otherwise, ..
} => {
assert!(matches!(then[0], Node::Block(_)));
assert!(matches!(otherwise[0], Node::Block(_)));
}
other => panic!("expected Node::If; got {other:?}"),
}
}
#[test]
fn map_children_recurses_into_loop_body() {
let input = Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![store_at(0, 7), store_at(1, 8)],
};
let mut count = 0;
let _mapped = map_children(input, &mut |n| {
count += 1;
n
});
assert_eq!(count, 2);
}
#[test]
fn map_children_recurses_into_block() {
let input = Node::Block(vec![store_at(0, 1), store_at(1, 2), store_at(2, 3)]);
let mut count = 0;
let _mapped = map_children(input, &mut |n| {
count += 1;
n
});
assert_eq!(count, 3);
}
#[test]
fn map_children_recurses_into_region_body() {
let input = Node::Region {
generator: Ident::from("test_op"),
source_region: None,
body: Arc::new(vec![store_at(0, 1)]),
};
let mut count = 0;
let mapped = map_children(input, &mut |n| {
count += 1;
n
});
assert_eq!(count, 1);
assert!(matches!(mapped, Node::Region { .. }));
}
#[test]
fn map_children_preserves_op_id_through_region_unwrap_clone_path() {
let body = Arc::new(vec![store_at(0, 1)]);
let _keepalive = Arc::clone(&body);
let input = Node::Region {
generator: Ident::from("test_op_with_clone"),
source_region: None,
body,
};
let mapped = map_children(input, &mut |n| n);
match mapped {
Node::Region { generator, .. } => {
assert_eq!(generator.as_str(), "test_op_with_clone");
}
other => panic!("expected Region; got {other:?}"),
}
}
#[test]
fn map_children_preserves_loop_metadata() {
let input = Node::Loop {
var: Ident::from("ix"),
from: Expr::u32(2),
to: Expr::u32(9),
body: Vec::new(),
};
let mapped = map_children(input, &mut |n| n);
match mapped {
Node::Loop { var, from, to, .. } => {
assert_eq!(var.as_str(), "ix");
assert!(matches!(from, Expr::LitU32(2)));
assert!(matches!(to, Expr::LitU32(9)));
}
other => panic!("expected Loop; got {other:?}"),
}
}
#[test]
fn map_children_returns_non_container_nodes_unchanged() {
let input = store_at(0, 7);
let mut fired = false;
let mapped = map_children(input, &mut |_n| {
fired = true;
unreachable!("non-container nodes must not invoke the callback")
});
assert!(!fired, "no children = no callback invocations");
assert!(matches!(mapped, Node::Store { .. }));
}
#[test]
fn any_descendant_finds_match_at_root() {
let node = store_at(0, 7);
assert!(any_descendant(&node, &mut |n| matches!(
n,
Node::Store { .. }
)));
}
#[test]
fn any_descendant_recurses_into_nested_region() {
let node = Node::Block(vec![Node::if_then(
Expr::bool(true),
vec![Node::Region {
generator: Ident::from("nested"),
source_region: None,
body: Arc::new(vec![store_at(0, 1)]),
}],
)]);
assert!(any_descendant(&node, &mut |n| matches!(
n,
Node::Store { .. }
)));
}
#[test]
fn any_descendant_returns_false_when_no_match() {
let node = Node::Block(vec![Node::if_then_else(
Expr::bool(true),
vec![Node::Block(Vec::new())],
vec![Node::Block(Vec::new())],
)]);
assert!(!any_descendant(&node, &mut |n| matches!(
n,
Node::Store { .. }
)));
}
#[test]
fn any_descendant_short_circuits() {
let node = Node::Block(vec![
store_at(0, 1),
store_at(1, 2),
store_at(2, 3),
store_at(3, 4),
]);
let mut visited = 0;
let _ = any_descendant(&node, &mut |n| {
visited += 1;
matches!(n, Node::Store { .. })
});
assert!(
visited <= 2,
"any_descendant must short-circuit; visited {visited}"
);
}
fn deep_if_tree(depth: usize) -> Node {
let mut node = store_at(0, 1);
for _ in 0..depth {
node = Node::if_then(Expr::bool(true), vec![node]);
}
node
}
fn any_descendant_recursive<P>(node: &Node, pred: &mut P) -> bool
where
P: FnMut(&Node) -> bool,
{
if pred(node) {
return true;
}
match node {
Node::If {
then, otherwise, ..
} => {
then.iter().any(|n| any_descendant_recursive(n, pred))
|| otherwise.iter().any(|n| any_descendant_recursive(n, pred))
}
Node::Loop { body, .. } => body.iter().any(|n| any_descendant_recursive(n, pred)),
Node::Block(body) => body.iter().any(|n| any_descendant_recursive(n, pred)),
Node::Region { body, .. } => body.iter().any(|n| any_descendant_recursive(n, pred)),
_ => false,
}
}
#[test]
fn any_descendant_iterative_no_stack_overflow_on_deep_tree() {
let deep = deep_if_tree(1000);
assert!(
any_descendant(&deep, &mut |n| matches!(n, Node::Store { .. })),
"must find the Store leaf at depth 1000 without stack overflow"
);
}
#[test]
fn any_descendant_iterative_matches_recursive_traversal() {
let tree = Node::Block(vec![Node::if_then_else(
Expr::bool(true),
vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(2),
body: vec![store_at(1, 2)],
}],
vec![Node::Region {
generator: Ident::from("r"),
source_region: None,
body: Arc::new(vec![store_at(3, 4)]),
}],
)]);
let mut recursive_ptrs = Vec::new();
let _ = any_descendant_recursive(&tree, &mut |n| {
recursive_ptrs.push(n as *const Node);
false
});
let mut iterative_ptrs = Vec::new();
let _ = any_descendant(&tree, &mut |n| {
iterative_ptrs.push(n as *const Node);
false
});
assert_eq!(
recursive_ptrs, iterative_ptrs,
"iterative walker must visit the exact same nodes in the exact same pre-order as the recursive reference"
);
}
}