use crate::ops::AlgebraicLaw;
use crate::ir::transform::compiler::{U32X2_OUTPUTS, U32X4_INPUTS};
use crate::lower::wgsl::compiler::wgsl_backend;
use crate::ops::{IntrinsicDescriptor, OpSpec};
use thiserror::Error;
#[must_use]
pub const fn source() -> &'static str {
include_str!("../../../lower/wgsl/compiler/dataflow_fixpoint.wgsl")
}
pub fn compute_fixpoint(
initial_state: &[u32],
transfer: &[u32],
successor_offsets: &[u32],
successors: &[u32],
max_iterations: u32,
) -> Result<FixpointResult, DataflowFixpointError> {
validate_graph(initial_state.len(), transfer, successor_offsets, successors)?;
let mut state = initial_state.to_vec();
for iteration in 0..max_iterations {
let mut changed = false;
for node in 0..state.len() {
let propagated = state[node] | transfer[node];
let start = usize::try_from(successor_offsets[node])
.map_err(|_| DataflowFixpointError::OffsetOverflow)?;
let end = usize::try_from(successor_offsets[node + 1])
.map_err(|_| DataflowFixpointError::OffsetOverflow)?;
for &successor in &successors[start..end] {
let successor_index = usize::try_from(successor)
.map_err(|_| DataflowFixpointError::NodeIndexOverflow)?;
let joined = state[successor_index] | propagated;
if joined != state[successor_index] {
state[successor_index] = joined;
changed = true;
}
}
}
if !changed {
return Ok(FixpointResult {
state,
iterations: iteration + 1,
});
}
}
Err(DataflowFixpointError::DidNotConverge { max_iterations })
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum DataflowFixpointError {
#[error("DataflowTransferLength: expected {expected} transfer entries, got {got}. Fix: emit one transfer mask per CFG node.")]
TransferLength {
expected: usize,
got: usize,
},
#[error("DataflowOffsetLength: expected {expected} offsets, got {got}. Fix: emit node_count + 1 CSR offsets.")]
OffsetLength {
expected: usize,
got: usize,
},
#[error("DataflowInvalidOffset: CSR offsets must be monotone and within successors. Fix: rebuild successor_offsets.")]
InvalidOffset,
#[error("DataflowOffsetOverflow: CSR offset cannot fit usize. Fix: split the graph before dispatch.")]
OffsetOverflow,
#[error("DataflowNodeIndexOverflow: node id cannot fit usize. Fix: split the graph before dispatch.")]
NodeIndexOverflow,
#[error("DataflowInvalidSuccessor: successor {successor} outside node_count {node_count}. Fix: validate CFG edge endpoints.")]
InvalidSuccessor {
successor: u32,
node_count: usize,
},
#[error("DataflowDidNotConverge: no fixed point within {max_iterations} iterations. Fix: raise the bounded iteration cap or inspect non-monotone transfer data.")]
DidNotConverge {
max_iterations: u32,
},
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DataflowFixpointOp;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FixpointResult {
pub state: Vec<u32>,
pub iterations: u32,
}
impl DataflowFixpointOp {
pub const SPEC: OpSpec = OpSpec::intrinsic(
"compiler_primitives.dataflow_fixpoint",
U32X4_INPUTS,
U32X2_OUTPUTS,
LAWS,
wgsl_backend,
IntrinsicDescriptor::new(
"compiler_primitives_dataflow_fixpoint",
"workgroup_change_bit",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
}
pub const LAWS: &[AlgebraicLaw] = &[AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}];
pub fn validate_graph(
node_count: usize,
transfer: &[u32],
offsets: &[u32],
successors: &[u32],
) -> Result<(), DataflowFixpointError> {
if transfer.len() != node_count {
return Err(DataflowFixpointError::TransferLength {
expected: node_count,
got: transfer.len(),
});
}
if offsets.len() != node_count.saturating_add(1) {
return Err(DataflowFixpointError::OffsetLength {
expected: node_count.saturating_add(1),
got: offsets.len(),
});
}
let mut previous = 0usize;
for &offset in offsets {
let current = usize::try_from(offset).map_err(|_| DataflowFixpointError::OffsetOverflow)?;
if current < previous || current > successors.len() {
return Err(DataflowFixpointError::InvalidOffset);
}
previous = current;
}
for &successor in successors {
let index =
usize::try_from(successor).map_err(|_| DataflowFixpointError::NodeIndexOverflow)?;
if index >= node_count {
return Err(DataflowFixpointError::InvalidSuccessor {
successor,
node_count,
});
}
}
Ok(())
}
pub const WORKGROUP_SIZE: [u32; 3] = [64, 1, 1];