vyre-conform 0.1.0

Conformance suite for vyre backends — proves byte-identical output to CPU reference
Documentation
//! H3 — fault-at-dispatch harness.
//!
//! Simulates backend dispatch failure by wrapping a [`WgslBackend`] and
//! returning a structured error after N dispatches. This verifies that
//! the conformance suite (and future backends) handle device loss or
//! transient dispatch failures gracefully — without panicking.

#[cfg(loom)]
use loom::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(loom))]
use std::sync::atomic::{AtomicUsize, Ordering};

use crate::pipeline::backend::{ConformDispatchConfig, ExecutionModel, WgslBackend};

#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FaultMode {
    Synthetic,
    RequireInnerActionable,
}

/// A mock backend that fails after a configurable number of dispatches.
///
/// Before the failure threshold, it delegates to an inner backend. At the
/// threshold and beyond, it returns a structured error.
pub struct FaultInjectingBackend<'a> {
    inner: &'a dyn WgslBackend,
    fail_after: usize,
    count: AtomicUsize,
    mode: FaultMode,
}

impl<'a> FaultInjectingBackend<'a> {
    /// Wrap `inner` so that the dispatch with 0-based index `fail_after`
    /// (and all subsequent dispatches) return an error.
    #[inline]
    pub(crate) fn new(inner: &'a dyn WgslBackend, fail_after: usize) -> Self {
        Self {
            inner,
            fail_after,
            count: AtomicUsize::new(0),
            mode: FaultMode::Synthetic,
        }
    }

    /// Wrap `inner` so that the failing dispatch is delegated to `inner`, and
    /// the resulting backend error is required to contain `Fix: `.
    #[inline]
    pub(crate) fn requiring_inner_actionable_error(
        inner: &'a dyn WgslBackend,
        fail_after: usize,
    ) -> Self {
        Self {
            inner,
            fail_after,
            count: AtomicUsize::new(0),
            mode: FaultMode::RequireInnerActionable,
        }
    }

    fn choose_fault_or_delegate<T>(
        &self,
        delegate: impl FnOnce() -> Result<T, String>,
    ) -> Result<T, String> {
        let idx = self.count.fetch_add(1, Ordering::SeqCst);
        if idx < self.fail_after {
            return delegate();
        }

        match self.mode {
            FaultMode::Synthetic => Err(format!(
                "Fault injection: dispatch {idx} failed (fail_after={}). Fix: retry or degrade gracefully.",
                self.fail_after
            )),
            FaultMode::RequireInnerActionable => match delegate() {
                Err(err) if err.contains("Fix: ") => Err(err),
                Err(err) => Err(format!(
                    "Fault injection delegated dispatch {idx} to backend `{}` but error was not actionable: {err}. Fix: backend errors must include `Fix: ...`.",
                    self.inner.name()
                )),
                Ok(_) => Err(format!(
                    "Fault injection expected backend `{}` to fail at dispatch {idx}, but it succeeded. Fix: configure the inner backend fault path before using delegated fault mode.",
                    self.inner.name()
                )),
            },
        }
    }
}

impl WgslBackend for FaultInjectingBackend<'_> {
    fn name(&self) -> &str {
        self.inner.name()
    }

    fn dispatch(
        &self,
        wgsl: &str,
        input: &[u8],
        output_size: usize,
        config: ConformDispatchConfig,
    ) -> Result<Vec<u8>, String> {
        self.choose_fault_or_delegate(|| self.inner.dispatch(wgsl, input, output_size, config))
    }

    fn dispatch_program(
        &self,
        program: &[u8],
        input: &[u8],
        output_size: usize,
        config: ConformDispatchConfig,
    ) -> Result<Vec<u8>, String> {
        self.choose_fault_or_delegate(|| {
            self.inner
                .dispatch_program(program, input, output_size, config)
        })
    }

    fn dispatch_batch(
        &self,
        wgsl: &str,
        inputs: &[Vec<u8>],
        output_sizes: &[usize],
        config: ConformDispatchConfig,
    ) -> Result<Vec<Vec<u8>>, String> {
        self.choose_fault_or_delegate(|| {
            self.inner
                .dispatch_batch(wgsl, inputs, output_sizes, config)
        })
    }

    fn execute(&self, model: &ExecutionModel) -> Result<Vec<u8>, String> {
        self.choose_fault_or_delegate(|| self.inner.execute(model))
    }
}

/// Convenience function: dispatch `program` through a fault-injecting backend.
///
/// The program is serialized to bytes and passed to
/// [`WgslBackend::dispatch_program`]. If `fail_after` is reached during
/// dispatch, a structured error is returned instead of panicking.
#[inline]
pub fn with_fault_at_dispatch(
    backend: &dyn WgslBackend,
    program: &vyre::ir::Program,
    input: &[u8],
    output_size: usize,
    config: ConformDispatchConfig,
    fail_after: usize,
) -> Result<Vec<u8>, String> {
    let fault_backend = FaultInjectingBackend::new(backend, fail_after);
    let bytes = program
        .to_wire()
        .map_err(|e| format!("with_fault_at_dispatch failed to serialize program: {e}"))?;
    fault_backend
        .dispatch_program(&bytes, input, output_size, config)
        .map_err(|e| format!("with_fault_at_dispatch failed at fail_after={fail_after}: {e}"))
}

/// Dispatch through a fault wrapper that delegates the failing call to the
/// backend and requires the backend's own error to include `Fix: `.
#[inline]
pub fn with_fault_at_dispatch_requiring_backend_error(
    backend: &dyn WgslBackend,
    program: &vyre::ir::Program,
    input: &[u8],
    output_size: usize,
    config: ConformDispatchConfig,
    fail_after: usize,
) -> Result<Vec<u8>, String> {
    let fault_backend =
        FaultInjectingBackend::requiring_inner_actionable_error(backend, fail_after);
    let bytes = program.to_wire().map_err(|e| {
        format!("with_fault_at_dispatch_requiring_backend_error failed to serialize program: {e}")
    })?;
    fault_backend
        .dispatch_program(&bytes, input, output_size, config)
        .map_err(|e| {
            format!(
                "with_fault_at_dispatch_requiring_backend_error failed at fail_after={fail_after}: {e}"
            )
        })
}

#[cfg(test)]
mod tests {

    use super::{with_fault_at_dispatch, FaultInjectingBackend};
    use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};

    struct MockBackend {
        output: Vec<u8>,
    }

    impl WgslBackend for MockBackend {
        fn name(&self) -> &str {
            "mock"
        }

        fn dispatch(
            &self,
            _wgsl: &str,
            _input: &[u8],
            _output_size: usize,
            _config: ConformDispatchConfig,
        ) -> Result<Vec<u8>, String> {
            Ok(self.output.clone())
        }

        fn dispatch_program(
            &self,
            _program: &[u8],
            _input: &[u8],
            _output_size: usize,
            _config: ConformDispatchConfig,
        ) -> Result<Vec<u8>, String> {
            Ok(self.output.clone())
        }
    }

    #[test]
    fn fault_backend_succeeds_before_threshold() {
        let inner = MockBackend {
            output: vec![0xAB, 0xCD],
        };
        let fault = FaultInjectingBackend::new(&inner, 2);

        assert!(fault
            .dispatch("", &[], 2, ConformDispatchConfig::default())
            .is_ok());
        assert!(fault
            .dispatch("", &[], 2, ConformDispatchConfig::default())
            .is_ok());
    }

    #[test]
    fn fault_backend_fails_at_threshold_without_panic() {
        let inner = MockBackend {
            output: vec![0xAB, 0xCD],
        };
        let fault = FaultInjectingBackend::new(&inner, 1);

        // First dispatch succeeds.
        assert!(fault
            .dispatch("", &[], 2, ConformDispatchConfig::default())
            .is_ok());
        // Second dispatch (index 1) fails.
        let result = fault.dispatch("", &[], 2, ConformDispatchConfig::default());
        assert!(
            result.is_err(),
            "expected structured error at fail_after threshold, got: {:?}",
            result
        );
        let msg = result.unwrap_err();
        assert!(
            msg.contains("Fault injection"),
            "error must mention fault injection, got: {msg}"
        );
    }

    #[test]
    fn with_fault_at_dispatch_detects_failure() {
        let inner = MockBackend {
            output: vec![0x00; 4],
        };
        let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);

        // fail_after = 0 means the very first dispatch fails.
        let result = with_fault_at_dispatch(
            &inner,
            &program,
            &[],
            4,
            ConformDispatchConfig::default(),
            0,
        );
        assert!(
            result.is_err(),
            "expected failure when fail_after=0, got: {:?}",
            result
        );
    }

    #[test]
    fn with_fault_at_dispatch_allows_success() {
        let inner = MockBackend {
            output: vec![0x00; 4],
        };
        let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);

        // fail_after = 5 means the first dispatch succeeds.
        let result = with_fault_at_dispatch(
            &inner,
            &program,
            &[],
            4,
            ConformDispatchConfig::default(),
            5,
        );
        assert!(
            result.is_ok(),
            "expected success when fail_after is high, got: {:?}",
            result
        );
    }
}