vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Trait implementation that routes [`WgslBackend`] calls through wgpu.
//!
//! This file contains the `impl WgslBackend for WgpuBackend` block. It handles
//! shader validation, buffer creation, bind-group assembly, command encoding,
//! and asynchronous readback.

use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
use bytemuck::cast_slice;
use std::collections::{hash_map::DefaultHasher, HashSet};
use std::hash::{Hash, Hasher};
use std::sync::{LazyLock, RwLock};
use wgpu::util::DeviceExt;

use super::backend::WgpuBackend;
use super::byte_words::pad_to_words;
use super::capabilities::adapter_naga_capabilities;
use super::context::get_gpu;
use super::readback::wait_for_readback;

const STACK_INIT_BYTES: usize = 4096;

/// Naga parse + validate cache: each distinct WGSL source is validated once
/// per process, keyed by a 64-bit hash. Subsequent dispatches of the same
/// shader skip the parse / validator walk. The cache is bounded only by the
/// number of distinct shaders the conform suite emits (hundreds), so a plain
/// HashSet is adequate.
fn wgsl_already_validated(wgsl: &str) -> bool {
    static SEEN: LazyLock<RwLock<HashSet<u64>>> = LazyLock::new(|| RwLock::new(HashSet::new()));
    let mut hasher = DefaultHasher::new();
    wgsl.hash(&mut hasher);
    let key = hasher.finish();
    {
        let Ok(guard) = SEEN.read() else {
            return false;
        };
        if guard.contains(&key) {
            return true;
        }
    }
    let Ok(mut guard) = SEEN.write() else {
        return false;
    };
    !guard.insert(key)
}

impl WgslBackend for WgpuBackend {
    fn name(&self) -> &str {
        "wgpu"
    }

    fn version(&self) -> &str {
        "24.0"
    }

    fn max_workgroup_invocations(&self) -> Option<u32> {
        let ctx = get_gpu()?;
        Some(ctx.device.limits().max_compute_invocations_per_workgroup)
    }

    fn dispatch(
        &self,
        wgsl: &str,
        input: &[u8],
        output_size: usize,
        config: ConformDispatchConfig,
    ) -> Result<Vec<u8>, String> {
        dispatch_wgsl(wgsl, "vyre_conform_main", input, output_size, config)
    }
}

impl vyre::VyreBackend for WgpuBackend {
    fn id(&self) -> &'static str {
        "wgpu"
    }

    fn dispatch(
        &self,
        program: &vyre::Program,
        inputs: &[Vec<u8>],
        _config: &vyre::DispatchConfig,
    ) -> Result<Vec<Vec<u8>>, vyre::BackendError> {
        let wgsl = vyre::lower::wgsl::lower(program).map_err(|error| {
            vyre::BackendError::new(format!(
                "failed to lower vyre IR to WGSL: {error}. Fix: provide a valid Program accepted by the WGSL lowering pipeline."
            ))
        })?;
        let input = inputs.first().map(Vec::as_slice).unwrap_or(&[]);
        let output_size = output_size_from_program(program)?;
        let config = ConformDispatchConfig {
            workgroup_size: program.workgroup_size[0].max(1),
            workgroup_count: output_size
                .div_ceil(4)
                .max(1)
                .try_into()
                .unwrap_or(u32::MAX),
            convention: crate::spec::types::Convention::V1,
            lookup_data: None,
            buffer_init: crate::spec::types::BufferInitPolicy::default(),
        };
        dispatch_wgsl(&wgsl, "main", input, output_size, config)
            .map(|output| vec![output])
            .map_err(vyre::BackendError::new)
    }
}

fn output_size_from_program(program: &vyre::Program) -> Result<usize, vyre::BackendError> {
    let output = program
        .buffers
        .iter()
        .find(|buffer| buffer.is_output())
        .ok_or_else(|| {
            vyre::BackendError::new(
                "program has no output buffer. Fix: declare exactly one output buffer in the vyre Program.",
            )
        })?;
    let count = usize::try_from(output.count()).map_err(|_| {
        vyre::BackendError::new(
            "program output element count exceeds usize. Fix: split the dispatch into smaller output buffers.",
        )
    })?;
    Ok(count.saturating_mul(element_size_bytes(output.element())))
}

fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
    match data_type {
        vyre::ir::DataType::Bool
        | vyre::ir::DataType::U32
        | vyre::ir::DataType::I32
        | vyre::ir::DataType::F32 => 4,
        vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
        vyre::ir::DataType::Vec4U32 => 16,
        vyre::ir::DataType::Bytes => 1,
        _ => 4,
    }
}

fn dispatch_wgsl(
    wgsl: &str,
    entry_point: &str,
    input: &[u8],
    output_size: usize,
    config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
    let ctx = get_gpu().ok_or_else(|| {
        "Fix: no GPU adapter available. vyre-conform GPU parity tests require a GPU.".to_string()
    })?;

    // Validate shader through naga for actionable error messages only on
    // the first time this exact WGSL is seen. Subsequent dispatches skip
    // the parse + validator walk, which dominated the per-call cost.
    if !wgsl_already_validated(wgsl) {
        let naga_caps =
            adapter_naga_capabilities(ctx.adapter_features, ctx.adapter_downlevel.clone());
        match naga::front::wgsl::parse_str(wgsl) {
            Ok(module) => {
                if let Err(e) =
                    naga::valid::Validator::new(naga::valid::ValidationFlags::all(), naga_caps)
                        .validate(&module)
                {
                    return Err(format!(
                            "Fix: WGSL shader fails naga validation: {e}. The shader parses but has semantic errors."
                        ));
                }
            }
            Err(e) => {
                return Err(format!(
                        "Fix: WGSL shader fails naga parsing: {e}. The shader source is syntactically invalid."
                    ));
            }
        }
    }

    // The compute pipeline is keyed by (WGSL source, entry point) in the
    // core runtime's compile_compute_pipeline cache, so repeated dispatch
    // of the same shader reuses the compiled module + pipeline instead of
    // paying shader creation + naga-in-wgpu validation on every call.
    let pipeline = vyre_wgpu::runtime::compile_compute_pipeline(
        &ctx.device,
        "vyre-conform pipeline",
        wgsl,
        entry_point,
    )
    .map_err(|e| format!("Fix: compute pipeline compilation failed: {e}"))?;

    // Prepare input buffer
    let input_words = pad_to_words(input);
    let input_buffer = ctx
        .device
        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("input"),
            contents: cast_slice(&input_words),
            usage: wgpu::BufferUsages::STORAGE,
        });

    // Prepare output buffer (zero-initialized or poison-filled)
    let output_word_count = output_size.div_ceil(4).max(1);
    let output_bytes = output_word_count * 4;
    let init_byte = match config.buffer_init {
        crate::spec::types::BufferInitPolicy::Poison => 0xCD,
        _ => 0x00,
    };
    let stack_init = [init_byte; STACK_INIT_BYTES];
    let heap_init;
    let output_init = if output_bytes <= stack_init.len() {
        &stack_init[..output_bytes]
    } else {
        heap_init = vec![init_byte; output_bytes];
        heap_init.as_slice()
    };
    let output_buffer = ctx
        .device
        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("output"),
            contents: output_init,
            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
        });

    // Prepare params uniform
    // input_len is the original byte length BEFORE padding so shaders can
    // distinguish true input bytes from zero-pad fill.
    let input_len_u32 = u32::try_from(input.len()).map_err(|_| {
        format!(
            "Fix: input length {} exceeds u32 capacity; split the dispatch into u32-sized chunks",
            input.len()
        )
    })?;
    let output_len_u32 = u32::try_from(output_word_count).map_err(|_| {
        format!(
            "Fix: output_word_count {output_word_count} exceeds u32 capacity; reduce output_size"
        )
    })?;
    let params = [
        input_len_u32,  // input_len: bytes_before_padding
        output_len_u32, // output_len (in u32 words)
        0u32,           // _pad0
        0u32,           // _pad1
    ];
    let params_buffer = ctx
        .device
        .create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("params"),
            contents: cast_slice(&params),
            usage: wgpu::BufferUsages::UNIFORM,
        });

    // Readback buffer
    let readback_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
        label: Some("readback"),
        size: output_bytes as u64,
        usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
        mapped_at_creation: false,
    });

    // Bind group
    let bind_group_layout = pipeline.get_bind_group_layout(0);
    let mut entries = Vec::with_capacity(4);
    entries.push(wgpu::BindGroupEntry {
        binding: 0,
        resource: input_buffer.as_entire_binding(),
    });
    entries.push(wgpu::BindGroupEntry {
        binding: 1,
        resource: output_buffer.as_entire_binding(),
    });
    entries.push(wgpu::BindGroupEntry {
        binding: 2,
        resource: params_buffer.as_entire_binding(),
    });

    // Optional lookup buffer for V2
    let lookup_buffer;
    if let Some(ref lookup_data) = config.lookup_data {
        let lookup_words = pad_to_words(lookup_data);
        lookup_buffer = ctx
            .device
            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
                label: Some("lookup"),
                contents: cast_slice(&lookup_words),
                usage: wgpu::BufferUsages::STORAGE,
            });
        entries.push(wgpu::BindGroupEntry {
            binding: 3,
            resource: lookup_buffer.as_entire_binding(),
        });
    }

    let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
        label: Some("vyre-conform bind group"),
        layout: &bind_group_layout,
        entries: &entries,
    });

    // Dispatch
    let mut encoder = ctx
        .device
        .create_command_encoder(&wgpu::CommandEncoderDescriptor {
            label: Some("vyre-conform dispatch"),
        });
    {
        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
            label: Some("vyre-conform compute"),
            timestamp_writes: None,
        });
        pass.set_pipeline(&pipeline);
        pass.set_bind_group(0, &bind_group, &[]);
        pass.dispatch_workgroups(config.workgroup_count, 1, 1);
    }
    encoder.copy_buffer_to_buffer(&output_buffer, 0, &readback_buffer, 0, output_bytes as u64);
    ctx.queue.submit(std::iter::once(encoder.finish()));

    // Readback
    let slice = readback_buffer.slice(..);
    let (sender, receiver) = std::sync::mpsc::channel();
    slice.map_async(wgpu::MapMode::Read, move |result| {
        let _ = sender.send(result);
    });
    wait_for_readback(&ctx.device, receiver)?
        .map_err(|e| format!("Fix: GPU readback mapping failed: {e:?}"))?;

    let mapped = slice.get_mapped_range();
    let result = mapped[..output_size].to_vec();
    drop(mapped);
    readback_buffer.unmap();

    Ok(result)
}