vyre 0.4.0

GPU compute intermediate representation with a standard operation library
Documentation
//! Example 02 — Dispatch the XOR program on a real GPU and verify the bit pattern.
//!
//! This example acquires a wgpu device, builds the same XOR program as Example 01
//! (this time via the standard op library), lowers it to WGSL, compiles a compute
//! pipeline, uploads input buffers, dispatches one workgroup, and reads back the
//! result.

// NOTE: vyre core does not yet ship a high-level per-op GPU dispatcher for
// primitive operations. This example therefore uses the low-level runtime
// helpers (`cached_device`, `compile_compute_pipeline`, `bg_entry`) and
// manual wgpu dispatch. This is the same path engines use internally.

#[cfg(not(feature = "gpu"))]
fn main() {
    eprintln!("This example requires the `gpu` feature.");
    std::process::exit(1);
}

#[cfg(feature = "gpu")]
fn main() {
    use std::iter;
    use vyre::ir::validate;
    use vyre::ops::primitive::xor::Xor;
    use vyre::runtime::{bg_entry, cached_device, compile_compute_pipeline};
    use wgpu::util::DeviceExt;

    // 1. Acquire a device via `vyre::runtime::cached_device()`.
    //
    //    The cached device is a process-wide singleton. If the host has no
    //    compatible GPU (or the driver is missing), we print a clear error
    //    and exit with code 2 rather than panicking.
    let (device, queue) = match cached_device() {
        Ok(pair) => pair,
        Err(error) => {
            eprintln!("GPU unavailable: {error}");
            std::process::exit(2);
        }
    };

    // 2. Build the XOR IR program using the standard operation library.
    //
    //    `Xor::program()` produces the canonical 3-buffer U32 program that
    //    implements element-wise bitwise XOR. We validate it before sending it
    //    to the backend.
    let program = Xor::program();
    let errors = validate(&program);
    assert!(
        errors.is_empty(),
        "validation failed for Xor::program(): {errors:?}"
    );

    // 3. Lower to WGSL and compile a wgpu compute pipeline.
    let wgsl =
        vyre::lower::wgsl::lower(&program).expect("WGSL lowering must succeed for a valid program");
    let pipeline = compile_compute_pipeline(device, "xor_example", &wgsl, "main")
        .expect("pipeline compilation must succeed");

    // 4. Provide actual input data and expected output.
    //
    //    XOR is its own inverse: 0xAAAAAAAA ^ 0x55555555 == 0xFFFFFFFF.
    //    We fill 64 elements to exercise a full workgroup.
    let a = vec![0xAAAAAAAAu32; 64];
    let b = vec![0x55555555u32; 64];

    let a_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: Some("xor_a"),
        contents: bytemuck::cast_slice(&a),
        usage: wgpu::BufferUsages::STORAGE,
    });
    let b_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
        label: Some("xor_b"),
        contents: bytemuck::cast_slice(&b),
        usage: wgpu::BufferUsages::STORAGE,
    });
    let out_buffer = device.create_buffer(&wgpu::BufferDescriptor {
        label: Some("xor_out"),
        size: (64 * std::mem::size_of::<u32>()) as u64,
        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
        mapped_at_creation: false,
    });

    // 5. Bind the buffers and encode the dispatch.
    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some("xor_bind_group"),
        layout: &pipeline.get_bind_group_layout(0),
        entries: &[
            bg_entry(0, &a_buffer),
            bg_entry(1, &b_buffer),
            bg_entry(2, &out_buffer),
        ],
    });

    let readback = device.create_buffer(&wgpu::BufferDescriptor {
        label: Some("xor_readback"),
        size: (64 * std::mem::size_of::<u32>()) as u64,
        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
        mapped_at_creation: false,
    });

    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
        label: Some("xor_encoder"),
    });
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("xor_pass"),
            timestamp_writes: None,
        });
        pass.set_pipeline(&pipeline);
        pass.set_bind_group(0, &bind_group, &[]);
        pass.dispatch_workgroups(1, 1, 1);
    }
    encoder.copy_buffer_to_buffer(
        &out_buffer,
        0,
        &readback,
        0,
        (64 * std::mem::size_of::<u32>()) as u64,
    );
    queue.submit(iter::once(encoder.finish()));

    // 6. Read back the output buffer and assert every element equals 0xFFFFFFFF.
    let output = read_u32_buffer(device, &readback, 64);
    for (i, &value) in output.iter().enumerate() {
        if value != 0xFFFFFFFF {
            eprintln!(
                "First divergence at index {}: expected 0xFFFFFFFF, got 0x{:08X}",
                i, value
            );
            std::process::exit(1);
        }
    }

    println!("GPU XOR dispatch successful — all 64 elements are 0xFFFFFFFF");
}

fn read_u32_buffer(device: &wgpu::Device, buffer: &wgpu::Buffer, word_len: usize) -> Vec<u32> {
    let byte_len = (word_len * std::mem::size_of::<u32>()) as u64;
    let slice = buffer.slice(0..byte_len);
    let (sender, receiver) = std::sync::mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = sender.send(result);
    });
    let _ = device.poll(wgpu::Maintain::Wait);
    receiver
        .recv()
        .expect("readback map callback must report completion")
        .expect("readback buffer must map for reading");

    let mapped = slice.get_mapped_range();
    let bytes = mapped.to_vec();
    drop(mapped);
    buffer.unmap();
    bytemuck::cast_slice::<u8, u32>(&bytes).to_vec()
}