use vyre::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
use crate::region::wrap_anonymous;
use crate::scan::builders::{append_match, append_match_subgroup, load_packed_byte};
use crate::scan::dfa::CompiledDfa;
#[cfg(any(test, feature = "cpu-parity"))]
use super::ClassicAcAutomaton;
#[path = "bounded_ranges/prefilter.rs"]
mod prefilter;
pub use prefilter::{
build_ac_bounded_ranges_prefilter_program, build_ac_bounded_ranges_prefilter_program_ext,
build_ac_bounded_ranges_suffix3_prefilter_program,
build_ac_bounded_ranges_suffix3_prefilter_program_ext,
classic_ac_bounded_ranges_prefilter_program, classic_ac_bounded_ranges_prefilter_program_ext,
classic_ac_bounded_ranges_suffix3_prefilter_program,
classic_ac_bounded_ranges_suffix3_prefilter_program_ext,
classic_ac_bounded_ranges_suffix3_presence_by_region_program_ext,
classic_ac_bounded_ranges_suffix3_presence_program_ext, presence_bitmap_words,
presence_by_region_words, try_build_ac_bounded_ranges_prefilter_program,
try_build_ac_bounded_ranges_prefilter_program_ext,
try_build_ac_bounded_ranges_suffix3_prefilter_program,
try_build_ac_bounded_ranges_suffix3_prefilter_program_ext,
try_build_ac_bounded_ranges_suffix3_presence_by_region_program,
try_build_ac_bounded_ranges_suffix3_presence_program,
};
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn classic_ac_bounded_ranges_program(
haystack: &str,
transitions: &str,
output_offsets: &str,
output_records: &str,
pattern_lengths: &str,
haystack_len: &str,
match_count: &str,
matches: &str,
state_count: u32,
output_records_len: u32,
pattern_count: u32,
max_matches: u32,
max_pattern_len: u32,
) -> Program {
classic_ac_bounded_ranges_program_ext(
haystack,
transitions,
output_offsets,
output_records,
pattern_lengths,
haystack_len,
match_count,
matches,
state_count,
output_records_len,
pattern_count,
max_matches,
max_pattern_len,
true,
)
}
#[must_use]
#[allow(clippy::too_many_arguments)]
pub fn classic_ac_bounded_ranges_program_ext(
haystack: &str,
transitions: &str,
output_offsets: &str,
output_records: &str,
pattern_lengths: &str,
haystack_len: &str,
match_count: &str,
matches: &str,
state_count: u32,
output_records_len: u32,
pattern_count: u32,
max_matches: u32,
max_pattern_len: u32,
use_subgroup_coalesce: bool,
) -> Program {
let max_pattern_len = max_pattern_len.max(1);
let i = Expr::var("i");
let walk_body = vec![
Node::let_bind("i", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(i.clone(), Expr::load(haystack_len, Expr::u32(0))),
bounded_ranges_scan_nodes(
haystack,
transitions,
output_offsets,
output_records,
pattern_lengths,
match_count,
matches,
max_pattern_len,
use_subgroup_coalesce,
),
),
];
Program::wrapped(
vec![
BufferDecl::storage(haystack, 0, BufferAccess::ReadOnly, DataType::U32),
BufferDecl::storage(transitions, 1, BufferAccess::ReadOnly, DataType::U32)
.with_count(state_count.saturating_mul(256)),
BufferDecl::storage(output_offsets, 2, BufferAccess::ReadOnly, DataType::U32)
.with_count(state_count.saturating_add(1)),
BufferDecl::storage(output_records, 3, BufferAccess::ReadOnly, DataType::U32)
.with_count(output_records_len),
BufferDecl::storage(pattern_lengths, 4, BufferAccess::ReadOnly, DataType::U32)
.with_count(pattern_count),
BufferDecl::storage(haystack_len, 5, BufferAccess::ReadOnly, DataType::U32)
.with_count(1),
BufferDecl::read_write(match_count, 6, DataType::U32).with_count(1),
BufferDecl::output(matches, 7, DataType::U32).with_count(max_matches.saturating_mul(3)),
],
[128, 1, 1],
vec![wrap_anonymous(
"vyre-libs::matching::classic_ac_bounded_ranges",
walk_body,
)],
)
}
#[allow(clippy::too_many_arguments)]
fn bounded_ranges_scan_nodes(
haystack: &str,
transitions: &str,
output_offsets: &str,
output_records: &str,
pattern_lengths: &str,
match_count: &str,
matches: &str,
max_pattern_len: u32,
use_subgroup_coalesce: bool,
) -> Vec<Node> {
let max_pattern_len = max_pattern_len.max(1);
let i = Expr::var("i");
let end = Expr::add(i.clone(), Expr::u32(1));
let scan_start = Expr::select(
Expr::lt(i.clone(), Expr::u32(max_pattern_len - 1)),
Expr::u32(0),
Expr::sub(end.clone(), Expr::u32(max_pattern_len)),
);
let (load_step_byte, step_byte) = load_packed_byte(haystack, Expr::var("step"));
vec![
Node::let_bind("state", Expr::u32(0)),
Node::let_bind("scan_start", scan_start),
Node::let_bind("scan_end", end),
Node::loop_for(
"step",
Expr::var("scan_start"),
Expr::var("scan_end"),
vec![
load_step_byte,
Node::assign(
"state",
Expr::load(
transitions,
Expr::add(Expr::mul(Expr::var("state"), Expr::u32(256)), step_byte),
),
),
],
),
Node::let_bind("out_begin", Expr::load(output_offsets, Expr::var("state"))),
Node::let_bind(
"out_end",
Expr::load(output_offsets, Expr::add(Expr::var("state"), Expr::u32(1))),
),
Node::loop_for("out_idx", Expr::var("out_begin"), Expr::var("out_end"), {
let mut body = vec![
Node::let_bind(
"pattern_id",
Expr::load(output_records, Expr::var("out_idx")),
),
Node::let_bind(
"pat_len",
Expr::load(pattern_lengths, Expr::var("pattern_id")),
),
Node::let_bind(
"match_start",
Expr::select(
Expr::lt(Expr::var("scan_end"), Expr::var("pat_len")),
Expr::u32(0),
Expr::sub(Expr::var("scan_end"), Expr::var("pat_len")),
),
),
];
if use_subgroup_coalesce {
body.extend(append_match_subgroup(
matches,
match_count,
Expr::var("pattern_id"),
Expr::var("match_start"),
Expr::var("scan_end"),
Expr::bool(true),
));
} else {
body.push(append_match(
matches,
match_count,
Expr::var("pattern_id"),
Expr::var("match_start"),
Expr::var("scan_end"),
));
}
body
}),
]
}
fn bounded_ranges_presence_nodes(
haystack: &str,
transitions: &str,
output_offsets: &str,
output_records: &str,
presence: &str,
max_pattern_len: u32,
) -> Vec<Node> {
let max_pattern_len = max_pattern_len.max(1);
let i = Expr::var("i");
let end = Expr::add(i.clone(), Expr::u32(1));
let scan_start = Expr::select(
Expr::lt(i.clone(), Expr::u32(max_pattern_len - 1)),
Expr::u32(0),
Expr::sub(end.clone(), Expr::u32(max_pattern_len)),
);
let (load_step_byte, step_byte) = load_packed_byte(haystack, Expr::var("step"));
vec![
Node::let_bind("state", Expr::u32(0)),
Node::let_bind("scan_start", scan_start),
Node::let_bind("scan_end", end),
Node::loop_for(
"step",
Expr::var("scan_start"),
Expr::var("scan_end"),
vec![
load_step_byte,
Node::assign(
"state",
Expr::load(
transitions,
Expr::add(Expr::mul(Expr::var("state"), Expr::u32(256)), step_byte),
),
),
],
),
Node::let_bind("out_begin", Expr::load(output_offsets, Expr::var("state"))),
Node::let_bind(
"out_end",
Expr::load(output_offsets, Expr::add(Expr::var("state"), Expr::u32(1))),
),
Node::loop_for("out_idx", Expr::var("out_begin"), Expr::var("out_end"), {
vec![
Node::let_bind(
"pattern_id",
Expr::load(output_records, Expr::var("out_idx")),
),
Node::let_bind(
"_vyre_presence_prev",
Expr::atomic_or(
presence,
Expr::shr(Expr::var("pattern_id"), Expr::u32(5)),
Expr::shl(
Expr::u32(1),
Expr::bitand(Expr::var("pattern_id"), Expr::u32(31)),
),
),
),
]
}),
]
}
#[allow(clippy::too_many_arguments)]
fn bounded_ranges_presence_by_region_nodes(
haystack: &str,
transitions: &str,
output_offsets: &str,
output_records: &str,
presence: &str,
region_starts: &str,
region_base: &str,
max_pattern_len: u32,
presence_words: u32,
log2_max_regions: u32,
) -> Vec<Node> {
let max_pattern_len = max_pattern_len.max(1);
let i = Expr::var("i");
let end = Expr::add(i.clone(), Expr::u32(1));
let scan_start = Expr::select(
Expr::lt(i.clone(), Expr::u32(max_pattern_len - 1)),
Expr::u32(0),
Expr::sub(end.clone(), Expr::u32(max_pattern_len)),
);
let (load_step_byte, step_byte) = load_packed_byte(haystack, Expr::var("step"));
let region_and_emit = vec![
Node::let_bind(
"rs_pos",
Expr::add(i.clone(), Expr::load(region_base, Expr::u32(0))),
),
Node::let_bind("rs_lo", Expr::u32(0)),
Node::let_bind(
"rs_hi",
Expr::sub(Expr::buf_len(region_starts), Expr::u32(1)),
),
Node::loop_for(
"rs_step",
Expr::u32(0),
Expr::u32(log2_max_regions.max(1)),
vec![
Node::let_bind(
"rs_mid",
Expr::div(
Expr::add(
Expr::add(Expr::var("rs_lo"), Expr::var("rs_hi")),
Expr::u32(1),
),
Expr::u32(2),
),
),
Node::let_bind(
"rs_cond",
Expr::le(
Expr::load(region_starts, Expr::var("rs_mid")),
Expr::var("rs_pos"),
),
),
Node::assign(
"rs_lo",
Expr::select(
Expr::var("rs_cond"),
Expr::var("rs_mid"),
Expr::var("rs_lo"),
),
),
Node::assign(
"rs_hi",
Expr::select(
Expr::var("rs_cond"),
Expr::var("rs_hi"),
Expr::sub(Expr::var("rs_mid"), Expr::u32(1)),
),
),
],
),
Node::let_bind(
"rs_base",
Expr::mul(Expr::var("rs_lo"), Expr::u32(presence_words.max(1))),
),
Node::loop_for("out_idx", Expr::var("out_begin"), Expr::var("out_end"), {
vec![
Node::let_bind(
"pattern_id",
Expr::load(output_records, Expr::var("out_idx")),
),
Node::let_bind(
"_vyre_presence_prev",
Expr::atomic_or(
presence,
Expr::add(
Expr::var("rs_base"),
Expr::shr(Expr::var("pattern_id"), Expr::u32(5)),
),
Expr::shl(
Expr::u32(1),
Expr::bitand(Expr::var("pattern_id"), Expr::u32(31)),
),
),
),
]
}),
];
vec![
Node::let_bind("state", Expr::u32(0)),
Node::let_bind("scan_start", scan_start),
Node::let_bind("scan_end", end),
Node::loop_for(
"step",
Expr::var("scan_start"),
Expr::var("scan_end"),
vec![
load_step_byte,
Node::assign(
"state",
Expr::load(
transitions,
Expr::add(Expr::mul(Expr::var("state"), Expr::u32(256)), step_byte),
),
),
],
),
Node::let_bind("out_begin", Expr::load(output_offsets, Expr::var("state"))),
Node::let_bind(
"out_end",
Expr::load(output_offsets, Expr::add(Expr::var("state"), Expr::u32(1))),
),
Node::if_then(
Expr::lt(Expr::var("out_begin"), Expr::var("out_end")),
region_and_emit,
),
]
}
#[must_use]
pub fn build_ac_bounded_ranges_program(
dfa: &CompiledDfa,
pattern_count: u32,
max_matches: u32,
) -> Program {
build_ac_bounded_ranges_program_ext(dfa, pattern_count, max_matches, true)
}
#[must_use]
pub fn build_ac_bounded_ranges_program_ext(
dfa: &CompiledDfa,
pattern_count: u32,
max_matches: u32,
use_subgroup_coalesce: bool,
) -> Program {
match try_build_ac_bounded_ranges_program_ext(
dfa,
pattern_count,
max_matches,
use_subgroup_coalesce,
) {
Ok(program) => program,
Err(error) => {
eprintln!("vyre-libs AC bounded-ranges program build failed: {error}");
empty_ac_bounded_ranges_program(max_matches, use_subgroup_coalesce)
}
}
}
pub fn try_build_ac_bounded_ranges_program(
dfa: &CompiledDfa,
pattern_count: u32,
max_matches: u32,
) -> Result<Program, String> {
try_build_ac_bounded_ranges_program_ext(dfa, pattern_count, max_matches, true)
}
pub fn try_build_ac_bounded_ranges_program_ext(
dfa: &CompiledDfa,
pattern_count: u32,
max_matches: u32,
use_subgroup_coalesce: bool,
) -> Result<Program, String> {
let output_records_len = u32::try_from(dfa.output_records.len()).map_err(|source| {
format!(
"AC bounded-ranges DFA output record count {} exceeds u32 GPU buffer metadata: {source}. Fix: shard the pattern set or lower the DFA budget before dispatch.",
dfa.output_records.len()
)
})?;
Ok(classic_ac_bounded_ranges_program_ext(
"haystack",
"transitions",
"output_offsets",
"output_records",
"pattern_lengths",
"haystack_len",
"match_count",
"matches",
dfa.state_count,
output_records_len,
pattern_count,
max_matches,
dfa.max_pattern_len,
use_subgroup_coalesce,
))
}
fn empty_ac_bounded_ranges_program(max_matches: u32, use_subgroup_coalesce: bool) -> Program {
classic_ac_bounded_ranges_program_ext(
"haystack",
"transitions",
"output_offsets",
"output_records",
"pattern_lengths",
"haystack_len",
"match_count",
"matches",
1,
0,
0,
max_matches,
0,
use_subgroup_coalesce,
)
}
#[must_use]
#[cfg(any(test, feature = "cpu-parity"))]
pub fn classic_ac_bounded_ranges_scan(
ac: &ClassicAcAutomaton,
pattern_lengths: &[u32],
haystack: &[u8],
) -> Vec<(u32, u32, u32)> {
let dfa = &ac.dfa;
let mut state = 0u32;
let mut out = Vec::new();
for (pos, &b) in haystack.iter().enumerate() {
state = dfa.transitions[(state as usize) * 256 + (b as usize)];
let begin = dfa.output_offsets[state as usize] as usize;
let end_off = dfa.output_offsets[state as usize + 1] as usize;
for &pid in &dfa.output_records[begin..end_off] {
let pat_len = pattern_lengths.get(pid as usize).copied().unwrap_or(0);
let end_pos = (pos as u32).saturating_add(1);
let start = end_pos.saturating_sub(pat_len);
out.push((pid, start, end_pos));
}
}
out
}