#![allow(dead_code, unused_imports)]
use mumu::parser::types::{Value, TensorValue, ElemType};
pub fn elementwise_op(a: &TensorValue, b: &TensorValue, fun: fn(f32,f32)->f32) -> Result<Value, String> {
if a.shape != b.shape {
return Err("Shape mismatch in elementwise op".to_string());
}
let mut out = Vec::with_capacity(a.data.len());
for i in (0..a.data.len()).step_by(4) {
let a_val = f32_from_bytes(&a.data[i..i+4]);
let b_val = f32_from_bytes(&b.data[i..i+4]);
let r = fun(a_val, b_val);
out.extend_from_slice(&r.to_le_bytes());
}
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: a.shape.clone(),
data: out,
}))
}
fn f32_from_bytes(bytes: &[u8]) -> f32 {
let mut arr = [0u8; 4];
arr.copy_from_slice(bytes);
f32::from_le_bytes(arr)
}
pub fn ensure_f32_tensor(v: &Value) -> Result<&TensorValue, String> {
match v {
Value::Tensor(tv) => {
if tv.elem_type != ElemType::Float32 {
return Err(format!("Expected Float32 Tensor, got {:?}", tv.elem_type));
}
Ok(tv)
}
_ => Err("Argument must be a Float32 Tensor".to_string()),
}
}