mumu-gpu 0.1.0

GPU/Vulkan matrix and tensor operations for the mumu/lava language
Documentation
#![allow(dead_code, unused_imports)]
use mumu::parser::types::{Value, TensorValue, ElemType};

pub fn elementwise_op(a: &TensorValue, b: &TensorValue, fun: fn(f32,f32)->f32) -> Result<Value, String> {
    if a.shape != b.shape {
        return Err("Shape mismatch in elementwise op".to_string());
    }
    let mut out = Vec::with_capacity(a.data.len());
    for i in (0..a.data.len()).step_by(4) {
        let a_val = f32_from_bytes(&a.data[i..i+4]);
        let b_val = f32_from_bytes(&b.data[i..i+4]);
        let r = fun(a_val, b_val);
        out.extend_from_slice(&r.to_le_bytes());
    }
    Ok(Value::Tensor(TensorValue {
        elem_type: ElemType::Float32,
        shape: a.shape.clone(),
        data: out,
    }))
}

fn f32_from_bytes(bytes: &[u8]) -> f32 {
    let mut arr = [0u8; 4];
    arr.copy_from_slice(bytes);
    f32::from_le_bytes(arr)
}

pub fn ensure_f32_tensor(v: &Value) -> Result<&TensorValue, String> {
    match v {
        Value::Tensor(tv) => {
            if tv.elem_type != ElemType::Float32 {
                return Err(format!("Expected Float32 Tensor, got {:?}", tv.elem_type));
            }
            Ok(tv)
        }
        _ => Err("Argument must be a Float32 Tensor".to_string()),
    }
}