vyre-wgpu 0.1.0

wgpu backend for vyre IR — implements VyreBackend, owns GPU runtime, buffer pool, pipeline cache
Documentation
//! Host-side DFA workflow dispatcher.
//!
//! NOTE: This is NOT an IR op domain. It accepts runtime DFA tables and input
//! bytes, compiles and owns GPU pipelines and buffers, dispatches scanning
//! kernels, performs deterministic readback sorting, and returns
//! `vyre::Match` values. The IR-side match domain lives under
//! `vyre::ops::match_ops`; those modules produce `Program` values that go
//! through validate and lower.

mod buffers;
mod input;
mod readback;
mod resources;
mod transitions;

use self::buffers::{entry, storage_buffer};
use self::input::{checked_input_len, dispatch_groups, DfaParams};
use self::readback::{read_matches, read_one_u32};
use self::resources::ScanResources;
use self::transitions::{build_accept_map, compile_pipeline, validate_compile_inputs};
use std::sync::Mutex;
use vyre::error::{Error, Result};

pub(crate) const BYTE_CLASSES: usize = 256;
pub(crate) const SENTINEL_NO_ACCEPT: u32 = 0xFFFF_FFFF;
/// Default maximum matches captured by one GPU DFA scan.
pub const DEFAULT_MAX_MATCHES: u32 = 65_536;

/// Maximum matches a single GPU DFA scan may allocate/read back.
pub const MAX_DFA_MATCHES: u32 = 1_000_000;

/// A GPU-compiled DFA scanner.
#[non_exhaustive]
pub struct GpuDfa {
    device: wgpu::Device,
    compiled_with_cached_device: bool,
    pipeline: wgpu::ComputePipeline,
    transition_buffer: wgpu::Buffer,
    accept_buffer: wgpu::Buffer,
    pattern_length_buffer: wgpu::Buffer,
    pattern_lengths: Vec<u32>,
    state_count: u32,
    max_matches: u32,
    resources: Mutex<Vec<ScanResources>>,
}

impl std::fmt::Debug for GpuDfa {
    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        formatter
            .debug_struct("GpuDfa")
            .field("state_count", &self.state_count)
            .field("pattern_count", &self.pattern_lengths.len())
            .field("max_matches", &self.max_matches)
            .finish_non_exhaustive()
    }
}

fn zero_dfa_input_padding(
    queue: &wgpu::Queue,
    buffer: &wgpu::Buffer,
    written: usize,
) -> Result<()> {
    let padding_len = written.next_multiple_of(4) - written;
    if padding_len == 0 {
        return Ok(());
    }
    let offset = u64::try_from(written).map_err(|source| Error::Dfa {
        message: format!(
            "DFA input byte count {written} cannot fit u64: {source}. Fix: scan chunks smaller than 4 GiB."
        ),
    })?;
    let padding = [0u8; 4];
    queue.write_buffer(buffer, offset, &padding[..padding_len]);
    Ok(())
}

impl GpuDfa {
    /// Internal pipeline helper. User-facing DFA construction lives in
    /// `std::pattern::aho_corasick_build`.
    ///
    /// # Errors
    /// Returns [`Error::Dfa`] when the table, output links, or bindings are invalid.
    pub fn compile(
        device: &wgpu::Device,
        transitions: &[u32],
        state_count: usize,
        accept_states: &[(u32, u32)],
        output_links: &[u32],
        pattern_lengths: &[u32],
    ) -> Result<Self> {
        Self::compile_with_max_matches(
            device,
            transitions,
            state_count,
            accept_states,
            output_links,
            pattern_lengths,
            DEFAULT_MAX_MATCHES,
        )
    }

    /// Internal pipeline helper. User-facing DFA construction lives in
    /// `std::pattern::aho_corasick_build`.
    ///
    /// # Errors
    /// Returns [`Error::Dfa`] when validation of the DFA or GPU binding sizes fails.
    pub fn compile_with_max_matches(
        device: &wgpu::Device,
        transitions: &[u32],
        state_count: usize,
        accept_states: &[(u32, u32)],
        output_links: &[u32],
        pattern_lengths: &[u32],
        max_matches: u32,
    ) -> Result<Self> {
        validate_compile_inputs(
            device,
            transitions,
            state_count,
            accept_states,
            output_links,
            pattern_lengths,
            max_matches,
        )?;

        let accept_map = build_accept_map(state_count, accept_states)?;
        let pipeline = compile_pipeline(device)?;
        let transition_buffer = storage_buffer(device, "vyre dfa transitions", transitions);
        let accept_buffer = storage_buffer(device, "vyre dfa accept map", &accept_map);
        let pattern_length_buffer =
            storage_buffer(device, "vyre dfa pattern lengths", pattern_lengths);

        Ok(Self {
            device: device.clone(),
            compiled_with_cached_device: crate::runtime::device::is_cached_device(device),
            pipeline,
            transition_buffer,
            accept_buffer,
            pattern_length_buffer,
            pattern_lengths: pattern_lengths.to_vec(),
            state_count: u32::try_from(state_count).map_err(|source| Error::Dfa {
                message: format!(
                    "DFA state_count {state_count} cannot fit u32: {source}. Fix: split the automaton or reduce states."
                ),
            })?,
            max_matches,
            resources: Mutex::new(Vec::new()),
        })
    }

    /// Scan input bytes on the GPU and return captured matches.
    ///
    /// # Errors
    /// Returns [`Error::Dfa`] if the input, device, queue, or readback is invalid.
    pub fn scan(
        &self,
        device: &wgpu::Device,
        queue: &wgpu::Queue,
        input: &[u8],
        command_encoder: Option<&mut wgpu::CommandEncoder>,
    ) -> Result<Vec<vyre::Match>> {
        if *device != self.device {
            return Err(Error::Dfa {
                message: "DFA scan device differs from compile device. Fix: scan with the same wgpu::Device and matching Queue used to compile the DFA.".to_string(),
            });
        }
        if input.is_empty() {
            return Ok(Vec::new());
        }
        let input_len = checked_input_len(input)?;
        let mut resources = self.acquire_scan_resources(input_len)?;
        let result =
            self.scan_with_resources(queue, input, input_len, &mut resources, command_encoder);
        self.release_scan_resources(resources)?;
        result
    }

    fn acquire_scan_resources(&self, input_len: u32) -> Result<ScanResources> {
        let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
            message: format!("DFA scan resources mutex is poisoned: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
        })?;
        if let Some(index) = pool
            .iter()
            .position(|resources| resources.max_input_len >= input_len)
        {
            return Ok(pool.swap_remove(index));
        }
        drop(pool);
        ScanResources::new(&self.device, input_len, self.max_matches)
    }

    fn release_scan_resources(&self, resources: ScanResources) -> Result<()> {
        let mut pool = self.resources.lock().map_err(|source| Error::Dfa {
            message: format!("DFA scan resources mutex is poisoned while releasing resources: {source}. Fix: recreate the compiled DFA and inspect panics from concurrent scan tasks."),
        })?;
        pool.push(resources);
        Ok(())
    }

    pub(crate) fn scan_with_resources(
        &self,
        queue: &wgpu::Queue,
        input: &[u8],
        input_len: u32,
        resources: &mut ScanResources,
        command_encoder: Option<&mut wgpu::CommandEncoder>,
    ) -> Result<Vec<vyre::Match>> {
        if input_len > resources.max_input_len {
            return Err(Error::Dfa {
                message: format!(
                    "DFA input length {input_len} exceeds ScanResources capacity {}. Fix: create larger ScanResources.",
                    resources.max_input_len
                ),
            });
        }

        let params = DfaParams {
            input_len,
            state_count: self.state_count,
            max_matches: self.max_matches,
            _pad: 0,
        };
        queue.write_buffer(&resources.input_buffer, 0, input);
        zero_dfa_input_padding(queue, &resources.input_buffer, input.len())?;
        queue.write_buffer(&resources.params_buffer, 0, bytemuck::bytes_of(&params));

        let bind_group = self.create_bind_group(
            &resources.input_buffer,
            &resources.match_buffer,
            &resources.match_count_buffer,
            &resources.params_buffer,
        );
        let mut owned_encoder = command_encoder.is_none().then(|| {
            self.device
                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
                    label: Some("vyre dfa dispatch and readback"),
                })
        });
        let encoder = if let Some(encoder) = command_encoder {
            encoder
        } else {
            owned_encoder
                .as_mut()
                .expect("owned encoder must be present when command_encoder is omitted")
        };
        encoder.clear_buffer(&resources.match_count_buffer, 0, None);
        {
            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("vyre dfa scan pass"),
                timestamp_writes: None,
            });
            pass.set_pipeline(&self.pipeline);
            pass.set_bind_group(0, &bind_group, &[]);
            pass.dispatch_workgroups(dispatch_groups(input_len), 1, 1);
        }
        encoder.copy_buffer_to_buffer(
            &resources.match_count_buffer,
            0,
            &resources.count_readback,
            0,
            4,
        );
        let Some(owned_encoder) = owned_encoder else {
            return Err(Error::Dfa {
                message: "DFA scan was called with an external command encoder, but this API returns readback matches that are unavailable until the caller submits that encoder. Fix: call with `None` for immediate submit/readback, or add a deferred DFA API that returns readback buffers.".to_string(),
            });
        };
        let count_submission = queue.submit(std::iter::once(owned_encoder.finish()));

        let reported = read_one_u32(
            &self.device,
            &resources.count_readback,
            "match count",
            count_submission,
        )?;
        let captured = reported.min(self.max_matches);
        if captured == 0 {
            return Ok(Vec::new());
        }
        let mut readback_encoder =
            self.device
                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
                    label: Some("vyre dfa match readback"),
                });
        readback_encoder.copy_buffer_to_buffer(
            &resources.match_buffer,
            0,
            &resources.match_readback,
            0,
            buffers::match_buffer_size(captured)?,
        );
        let match_submission = queue.submit(std::iter::once(readback_encoder.finish()));
        let mut matches = read_matches(
            &self.device,
            &resources.match_readback,
            captured,
            match_submission,
        )?;
        matches.sort_unstable();
        Ok(matches)
    }

    /// Scan input bytes using the shared vyre device.
    ///
    /// # Errors
    /// Returns [`Error::Dfa`] if this DFA was not compiled on the shared runtime device.
    pub fn scan_shared(&self, input: &[u8]) -> Result<Vec<vyre::Match>> {
        if !self.compiled_with_cached_device {
            return Err(Error::Dfa {
                message: "DFA was compiled with a non-shared GPU device. Fix: compile with vyre::runtime::cached_device() before calling scan_shared(), or call scan() with the original device and queue.".to_string(),
            });
        }
        let (device, queue) = crate::runtime::cached_device()?;
        self.scan(device, queue, input, None)
    }

    /// Number of DFA states in the compiled scanner.
    #[must_use]
    pub fn state_count(&self) -> u32 {
        self.state_count
    }

    /// Maximum number of matches this scanner captures.
    #[must_use]
    pub fn max_matches(&self) -> u32 {
        self.max_matches
    }

    /// Pattern lengths supplied at compile time.
    #[must_use]
    pub fn pattern_lengths(&self) -> &[u32] {
        &self.pattern_lengths
    }

    pub(crate) fn create_bind_group(
        &self,
        input_buffer: &wgpu::Buffer,
        match_buffer: &wgpu::Buffer,
        match_count_buffer: &wgpu::Buffer,
        params_buffer: &wgpu::Buffer,
    ) -> wgpu::BindGroup {
        let layout = self.pipeline.get_bind_group_layout(0);
        self.device.create_bind_group(&wgpu::BindGroupDescriptor {
            label: Some("vyre dfa bind group"),
            layout: &layout,
            entries: &[
                entry(0, input_buffer),
                entry(1, &self.transition_buffer),
                entry(2, &self.accept_buffer),
                entry(3, match_buffer),
                entry(4, match_count_buffer),
                entry(5, params_buffer),
                entry(6, &self.pattern_length_buffer),
            ],
        })
    }
}