use std::sync::Arc;
use crate::ir::{Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use crate::visit::node_map;
#[derive(Debug, Default)]
#[vyre_pass(
name = "region_promote_singleton_block",
requires = [],
invalidates = []
)]
pub struct RegionPromoteSingletonBlockPass;
impl RegionPromoteSingletonBlockPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program
.entry()
.iter()
.any(|n| node_map::any_descendant(n, &mut is_singleton_block_region))
{
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry = program
.into_entry_vec()
.into_iter()
.map(|n| promote_node(n, &mut changed))
.collect();
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn promote_node(node: Node, changed: &mut bool) -> Node {
let recursed = node_map::map_children(node, &mut |child| promote_node(child, changed));
if let Node::Region {
generator,
source_region,
body,
} = recursed
{
let body_vec: Vec<Node> = match Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
if matches!(body_vec.as_slice(), [Node::Block(_)]) {
*changed = true;
let mut iter = body_vec.into_iter();
let inner = match iter.next() {
Some(Node::Block(inner)) => inner,
_ => unreachable!("matched [Node::Block(_)] above"),
};
return Node::Region {
generator,
source_region,
body: Arc::new(inner),
};
}
return Node::Region {
generator,
source_region,
body: Arc::new(body_vec),
};
}
recursed
}
fn is_singleton_block_region(node: &Node) -> bool {
matches!(
node,
Node::Region { body, .. } if matches!(body.as_slice(), [Node::Block(_)])
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::model::expr::Ident;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn region_with_body(body: Vec<Node>) -> Node {
Node::Region {
generator: Ident::from("test_op"),
source_region: None,
body: Arc::new(body),
}
}
fn program_with_entry(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn count_singleton_block_regions(node: &Node) -> usize {
let mut count = 0;
if let Node::Region { body, .. } = node {
if matches!(body.as_slice(), [Node::Block(_)]) {
count += 1;
}
for child in body.iter() {
count += count_singleton_block_regions(child);
}
}
match node {
Node::If {
then, otherwise, ..
} => {
for n in then {
count += count_singleton_block_regions(n);
}
for n in otherwise {
count += count_singleton_block_regions(n);
}
}
Node::Loop { body, .. } => {
for n in body {
count += count_singleton_block_regions(n);
}
}
Node::Block(body) => {
for n in body {
count += count_singleton_block_regions(n);
}
}
_ => {}
}
count
}
#[test]
fn skip_analysis_on_program_without_region() {
let entry = vec![Node::store("buf", Expr::u32(0), Expr::u32(7))];
let program = program_with_entry(entry);
assert_eq!(
RegionPromoteSingletonBlockPass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn skip_analysis_on_region_with_multiple_children() {
let entry = vec![region_with_body(vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::store("buf", Expr::u32(1), Expr::u32(8)),
])];
let program = program_with_entry(entry);
assert_eq!(
RegionPromoteSingletonBlockPass::analyze(&program),
PassAnalysis::SKIP
);
}
#[test]
fn run_analysis_on_singleton_block_region() {
let entry = vec![region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
assert_eq!(
RegionPromoteSingletonBlockPass::analyze(&program),
PassAnalysis::RUN
);
}
#[test]
fn transform_unwraps_simple_singleton_block() {
let entry = vec![region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(result.changed);
let total: usize = result
.program
.entry()
.iter()
.map(count_singleton_block_regions)
.sum();
assert_eq!(total, 0, "no singleton-block Regions must remain");
}
#[test]
fn transform_preserves_region_op_id() {
let entry = vec![region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
match &result.program.entry()[0] {
Node::Region { generator, .. } => {
assert_eq!(generator.as_str(), "test_op", "op id must be preserved");
}
_ => panic!(
"expected Region at top of entry; got {:?}",
result.program.entry()[0]
),
}
}
#[test]
fn transform_lifts_inner_children_to_region_body() {
let entry = vec![region_with_body(vec![Node::Block(vec![
Node::store("buf", Expr::u32(0), Expr::u32(1)),
Node::store("buf", Expr::u32(1), Expr::u32(2)),
Node::store("buf", Expr::u32(2), Expr::u32(3)),
])])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
match &result.program.entry()[0] {
Node::Region { body, .. } => {
assert_eq!(
body.len(),
3,
"the 3 inner Block children must be promoted to Region body"
);
}
other => panic!("expected Region; got {other:?}"),
}
}
#[test]
fn transform_skips_multi_child_region_unchanged() {
let entry = vec![region_with_body(vec![
Node::store("buf", Expr::u32(0), Expr::u32(7)),
Node::store("buf", Expr::u32(1), Expr::u32(8)),
])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(
!result.changed,
"multi-child Region must not match the singleton rule"
);
}
#[test]
fn transform_handles_nested_singleton_block_regions() {
let inner_region = region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(1),
)])]);
let entry = vec![region_with_body(vec![Node::Block(vec![inner_region])])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(result.changed);
let total: usize = result
.program
.entry()
.iter()
.map(count_singleton_block_regions)
.sum();
assert_eq!(
total, 0,
"every layer of singleton-block Region must be unwrapped, including nested"
);
}
#[test]
fn transform_is_idempotent() {
let entry = vec![region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
let once = RegionPromoteSingletonBlockPass::transform(program);
let twice_program = once.program.clone();
let twice = RegionPromoteSingletonBlockPass::transform(twice_program);
assert!(once.changed);
assert!(!twice.changed, "second run must report no change");
}
#[test]
fn transform_preserves_region_inside_removed_block() {
let entry = vec![Node::Block(vec![region_with_body(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
match &result.program.entry()[0] {
Node::Region { body, .. } => {
assert!(
matches!(body[0], Node::Region { .. }),
"inner Region must still be present after transparent Block removal"
);
}
other => panic!("expected root Region; got {other:?}"),
}
}
#[test]
fn transform_handles_empty_program() {
let program = Program::wrapped(vec![buf()], [1, 1, 1], vec![]);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(!result.changed);
}
#[test]
fn transform_unwraps_region_inside_if_branch() {
let inner_region = region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(1),
)])]);
let entry = vec![Node::if_then(
Expr::lt(Expr::u32(0), Expr::u32(1)),
vec![inner_region],
)];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(
result.changed,
"Regions inside If branches must be processed"
);
let total: usize = result
.program
.entry()
.iter()
.map(count_singleton_block_regions)
.sum();
assert_eq!(total, 0);
}
#[test]
fn transform_unwraps_region_inside_loop_body() {
let inner_region = region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(1),
)])]);
let entry = vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![inner_region],
}];
let program = program_with_entry(entry);
let result = RegionPromoteSingletonBlockPass::transform(program);
assert!(
result.changed,
"Regions inside Loop bodies must be processed"
);
let total: usize = result
.program
.entry()
.iter()
.map(count_singleton_block_regions)
.sum();
assert_eq!(total, 0);
}
#[test]
fn fingerprint_returns_stable_value() {
let entry = vec![region_with_body(vec![Node::Block(vec![Node::store(
"buf",
Expr::u32(0),
Expr::u32(7),
)])])];
let program = program_with_entry(entry);
let fp1 = RegionPromoteSingletonBlockPass::fingerprint(&program);
let fp2 = RegionPromoteSingletonBlockPass::fingerprint(&program);
assert_eq!(fp1, fp2);
}
}