mod buffers;
mod input;
mod readback;
mod resources;
mod transitions;
use self::buffers::{entry, storage_buffer};
use self::input::{checked_input_len, dispatch_groups, DfaParams};
use self::readback::{read_matches, read_one_u32};
use self::resources::ScanResources;
use self::transitions::{build_accept_map, compile_pipeline, validate_compile_inputs};
use std::sync::Mutex;
use vyre::error::{Error, Result};
pub(crate) const BYTE_CLASSES: usize = 256;
pub(crate) const SENTINEL_NO_ACCEPT: u32 = 0xFFFF_FFFF;
pub const DEFAULT_MAX_MATCHES: u32 = 65_536;
pub const MAX_DFA_MATCHES: u32 = 1_000_000;
#[non_exhaustive]
pub struct GpuDfa {
device: wgpu::Device,
compiled_with_cached_device: bool,
pipeline: wgpu::ComputePipeline,
transition_buffer: wgpu::Buffer,
accept_buffer: wgpu::Buffer,
pattern_length_buffer: wgpu::Buffer,
pattern_lengths: Vec<u32>,
state_count: u32,
max_matches: u32,
resources: Mutex<Vec<ScanResources>>,
}
impl std::fmt::Debug for GpuDfa {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter
.debug_struct("GpuDfa")
.field("state_count", &self.state_count)
.field("pattern_count", &self.pattern_lengths.len())
.field("max_matches", &self.max_matches)
.finish_non_exhaustive()
}
}
fn zero_dfa_input_padding(
queue: &wgpu::Queue,
buffer: &wgpu::Buffer,
written: usize,
) -> Result<()> {
let padding_len = written.next_multiple_of(4) - written;
if padding_len == 0 {
return Ok(());
}
let offset = u64::try_from(written).map_err(|source| Error::Dfa {
message: format!(
"DFA input byte count {written} cannot fit u64: {source}. Fix: scan chunks smaller than 4 GiB."
),
})?;
let padding = [0u8; 4];
queue.write_buffer(buffer, offset, &padding[..padding_len]);
Ok(())
}
impl GpuDfa {
pub fn compile(
device: &wgpu::Device,
transitions: &[u32],
state_count: usize,
accept_states: &[(u32, u32)],
output_links: &[u32],
pattern_lengths: &[u32],
) -> Result<Self> {
Self::compile_with_max_matches(
device,
transitions,
state_count,
accept_states,
output_links,
pattern_lengths,
DEFAULT_MAX_MATCHES,
)
}
pub fn compile_with_max_matches(
device: &wgpu::Device,
transitions: &[u32],
state_count: usize,
accept_states: &[(u32, u32)],
output_links: &[u32],
pattern_lengths: &[u32],
max_matches: u32,
) -> Result<Self> {
validate_compile_inputs(
device,
transitions,
state_count,
accept_states,
output_links,
pattern_lengths,
max_matches,
)?;
let accept_map = build_accept_map(state_count, accept_states)?;
let pipeline = compile_pipeline(device)?;
let transition_buffer = storage_buffer(device, "vyre dfa transitions", transitions);
let accept_buffer = storage_buffer(device, "vyre dfa accept map", &accept_map);
let pattern_length_buffer =
storage_buffer(device, "vyre dfa pattern lengths", pattern_lengths);
Ok(Self {
device: device.clone(),
compiled_with_cached_device: crate::runtime::device::is_cached_device(device),
pipeline,
transition_buffer,
accept_buffer,
pattern_length_buffer,
pattern_lengths: pattern_lengths.to_vec(),
state_count: u32::try_from(state_count).map_err(|source| Error::Dfa {
message: format!(
"DFA state_count {state_count} cannot fit u32: {source}. Fix: split the automaton or reduce states."
),
})?,
max_matches,
resources: Mutex::new(Vec::new()),
})
}
pub fn scan(
&self,
device: &wgpu::Device,
queue: &wgpu::Queue,
input: &[u8],
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<Vec<vyre::Match>> {
if *device != self.device {
return Err(Error::Dfa {
message: "DFA scan device differs from compile device. Fix: scan with the same wgpu::Device and matching Queue used to compile the DFA.".to_string(),
});
}
if input.is_empty() {
return Ok(Vec::new());
}
let input_len = checked_input_len(input)?;
let mut resources = self.acquire_scan_resources(input_len)?;
let result =
self.scan_with_resources(queue, input, input_len, &mut resources, command_encoder);
self.release_scan_resources(resources)?;
result
}
fn acquire_scan_resources(&self, input_len: u32) -> Result<ScanResources> {
let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
message: format!("DFA scan resources mutex is poisoned: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
})?;
if let Some(index) = pool
.iter()
.position(|resources| resources.max_input_len >= input_len)
{
return Ok(pool.swap_remove(index));
}
drop(pool);
ScanResources::new(&self.device, input_len, self.max_matches)
}
fn release_scan_resources(&self, resources: ScanResources) -> Result<()> {
let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
message: format!("DFA scan resources mutex is poisoned while releasing resources: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
})?;
pool.push(resources);
Ok(())
}
pub(crate) fn scan_with_resources(
&self,
queue: &wgpu::Queue,
input: &[u8],
input_len: u32,
resources: &mut ScanResources,
command_encoder: Option<&mut wgpu::CommandEncoder>,
) -> Result<Vec<vyre::Match>> {
if input_len > resources.max_input_len {
return Err(Error::Dfa {
message: format!(
"DFA input length {input_len} exceeds ScanResources capacity {}. Fix: create larger ScanResources.",
resources.max_input_len
),
});
}
let params = DfaParams {
input_len,
state_count: self.state_count,
max_matches: self.max_matches,
_pad: 0,
};
queue.write_buffer(&resources.input_buffer, 0, input);
zero_dfa_input_padding(queue, &resources.input_buffer, input.len())?;
queue.write_buffer(&resources.params_buffer, 0, bytemuck::bytes_of(¶ms));
let bind_group = self.create_bind_group(
&resources.input_buffer,
&resources.match_buffer,
&resources.match_count_buffer,
&resources.params_buffer,
);
let mut owned_encoder = command_encoder.is_none().then(|| {
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("vyre dfa dispatch and readback"),
})
});
let encoder = if let Some(encoder) = command_encoder {
encoder
} else {
owned_encoder
.as_mut()
.expect("owned encoder must be present when command_encoder is omitted")
};
encoder.clear_buffer(&resources.match_count_buffer, 0, None);
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vyre dfa scan pass"),
timestamp_writes: None,
});
pass.set_pipeline(&self.pipeline);
pass.set_bind_group(0, &bind_group, &[]);
pass.dispatch_workgroups(dispatch_groups(input_len), 1, 1);
}
encoder.copy_buffer_to_buffer(
&resources.match_count_buffer,
0,
&resources.count_readback,
0,
4,
);
let Some(owned_encoder) = owned_encoder else {
return Err(Error::Dfa {
message: "DFA scan was called with an external command encoder, but this API returns readback matches that are unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred DFA API that returns readback buffers.".to_string(),
});
};
let count_submission = queue.submit(std::iter::once(owned_encoder.finish()));
let reported = read_one_u32(
&self.device,
&resources.count_readback,
"match count",
count_submission,
)?;
let captured = reported.min(self.max_matches);
if captured == 0 {
return Ok(Vec::new());
}
let mut readback_encoder =
self.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("vyre dfa match readback"),
});
readback_encoder.copy_buffer_to_buffer(
&resources.match_buffer,
0,
&resources.match_readback,
0,
buffers::match_buffer_size(captured)?,
);
let match_submission = queue.submit(std::iter::once(readback_encoder.finish()));
let mut matches = read_matches(
&self.device,
&resources.match_readback,
captured,
match_submission,
)?;
matches.sort_unstable();
Ok(matches)
}
pub fn scan_shared(&self, input: &[u8]) -> Result<Vec<vyre::Match>> {
if !self.compiled_with_cached_device {
return Err(Error::Dfa {
message: "DFA was compiled with a non-shared GPU device. Fix: compile with vyre::runtime::cached_device() before calling scan_shared(), or call scan() with the original device and queue.".to_string(),
});
}
let (device, queue) = crate::runtime::cached_device()?;
self.scan(device, queue, input, None)
}
#[must_use]
pub fn state_count(&self) -> u32 {
self.state_count
}
#[must_use]
pub fn max_matches(&self) -> u32 {
self.max_matches
}
#[must_use]
pub fn pattern_lengths(&self) -> &[u32] {
&self.pattern_lengths
}
pub(crate) fn create_bind_group(
&self,
input_buffer: &wgpu::Buffer,
match_buffer: &wgpu::Buffer,
match_count_buffer: &wgpu::Buffer,
params_buffer: &wgpu::Buffer,
) -> wgpu::BindGroup {
let layout = self.pipeline.get_bind_group_layout(0);
self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vyre dfa bind group"),
layout: &layout,
entries: &[
entry(0, input_buffer),
entry(1, &self.transition_buffer),
entry(2, &self.accept_buffer),
entry(3, match_buffer),
entry(4, match_count_buffer),
entry(5, params_buffer),
entry(6, &self.pattern_length_buffer),
],
})
}
}