use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use thiserror::Error;
use vyre_spec::AlgebraicLaw;
#[must_use]
pub fn source() -> Option<&'static str> {
crate::transform::compiler::shader_provider::source("dataflow_fixpoint")
}
#[must_use]
pub fn relax_step_program(
state: &str,
transfer: &str,
successor_offsets: &str,
successors: &str,
changed_flag: &str,
) -> Program {
let tid = Expr::InvocationId { axis: 0 };
let body = vec![
Node::let_bind("node", tid.clone()),
Node::let_bind("state_n", Expr::load(state, Expr::var("node"))),
Node::let_bind("transfer_n", Expr::load(transfer, Expr::var("node"))),
Node::let_bind(
"propagated",
Expr::bitor(Expr::var("state_n"), Expr::var("transfer_n")),
),
Node::let_bind("start", Expr::load(successor_offsets, Expr::var("node"))),
Node::let_bind(
"end",
Expr::load(
successor_offsets,
Expr::add(Expr::var("node"), Expr::u32(1)),
),
),
Node::loop_for(
"i",
Expr::var("start"),
Expr::var("end"),
vec![
Node::let_bind("succ", Expr::load(successors, Expr::var("i"))),
Node::let_bind("old", Expr::load(state, Expr::var("succ"))),
Node::let_bind(
"new",
Expr::bitor(Expr::var("old"), Expr::var("propagated")),
),
Node::if_then(
Expr::ne(Expr::var("new"), Expr::var("old")),
vec![
Node::store(state, Expr::var("succ"), Expr::var("new")),
Node::let_bind(
"chg",
Expr::atomic_exchange(changed_flag, Expr::u32(0), Expr::u32(1)),
),
],
),
],
),
];
Program::wrapped(
vec![
BufferDecl::storage(state, 0, BufferAccess::ReadWrite, DataType::U32),
BufferDecl::storage(transfer, 1, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(successor_offsets, 2, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(successors, 3, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(changed_flag, 4, BufferAccess::ReadWrite, DataType::U32)
.with_count(1),
],
[64, 1, 1],
body,
)
}
#[must_use]
pub fn compute_fixpoint(
initial_state: &[u32],
transfer: &[u32],
successor_offsets: &[u32],
successors: &[u32],
max_iterations: u32,
) -> Result<FixpointResult, DataflowFixpointError> {
validate_graph(initial_state.len(), transfer, successor_offsets, successors)?;
let mut state = initial_state.to_vec();
for iteration in 0..max_iterations {
let mut changed = false;
for node in 0..state.len() {
let propagated = state[node] | transfer[node];
let start = usize::try_from(successor_offsets[node])
.map_err(|_| DataflowFixpointError::OffsetOverflow)?;
let end = usize::try_from(successor_offsets[node + 1])
.map_err(|_| DataflowFixpointError::OffsetOverflow)?;
for &successor in &successors[start..end] {
let successor_index = usize::try_from(successor)
.map_err(|_| DataflowFixpointError::NodeIndexOverflow)?;
let joined = state[successor_index] | propagated;
if joined != state[successor_index] {
state[successor_index] = joined;
changed = true;
}
}
}
if !changed {
return Ok(FixpointResult {
state,
iterations: iteration + 1,
});
}
}
Err(DataflowFixpointError::DidNotConverge { max_iterations })
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum DataflowFixpointError {
#[error(
"DataflowTransferLength: expected {expected} transfer entries, got {got}. Fix: emit one transfer mask per CFG node."
)]
TransferLength {
expected: usize,
got: usize,
},
#[error(
"DataflowOffsetLength: expected {expected} offsets, got {got}. Fix: emit node_count + 1 CSR offsets."
)]
OffsetLength {
expected: usize,
got: usize,
},
#[error(
"DataflowInvalidOffset: CSR offsets must be monotone and within successors. Fix: rebuild successor_offsets."
)]
InvalidOffset,
#[error(
"DataflowOffsetOverflow: CSR offset cannot fit usize. Fix: split the graph before dispatch."
)]
OffsetOverflow,
#[error(
"DataflowNodeIndexOverflow: node id cannot fit usize. Fix: split the graph before dispatch."
)]
NodeIndexOverflow,
#[error(
"DataflowInvalidSuccessor: successor {successor} outside node_count {node_count}. Fix: validate CFG edge endpoints."
)]
InvalidSuccessor {
successor: u32,
node_count: usize,
},
#[error(
"DataflowDidNotConverge: no fixed point within {max_iterations} iterations. Fix: raise the bounded iteration cap or inspect non-monotone transfer data."
)]
DidNotConverge {
max_iterations: u32,
},
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DataflowFixpointOp;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FixpointResult {
pub state: Vec<u32>,
pub iterations: u32,
}
impl DataflowFixpointOp {}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
#[must_use]
pub fn validate_graph(
node_count: usize,
transfer: &[u32],
offsets: &[u32],
successors: &[u32],
) -> Result<(), DataflowFixpointError> {
if transfer.len() != node_count {
return Err(DataflowFixpointError::TransferLength {
expected: node_count,
got: transfer.len(),
});
}
if offsets.len() != node_count.saturating_add(1) {
return Err(DataflowFixpointError::OffsetLength {
expected: node_count.saturating_add(1),
got: offsets.len(),
});
}
let mut previous = 0usize;
for &offset in offsets {
let current = usize::try_from(offset).map_err(|_| DataflowFixpointError::OffsetOverflow)?;
if current < previous || current > successors.len() {
return Err(DataflowFixpointError::InvalidOffset);
}
previous = current;
}
for &successor in successors {
let index =
usize::try_from(successor).map_err(|_| DataflowFixpointError::NodeIndexOverflow)?;
if index >= node_count {
return Err(DataflowFixpointError::InvalidSuccessor {
successor,
node_count,
});
}
}
Ok(())
}
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];
#[cfg(test)]
mod ir_program_tests {
use super::*;
#[test]
fn relax_step_program_validates() {
let prog = relax_step_program(
"state",
"transfer",
"successor_offsets",
"successors",
"changed_flag",
);
let errors = crate::validate::validate::validate(&prog);
assert!(errors.is_empty(), "dataflow IR must validate: {errors:?}");
}
#[test]
fn relax_step_program_wire_round_trips() {
let prog = relax_step_program("s", "t", "o", "sc", "cf");
let bytes = prog
.to_wire()
.expect("Fix: serialize; restore this invariant before continuing.");
let decoded = Program::from_wire(&bytes)
.expect("Fix: decode; restore this invariant before continuing.");
assert_eq!(decoded.buffers().len(), 5);
assert_eq!(decoded.workgroup_size(), [64, 1, 1]);
}
#[test]
fn relax_step_program_declares_five_buffers_in_csr_order() {
let prog = relax_step_program(
"state",
"transfer",
"successor_offsets",
"successors",
"changed_flag",
);
let names: Vec<&str> = prog.buffers().iter().map(|b| b.name()).collect();
assert_eq!(
names,
vec![
"state",
"transfer",
"successor_offsets",
"successors",
"changed_flag",
]
);
}
}