use vyre::VyreBackend;
use vyre_foundation::ir::Program;
use vyre_foundation::match_result::Match;
use super::nfa;
const NFA_LANES: usize = vyre_primitives::nfa::subgroup_nfa::LANES_PER_SUBGROUP;
#[derive(Debug, Clone)]
pub struct RulePipeline {
pub program: Program,
pub transition_table: Vec<u32>,
pub epsilon_table: Vec<u32>,
pub plan: nfa::NfaPlan,
}
impl RulePipeline {
pub fn scan<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
) -> Result<Vec<Match>, vyre::BackendError> {
let mut matches = Vec::new();
self.scan_into(backend, haystack, max_matches, &mut matches)?;
Ok(matches)
}
pub fn scan_bounded<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
max_scan_bytes: u32,
) -> Result<Vec<Match>, vyre::BackendError> {
let mut matches = Vec::new();
self.scan_bounded_into(backend, haystack, max_matches, max_scan_bytes, &mut matches)?;
Ok(matches)
}
pub fn scan_into<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
matches: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
self.scan_bounded_into(backend, haystack, max_matches, u32::MAX, matches)
}
pub fn scan_bounded_into<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
max_scan_bytes: u32,
matches: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
let mut scratch = crate::scan::dispatch_io::ScanDispatchScratch::default();
self.scan_bounded_into_with_scratch(
backend,
haystack,
max_matches,
max_scan_bytes,
matches,
&mut scratch,
)
}
pub fn scan_bounded_into_with_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
max_scan_bytes: u32,
matches: &mut Vec<Match>,
scratch: &mut crate::scan::dispatch_io::ScanDispatchScratch,
) -> Result<(), vyre::BackendError> {
use crate::scan::dispatch_io;
matches.clear();
let haystack_len = dispatch_io::scan_guard(
haystack,
"RulePipeline::scan",
dispatch_io::DEFAULT_MAX_SCAN_BYTES,
)?;
zeroed_hit_buffer_into(max_matches, &mut scratch.hit_bytes)?;
dispatch_io::pack_haystack_u32_into(haystack, &mut scratch.haystack_bytes)?;
let hit_bytes = scratch.hit_bytes.as_slice();
let haystack_bytes = scratch.haystack_bytes.as_slice();
let transition_bytes = dispatch_io::u32_words_as_le_bytes(&self.transition_table);
let epsilon_bytes = dispatch_io::u32_words_as_le_bytes(&self.epsilon_table);
let haystack_len_bytes = haystack_len.to_le_bytes();
let max_scan_bytes_bytes = max_scan_bytes.to_le_bytes();
let config = dispatch_io::candidate_start_dispatch_config(haystack_len);
let borrowed_inputs: smallvec::SmallVec<[&[u8]; 6]> = [
haystack_bytes,
transition_bytes.as_ref(),
epsilon_bytes.as_ref(),
hit_bytes,
haystack_len_bytes.as_slice(),
max_scan_bytes_bytes.as_slice(),
]
.into_iter()
.collect();
let outputs = backend.dispatch_borrowed(&self.program, &borrowed_inputs, &config)?;
let hit_bytes = dispatch_io::try_output_bytes(&outputs, 0, "RulePipeline hit buffer")?;
let count = dispatch_io::try_read_u32_prefix(hit_bytes, "RulePipeline hit buffer")?;
dispatch_io::try_unpack_match_triples_exact_prefix_into(
&hit_bytes[4..],
count.min(max_matches),
matches,
)?;
Ok(())
}
#[must_use]
pub fn reference_scan(&self, haystack: &[u8]) -> Vec<Match> {
match self.try_reference_scan(haystack) {
Ok(matches) => matches,
Err(error) => {
eprintln!("vyre-libs RulePipeline::reference_scan failed: {error}");
Vec::new()
}
}
}
pub fn try_reference_scan(&self, haystack: &[u8]) -> Result<Vec<Match>, vyre::BackendError> {
let mut results = Vec::new();
self.try_reference_scan_into(haystack, &mut results)?;
Ok(results)
}
pub fn try_reference_scan_into(
&self,
haystack: &[u8],
results: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
crate::scan::dispatch_io::scan_guard(haystack, "RulePipeline::reference_scan", u32::MAX)?;
results.clear();
for start in 0..haystack.len() {
let start_u32 = u32::try_from(start).map_err(|_| {
vyre::BackendError::new(
"RulePipeline::reference_scan start offset exceeds u32 capacity. Fix: split the haystack before parity scanning.",
)
})?;
let mut state = [0_u32; NFA_LANES];
let mut next = [0_u32; NFA_LANES];
state[0] = 1;
close_epsilon(&mut state, &self.epsilon_table, self.plan.num_states as usize);
for (cursor, &byte) in haystack.iter().enumerate().skip(start) {
next.fill(0);
for (lane, &peer) in state.iter().enumerate() {
for bit in 0..32 {
if (peer >> bit) & 1 == 0 {
continue;
}
let src_state = lane * 32 + bit;
if src_state >= self.plan.num_states as usize {
continue;
}
let base = src_state * 256 * NFA_LANES + (byte as usize) * NFA_LANES;
for (dst_lane, slot) in next.iter_mut().enumerate() {
*slot |= self.transition_table[base + dst_lane];
}
}
}
std::mem::swap(&mut state, &mut next);
close_epsilon(&mut state, &self.epsilon_table, self.plan.num_states as usize);
for (&accept_state, &(pattern_id, _pattern_len)) in self
.plan
.accept_state_ids
.iter()
.zip(&self.plan.accept_states)
{
let lane = (accept_state / 32) as usize;
let bit = accept_state % 32;
if lane < state.len() && (state[lane] & (1_u32 << bit)) != 0 {
let end_u32 = u32::try_from(cursor + 1).map_err(|_| {
vyre::BackendError::new(
"RulePipeline::reference_scan end offset exceeds u32 capacity. Fix: split the haystack before parity scanning.",
)
})?;
results.push(Match::new(pattern_id, start_u32, end_u32));
}
}
}
}
results.sort_unstable();
Ok(())
}
}
fn close_epsilon(state: &mut [u32; NFA_LANES], epsilon_table: &[u32], num_states: usize) {
if epsilon_table.is_empty() || num_states == 0 {
return;
}
for _ in 0..num_states {
let snapshot = *state;
for (lane, &peer) in snapshot.iter().enumerate() {
if peer == 0 {
continue;
}
for bit in 0..32 {
if (peer >> bit) & 1 == 0 {
continue;
}
let src_state = lane * 32 + bit;
if src_state >= num_states {
continue;
}
let base = src_state * NFA_LANES;
for (dst_lane, slot) in state.iter_mut().enumerate() {
if let Some(&bits) = epsilon_table.get(base + dst_lane) {
*slot |= bits;
}
}
}
}
if *state == snapshot {
break; }
}
}
#[must_use]
pub fn build(patterns: &[&str], input_buf: &str, hit_buf: &str, input_len: u32) -> RulePipeline {
let plan = nfa::compile(patterns).for_input_len(input_len);
let program = nfa::nfa_scan(patterns, input_buf, hit_buf, input_len);
let transition_table = nfa::build_transition_table(patterns);
let epsilon_table = nfa::build_epsilon_table(patterns);
RulePipeline {
program,
transition_table,
epsilon_table,
plan,
}
}
pub(crate) fn hit_buffer_byte_len(max_matches: u32) -> Result<usize, vyre::BackendError> {
let match_words = usize::try_from(max_matches)
.map_err(|_| {
vyre::BackendError::new(
"RulePipeline::scan max_matches exceeds host usize capacity. Fix: reduce max_matches or shard the scan.",
)
})?
.checked_mul(3)
.and_then(|words| words.checked_add(1))
.ok_or_else(|| {
vyre::BackendError::new(
"RulePipeline::scan hit-buffer word count overflowed. Fix: reduce max_matches or shard the scan.",
)
})?;
match_words.checked_mul(4).ok_or_else(|| {
vyre::BackendError::new(
"RulePipeline::scan hit-buffer byte count overflowed. Fix: reduce max_matches or shard the scan.",
)
})
}
fn zeroed_hit_buffer(max_matches: u32) -> Result<Vec<u8>, vyre::BackendError> {
let byte_len = hit_buffer_byte_len(max_matches)?;
let mut bytes = Vec::new();
zeroed_hit_buffer_into(max_matches, &mut bytes)?;
debug_assert_eq!(bytes.len(), byte_len);
Ok(bytes)
}
fn zeroed_hit_buffer_into(max_matches: u32, bytes: &mut Vec<u8>) -> Result<(), vyre::BackendError> {
let byte_len = hit_buffer_byte_len(max_matches)?;
bytes.clear();
vyre_foundation::allocation::try_reserve_vec_to_capacity(bytes, byte_len).map_err(
|source| {
vyre::BackendError::new(format!(
"RulePipeline::scan could not reserve {byte_len} hit-buffer byte(s): {source}. Fix: lower max_matches or shard the scan."
))
},
)?;
bytes.resize(byte_len, 0);
Ok(())
}
fn reserve_wire_vec<T>(
vec: &mut Vec<T>,
requested: usize,
field: &'static str,
) -> Result<(), PipelineWireError> {
vyre_foundation::allocation::try_reserve_vec_to_capacity(vec, requested).map_err(|source| {
PipelineWireError::StorageReserveFailed {
field,
requested,
message: source.to_string(),
}
})
}
const PIPELINE_WIRE_MAGIC: &[u8; 4] = b"VRPL";
const PIPELINE_WIRE_VERSION: u32 = 4;
#[derive(Debug)]
#[non_exhaustive]
pub enum PipelineWireError {
WireFraming(vyre_foundation::serial::envelope::EnvelopeError),
InvalidProgram(String),
ShapeMismatch {
reason: &'static str,
},
StorageReserveFailed {
field: &'static str,
requested: usize,
message: String,
},
}
impl std::fmt::Display for PipelineWireError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WireFraming(e) => write!(f, "RulePipeline wire envelope: {e}"),
Self::InvalidProgram(msg) => {
write!(f, "RulePipeline wire blob has invalid Program: {msg}")
}
Self::ShapeMismatch { reason } => {
write!(f, "RulePipeline wire blob shape mismatch: {reason}")
}
Self::StorageReserveFailed {
field,
requested,
message,
} => write!(
f,
"RulePipeline wire serialization could not reserve {requested} {field} slot(s): {message}. Fix: shard the pattern pipeline before serialization."
),
}
}
}
impl std::error::Error for PipelineWireError {}
impl RulePipeline {
pub fn to_bytes(&self) -> Result<Vec<u8>, PipelineWireError> {
let mut w = vyre_foundation::serial::envelope::WireWriter::new(
PIPELINE_WIRE_MAGIC,
PIPELINE_WIRE_VERSION,
);
w.write_u32(self.plan.num_states);
w.write_u32(self.plan.input_len);
w.write_section(&self.program.to_bytes())
.map_err(PipelineWireError::WireFraming)?;
w.write_words(&self.transition_table)
.map_err(PipelineWireError::WireFraming)?;
w.write_words(&self.epsilon_table)
.map_err(PipelineWireError::WireFraming)?;
let accept_flat_words = self.plan.accept_states.len().checked_mul(2).ok_or(
PipelineWireError::ShapeMismatch {
reason: "accept_states length overflows flattened word count",
},
)?;
let mut accept_flat: Vec<u32> = Vec::new();
reserve_wire_vec(&mut accept_flat, accept_flat_words, "accept state word")?;
for &(pid, len) in &self.plan.accept_states {
accept_flat.push(pid);
accept_flat.push(len);
}
w.write_words(&accept_flat)
.map_err(PipelineWireError::WireFraming)?;
w.write_words(&self.plan.accept_state_ids)
.map_err(PipelineWireError::WireFraming)?;
let mut anchor_flags: Vec<u32> = Vec::new();
reserve_wire_vec(
&mut anchor_flags,
self.plan.accept_states.len(),
"accept anchor flag",
)?;
for idx in 0..self.plan.accept_states.len() {
let mut flags = 0u32;
if self
.plan
.accept_start_anchored
.get(idx)
.copied()
.unwrap_or(false)
{
flags |= 1;
}
if self
.plan
.accept_end_anchored
.get(idx)
.copied()
.unwrap_or(false)
{
flags |= 2;
}
anchor_flags.push(flags);
}
w.write_words(&anchor_flags)
.map_err(PipelineWireError::WireFraming)?;
Ok(w.into_bytes())
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, PipelineWireError> {
let mut r = vyre_foundation::serial::envelope::WireReader::new(
bytes,
PIPELINE_WIRE_MAGIC,
PIPELINE_WIRE_VERSION,
)
.map_err(PipelineWireError::WireFraming)?;
let num_states = r.read_u32().map_err(PipelineWireError::WireFraming)?;
let input_len = r.read_u32().map_err(PipelineWireError::WireFraming)?;
let program_bytes = r.read_section().map_err(PipelineWireError::WireFraming)?;
let program = vyre_foundation::ir::Program::from_bytes(program_bytes)
.map_err(|e| PipelineWireError::InvalidProgram(format!("{e}")))?;
let transition_table = r.read_words().map_err(PipelineWireError::WireFraming)?;
let epsilon_table = r.read_words().map_err(PipelineWireError::WireFraming)?;
let accept_flat = r.read_words().map_err(PipelineWireError::WireFraming)?;
let accept_state_ids = r.read_words().map_err(PipelineWireError::WireFraming)?;
let anchor_flags = r.read_words().map_err(PipelineWireError::WireFraming)?;
if accept_flat.len() % 2 != 0 {
return Err(PipelineWireError::ShapeMismatch {
reason: "accept_states array length is not even",
});
}
let accept_states: Vec<(u32, u32)> =
accept_flat.chunks_exact(2).map(|w| (w[0], w[1])).collect();
if accept_state_ids.len() != accept_states.len() {
return Err(PipelineWireError::ShapeMismatch {
reason: "accept_state_ids length disagrees with accept_states length",
});
}
if anchor_flags.len() != accept_states.len() {
return Err(PipelineWireError::ShapeMismatch {
reason: "accept anchor flag length disagrees with accept_states length",
});
}
let accept_start_anchored = anchor_flags.iter().map(|flags| flags & 1 != 0).collect();
let accept_end_anchored = anchor_flags.iter().map(|flags| flags & 2 != 0).collect();
Ok(RulePipeline {
program,
transition_table,
epsilon_table,
plan: nfa::NfaPlan {
num_states,
input_len,
accept_states,
accept_state_ids,
accept_start_anchored,
accept_end_anchored,
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Clone)]
struct RuleReadbackBackend {
outputs: Vec<Vec<u8>>,
}
impl vyre::backend::private::Sealed for RuleReadbackBackend {}
impl VyreBackend for RuleReadbackBackend {
fn id(&self) -> &'static str {
"rule-readback-test"
}
fn dispatch(
&self,
_program: &Program,
_inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
Ok(self.outputs.clone())
}
fn dispatch_borrowed(
&self,
_program: &Program,
_inputs: &[&[u8]],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
Ok(self.outputs.clone())
}
}
fn hit_buffer_bytes(count: u32, triples: &[u8]) -> Vec<u8> {
let mut bytes = Vec::with_capacity(4 + triples.len());
bytes.extend_from_slice(&count.to_le_bytes());
bytes.extend_from_slice(triples);
bytes
}
fn match_triple_bytes(pattern_id: u32, start: u32, end: u32) -> Vec<u8> {
let mut bytes = Vec::with_capacity(12);
bytes.extend_from_slice(&pattern_id.to_le_bytes());
bytes.extend_from_slice(&start.to_le_bytes());
bytes.extend_from_slice(&end.to_le_bytes());
bytes
}
#[test]
fn integrator_returns_primitive_compatible_tables() {
let pipe = build(&["abc"], "input", "hits", 16);
let plan = nfa::compile(&["abc"]);
let expected_trans_len = (plan.num_states as usize)
* 256
* vyre_primitives::nfa::subgroup_nfa::LANES_PER_SUBGROUP;
let expected_eps_len =
(plan.num_states as usize) * vyre_primitives::nfa::subgroup_nfa::LANES_PER_SUBGROUP;
assert_eq!(pipe.transition_table.len(), expected_trans_len);
assert_eq!(pipe.epsilon_table.len(), expected_eps_len);
}
#[test]
fn integrator_plan_matches_compile() {
let pipe = build(&["ab", "cd"], "input", "hits", 8);
assert_eq!(pipe.plan.num_states, 5);
assert_eq!(pipe.plan.input_len, 8);
assert_eq!(pipe.plan.accept_states.len(), 2);
}
#[test]
fn rule_pipeline_reference_scan_into_matches_owned_scan_and_reuses_scratch() {
let pipe = build(&["ab", "bc"], "input", "hits", 16);
let owned = pipe.reference_scan(b"zabc");
let mut scratch = Vec::with_capacity(16);
let retained_capacity = scratch.capacity();
pipe.try_reference_scan_into(b"zabc", &mut scratch)
.expect("Fix: RulePipeline CPU oracle should scan small haystacks");
assert_eq!(scratch, owned);
assert!(scratch.capacity() >= retained_capacity);
assert_eq!(scratch, vec![Match::new(0, 1, 3), Match::new(1, 2, 4)]);
}
#[test]
fn rule_pipeline_hit_buffer_allocation_is_checked_and_zeroed() {
let bytes = super::zeroed_hit_buffer(2)
.expect("Fix: small RulePipeline hit buffer should allocate");
assert_eq!(bytes.len(), (2 * 3 + 1) * 4);
assert!(bytes.iter().all(|&byte| byte == 0));
}
#[test]
fn rule_pipeline_hit_buffer_into_reuses_and_zeroes_scratch() {
let mut scratch = vec![0xAA; 128];
let retained = scratch.capacity();
super::zeroed_hit_buffer_into(3, &mut scratch)
.expect("Fix: RulePipeline hit buffer scratch should reserve");
assert_eq!(scratch.len(), (3 * 3 + 1) * 4);
assert!(scratch.iter().all(|&byte| byte == 0));
assert!(scratch.capacity() >= retained);
}
#[test]
fn rule_pipeline_scan_rejects_missing_hit_output_slot() {
let pipe = build(&["ab"], "input", "hits", 16);
let backend = RuleReadbackBackend {
outputs: Vec::new(),
};
let mut matches = vec![Match::new(99, 1, 2)];
let err = pipe
.scan_into(&backend, b"ab", 1, &mut matches)
.expect_err("missing RulePipeline hit output must fail");
let msg = err.to_string();
assert!(
matches.is_empty(),
"scan errors must not expose stale matches"
);
assert!(
msg.contains("RulePipeline hit buffer") && msg.contains("output index 0"),
"RulePipeline missing-output error must identify the omitted slot: {msg}"
);
}
#[test]
fn rule_pipeline_scan_rejects_short_hit_counter_readback() {
let pipe = build(&["ab"], "input", "hits", 16);
let backend = RuleReadbackBackend {
outputs: vec![vec![1, 2, 3]],
};
let mut matches = vec![Match::new(99, 1, 2)];
let err = pipe
.scan_into(&backend, b"ab", 1, &mut matches)
.expect_err("short RulePipeline counter readback must fail");
let msg = err.to_string();
assert!(
matches.is_empty(),
"scan errors must not expose stale matches"
);
assert!(
msg.contains("RulePipeline hit buffer") && msg.contains("requires 4 bytes"),
"RulePipeline counter error must name the malformed hit buffer: {msg}"
);
}
#[test]
fn rule_pipeline_scan_rejects_match_payload_shorter_than_reported_count() {
let pipe = build(&["ab"], "input", "hits", 16);
let backend = RuleReadbackBackend {
outputs: vec![hit_buffer_bytes(2, &match_triple_bytes(0, 0, 2))],
};
let mut matches = vec![Match::new(99, 1, 2)];
let err = pipe
.scan_into(&backend, b"ab", 2, &mut matches)
.expect_err("short RulePipeline match payload must fail");
let msg = err.to_string();
assert!(
matches.is_empty(),
"scan errors must not expose stale matches"
);
assert!(
msg.contains("readback was 12 byte(s)")
&& msg.contains("count=2")
&& msg.contains("requires 24 byte(s)"),
"RulePipeline match-payload error must identify observed and required bytes: {msg}"
);
}
#[test]
fn rule_pipeline_reference_scan_state_is_stack_backed() {
let production = include_str!("mega_scan.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: mega_scan.rs must contain production section");
assert!(
production.contains("let mut state = [0_u32; NFA_LANES];")
&& production.contains("let mut next = [0_u32; NFA_LANES];")
&& production.contains("next.fill(0);")
&& !production.contains("vec![0_u32;")
&& !production.contains("Vec::with_capacity"),
"Fix: RulePipeline scan and wire paths must use checked shared reservation helpers instead of nested subgroup vector allocation or infallible capacity allocation."
);
}
#[test]
fn rule_pipeline_program_declares_haystack_len_buffer() {
let pipe = build(&["ab"], "input", "hits", 1024);
let names: Vec<&str> = pipe.program.buffers.iter().map(|b| b.name()).collect();
assert!(
names.iter().any(|n| *n == super::nfa::HAYSTACK_LEN_BUF),
"Fix: nfa_scan must declare `{}` so the cursor loop bound \
is runtime-supplied; without it, RulePipeline can only \
dispatch at exactly its compile-time input_len.",
super::nfa::HAYSTACK_LEN_BUF
);
}
#[test]
fn rule_pipeline_program_declares_max_scan_bytes_buffer() {
let pipe = build(&["ab"], "input", "hits", 1024);
let names: Vec<&str> = pipe.program.buffers.iter().map(|b| b.name()).collect();
assert!(
names.iter().any(|n| *n == super::nfa::MAX_SCAN_BYTES_BUF),
"Fix: nfa_scan must declare `{}` so the per-workgroup \
cursor cap is runtime-supplied; without it, RulePipeline \
dispatches at O(N²) per shard.",
super::nfa::MAX_SCAN_BYTES_BUF
);
}
}