use crate::scan::classic_ac::{
build_ac_bounded_count_suffix3_prefilter_program,
build_ac_bounded_ranges_suffix3_prefilter_program_ext,
classic_ac_candidate_suffix3_bloom_words, presence_bitmap_words, presence_by_region_words,
try_build_ac_bounded_ranges_suffix3_prefilter_program_ext,
try_build_ac_bounded_ranges_suffix3_presence_by_region_program,
try_build_ac_bounded_ranges_suffix3_presence_program, CLASSIC_AC_SUFFIX2_MASK_WORDS,
};
use crate::scan::dfa::{dfa_compile, CompiledDfa};
use crate::scan::dispatch_io::ScanDispatchScratch;
use std::borrow::Cow;
use std::collections::TryReserveError;
use vyre::ir::{Expr, Node, Program};
use vyre::{DispatchConfig, VyreBackend};
pub use vyre_foundation::match_result::Match;
use vyre_primitives::hash::fnv1a::{fnv1a64_initial_state, fnv1a64_update_byte};
use vyre_primitives::matching::DfaWireError;
const LITERAL_SET_DEFAULT_MAX_MATCHES: u32 = 10_000;
const MATCH_TRIPLE_WORDS: u32 = 3;
const U32_BYTES: usize = std::mem::size_of::<u32>();
const U32_COUNTER_BYTES: usize = std::mem::size_of::<u32>();
const LITERAL_SET_INPUT_COUNT: usize = 10;
const LITERAL_SET_COUNT_INPUT_COUNT: usize = 8;
pub const LITERAL_SET_MATCH_COUNT_RESOURCE_INDEX: usize = 6;
pub const LITERAL_SET_MATCHES_RESOURCE_INDEX: usize = 10;
pub const LITERAL_SET_RESET_RESOURCE_INDICES: [usize; 1] = [LITERAL_SET_MATCH_COUNT_RESOURCE_INDEX];
pub const LITERAL_SET_SCAN_RESOURCE_INDICES: [usize; 11] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
pub const LITERAL_SET_COUNT_RESOURCE_INDEX: usize = 7;
pub const LITERAL_SET_COUNT_RESET_RESOURCE_INDICES: [usize; 1] = [LITERAL_SET_COUNT_RESOURCE_INDEX];
pub const LITERAL_SET_COUNT_SCAN_RESOURCE_INDICES: [usize; 8] = [0, 1, 2, 3, 4, 5, 6, 7];
pub type LiteralMatch = Match;
#[derive(Debug)]
pub enum LiteralSetCompileError {
PatternCountOverflow {
count: usize,
},
PatternLengthOverflow {
pattern_index: usize,
len: usize,
},
PatternByteCountOverflow,
PatternByteCountExceedsGpuAbi {
count: usize,
},
StorageReserveFailed {
field: &'static str,
requested: usize,
message: String,
},
DispatchProgramBuildFailed {
message: String,
},
}
impl std::fmt::Display for LiteralSetCompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::PatternCountOverflow { count } => write!(
f,
"literal_set pattern count {count} exceeds u32 capacity. Fix: shard the pattern set before GPU compilation."
),
Self::PatternLengthOverflow { pattern_index, len } => write!(
f,
"literal_set pattern {pattern_index} length {len} exceeds u32 capacity. Fix: split or reject oversized literals before GPU compilation."
),
Self::PatternByteCountOverflow => write!(
f,
"literal_set total pattern byte count overflowed host usize. Fix: shard the pattern set before GPU compilation."
),
Self::PatternByteCountExceedsGpuAbi { count } => write!(
f,
"literal_set total pattern byte count {count} exceeds u32 capacity. Fix: shard the pattern set before GPU compilation."
),
Self::StorageReserveFailed {
field,
requested,
message,
} => write!(
f,
"literal_set compile failed to reserve {requested} {field} slot(s): {message}. Fix: shard the pattern set before GPU compilation."
),
Self::DispatchProgramBuildFailed { message } => write!(
f,
"literal_set DFA dispatch program build failed: {message}"
),
}
}
}
impl std::error::Error for LiteralSetCompileError {}
pub struct GpuLiteralSet {
pub dfa: CompiledDfa,
pub pattern_bytes: Vec<u32>,
pub pattern_offsets: Vec<u32>,
pub pattern_lengths: Vec<u32>,
pub program: Program,
}
#[derive(Debug, Default)]
pub struct LiteralSetScanScratch {
pub dispatch: ScanDispatchScratch,
cached_program: Option<CachedLiteralSetProgram>,
cached_count_program: Option<CachedLiteralSetCountProgram>,
cached_prefilter: Option<LiteralSetPrefilterTables>,
}
#[derive(Clone, Debug)]
pub struct LiteralSetPreparedScan {
pub program: Program,
pub inputs: Vec<Vec<u8>>,
pub dispatch_config: DispatchConfig,
pub haystack_len: u32,
pub max_matches: u32,
pub matches_output_bytes: usize,
pub encoded_input_bytes: u64,
}
impl LiteralSetPreparedScan {
#[must_use]
pub const fn match_count_readback_bytes(&self) -> usize {
U32_COUNTER_BYTES
}
pub fn match_triples_readback_bytes(
&self,
match_count: u32,
) -> Result<usize, vyre::BackendError> {
literal_set_match_triple_bytes(match_count.min(self.max_matches))
}
pub fn decode_outputs_into(
&self,
outputs: &[Vec<u8>],
matches: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
decode_literal_set_outputs_into(outputs, self.max_matches, matches)
}
}
#[derive(Clone, Debug)]
pub struct LiteralSetPreparedCount {
pub program: Program,
pub inputs: Vec<Vec<u8>>,
pub dispatch_config: DispatchConfig,
pub haystack_len: u32,
pub encoded_input_bytes: u64,
}
impl LiteralSetPreparedCount {
#[must_use]
pub const fn count_readback_bytes(&self) -> usize {
U32_COUNTER_BYTES
}
pub fn decode_outputs(&self, outputs: &[Vec<u8>]) -> Result<u32, vyre::BackendError> {
decode_literal_set_count_outputs(outputs)
}
}
#[derive(Debug)]
struct CachedLiteralSetProgram {
base_fingerprint: [u8; 32],
max_matches: u32,
program: Program,
}
#[derive(Debug)]
struct CachedLiteralSetCountProgram {
pattern_fingerprint: u64,
program: Program,
}
#[derive(Debug)]
struct LiteralSetPrefilterTables {
pattern_fingerprint: u64,
candidate_end_mask: [u32; 8],
candidate_suffix2_mask: [u32; CLASSIC_AC_SUFFIX2_MASK_WORDS],
candidate_suffix3_bloom: Vec<u32>,
}
impl GpuLiteralSet {
#[must_use]
pub fn compile(patterns: &[&[u8]]) -> Self {
match Self::try_compile(patterns) {
Ok(compiled) => compiled,
Err(error) => {
eprintln!("vyre-libs GpuLiteralSet::compile failed: {error}");
Self::empty_after_compile_failure()
}
}
}
pub fn try_compile(patterns: &[&[u8]]) -> Result<Self, LiteralSetCompileError> {
let dfa = dfa_compile(patterns);
let declared_pattern_count = u32::try_from(patterns.len()).map_err(|_| {
LiteralSetCompileError::PatternCountOverflow {
count: patterns.len(),
}
})?;
let total_pattern_bytes = patterns.iter().try_fold(0usize, |sum, pattern| {
sum.checked_add(pattern.len())
.ok_or(LiteralSetCompileError::PatternByteCountOverflow)
})?;
u32::try_from(total_pattern_bytes).map_err(|_| {
LiteralSetCompileError::PatternByteCountExceedsGpuAbi {
count: total_pattern_bytes,
}
})?;
let mut pattern_lengths = Vec::new();
reserve_vec(&mut pattern_lengths, patterns.len(), "pattern length")?;
let mut pattern_offsets = Vec::new();
reserve_vec(&mut pattern_offsets, patterns.len(), "pattern offset")?;
let mut pattern_bytes = Vec::new();
reserve_vec(
&mut pattern_bytes,
total_pattern_bytes,
"packed pattern byte",
)?;
for (pattern_index, pattern) in patterns.iter().enumerate() {
let offset = u32::try_from(pattern_bytes.len()).map_err(|_| {
LiteralSetCompileError::PatternByteCountExceedsGpuAbi {
count: pattern_bytes.len(),
}
})?;
let len = u32::try_from(pattern.len()).map_err(|_| {
LiteralSetCompileError::PatternLengthOverflow {
pattern_index,
len: pattern.len(),
}
})?;
pattern_offsets.push(offset);
pattern_lengths.push(len);
pattern_bytes.extend(pattern.iter().map(|&byte| u32::from(byte)));
}
let program = try_build_literal_set_program(&dfa, declared_pattern_count)
.map_err(|message| LiteralSetCompileError::DispatchProgramBuildFailed { message })?;
Ok(Self {
dfa,
pattern_bytes,
pattern_offsets,
pattern_lengths,
program,
})
}
fn empty_after_compile_failure() -> Self {
let dfa = dfa_compile(&[]);
let program = build_literal_set_program(&dfa, 0);
Self {
dfa,
pattern_bytes: Vec::new(),
pattern_offsets: Vec::new(),
pattern_lengths: Vec::new(),
program,
}
}
#[must_use]
pub fn reference_scan(&self, haystack: &[u8]) -> Vec<Match> {
let mut state = 0u32;
let mut results = Vec::new();
for (pos, &byte) in haystack.iter().enumerate() {
state = self.dfa.transitions[(state as usize) * 256 + (byte as usize)];
let begin = self.dfa.output_offsets[state as usize] as usize;
let end = self.dfa.output_offsets[state as usize + 1] as usize;
for &pattern_id in &self.dfa.output_records[begin..end] {
let len = self.pattern_lengths[pattern_id as usize];
results.push(Match::new(
pattern_id,
(pos as u32 + 1).saturating_sub(len),
pos as u32 + 1,
));
}
}
results.sort_unstable();
results
}
pub fn scan<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
) -> Result<Vec<Match>, vyre::BackendError> {
let mut matches = Vec::new();
self.scan_into(backend, haystack, max_matches, &mut matches)?;
Ok(matches)
}
pub fn scan_into<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
matches: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
let mut scratch = ScanDispatchScratch::default();
self.scan_into_with_scratch(backend, haystack, max_matches, matches, &mut scratch)
}
pub fn count<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
) -> Result<u32, vyre::BackendError> {
let mut scratch = LiteralSetScanScratch::default();
self.count_with_literal_scratch(backend, haystack, &mut scratch)
}
pub fn scan_into_with_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
matches: &mut Vec<Match>,
scratch: &mut ScanDispatchScratch,
) -> Result<(), vyre::BackendError> {
let dispatch_program = self.program_for_match_capacity(max_matches)?;
let prefilter_tables = self.build_prefilter_tables()?;
self.scan_into_with_program(
backend,
haystack,
max_matches,
matches,
scratch,
dispatch_program.as_ref(),
&prefilter_tables,
)
}
pub fn prepare_literal_scratch(
&self,
max_matches: u32,
scratch: &mut LiteralSetScanScratch,
) -> Result<(), vyre::BackendError> {
self.program_for_match_capacity_cached(max_matches, &mut scratch.cached_program)?;
self.prefilter_tables_cached(&mut scratch.cached_prefilter)?;
Ok(())
}
pub fn prepare_count_scratch(
&self,
scratch: &mut LiteralSetScanScratch,
) -> Result<(), vyre::BackendError> {
self.count_program_cached(&mut scratch.cached_count_program)?;
self.prefilter_tables_cached(&mut scratch.cached_prefilter)?;
Ok(())
}
pub fn prepare_scan_dispatch(
&self,
haystack: &[u8],
max_matches: u32,
) -> Result<LiteralSetPreparedScan, vyre::BackendError> {
let dispatch_program = self.program_for_match_capacity(max_matches)?;
let prefilter_tables = self.build_prefilter_tables()?;
self.prepare_scan_dispatch_with_program(
haystack,
max_matches,
dispatch_program.as_ref(),
&prefilter_tables,
)
}
pub fn prepare_count_dispatch(
&self,
haystack: &[u8],
) -> Result<LiteralSetPreparedCount, vyre::BackendError> {
let count_program = self.count_program();
let prefilter_tables = self.build_prefilter_tables()?;
self.prepare_count_dispatch_with_program(haystack, &count_program, &prefilter_tables)
}
pub fn scan_into_with_literal_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
matches: &mut Vec<Match>,
scratch: &mut LiteralSetScanScratch,
) -> Result<(), vyre::BackendError> {
let cached_program = &mut scratch.cached_program;
let dispatch_program =
self.program_for_match_capacity_cached(max_matches, cached_program)?;
let prefilter_tables = self.prefilter_tables_cached(&mut scratch.cached_prefilter)?;
self.scan_into_with_program(
backend,
haystack,
max_matches,
matches,
&mut scratch.dispatch,
dispatch_program,
prefilter_tables,
)
}
pub fn count_with_literal_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
scratch: &mut LiteralSetScanScratch,
) -> Result<u32, vyre::BackendError> {
let count_program = self.count_program_cached(&mut scratch.cached_count_program)?;
let prefilter_tables = self.prefilter_tables_cached(&mut scratch.cached_prefilter)?;
self.count_with_program(
backend,
haystack,
&mut scratch.dispatch,
count_program,
prefilter_tables,
)
}
pub fn scan_presence<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
) -> Result<Vec<u32>, vyre::BackendError> {
let mut scratch = ScanDispatchScratch::default();
self.scan_presence_with_scratch(backend, haystack, &mut scratch)
}
pub fn scan_presence_with_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
scratch: &mut ScanDispatchScratch,
) -> Result<Vec<u32>, vyre::BackendError> {
use crate::scan::dispatch_io;
let pattern_count = u32::try_from(self.pattern_lengths.len()).map_err(|_| {
vyre::BackendError::new(
"literal_set presence: pattern count exceeds u32 GPU ABI".to_string(),
)
})?;
let presence_words = presence_bitmap_words(pattern_count) as usize;
let program =
try_build_ac_bounded_ranges_suffix3_presence_program(&self.dfa, pattern_count)
.map_err(vyre::BackendError::new)?;
let prefilter_tables = self.build_prefilter_tables()?;
let haystack_len = dispatch_io::scan_guard(
haystack,
"literal_set_presence",
dispatch_io::DEFAULT_MAX_SCAN_BYTES,
)?;
dispatch_io::pack_haystack_u32_into(haystack, &mut scratch.haystack_bytes)?;
let haystack_bytes = scratch.haystack_bytes.as_slice();
let transition_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.transitions);
let output_offset_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_offsets);
let output_record_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_records);
let pattern_length_bytes = dispatch_io::u32_words_as_le_bytes(&self.pattern_lengths);
let candidate_end_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_end_mask);
let candidate_suffix2_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix2_mask);
let candidate_suffix3_bloom_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix3_bloom);
let haystack_len_word = [haystack_len];
let haystack_len_bytes = dispatch_io::u32_words_as_le_bytes(&haystack_len_word);
let presence_zeroed = vec![0u8; presence_words.saturating_mul(4)];
let config =
dispatch_io::byte_scan_dispatch_config(haystack_len, program.workgroup_size[0]);
let borrowed_inputs: smallvec::SmallVec<[&[u8]; 10]> = [
haystack_bytes, transition_bytes.as_ref(), output_offset_bytes.as_ref(), output_record_bytes.as_ref(), pattern_length_bytes.as_ref(), haystack_len_bytes.as_ref(), presence_zeroed.as_slice(), candidate_end_mask_bytes.as_ref(), candidate_suffix2_mask_bytes.as_ref(), candidate_suffix3_bloom_bytes.as_ref(), ]
.into_iter()
.collect();
let outputs = backend.dispatch_borrowed(&program, &borrowed_inputs, &config)?;
let presence_bytes = dispatch_io::try_output_bytes(&outputs, 0, "literal_set presence")?;
let words: Vec<u32> = presence_bytes
.chunks_exact(4)
.take(presence_words)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(words)
}
pub fn scan_presence_by_region<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
region_starts: &[u32],
) -> Result<Vec<u32>, vyre::BackendError> {
let mut scratch = ScanDispatchScratch::default();
self.scan_presence_by_region_with_scratch(backend, haystack, region_starts, 0, &mut scratch)
}
pub fn scan_presence_by_region_with_scratch<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
region_starts: &[u32],
region_base: u32,
scratch: &mut ScanDispatchScratch,
) -> Result<Vec<u32>, vyre::BackendError> {
use crate::scan::dispatch_io;
let pattern_count = u32::try_from(self.pattern_lengths.len()).map_err(|_| {
vyre::BackendError::new(
"literal_set region-presence: pattern count exceeds u32 GPU ABI".to_string(),
)
})?;
let region_count = u32::try_from(region_starts.len()).map_err(|_| {
vyre::BackendError::new(
"literal_set region-presence: region count exceeds u32 GPU ABI".to_string(),
)
})?;
if region_count == 0 {
return Err(vyre::BackendError::new(
"literal_set region-presence: region_starts must be non-empty. Fix: pass one start offset per coalesced file, beginning with 0.".to_string(),
));
}
if region_starts[0] != 0 {
return Err(vyre::BackendError::new(
"literal_set region-presence: region_starts[0] must be 0 (the kernel binary-search lower bound). Fix: the first coalesced file must start at offset 0.".to_string(),
));
}
let presence_words = presence_bitmap_words(pattern_count) as usize;
let total_words = presence_by_region_words(pattern_count, region_count) as usize;
let program = try_build_ac_bounded_ranges_suffix3_presence_by_region_program(
&self.dfa,
pattern_count,
region_count,
)
.map_err(vyre::BackendError::new)?;
let prefilter_tables = self.build_prefilter_tables()?;
let haystack_len = dispatch_io::scan_guard(
haystack,
"literal_set_presence_by_region",
dispatch_io::DEFAULT_MAX_SCAN_BYTES,
)?;
dispatch_io::pack_haystack_u32_into(haystack, &mut scratch.haystack_bytes)?;
let haystack_bytes = scratch.haystack_bytes.as_slice();
let transition_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.transitions);
let output_offset_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_offsets);
let output_record_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_records);
let pattern_length_bytes = dispatch_io::u32_words_as_le_bytes(&self.pattern_lengths);
let candidate_end_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_end_mask);
let candidate_suffix2_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix2_mask);
let candidate_suffix3_bloom_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix3_bloom);
let haystack_len_word = [haystack_len];
let haystack_len_bytes = dispatch_io::u32_words_as_le_bytes(&haystack_len_word);
let region_starts_bytes = dispatch_io::u32_words_as_le_bytes(region_starts);
let region_base_bytes = region_base.to_le_bytes();
let presence_zeroed = vec![0u8; total_words.saturating_mul(4)];
let config =
dispatch_io::byte_scan_dispatch_config(haystack_len, program.workgroup_size[0]);
let borrowed_inputs: smallvec::SmallVec<[&[u8]; 12]> = [
haystack_bytes, transition_bytes.as_ref(), output_offset_bytes.as_ref(), output_record_bytes.as_ref(), pattern_length_bytes.as_ref(), haystack_len_bytes.as_ref(), presence_zeroed.as_slice(), candidate_end_mask_bytes.as_ref(), candidate_suffix2_mask_bytes.as_ref(), candidate_suffix3_bloom_bytes.as_ref(), region_starts_bytes.as_ref(), region_base_bytes.as_slice(), ]
.into_iter()
.collect();
let outputs = backend.dispatch_borrowed(&program, &borrowed_inputs, &config)?;
let presence_bytes =
dispatch_io::try_output_bytes(&outputs, 0, "literal_set presence_by_region")?;
let words: Vec<u32> = presence_bytes
.chunks_exact(4)
.take(total_words)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Ok(words)
}
fn scan_into_with_program<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
max_matches: u32,
matches: &mut Vec<Match>,
scratch: &mut ScanDispatchScratch,
dispatch_program: &Program,
prefilter_tables: &LiteralSetPrefilterTables,
) -> Result<(), vyre::BackendError> {
use crate::scan::dispatch_io;
matches.clear();
let haystack_len =
dispatch_io::scan_guard(haystack, "literal_set", dispatch_io::DEFAULT_MAX_SCAN_BYTES)?;
dispatch_io::pack_haystack_u32_into(haystack, &mut scratch.haystack_bytes)?;
let haystack_bytes = scratch.haystack_bytes.as_slice();
let transition_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.transitions);
let output_offset_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_offsets);
let output_record_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_records);
let pattern_length_bytes = dispatch_io::u32_words_as_le_bytes(&self.pattern_lengths);
let candidate_end_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_end_mask);
let candidate_suffix2_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix2_mask);
let candidate_suffix3_bloom_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix3_bloom);
let haystack_len_word = [haystack_len];
let haystack_len_bytes = dispatch_io::u32_words_as_le_bytes(&haystack_len_word);
let match_count_bytes = [0u8; 4];
let config = dispatch_io::byte_scan_dispatch_config(
haystack_len,
dispatch_program.workgroup_size[0],
);
let borrowed_inputs: smallvec::SmallVec<[&[u8]; 10]> = [
haystack_bytes,
transition_bytes.as_ref(),
output_offset_bytes.as_ref(),
output_record_bytes.as_ref(),
pattern_length_bytes.as_ref(),
haystack_len_bytes.as_ref(),
match_count_bytes.as_slice(),
candidate_end_mask_bytes.as_ref(),
candidate_suffix2_mask_bytes.as_ref(),
candidate_suffix3_bloom_bytes.as_ref(),
]
.into_iter()
.collect();
let outputs = backend.dispatch_borrowed(&dispatch_program, &borrowed_inputs, &config)?;
decode_literal_set_outputs_into(&outputs, max_matches, matches)?;
Ok(())
}
fn count_with_program<B: VyreBackend + ?Sized>(
&self,
backend: &B,
haystack: &[u8],
scratch: &mut ScanDispatchScratch,
count_program: &Program,
prefilter_tables: &LiteralSetPrefilterTables,
) -> Result<u32, vyre::BackendError> {
use crate::scan::dispatch_io;
let haystack_len =
dispatch_io::scan_guard(haystack, "literal_set", dispatch_io::DEFAULT_MAX_SCAN_BYTES)?;
dispatch_io::pack_haystack_u32_into(haystack, &mut scratch.haystack_bytes)?;
let haystack_bytes = scratch.haystack_bytes.as_slice();
let transition_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.transitions);
let output_offset_bytes = dispatch_io::u32_words_as_le_bytes(&self.dfa.output_offsets);
let candidate_end_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_end_mask);
let candidate_suffix2_mask_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix2_mask);
let candidate_suffix3_bloom_bytes =
dispatch_io::u32_words_as_le_bytes(&prefilter_tables.candidate_suffix3_bloom);
let haystack_len_word = [haystack_len];
let haystack_len_bytes = dispatch_io::u32_words_as_le_bytes(&haystack_len_word);
let match_count_bytes = [0u8; U32_COUNTER_BYTES];
let config =
dispatch_io::byte_scan_dispatch_config(haystack_len, count_program.workgroup_size[0]);
let borrowed_inputs: smallvec::SmallVec<[&[u8]; 8]> = [
haystack_bytes,
transition_bytes.as_ref(),
output_offset_bytes.as_ref(),
candidate_end_mask_bytes.as_ref(),
candidate_suffix2_mask_bytes.as_ref(),
candidate_suffix3_bloom_bytes.as_ref(),
haystack_len_bytes.as_ref(),
match_count_bytes.as_slice(),
]
.into_iter()
.collect();
let outputs = backend.dispatch_borrowed(count_program, &borrowed_inputs, &config)?;
decode_literal_set_count_outputs(&outputs)
}
fn prepare_scan_dispatch_with_program(
&self,
haystack: &[u8],
max_matches: u32,
dispatch_program: &Program,
prefilter_tables: &LiteralSetPrefilterTables,
) -> Result<LiteralSetPreparedScan, vyre::BackendError> {
use crate::scan::dispatch_io;
let haystack_len =
dispatch_io::scan_guard(haystack, "literal_set", dispatch_io::DEFAULT_MAX_SCAN_BYTES)?;
let (_, matches_output_bytes) = literal_set_match_output_layout(max_matches)?;
let mut inputs = Vec::new();
vyre_foundation::allocation::try_reserve_vec_to_capacity(
&mut inputs,
LITERAL_SET_INPUT_COUNT,
)
.map_err(|source| {
vyre::BackendError::new(format!(
"literal_set prepared scan could not reserve {LITERAL_SET_INPUT_COUNT} input buffer slot(s): {source}. Fix: shard the literal set or haystack before preparing resident dispatch."
))
})?;
let mut haystack_bytes = Vec::new();
dispatch_io::pack_haystack_u32_into(haystack, &mut haystack_bytes)?;
inputs.push(haystack_bytes);
inputs.push(copy_u32_words_as_le_bytes(
&self.dfa.transitions,
"transition table",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&self.dfa.output_offsets,
"output offset table",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&self.dfa.output_records,
"output record table",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&self.pattern_lengths,
"pattern length table",
)?);
inputs.push(haystack_len.to_le_bytes().to_vec());
inputs.push(vec![0_u8; U32_COUNTER_BYTES]);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_end_mask,
"candidate end mask",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_suffix2_mask,
"candidate suffix2 mask",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_suffix3_bloom,
"candidate suffix3 bloom",
)?);
let encoded_input_bytes = inputs.iter().try_fold(0_u64, |sum, input| {
let len = u64::try_from(input.len()).map_err(|source| {
vyre::BackendError::new(format!(
"literal_set prepared scan input byte length does not fit u64: {source}. Fix: shard the scan before dispatch."
))
})?;
sum.checked_add(len).ok_or_else(|| {
vyre::BackendError::new(
"literal_set prepared scan input byte total overflowed u64. Fix: shard the scan before dispatch.",
)
})
})?;
Ok(LiteralSetPreparedScan {
program: dispatch_program.clone(),
inputs,
dispatch_config: dispatch_io::byte_scan_dispatch_config(
haystack_len,
dispatch_program.workgroup_size[0],
),
haystack_len,
max_matches,
matches_output_bytes,
encoded_input_bytes,
})
}
fn prepare_count_dispatch_with_program(
&self,
haystack: &[u8],
count_program: &Program,
prefilter_tables: &LiteralSetPrefilterTables,
) -> Result<LiteralSetPreparedCount, vyre::BackendError> {
use crate::scan::dispatch_io;
let haystack_len =
dispatch_io::scan_guard(haystack, "literal_set", dispatch_io::DEFAULT_MAX_SCAN_BYTES)?;
let mut inputs = Vec::new();
vyre_foundation::allocation::try_reserve_vec_to_capacity(
&mut inputs,
LITERAL_SET_COUNT_INPUT_COUNT,
)
.map_err(|source| {
vyre::BackendError::new(format!(
"literal_set prepared count could not reserve {LITERAL_SET_COUNT_INPUT_COUNT} input buffer slot(s): {source}. Fix: shard the literal set or haystack before preparing resident dispatch."
))
})?;
let mut haystack_bytes = Vec::new();
dispatch_io::pack_haystack_u32_into(haystack, &mut haystack_bytes)?;
inputs.push(haystack_bytes);
inputs.push(copy_u32_words_as_le_bytes(
&self.dfa.transitions,
"transition table",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&self.dfa.output_offsets,
"output offset table",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_end_mask,
"candidate end mask",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_suffix2_mask,
"candidate suffix2 mask",
)?);
inputs.push(copy_u32_words_as_le_bytes(
&prefilter_tables.candidate_suffix3_bloom,
"candidate suffix3 bloom",
)?);
inputs.push(haystack_len.to_le_bytes().to_vec());
inputs.push(vec![0_u8; U32_COUNTER_BYTES]);
let encoded_input_bytes = inputs.iter().try_fold(0_u64, |sum, input| {
let len = u64::try_from(input.len()).map_err(|source| {
vyre::BackendError::new(format!(
"literal_set prepared count input byte length does not fit u64: {source}. Fix: shard the scan before dispatch."
))
})?;
sum.checked_add(len).ok_or_else(|| {
vyre::BackendError::new(
"literal_set prepared count input byte total overflowed u64. Fix: shard the scan before dispatch.",
)
})
})?;
Ok(LiteralSetPreparedCount {
program: count_program.clone(),
inputs,
dispatch_config: dispatch_io::byte_scan_dispatch_config(
haystack_len,
count_program.workgroup_size[0],
),
haystack_len,
encoded_input_bytes,
})
}
fn prefilter_tables_cached<'a>(
&'a self,
cached_prefilter: &'a mut Option<LiteralSetPrefilterTables>,
) -> Result<&'a LiteralSetPrefilterTables, vyre::BackendError> {
let pattern_fingerprint = self.pattern_fingerprint();
let reuse_cached = cached_prefilter
.as_ref()
.is_some_and(|cached| cached.pattern_fingerprint == pattern_fingerprint);
if !reuse_cached {
*cached_prefilter =
Some(self.build_prefilter_tables_with_fingerprint(pattern_fingerprint)?);
}
cached_prefilter.as_ref().ok_or_else(|| {
vyre::BackendError::new(
"literal_set failed to retain cached suffix-prefilter tables. Fix: retry with generic ScanDispatchScratch.",
)
})
}
fn build_prefilter_tables(&self) -> Result<LiteralSetPrefilterTables, vyre::BackendError> {
self.build_prefilter_tables_with_fingerprint(self.pattern_fingerprint())
}
fn count_program_cached<'a>(
&'a self,
cached_count_program: &'a mut Option<CachedLiteralSetCountProgram>,
) -> Result<&'a Program, vyre::BackendError> {
let pattern_fingerprint = self.pattern_fingerprint();
let reuse_cached = cached_count_program
.as_ref()
.is_some_and(|cached| cached.pattern_fingerprint == pattern_fingerprint);
if !reuse_cached {
*cached_count_program = Some(CachedLiteralSetCountProgram {
pattern_fingerprint,
program: self.count_program(),
});
}
cached_count_program
.as_ref()
.map(|cached| &cached.program)
.ok_or_else(|| {
vyre::BackendError::new(
"literal_set failed to retain the cached count program. Fix: retry without reusable scratch.",
)
})
}
fn count_program(&self) -> Program {
build_ac_bounded_count_suffix3_prefilter_program(&self.dfa)
}
fn build_prefilter_tables_with_fingerprint(
&self,
pattern_fingerprint: u64,
) -> Result<LiteralSetPrefilterTables, vyre::BackendError> {
let pattern_vectors = self.materialize_pattern_bytes()?;
let pattern_refs = pattern_vectors
.iter()
.map(Vec::as_slice)
.collect::<Vec<_>>();
Ok(LiteralSetPrefilterTables {
pattern_fingerprint,
candidate_end_mask: literal_set_candidate_end_byte_mask_words(&pattern_refs),
candidate_suffix2_mask: literal_set_candidate_suffix2_mask_words(&pattern_refs),
candidate_suffix3_bloom: classic_ac_candidate_suffix3_bloom_words(&pattern_refs),
})
}
fn materialize_pattern_bytes(&self) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
if self.pattern_offsets.len() != self.pattern_lengths.len() {
return Err(vyre::BackendError::new(format!(
"literal_set pattern metadata is malformed: {} offsets for {} lengths. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch.",
self.pattern_offsets.len(),
self.pattern_lengths.len()
)));
}
let mut patterns = Vec::new();
vyre_foundation::allocation::try_reserve_vec_to_capacity(
&mut patterns,
self.pattern_lengths.len(),
)
.map_err(|source| {
vyre::BackendError::new(format!(
"literal_set could not reserve {} decoded pattern slot(s): {source}. Fix: shard the pattern set before dispatch.",
self.pattern_lengths.len()
))
})?;
for (pattern_index, (&offset, &len)) in self
.pattern_offsets
.iter()
.zip(&self.pattern_lengths)
.enumerate()
{
let start = usize::try_from(offset).map_err(|source| {
vyre::BackendError::new(format!(
"literal_set pattern {pattern_index} offset {offset} cannot fit host usize: {source}. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch."
))
})?;
let len = usize::try_from(len).map_err(|source| {
vyre::BackendError::new(format!(
"literal_set pattern {pattern_index} length {len} cannot fit host usize: {source}. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch."
))
})?;
let end = start.checked_add(len).ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set pattern {pattern_index} byte range overflows host usize. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch."
))
})?;
let words = self.pattern_bytes.get(start..end).ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set pattern {pattern_index} byte range {start}..{end} exceeds packed pattern byte table length {}. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch.",
self.pattern_bytes.len()
))
})?;
let mut pattern = Vec::new();
vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut pattern, words.len())
.map_err(|source| {
vyre::BackendError::new(format!(
"literal_set could not reserve {} byte(s) for pattern {pattern_index}: {source}. Fix: shard the pattern set before dispatch.",
words.len()
))
})?;
for (byte_index, &word) in words.iter().enumerate() {
let byte = u8::try_from(word).map_err(|source| {
vyre::BackendError::new(format!(
"literal_set pattern {pattern_index} byte {byte_index} has non-byte word {word}: {source}. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch."
))
})?;
pattern.push(byte);
}
patterns.push(pattern);
}
Ok(patterns)
}
fn pattern_fingerprint(&self) -> u64 {
let mut hash = fnv1a64_initial_state();
for words in [
self.pattern_offsets.as_slice(),
self.pattern_lengths.as_slice(),
self.pattern_bytes.as_slice(),
] {
for &word in words {
for byte in word.to_le_bytes() {
hash = fnv1a64_update_byte(hash, byte);
}
}
}
hash
}
fn program_for_match_capacity_cached<'a>(
&'a self,
max_matches: u32,
cached_program: &'a mut Option<CachedLiteralSetProgram>,
) -> Result<&'a Program, vyre::BackendError> {
let (declared_words, readback_bytes) = literal_set_match_output_layout(max_matches)?;
if self.compiled_matches_output_satisfies(declared_words, readback_bytes)? {
return Ok(&self.program);
}
let base_fingerprint = self.program.fingerprint();
let reuse_cached = cached_program.as_ref().is_some_and(|cached| {
cached.max_matches == max_matches && cached.base_fingerprint == base_fingerprint
});
if !reuse_cached {
let program = self.rewrite_program_for_match_layout(declared_words, readback_bytes);
*cached_program = Some(CachedLiteralSetProgram {
base_fingerprint,
max_matches,
program,
});
}
match cached_program.as_ref() {
Some(cached) => Ok(&cached.program),
None => Err(vyre::BackendError::new(
"literal_set failed to retain the cached match-capacity program. Fix: retry with generic ScanDispatchScratch.",
)),
}
}
fn program_for_match_capacity(
&self,
max_matches: u32,
) -> Result<Cow<'_, Program>, vyre::BackendError> {
let (declared_words, readback_bytes) = literal_set_match_output_layout(max_matches)?;
if self.compiled_matches_output_satisfies(declared_words, readback_bytes)? {
return Ok(Cow::Borrowed(&self.program));
}
Ok(Cow::Owned(self.rewrite_program_for_match_layout(
declared_words,
readback_bytes,
)))
}
fn compiled_matches_output_satisfies(
&self,
declared_words: u32,
readback_bytes: usize,
) -> Result<bool, vyre::BackendError> {
let matches_output = self
.program
.buffers()
.iter()
.find(|buffer| buffer.name() == "matches" && buffer.is_output())
.ok_or_else(|| {
vyre::BackendError::new(
"literal_set program is missing its matches output buffer. Fix: rebuild the literal set with GpuLiteralSet::try_compile before dispatch.",
)
})?;
Ok(matches_output.count == declared_words
&& (matches_output.output_byte_range().is_none()
|| matches_output.output_byte_range() == Some(0..readback_bytes)))
}
fn rewrite_program_for_match_layout(
&self,
declared_words: u32,
readback_bytes: usize,
) -> Program {
let buffers = self
.program
.buffers()
.iter()
.cloned()
.map(|buffer| {
if buffer.name() == "matches" && buffer.is_output() {
buffer
.with_count(declared_words)
.with_output_byte_range(0..readback_bytes)
} else {
buffer
}
})
.collect::<Vec<_>>();
self.program.with_rewritten_buffers(buffers)
}
pub fn to_bytes(&self) -> Result<Vec<u8>, LiteralSetWireError> {
let mut w = vyre_foundation::serial::envelope::WireWriter::new(
LITERAL_SET_WIRE_MAGIC,
LITERAL_SET_WIRE_VERSION,
);
w.write_section(&self.program.to_bytes())
.map_err(LiteralSetWireError::WireFraming)?;
let dfa_bytes = self
.dfa
.to_bytes()
.map_err(LiteralSetWireError::InvalidDfa)?;
w.write_section(&dfa_bytes)
.map_err(LiteralSetWireError::WireFraming)?;
w.write_words(&self.pattern_offsets)
.map_err(LiteralSetWireError::WireFraming)?;
w.write_words(&self.pattern_lengths)
.map_err(LiteralSetWireError::WireFraming)?;
w.write_words(&self.pattern_bytes)
.map_err(LiteralSetWireError::WireFraming)?;
Ok(w.into_bytes())
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, LiteralSetWireError> {
let (mut r, wire_version) =
literal_set_wire_reader(bytes).map_err(LiteralSetWireError::WireFraming)?;
let program_bytes = r.read_section().map_err(LiteralSetWireError::WireFraming)?;
if wire_version == LITERAL_SET_WIRE_VERSION {
Program::from_bytes(program_bytes)
.map_err(|e| LiteralSetWireError::InvalidProgram(format!("{e}")))?;
}
let dfa_bytes = r.read_section().map_err(LiteralSetWireError::WireFraming)?;
let dfa = CompiledDfa::from_bytes(dfa_bytes).map_err(LiteralSetWireError::InvalidDfa)?;
let pattern_offsets = r.read_words().map_err(LiteralSetWireError::WireFraming)?;
let pattern_lengths = r.read_words().map_err(LiteralSetWireError::WireFraming)?;
let pattern_bytes = r.read_words().map_err(LiteralSetWireError::WireFraming)?;
let pattern_count =
u32::try_from(pattern_lengths.len()).map_err(|source| {
LiteralSetWireError::InvalidProgram(format!(
"literal_set decoded pattern length count {} exceeds u32 GPU buffer metadata: {source}. Fix: shard the pattern set before caching.",
pattern_lengths.len()
))
})?;
let program = try_build_literal_set_program(&dfa, pattern_count).map_err(|message| {
LiteralSetWireError::InvalidProgram(format!(
"literal_set decoded DFA cannot rebuild current dispatch Program: {message}"
))
})?;
Ok(Self {
dfa,
pattern_bytes,
pattern_offsets,
pattern_lengths,
program,
})
}
}
fn literal_set_wire_reader(
bytes: &[u8],
) -> Result<
(vyre_foundation::serial::envelope::WireReader<'_>, u32),
vyre_foundation::serial::envelope::EnvelopeError,
> {
match vyre_foundation::serial::envelope::WireReader::new(
bytes,
LITERAL_SET_WIRE_MAGIC,
LITERAL_SET_WIRE_VERSION,
) {
Ok(reader) => Ok((reader, LITERAL_SET_WIRE_VERSION)),
Err(vyre_foundation::serial::envelope::EnvelopeError::VersionMismatch {
found:
legacy_version @ (LITERAL_SET_LEGACY_LITERAL_COMPARE_WIRE_VERSION
| LITERAL_SET_LEGACY_BOUNDED_DFA_WIRE_VERSION),
..
}) => vyre_foundation::serial::envelope::WireReader::new(
bytes,
LITERAL_SET_WIRE_MAGIC,
legacy_version,
)
.map(|reader| (reader, legacy_version)),
Err(error) => Err(error),
}
}
fn literal_set_candidate_end_byte_mask_words(patterns: &[&[u8]]) -> [u32; 8] {
let mut mask = [0_u32; 8];
for pattern in patterns
.iter()
.copied()
.filter(|pattern| !pattern.is_empty())
{
let byte = usize::from(pattern[pattern.len() - 1]);
mask[byte / 32] |= 1_u32 << (byte % 32);
}
mask
}
fn literal_set_candidate_suffix2_mask_words(
patterns: &[&[u8]],
) -> [u32; CLASSIC_AC_SUFFIX2_MASK_WORDS] {
let mut mask = [0_u32; CLASSIC_AC_SUFFIX2_MASK_WORDS];
for pattern in patterns
.iter()
.copied()
.filter(|pattern| !pattern.is_empty())
{
match pattern.len() {
1 => {
let current = usize::from(pattern[0]);
for previous in 0..=u8::MAX {
set_suffix2_candidate_bit(&mut mask, usize::from(previous), current);
}
}
len => {
set_suffix2_candidate_bit(
&mut mask,
usize::from(pattern[len - 2]),
usize::from(pattern[len - 1]),
);
}
}
}
mask
}
fn set_suffix2_candidate_bit(
mask: &mut [u32; CLASSIC_AC_SUFFIX2_MASK_WORDS],
previous: usize,
current: usize,
) {
let suffix = (previous << 8) | current;
mask[suffix / 32] |= 1_u32 << (suffix % 32);
}
fn reserve_vec<T>(
vec: &mut Vec<T>,
requested: usize,
field: &'static str,
) -> Result<(), LiteralSetCompileError> {
vyre_foundation::allocation::try_reserve_vec_to_capacity(vec, requested).map_err(
|source: TryReserveError| LiteralSetCompileError::StorageReserveFailed {
field,
requested,
message: source.to_string(),
},
)
}
fn copy_u32_words_as_le_bytes(
words: &[u32],
field: &'static str,
) -> Result<Vec<u8>, vyre::BackendError> {
let byte_len = words.len().checked_mul(U32_BYTES).ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set prepared scan {field} byte length overflowed host usize. Fix: shard the literal set before preparing resident dispatch."
))
})?;
let mut bytes = Vec::new();
vyre_foundation::allocation::try_reserve_vec_to_capacity(&mut bytes, byte_len).map_err(
|source| {
vyre::BackendError::new(format!(
"literal_set prepared scan could not reserve {byte_len} byte(s) for {field}: {source}. Fix: shard the literal set before preparing resident dispatch."
))
},
)?;
if cfg!(target_endian = "little") {
bytes.extend_from_slice(bytemuck::cast_slice(words));
} else {
for &word in words {
bytes.extend_from_slice(&word.to_le_bytes());
}
}
Ok(bytes)
}
fn decode_literal_set_outputs_into(
outputs: &[Vec<u8>],
max_matches: u32,
matches: &mut Vec<Match>,
) -> Result<(), vyre::BackendError> {
let count_bytes =
crate::scan::dispatch_io::try_output_bytes(outputs, 0, "literal_set match count")?;
let count =
crate::scan::dispatch_io::try_read_u32_prefix(count_bytes, "literal_set match count")?;
let matches_bytes =
crate::scan::dispatch_io::try_output_bytes(outputs, 1, "literal_set matches")?;
crate::scan::dispatch_io::try_unpack_match_triples_exact_prefix_into(
matches_bytes,
count.min(max_matches),
matches,
)
}
fn decode_literal_set_count_outputs(outputs: &[Vec<u8>]) -> Result<u32, vyre::BackendError> {
let count_bytes = crate::scan::dispatch_io::try_output_bytes(outputs, 0, "literal_set count")?;
crate::scan::dispatch_io::try_read_u32_prefix(count_bytes, "literal_set count")
}
fn literal_set_match_triple_bytes(count: u32) -> Result<usize, vyre::BackendError> {
let words = count.checked_mul(MATCH_TRIPLE_WORDS).ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set match count {count} overflows the GPU match-output word count. Fix: lower max_matches or split the scan before dispatch."
))
})?;
usize::try_from(words)
.ok()
.and_then(|words| words.checked_mul(U32_BYTES))
.ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set match count {count} overflows host match-output byte sizing. Fix: lower max_matches or split the scan before dispatch."
))
})
}
fn literal_set_match_output_layout(max_matches: u32) -> Result<(u32, usize), vyre::BackendError> {
let words = max_matches.checked_mul(MATCH_TRIPLE_WORDS).ok_or_else(|| {
vyre::BackendError::new(format!(
"literal_set max_matches={max_matches} overflows the GPU match-output word count. Fix: lower max_matches or split the scan before dispatch."
))
})?;
let byte_len = literal_set_match_triple_bytes(max_matches)?;
Ok((words.max(1), byte_len))
}
#[cfg(test)]
mod compile_tests {
use super::*;
#[derive(Clone)]
struct LiteralReadbackBackend {
outputs: Vec<Vec<u8>>,
}
impl vyre::backend::private::Sealed for LiteralReadbackBackend {}
impl VyreBackend for LiteralReadbackBackend {
fn id(&self) -> &'static str {
"literal-readback-test"
}
fn dispatch(
&self,
_program: &Program,
_inputs: &[Vec<u8>],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
Ok(self.outputs.clone())
}
fn dispatch_borrowed(
&self,
_program: &Program,
_inputs: &[&[u8]],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
Ok(self.outputs.clone())
}
}
#[derive(Clone)]
struct RecordingLiteralBackend {
outputs: Vec<Vec<u8>>,
observed_matches_layouts:
std::sync::Arc<std::sync::Mutex<Vec<(u32, Option<std::ops::Range<usize>>)>>>,
observed_program_buffer_ptrs: std::sync::Arc<std::sync::Mutex<Vec<usize>>>,
observed_input_lengths: std::sync::Arc<std::sync::Mutex<Vec<Vec<usize>>>>,
}
impl RecordingLiteralBackend {
fn new(outputs: Vec<Vec<u8>>) -> Self {
Self {
outputs,
observed_matches_layouts: std::sync::Arc::default(),
observed_program_buffer_ptrs: std::sync::Arc::default(),
observed_input_lengths: std::sync::Arc::default(),
}
}
fn observed_matches_layouts(&self) -> Vec<(u32, Option<std::ops::Range<usize>>)> {
self.observed_matches_layouts
.lock()
.expect("Fix: recording literal backend mutex should not be poisoned")
.clone()
}
fn observed_program_buffer_ptrs(&self) -> Vec<usize> {
self.observed_program_buffer_ptrs
.lock()
.expect("Fix: recording literal backend mutex should not be poisoned")
.clone()
}
fn observed_input_lengths(&self) -> Vec<Vec<usize>> {
self.observed_input_lengths
.lock()
.expect("Fix: recording literal backend mutex should not be poisoned")
.clone()
}
}
impl vyre::backend::private::Sealed for RecordingLiteralBackend {}
impl VyreBackend for RecordingLiteralBackend {
fn id(&self) -> &'static str {
"literal-recording-test"
}
fn dispatch(
&self,
program: &Program,
inputs: &[Vec<u8>],
config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let borrowed = inputs.iter().map(Vec::as_slice).collect::<Vec<_>>();
self.dispatch_borrowed(program, &borrowed, config)
}
fn dispatch_borrowed(
&self,
program: &Program,
inputs: &[&[u8]],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let matches = program
.buffers()
.iter()
.find(|buffer| buffer.name() == "matches")
.ok_or_else(|| vyre::BackendError::new("test program omitted matches buffer"))?;
self.observed_matches_layouts
.lock()
.map_err(|_| vyre::BackendError::new("test observation mutex poisoned"))?
.push((matches.count, matches.output_byte_range()));
self.observed_program_buffer_ptrs
.lock()
.map_err(|_| vyre::BackendError::new("test observation mutex poisoned"))?
.push(program.buffers().as_ptr() as usize);
self.observed_input_lengths
.lock()
.map_err(|_| vyre::BackendError::new("test observation mutex poisoned"))?
.push(inputs.iter().map(|input| input.len()).collect());
Ok(self.outputs.clone())
}
}
#[derive(Clone)]
struct RecordingCountBackend {
outputs: Vec<Vec<u8>>,
observed_input_lengths: std::sync::Arc<std::sync::Mutex<Vec<Vec<usize>>>>,
observed_buffer_names: std::sync::Arc<std::sync::Mutex<Vec<Vec<String>>>>,
}
impl RecordingCountBackend {
fn new(outputs: Vec<Vec<u8>>) -> Self {
Self {
outputs,
observed_input_lengths: std::sync::Arc::default(),
observed_buffer_names: std::sync::Arc::default(),
}
}
fn observed_input_lengths(&self) -> Vec<Vec<usize>> {
self.observed_input_lengths
.lock()
.expect("Fix: recording count backend mutex should not be poisoned")
.clone()
}
fn observed_buffer_names(&self) -> Vec<Vec<String>> {
self.observed_buffer_names
.lock()
.expect("Fix: recording count backend mutex should not be poisoned")
.clone()
}
}
impl vyre::backend::private::Sealed for RecordingCountBackend {}
impl VyreBackend for RecordingCountBackend {
fn id(&self) -> &'static str {
"literal-count-recording-test"
}
fn dispatch(
&self,
program: &Program,
inputs: &[Vec<u8>],
config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
let borrowed = inputs.iter().map(Vec::as_slice).collect::<Vec<_>>();
self.dispatch_borrowed(program, &borrowed, config)
}
fn dispatch_borrowed(
&self,
program: &Program,
inputs: &[&[u8]],
_config: &vyre::DispatchConfig,
) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
self.observed_input_lengths
.lock()
.map_err(|_| vyre::BackendError::new("test observation mutex poisoned"))?
.push(inputs.iter().map(|input| input.len()).collect());
self.observed_buffer_names
.lock()
.map_err(|_| vyre::BackendError::new("test observation mutex poisoned"))?
.push(
program
.buffers()
.iter()
.map(|buffer| buffer.name().to_string())
.collect(),
);
Ok(self.outputs.clone())
}
}
fn match_count_bytes(count: u32) -> Vec<u8> {
count.to_le_bytes().to_vec()
}
fn match_triple_bytes(pattern_id: u32, start: u32, end: u32) -> Vec<u8> {
let mut bytes = Vec::with_capacity(12);
bytes.extend_from_slice(&pattern_id.to_le_bytes());
bytes.extend_from_slice(&start.to_le_bytes());
bytes.extend_from_slice(&end.to_le_bytes());
bytes
}
fn decode_u32_words(bytes: &[u8]) -> Vec<u32> {
bytes
.chunks_exact(4)
.map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect()
}
fn decode_reference_matches(outputs: &[vyre_reference::value::Value]) -> Vec<Match> {
let count = decode_u32_words(&outputs[0].to_bytes())[0] as usize;
decode_u32_words(&outputs[1].to_bytes())
.into_iter()
.take(count.saturating_mul(3))
.collect::<Vec<_>>()
.chunks_exact(3)
.map(|chunk| Match::new(chunk[0], chunk[1], chunk[2]))
.collect()
}
#[test]
fn try_compile_packs_offsets_lengths_and_bytes_without_truncation() {
let compiled = GpuLiteralSet::try_compile(&[b"ab".as_slice(), b"cde".as_slice()])
.expect("Fix: small literal set must compile");
assert_eq!(compiled.pattern_offsets, vec![0, 2]);
assert_eq!(compiled.pattern_lengths, vec![2, 3]);
assert_eq!(
compiled.pattern_bytes,
vec![
b'a' as u32,
b'b' as u32,
b'c' as u32,
b'd' as u32,
b'e' as u32
]
);
}
#[test]
fn compile_empty_patterns_matches_fallible_compile_contract() {
let compat = GpuLiteralSet::compile(&[]);
let fallible =
GpuLiteralSet::try_compile(&[]).expect("Fix: empty literal set must compile");
assert_eq!(compat.pattern_offsets, fallible.pattern_offsets);
assert_eq!(compat.pattern_lengths, fallible.pattern_lengths);
assert_eq!(compat.pattern_bytes, fallible.pattern_bytes);
}
#[test]
fn literal_prefilter_masks_are_derived_from_literal_suffixes() {
let patterns: [&[u8]; 3] = [b"a", b"bc", b"token"];
let end_mask = literal_set_candidate_end_byte_mask_words(&patterns);
let suffix2_mask = literal_set_candidate_suffix2_mask_words(&patterns);
let end_contains = |byte: u8| {
let byte = usize::from(byte);
(end_mask[byte / 32] & (1_u32 << (byte % 32))) != 0
};
let suffix2_contains = |previous: u8, current: u8| {
let suffix = (usize::from(previous) << 8) | usize::from(current);
(suffix2_mask[suffix / 32] & (1_u32 << (suffix % 32))) != 0
};
assert!(end_contains(b'a'));
assert!(end_contains(b'c'));
assert!(end_contains(b'n'));
assert!(!end_contains(b'z'));
assert!(suffix2_contains(0, b'a'));
assert!(suffix2_contains(u8::MAX, b'a'));
assert!(suffix2_contains(b'b', b'c'));
assert!(suffix2_contains(b'e', b'n'));
assert!(!suffix2_contains(b'x', b'n'));
}
#[test]
fn prepare_literal_scratch_populates_reusable_program_and_prefilter_tables() {
let engine =
GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice(), b"token".as_slice()])
.expect("Fix: small literal set must compile");
let mut scratch = LiteralSetScanScratch::default();
engine
.prepare_literal_scratch(3, &mut scratch)
.expect("Fix: literal hot-loop scratch preparation should build derived state");
assert!(
scratch.cached_program.is_some(),
"Fix: non-default match cap should prepare a reusable rewritten Program."
);
let prefilter = scratch
.cached_prefilter
.as_ref()
.expect("Fix: scratch preparation should cache suffix-prefilter tables.");
assert_ne!(
prefilter.candidate_end_mask, [0; 8],
"Fix: suffix-prefilter preparation must materialize candidate-end bits."
);
assert!(
prefilter
.candidate_suffix2_mask
.iter()
.any(|&word| word != 0),
"Fix: suffix-prefilter preparation must materialize suffix2 candidate bits."
);
assert!(
prefilter
.candidate_suffix3_bloom
.iter()
.any(|&word| word != 0),
"Fix: suffix-prefilter preparation must materialize suffix3 candidate bits."
);
assert!(scratch.cached_count_program.is_none());
}
#[test]
fn prepare_count_scratch_populates_count_program_and_prefilter_tables() {
let engine =
GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice(), b"token".as_slice()])
.expect("Fix: small literal set must compile");
let mut scratch = LiteralSetScanScratch::default();
engine
.prepare_count_scratch(&mut scratch)
.expect("Fix: literal count scratch preparation should build derived state");
assert!(
scratch.cached_count_program.is_some(),
"Fix: count hot-loop scratch should prepare the count-only program."
);
assert!(
scratch.cached_prefilter.is_some(),
"Fix: count hot-loop scratch should prepare suffix-prefilter tables."
);
assert!(
scratch.cached_program.is_none(),
"Fix: count scratch preparation should not build match-list output programs."
);
}
#[test]
fn prepare_scan_dispatch_matches_borrowed_input_layout() {
let engine =
GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice(), b"token".as_slice()])
.expect("Fix: small literal set must compile");
let plan = engine
.prepare_scan_dispatch(b"xx token bc a", 3)
.expect("Fix: prepared literal scan dispatch should own input buffers");
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
engine
.scan_into(&backend, b"xx token bc a", 3, &mut matches)
.expect("Fix: recording backend should accept literal scan");
assert_eq!(plan.inputs.len(), LITERAL_SET_INPUT_COUNT);
assert_eq!(
backend.observed_input_lengths()[0],
plan.inputs.iter().map(Vec::len).collect::<Vec<_>>(),
"Fix: prepared dispatch buffers must stay in the same ABI order as direct scan dispatch."
);
assert_eq!(
plan.dispatch_config.grid_override,
Some([1, 1, 1]),
"Fix: prepared dispatch must preserve byte-scan grid geometry."
);
assert_eq!(plan.match_count_readback_bytes(), U32_COUNTER_BYTES);
assert_eq!(
plan.match_triples_readback_bytes(u32::MAX)
.expect("Fix: clamped readback sizing should not overflow"),
plan.matches_output_bytes
);
assert_eq!(
plan.encoded_input_bytes,
plan.inputs
.iter()
.map(|input| input.len() as u64)
.sum::<u64>()
);
}
#[test]
fn prepared_scan_decodes_resident_style_readback() {
let engine = GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice()])
.expect("Fix: small literal set must compile");
let plan = engine
.prepare_scan_dispatch(b"abc", 2)
.expect("Fix: prepared literal scan dispatch should build");
let outputs = vec![
match_count_bytes(2),
[match_triple_bytes(0, 0, 1), match_triple_bytes(1, 1, 3)].concat(),
];
let mut matches = Vec::new();
plan.decode_outputs_into(&outputs, &mut matches)
.expect("Fix: prepared scan decoder should read count plus match triples");
assert_eq!(
matches,
vec![Match::new(0, 0, 1), Match::new(1, 1, 3)],
"Fix: prepared dispatch decode must match public GpuLiteralSet scan semantics."
);
}
#[test]
fn literal_count_uses_count_only_program_and_readback() {
let engine = GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice()])
.expect("Fix: small literal set must compile");
let backend = RecordingCountBackend::new(vec![match_count_bytes(3)]);
let mut scratch = LiteralSetScanScratch::default();
let count = engine
.count_with_literal_scratch(&backend, b"abcabc", &mut scratch)
.expect("Fix: literal count dispatch should decode one count output");
assert_eq!(count, 3);
assert_eq!(
backend.observed_input_lengths()[0].len(),
LITERAL_SET_COUNT_INPUT_COUNT,
"Fix: count-only dispatch must not upload output_records or pattern lengths."
);
assert_eq!(
backend.observed_buffer_names()[0],
vec![
"haystack",
"transitions",
"output_offsets",
"candidate_end_mask",
"candidate_suffix2_mask",
"candidate_suffix3_bloom",
"haystack_len",
"match_count"
],
"Fix: literal count must dispatch the suffix3 count program ABI."
);
assert!(
scratch.cached_count_program.is_some(),
"Fix: count hot loops should reuse the count program."
);
}
#[test]
fn prepare_count_dispatch_matches_count_input_layout() {
let engine = GpuLiteralSet::try_compile(&[b"a".as_slice(), b"bc".as_slice()])
.expect("Fix: small literal set must compile");
let plan = engine
.prepare_count_dispatch(b"abcabc")
.expect("Fix: prepared literal count dispatch should own input buffers");
let backend = RecordingCountBackend::new(vec![match_count_bytes(3)]);
let count = engine
.count(&backend, b"abcabc")
.expect("Fix: recording backend should accept literal count");
assert_eq!(count, 3);
assert_eq!(plan.inputs.len(), LITERAL_SET_COUNT_INPUT_COUNT);
assert_eq!(
backend.observed_input_lengths()[0],
plan.inputs.iter().map(Vec::len).collect::<Vec<_>>(),
"Fix: prepared count buffers must stay in the same ABI order as direct count dispatch."
);
assert_eq!(plan.dispatch_config.grid_override, Some([1, 1, 1]));
assert_eq!(plan.count_readback_bytes(), U32_COUNTER_BYTES);
assert_eq!(
plan.decode_outputs(&[match_count_bytes(3)])
.expect("Fix: prepared count decoder should read one u32"),
3
);
assert_eq!(
plan.encoded_input_bytes,
plan.inputs
.iter()
.map(|input| input.len() as u64)
.sum::<u64>()
);
}
#[test]
fn reserve_vec_reports_compile_storage_failure() {
let mut scratch = Vec::<u8>::new();
let error = reserve_vec(&mut scratch, usize::MAX, "adversarial scratch")
.expect_err("Fix: usize::MAX reserve must fail instead of silently truncating");
match error {
LiteralSetCompileError::StorageReserveFailed {
field, requested, ..
} => {
assert_eq!(field, "adversarial scratch");
assert_eq!(requested, usize::MAX);
}
other => panic!("expected storage reserve failure, got {other:?}"),
}
assert!(scratch.is_empty());
}
#[test]
fn literal_scan_rejects_short_match_count_readback() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = LiteralReadbackBackend {
outputs: vec![vec![1, 2, 3], Vec::new()],
};
let mut matches = vec![Match::new(99, 1, 2)];
let err = engine
.scan_into(&backend, b"a", 1, &mut matches)
.expect_err("short literal match-count readback must fail");
let msg = err.to_string();
assert!(
matches.is_empty(),
"scan errors must not expose stale matches"
);
assert!(
msg.contains("literal_set match count") && msg.contains("requires 4 bytes"),
"literal scan counter error must name the malformed output: {msg}"
);
}
#[test]
fn literal_scan_rejects_missing_match_output_slot() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = LiteralReadbackBackend {
outputs: vec![match_count_bytes(1)],
};
let mut matches = Vec::new();
let err = engine
.scan_into(&backend, b"a", 1, &mut matches)
.expect_err("missing literal match output must fail");
let msg = err.to_string();
assert!(
msg.contains("literal_set matches") && msg.contains("output index 1"),
"literal scan missing-output error must identify the omitted slot: {msg}"
);
}
#[test]
fn literal_scan_rejects_match_payload_shorter_than_reported_count() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = LiteralReadbackBackend {
outputs: vec![match_count_bytes(2), match_triple_bytes(0, 0, 1)],
};
let mut matches = vec![Match::new(99, 1, 2)];
let err = engine
.scan_into(&backend, b"a", 2, &mut matches)
.expect_err("short literal match payload must fail");
let msg = err.to_string();
assert!(
matches.is_empty(),
"scan errors must not expose stale matches"
);
assert!(
msg.contains("readback was 12 byte(s)")
&& msg.contains("count=2")
&& msg.contains("requires 24 byte(s)"),
"literal scan match-payload error must identify observed and required bytes: {msg}"
);
}
#[test]
fn literal_scan_exposes_scratch_backed_dispatch_staging() {
let production = include_str!("literal_set.rs")
.split("#[cfg(test)]")
.next()
.expect("Fix: literal_set.rs must contain production section");
assert!(
production.contains("pub fn scan_into_with_scratch")
&& production.contains("ScanDispatchScratch")
&& production.contains("LiteralSetScanScratch")
&& production.contains("pack_haystack_u32_into")
&& !production.contains(concat!("pack_haystack_u32", "(haystack)")),
"Fix: literal scan hot path must expose reusable dispatch scratch and avoid fresh haystack packing allocations."
);
assert!(
!production.contains(".expect(") && !production.contains(".unwrap("),
"Fix: literal_set production wrappers must not panic."
);
let program_debug = format!("{:#?}", GpuLiteralSet::compile(&[b"a".as_slice()]).program);
assert!(
!program_debug.contains("_vyre_match_leader"),
"Fix: literal-set GPU program must use the CUDA-lowerable append primitive, not subgroup leader append."
);
let engine = GpuLiteralSet::compile(&[b"a".as_slice(), b"bc".as_slice()]);
let buffer_names = engine
.program
.buffers()
.iter()
.map(|buffer| buffer.name())
.collect::<Vec<_>>();
assert_eq!(
buffer_names,
vec![
"haystack",
"transitions",
"output_offsets",
"output_records",
"pattern_lengths",
"haystack_len",
"match_count",
"candidate_end_mask",
"candidate_suffix2_mask",
"candidate_suffix3_bloom",
"matches"
],
"Fix: public literal-set dispatch must run on the suffix-prefiltered bounded DFA table layout, not the old literal-byte compare ABI."
);
assert!(
!program_debug.contains("pattern_bytes")
&& !program_debug.contains("pattern_offsets")
&& !program_debug.contains("_pid")
&& !program_debug.contains("_literal_matched"),
"Fix: literal-set GPU program must not retain the per-pattern literal compare loop."
);
}
#[test]
fn literal_scan_sizes_match_output_to_requested_cap() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let mut payload = match_triple_bytes(0, 0, 1);
payload.extend_from_slice(&match_triple_bytes(0, 3, 4));
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(2), payload]);
let mut matches = Vec::new();
engine
.scan_into(&backend, b"a--a", 2, &mut matches)
.expect("Fix: literal scan with two-match cap should dispatch");
assert_eq!(matches, vec![Match::new(0, 0, 1), Match::new(0, 3, 4)]);
assert_eq!(backend.observed_matches_layouts(), vec![(6, Some(0..24))]);
}
#[test]
fn literal_scan_uploads_dfa_tables_instead_of_literal_compare_tables() {
let engine = GpuLiteralSet::compile(&[
b"AKIA".as_slice(),
b"ghp_".as_slice(),
b"Authorization: Bearer ".as_slice(),
]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
engine
.scan_into(
&backend,
b"prefix Authorization: Bearer token",
4,
&mut matches,
)
.expect("Fix: literal scan should dispatch with DFA table inputs");
assert!(matches.is_empty());
let packed_haystack_len =
crate::scan::dispatch_io::pack_haystack_u32(b"prefix Authorization: Bearer token")
.len();
let prefilter = engine
.build_prefilter_tables()
.expect("Fix: small literal-set prefilter tables should build");
assert_eq!(
backend.observed_input_lengths(),
vec![vec![
packed_haystack_len,
engine.dfa.transitions.len() * U32_BYTES,
engine.dfa.output_offsets.len() * U32_BYTES,
engine.dfa.output_records.len() * U32_BYTES,
engine.pattern_lengths.len() * U32_BYTES,
U32_BYTES,
U32_BYTES,
prefilter.candidate_end_mask.len() * U32_BYTES,
prefilter.candidate_suffix2_mask.len() * U32_BYTES,
prefilter.candidate_suffix3_bloom.len() * U32_BYTES,
]],
"Fix: public literal-set scan must upload haystack, DFA tables, suffix-prefilter masks, haystack_len, and match_count."
);
}
#[test]
fn literal_set_dfa_program_reference_eval_matches_public_oracle() {
let patterns: [&[u8]; 5] = [b"a", b"bc", b"abcd", b"BEGIN", b"token"];
let haystack = b"zabcd BEGIN token abcdbc";
let engine = GpuLiteralSet::compile(&patterns);
let prefilter = engine
.build_prefilter_tables()
.expect("Fix: small literal-set prefilter tables should build");
let inputs = vec![
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_haystack_u32(
haystack,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&engine.dfa.transitions,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&engine.dfa.output_offsets,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&engine.dfa.output_records,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&engine.pattern_lengths,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(&[
haystack.len() as u32,
])),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(&[0])),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&prefilter.candidate_end_mask,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&prefilter.candidate_suffix2_mask,
)),
vyre_reference::value::Value::from(crate::scan::dispatch_io::pack_u32_slice(
&prefilter.candidate_suffix3_bloom,
)),
];
let outputs = vyre_reference::reference_eval(&engine.program, &inputs).expect(
"Fix: public literal-set suffix-prefiltered bounded-DFA program should evaluate in reference backend.",
);
let mut actual = decode_reference_matches(&outputs);
let mut expected = engine.reference_scan(haystack);
actual.sort_unstable();
expected.sort_unstable();
assert_eq!(actual, expected);
}
#[test]
fn literal_scan_default_cap_uses_compiled_output_layout() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
engine
.scan_into(
&backend,
b"no hits",
LITERAL_SET_DEFAULT_MAX_MATCHES,
&mut matches,
)
.expect("Fix: default literal scan cap should use the compiled program layout");
assert!(matches.is_empty());
assert_eq!(backend.observed_matches_layouts(), vec![(30_000, None)]);
}
#[test]
fn literal_scan_zero_match_cap_reads_no_match_payload() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(1), Vec::new()]);
let mut matches = vec![Match::new(99, 1, 2)];
engine
.scan_into(&backend, b"a", 0, &mut matches)
.expect("Fix: literal scan with zero cap should return an empty decoded prefix");
assert!(matches.is_empty());
assert_eq!(backend.observed_matches_layouts(), vec![(1, Some(0..0))]);
}
#[test]
fn literal_scan_expands_match_output_above_legacy_fixed_cap() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
engine
.scan_into(&backend, b"no hits", 20_001, &mut matches)
.expect("Fix: literal scan should honor caps above the compiled default");
assert!(matches.is_empty());
assert_eq!(
backend.observed_matches_layouts(),
vec![(60_003, Some(0..240_012))]
);
}
#[test]
fn literal_scan_literal_scratch_reuses_rewritten_program_for_same_cap() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
let mut scratch = LiteralSetScanScratch::default();
engine
.scan_into_with_literal_scratch(&backend, b"first", 2, &mut matches, &mut scratch)
.expect("Fix: first cap-specific literal scan should dispatch");
engine
.scan_into_with_literal_scratch(&backend, b"second", 2, &mut matches, &mut scratch)
.expect("Fix: repeated cap-specific literal scan should dispatch");
assert_eq!(
backend.observed_matches_layouts(),
vec![(6, Some(0..24)), (6, Some(0..24))]
);
let ptrs = backend.observed_program_buffer_ptrs();
assert_eq!(ptrs.len(), 2);
assert_eq!(
ptrs[0], ptrs[1],
"Fix: literal-set scan scratch must reuse the rewritten Program for stable caps"
);
}
#[test]
fn literal_scan_literal_scratch_rebuilds_rewritten_program_when_cap_changes() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
let mut scratch = LiteralSetScanScratch::default();
engine
.scan_into_with_literal_scratch(&backend, b"first", 2, &mut matches, &mut scratch)
.expect("Fix: first cap-specific literal scan should dispatch");
engine
.scan_into_with_literal_scratch(&backend, b"second", 3, &mut matches, &mut scratch)
.expect("Fix: changed cap-specific literal scan should dispatch");
assert_eq!(
backend.observed_matches_layouts(),
vec![(6, Some(0..24)), (9, Some(0..36))]
);
let ptrs = backend.observed_program_buffer_ptrs();
assert_eq!(ptrs.len(), 2);
assert_ne!(
ptrs[0], ptrs[1],
"Fix: literal-set scan scratch must rebuild cached Program when cap changes"
);
}
#[test]
fn literal_scan_rejects_match_cap_that_overflows_output_words() {
let engine = GpuLiteralSet::compile(&[b"a".as_slice()]);
let backend = RecordingLiteralBackend::new(vec![match_count_bytes(0), Vec::new()]);
let mut matches = Vec::new();
let err = engine
.scan_into(&backend, b"a", u32::MAX, &mut matches)
.expect_err("Fix: overflowing literal max_matches must fail before dispatch");
let msg = err.to_string();
assert!(msg.contains("literal_set max_matches"));
assert!(msg.contains("overflows the GPU match-output word count"));
assert!(backend.observed_matches_layouts().is_empty());
}
}
const LITERAL_SET_WIRE_MAGIC: &[u8; 4] = b"VLIT";
const LITERAL_SET_WIRE_VERSION: u32 = 3;
const LITERAL_SET_LEGACY_BOUNDED_DFA_WIRE_VERSION: u32 = 2;
const LITERAL_SET_LEGACY_LITERAL_COMPARE_WIRE_VERSION: u32 = 1;
#[derive(Debug)]
#[non_exhaustive]
pub enum LiteralSetWireError {
WireFraming(vyre_foundation::serial::envelope::EnvelopeError),
InvalidProgram(String),
InvalidDfa(DfaWireError),
}
impl std::fmt::Display for LiteralSetWireError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::WireFraming(e) => write!(f, "GpuLiteralSet wire envelope: {e}"),
Self::InvalidProgram(msg) => {
write!(f, "GpuLiteralSet wire blob has invalid Program: {msg}")
}
Self::InvalidDfa(e) => {
write!(f, "GpuLiteralSet wire blob has invalid DFA: {e}")
}
}
}
}
impl std::error::Error for LiteralSetWireError {}
fn try_build_literal_set_program(dfa: &CompiledDfa, pattern_count: u32) -> Result<Program, String> {
try_build_ac_bounded_ranges_suffix3_prefilter_program_ext(
dfa,
pattern_count,
LITERAL_SET_DEFAULT_MAX_MATCHES,
false,
)
}
fn build_literal_set_program(dfa: &CompiledDfa, pattern_count: u32) -> Program {
build_ac_bounded_ranges_suffix3_prefilter_program_ext(
dfa,
pattern_count,
LITERAL_SET_DEFAULT_MAX_MATCHES,
false,
)
}
pub fn dfa_to_jit_ir(dfa: &CompiledDfa, state_var: &str, byte_expr: Expr) -> Node {
build_state_cascade(dfa, 0, state_var, byte_expr)
}
fn build_state_cascade(dfa: &CompiledDfa, state: u32, state_var: &str, byte_expr: Expr) -> Node {
let mut arms = Vec::new();
for byte in 0..=255 {
let next_state = dfa.transitions[(state as usize) * 256 + byte];
if next_state != 0 {
arms.push((byte as u32, next_state));
}
}
if arms.is_empty() {
return Node::Assign {
name: state_var.into(),
value: Expr::u32(0),
};
}
let mut node = Node::Assign {
name: state_var.into(),
value: Expr::u32(0),
};
for (byte, next) in arms.into_iter().rev() {
node = Node::If {
cond: Expr::eq(byte_expr.clone(), Expr::u32(byte)),
then: vec![Node::Assign {
name: state_var.into(),
value: Expr::u32(next),
}],
otherwise: vec![node],
};
}
node
}