use std::collections::HashSet;
use std::sync::{LazyLock, Mutex};
use bytemuck::{cast_slice, Pod, Zeroable};
use wgpu::util::DeviceExt;
use crate::bytecode::{Instruction, Program};
use crate::error::{Error, Result};
use crate::shaders::{eval_shader, scatter_shader};
use crate::FileContext;
#[derive(Debug, Clone, Copy)]
pub struct GpuEvaluationPlan<'a> {
pub rule_string_counts: &'a [usize],
pub pattern_to_rules: &'a [[u32; 2]],
pub rule_list: &'a [u32],
pub string_local_ids: &'a [u32],
pub sentinel_pattern_ids: &'a [u32],
pub max_cached_positions: usize,
}
#[repr(C)]
#[derive(Clone, Copy, Pod, Zeroable)]
struct Params {
x: u32,
y: u32,
z: u32,
w: u32,
}
const MAX_FIRED_INDICES: u32 = 1024;
const EVAL_SHADER_VERSION: u32 = 1;
pub fn execute_gpu(
device: &wgpu::Device,
queue: &wgpu::Queue,
programs: &[Program],
plan: GpuEvaluationPlan<'_>,
matches: &[matchkit::Match],
file_bytes: &[u8],
file_ctx: FileContext,
) -> Result<Vec<bool>> {
if plan.rule_string_counts.len() != programs.len() {
return Err(Error::BytecodeValidation {
message: format!(
"rule/program count mismatch: {} string-count entries for {} programs",
plan.rule_string_counts.len(),
programs.len()
),
});
}
if file_bytes.len() > u32::MAX as usize {
return Err(Error::Gpu {
message: "file size exceeds u32::MAX, not supported by GPU evaluator".to_string(),
});
}
let file_size = file_bytes.len() as u32;
let rule_count = programs.len() as u32;
let max_strings = plan
.rule_string_counts
.iter()
.copied()
.max()
.unwrap_or(0) as u32;
let sentinel_ids: HashSet<u32> = plan.sentinel_pattern_ids.iter().copied().collect();
let filtered_matches: Vec<matchkit::Match> = matches
.iter()
.filter(|matched| !sentinel_ids.contains(&matched.pattern_id))
.cloned()
.collect();
let match_count = filtered_matches.len() as u32;
let flat_programs = programs
.iter()
.flat_map(|program| program.instructions.iter().copied())
.collect::<Vec<Instruction>>();
let mut spans = Vec::<[u32; 2]>::with_capacity(programs.len());
let mut start = 0u32;
for program in programs {
let len = program.instructions.len() as u32;
spans.push([start, len]);
start = start.checked_add(len).ok_or_else(|| Error::Gpu {
message: "total bytecode instruction count exceeds u32::MAX".to_string(),
})?;
}
let max_buffer = device.limits().max_storage_buffer_binding_size as u64;
let max_buffer_size = device.limits().max_buffer_size;
let checked_mul = |left: u64, right: u64| -> Result<u64> {
left.checked_mul(right).ok_or_else(|| Error::Gpu {
message: "buffer size calculation overflow".to_string(),
})
};
let rule_count_u64 = rule_count as u64;
let max_strings_u64 = max_strings as u64;
let max_cached_u64 = plan.max_cached_positions as u64;
let rule_bitmap_bytes = checked_mul(rule_count_u64, 32)?;
let rule_counts_bytes = checked_mul(checked_mul(rule_count_u64, max_strings_u64)?, 4)?;
let rule_positions_bytes = checked_mul(
checked_mul(checked_mul(rule_count_u64, max_strings_u64)?, max_cached_u64)?,
4,
)?;
let fired_count_bytes = 4u64;
let fired_indices_bytes = (MAX_FIRED_INDICES as u64) * 4;
let match_bytes_len = cast_slice::<matchkit::Match, u8>(filtered_matches.as_slice()).len() as u64;
let programs_bytes = checked_mul(flat_programs.len() as u64, std::mem::size_of::<Instruction>() as u64)?;
let spans_bytes = checked_mul(spans.len() as u64, 8)?;
let file_bytes_len = file_bytes.len() as u64;
for (name, size) in [
("rule_bitmaps", rule_bitmap_bytes),
("rule_counts", rule_counts_bytes),
("rule_positions", rule_positions_bytes),
("fired_count", fired_count_bytes),
("fired_indices", fired_indices_bytes),
("matches", match_bytes_len),
("programs", programs_bytes),
("spans", spans_bytes),
("file_bytes", file_bytes_len),
] {
if size > max_buffer {
return Err(Error::Gpu {
message: format!(
"buffer '{name}' ({size} bytes) exceeds max_storage_buffer_binding_size ({max_buffer} bytes)"
),
});
}
if size > max_buffer_size {
return Err(Error::Gpu {
message: format!("buffer '{name}' ({size} bytes) exceeds max_buffer_size ({max_buffer_size} bytes)"),
});
}
}
let matches_buf = if filtered_matches.is_empty() {
zeroed_storage(device, "vyre matches", 16)
} else {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("vyre matches"),
contents: cast_slice(filtered_matches.as_slice()),
usage: wgpu::BufferUsages::STORAGE,
})
};
let pattern_to_rules_buf = safe_storage_buffer(device, "vyre pattern_to_rules", cast_slice(plan.pattern_to_rules));
let rule_list_buf = safe_storage_buffer(device, "vyre rule_list", cast_slice(plan.rule_list));
let string_ids_buf = safe_storage_buffer(device, "vyre string_local_ids", cast_slice(plan.string_local_ids));
let rule_bitmaps_buf = zeroed_storage(device, "vyre rule_bitmaps", (rule_count as usize * 32).max(16));
let rule_counts_buf = zeroed_storage(
device,
"vyre rule_counts",
(rule_count as usize * max_strings as usize * 4).max(16),
);
let rule_positions_buf = zeroed_storage(
device,
"vyre rule_positions",
(rule_count as usize * max_strings as usize * plan.max_cached_positions * 4).max(16),
);
let rule_lengths_buf = zeroed_storage(
device,
"vyre rule_lengths",
(rule_count as usize * max_strings as usize * plan.max_cached_positions * 4).max(16),
);
let programs_buf = safe_storage_buffer(device, "vyre programs", cast_slice(flat_programs.as_slice()));
let spans_buf = safe_storage_buffer(device, "vyre spans", cast_slice(spans.as_slice()));
let padded_file_bytes = if file_bytes.len() % 4 == 0 {
file_bytes.to_vec()
} else {
let mut padded = file_bytes.to_vec();
padded.resize((file_bytes.len() + 3) & !3, 0);
padded
};
let file_bytes_buf = safe_storage_buffer(device, "vyre file_bytes", &padded_file_bytes);
let fired_results_buf = zeroed_storage(device, "vyre fired_results", ((1 + MAX_FIRED_INDICES) as usize * 4).max(16));
let scatter_params_buf = uniform_buffer(
device,
"vyre scatter params",
&Params {
x: match_count,
y: max_strings,
z: plan.max_cached_positions as u32,
w: file_size,
},
);
let eval_params_buf = uniform_buffer(
device,
"vyre eval params",
&Params {
x: rule_count,
y: max_strings,
z: plan.max_cached_positions as u32,
w: file_size,
},
);
let file_ctx_buf = uniform_buffer(device, "vyre file context", &file_ctx);
static CACHED_PIPELINES: LazyLock<
Mutex<
std::collections::HashMap<
(u32, u32),
std::result::Result<(wgpu::ComputePipeline, wgpu::ComputePipeline), String>,
>,
>,
> = LazyLock::new(|| Mutex::new(std::collections::HashMap::new()));
let max_stack = max_stack_needed(programs);
let mut cache = CACHED_PIPELINES.lock().unwrap_or_else(|poisoned| poisoned.into_inner());
let pipelines = cache.entry((EVAL_SHADER_VERSION, max_stack)).or_insert_with(|| {
let scatter_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("vyre scatter shader"),
source: wgpu::ShaderSource::Wgsl(
scatter_shader::build_scatter_shader(plan.max_cached_positions as u32).into(),
),
});
let eval_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("vyre eval shader"),
source: wgpu::ShaderSource::Wgsl(
eval_shader::build_eval_shader(max_stack, crate::MAX_FOR_ITERATIONS).into(),
),
});
let scatter_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("vyre scatter pipeline"),
layout: None,
module: &scatter_shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let eval_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("vyre eval pipeline"),
layout: None,
module: &eval_shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Ok((scatter_pipeline, eval_pipeline))
});
let (scatter_pipeline, eval_pipeline) = pipelines.as_ref().map_err(|message| Error::Gpu {
message: message.clone(),
})?;
let scatter_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vyre scatter bind group"),
layout: &scatter_pipeline.get_bind_group_layout(0),
entries: &[
bg_entry(0, &matches_buf),
bg_entry(1, &pattern_to_rules_buf),
bg_entry(2, &rule_list_buf),
bg_entry(3, &string_ids_buf),
bg_entry(4, &rule_bitmaps_buf),
bg_entry(5, &rule_counts_buf),
bg_entry(6, &rule_positions_buf),
bg_entry(7, &rule_lengths_buf),
bg_entry(8, &scatter_params_buf),
],
});
let eval_bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("vyre eval bind group"),
layout: &eval_pipeline.get_bind_group_layout(0),
entries: &[
bg_entry(0, &rule_bitmaps_buf),
bg_entry(1, &rule_counts_buf),
bg_entry(2, &rule_positions_buf),
bg_entry(3, &rule_lengths_buf),
bg_entry(4, &programs_buf),
bg_entry(5, &spans_buf),
bg_entry(6, &file_bytes_buf),
bg_entry(7, &fired_results_buf),
bg_entry(8, &eval_params_buf),
bg_entry(9, &file_ctx_buf),
],
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("vyre eval encoder"),
});
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vyre scatter pass"),
timestamp_writes: None,
});
pass.set_pipeline(scatter_pipeline);
pass.set_bind_group(0, &scatter_bg, &[]);
pass.dispatch_workgroups(match_count.max(1), 1, 1);
}
{
let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("vyre eval pass"),
timestamp_writes: None,
});
pass.set_pipeline(eval_pipeline);
pass.set_bind_group(0, &eval_bg, &[]);
pass.dispatch_workgroups(rule_count.max(1), 1, 1);
}
let readback_size = ((1 + MAX_FIRED_INDICES) as usize * 4) as u64;
let readback = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("vyre fired results readback"),
size: readback_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(&fired_results_buf, 0, &readback, 0, readback_size);
queue.submit(std::iter::once(encoder.finish()));
let slice = readback.slice(..);
let (sender, receiver) = std::sync::mpsc::channel();
slice.map_async(wgpu::MapMode::Read, move |result| {
let _ = sender.send(result);
});
let _ = device.poll(wgpu::Maintain::Wait);
match receiver.recv() {
Ok(Ok(())) => {}
Ok(Err(error)) => {
return Err(Error::Gpu {
message: format!("failed to map GPU results: {error:?}"),
});
}
Err(error) => {
return Err(Error::Gpu {
message: format!("failed to receive GPU map status: {error}"),
});
}
}
let data = slice.get_mapped_range();
let words: Vec<u32> = cast_slice(&data).to_vec();
drop(data);
readback.unmap();
let fired = words.first().copied().unwrap_or(0);
if fired > MAX_FIRED_INDICES {
return Err(Error::Gpu {
message: format!("GPU produced {fired} fired indices, exceeding readback limit {MAX_FIRED_INDICES}"),
});
}
let mut hit_bits = vec![false; programs.len()];
for rule_id in words.iter().skip(1).take(fired as usize) {
if let Some(hit) = hit_bits.get_mut(*rule_id as usize) {
*hit = true;
}
}
Ok(hit_bits)
}
pub fn cached_device() -> Result<&'static (wgpu::Device, wgpu::Queue)> {
static DEVICE: LazyLock<Result<(wgpu::Device, wgpu::Queue)>> = LazyLock::new(init_device);
DEVICE.as_ref().map_err(|error| Error::Gpu {
message: error.to_string(),
})
}
fn init_device() -> Result<(wgpu::Device, wgpu::Queue)> {
let instance = wgpu::Instance::default();
let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions::default()))
.ok_or_else(|| Error::Gpu {
message: "failed to acquire adapter".to_string(),
})?;
pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor::default(), None))
.map_err(|error| Error::Gpu {
message: format!("failed to acquire device: {error}"),
})
}
fn bg_entry(binding: u32, buffer: &wgpu::Buffer) -> wgpu::BindGroupEntry<'_> {
wgpu::BindGroupEntry {
binding,
resource: buffer.as_entire_binding(),
}
}
fn uniform_buffer<T: Pod>(device: &wgpu::Device, label: &str, value: &T) -> wgpu::Buffer {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: cast_slice(std::slice::from_ref(value)),
usage: wgpu::BufferUsages::UNIFORM,
})
}
fn safe_storage_buffer(device: &wgpu::Device, label: &str, bytes: &[u8]) -> wgpu::Buffer {
if bytes.is_empty() {
zeroed_storage(device, label, 16)
} else {
device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytes,
usage: wgpu::BufferUsages::STORAGE,
})
}
}
fn zeroed_storage(device: &wgpu::Device, label: &str, size: usize) -> wgpu::Buffer {
device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
})
}
fn max_stack_needed(programs: &[Program]) -> u32 {
let mut max_stack = 8u32;
for program in programs {
let mut depth = 0i32;
let mut peak = 0i32;
for instruction in &program.instructions {
let delta = stack_delta(*instruction);
depth = (depth + delta).max(0);
peak = peak.max(depth);
}
max_stack = max_stack.max(peak.max(1) as u32);
}
max_stack
}
fn stack_delta(instruction: Instruction) -> i32 {
match instruction.kind() {
Ok(
crate::bytecode::Opcode::PushTrue
| crate::bytecode::Opcode::PushFalse
| crate::bytecode::Opcode::PushImmediate
| crate::bytecode::Opcode::PushFileSize
| crate::bytecode::Opcode::PushEntryCount
| crate::bytecode::Opcode::PushNumStrings
| crate::bytecode::Opcode::PushEntropy
| crate::bytecode::Opcode::PushIsPe
| crate::bytecode::Opcode::PushIsDll
| crate::bytecode::Opcode::PushNumSections
| crate::bytecode::Opcode::PushNumImports
| crate::bytecode::Opcode::PushEntryPoint
| crate::bytecode::Opcode::PushHasSignature
| crate::bytecode::Opcode::PushMagicU32
| crate::bytecode::Opcode::PushIs64bit
| crate::bytecode::Opcode::PushStringMatched
| crate::bytecode::Opcode::PushStringCount
| crate::bytecode::Opcode::PushStringOffset
| crate::bytecode::Opcode::PushStringLength
| crate::bytecode::Opcode::ReadIntAt,
) => 1,
Ok(crate::bytecode::Opcode::Not) => 0,
Ok(crate::bytecode::Opcode::ForAny | crate::bytecode::Opcode::ForAll) => -2,
Ok(crate::bytecode::Opcode::ForN) => -3,
Ok(crate::bytecode::Opcode::Halt | crate::bytecode::Opcode::EndFor) => 0,
Ok(_) => -1,
Err(_) => 0,
}
}