use vyre_foundation::ir::{Expr, Node};
const DEFAULT_SUBGROUP_WIDTH: u32 = 32;
const ALPHABET_SIZE: u32 = 256;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ParallelDfaBindings {
pub transitions: &'static str,
pub haystack: &'static str,
pub lane_prefix: &'static str,
pub lane_next: &'static str,
pub out_state_by_lane: &'static str,
pub file_start: &'static str,
pub file_end: &'static str,
pub transition_base: &'static str,
pub initial_state: &'static str,
pub state_count: &'static str,
pub subgroup_width: u32,
}
impl Default for ParallelDfaBindings {
fn default() -> Self {
Self {
transitions: "transitions",
haystack: "haystack",
lane_prefix: "lane_prefix",
lane_next: "lane_next",
out_state_by_lane: "out_state_by_lane",
file_start: "file_start",
file_end: "file_end",
transition_base: "transition_base",
initial_state: "initial_state",
state_count: "state_count",
subgroup_width: DEFAULT_SUBGROUP_WIDTH,
}
}
}
#[must_use]
pub fn dfa_byte_scanner_parallel_composition() -> Vec<Node> {
dfa_byte_scanner_parallel_composition_with(&ParallelDfaBindings::default())
}
#[must_use]
pub fn dfa_byte_scanner_parallel_composition_with(bindings: &ParallelDfaBindings) -> Vec<Node> {
let mut nodes = vec![
Node::let_bind("lane_id", Expr::invocation_local_x()),
Node::let_bind(
"lane_byte_pos",
Expr::add(Expr::var(bindings.file_start), Expr::var("lane_id")),
),
Node::let_bind(
"lane_active",
Expr::lt(Expr::var("lane_byte_pos"), Expr::var(bindings.file_end)),
),
Node::let_bind(
"lane_byte",
Expr::select(
Expr::var("lane_active"),
Expr::cast(
vyre_foundation::ir::DataType::U32,
Expr::load(bindings.haystack, Expr::var("lane_byte_pos")),
),
Expr::u32(0),
),
),
Node::loop_for(
"state",
Expr::u32(0),
Expr::var(bindings.state_count),
vec![Node::store(
bindings.lane_prefix,
table_index("lane_id", bindings.state_count, Expr::var("state")),
Expr::select(
Expr::var("lane_active"),
Expr::load(
bindings.transitions,
Expr::add(
Expr::var(bindings.transition_base),
Expr::add(
Expr::mul(Expr::var("state"), Expr::u32(ALPHABET_SIZE)),
Expr::var("lane_byte"),
),
),
),
Expr::var("state"),
),
)],
),
Node::barrier(),
];
let mut stride = 1;
while stride < bindings.subgroup_width {
append_prefix_stage(&mut nodes, bindings, stride);
stride *= 2;
}
nodes.extend([
Node::store(
bindings.out_state_by_lane,
Expr::var("lane_id"),
Expr::load(
bindings.lane_prefix,
table_index(
"lane_id",
bindings.state_count,
Expr::var(bindings.initial_state),
),
),
),
Node::barrier(),
]);
nodes
}
fn append_prefix_stage(nodes: &mut Vec<Node>, bindings: &ParallelDfaBindings, stride: u32) {
nodes.push(Node::loop_for(
"state",
Expr::u32(0),
Expr::var(bindings.state_count),
vec![
Node::let_bind(
"source_lane",
Expr::select(
Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
Expr::sub(Expr::var("lane_id"), Expr::u32(stride)),
Expr::u32(0),
),
),
Node::let_bind(
"previous_state",
Expr::subgroup_shuffle(
Expr::load(
bindings.lane_prefix,
table_index("lane_id", bindings.state_count, Expr::var("state")),
),
Expr::var("source_lane"),
),
),
Node::let_bind(
"composed_state",
Expr::select(
Expr::ge(Expr::var("lane_id"), Expr::u32(stride)),
Expr::load(
bindings.lane_prefix,
table_index("lane_id", bindings.state_count, Expr::var("previous_state")),
),
Expr::load(
bindings.lane_prefix,
table_index("lane_id", bindings.state_count, Expr::var("state")),
),
),
),
Node::store(
bindings.lane_next,
table_index("lane_id", bindings.state_count, Expr::var("state")),
Expr::var("composed_state"),
),
],
));
nodes.push(Node::barrier());
nodes.push(Node::loop_for(
"state",
Expr::u32(0),
Expr::var(bindings.state_count),
vec![Node::store(
bindings.lane_prefix,
table_index("lane_id", bindings.state_count, Expr::var("state")),
Expr::load(
bindings.lane_next,
table_index("lane_id", bindings.state_count, Expr::var("state")),
),
)],
));
nodes.push(Node::barrier());
}
fn table_index(lane_var: &str, state_count_var: &str, state: Expr) -> Expr {
Expr::add(
Expr::mul(Expr::var(lane_var), Expr::var(state_count_var)),
state,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parallel_dfa_fragment_has_prefix_barriers_and_output() {
let nodes = dfa_byte_scanner_parallel_composition();
assert!(
nodes
.iter()
.filter(|node| matches!(
node,
Node::Barrier {
ordering: vyre::memory_model::MemoryOrdering::SeqCst
}
))
.count()
>= 2,
"prefix composition must synchronize scratch-table stages"
);
assert!(
stores_buffer(&nodes, "out_state_by_lane"),
"fragment must publish per-lane states"
);
}
fn stores_buffer(nodes: &[Node], name: &str) -> bool {
nodes.iter().any(|node| match node {
Node::Store { buffer, .. } => buffer.as_str() == name,
Node::Block(body) | Node::Loop { body, .. } => stores_buffer(body, name),
Node::Region { body, .. } => stores_buffer(body, name),
Node::If {
then, otherwise, ..
} => stores_buffer(then, name) || stores_buffer(otherwise, name),
_ => false,
})
}
}