vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
#![cfg(feature = "gpu")]

use std::sync::mpsc;

use bytemuck::cast_slice;
use vyre::ir::{validate, BufferDecl, DataType, Expr, Node, Program};
use vyre::ops::match_ops::dfa_scan::{DfaScan, NO_MATCH};
use wgpu::util::DeviceExt;

const BYTE_CLASSES: usize = 256;

#[test]
fn manual_ir_add_dispatches_on_gpu() {
    let Some((device, queue)) = cached_device() else {
        return;
    };
    let program = add_program(false);

    let output =
        dispatch_three_buffer_u32(device, queue, &program, &[1, 2, 3, 4], &[10, 20, 30, 40]);

    assert_eq!(output, [11, 22, 33, 44]);
}

#[test]
fn call_inlined_add_dispatches_on_gpu() {
    let Some((device, queue)) = cached_device() else {
        return;
    };
    let program = add_program(true);

    let output =
        dispatch_three_buffer_u32(device, queue, &program, &[1, 2, 3, 4], &[10, 20, 30, 40]);

    assert_eq!(output, [11, 22, 33, 44]);
}

#[test]
fn dfa_scan_ir_dispatches_on_gpu() {
    let Some((device, queue)) = cached_device() else {
        return;
    };
    let program = DfaScan::program();
    assert_validation_passes(&program);
    let pipeline = compile_program(device, &program, "gpu_dispatch_dfa_scan");

    let input = b"xxabxxcdxx";
    let input_words = pack_input_words(input);
    let state_count = 5;
    let mut transitions = vec![0u32; state_count * BYTE_CLASSES];
    transitions[b'a' as usize] = 1;
    transitions[BYTE_CLASSES + b'b' as usize] = 2;
    transitions[b'c' as usize] = 3;
    transitions[3 * BYTE_CLASSES + b'd' as usize] = 4;
    let mut accept_map = vec![NO_MATCH; state_count];
    accept_map[2] = 0;
    accept_map[4] = 1;
    let max_matches = 8u32;
    let match_words = usize::try_from(max_matches)
        .expect("max_matches fits usize")
        .checked_mul(3)
        .expect("match buffer word count fits usize");
    let params = [
        u32::try_from(input.len()).expect("input len fits u32"),
        u32::try_from(state_count).expect("state count fits u32"),
        max_matches,
        0,
    ];

    let input_buffer = storage_init(device, "gpu_dispatch_dfa_input", &input_words);
    let transition_buffer = storage_init(device, "gpu_dispatch_dfa_transitions", &transitions);
    let accept_buffer = storage_init(device, "gpu_dispatch_dfa_accept", &accept_map);
    let matches_buffer = storage_empty(
        device,
        "gpu_dispatch_dfa_matches",
        match_words,
        wgpu::BufferUsages::COPY_SRC,
    );
    let count_buffer = storage_empty(
        device,
        "gpu_dispatch_dfa_count",
        1,
        wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
    );
    let params_buffer = storage_init(device, "gpu_dispatch_dfa_params", &params);
    let pattern_lengths_buffer = storage_init(device, "gpu_dispatch_dfa_pattern_lengths", &[2, 2]);

    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some("gpu_dispatch_dfa_bind_group"),
        layout: &pipeline.get_bind_group_layout(0),
        entries: &[
            vyre_wgpu::runtime::bg_entry(0, &input_buffer),
            vyre_wgpu::runtime::bg_entry(1, &transition_buffer),
            vyre_wgpu::runtime::bg_entry(2, &accept_buffer),
            vyre_wgpu::runtime::bg_entry(3, &matches_buffer),
            vyre_wgpu::runtime::bg_entry(4, &count_buffer),
            vyre_wgpu::runtime::bg_entry(5, &params_buffer),
            vyre_wgpu::runtime::bg_entry(6, &pattern_lengths_buffer),
        ],
    });
    let count_readback = readback_buffer(device, "gpu_dispatch_dfa_count_readback", 4);
    let matches_readback = readback_buffer(
        device,
        "gpu_dispatch_dfa_matches_readback",
        u64::from(max_matches) * 12,
    );
    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
        label: Some("gpu_dispatch_dfa_encoder"),
    });
    encoder.clear_buffer(&count_buffer, 0, None);
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("gpu_dispatch_dfa_pass"),
            timestamp_writes: None,
        });
        pass.set_pipeline(&pipeline);
        pass.set_bind_group(0, &bind_group, &[]);
        pass.dispatch_workgroups(1, 1, 1);
    }
    encoder.copy_buffer_to_buffer(&count_buffer, 0, &count_readback, 0, 4);
    encoder.copy_buffer_to_buffer(
        &matches_buffer,
        0,
        &matches_readback,
        0,
        u64::from(max_matches) * 12,
    );
    queue.submit(std::iter::once(encoder.finish()));

    let reported = read_u32_buffer(device, &count_readback, 1)[0];
    let captured = reported.min(max_matches);
    // The shader dispatches 256 threads (workgroup_size) — threads 0..9
    // each run an independent scan from a different byte offset and can
    // find overlapping matches. Verify at least 2 matches were found and
    // that both expected pattern IDs appear.
    assert!(
        captured >= 2,
        "expected at least 2 DFA matches, got {captured}"
    );
    let match_words = read_u32_buffer(
        device,
        &matches_readback,
        usize::try_from(captured).expect("captured fits usize") * 3,
    );
    let matches = match_words
        .chunks_exact(3)
        .map(|fields| vyre::Match::new(fields[0], fields[1], fields[2]))
        .collect::<Vec<_>>();
    let found_pattern_ids: std::collections::BTreeSet<u32> =
        matches.iter().map(|m| m.pattern_id).collect();
    assert!(
        found_pattern_ids.contains(&0),
        "pattern 0 ('ab') not found in DFA matches: {matches:?}"
    );
    assert!(
        found_pattern_ids.contains(&1),
        "pattern 1 ('cd') not found in DFA matches: {matches:?}"
    );
}

fn add_program(use_call: bool) -> Program {
    let idx = Expr::var("idx");
    let sum = if use_call {
        Expr::call(
            "primitive.math.add",
            vec![Expr::load("a", idx.clone()), Expr::load("b", idx.clone())],
        )
    } else {
        Expr::add(Expr::load("a", idx.clone()), Expr::load("b", idx.clone()))
    };
    Program::new(
        vec![
            BufferDecl::read("a", 0, DataType::U32),
            BufferDecl::read("b", 1, DataType::U32),
            BufferDecl::output("out", 2, DataType::U32),
        ],
        [4, 1, 1],
        vec![
            Node::let_bind("idx", Expr::gid_x()),
            Node::if_then(
                Expr::lt(idx.clone(), Expr::buf_len("out")),
                vec![Node::store("out", idx, sum)],
            ),
        ],
    )
}

fn dispatch_three_buffer_u32(
    device: &wgpu::Device,
    queue: &wgpu::Queue,
    program: &Program,
    a: &[u32],
    b: &[u32],
) -> Vec<u32> {
    assert_eq!(a.len(), b.len());
    assert_validation_passes(program);
    let pipeline = compile_program(device, program, "gpu_dispatch_add");
    let a_buffer = storage_init(device, "gpu_dispatch_a", a);
    let b_buffer = storage_init(device, "gpu_dispatch_b", b);
    let out_buffer = storage_empty(
        device,
        "gpu_dispatch_out",
        a.len(),
        wgpu::BufferUsages::COPY_SRC,
    );
    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some("gpu_dispatch_add_bind_group"),
        layout: &pipeline.get_bind_group_layout(0),
        entries: &[
            vyre_wgpu::runtime::bg_entry(0, &a_buffer),
            vyre_wgpu::runtime::bg_entry(1, &b_buffer),
            vyre_wgpu::runtime::bg_entry(2, &out_buffer),
        ],
    });
    let readback = readback_buffer(device, "gpu_dispatch_add_readback", byte_len(a.len()));
    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
        label: Some("gpu_dispatch_add_encoder"),
    });
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("gpu_dispatch_add_pass"),
            timestamp_writes: None,
        });
        pass.set_pipeline(&pipeline);
        pass.set_bind_group(0, &bind_group, &[]);
        pass.dispatch_workgroups(1, 1, 1);
    }
    encoder.copy_buffer_to_buffer(&out_buffer, 0, &readback, 0, byte_len(a.len()));
    queue.submit(std::iter::once(encoder.finish()));
    read_u32_buffer(device, &readback, a.len())
}

fn cached_device() -> Option<&'static (wgpu::Device, wgpu::Queue)> {
    match vyre_wgpu::runtime::cached_device() {
        Ok(pair) => Some(pair),
        Err(error) => {
            panic!("GPU required on this machine (RTX 5090 / 4090 available per project invariant) — do not silently skip: gpu_dispatch integration test: {error}");
            None
        }
    }
}

fn assert_validation_passes(program: &Program) {
    let errors = validate(program);
    assert!(
        errors.is_empty(),
        "IR validation failed for GPU dispatch smoke test: {errors:?}"
    );
}

fn compile_program(
    device: &wgpu::Device,
    program: &Program,
    label: &'static str,
) -> wgpu::ComputePipeline {
    let wgsl = vyre::lower::wgsl::lower(program)
        .unwrap_or_else(|error| panic!("WGSL lowering failed for {label}: {error}"));
    vyre_wgpu::runtime::compile_compute_pipeline(device, label, &wgsl, "main")
        .unwrap_or_else(|error| panic!("WGSL compilation failed for {label}: {error}"))
}

fn storage_init(device: &wgpu::Device, label: &'static str, data: &[u32]) -> wgpu::Buffer {
    storage_init_with_usage(device, label, data, wgpu::BufferUsages::STORAGE)
}

fn storage_init_with_usage(
    device: &wgpu::Device,
    label: &'static str,
    data: &[u32],
    usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
    device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: Some(label),
        contents: cast_slice(data),
        usage,
    })
}

fn storage_empty(
    device: &wgpu::Device,
    label: &'static str,
    word_len: usize,
    extra_usage: wgpu::BufferUsages,
) -> wgpu::Buffer {
    device.create_buffer(&wgpu::BufferDescriptor {
        label: Some(label),
        size: byte_len(word_len),
        usage: wgpu::BufferUsages::STORAGE | extra_usage,
        mapped_at_creation: false,
    })
}

fn readback_buffer(device: &wgpu::Device, label: &'static str, size: u64) -> wgpu::Buffer {
    device.create_buffer(&wgpu::BufferDescriptor {
        label: Some(label),
        size,
        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
        mapped_at_creation: false,
    })
}

fn read_u32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
    let byte_len = byte_len(word_len);
    let slice = buffer.slice(0..byte_len);
    let (sender, receiver) = mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = sender.send(result);
    });
    match device.poll(wgpu::Maintain::Wait) {
        wgpu::MaintainResult::Ok | wgpu::MaintainResult::SubmissionQueueEmpty => {}
    }
    receiver
        .recv()
        .expect("readback map callback must report completion")
        .expect("readback buffer must map for reading");

    let mapped = slice.get_mapped_range();
    let bytes = mapped.to_vec();
    drop(mapped);
    buffer.unmap();
    cast_slice::<u8, u32>(&bytes).to_vec()
}

fn byte_len(word_len: usize) -> u64 {
    let bytes = word_len
        .checked_mul(std::mem::size_of::<u32>())
        .expect("u32 buffer byte length must not overflow usize");
    u64::try_from(bytes).expect("u32 buffer byte length must fit u64")
}

fn pack_input_words(input: &[u8]) -> Vec<u32> {
    input
        .chunks(4)
        .map(|chunk| {
            let mut bytes = [0u8; 4];
            bytes[..chunk.len()].copy_from_slice(chunk);
            u32::from_le_bytes(bytes)
        })
        .collect()
}