use crate::ir_inner::model::node::Node;
use crate::ir_inner::model::program::Program;
use rustc_hash::FxHashMap;
use std::cell::RefCell;
pub const DEFAULT_INLINE_THRESHOLD: usize = 64;
thread_local! {
static SCRATCH_POOL: RefCell<Vec<Vec<Node>>> = const { RefCell::new(Vec::new()) };
}
fn take_scratch(min_capacity: usize) -> Vec<Node> {
SCRATCH_POOL.with(|pool| {
let mut pool = pool.borrow_mut();
if let Some(idx) = pool.iter().position(|v| v.capacity() >= min_capacity) {
let mut v = pool.swap_remove(idx);
v.clear();
v
} else {
Vec::with_capacity(min_capacity)
}
})
}
fn return_scratch(mut v: Vec<Node>) {
v.clear();
SCRATCH_POOL.with(|pool| {
if let Ok(mut pool) = pool.try_borrow_mut() {
pool.push(v);
}
});
}
#[must_use]
#[inline]
pub fn run(program: Program) -> Program {
run_with_threshold(program, DEFAULT_INLINE_THRESHOLD)
}
#[must_use]
pub fn run_with_threshold(program: Program, threshold: usize) -> Program {
let mut region_counts = FxHashMap::default();
let mut entry = Vec::with_capacity(program.entry().len());
inline_nodes_into(
program.entry().to_vec(),
threshold,
&mut region_counts,
&mut entry,
);
program.with_rewritten_entry(entry)
}
fn inline_nodes_into(
nodes: Vec<Node>,
threshold: usize,
region_counts: &mut FxHashMap<usize, usize>,
out: &mut Vec<Node>,
) {
for node in nodes {
match node {
Node::Region {
body,
generator,
source_region,
} => {
let count = count_nodes(&body, region_counts);
let body_vec = match std::sync::Arc::try_unwrap(body) {
Ok(v) => v,
Err(arc) => (*arc).clone(),
};
if count <= threshold {
inline_nodes_into(body_vec, threshold, region_counts, out);
} else {
let mut new_body = take_scratch(body_vec.len());
inline_nodes_into(body_vec, threshold, region_counts, &mut new_body);
out.push(Node::Region {
generator,
source_region,
body: std::sync::Arc::new(std::mem::take(&mut new_body)),
});
return_scratch(new_body);
}
}
Node::Block(children) => {
let mut new_children = take_scratch(children.len());
inline_nodes_into(children, threshold, region_counts, &mut new_children);
out.push(Node::Block(std::mem::take(&mut new_children)));
return_scratch(new_children);
}
Node::Loop {
var,
from,
to,
body,
} => {
let mut new_body = take_scratch(body.len());
inline_nodes_into(body, threshold, region_counts, &mut new_body);
out.push(Node::Loop {
var,
from,
to,
body: std::mem::take(&mut new_body),
});
return_scratch(new_body);
}
Node::If {
cond,
then,
otherwise,
} => {
let mut new_then = take_scratch(then.len());
let mut new_otherwise = take_scratch(otherwise.len());
inline_nodes_into(then, threshold, region_counts, &mut new_then);
inline_nodes_into(otherwise, threshold, region_counts, &mut new_otherwise);
out.push(Node::If {
cond,
then: std::mem::take(&mut new_then),
otherwise: std::mem::take(&mut new_otherwise),
});
return_scratch(new_then);
return_scratch(new_otherwise);
}
other => out.push(other),
}
}
}
fn count_nodes(nodes: &[Node], region_counts: &mut FxHashMap<usize, usize>) -> usize {
nodes
.iter()
.map(|n| match n {
Node::Block(children) => 1 + count_nodes(children, region_counts),
Node::Loop { body, .. } => 1 + count_nodes(body, region_counts),
Node::If {
then, otherwise, ..
} => 1 + count_nodes(then, region_counts) + count_nodes(otherwise, region_counts),
Node::Region { body, .. } => {
let key = std::sync::Arc::as_ptr(body) as usize;
if let Some(count) = region_counts.get(&key) {
1 + *count
} else {
let count = count_nodes(body, region_counts);
region_counts.insert(key, count);
1 + count
}
}
_ => 1,
})
.sum()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferDecl, DataType, Expr, Program};
#[test]
fn small_region_inlines() {
let body = vec![Node::store("out", Expr::u32(0), Expr::u32(42))];
let region = Node::Region {
generator: "test".into(),
source_region: None,
body: std::sync::Arc::new(body),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![region],
);
let optimized = run(prog);
assert!(
!matches!(&optimized.entry()[0], Node::Region { .. }),
"small Region must inline"
);
assert!(matches!(&optimized.entry()[0], Node::Store { .. }));
}
#[test]
fn large_region_stays_wrapped() {
let body: Vec<Node> = (0..100)
.map(|i| Node::store("out", Expr::u32(i), Expr::u32(i)))
.collect();
let region = Node::Region {
generator: "test".into(),
source_region: None,
body: std::sync::Arc::new(body),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![region],
);
let optimized = run_with_threshold(prog, 64);
assert!(
matches!(&optimized.entry()[0], Node::Region { .. }),
"large Region must stay wrapped"
);
}
#[test]
fn nested_small_regions_all_inline() {
let inner = Node::Region {
generator: "inner".into(),
source_region: None,
body: std::sync::Arc::new(vec![Node::store("out", Expr::u32(0), Expr::u32(1))]),
};
let outer = Node::Region {
generator: "outer".into(),
source_region: None,
body: std::sync::Arc::new(vec![inner]),
};
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![outer],
);
let optimized = run(prog);
assert_eq!(optimized.entry().len(), 1);
assert!(matches!(&optimized.entry()[0], Node::Store { .. }));
}
#[test]
fn regions_inside_loops_also_inline() {
let region = Node::Region {
generator: "inner".into(),
source_region: None,
body: std::sync::Arc::new(vec![Node::store("out", Expr::var("i"), Expr::u32(1))]),
};
let loop_node = Node::loop_for("i", Expr::u32(0), Expr::u32(4), vec![region]);
let prog = Program::wrapped(
vec![BufferDecl::read_write("out", 0, DataType::U32)],
[1, 1, 1],
vec![loop_node],
);
let optimized = run(prog);
let Node::Loop { body, .. } = &optimized.entry()[0] else {
panic!("expected Loop");
};
assert_eq!(body.len(), 1);
assert!(
matches!(&body[0], Node::Store { .. }),
"Region inside Loop must inline to just the Store"
);
}
}