use std::sync::Arc;
use std::thread;
use vyre::ir::Program;
pub use crate::enforce::enforcers::determinism::{DeterminismReport, Divergence};
use crate::spec::value::Value;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadDivergence {
pub iteration: usize,
pub thread_id: usize,
pub expected: Vec<u8>,
pub actual: Vec<u8>,
}
#[derive(Debug)]
struct TaggedResult {
iteration: usize,
thread_id: usize,
values: Vec<Value>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ThreadDeterminismReport {
pub iterations: usize,
pub threads: usize,
pub divergences: Vec<ThreadDivergence>,
}
#[inline]
pub fn run_determinism_suite(
program: &Program,
inputs: &[Value],
iterations: usize,
threads: usize,
) -> Result<ThreadDeterminismReport, String> {
run_determinism_suite_with_runner(program, inputs, iterations, threads, |p, i| {
let reference_inputs = Value::to_reference_values(i);
vyre_reference::run(p, &reference_inputs)
.map(Value::from_reference_values)
.map_err(|e| e.to_string())
})
}
#[inline]
pub fn run_determinism_suite_with_runner<F>(
program: &Program,
inputs: &[Value],
iterations: usize,
threads: usize,
runner: F,
) -> Result<ThreadDeterminismReport, String>
where
F: Fn(&Program, &[Value]) -> Result<Vec<Value>, String> + Send + Sync + Clone + 'static,
{
if threads == 0 {
return Err("threads must be > 0".to_string());
}
if iterations < 2 {
return Err(
"iterations must be >= 2 to detect nondeterminism. Fix: run at least two observations."
.to_string(),
);
}
let program = Arc::new(program.clone());
let inputs: Arc<[Value]> = inputs.to_vec().into();
let runner = Arc::new(runner);
let per_thread = iterations / threads;
let remainder = iterations % threads;
let mut handles = Vec::with_capacity(threads);
for t in 0..threads {
let count = per_thread + if t < remainder { 1 } else { 0 };
let prog = Arc::clone(&program);
let ins = Arc::clone(&inputs);
let run = Arc::clone(&runner);
let start_iteration = (0..t)
.map(|prior| per_thread + if prior < remainder { 1 } else { 0 })
.sum::<usize>();
let handle = thread::Builder::new()
.name(format!("vyre-h5-worker-{t}"))
.spawn(move || {
let mut results = Vec::with_capacity(count);
for local_index in 0..count {
match run(&prog, &ins) {
Ok(values) => results.push(TaggedResult {
iteration: start_iteration + local_index,
thread_id: t,
values,
}),
Err(e) => return Err(e),
}
}
Ok(results)
})
.map_err(|err| format!("failed to spawn determinism worker {t}: {err}"))?;
handles.push(handle);
}
let mut all_results: Vec<TaggedResult> = Vec::with_capacity(iterations);
for handle in handles {
match handle.join() {
Ok(Ok(results)) => all_results.extend(results),
Ok(Err(e)) => return Err(e),
Err(_) => return Err("thread panicked during determinism suite".to_string()),
}
}
let mut divergences = Vec::new();
if !all_results.is_empty() {
all_results.sort_by_key(|result| result.iteration);
let baseline = &all_results[0].values;
for result in all_results.iter().skip(1) {
if &result.values != baseline {
divergences.push(ThreadDivergence {
iteration: result.iteration,
thread_id: result.thread_id,
expected: values_to_bytes(baseline),
actual: values_to_bytes(&result.values),
});
}
}
}
Ok(ThreadDeterminismReport {
iterations,
threads,
divergences,
})
}
fn values_to_bytes(values: &[Value]) -> Vec<u8> {
values.iter().flat_map(|v| v.to_bytes()).collect()
}
#[inline]
pub fn run_gpu_determinism_suite(
op: &crate::OpSpec,
input: &[u8],
iterations: usize,
threads: usize,
) -> Result<ThreadDeterminismReport, String> {
if threads == 0 {
return Err("threads must be > 0".to_string());
}
if iterations < 2 {
return Err(
"iterations must be >= 2 to detect nondeterminism. Fix: run at least two observations."
.to_string(),
);
}
let op = std::sync::Arc::new(op.clone());
let input: std::sync::Arc<[u8]> = input.to_vec().into();
let per_thread = iterations / threads;
let remainder = iterations % threads;
let mut handles = Vec::with_capacity(threads);
for t in 0..threads {
let count = per_thread + if t < remainder { 1 } else { 0 };
let op = std::sync::Arc::clone(&op);
let input = std::sync::Arc::clone(&input);
let start_iteration = (0..t)
.map(|prior| per_thread + if prior < remainder { 1 } else { 0 })
.sum::<usize>();
let handle = std::thread::Builder::new()
.name(format!("vyre-h7-gpu-worker-{t}"))
.spawn(move || {
let backend =
crate::pipeline::backend::wgpu::WgpuBackend::new().ok_or_else(|| {
"Fix: no GPU adapter available for gpu determinism suite.".to_string()
})?;
let workgroup_size = op.workgroup_size.unwrap_or(64);
let mut results = Vec::with_capacity(count);
for local_index in 0..count {
match crate::pipeline::execution::execute_op(
&backend,
&op,
&input,
workgroup_size,
) {
Ok((gpu, _cpu)) => results.push(TaggedResult {
iteration: start_iteration + local_index,
thread_id: t,
values: vec![Value::Bytes(gpu)],
}),
Err(e) => return Err(e),
}
}
Ok(results)
})
.map_err(|err| format!("failed to spawn gpu determinism worker {t}: {err}"))?;
handles.push(handle);
}
let mut all_results: Vec<TaggedResult> = Vec::with_capacity(iterations);
for handle in handles {
match handle.join() {
Ok(Ok(results)) => all_results.extend(results),
Ok(Err(e)) => return Err(e),
Err(_) => return Err("thread panicked during gpu determinism suite".to_string()),
}
}
let mut divergences = Vec::new();
if !all_results.is_empty() {
all_results.sort_by_key(|r| r.iteration);
let baseline = &all_results[0].values;
for result in all_results.iter().skip(1) {
if &result.values != baseline {
divergences.push(ThreadDivergence {
iteration: result.iteration,
thread_id: result.thread_id,
expected: values_to_bytes(baseline),
actual: values_to_bytes(&result.values),
});
}
}
}
Ok(ThreadDeterminismReport {
iterations,
threads,
divergences,
})
}
#[cfg(test)]
mod tests {
use super::{run_determinism_suite, run_determinism_suite_with_runner};
use crate::spec::value::Value;
use std::sync::Arc;
use vyre::ir::Program;
fn empty_program() -> Program {
Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return])
}
#[test]
fn determinism_suite_passes_for_reference() {
let program = empty_program();
let report = run_determinism_suite(&program, &[], 20, 4)
.expect("suite should complete without error");
assert!(
report.divergences.is_empty(),
"reference interpreter must be deterministic, got divergences: {:?}",
report.divergences
);
}
#[test]
fn determinism_suite_detects_divergence() {
let program = empty_program();
let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let runner = move |_p: &Program, _i: &[Value]| {
let c = cc.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
if c == 3 {
Ok(vec![Value::U32(99)])
} else {
Ok(vec![Value::U32(42)])
}
};
let report = run_determinism_suite_with_runner(&program, &[], 10, 4, runner)
.expect("suite should complete without runner error");
assert!(
!report.divergences.is_empty(),
"expected at least one divergence for non-deterministic runner"
);
}
#[test]
fn determinism_suite_rejects_zero_threads() {
let program = empty_program();
let result = run_determinism_suite(&program, &[], 1, 0);
assert!(
result.is_err(),
"expected error for zero threads, got: {:?}",
result
);
assert!(result.unwrap_err().contains("threads must be > 0"));
}
#[test]
fn determinism_suite_rejects_zero_iterations() {
let program = empty_program();
let result = run_determinism_suite(&program, &[], 0, 2);
assert!(result.is_err());
assert!(result.unwrap_err().contains("iterations must be >= 2"));
}
}