vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! Backend abstraction and WGSL wrapping for the conformance suite.
//!
//! This module defines the [`WgslBackend`] trait — the single interface every
//! vyre implementation must satisfy — and provides the test shader wrapper that
//! turns an operation's WGSL fragment into a complete compute shader.

use super::wgpu;
use super::{
    unsupported_execution_model_error, ConformDispatchConfig, ExecutionModel, ExecutionModelKind,
    OneShotDispatch,
};
use crate::spec::types::Convention;

/// Runtime adapter for the conformance suite.
///
/// Any vyre implementation — wgpu, CUDA, Metal, TPU, FPGA — implements
/// this single trait. The universal entry point is `execute()`, whose
/// `ExecutionModel` argument is non-exhaustive so new execution models can be
/// added without breaking existing backends. Adding a new `ExecutionModel`
/// variant is a minor version bump, not a major one.
///
/// `dispatch()` remains the stable primitive for today's one-shot WGSL batch
/// dispatch model. Backends that only support one-shot dispatch can keep
/// implementing `dispatch()` and inherit the default `execute()` plumbing.
///
/// # Contract
///
/// A `WgslBackend` that has been constructed is **ready to dispatch**.
/// GPU availability is enforced at construction time, not per-call.
/// There is no `available()` check — if you have a `&dyn VyreBackend`,
/// it works. Period.
pub(crate) trait WgslBackend: Send + Sync {
    /// Human-readable name for this backend (used in reports).
    fn name(&self) -> &str;

    /// Human-readable backend version for archival certificates.
    fn version(&self) -> &str {
        "unspecified"
    }

    /// Cross-check the declared `version()` string against whatever
    /// runtime the backend actually holds, returning an actionable
    /// `Fix:`-prefixed message on mismatch.
    ///
    /// # Audit L.1.34
    /// The conform-gate audit flagged that the backend version string
    /// was recorded in the certificate but never verified against the
    /// runtime adapter (e.g. `wgpu::Adapter::get_info().driver_info`).
    /// A backend that lied about its version could pass conformance
    /// and ship a misleading certificate; downstream consumers had no
    /// way to detect the drift.
    ///
    /// Every backend that talks to a real runtime (wgpu, Vulkan, CUDA)
    /// should override this method to compare `self.version()` against
    /// the adapter / driver info it actually holds, and return an
    /// `Err("Fix: ...")` describing the mismatch.
    ///
    /// The default implementation is a no-op to preserve
    /// source-compatibility for mock and reference backends that have
    /// no external runtime to verify against; the certify gate calls
    /// this before emitting a certificate, so a default no-op means
    /// "no runtime-level check performed" — the certificate cannot be
    /// any *less* trustworthy than the backend's own `version()` claim.
    fn verify_runtime_info(&self) -> Result<(), String> {
        Ok(())
    }

    /// Execute a WGSL compute shader against input bytes.
    ///
    /// The conformance suite wraps the op's WGSL in a test shader using
    /// `wrap_shader()` before calling this. The backend receives complete,
    /// ready-to-compile WGSL.
    ///
    /// Returns exactly `output_size` bytes, or an actionable error with "Fix: ...".
    fn dispatch(
        &self,
        wgsl: &str,
        input: &[u8],
        output_size: usize,
        config: ConformDispatchConfig,
    ) -> Result<Vec<u8>, String>;

    /// Execution models this backend can run.
    ///
    /// Backends that only implement the current WGSL one-shot dispatch model
    /// can rely on the default. Future backends may advertise additional
    /// `ExecutionModelKind` values as the runtime grows persistent, streaming,
    /// event-driven, or multi-device models.
    fn supported_models(&self) -> &[ExecutionModelKind] {
        &[ExecutionModelKind::OneShot]
    }

    /// Execute a vyre workload through the universal execution-model API.
    ///
    /// The default implementation preserves compatibility for existing
    /// one-shot backends by routing `ExecutionModel::OneShot` to `dispatch()`.
    /// Future execution models should add a new `ExecutionModel` variant and a
    /// default match arm that returns an actionable unsupported-model error
    /// unless the backend overrides this method.
    fn execute(&self, model: &ExecutionModel) -> Result<Vec<u8>, String> {
        if !self.supported_models().contains(&model.kind()) {
            return Err(unsupported_execution_model_error(
                self.name(),
                model.kind_name(),
            ));
        }

        match model {
            ExecutionModel::OneShot(dispatch) => self.dispatch(
                &dispatch.wgsl,
                &dispatch.input,
                dispatch.output_size,
                dispatch.config.clone(),
            ),
        }
    }

    /// Execute one wrapped WGSL shader against a batch of independent inputs.
    ///
    /// Implementors with native batching should override this method and
    /// dispatch the whole batch in one backend submission. The default keeps
    /// existing backends source-compatible by issuing one dispatch per input.
    ///
    /// Returns one output per input in the same order, each exactly matching
    /// the requested output size, or an actionable error with "Fix: ...".
    fn dispatch_batch(
        &self,
        wgsl: &str,
        inputs: &[Vec<u8>],
        output_sizes: &[usize],
        config: ConformDispatchConfig,
    ) -> Result<Vec<Vec<u8>>, String> {
        if inputs.len() != output_sizes.len() {
            return Err(format!(
                "batch input/output length mismatch: {} inputs, {} output sizes. Fix: pass one output size for every batch input.",
                inputs.len(),
                output_sizes.len()
            ));
        }
        let mut outputs = Vec::with_capacity(inputs.len());
        for (input, output_size) in inputs.iter().zip(output_sizes) {
            outputs.push(self.dispatch(wgsl, input, *output_size, config.clone())?);
        }
        Ok(outputs)
    }

    /// Execute a serialized vyre IR program against input bytes.
    ///
    /// `program` must be a `Program::to_wire()` blob. Backends with native IR,
    /// CUDA, PTX, SPIR-V, MSL, or other target lowerings should override this
    /// method and lower directly. WGSL-only backends may use the default path,
    /// which validates the wire blob by deserializing it, lowers through the
    /// reference WGSL lowering, and calls `dispatch()`.
    ///
    /// Returns exactly `output_size` bytes, or an actionable error with "Fix: ...".
    fn dispatch_program(
        &self,
        program: &[u8],
        input: &[u8],
        output_size: usize,
        config: ConformDispatchConfig,
    ) -> Result<Vec<u8>, String> {
        let program = vyre::ir::Program::from_wire(program).map_err(|err| err.to_string())?;
        let wgsl = vyre::lower::wgsl::lower(&program).map_err(|err| {
            format!(
                "failed to lower serialized IR program to WGSL: {err}. Fix: provide valid vyre IR or override dispatch_program with a native backend lowering."
            )
        })?;
        self.dispatch(&wgsl, input, output_size, config)
    }

    /// Maximum supported convention version.
    fn max_convention(&self) -> Convention {
        Convention::V1
    }

    /// Maximum compute invocations per workgroup supported by this backend.
    ///
    /// Returns `None` if the limit is not known (e.g., CPU mocks).
    fn max_workgroup_invocations(&self) -> Option<u32> {
        None
    }
}

/// Wrap an operation WGSL fragment in the conformance test shader.
///
/// The calling convention determines the buffer layout:
/// - V1: input (binding 0, read) + output (binding 1, `read_write`) + params (binding 2, uniform)
/// - V2: V1 + lookup (binding 3, read)
///
/// The entry point is always `vyre_conform_main`.
/// Operations must define: `fn vyre_op(index: u32, input_len: u32) -> u32`
///
/// When `op_wgsl` is already a fully-formed compute shader (produced
/// by `vyre::lower::wgsl::lower`), wrapping would emit a second
/// `@compute @workgroup_size` directive and reference an undefined
/// `vyre_op` identifier. In that case the wrapper is a pass-through —
/// the caller's dispatch path is expected to use the shader's own
/// entry point (`main`) rather than `vyre_conform_main`.
#[inline]
pub fn wrap_shader(op_wgsl: &str, config: &ConformDispatchConfig) -> String {
    if op_wgsl.contains("@compute") {
        return op_wgsl.to_string();
    }
    let mut shader = String::with_capacity(op_wgsl.len() + 1024);
    shader.push_str(op_wgsl);

    let workgroup_size = config.workgroup_size;

    let lookup_binding = match config.convention {
        Convention::V2 { .. } => "\n@group(0) @binding(3)\nvar<storage, read> lookup: Bytes;\n",
        Convention::V1 => "",
    };

    #[allow(clippy::items_after_statements)]
    {
        use std::fmt::Write;
        let _ = write!(
            shader,
            r"

struct Bytes {{
    data: array<u32>,
}};

struct Params {{
    // Original byte length of the input buffer before word-padding.
    // Shaders must use this to ignore zero-padded trailing bytes.
    input_len: u32,
    output_len: u32,
    _pad0: u32,
    _pad1: u32,
}};

@group(0) @binding(0)
var<storage, read> input: Bytes;

@group(0) @binding(1)
var<storage, read_write> output: Bytes;

@group(0) @binding(2)
var<uniform> params: Params;
{lookup_binding}
@compute @workgroup_size({workgroup_size})
fn vyre_conform_main(@builtin(global_invocation_id) gid: vec3<u32>) {{
    let index = gid.x;
    if (index >= arrayLength(&output.data)) {{
        return;
    }}
    output.data[index] = vyre_op(index, params.input_len);
}}
"
        );
    }
    shader
}

#[cfg(test)]
mod tests {

    use super::{wrap_shader, ConformDispatchConfig};
    use crate::spec::types::Convention;

    fn test_op_wgsl() -> &'static str {
        "fn vyre_op(index: u32, input_len: u32) -> u32 { return input.data[index]; }"
    }

    #[test]
    fn wrap_shader_includes_op_fragment() {
        let config = ConformDispatchConfig::default();
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(shader.contains("fn vyre_op"), "op fragment missing");
    }

    #[test]
    fn wrap_shader_v1_has_input_output_params() {
        let config = ConformDispatchConfig::default();
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(
            shader.contains("@group(0) @binding(0)"),
            "missing input binding"
        );
        assert!(
            shader.contains("@group(0) @binding(1)"),
            "missing output binding"
        );
        assert!(
            shader.contains("@group(0) @binding(2)"),
            "missing params binding"
        );
    }

    #[test]
    fn wrap_shader_v1_no_lookup() {
        let config = ConformDispatchConfig::default();
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(
            !shader.contains("@binding(3)"),
            "V1 should not have lookup binding"
        );
    }

    #[test]
    fn wrap_shader_v2_has_lookup() {
        let config = ConformDispatchConfig {
            convention: Convention::V2 { lookup_binding: 3 },
            ..Default::default()
        };
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(
            shader.contains("@binding(3)"),
            "V2 should have lookup binding"
        );
        assert!(
            shader.contains("lookup: Bytes"),
            "V2 should have lookup buffer"
        );
    }

    #[test]
    fn wrap_shader_embeds_workgroup_size() {
        let config = ConformDispatchConfig {
            workgroup_size: 128,
            ..Default::default()
        };
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(
            shader.contains("@workgroup_size(128)"),
            "workgroup size not embedded"
        );
    }

    #[test]
    fn wrap_shader_entry_point_name() {
        let config = ConformDispatchConfig::default();
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(
            shader.contains("fn vyre_conform_main"),
            "wrong entry point name"
        );
    }

    #[test]
    fn wrap_shader_has_bounds_check() {
        let config = ConformDispatchConfig::default();
        let shader = wrap_shader(test_op_wgsl(), &config);
        assert!(shader.contains("arrayLength"), "missing bounds check");
    }

    #[test]
    fn dispatch_config_default_values() {
        let config = ConformDispatchConfig::default();
        assert_eq!(config.workgroup_size, 1);
        assert_eq!(config.workgroup_count, 1);
        assert_eq!(config.convention, Convention::V1);
        assert!(config.lookup_data.is_none());
    }
}

/// Return a real GPU backend or an actionable setup error.
///
/// # Errors
///
/// Returns `Err` when no GPU adapter is available. Library code must never
/// panic on missing hardware — the caller decides whether to abort, retry,
/// or skip.
#[inline]
pub fn require_gpu() -> Result<wgpu::WgpuBackend, String> {
    wgpu::WgpuBackend::new().ok_or_else(|| {
        "Fix: no discrete or integrated GPU adapter is available for vyre-conform. \
         Install a supported GPU/driver stack and run with the `gpu` feature; set \
         VYRE_CONFORM_GPU_REQUIRED=1 in CI to make adapter discovery fail before tests run."
            .to_string()
    })
}

/// GPU parity tests.
#[cfg(all(test, feature = "gpu"))]
mod gpu_parity;