use std::mem;
use bytemuck::Pod;
use smallvec::SmallVec;
use vyre_foundation::ir::Program;
use crate::backend::{BackendError, DispatchConfig, VyreBackend};
pub trait TypedDispatchExt: VyreBackend {
fn dispatch_bytes(
&self,
program: &Program,
inputs: &[&[u8]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
self.dispatch_borrowed(program, inputs, config)
}
fn dispatch_pod<T: Pod>(
&self,
program: &Program,
inputs: &[&[T]],
config: &DispatchConfig,
) -> Result<Vec<Vec<T>>, BackendError> {
let byte_inputs: SmallVec<[&[u8]; 8]> = inputs
.iter()
.map(|input| bytemuck::cast_slice::<T, u8>(input))
.collect();
let outputs = self.dispatch_borrowed(program, &byte_inputs, config)?;
decode_pod_outputs(outputs)
}
fn dispatch_u32(
&self,
program: &Program,
inputs: &[&[u32]],
config: &DispatchConfig,
) -> Result<Vec<Vec<u32>>, BackendError> {
self.dispatch_pod(program, inputs, config)
}
fn dispatch_f32(
&self,
program: &Program,
inputs: &[&[f32]],
config: &DispatchConfig,
) -> Result<Vec<Vec<f32>>, BackendError> {
self.dispatch_pod(program, inputs, config)
}
}
impl<T: VyreBackend + ?Sized> TypedDispatchExt for T {}
fn decode_pod_outputs<T: Pod>(outputs: Vec<Vec<u8>>) -> Result<Vec<Vec<T>>, BackendError> {
let width = mem::size_of::<T>();
if width == 0 {
return Err(BackendError::InvalidProgram {
fix: "Fix: typed dispatch does not support zero-sized POD outputs.".to_string(),
});
}
outputs
.into_iter()
.enumerate()
.map(|(index, bytes)| decode_pod_output(index, bytes, width))
.collect()
}
fn decode_pod_output<T: Pod>(
index: usize,
bytes: Vec<u8>,
width: usize,
) -> Result<Vec<T>, BackendError> {
let remainder = bytes.len() % width;
if remainder != 0 {
return Err(BackendError::InvalidProgram {
fix: format!(
"Fix: output buffer {index} has {} bytes, which is not a whole number of {}-byte typed values.",
bytes.len(),
width
),
});
}
Ok(bytes
.chunks_exact(width)
.map(bytemuck::pod_read_unaligned::<T>)
.collect())
}
#[cfg(test)]
mod tests {
use std::collections::HashSet;
use vyre_foundation::ir::{OpId, Program};
use super::*;
use crate::backend::private;
struct EchoBackend;
impl private::Sealed for EchoBackend {}
impl VyreBackend for EchoBackend {
fn id(&self) -> &'static str {
"typed-dispatch-test"
}
fn supported_ops(&self) -> &HashSet<OpId> {
static OPS: std::sync::OnceLock<HashSet<OpId>> = std::sync::OnceLock::new();
OPS.get_or_init(HashSet::new)
}
fn dispatch(
&self,
_program: &Program,
inputs: &[Vec<u8>],
_config: &DispatchConfig,
) -> Result<Vec<Vec<u8>>, BackendError> {
Ok(inputs.to_vec())
}
}
#[test]
fn dispatch_u32_packs_inputs_and_decodes_outputs() {
let backend = EchoBackend;
let input = [1u32, 2, 0x0102_0304];
let outputs = backend
.dispatch_u32(&Program::empty(), &[&input], &DispatchConfig::default())
.unwrap_or_else(|error| panic!("typed u32 dispatch must succeed: {error}"));
assert_eq!(outputs, vec![input.to_vec()]);
}
#[test]
fn typed_decode_rejects_partial_words() {
let error = decode_pod_outputs::<u32>(vec![vec![1, 2, 3]])
.expect_err("partial u32 output must fail");
assert!(
error.to_string().contains("whole number of 4-byte"),
"malformed typed output must produce actionable width error: {error}"
);
}
}