use crate::ir::{Node, Program};
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
use std::sync::Arc;
#[derive(Debug, Default)]
#[vyre_pass(
name = "region_fusion_hint",
requires = [],
invalidates = []
)]
pub struct RegionFusionHintPass;
impl RegionFusionHintPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if program.entry().iter().any(|n| has_candidate_pair(n)) {
PassAnalysis::RUN
} else {
PassAnalysis::SKIP
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
let scaffold = program.with_rewritten_entry(Vec::new());
let mut changed = false;
let entry: Vec<Node> = fuse_in_body(program.into_entry_vec(), &mut changed);
PassResult {
program: scaffold.with_rewritten_entry(entry),
changed,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
const FUSION_RULES: &[(&str, &str, &str)] = &[
(
"vyre-libs::nn::linear",
"vyre-libs::nn::relu",
"vyre-libs::nn::linear_relu",
),
(
"vyre-libs::nn::linear",
"vyre-libs::nn::silu",
"vyre-libs::nn::linear_silu",
),
];
fn lookup_fused(left_gen: &str, right_gen: &str) -> Option<&'static str> {
FUSION_RULES
.iter()
.find(|(l, r, _)| *l == left_gen && *r == right_gen)
.map(|(_, _, f)| *f)
}
fn fuse_in_body(body: Vec<Node>, changed: &mut bool) -> Vec<Node> {
let body: Vec<Node> = body.into_iter().map(|n| recurse(n, changed)).collect();
let mut out: Vec<Node> = Vec::with_capacity(body.len());
let mut iter = body.into_iter().peekable();
while let Some(node) = iter.next() {
let Node::Region {
generator,
source_region,
body,
} = node
else {
out.push(node);
continue;
};
let next_match = matches!(
iter.peek(),
Some(Node::Region {
generator: g, ..
}) if lookup_fused(generator.as_str(), g.as_str()).is_some()
);
if !next_match {
out.push(Node::Region {
generator,
source_region,
body,
});
continue;
}
let Some(Node::Region {
generator: gen_b,
source_region: src_b,
body: body_b,
}) = iter.next()
else {
unreachable!("peek confirmed Region above");
};
let _ = src_b;
let fused_gen = lookup_fused(generator.as_str(), gen_b.as_str()).unwrap();
let mut fused_body: Vec<Node> = match Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
let body_b_vec: Vec<Node> = match Arc::try_unwrap(body_b) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
fused_body.extend(body_b_vec);
*changed = true;
out.push(Node::Region {
generator: crate::ir::Ident::from(fused_gen),
source_region,
body: Arc::new(fused_body),
});
}
out
}
fn recurse(node: Node, changed: &mut bool) -> Node {
match node {
Node::If {
cond,
then,
otherwise,
} => Node::If {
cond,
then: fuse_in_body(then, changed),
otherwise: fuse_in_body(otherwise, changed),
},
Node::Loop {
var,
from,
to,
body,
} => Node::Loop {
var,
from,
to,
body: fuse_in_body(body, changed),
},
Node::Block(body) => Node::Block(fuse_in_body(body, changed)),
Node::Region {
generator,
source_region,
body,
} => {
let body_vec: Vec<Node> = match Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
Node::Region {
generator,
source_region,
body: Arc::new(fuse_in_body(body_vec, changed)),
}
}
other => other,
}
}
fn has_candidate_pair(node: &Node) -> bool {
match node {
Node::Region { body, .. } => {
let body = body.as_ref();
for window in body.windows(2) {
if let (Node::Region { generator: a, .. }, Node::Region { generator: b, .. }) =
(&window[0], &window[1])
{
if lookup_fused(a.as_str(), b.as_str()).is_some() {
return true;
}
}
}
body.iter().any(has_candidate_pair)
}
Node::If {
then, otherwise, ..
} => {
for window in then.windows(2) {
if let (Node::Region { generator: a, .. }, Node::Region { generator: b, .. }) =
(&window[0], &window[1])
{
if lookup_fused(a.as_str(), b.as_str()).is_some() {
return true;
}
}
}
for window in otherwise.windows(2) {
if let (Node::Region { generator: a, .. }, Node::Region { generator: b, .. }) =
(&window[0], &window[1])
{
if lookup_fused(a.as_str(), b.as_str()).is_some() {
return true;
}
}
}
then.iter().any(has_candidate_pair) || otherwise.iter().any(has_candidate_pair)
}
Node::Loop { body, .. } => {
for window in body.windows(2) {
if let (Node::Region { generator: a, .. }, Node::Region { generator: b, .. }) =
(&window[0], &window[1])
{
if lookup_fused(a.as_str(), b.as_str()).is_some() {
return true;
}
}
}
body.iter().any(has_candidate_pair)
}
Node::Block(body) => {
for window in body.windows(2) {
if let (Node::Region { generator: a, .. }, Node::Region { generator: b, .. }) =
(&window[0], &window[1])
{
if lookup_fused(a.as_str(), b.as_str()).is_some() {
return true;
}
}
}
body.iter().any(has_candidate_pair)
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Ident, Node};
fn buf() -> BufferDecl {
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn region(gen: &str, body: Vec<Node>) -> Node {
Node::Region {
generator: Ident::from(gen),
source_region: None,
body: Arc::new(body),
}
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf()], [1, 1, 1], entry)
}
fn region_generators(nodes: &[Node]) -> Vec<String> {
let mut out = Vec::new();
fn walk(nodes: &[Node], out: &mut Vec<String>) {
for n in nodes {
if let Node::Region {
generator, body, ..
} = n
{
out.push(generator.as_str().to_owned());
walk(body.as_ref(), out);
}
match n {
Node::If {
then, otherwise, ..
} => {
walk(then, out);
walk(otherwise, out);
}
Node::Loop { body, .. } => walk(body, out),
Node::Block(body) => walk(body, out),
_ => {}
}
}
}
walk(nodes, &mut out);
out
}
#[test]
fn fuses_linear_then_relu() {
let entry = vec![
region("vyre-libs::nn::linear", vec![Node::Return]),
region("vyre-libs::nn::relu", vec![Node::Return]),
];
let result = RegionFusionHintPass::transform(program(entry));
assert!(result.changed, "linear+relu must fuse");
let gens = region_generators(result.program.entry());
assert!(
gens.iter().any(|g| g == "vyre-libs::nn::linear_relu"),
"generators after fuse: {gens:?}"
);
}
#[test]
fn fuses_linear_then_silu() {
let entry = vec![
region("vyre-libs::nn::linear", vec![Node::Return]),
region("vyre-libs::nn::silu", vec![Node::Return]),
];
let result = RegionFusionHintPass::transform(program(entry));
assert!(result.changed, "linear+silu must fuse");
let gens = region_generators(result.program.entry());
assert!(
gens.iter().any(|g| g == "vyre-libs::nn::linear_silu"),
"generators after fuse: {gens:?}"
);
}
#[test]
fn keeps_when_order_reversed() {
let entry = vec![
region("vyre-libs::nn::relu", vec![Node::Return]),
region("vyre-libs::nn::linear", vec![Node::Return]),
];
let result = RegionFusionHintPass::transform(program(entry));
assert!(!result.changed, "wrong order must not fuse");
}
#[test]
fn keeps_when_no_rule_matches() {
let entry = vec![
region("foo::bar", vec![Node::Return]),
region("baz::qux", vec![Node::Return]),
];
let result = RegionFusionHintPass::transform(program(entry));
assert!(!result.changed);
}
#[test]
fn analyze_skips_when_no_candidate() {
let entry = vec![region("foo::bar", vec![Node::Return])];
let prog = program(entry);
match RegionFusionHintPass::analyze(&prog) {
PassAnalysis::SKIP => {}
other => panic!("expected SKIP, got {other:?}"),
}
}
#[test]
fn fuses_inside_wrapping_region() {
let inner = vec![
region("vyre-libs::nn::linear", vec![Node::Return]),
region("vyre-libs::nn::relu", vec![Node::Return]),
];
let entry = vec![region("wrapper", inner)];
let result = RegionFusionHintPass::transform(program(entry));
assert!(result.changed);
let gens = region_generators(result.program.entry());
assert!(
gens.iter().any(|g| g == "vyre-libs::nn::linear_relu"),
"generators: {gens:?}"
);
}
}