use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use thiserror::Error;
use vyre_spec::AlgebraicLaw;
#[must_use]
pub fn source() -> Option<&'static str> {
crate::transform::compiler::shader_provider::source("dominator_tree")
}
pub const IDOM_UNDEFINED: u32 = u32::MAX;
#[must_use]
pub fn relax_step_program(
idom: &str,
pred_offsets: &str,
preds: &str,
rpo_index: &str,
changed_flag: &str,
) -> Program {
let tid = Expr::InvocationId { axis: 0 };
let body = vec![
Node::let_bind("node", tid.clone()),
Node::let_bind("start", Expr::load(pred_offsets, Expr::var("node"))),
Node::let_bind(
"end",
Expr::load(pred_offsets, Expr::add(Expr::var("node"), Expr::u32(1))),
),
Node::let_bind("new_idom", Expr::u32(IDOM_UNDEFINED)),
Node::loop_for(
"i",
Expr::var("start"),
Expr::var("end"),
vec![
Node::let_bind("p", Expr::load(preds, Expr::var("i"))),
Node::let_bind("pi", Expr::load(idom, Expr::var("p"))),
Node::if_then(
Expr::ne(Expr::var("pi"), Expr::u32(IDOM_UNDEFINED)),
vec![Node::if_then(
Expr::eq(Expr::var("new_idom"), Expr::u32(IDOM_UNDEFINED)),
vec![Node::assign("new_idom", Expr::var("p"))],
)],
),
],
),
Node::loop_for(
"j",
Expr::var("start"),
Expr::var("end"),
vec![
Node::let_bind("p2", Expr::load(preds, Expr::var("j"))),
Node::let_bind("pi2", Expr::load(idom, Expr::var("p2"))),
Node::if_then(
Expr::ne(Expr::var("pi2"), Expr::u32(IDOM_UNDEFINED)),
vec![Node::if_then(
Expr::ne(Expr::var("p2"), Expr::var("new_idom")),
vec![
Node::let_bind("finger1", Expr::var("p2")),
Node::let_bind("finger2", Expr::var("new_idom")),
Node::let_bind("converged", Expr::u32(0)),
Node::loop_for(
"k",
Expr::u32(0),
Expr::u32(1024),
vec![Node::if_then(
Expr::eq(Expr::var("converged"), Expr::u32(0)),
vec![
Node::if_then(
Expr::eq(Expr::var("finger1"), Expr::var("finger2")),
vec![Node::assign("converged", Expr::u32(1))],
),
Node::let_bind(
"rpo1",
Expr::load(rpo_index, Expr::var("finger1")),
),
Node::let_bind(
"rpo2",
Expr::load(rpo_index, Expr::var("finger2")),
),
Node::if_then(
Expr::lt(Expr::var("rpo1"), Expr::var("rpo2")),
vec![Node::assign(
"finger1",
Expr::load(idom, Expr::var("finger1")),
)],
),
Node::if_then(
Expr::lt(Expr::var("rpo2"), Expr::var("rpo1")),
vec![Node::assign(
"finger2",
Expr::load(idom, Expr::var("finger2")),
)],
),
],
)],
),
Node::assign("new_idom", Expr::var("finger1")),
],
)],
),
],
),
Node::let_bind("old_idom", Expr::load(idom, Expr::var("node"))),
Node::if_then(
Expr::ne(Expr::var("new_idom"), Expr::var("old_idom")),
vec![
Node::store(idom, Expr::var("node"), Expr::var("new_idom")),
Node::let_bind(
"chg",
Expr::atomic_exchange(changed_flag, Expr::u32(0), Expr::u32(1)),
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(idom, 0, BufferAccess::ReadWrite, DataType::U32),
BufferDecl::storage(pred_offsets, 1, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(preds, 2, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(rpo_index, 3, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(changed_flag, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[64, 1, 1],
body,
)
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum DominatorTreeError {
#[error(
"DominatorEmptyOffsets: successor_offsets must include node_count + 1 entries. Fix: emit a valid CSR offset table."
)]
EmptyOffsets,
#[error(
"DominatorInvalidEntry: entry {entry} outside node_count {node_count}. Fix: pass a valid CFG entry node."
)]
InvalidEntry {
entry: u32,
node_count: usize,
},
#[error(
"DominatorInvalidOffset: CSR offsets must be monotone and within successors. Fix: rebuild successor_offsets."
)]
InvalidOffset,
#[error(
"DominatorNodeIndexOverflow: node id cannot fit usize. Fix: split the CFG before dispatch."
)]
NodeIndexOverflow,
#[error(
"DominatorInvalidSuccessor: successor {successor} outside node_count {node_count}. Fix: validate CFG edge endpoints."
)]
InvalidSuccessor {
successor: u32,
node_count: usize,
},
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DominatorTreeOp;
#[must_use]
pub fn immediate_dominators(
entry: u32,
successor_offsets: &[u32],
successors: &[u32],
) -> Result<Vec<u32>, DominatorTreeError> {
let node_count = successor_offsets
.len()
.checked_sub(1)
.ok_or(DominatorTreeError::EmptyOffsets)?;
let entry_index = usize::try_from(entry).map_err(|_| DominatorTreeError::NodeIndexOverflow)?;
if entry_index >= node_count {
return Err(DominatorTreeError::InvalidEntry { entry, node_count });
}
validate_graph(node_count, successor_offsets, successors)?;
let rpo = reverse_postorder(entry_index, successor_offsets, successors)?;
let mut order = vec![UNREACHABLE; node_count];
for (rank, &node) in rpo.iter().enumerate() {
order[node] = u32::try_from(rank).map_err(|_| DominatorTreeError::NodeIndexOverflow)?;
}
let preds = predecessors(node_count, successor_offsets, successors)?;
let mut idom = vec![UNREACHABLE; node_count];
idom[entry_index] = entry;
let mut changed = true;
while changed {
changed = false;
for &node in rpo.iter().skip(1) {
let mut new_idom = UNREACHABLE;
for &pred in &preds[node] {
let pred_index =
usize::try_from(pred).map_err(|_| DominatorTreeError::NodeIndexOverflow)?;
if idom[pred_index] != UNREACHABLE {
new_idom = if new_idom == UNREACHABLE {
pred
} else {
intersect(pred, new_idom, &idom, &order)?
};
}
}
if idom[node] != new_idom {
idom[node] = new_idom;
changed = true;
}
}
}
Ok(idom)
}
impl DominatorTreeOp {}
#[must_use]
pub fn index(value: u32) -> Result<usize, DominatorTreeError> {
usize::try_from(value).map_err(|_| DominatorTreeError::NodeIndexOverflow)
}
#[must_use]
pub fn intersect(
mut left: u32,
mut right: u32,
idom: &[u32],
order: &[u32],
) -> Result<u32, DominatorTreeError> {
while left != right {
while order[index(left)?] > order[index(right)?] {
left = idom[index(left)?];
}
while order[index(right)?] > order[index(left)?] {
right = idom[index(right)?];
}
}
Ok(left)
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[must_use]
pub fn predecessors(
node_count: usize,
offsets: &[u32],
successors: &[u32],
) -> Result<Vec<Vec<u32>>, DominatorTreeError> {
let mut preds = vec![Vec::new(); node_count];
for node in 0..node_count {
for &succ in &successors[index(offsets[node])?..index(offsets[node + 1])?] {
preds[index(succ)?]
.push(u32::try_from(node).map_err(|_| DominatorTreeError::NodeIndexOverflow)?);
}
}
Ok(preds)
}
#[must_use]
pub fn reverse_postorder(
entry: usize,
offsets: &[u32],
successors: &[u32],
) -> Result<Vec<usize>, DominatorTreeError> {
let mut seen = vec![false; offsets.len() - 1];
let mut stack = Vec::with_capacity(offsets.len());
stack.push((entry, false));
let mut postorder = Vec::with_capacity(offsets.len().saturating_sub(1));
while let Some((node, expanded)) = stack.pop() {
if expanded {
postorder.push(node);
continue;
}
if seen[node] {
continue;
}
seen[node] = true;
stack.push((node, true));
let start = index(offsets[node])?;
let end = index(offsets[node + 1])?;
for &succ in successors[start..end].iter().rev() {
let succ_index = index(succ)?;
if !seen[succ_index] {
stack.push((succ_index, false));
}
}
}
postorder.reverse();
Ok(postorder)
}
pub const UNREACHABLE: u32 = u32::MAX;
#[must_use]
pub fn validate_graph(
node_count: usize,
offsets: &[u32],
successors: &[u32],
) -> Result<(), DominatorTreeError> {
let mut previous = 0usize;
for &offset in offsets {
let current = index(offset)?;
if current < previous || current > successors.len() {
return Err(DominatorTreeError::InvalidOffset);
}
previous = current;
}
for &successor in successors {
if index(successor)? >= node_count {
return Err(DominatorTreeError::InvalidSuccessor {
successor,
node_count,
});
}
}
Ok(())
}
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];
#[cfg(test)]
mod ir_program_tests {
use super::*;
#[test]
fn relax_step_program_validates() {
let prog = relax_step_program("idom", "po", "preds", "rpo", "cf");
let errors = crate::validate::validate::validate(&prog);
assert!(errors.is_empty(), "dominator IR must validate: {errors:?}");
}
#[test]
fn relax_step_program_wire_round_trips() {
let prog = relax_step_program("idom", "po", "preds", "rpo", "cf");
let bytes = prog
.to_wire()
.expect("Fix: serialize; restore this invariant before continuing.");
let decoded = Program::from_wire(&bytes)
.expect("Fix: decode; restore this invariant before continuing.");
assert_eq!(decoded.buffers().len(), 5);
}
#[test]
fn idom_undefined_is_u32_max() {
assert_eq!(IDOM_UNDEFINED, u32::MAX);
}
}