use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::Program;
use smallvec::SmallVec;
pub const DEFAULT_INLINE_THRESHOLD: usize = 64;
fn take_scratch(pool: &mut Vec<Vec<Node>>, min_capacity: usize) -> Vec<Node> {
if let Some(idx) = pool.iter().position(|v| v.capacity() >= min_capacity) {
let mut v = pool.swap_remove(idx);
v.clear();
v
} else {
Vec::with_capacity(min_capacity)
}
}
fn return_scratch(pool: &mut Vec<Vec<Node>>, mut v: Vec<Node>) {
v.clear();
pool.push(v);
}
#[must_use]
#[inline]
pub fn run(program: Program) -> Program {
run_with_threshold(program, DEFAULT_INLINE_THRESHOLD)
}
#[must_use]
pub fn run_with_threshold(program: Program, threshold: usize) -> Program {
program.map_entry(|owned_entry| {
let mut entry = Vec::with_capacity(owned_entry.len());
let mut scratch_pool = Vec::new();
inline_nodes_into(owned_entry, threshold, &mut scratch_pool, &mut entry);
entry
})
}
fn inline_nodes_into(
nodes: Vec<Node>,
threshold: usize,
scratch_pool: &mut Vec<Vec<Node>>,
out: &mut Vec<Node>,
) {
for node in nodes {
match node {
Node::Region {
body,
generator,
source_region,
} => {
let count = count_nodes_capped(&body, threshold);
let body_vec = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
if count <= threshold {
inline_nodes_into(body_vec, threshold, scratch_pool, out);
} else {
let mut new_body = take_scratch(scratch_pool, body_vec.len());
inline_nodes_into(body_vec, threshold, scratch_pool, &mut new_body);
out.push(Node::Region {
generator,
source_region,
body: std::sync::Arc::new(std::mem::take(&mut new_body)),
});
return_scratch(scratch_pool, new_body);
}
}
Node::Block(children) => {
let mut new_children = take_scratch(scratch_pool, children.len());
inline_nodes_into(children, threshold, scratch_pool, &mut new_children);
out.push(Node::Block(std::mem::take(&mut new_children)));
return_scratch(scratch_pool, new_children);
}
Node::Loop {
var,
from,
to,
body,
} => {
let mut new_body = take_scratch(scratch_pool, body.len());
inline_nodes_into(body, threshold, scratch_pool, &mut new_body);
out.push(Node::Loop {
var,
from,
to,
body: std::mem::take(&mut new_body),
});
return_scratch(scratch_pool, new_body);
}
Node::If {
cond,
then,
otherwise,
} => {
let mut new_then = take_scratch(scratch_pool, then.len());
let mut new_otherwise = take_scratch(scratch_pool, otherwise.len());
inline_nodes_into(then, threshold, scratch_pool, &mut new_then);
inline_nodes_into(otherwise, threshold, scratch_pool, &mut new_otherwise);
out.push(Node::If {
cond,
then: std::mem::take(&mut new_then),
otherwise: std::mem::take(&mut new_otherwise),
});
return_scratch(scratch_pool, new_then);
return_scratch(scratch_pool, new_otherwise);
}
other => out.push(other),
}
}
}
fn count_nodes_capped(nodes: &[Node], threshold: usize) -> usize {
let cap = threshold.saturating_add(1);
let mut count = 0usize;
let mut stack: SmallVec<[&[Node]; 16]> = SmallVec::new();
stack.push(nodes);
while let Some(nodes) = stack.pop() {
for node in nodes {
count = count.saturating_add(1);
if count >= cap {
return cap;
}
match node {
Node::Block(children) | Node::Loop { body: children, .. } => {
stack.push(children);
}
Node::If {
then, otherwise, ..
} => {
stack.push(otherwise);
stack.push(then);
}
Node::Region { body, .. } => {
stack.push(body);
}
_ => {}
}
}
}
count
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Program};
#[test]
fn small_region_inlines() {
let body = vec![Node::store("out", Expr::u32(0), Expr::u32(42))];
let region = Node::Region {
generator: "test".into(),
source_region: None,
body: std::sync::Arc::new(body),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![region],
);
let optimized = run(prog);
assert!(
!matches!(&optimized.entry()[0], Node::Region { .. }),
"small Region must inline"
);
assert!(matches!(&optimized.entry()[0], Node::Store { .. }));
}
#[test]
fn large_region_stays_wrapped() {
let body: Vec<Node> = (0..100)
.map(|i| Node::store("out", Expr::u32(i), Expr::u32(i)))
.collect();
let region = Node::Region {
generator: "test".into(),
source_region: None,
body: std::sync::Arc::new(body),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![region],
);
let optimized = run_with_threshold(prog, 64);
assert!(
matches!(&optimized.entry()[0], Node::Region { .. }),
"large Region must stay wrapped"
);
}
#[test]
fn generated_large_region_count_is_capped_at_inline_threshold() {
let body: Vec<Node> = (0..4096)
.map(|i| Node::store("out", Expr::u32(i), Expr::u32(i)))
.collect();
assert_eq!(
count_nodes_capped(&body, 64),
65,
"Fix: region-inline must stop counting once a generated body exceeds the inline threshold."
);
}
#[test]
fn nested_small_regions_all_inline() {
let inner = Node::Region {
generator: "inner".into(),
source_region: None,
body: std::sync::Arc::new(vec![Node::store("out", Expr::u32(0), Expr::u32(1))]),
};
let outer = Node::Region {
generator: "outer".into(),
source_region: None,
body: std::sync::Arc::new(vec![inner]),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![outer],
);
let optimized = run(prog);
assert_eq!(optimized.entry().len(), 1);
assert!(matches!(&optimized.entry()[0], Node::Store { .. }));
}
#[test]
fn regions_inside_loops_also_inline() {
let region = Node::Region {
generator: "inner".into(),
source_region: None,
body: std::sync::Arc::new(vec![Node::store("out", Expr::var("i"), Expr::u32(1))]),
};
let loop_node = Node::loop_for("i", Expr::u32(0), Expr::u32(4), vec![region]);
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![loop_node],
);
let optimized = run(prog);
let Node::Loop { body, .. } = &optimized.entry()[0] else {
panic!("expected Loop");
};
assert_eq!(body.len(), 1);
assert!(
matches!(&body[0], Node::Store { .. }),
"Region inside Loop must inline to just the Store"
);
}
}