use super::StreamingRunner;
use crate::generate::generators::InputGenerator;
use crate::spec::builder::BuildError;
use crate::{DataType, OpSignature, OpSpec};
use std::fs;
use std::sync::{Arc, Mutex};
use vyre::{BackendError, DispatchConfig, Program, VyreBackend};
struct FixedGenerator {
count: u64,
}
impl InputGenerator for FixedGenerator {
fn name(&self) -> &str {
"fixed"
}
fn handles(&self, signature: &OpSignature) -> bool {
signature.inputs == [DataType::U64] && signature.output == DataType::U64
}
fn generate(&self, _signature: &OpSignature, _seed: u64) -> Vec<(String, Vec<u8>)> {
(0..self.count)
.map(|idx| (format!("case:{idx}"), idx.to_le_bytes().to_vec()))
.collect()
}
}
#[derive(Default)]
struct RecordingBackend {
batches: Mutex<Vec<usize>>,
corrupt: bool,
}
impl VyreBackend for RecordingBackend {
fn id(&self) -> &'static str {
"recording"
}
fn dispatch(
&self,
_program: &Program,
inputs: &[Vec<u8>],
_config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
let input = inputs.first().map(Vec::as_slice).unwrap_or(&[]);
let output_size = 8.min(input.len());
let mut out = input[..output_size].to_vec();
if self.corrupt && !out.is_empty() {
out[0] ^= 0xFF;
}
self.batches
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.push(inputs.len());
Ok(vec![out])
}
}
fn fixed_spec(id: &'static str) -> Result<OpSpec, BuildError> {
OpSpec::builder(id)
.signature(OpSignature {
inputs: vec![DataType::U64],
output: DataType::U64,
})
.cpu_fn(|input| input[..8].to_vec())
.wgsl_fn(|| {
"fn vyre_op(index: u32, input_len: u32) -> u32 { return input.data[index]; }"
.to_string()
})
.ir_program(Some(vyre::Program::empty))
.category(crate::Category::A {
composition_of: vec![id],
})
.laws(vec![crate::spec::law::AlgebraicLaw::Bounded {
lo: 0,
hi: u32::MAX,
}])
.strictness(crate::spec::types::Strictness::Strict)
.version(1)
.workgroup_size(Some(1))
.build()
}
#[test]
fn sharding_and_skip_use_deterministic_global_test_ids() -> Result<(), String> {
let backend = RecordingBackend::default();
let spec = fixed_spec("streaming.test.shard")
.map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?;
let failures = StreamingRunner::new()
.batch_size(3)
.skip(2)
.shard(1, 2)
.with_generator(Box::new(FixedGenerator { count: 10 }))
.run(&backend, &[spec]);
assert!(failures.is_empty());
assert_eq!(
*backend
.batches
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner()),
vec![1; 18]
);
Ok(())
}
#[test]
fn progress_callback_reports_runner_counters() -> Result<(), String> {
let backend = RecordingBackend::default();
let progress = Arc::new(Mutex::new(Vec::new()));
let capture = Arc::clone(&progress);
let failures = StreamingRunner::new()
.batch_size(4)
.with_generator(Box::new(FixedGenerator { count: 4 }))
.on_progress(move |snapshot| {
capture
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
.push(snapshot)
})
.run(
&backend,
&[fixed_spec("streaming.test.progress")
.map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?],
);
assert!(failures.is_empty());
let snapshots = progress
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner());
assert!(snapshots
.iter()
.any(|item| item.tested == 4 && item.passed == 4));
assert!(snapshots
.iter()
.any(|item| item.tested == 8 && item.passed == 8));
Ok(())
}
#[test]
fn failed_inputs_are_persisted_as_binary_regressions() -> Result<(), String> {
let op_id = "streaming_test_failure";
let dir = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
.join("regressions")
.join(op_id);
match fs::remove_dir_all(&dir) {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
Err(error) => {
return Err(format!(
"Fix: streaming regression fixture cleanup must remove stale dir: {error}"
));
}
}
let backend = RecordingBackend {
batches: Mutex::new(Vec::new()),
corrupt: true,
};
let failures = StreamingRunner::new()
.batch_size(2)
.with_generator(Box::new(FixedGenerator { count: 1 }))
.run(
&backend,
&[fixed_spec(op_id)
.map_err(|err| format!("Fix: streaming fixture spec must build: {err:?}"))?],
);
assert_eq!(failures.len(), 2);
let bin_count = fs::read_dir(&dir)
.map_err(|error| {
format!("Fix: streaming regression fixture must create regression dir: {error}")
})?
.flatten()
.filter(|entry| entry.path().extension().and_then(|ext| ext.to_str()) == Some("bin"))
.count();
assert_eq!(bin_count, 1);
fs::remove_dir_all(&dir).map_err(|error| {
format!("Fix: streaming regression fixture cleanup must remove regression dir: {error}")
})?;
Ok(())
}