use crate::OpSpec;
pub fn program_for_spec_input(spec: &OpSpec, input: &[u8]) -> Result<vyre::Program, String> {
let cpu_output = (spec.cpu_fn)(input);
let program = spec
.program()
.or_else(|| vyre::ops::registry::lookup_program(spec.id))
.ok_or_else(|| {
format!(
"Fix: OpSpec `{}` has no vyre IR program; register `ir_program` in the spec or add the op to vyre::ops::registry before certification.",
spec.id
)
})?;
Ok(size_output_buffers(
program.with_entry_op_id(spec.id),
cpu_output.len(),
))
}
fn size_output_buffers(mut program: vyre::Program, output_size: usize) -> vyre::Program {
let mut buffers = program.buffers().to_vec();
for buffer in &mut buffers {
if buffer.is_output() {
let element_size = element_size_bytes(buffer.element()).max(1);
buffer.count = output_size
.div_ceil(element_size)
.try_into()
.unwrap_or(u32::MAX);
}
}
let mut resized = vyre::Program::new(buffers, program.workgroup_size, (*program.entry).clone());
resized.entry_op_id = program.entry_op_id.take();
resized
}
fn element_size_bytes(data_type: vyre::ir::DataType) -> usize {
match data_type {
vyre::ir::DataType::Bool => 4,
vyre::ir::DataType::U32 | vyre::ir::DataType::I32 | vyre::ir::DataType::F32 => 4,
vyre::ir::DataType::U64 | vyre::ir::DataType::Vec2U32 => 8,
vyre::ir::DataType::Vec4U32 => 16,
vyre::ir::DataType::Bytes => 1,
_ => 4,
}
}