#[cfg(loom)]
use loom::sync::atomic::{AtomicUsize, Ordering};
#[cfg(not(loom))]
use std::sync::atomic::{AtomicUsize, Ordering};
use crate::pipeline::backend::{ConformDispatchConfig, ExecutionModel, WgslBackend};
pub struct SwapBackend<'a> {
backend_a: &'a dyn WgslBackend,
backend_b: &'a dyn WgslBackend,
swap_after: usize,
count: AtomicUsize,
}
impl<'a> SwapBackend<'a> {
#[inline]
pub fn new(
backend_a: &'a dyn WgslBackend,
backend_b: &'a dyn WgslBackend,
swap_after: usize,
) -> Self {
Self {
backend_a,
backend_b,
swap_after,
count: AtomicUsize::new(0),
}
}
}
impl WgslBackend for SwapBackend<'_> {
fn name(&self) -> &str {
"swap"
}
fn dispatch(
&self,
wgsl: &str,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let idx = self.count.fetch_add(1, Ordering::SeqCst);
if idx < self.swap_after {
self.backend_a.dispatch(wgsl, input, output_size, config)
} else {
self.backend_b.dispatch(wgsl, input, output_size, config)
}
}
fn dispatch_program(
&self,
program: &[u8],
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
let idx = self.count.fetch_add(1, Ordering::SeqCst);
if idx < self.swap_after {
self.backend_a
.dispatch_program(program, input, output_size, config)
} else {
self.backend_b
.dispatch_program(program, input, output_size, config)
}
}
fn dispatch_batch(
&self,
wgsl: &str,
inputs: &[Vec<u8>],
output_sizes: &[usize],
config: ConformDispatchConfig,
) -> Result<Vec<Vec<u8>>, String> {
let idx = self.count.fetch_add(1, Ordering::SeqCst);
if idx < self.swap_after {
self.backend_a
.dispatch_batch(wgsl, inputs, output_sizes, config)
} else {
self.backend_b
.dispatch_batch(wgsl, inputs, output_sizes, config)
}
}
fn execute(&self, model: &ExecutionModel) -> Result<Vec<u8>, String> {
let idx = self.count.fetch_add(1, Ordering::SeqCst);
if idx < self.swap_after {
self.backend_a.execute(model)
} else {
self.backend_b.execute(model)
}
}
}
#[inline]
pub fn with_backend_swap(
backend_a: &dyn WgslBackend,
backend_b: &dyn WgslBackend,
program: &vyre::ir::Program,
input: &[u8],
output_size: usize,
config: ConformDispatchConfig,
swap_after: usize,
) -> Result<Vec<u8>, String> {
let swap = SwapBackend::new(backend_a, backend_b, swap_after);
let bytes = program
.to_wire()
.map_err(|e| format!("with_backend_swap failed to serialize program: {e}"))?;
let mixed = swap
.dispatch_program(&bytes, input, output_size, config.clone())
.map_err(|e| format!("with_backend_swap failed at swap_after={swap_after}: {e}"))?;
let pure_a = backend_a
.dispatch_program(&bytes, input, output_size, config.clone())
.map_err(|e| format!("with_backend_swap pure backend A baseline failed: {e}"))?;
let pure_b = backend_b
.dispatch_program(&bytes, input, output_size, config)
.map_err(|e| format!("with_backend_swap pure backend B baseline failed: {e}"))?;
if mixed != pure_a || mixed != pure_b {
return Err(format!(
"backend swap produced output that differs from pure baselines at swap_after={swap_after}. Fix: reject illegal state transfer or make backend A/B layouts byte-compatible."
));
}
Ok(mixed)
}
#[cfg(test)]
mod tests {
use super::{with_backend_swap, SwapBackend};
use crate::pipeline::backend::{ConformDispatchConfig, WgslBackend};
struct MockBackend {
name: &'static str,
output: Vec<u8>,
}
impl WgslBackend for MockBackend {
fn name(&self) -> &str {
self.name
}
fn dispatch(
&self,
_wgsl: &str,
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Ok(self.output.clone())
}
fn dispatch_program(
&self,
_program: &[u8],
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Ok(self.output.clone())
}
}
#[test]
fn swap_backend_uses_a_before_threshold() {
let a = MockBackend {
name: "a",
output: vec![0xAA],
};
let b = MockBackend {
name: "b",
output: vec![0xBB],
};
let swap = SwapBackend::new(&a, &b, 2);
assert_eq!(
swap.dispatch("", &[], 1, ConformDispatchConfig::default())
.unwrap(),
vec![0xAA]
);
assert_eq!(
swap.dispatch("", &[], 1, ConformDispatchConfig::default())
.unwrap(),
vec![0xAA]
);
}
#[test]
fn swap_backend_uses_b_after_threshold() {
let a = MockBackend {
name: "a",
output: vec![0xAA],
};
let b = MockBackend {
name: "b",
output: vec![0xBB],
};
let swap = SwapBackend::new(&a, &b, 1);
assert_eq!(
swap.dispatch("", &[], 1, ConformDispatchConfig::default())
.unwrap(),
vec![0xAA]
);
assert_eq!(
swap.dispatch("", &[], 1, ConformDispatchConfig::default())
.unwrap(),
vec![0xBB]
);
}
#[test]
fn with_backend_swap_detects_divergence() {
let a = MockBackend {
name: "a",
output: vec![0x00; 4],
};
let b = MockBackend {
name: "b",
output: vec![0xFF; 4],
};
let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let result = with_backend_swap(
&a,
&b,
&program,
&[],
4,
ConformDispatchConfig::default(),
0,
);
assert!(
result.is_err(),
"mixed run must be rejected when pure baselines diverge"
);
let result = with_backend_swap(
&a,
&b,
&program,
&[],
4,
ConformDispatchConfig::default(),
1,
);
assert!(
result.is_err(),
"mixed run must be rejected when pure baselines diverge"
);
}
#[test]
fn with_backend_swap_propagates_backend_b_error() {
struct FailingBackend;
impl WgslBackend for FailingBackend {
fn name(&self) -> &str {
"failing"
}
fn dispatch(
&self,
_wgsl: &str,
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Err("backend b error".to_string())
}
fn dispatch_program(
&self,
_program: &[u8],
_input: &[u8],
_output_size: usize,
_config: ConformDispatchConfig,
) -> Result<Vec<u8>, String> {
Err("backend b error".to_string())
}
}
let a = MockBackend {
name: "a",
output: vec![0x00; 4],
};
let program = vyre::ir::Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let result = with_backend_swap(
&a,
&FailingBackend,
&program,
&[],
4,
ConformDispatchConfig::default(),
0,
);
assert!(
result.is_err(),
"expected error from backend B, got: {:?}",
result
);
let msg = result.unwrap_err();
assert!(
msg.contains("backend b error"),
"error must propagate from backend B, got: {msg}"
);
}
}