use vyre::ir::Program;
use crate::spec::value::Value;
#[inline]
pub fn with_oom_injection(
program: &Program,
inputs: &[Value],
fail_at_nth: usize,
) -> Result<Vec<Value>, String> {
if fail_at_nth == 0 {
return Err(
"OOM injection: fail_at_nth must be > 0. Fix: pass a positive allocation index."
.to_string(),
);
}
let mut output = None;
let report = crate::meta::oom::probe(fail_at_nth, || {
let reference_inputs = Value::to_reference_values(inputs);
output = Some(
vyre_reference::run(program, &reference_inputs)
.map(Value::from_reference_values)
.map_err(|e| e.to_string()),
);
});
match report.outcome {
crate::meta::oom::ProbeOutcome::Returned => match output {
Some(result) => result,
None => Err("OOM injection: reference run did not produce a result. Fix: keep probe closure side effects intact.".to_string()),
},
crate::meta::oom::ProbeOutcome::OomInjected => Err(format!(
"OOM injection triggered: allocation {fail_at_nth} failed"
)),
crate::meta::oom::ProbeOutcome::Panicked(message) => Err(format!(
"non-OOM panic during OOM injection: {message}. Fix: repair the reference panic before treating this as allocator failure."
)),
}
}
#[inline]
pub fn run_gpu_oom_suite(op: &crate::OpSpec, input: &[u8]) -> Result<(), String> {
use std::panic::{catch_unwind, AssertUnwindSafe};
use crate::pipeline::backend::WgslBackend;
let backend = crate::pipeline::backend::wgpu::WgpuBackend::new()
.ok_or_else(|| "Fix: no GPU adapter available for gpu oom suite.".to_string())?;
let workgroup_size = op.workgroup_size.unwrap_or(64);
let adversarial_output_size = usize::MAX / 2;
let config = crate::pipeline::backend::ConformDispatchConfig {
workgroup_size,
workgroup_count: 1,
convention: op.convention,
lookup_data: None,
buffer_init: crate::spec::types::BufferInitPolicy::default(),
};
let shader = crate::pipeline::backend::wrap_shader(&(op.wgsl_fn)(), &config);
let result = catch_unwind(AssertUnwindSafe(|| {
<crate::pipeline::backend::wgpu::WgpuBackend as WgslBackend>::dispatch(
&backend,
&shader,
input,
adversarial_output_size,
config,
)
}));
match result {
Ok(Ok(_)) => Err(format!(
"Fix: op {} dispatched successfully with adversarial output_size {}. GPU OOM guard is missing.",
op.id, adversarial_output_size
)),
Ok(Err(_)) => Ok(()),
Err(_) => Err(format!(
"Fix: op {} panicked during GPU OOM injection. GPU backend must return a structured error instead of panicking.",
op.id
)),
}
}
#[cfg(test)]
mod tests {
use super::with_oom_injection;
use vyre::ir::Program;
fn empty_program() -> Program {
Program::new(vec![], [1, 1, 1], vec![vyre::ir::Node::Return])
}
#[test]
fn with_oom_injection_rejects_zero_fail_at() {
let program = empty_program();
let result = with_oom_injection(&program, &[], 0);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("fail_at_nth must be > 0"),
"must reject zero fail_at_nth"
);
}
}