const STATE_MACHINE_LANES: u32 = 64u;
const STATE_MACHINE_TRANSITIONS: u32 = 8u;
const STATE_MACHINE_OK: u32 = 0u;
const STATE_MACHINE_NO_MATCH: u32 = 1u;
const STATE_MACHINE_DIVERGENT: u32 = 2u;
const STATE_MACHINE_EMPTY_ACTION: u32 = 0xffffffffu;
struct Transition {
state: u32,
event: u32,
next_state: u32,
action: u32,
}
struct StateMachineParams {
initial_state: u32,
transition_count: u32,
step_count: u32,
reserved: u32,
}
struct StateMachineResult {
status: u32,
state: u32,
action: u32,
divergent_count: u32,
}
@group(0) @binding(0) var<storage, read> state_machine_events: array<u32>;
@group(0) @binding(1) var<storage, read> state_machine_transitions: array<Transition>;
@group(0) @binding(2) var<storage, read> state_machine_params: array<StateMachineParams>;
@group(0) @binding(3) var<storage, read_write> state_machine_results: array<StateMachineResult>;
var<workgroup> state_machine_current: u32;
var<workgroup> state_machine_event: u32;
var<workgroup> state_machine_divergent: atomic<u32>;
@compute @workgroup_size(64, 1, 1)
fn workgroup_state_machine_kernel(@builtin(local_invocation_id) local_id: vec3<u32>) {
let lane = local_id.x;
let params = state_machine_params[0];
if (lane == 0u) {
state_machine_current = params.initial_state;
}
workgroupBarrier();
var step = 0u;
loop {
if (step >= params.step_count) {
break;
}
let event_index = step * STATE_MACHINE_LANES + lane;
let lane_event = state_machine_events[event_index];
if (lane == 0u) {
state_machine_event = lane_event;
atomicStore(&state_machine_divergent, 0u);
}
workgroupBarrier();
if (lane_event != state_machine_event) {
_ = atomicAdd(&state_machine_divergent, 1u);
}
workgroupBarrier();
if (lane == 0u) {
let divergent_count = atomicLoad(&state_machine_divergent);
var status = STATE_MACHINE_NO_MATCH;
var action = STATE_MACHINE_EMPTY_ACTION;
if (divergent_count == 0u) {
var transition_index = 0u;
loop {
if (transition_index >= params.transition_count ||
transition_index >= STATE_MACHINE_TRANSITIONS) {
break;
}
let transition = state_machine_transitions[transition_index];
if (transition.state == state_machine_current &&
transition.event == state_machine_event) {
state_machine_current = transition.next_state;
action = transition.action;
status = STATE_MACHINE_OK;
break;
}
transition_index = transition_index + 1u;
}
} else {
status = STATE_MACHINE_DIVERGENT;
}
state_machine_results[step] =
StateMachineResult(status, state_machine_current, action, divergent_count);
}
workgroupBarrier();
step = step + 1u;
}
}