#[cfg(loom)]
use loom::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(loom))]
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::pipeline::backend::{ConformDispatchConfig, ExecutionModel, WgslBackend};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum FaultMode {
Synthetic,
RequireInnerActionable,
}
pub struct FaultInjectingBackend<'a> {
inner: &'a dyn WgslBackend,
fail_after: usize,
count: AtomicUsize,
mode: FaultMode,
}
impl<'a> FaultInjectingBackend<'a> {
#[inline]
pub(crate) fn new(inner: &'a dyn WgslBackend, fail_after: usize) -> Self {
Self {
inner,
fail_after,
count: AtomicUsize::new(0),
mode: FaultMode::Synthetic,
}
}
#[inline]
pub(crate) fn requiring_inner_actionable_error(
inner: &'a dyn WgslBackend,
fail_after: usize,
) -> Self {
Self {
inner,
fail_after,
count: AtomicUsize::new(0),
mode: FaultMode::RequireInnerActionable,
}
}
fn choose_fault_or_delegate<T>(
&self,
delegate: impl FnOnce() -> Result<T, String>,
) -> Result<T, String> {
let idx = self.count.fetch_add(1, Ordering::SeqCst);
if idx < self.fail_after {
return delegate();
}
match self.mode {
FaultMode::Synthetic => Err(format!(
"Fault injection: dispatch {idx} failed (fail_after={}). Fix: retry or degrade gracefully.",
self.fail_after
)),
FaultMode::RequireInnerActionable => match delegate() {
Err(err) if err.contains("Fix: ") => Err(err),
Err(err) => Err(format!(
"Fault injection delegated dispatch {idx} to backend `{}` but error was not actionable: {err}. Fix: backend errors must include `Fix: ...`.",
self.inner.name()
)),
Ok(_) => Err(format!(
"Fault injection expected backend `{}` to fail at dispatch {idx}, but it succeeded. Fix: configure the inner backend fault path before using delegated fault mode.",
self.inner.name()
)),
},
}
}
}
impl WgslBackend for FaultInjectingBackend<'_> {
fn name(&self) -> &str {
self.inner.name()
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
self.choose_fault_or_delegate(|| self.inner.dispatch(wgsl, input, output_size, config))
}
fn dispatch_program(
&self,
program: &[u8],
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
self.choose_fault_or_delegate(|| {
self.inner
.dispatch_program(program, input, output_size, config)
})
}
fn dispatch_batch(
&self,
wgsl: &str,
inputs: &[Vec<u8>],
output_sizes: &[usize],
config: ConformDispatchConfig,
) -> Result<Vec<Vec<u8>>, String> {
self.choose_fault_or_delegate(|| {
self.inner
.dispatch_batch(wgsl, inputs, output_sizes, config)
})
}
fn execute(&self, model: &ExecutionModel) -> Result<Vec<u8>, String> {
self.choose_fault_or_delegate(|| self.inner.execute(model))
}
}
#[inline]
pub fn with_fault_at_dispatch(
backend: &dyn WgslBackend,
program: &vyre::ir::Program,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
fail_after: usize,
) -> Result<Vec<u8>, String> {
let fault_backend = FaultInjectingBackend::new(backend, fail_after);
let bytes = program
.to_wire()
.map_err(|e| format!("with_fault_at_dispatch failed to serialize program: {e}"))?;
fault_backend
.dispatch_program(&bytes, input, output_size, config)
.map_err(|e| format!("with_fault_at_dispatch failed at fail_after={fail_after}: {e}"))
}
#[inline]
pub fn with_fault_at_dispatch_requiring_backend_error(
backend: &dyn WgslBackend,
program: &vyre::ir::Program,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
fail_after: usize,
) -> Result<Vec<u8>, String> {
let fault_backend =
FaultInjectingBackend::requiring_inner_actionable_error(backend, fail_after);
let bytes = program.to_wire().map_err(|e| {
format!("with_fault_at_dispatch_requiring_backend_error failed to serialize program: {e}")
})?;
fault_backend
.dispatch_program(&bytes, input, output_size, config)
.map_err(|e| {
format!(
"with_fault_at_dispatch_requiring_backend_error failed at fail_after={fail_after}: {e}"
)
})
}
#[cfg(test)]
mod tests {
use super::{with_fault_at_dispatch, FaultInjectingBackend};
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
struct MockBackend {
output: Vec<u8>,
}
impl WgslBackend for MockBackend {
fn name(&self) -> &str {
"mock"
}
fn dispatch(
&self,
_wgsl: &str,
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Ok(self.output.clone())
}
fn dispatch_program(
&self,
_program: &[u8],
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Ok(self.output.clone())
}
}
#[test]
fn fault_backend_succeeds_before_threshold() {
let inner = MockBackend {
output: vec![0xAB, 0xCD],
};
let fault = FaultInjectingBackend::new(&inner, 2);
assert!(fault
.dispatch("", &[], 2, ConformDispatchConfig::default())
.is_ok());
assert!(fault
.dispatch("", &[], 2, ConformDispatchConfig::default())
.is_ok());
}
#[test]
fn fault_backend_fails_at_threshold_without_panic() {
let inner = MockBackend {
output: vec![0xAB, 0xCD],
};
let fault = FaultInjectingBackend::new(&inner, 1);
assert!(fault
.dispatch("", &[], 2, ConformDispatchConfig::default())
.is_ok());
let result = fault.dispatch("", &[], 2, ConformDispatchConfig::default());
assert!(
result.is_err(),
"expected structured error at fail_after threshold, got: {:?}",
result
);
let msg = result.unwrap_err();
assert!(
msg.contains("Fault injection"),
"error must mention fault injection, got: {msg}"
);
}
#[test]
fn with_fault_at_dispatch_detects_failure() {
let inner = MockBackend {
output: vec![0x00; 4],
};
let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let result = with_fault_at_dispatch(
&inner,
&program,
&[],
4,
ConformDispatchConfig::default(),
0,
);
assert!(
result.is_err(),
"expected failure when fail_after=0, got: {:?}",
result
);
}
#[test]
fn with_fault_at_dispatch_allows_success() {
let inner = MockBackend {
output: vec![0x00; 4],
};
let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let result = with_fault_at_dispatch(
&inner,
&program,
&[],
4,
ConformDispatchConfig::default(),
5,
);
assert!(
result.is_ok(),
"expected success when fail_after is high, got: {:?}",
result
);
}
}