use std::error::Error;
use std::fmt;
use vyre_foundation::ir::Program;
use vyre_primitives::matching::{nfa_to_dfa, CompiledDfa, NfaTables, NfaToDfaError};
use crate::scan::classic_ac::try_build_ac_bounded_ranges_program_ext;
use crate::scan::regex_compile::{compile_regex_set, CompiledRegexSet, RegexCompileError};
#[derive(Debug, Clone)]
pub struct RegexDfaPipeline {
pub program: Program,
pub dfa: CompiledDfa,
pub pattern_lengths: Vec<u32>,
}
#[derive(Debug)]
#[non_exhaustive]
pub enum RegexDfaError {
Compile(RegexCompileError),
Lower(NfaToDfaError),
Size {
message: String,
},
}
impl fmt::Display for RegexDfaError {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Compile(error) => write!(formatter, "regex NFA compile failed: {error}"),
Self::Lower(error) => {
write!(formatter, "NFA → DFA subset construction failed: {error}")
}
Self::Size { message } => write!(formatter, "regex DFA sizing failed: {message}"),
}
}
}
impl Error for RegexDfaError {}
impl From<RegexCompileError> for RegexDfaError {
fn from(error: RegexCompileError) -> Self {
Self::Compile(error)
}
}
impl From<NfaToDfaError> for RegexDfaError {
fn from(error: NfaToDfaError) -> Self {
Self::Lower(error)
}
}
pub fn build_regex_dfa_pipeline(
patterns: &[&str],
max_matches: u32,
max_dfa_states: usize,
) -> Result<RegexDfaPipeline, RegexDfaError> {
build_regex_dfa_pipeline_ext(patterns, max_matches, max_dfa_states, true)
}
pub fn build_regex_dfa_pipeline_ext(
patterns: &[&str],
max_matches: u32,
max_dfa_states: usize,
use_subgroup_coalesce: bool,
) -> Result<RegexDfaPipeline, RegexDfaError> {
let regex_set = compile_regex_set(patterns)?;
finish_regex_dfa_pipeline(
regex_set,
patterns,
max_matches,
max_dfa_states,
use_subgroup_coalesce,
)
}
pub fn build_regex_dfa_unanchored(
patterns: &[&str],
max_matches: u32,
max_dfa_states: usize,
) -> Result<RegexDfaPipeline, RegexDfaError> {
let mut regex_set = compile_regex_set(patterns)?;
add_implicit_dotstar_prefix(
&mut regex_set.transition_table,
regex_set.plan.num_states as usize,
);
finish_regex_dfa_pipeline(regex_set, patterns, max_matches, max_dfa_states, true)
}
fn add_implicit_dotstar_prefix(transition_table: &mut [u32], num_states: usize) {
if num_states == 0 {
return;
}
let denom = num_states.saturating_mul(256);
if denom == 0 || transition_table.len() % denom != 0 {
return; }
let lanes = transition_table.len() / denom;
for byte in 0..256usize {
let idx = byte * lanes;
if idx < transition_table.len() {
transition_table[idx] |= 1; }
}
}
fn finish_regex_dfa_pipeline(
regex_set: CompiledRegexSet,
patterns: &[&str],
max_matches: u32,
max_dfa_states: usize,
use_subgroup_coalesce: bool,
) -> Result<RegexDfaPipeline, RegexDfaError> {
let mut accept_pattern_ids: Vec<u32> = Vec::new();
reserve_regex_vec(
&mut accept_pattern_ids,
regex_set.plan.accept_states.len(),
"accept pattern id table",
)?;
accept_pattern_ids.extend(regex_set.plan.accept_states.iter().map(|(pid, _)| *pid));
let max_pattern_len = regex_set
.plan
.accept_states
.iter()
.map(|(_, len)| *len)
.max()
.unwrap_or(0);
let pattern_count = u32::try_from(patterns.len()).map_err(|source| RegexDfaError::Size {
message: format!(
"pattern count {} exceeds u32 GPU buffer metadata: {source}. Fix: shard the regex set before building a DFA dispatch.",
patterns.len()
),
})?;
let mut pattern_lengths = Vec::new();
reserve_regex_vec(&mut pattern_lengths, patterns.len(), "pattern length table")?;
pattern_lengths.resize(patterns.len(), 0);
for (pid, len) in ®ex_set.plan.accept_states {
let idx = usize::try_from(*pid).map_err(|source| RegexDfaError::Size {
message: format!(
"accept pattern id {pid} cannot fit usize for pattern-length indexing: {source}. Fix: shard the regex set before building a DFA dispatch."
),
})?;
if idx < pattern_lengths.len() && *len > pattern_lengths[idx] {
pattern_lengths[idx] = *len;
}
}
let tables = NfaTables {
num_states: regex_set.plan.num_states,
transition_table: ®ex_set.transition_table,
epsilon_table: ®ex_set.epsilon_table,
accept_state_ids: ®ex_set.plan.accept_state_ids,
accept_pattern_ids: &accept_pattern_ids,
max_pattern_len,
};
let dfa = nfa_to_dfa(&tables, max_dfa_states)?;
let program = try_build_ac_bounded_ranges_program_ext(
&dfa,
pattern_count,
max_matches,
use_subgroup_coalesce,
)
.map_err(|message| RegexDfaError::Size { message })?;
Ok(RegexDfaPipeline {
program,
dfa,
pattern_lengths,
})
}
fn reserve_regex_vec<T>(
vec: &mut Vec<T>,
requested: usize,
label: &'static str,
) -> Result<(), RegexDfaError> {
vyre_foundation::allocation::try_reserve_vec_to_capacity(vec, requested).map_err(|source| {
RegexDfaError::Size {
message: format!(
"regex DFA {label} reservation failed for {requested} item(s): {source}. Fix: shard the regex set or lower the DFA budget before dispatch."
),
}
})
}
#[cfg(test)]
mod tests {
use super::*;
fn single_pass_accept_ends(dfa: &CompiledDfa, haystack: &[u8]) -> Vec<usize> {
let mut state = 0u32;
let mut ends = Vec::new();
for (i, &b) in haystack.iter().enumerate() {
state = dfa.transitions[state as usize * 256 + b as usize];
if dfa.accept[state as usize] != 0 {
ends.push(i + 1);
}
}
ends
}
#[test]
fn unanchored_dfa_matches_at_any_offset_single_pass() {
let anchored = build_regex_dfa_pipeline(&["abc"], 1024, 1024).expect("anchored compiles");
let unanchored =
build_regex_dfa_unanchored(&["abc"], 1024, 1024).expect("unanchored compiles");
assert_eq!(
single_pass_accept_ends(&unanchored.dfa, b"xxabc"),
vec![5],
"unanchored DFA must match `abc` after a non-matching prefix"
);
assert!(
single_pass_accept_ends(&anchored.dfa, b"xxabc").is_empty(),
"anchored DFA must NOT match `abc` after a non-matching prefix"
);
assert_eq!(single_pass_accept_ends(&unanchored.dfa, b"abc"), vec![3]);
assert_eq!(single_pass_accept_ends(&anchored.dfa, b"abc"), vec![3]);
assert_eq!(
single_pass_accept_ends(&unanchored.dfa, b"abcxabc"),
vec![3, 7],
"unanchored DFA must find all occurrences"
);
}
#[test]
fn unanchored_dfa_finds_overlap_body_token_single_pass() {
let dfa = build_regex_dfa_unanchored(&["ghp_[A-Za-z0-9]{36}"], 1024, 16384)
.expect("compiles")
.dfa;
let hay = b"at = \"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\"";
let ends = single_pass_accept_ends(&dfa, hay);
assert_eq!(
ends,
vec![hay.len() - 1],
"unanchored DFA must accept the ghp_ token exactly before the closing quote"
);
}
#[test]
fn unanchored_dfa_finds_all_parity_gate_misses_single_pass() {
let cases: &[(&str, &[u8])] = &[
(
"ghp_[A-Za-z0-9]{36}",
b"at = \"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\"",
),
(
"gho_[A-Za-z0-9]{36}",
b"ken: \"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\"",
),
(
"ghu_[A-Za-z0-9]{36}",
b"Key: \"ghu_m7BOv2Uj0AZZK088M7RQJkZX3EgBVV1Xt7i2\"",
),
(
"ghu_[A-Za-z0-9]{36}",
b"OKEN: ghu_4u5ef0rIhtKpPV1F0dPwwhXNMpEXkB0tWWQv",
),
(
"xox[baprs]-[A-Za-z0-9-]{10,48}",
b"Key: \"xoxb-1234567890-1234567890-EXAMPLE-TOKEN\"",
),
(
"xox[baprs]-[A-Za-z0-9-]{10,48}",
b"_KEY=\"xoxb-32790994721-16118213278-q5KLPWcLboh0tchHpJPgWhuC\"",
),
];
for (pat, hay) in cases {
let dfa = build_regex_dfa_unanchored(&[pat], 1024, 16384)
.unwrap_or_else(|e| panic!("pattern {pat:?} must compile: {e:?}"))
.dfa;
let ends = single_pass_accept_ends(&dfa, hay);
let expected_end = if hay.ends_with(b"\"") {
hay.len() - 1
} else {
hay.len()
};
assert!(
ends.contains(&expected_end),
"dense CompiledDfa for {pat:?} must accept the full token end {expected_end} \
in {:?} (single-pass); got {ends:?}. state_count={}",
String::from_utf8_lossy(hay),
dfa.state_count,
);
assert!(
ends.iter().all(|end| *end <= expected_end),
"dense CompiledDfa for {pat:?} must not accept past token boundary {expected_end} \
in {:?}; got {ends:?}. state_count={}",
String::from_utf8_lossy(hay),
dfa.state_count,
);
}
}
#[test]
fn literal_pattern_set_lowers_through_to_dfa_program() {
let pipeline =
build_regex_dfa_pipeline(&["abc"], 1024, 1024).expect("Fix: literal must compile");
assert!(
pipeline.dfa.state_count >= 4,
"literal 'abc' DFA must have at least 4 states (entry + 3 progress); got {}",
pipeline.dfa.state_count
);
assert_eq!(
pipeline.pattern_lengths,
vec![3],
"single literal 'abc' must have pattern_lengths = [3]"
);
assert!(
pipeline
.dfa
.accept
.iter()
.any(|&pid_plus_one| pid_plus_one == 1),
"at least one DFA state must accept pattern 0 (encoded as accept = 1)"
);
let names: Vec<&str> = pipeline.program.buffers.iter().map(|b| b.name()).collect();
for expected in [
"haystack",
"transitions",
"output_offsets",
"output_records",
"pattern_lengths",
"haystack_len",
"match_count",
"matches",
] {
assert!(
names.contains(&expected),
"RegexDfaPipeline program must declare buffer `{expected}` for AC dispatch; got {names:?}"
);
}
}
#[test]
fn multi_literal_set_emits_distinct_accept_pids() {
let pipeline = build_regex_dfa_pipeline(&["abc", "xyz"], 1024, 1024)
.expect("Fix: two literals must compile");
assert_eq!(pipeline.pattern_lengths, vec![3, 3]);
let has_pid0 = pipeline.dfa.accept.iter().any(|&value| value == 1);
let has_pid1 = pipeline.dfa.accept.iter().any(|&value| value == 2);
assert!(has_pid0, "no DFA state accepts pid 0 - 'abc' lost in lower");
assert!(has_pid1, "no DFA state accepts pid 1 - 'xyz' lost in lower");
}
#[test]
fn state_explosion_surfaces_as_error_not_panic() {
let err = build_regex_dfa_pipeline(&["abc"], 1024, 1)
.expect_err("max_dfa_states=1 must trip state explosion");
match err {
RegexDfaError::Lower(NfaToDfaError::StateExplosion { .. }) => {}
other => panic!("expected Lower(StateExplosion), got {other:?}"),
}
}
#[test]
fn character_class_pattern_lowers_to_acceptor_dfa() {
let pipeline = build_regex_dfa_pipeline(&["[ab]c"], 1024, 1024)
.expect("Fix: character class must compile");
assert!(
pipeline.dfa.accept.iter().any(|&value| value != 0),
"DFA for '[ab]c' must accept at least one state"
);
}
#[test]
fn regex_dfa_pipeline_uses_checked_size_conversions() {
let production = include_str!("regex_dfa.rs")
.split("\n#[cfg(test)]\nmod tests")
.next()
.expect("Fix: regex DFA production section should precede tests");
assert!(
production.contains("RegexDfaError::Size"),
"Fix: regex DFA sizing failures must be structured errors, not panics or unchecked casts."
);
assert!(
production.contains("u32::try_from(patterns.len())"),
"Fix: regex DFA pattern count must use checked conversion for GPU ABI metadata."
);
assert!(
production.contains("usize::try_from(*pid)"),
"Fix: regex DFA accept pattern ids must use checked host indexing conversion."
);
assert!(
production.contains("try_build_ac_bounded_ranges_program_ext"),
"Fix: regex DFA must call the fallible AC program builder."
);
assert!(
!production.contains("patterns.len() as u32"),
"Fix: regex DFA must not narrow pattern counts with unchecked casts."
);
}
}