use crate::ir::DataType;
use crate::ops::{AlgebraicLaw, Backend, IntrinsicDescriptor, OpSpec};
pub const INPUTS: &[DataType] = &[DataType::U32, DataType::U32, DataType::U32];
pub const OUTPUTS: &[DataType] = &[DataType::U32, DataType::U32];
pub const LAWS: &[AlgebraicLaw] = &[];
pub const SPEC: OpSpec = OpSpec::intrinsic(
"workgroup.state_machine",
INPUTS,
OUTPUTS,
LAWS,
wgsl_only,
IntrinsicDescriptor::new(
"workgroup_state_machine_kernel",
"workgroup-sram-uniform-lookup",
crate::ops::cpu_op::structured_intrinsic_cpu,
),
);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Transition {
pub state: u32,
pub event: u32,
pub next_state: u32,
pub action: u32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StateMachineStatus {
Ok = 0,
NoMatch = 1,
Divergent = 2,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StateMachineError {
NoMatch,
Divergent,
EmptyTable,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WorkgroupStateMachine {
transitions: Vec<Transition>,
current: u32,
actions: Vec<u32>,
}
impl WorkgroupStateMachine {
pub fn new(transitions: Vec<Transition>, initial: u32) -> Result<Self, StateMachineError> {
Self::try_new(transitions, initial).ok_or(StateMachineError::EmptyTable)
}
pub fn try_new(transitions: Vec<Transition>, initial: u32) -> Option<Self> {
if transitions.is_empty() {
return None;
}
Some(Self {
transitions,
current: initial,
actions: Vec::new(),
})
}
#[must_use]
pub fn state(&self) -> u32 {
self.current
}
#[must_use]
pub fn actions(&self) -> &[u32] {
&self.actions
}
pub fn step(&mut self, event: u32) -> Result<StateMachineStatus, StateMachineError> {
if self.transitions.is_empty() {
return Err(StateMachineError::EmptyTable);
}
for t in &self.transitions {
if t.state == self.current && t.event == event {
self.current = t.next_state;
self.actions.push(t.action);
return Ok(StateMachineStatus::Ok);
}
}
Err(StateMachineError::NoMatch)
}
pub fn step_uniform(
&mut self,
lane_events: &[u32],
) -> Result<StateMachineStatus, StateMachineError> {
if lane_events.is_empty() {
return Err(StateMachineError::EmptyTable);
}
let first = lane_events[0];
if lane_events.iter().any(|&e| e != first) {
return Err(StateMachineError::Divergent);
}
self.step(first)
}
pub fn reset(&mut self, state: u32) {
self.current = state;
self.actions.clear();
}
}
pub fn wgsl_only(backend: &Backend) -> bool {
matches!(backend, Backend::Wgsl)
}