use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
const OP_ID: &str = "vyre-libs::matching::cooperative_dfa";
const ALPHABET_SIZE: u32 = 256;
use crate::scan::dispatch_io::pack_u32_slice as pack_u32;
fn correction_lane(local_lane: Expr, offset: u32) -> Expr {
Expr::select(
Expr::lt(local_lane.clone(), Expr::u32(offset)),
Expr::u32(0),
Expr::sub(local_lane, Expr::u32(offset)),
)
}
fn transition_expr(transitions: &str, state: Expr, byte: Expr) -> Expr {
Expr::load(
transitions,
Expr::add(Expr::mul(state, Expr::u32(ALPHABET_SIZE)), byte),
)
}
fn fixture_case() -> (Vec<u32>, super::CompiledDfa, Vec<u32>) {
let compiled = super::dfa_compile(&[b"a"]);
let input = b"banana"
.iter()
.map(|&byte| u32::from(byte))
.collect::<Vec<_>>();
let expected = vec![0, 1, 0, 1, 0, 1];
(input, compiled, expected)
}
#[must_use]
pub fn cooperative_dfa_scan_body_with_store(
input: &str,
transitions: &str,
accept_mask: &str,
matches: &str,
subgroup_size: u32,
store_value: Expr,
) -> Vec<Node> {
let idx = Expr::InvocationId { axis: 0 };
let local_lane = Expr::LocalId { axis: 0 };
let effective_subgroup = subgroup_size.max(1);
let round_count = effective_subgroup.ilog2();
let mut lane_body = vec![
Node::let_bind(
"safe_idx",
Expr::select(Expr::var("in_bounds"), idx.clone(), Expr::u32(0)),
),
Node::let_bind("byte", Expr::load(input, Expr::var("safe_idx"))),
Node::assign(
"state",
transition_expr(transitions, Expr::var("state"), Expr::var("byte")),
),
];
for round in 0..round_count {
let offset = 1u32 << round;
let shuffled_name = format!("forwarded_state_{round}");
lane_body.push(Node::let_bind(
shuffled_name.as_str(),
Expr::SubgroupShuffle {
value: Box::new(Expr::var("state")),
lane: Box::new(correction_lane(local_lane.clone(), offset)),
},
));
lane_body.push(Node::assign(
"state",
transition_expr(
transitions,
Expr::var(shuffled_name.as_str()),
Expr::var("byte"),
),
));
}
lane_body.push(Node::let_bind(
"accepting",
Expr::load(accept_mask, Expr::var("state")),
));
lane_body.push(Node::if_then(
Expr::var("in_bounds"),
vec![Node::Store {
buffer: matches.into(),
index: idx.clone(),
value: store_value,
}],
));
let mut body = vec![
Node::let_bind("idx", idx),
Node::let_bind(
"in_bounds",
Expr::lt(Expr::var("idx"), Expr::buf_len(input)),
),
Node::let_bind("state", Expr::u32(0)),
];
body.extend(lane_body);
body
}
#[must_use]
pub fn cooperative_dfa_scan(
input: &str,
transitions: &str,
accept_mask: &str,
matches: &str,
input_len: u32,
state_count: u32,
subgroup_size: u32,
) -> Program {
let effective_subgroup = subgroup_size.max(1);
let body = cooperative_dfa_scan_body_with_store(
input,
transitions,
accept_mask,
matches,
subgroup_size,
Expr::select(
Expr::ne(Expr::var("accepting"), Expr::u32(0)),
Expr::u32(1),
Expr::u32(0),
),
);
Program::wrapped(
vec![
BufferDecl::storage(input, 0, BufferAccess::ReadOnly, DataType::U32)
.with_count(input_len),
BufferDecl::storage(transitions, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(state_count.saturating_mul(ALPHABET_SIZE)),
BufferDecl::storage(accept_mask, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(state_count),
BufferDecl::output(matches, 3, DataType::U32).with_count(input_len),
],
[effective_subgroup, 1, 1],
vec![wrap_anonymous(OP_ID, body)],
)
}
#[must_use]
#[cfg(test)]
pub(crate) fn cpu_ref_cooperative_dfa(
input: &[u32],
transitions: &[u32],
accept_mask: &[u32],
state_count: u32,
alphabet_size: u32,
) -> Vec<u32> {
let row_width = alphabet_size as usize;
let Some(expected_transitions) = (state_count as usize).checked_mul(row_width) else {
return vec![0; input.len()];
};
if row_width == 0
|| transitions.len() != expected_transitions
|| accept_mask.len() < state_count as usize
{
return vec![0; input.len()];
}
let mut state = 0u32;
let mut matches = Vec::with_capacity(input.len());
for &symbol in input {
if symbol >= alphabet_size || state >= state_count {
state = 0;
matches.push(0);
continue;
}
let offset = (state as usize) * row_width + symbol as usize;
state = transitions[offset];
matches.push(u32::from(
(state as usize) < accept_mask.len() && accept_mask[state as usize] != 0,
));
}
matches
}
fn fixture_inputs() -> Vec<Vec<Vec<u8>>> {
let (input, compiled, _) = fixture_case();
vec![vec![
pack_u32(&input),
pack_u32(&compiled.transitions),
pack_u32(&compiled.accept),
]]
}
fn fixture_expected_output() -> Vec<Vec<Vec<u8>>> {
let (_, _, expected) = fixture_case();
vec![vec![pack_u32(&expected)]]
}
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| {
let (_, compiled, _) = fixture_case();
cooperative_dfa_scan(
"input",
"transitions",
"accept_mask",
"matches",
6,
compiled.state_count,
4,
)
},
Some(fixture_inputs),
Some(fixture_expected_output),
)
}
#[cfg(test)]
mod tests {
use super::*;
fn compile_patterns(patterns: &[&[u8]]) -> (Vec<u32>, Vec<u32>, u32) {
let compiled = super::super::dfa_compile(patterns);
(compiled.transitions, compiled.accept, compiled.state_count)
}
fn encode(bytes: &[u8]) -> Vec<u32> {
bytes.iter().map(|&byte| u32::from(byte)).collect()
}
#[test]
fn cooperative_dfa_single_pattern_abc() {
let input = encode(b"zabc");
let (transitions, accept_mask, state_count) = compile_patterns(&[b"abc"]);
assert_eq!(
cpu_ref_cooperative_dfa(
&input,
&transitions,
&accept_mask,
state_count,
ALPHABET_SIZE
),
vec![0, 0, 0, 1],
);
}
#[test]
fn cooperative_dfa_overlapping_multi_pattern() {
let input = encode(b"xabcd");
let (transitions, accept_mask, state_count) = compile_patterns(&[b"abc", b"bcd"]);
assert_eq!(
cpu_ref_cooperative_dfa(
&input,
&transitions,
&accept_mask,
state_count,
ALPHABET_SIZE
),
vec![0, 0, 0, 1, 1],
);
}
#[test]
fn cooperative_dfa_empty_input() {
let (transitions, accept_mask, state_count) = compile_patterns(&[b"abc"]);
assert!(cpu_ref_cooperative_dfa(
&[],
&transitions,
&accept_mask,
state_count,
ALPHABET_SIZE
)
.is_empty());
}
}