#[cfg(not(loom))]
#[cfg(loom)]
use loom::sync::Mutex;
#[cfg(loom)]
use loom::sync::{Mutex, Once};
#[cfg(not(loom))]
use std::sync::{Mutex, Once};
use vyre::ir::{BufferAccess, DataTypeSizeBytes, Program};
use crate::spec::value::Value;
use vyre_reference;
pub trait HarnessBackend: Send + Sync {
fn name(&self) -> &str;
fn run_with_byte_length(
&self,
program: &Program,
inputs: &[Value],
) -> Result<(Vec<Value>, usize), String>;
fn run(&self, program: &Program, inputs: &[Value]) -> Result<Vec<Value>, String> {
self.run_with_byte_length(program, inputs).map(|(v, _)| v)
}
}
pub struct ReferenceBackend;
impl HarnessBackend for ReferenceBackend {
fn name(&self) -> &str {
"reference"
}
fn run_with_byte_length(
&self,
program: &Program,
inputs: &[Value],
) -> Result<(Vec<Value>, usize), String> {
let reference_inputs = Value::to_reference_values(inputs);
let values = vyre_reference::run(program, &reference_inputs)
.map(Value::from_reference_values)
.map_err(|e| e.to_string())?;
let len = super::backend::values_to_bytes(&values)?.len();
Ok((values, len))
}
}
static REGISTRY: Mutex<Vec<&'static dyn HarnessBackend>> = Mutex::new(Vec::new());
static INIT: Once = Once::new();
const MAX_BACKEND_OUTPUT_BYTES: usize = 1024 * 1024 * 1024;
fn init_registry() {
INIT.call_once(|| {
REGISTRY
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(&ReferenceBackend);
});
}
#[inline]
pub fn backend_registry() -> Vec<&'static dyn HarnessBackend> {
init_registry();
REGISTRY.lock().unwrap_or_else(|e| e.into_inner()).clone()
}
#[inline]
pub fn register_backend(backend: &'static dyn HarnessBackend) {
init_registry();
REGISTRY
.lock()
.unwrap_or_else(|e| e.into_inner())
.push(backend);
}
#[inline]
pub fn with_every_backend<F>(program: &Program, inputs: &[Value], mut handler: F)
where
F: FnMut(&str, Result<Vec<Value>, String>),
{
let expected_output_bytes = expected_output_size(program, inputs);
for backend in backend_registry() {
let result =
match &expected_output_bytes {
Ok(expected) => backend.run_with_byte_length(program, inputs).and_then(
|(values, actual_len)| {
verify_raw_output_size(backend.name(), actual_len, *expected)?;
Ok(values)
},
),
Err(err) => Err(err.clone()),
};
handler(backend.name(), result);
}
}
fn expected_output_size(program: &Program, inputs: &[Value]) -> Result<Option<usize>, String> {
let mut input_index = 0usize;
let mut total = 0usize;
let mut saw_output = false;
for decl in program.buffers() {
if decl.access() == BufferAccess::Workgroup {
continue;
}
let input_bytes = inputs.get(input_index).map(Value::to_bytes);
input_index += 1;
if decl.access() == BufferAccess::ReadWrite {
saw_output = true;
let fallback = decl.element().size_bytes();
let len = input_bytes
.as_ref()
.map(Vec::len)
.filter(|len| *len > 0)
.unwrap_or(fallback);
total = total.checked_add(len).ok_or_else(|| {
"backend output-size contract overflowed usize. Fix: bound output buffers before dispatch."
.to_string()
})?;
}
}
Ok(saw_output.then_some(total))
}
fn verify_raw_output_size(
backend_name: &str,
actual: usize,
expected_output_bytes: Option<usize>,
) -> Result<(), String> {
let Some(expected) = expected_output_bytes else {
return Ok(());
};
if actual == expected {
return Ok(());
}
Err(format!(
"backend {backend_name} returned {actual} output bytes, expected exactly {expected}. Fix: honor WgslBackend::dispatch output_size without truncation or padding."
))
}
fn values_to_bytes(values: &[Value]) -> Result<Vec<u8>, String> {
values_to_bytes_bounded(values, MAX_BACKEND_OUTPUT_BYTES)
}
fn values_to_bytes_bounded(values: &[Value], max_bytes: usize) -> Result<Vec<u8>, String> {
let mut bytes = Vec::new();
for value in values {
let chunk = value.to_bytes();
let next_len = bytes.len().checked_add(chunk.len()).ok_or_else(|| {
"backend returned output whose byte length overflowed usize. Fix: bound backend result values before returning from HarnessBackend::run.".to_string()
})?;
if next_len > max_bytes {
return Err(format!(
"backend returned {next_len} bytes, exceeding the {max_bytes}-byte harness cap. Fix: honor output-size contracts and do not return unbounded Value vectors."
));
}
bytes.extend(chunk);
}
Ok(bytes)
}
#[cfg(test)]
mod tests {
use super::{backend_registry, register_backend, with_every_backend, HarnessBackend};
use crate::spec::value::Value;
use vyre::ir::Program;
struct GoodBackend {
output: Vec<Value>,
}
impl HarnessBackend for GoodBackend {
fn name(&self) -> &str {
"good-mock"
}
fn run_with_byte_length(
&self,
_program: &Program,
_inputs: &[Value],
) -> Result<(Vec<Value>, usize), String> {
let bytes = super::values_to_bytes(&self.output)?;
Ok((self.output.clone(), bytes.len()))
}
}
struct BadBackend;
impl HarnessBackend for BadBackend {
fn name(&self) -> &str {
"bad-mock"
}
fn run_with_byte_length(
&self,
_program: &Program,
_inputs: &[Value],
) -> Result<(Vec<Value>, usize), String> {
Err("mock backend failure".to_string())
}
}
#[test]
fn backend_registry_includes_reference() {
let backends = backend_registry();
assert!(
backends.iter().any(|b| b.name() == "reference"),
"registry must contain reference backend"
);
}
#[test]
fn register_backend_extends_registry() {
let leaked: &'static dyn HarnessBackend = Box::leak(Box::new(GoodBackend {
output: vec![Value::U32(42)],
}));
register_backend(leaked);
let backends = backend_registry();
assert!(
backends.iter().any(|b| b.name() == "good-mock"),
"registry must contain newly registered backend"
);
}
#[test]
fn with_every_backend_yields_results() {
let program = Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return]);
let inputs: &[Value] = &[];
let leaked_good: &'static dyn HarnessBackend = Box::leak(Box::new(GoodBackend {
output: vec![Value::U32(7)],
}));
let leaked_bad: &'static dyn HarnessBackend = Box::leak(Box::new(BadBackend));
register_backend(leaked_good);
register_backend(leaked_bad);
let mut names = Vec::new();
let mut ok_count = 0usize;
let mut err_count = 0usize;
with_every_backend(&program, inputs, |name, result| {
names.push(name.to_string());
match result {
Ok(_) => ok_count += 1,
Err(_) => err_count += 1,
}
});
assert!(
names.contains(&"reference".to_string()),
"reference must run"
);
assert!(
names.contains(&"good-mock".to_string()),
"good-mock must run"
);
assert!(names.contains(&"bad-mock".to_string()), "bad-mock must run");
assert!(ok_count >= 2, "expected at least two successes");
assert!(err_count >= 1, "expected at least one failure");
}
}