vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
use super::StreamingRunner;
use crate::generate::generators::InputGenerator;
use crate::spec::builder::BuildError;
use crate::{DataType, OpSignature, OpSpec};
use std::fs;
use std::sync::{Arc, Mutex};
use vyre::{BackendError, DispatchConfig, Program, VyreBackend};

struct FixedGenerator {
    count: u64,
}

impl InputGenerator for FixedGenerator {
    fn name(&self) -> &str {
        "fixed"
    }

    fn handles(&self, signature: &OpSignature) -> bool {
        signature.inputs == [DataType::U64] && signature.output == DataType::U64
    }

    fn generate(&self, _signature: &OpSignature, _seed: u64) -> Vec<(String, Vec<u8>)> {
        (0..self.count)
            .map(|idx| (format!("case:{idx}"), idx.to_le_bytes().to_vec()))
            .collect()
    }
}

#[derive(Default)]
struct RecordingBackend {
    batches: Mutex<Vec<usize>>,
    corrupt: bool,
}

impl VyreBackend for RecordingBackend {
    fn id(&self) -> &'static str {
        "recording"
    }

    fn dispatch(
        &self,
        _program: &Program,
        inputs: &[Vec<u8>],
        _config: &DispatchConfig,
    ) -> Result<Vec<Vec<u8>>, BackendError> {
        let input = inputs.first().map(Vec::as_slice).unwrap_or(&[]);
        let output_size = 8.min(input.len());
        let mut out = input[..output_size].to_vec();
        if self.corrupt && !out.is_empty() {
            out[0] ^= 0xFF;
        }
        self.batches
            .lock()
            .unwrap_or_else(|poisoned| poisoned.into_inner())
            .push(inputs.len());
        Ok(vec![out])
    }
}

fn fixed_spec(id: &'static str) -> Result<OpSpec, BuildError> {
    OpSpec::builder(id)
        .signature(OpSignature {
            inputs: vec![DataType::U64],
            output: DataType::U64,
        })
        .cpu_fn(|input| input[..8].to_vec())
        .wgsl_fn(|| {
            "fn vyre_op(index: u32, input_len: u32) -> u32 { return input.data[index]; }"
                .to_string()
        })
        .ir_program(Some(vyre::Program::empty))
        .category(crate::Category::A {
            composition_of: vec![id],
        })
        .laws(vec![crate::spec::law::AlgebraicLaw::Bounded {
            lo: 0,
            hi: u32::MAX,
        }])
        .strictness(crate::spec::types::Strictness::Strict)
        .version(1)
        .workgroup_size(Some(1))
        .build()
}

#[test]
fn sharding_and_skip_use_deterministic_global_test_ids() -> Result<(), String> {
    let backend = RecordingBackend::default();
    let spec = fixed_spec("streaming.test.shard")
        .map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?;

    let failures = StreamingRunner::new()
        .batch_size(3)
        .skip(2)
        .shard(1, 2)
        .with_generator(Box::new(FixedGenerator { count: 10 }))
        .run(&backend, &[spec]);

    assert!(failures.is_empty());
    // IR backends dispatch one concrete Program per input. Each accepted
    // input is dispatched twice so the runner can observe determinism.
    assert_eq!(
        *backend
            .batches
            .lock()
            .unwrap_or_else(|poisoned| poisoned.into_inner()),
        vec![1; 18]
    );
    Ok(())
}

#[test]
fn progress_callback_reports_runner_counters() -> Result<(), String> {
    let backend = RecordingBackend::default();
    let progress = Arc::new(Mutex::new(Vec::new()));
    let capture = Arc::clone(&progress);

    let failures = StreamingRunner::new()
        .batch_size(4)
        .with_generator(Box::new(FixedGenerator { count: 4 }))
        .on_progress(move |snapshot| {
            capture
                .lock()
                .unwrap_or_else(|poisoned| poisoned.into_inner())
                .push(snapshot)
        })
        .run(
            &backend,
            &[fixed_spec("streaming.test.progress")
                .map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?],
        );

    assert!(failures.is_empty());
    let snapshots = progress
        .lock()
        .unwrap_or_else(|poisoned| poisoned.into_inner());
    assert!(snapshots
        .iter()
        .any(|item| item.tested == 4 && item.passed == 4));
    assert!(snapshots
        .iter()
        .any(|item| item.tested == 8 && item.passed == 8));
    Ok(())
}

#[test]
fn failed_inputs_are_persisted_as_binary_regressions() -> Result<(), String> {
    let op_id = "streaming_test_failure";
    let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
        .join("regressions")
        .join(op_id);
    match fs::remove_dir_all(&dir) {
        Ok(()) => {}
        Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
        Err(error) => {
            return Err(format!(
                "Fix: streaming regression fixture cleanup must remove stale dir: {error}"
            ));
        }
    }

    let backend = RecordingBackend {
        batches: Mutex::new(Vec::new()),
        corrupt: true,
    };
    let failures = StreamingRunner::new()
        .batch_size(2)
        .with_generator(Box::new(FixedGenerator { count: 1 }))
        .run(
            &backend,
            &[fixed_spec(op_id)
                .map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?],
        );

    assert_eq!(failures.len(), 2);
    let bin_count = fs::read_dir(&dir)
        .map_err(|error| {
            format!("Fix: streaming regression fixture must create regression dir: {error}")
        })?
        .flatten()
        .filter(|entry| entry.path().extension().and_then(|ext| ext.to_str()) == Some("bin"))
        .count();
    assert_eq!(bin_count, 1);

    fs::remove_dir_all(&dir).map_err(|error| {
        format!("Fix: streaming regression fixture cleanup must remove regression dir: {error}")
    })?;
    Ok(())
}