use crate::ir::{AtomicOp, Expr, Ident, Node, Program};
use crate::ir_inner::model::expr::GeneratorRef;
use rustc_hash::FxHashMap;
use std::sync::OnceLock;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct NodeIndex(pub u32);
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
#[repr(u8)]
pub enum NodeKind {
Let,
Assign,
Store,
If,
Loop,
IndirectDispatch,
AsyncLoad,
AsyncStore,
AsyncWait,
Trap,
Resume,
Return,
Barrier,
Block,
Region,
Opaque,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
pub enum BufferRefKind {
Read,
Write,
Atomic(AtomicOp),
AsyncDestination,
AsyncSource,
IndirectCount,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RegionMeta {
pub node: NodeIndex,
pub generator: Ident,
pub source_region: Option<GeneratorRef>,
}
#[derive(Debug, Default)]
pub struct ProgramFacts {
kinds: Vec<NodeKind>,
parent: Vec<Option<NodeIndex>>,
lets: Vec<(NodeIndex, Ident)>,
assigns: Vec<(NodeIndex, Ident)>,
loop_vars: Vec<(NodeIndex, Ident)>,
var_reads: Vec<(NodeIndex, Ident)>,
buffer_refs: Vec<(NodeIndex, Ident, BufferRefKind)>,
regions: Vec<RegionMeta>,
let_index: OnceLock<FxHashMap<Ident, Vec<NodeIndex>>>,
assign_index: OnceLock<FxHashMap<Ident, Vec<NodeIndex>>>,
var_read_index: OnceLock<FxHashMap<Ident, Vec<NodeIndex>>>,
buffer_index: OnceLock<FxHashMap<Ident, Vec<(NodeIndex, BufferRefKind)>>>,
region_index_by_node: OnceLock<FxHashMap<NodeIndex, usize>>,
region_index_by_generator: OnceLock<FxHashMap<Ident, Vec<usize>>>,
}
impl ProgramFacts {
#[must_use]
pub fn build(program: &Program) -> Self {
let mut facts = Self::default();
for node in program.entry() {
walk_node(node, None, &mut facts);
}
facts
}
#[must_use]
pub fn node_count(&self) -> usize {
self.kinds.len()
}
#[must_use]
pub fn kind_at(&self, idx: NodeIndex) -> NodeKind {
self.kinds[idx.0 as usize]
}
#[must_use]
pub fn parent_of(&self, idx: NodeIndex) -> Option<NodeIndex> {
self.parent[idx.0 as usize]
}
#[must_use]
pub fn is_descendant_of(&self, node: NodeIndex, ancestor: NodeIndex) -> bool {
let mut current = Some(node);
while let Some(idx) = current {
if idx == ancestor {
return true;
}
current = self.parent_of(idx);
}
false
}
pub fn iter_nodes(&self) -> impl Iterator<Item = (NodeIndex, NodeKind)> + '_ {
self.kinds
.iter()
.copied()
.enumerate()
.map(|(i, kind)| (NodeIndex(i as u32), kind))
}
pub fn iter_regionless_nodes(&self) -> impl Iterator<Item = (NodeIndex, NodeKind)> + '_ {
self.iter_nodes()
.filter(|(_, kind)| *kind != NodeKind::Region)
}
#[must_use]
pub fn regionless_parent_of(&self, idx: NodeIndex) -> Option<NodeIndex> {
let mut parent = self.parent_of(idx);
while let Some(candidate) = parent {
if self.kind_at(candidate) != NodeKind::Region {
return Some(candidate);
}
parent = self.parent_of(candidate);
}
None
}
#[must_use]
pub fn has_kind(&self, kind: NodeKind) -> bool {
self.kinds.iter().any(|&k| k == kind)
}
#[must_use]
pub fn lets(&self) -> &[(NodeIndex, Ident)] {
&self.lets
}
#[must_use]
pub fn assigns(&self) -> &[(NodeIndex, Ident)] {
&self.assigns
}
#[must_use]
pub fn loop_vars(&self) -> &[(NodeIndex, Ident)] {
&self.loop_vars
}
#[must_use]
pub fn var_reads(&self) -> &[(NodeIndex, Ident)] {
&self.var_reads
}
#[must_use]
pub fn buffer_refs(&self) -> &[(NodeIndex, Ident, BufferRefKind)] {
&self.buffer_refs
}
#[must_use]
pub fn let_sites_of(&self, name: &str) -> &[NodeIndex] {
let map = self.let_index.get_or_init(|| build_index(&self.lets));
map.get(name).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn assign_sites_of(&self, name: &str) -> &[NodeIndex] {
let map = self.assign_index.get_or_init(|| build_index(&self.assigns));
map.get(name).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn var_read_sites_of(&self, name: &str) -> &[NodeIndex] {
let map = self
.var_read_index
.get_or_init(|| build_index(&self.var_reads));
map.get(name).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn buffer_refs_of(&self, name: &str) -> &[(NodeIndex, BufferRefKind)] {
let map = self.buffer_index.get_or_init(|| {
let mut out: FxHashMap<Ident, Vec<(NodeIndex, BufferRefKind)>> = FxHashMap::default();
for (idx, buffer, kind) in &self.buffer_refs {
out.entry(buffer.clone()).or_default().push((*idx, *kind));
}
out
});
map.get(name).map_or(&[], Vec::as_slice)
}
#[must_use]
pub fn regions(&self) -> &[RegionMeta] {
&self.regions
}
#[must_use]
pub fn region_at(&self, idx: NodeIndex) -> Option<&RegionMeta> {
let map = self.region_index_by_node.get_or_init(|| {
self.regions
.iter()
.enumerate()
.map(|(i, meta)| (meta.node, i))
.collect()
});
map.get(&idx).and_then(|&i| self.regions.get(i))
}
pub fn regions_by_generator(&self, generator: &str) -> impl Iterator<Item = &RegionMeta> + '_ {
let map = self.region_index_by_generator.get_or_init(|| {
let mut out: FxHashMap<Ident, Vec<usize>> = FxHashMap::default();
for (i, meta) in self.regions.iter().enumerate() {
out.entry(meta.generator.clone()).or_default().push(i);
}
out
});
map.get(generator)
.map(|v| v.as_slice())
.unwrap_or(&[])
.iter()
.filter_map(move |&i| self.regions.get(i))
}
#[must_use]
pub fn is_name_rebound(&self, name: &str) -> bool {
let lets = self.let_sites_of(name);
if lets.len() > 1 {
return true;
}
if !self.assign_sites_of(name).is_empty() {
return true;
}
self.loop_vars.iter().any(|(_, var)| var.as_str() == name)
}
#[must_use]
pub fn buffers_provably_distinct(&self, buf_a: &str, buf_b: &str) -> bool {
if buf_a == buf_b {
return false;
}
let a_seen = self.buffer_refs.iter().any(|(_, b, _)| b.as_str() == buf_a);
let b_seen = self.buffer_refs.iter().any(|(_, b, _)| b.as_str() == buf_b);
a_seen && b_seen
}
#[must_use]
pub fn buffer_escapes(&self, name: &str) -> bool {
self.buffer_refs.iter().any(|(_, b, kind)| {
b.as_str() == name
&& matches!(
kind,
BufferRefKind::Write
| BufferRefKind::Atomic(_)
| BufferRefKind::AsyncDestination
| BufferRefKind::IndirectCount
)
})
}
#[must_use]
pub fn escaping_buffers(&self) -> FxHashMap<Ident, ()> {
let mut out: FxHashMap<Ident, ()> = FxHashMap::default();
for (_, name, kind) in &self.buffer_refs {
if matches!(
kind,
BufferRefKind::Write
| BufferRefKind::Atomic(_)
| BufferRefKind::AsyncDestination
| BufferRefKind::IndirectCount
) {
out.insert(name.clone(), ());
}
}
out
}
}
fn build_index(rows: &[(NodeIndex, Ident)]) -> FxHashMap<Ident, Vec<NodeIndex>> {
let mut out: FxHashMap<Ident, Vec<NodeIndex>> = FxHashMap::default();
for (idx, name) in rows {
out.entry(name.clone()).or_default().push(*idx);
}
out
}
fn record_node(facts: &mut ProgramFacts, kind: NodeKind, parent: Option<NodeIndex>) -> NodeIndex {
let idx = NodeIndex(facts.kinds.len() as u32);
facts.kinds.push(kind);
facts.parent.push(parent);
idx
}
fn walk_node(node: &Node, parent: Option<NodeIndex>, facts: &mut ProgramFacts) {
match node {
Node::Let { name, value } => {
let idx = record_node(facts, NodeKind::Let, parent);
facts.lets.push((idx, name.clone()));
walk_expr(value, idx, facts);
}
Node::Assign { name, value } => {
let idx = record_node(facts, NodeKind::Assign, parent);
facts.assigns.push((idx, name.clone()));
walk_expr(value, idx, facts);
}
Node::Store {
buffer,
index,
value,
} => {
let idx = record_node(facts, NodeKind::Store, parent);
facts
.buffer_refs
.push((idx, buffer.clone(), BufferRefKind::Write));
walk_expr(index, idx, facts);
walk_expr(value, idx, facts);
}
Node::If {
cond,
then,
otherwise,
} => {
let idx = record_node(facts, NodeKind::If, parent);
walk_expr(cond, idx, facts);
for n in then {
walk_node(n, Some(idx), facts);
}
for n in otherwise {
walk_node(n, Some(idx), facts);
}
}
Node::Loop {
var,
from,
to,
body,
} => {
let idx = record_node(facts, NodeKind::Loop, parent);
facts.loop_vars.push((idx, var.clone()));
walk_expr(from, idx, facts);
walk_expr(to, idx, facts);
for n in body {
walk_node(n, Some(idx), facts);
}
}
Node::IndirectDispatch { count_buffer, .. } => {
let idx = record_node(facts, NodeKind::IndirectDispatch, parent);
facts
.buffer_refs
.push((idx, count_buffer.clone(), BufferRefKind::IndirectCount));
}
Node::AsyncLoad {
source,
destination,
offset,
size,
..
} => {
let idx = record_node(facts, NodeKind::AsyncLoad, parent);
facts
.buffer_refs
.push((idx, source.clone(), BufferRefKind::AsyncSource));
facts
.buffer_refs
.push((idx, destination.clone(), BufferRefKind::AsyncDestination));
walk_expr(offset, idx, facts);
walk_expr(size, idx, facts);
}
Node::AsyncStore {
source,
destination,
offset,
size,
..
} => {
let idx = record_node(facts, NodeKind::AsyncStore, parent);
facts
.buffer_refs
.push((idx, source.clone(), BufferRefKind::AsyncSource));
facts
.buffer_refs
.push((idx, destination.clone(), BufferRefKind::Write));
walk_expr(offset, idx, facts);
walk_expr(size, idx, facts);
}
Node::AsyncWait { .. } => {
record_node(facts, NodeKind::AsyncWait, parent);
}
Node::Trap { address, .. } => {
let idx = record_node(facts, NodeKind::Trap, parent);
walk_expr(address, idx, facts);
}
Node::Resume { .. } => {
record_node(facts, NodeKind::Resume, parent);
}
Node::Return => {
record_node(facts, NodeKind::Return, parent);
}
Node::Barrier { .. } => {
record_node(facts, NodeKind::Barrier, parent);
}
Node::Block(body) => {
let idx = record_node(facts, NodeKind::Block, parent);
for n in body {
walk_node(n, Some(idx), facts);
}
}
Node::Region {
generator,
source_region,
body,
} => {
let idx = record_node(facts, NodeKind::Region, parent);
facts.regions.push(RegionMeta {
node: idx,
generator: generator.clone(),
source_region: source_region.clone(),
});
for n in body.iter() {
walk_node(n, Some(idx), facts);
}
}
Node::Opaque(_) => {
record_node(facts, NodeKind::Opaque, parent);
}
}
}
fn walk_expr(expr: &Expr, owning_node: NodeIndex, facts: &mut ProgramFacts) {
match expr {
Expr::Var(name) => {
facts.var_reads.push((owning_node, name.clone()));
}
Expr::Load { buffer, index } => {
facts
.buffer_refs
.push((owning_node, buffer.clone(), BufferRefKind::Read));
walk_expr(index, owning_node, facts);
}
Expr::BufLen { buffer } => {
facts
.buffer_refs
.push((owning_node, buffer.clone(), BufferRefKind::Read));
}
Expr::Atomic {
op,
buffer,
index,
expected,
value,
..
} => {
facts
.buffer_refs
.push((owning_node, buffer.clone(), BufferRefKind::Atomic(*op)));
walk_expr(index, owning_node, facts);
if let Some(e) = expected.as_deref() {
walk_expr(e, owning_node, facts);
}
walk_expr(value, owning_node, facts);
}
Expr::BinOp { left, right, .. } => {
walk_expr(left, owning_node, facts);
walk_expr(right, owning_node, facts);
}
Expr::UnOp { operand, .. } => walk_expr(operand, owning_node, facts),
Expr::Call { args, .. } => {
for arg in args {
walk_expr(arg, owning_node, facts);
}
}
Expr::Select {
cond,
true_val,
false_val,
} => {
walk_expr(cond, owning_node, facts);
walk_expr(true_val, owning_node, facts);
walk_expr(false_val, owning_node, facts);
}
Expr::Cast { value, .. } => walk_expr(value, owning_node, facts),
Expr::Fma { a, b, c } => {
walk_expr(a, owning_node, facts);
walk_expr(b, owning_node, facts);
walk_expr(c, owning_node, facts);
}
Expr::SubgroupBallot { cond } => walk_expr(cond, owning_node, facts),
Expr::SubgroupShuffle { value, lane } => {
walk_expr(value, owning_node, facts);
walk_expr(lane, owning_node, facts);
}
Expr::SubgroupAdd { value } => walk_expr(value, owning_node, facts),
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize
| Expr::Opaque(_) => {}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
use crate::runtime::memory_model::MemoryOrdering;
fn buf(name: &str) -> BufferDecl {
BufferDecl::storage(name, 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn program(entry: Vec<Node>) -> Program {
Program::wrapped(vec![buf("a"), buf("b")], [1, 1, 1], entry)
}
#[test]
fn empty_program_has_only_region_node() {
let facts = ProgramFacts::build(&program(Vec::new()));
assert_eq!(facts.node_count(), 1);
assert_eq!(facts.kind_at(NodeIndex(0)), NodeKind::Region);
assert!(facts.lets().is_empty());
assert!(facts.var_reads().is_empty());
assert!(facts.buffer_refs().is_empty());
}
#[test]
fn let_sites_recorded_in_preorder() {
let facts = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(1)),
Node::let_bind("y", Expr::u32(2)),
]));
let lets = facts.lets();
assert_eq!(lets.len(), 2);
assert_eq!(lets[0].1.as_str(), "x");
assert_eq!(lets[1].1.as_str(), "y");
}
#[test]
fn nested_if_collects_var_reads_and_buffer_refs() {
let facts = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(7)),
Node::If {
cond: Expr::var("c"),
then: vec![Node::store("a", Expr::var("x"), Expr::u32(1))],
otherwise: vec![Node::store("b", Expr::var("x"), Expr::u32(2))],
},
]));
let var_reads: Vec<&str> = facts.var_reads().iter().map(|(_, n)| n.as_str()).collect();
assert!(var_reads.contains(&"c"));
let x_count = var_reads.iter().filter(|n| **n == "x").count();
assert_eq!(x_count, 2, "x read in both arms");
let a_writes: Vec<_> = facts
.buffer_refs_of("a")
.iter()
.filter(|(_, k)| *k == BufferRefKind::Write)
.collect();
assert_eq!(a_writes.len(), 1);
let b_writes: Vec<_> = facts
.buffer_refs_of("b")
.iter()
.filter(|(_, k)| *k == BufferRefKind::Write)
.collect();
assert_eq!(b_writes.len(), 1);
}
#[test]
fn let_sites_of_resolves_via_lookup_index() {
let facts = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(1)),
Node::Block(vec![Node::let_bind("x", Expr::u32(2))]),
]));
let sites = facts.let_sites_of("x");
assert_eq!(sites.len(), 2, "both Let-sites of `x` are recorded");
assert!(facts.let_sites_of("missing").is_empty());
}
#[test]
fn descendant_query_uses_parent_column() {
let facts = ProgramFacts::build(&program(vec![Node::Block(vec![Node::let_bind(
"x",
Expr::u32(1),
)])]));
let root = facts.regions()[0].node;
let let_idx = facts.lets()[0].0;
assert!(facts.is_descendant_of(root, root));
assert!(facts.is_descendant_of(let_idx, root));
assert!(!facts.is_descendant_of(root, let_idx));
}
#[test]
fn atomic_buffer_refs_record_op() {
let facts = ProgramFacts::build(&program(vec![Node::let_bind(
"x",
Expr::Atomic {
op: AtomicOp::Add,
buffer: Ident::from("a"),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(1)),
ordering: MemoryOrdering::Relaxed,
},
)]));
let touches = facts.buffer_refs_of("a");
assert_eq!(touches.len(), 1);
assert_eq!(touches[0].1, BufferRefKind::Atomic(AtomicOp::Add));
}
#[test]
fn is_name_rebound_detects_every_shape() {
let facts_single = ProgramFacts::build(&program(vec![Node::let_bind("x", Expr::u32(1))]));
assert!(!facts_single.is_name_rebound("x"));
assert!(!facts_single.is_name_rebound("y"));
let facts_assign = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(1)),
Node::Assign {
name: Ident::from("x"),
value: Expr::u32(2),
},
]));
assert!(facts_assign.is_name_rebound("x"));
let facts_loop = ProgramFacts::build(&program(vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![],
}]));
assert!(facts_loop.is_name_rebound("i"));
let facts_double_let = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(1)),
Node::Block(vec![Node::let_bind("x", Expr::u32(2))]),
]));
assert!(facts_double_let.is_name_rebound("x"));
}
#[test]
fn loop_vars_recorded_for_every_loop() {
let facts = ProgramFacts::build(&program(vec![Node::Loop {
var: Ident::from("i"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![Node::Loop {
var: Ident::from("j"),
from: Expr::u32(0),
to: Expr::u32(4),
body: vec![],
}],
}]));
let names: Vec<&str> = facts.loop_vars().iter().map(|(_, n)| n.as_str()).collect();
assert_eq!(names, vec!["i", "j"]);
}
#[test]
fn parent_of_reports_immediate_container() {
let facts = ProgramFacts::build(&program(vec![Node::If {
cond: Expr::var("c"),
then: vec![Node::let_bind("x", Expr::u32(1))],
otherwise: vec![Node::let_bind("y", Expr::u32(2))],
}]));
let region = NodeIndex(0);
assert_eq!(facts.kind_at(region), NodeKind::Region);
assert_eq!(facts.parent_of(region), None);
let if_idx = facts
.iter_nodes()
.find(|(_, k)| *k == NodeKind::If)
.map(|(i, _)| i)
.expect("If node present");
assert_eq!(facts.parent_of(if_idx), Some(region));
let let_idxs: Vec<_> = facts.lets().iter().map(|(i, _)| *i).collect();
for let_idx in let_idxs {
assert_eq!(facts.parent_of(let_idx), Some(if_idx));
}
}
#[test]
fn buffer_refs_of_separates_read_and_write() {
let facts = ProgramFacts::build(&program(vec![Node::store(
"a",
Expr::u32(0),
Expr::Load {
buffer: Ident::from("b"),
index: Box::new(Expr::u32(0)),
},
)]));
let a_touches = facts.buffer_refs_of("a");
assert_eq!(a_touches.len(), 1);
assert_eq!(a_touches[0].1, BufferRefKind::Write);
let b_touches = facts.buffer_refs_of("b");
assert_eq!(b_touches.len(), 1);
assert_eq!(b_touches[0].1, BufferRefKind::Read);
}
#[test]
fn has_kind_short_circuits_missing_variants() {
let facts = ProgramFacts::build(&program(vec![Node::let_bind("x", Expr::u32(1))]));
assert!(facts.has_kind(NodeKind::Let));
assert!(!facts.has_kind(NodeKind::Loop));
assert!(!facts.has_kind(NodeKind::Trap));
}
#[test]
fn iter_nodes_yields_preorder() {
let facts = ProgramFacts::build(&program(vec![
Node::let_bind("x", Expr::u32(1)),
Node::let_bind("y", Expr::u32(2)),
]));
let kinds: Vec<NodeKind> = facts.iter_nodes().map(|(_, k)| k).collect();
assert_eq!(kinds, vec![NodeKind::Region, NodeKind::Let, NodeKind::Let]);
}
#[test]
fn regions_records_wrapping_and_nested() {
let inner = Node::Region {
generator: Ident::from("inner_pass"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("x", Expr::u32(1))]),
};
let facts = ProgramFacts::build(&program(vec![Node::let_bind("z", Expr::u32(0)), inner]));
let regions = facts.regions();
assert_eq!(regions.len(), 2, "wrapping Region + inner Region");
assert!(regions.iter().any(|r| r.generator.as_str() == "inner_pass"));
}
#[test]
fn region_at_resolves_by_node_index() {
let inner = Node::Region {
generator: Ident::from("custom"),
source_region: None,
body: std::sync::Arc::new(vec![]),
};
let facts = ProgramFacts::build(&program(vec![inner]));
let region_idx = facts
.iter_nodes()
.filter(|(_, k)| *k == NodeKind::Region)
.map(|(i, _)| i)
.find(|i| {
facts
.region_at(*i)
.map(|m| m.generator.as_str() == "custom")
.unwrap_or(false)
})
.expect("custom-generator Region present");
let meta = facts.region_at(region_idx).expect("region recorded");
assert_eq!(meta.generator.as_str(), "custom");
assert_eq!(meta.source_region, None);
let let_idx = facts.lets().get(0).map(|(i, _)| *i);
if let Some(let_idx) = let_idx {
assert!(facts.region_at(let_idx).is_none());
}
}
#[test]
fn regions_by_generator_filters_by_ident() {
let entry = vec![
Node::Region {
generator: Ident::from("vec"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("x", Expr::u32(1))]),
},
Node::Region {
generator: Ident::from("dce"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("y", Expr::u32(2))]),
},
Node::Region {
generator: Ident::from("vec"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("z", Expr::u32(3))]),
},
];
let facts = ProgramFacts::build(&program(entry));
let vec_count = facts.regions_by_generator("vec").count();
assert_eq!(vec_count, 2);
let dce_count = facts.regions_by_generator("dce").count();
assert_eq!(dce_count, 1);
let missing = facts.regions_by_generator("missing").count();
assert_eq!(missing, 0);
}
#[test]
fn regionless_nodes_skip_provenance_wrappers() {
let facts = ProgramFacts::build(&program(vec![
Node::let_bind("root", Expr::u32(0)),
Node::Region {
generator: Ident::from("inner"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("nested", Expr::u32(1))]),
},
]));
let kinds: Vec<NodeKind> = facts
.iter_regionless_nodes()
.map(|(_, kind)| kind)
.collect();
assert_eq!(kinds, vec![NodeKind::Let, NodeKind::Let]);
}
#[test]
fn regionless_parent_skips_only_region_ancestors() {
let facts = ProgramFacts::build(&program(vec![Node::Block(vec![Node::Region {
generator: Ident::from("inner"),
source_region: None,
body: std::sync::Arc::new(vec![Node::let_bind("x", Expr::u32(1))]),
}])]));
let block = facts
.iter_nodes()
.find(|(_, kind)| *kind == NodeKind::Block)
.map(|(idx, _)| idx)
.expect("Block node present");
let let_idx = facts.lets()[0].0;
assert_eq!(facts.regionless_parent_of(block), None);
assert_eq!(facts.regionless_parent_of(let_idx), Some(block));
}
#[test]
fn buffers_provably_distinct_for_distinct_names() {
let facts = ProgramFacts::build(&program(vec![
Node::store("a", Expr::u32(0), Expr::u32(1)),
Node::store("b", Expr::u32(0), Expr::u32(2)),
]));
assert!(facts.buffers_provably_distinct("a", "b"));
assert!(facts.buffers_provably_distinct("b", "a"));
}
#[test]
fn buffers_provably_distinct_rejects_same_name() {
let facts =
ProgramFacts::build(&program(vec![Node::store("a", Expr::u32(0), Expr::u32(1))]));
assert!(!facts.buffers_provably_distinct("a", "a"));
}
#[test]
fn buffers_provably_distinct_rejects_phantom_name() {
let facts =
ProgramFacts::build(&program(vec![Node::store("a", Expr::u32(0), Expr::u32(1))]));
assert!(!facts.buffers_provably_distinct("a", "phantom"));
}
#[test]
fn buffer_does_not_escape_when_read_only() {
let facts = ProgramFacts::build(&program(vec![Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("a"),
index: Box::new(Expr::u32(0)),
},
)]));
assert!(!facts.buffer_escapes("a"));
}
#[test]
fn buffer_escapes_when_stored_to() {
let facts =
ProgramFacts::build(&program(vec![Node::store("a", Expr::u32(0), Expr::u32(1))]));
assert!(facts.buffer_escapes("a"));
}
#[test]
fn buffer_escapes_when_atomically_touched() {
let facts = ProgramFacts::build(&program(vec![Node::let_bind(
"x",
Expr::Atomic {
op: AtomicOp::Add,
buffer: Ident::from("a"),
index: Box::new(Expr::u32(0)),
expected: None,
value: Box::new(Expr::u32(1)),
ordering: MemoryOrdering::Relaxed,
},
)]));
assert!(facts.buffer_escapes("a"));
}
#[test]
fn escaping_buffers_enumerates_set() {
let facts = ProgramFacts::build(&program(vec![
Node::store("a", Expr::u32(0), Expr::u32(1)),
Node::let_bind(
"x",
Expr::Load {
buffer: Ident::from("b"),
index: Box::new(Expr::u32(0)),
},
),
]));
let escaping = facts.escaping_buffers();
assert_eq!(escaping.len(), 1);
assert!(escaping.keys().any(|k| k.as_str() == "a"));
}
}