vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
#![cfg(feature = "gpu")]

//! GPU algebra L2 conformance tests.

use std::sync::mpsc;

use vyre_conform::algebra::verify_gpu_laws_witnessed;
use vyre_conform::backend::{DispatchConfig, VyreBackend};
use vyre_conform::specs::primitive;
use vyre_conform::types::Convention;
use wgpu::util::DeviceExt;

const GPU_LAW_WITNESSES: u64 = 128;

struct WgpuBackend {
    device: wgpu::Device,
    queue: wgpu::Queue,
}

impl WgpuBackend {
    fn new_if_available() -> Option<Self> {
        let instance = wgpu::Instance::default();
        let adapter =
            pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions::default()))?;
        let (device, queue) =
            pollster::block_on(adapter.request_device(&wgpu::DeviceDescriptor::default(), None))
                .ok()?;
        Some(Self { device, queue })
    }
}

impl VyreBackend for WgpuBackend {
    fn name(&self) -> &str {
        "wgpu"
    }

    fn max_convention(&self) -> Convention {
        Convention::V1
    }

    fn dispatch(
        &self,
        wgsl: &str,
        input: &[u8],
        output_size: usize,
        config: DispatchConfig,
    ) -> Result<Vec<u8>, String> {
        let input_padded = padded(input);
        let output_padded_size = padded_size(output_size);
        let params = params_bytes(
            (input_padded.len() / 4) as u32,
            (output_padded_size / 4) as u32,
        );

        let input_buffer = self
            .device
            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
                label: Some("gpu algebra input"),
                contents: &input_padded,
                usage: wgpu::BufferUsages::STORAGE,
            });
        let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("gpu algebra output"),
            size: output_padded_size as u64,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
            mapped_at_creation: false,
        });
        let readback_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("gpu algebra readback"),
            size: output_padded_size as u64,
            usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
            mapped_at_creation: false,
        });
        let params_buffer = self
            .device
            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
                label: Some("gpu algebra params"),
                contents: &params,
                usage: wgpu::BufferUsages::UNIFORM,
            });

        let module = self
            .device
            .create_shader_module(wgpu::ShaderModuleDescriptor {
                label: Some("gpu algebra shader"),
                source: wgpu::ShaderSource::Wgsl(wgsl.into()),
            });
        let pipeline = self
            .device
            .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
                label: Some("gpu algebra pipeline"),
                layout: None,
                module: &module,
                entry_point: Some("vyre_conform_main"),
                compilation_options: wgpu::PipelineCompilationOptions::default(),
                cache: None,
            });
        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
            label: Some("gpu algebra bind group"),
            layout: &pipeline.get_bind_group_layout(0),
            entries: &[
                wgpu::BindGroupEntry {
                    binding: 0,
                    resource: input_buffer.as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 1,
                    resource: output_buffer.as_entire_binding(),
                },
                wgpu::BindGroupEntry {
                    binding: 2,
                    resource: params_buffer.as_entire_binding(),
                },
            ],
        });

        let mut encoder = self
            .device
            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
                label: Some("gpu algebra encoder"),
            });
        {
            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("gpu algebra pass"),
                timestamp_writes: None,
            });
            pass.set_pipeline(&pipeline);
            pass.set_bind_group(0, &bind_group, &[]);
            pass.dispatch_workgroups(config.workgroup_count, 1, 1);
        }
        encoder.copy_buffer_to_buffer(
            &output_buffer,
            0,
            &readback_buffer,
            0,
            output_padded_size as u64,
        );
        self.queue.submit(std::iter::once(encoder.finish()));

        readback(
            &self.device,
            &readback_buffer,
            output_padded_size,
            output_size,
        )
    }
}

#[test]
fn all_primitive_laws_hold_on_wgpu_when_available() {
    let Some(backend) = WgpuBackend::new_if_available() else {
        eprintln!("skipping GPU algebra L2 conformance: no wgpu adapter/device available");
        return;
    };

    let specs = primitive::specs();
    assert_eq!(
        specs.len(),
        32,
        "primitive registry must contain all 32 ops"
    );

    let mut total_laws = 0u64;
    let mut total_cases = 0u64;
    for spec in &specs {
        let results = verify_gpu_laws_witnessed(&backend, spec, GPU_LAW_WITNESSES);
        for result in results {
            if result.cases_tested == 0 {
                continue;
            }
            total_laws += 1;
            total_cases += result.cases_tested;
            if let Some(violation) = result.violation {
                panic!(
                    "GPU LAW FAILED: {}\nOp: {}\nLaw: {}\na={}, b={}, c={}\ngpu/lhs={}, expected/rhs={}",
                    violation.message,
                    violation.op_id,
                    violation.law,
                    violation.a,
                    violation.b,
                    violation.c,
                    violation.lhs,
                    violation.rhs,
                );
            }
        }
    }

    assert!(total_laws > 0, "no primitive laws were verified on GPU");
    assert!(total_cases > 1_000, "too few GPU law cases: {total_cases}");
}

fn readback(
    device: &wgpu::Device,
    readback_buffer: &wgpu::Buffer,
    padded_size: usize,
    output_size: usize,
) -> Result<Vec<u8>, String> {
    let slice = readback_buffer.slice(0..padded_size as u64);
    let (sender, receiver) = mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = sender.send(result);
    });
    let _ = device.poll(wgpu::Maintain::Wait);

    receiver
        .recv()
        .map_err(|error| {
            format!("failed to receive readback: {error}. Fix: check GPU device availability")
        })?
        .map_err(|error| {
            format!("failed to map readback buffer: {error:?}. Fix: check buffer usage flags")
        })?;

    let data = slice.get_mapped_range();
    let result = data[..output_size].to_vec();
    drop(data);
    readback_buffer.unmap();
    Ok(result)
}

fn padded(input: &[u8]) -> Vec<u8> {
    let size = padded_size(input.len());
    let mut bytes = vec![0u8; size];
    bytes[..input.len()].copy_from_slice(input);
    bytes
}

fn padded_size(size: usize) -> usize {
    ((size + 3) & !3).max(16)
}

fn params_bytes(input_len: u32, output_len: u32) -> [u8; 16] {
    let mut bytes = [0u8; 16];
    bytes[0..4].copy_from_slice(&input_len.to_le_bytes());
    bytes[4..8].copy_from_slice(&output_len.to_le_bytes());
    bytes
}