#![allow(dead_code)]
use mumu::parser::types::{Value, TensorValue, ElemType};
pub fn cpu_matrix_multiply(a: &TensorValue, b: &TensorValue) -> Result<Value, String> {
if a.shape.len() != 2 || b.shape.len() != 2 {
return Err("gpu:multiply => rank must be 2".to_string());
}
let m = a.shape[0];
let k1 = a.shape[1];
let k2 = b.shape[0];
let n = b.shape[1];
if k1 != k2 {
return Err(format!("Dimension mismatch: left is {}×{}, right is {}×{}", m, k1, k2, n));
}
let mut out = Vec::with_capacity(m * n * 4);
for row in 0..m {
for col in 0..n {
let mut sum = 0f32;
for x in 0..k1 {
let a_index = (row * k1 + x) * 4;
let b_index = (x * n + col) * 4;
let a_val = f32_from_bytes(&a.data[a_index..a_index+4]);
let b_val = f32_from_bytes(&b.data[b_index..b_index+4]);
sum += a_val * b_val;
}
out.extend_from_slice(&sum.to_le_bytes());
}
}
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: vec![m, n],
data: out,
}))
}
fn f32_from_bytes(bytes: &[u8]) -> f32 {
let mut arr = [0u8; 4];
arr.copy_from_slice(bytes);
f32::from_le_bytes(arr)
}
pub fn cpu_inverse_2x2(t: &TensorValue) -> Result<Value, String> {
let a = f32_from_bytes(&t.data[0..4]);
let b = f32_from_bytes(&t.data[4..8]);
let c = f32_from_bytes(&t.data[8..12]);
let d = f32_from_bytes(&t.data[12..16]);
let det = a*d - b*c;
if det.abs() < 1e-12 {
return Err("Det is near zero => cannot invert".to_string());
}
let invd = 1.0 / det;
let ra = d * invd;
let rb = -b * invd;
let rc = -c * invd;
let rd = a * invd;
let mut out = Vec::with_capacity(16);
out.extend_from_slice(&ra.to_le_bytes());
out.extend_from_slice(&rb.to_le_bytes());
out.extend_from_slice(&rc.to_le_bytes());
out.extend_from_slice(&rd.to_le_bytes());
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: vec![2,2],
data: out,
}))
}
pub fn cpu_transpose_2d(t: &TensorValue) -> Result<Value, String> {
if t.shape.len() != 2 {
return Err("Rank must be 2".to_string());
}
let rows = t.shape[0];
let cols = t.shape[1];
let mut out = vec![0u8; t.data.len()];
for r in 0..rows {
for c in 0..cols {
let src_i = (r * cols + c) * 4;
let dst_i = (c * rows + r) * 4;
out[dst_i..dst_i+4].copy_from_slice(&t.data[src_i..src_i+4]);
}
}
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: vec![cols, rows],
data: out,
}))
}
pub fn cpu_reduce_sum(t: &TensorValue) -> Result<Value, String> {
let mut accum = 0f32;
for chunk in t.data.chunks_exact(4) {
accum += f32_from_bytes(chunk);
}
let out = accum.to_le_bytes();
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: vec![1],
data: out.to_vec(),
}))
}
pub fn cpu_scale_tensor(t: &TensorValue, scalar: f32) -> Result<Value, String> {
let mut out = Vec::with_capacity(t.data.len());
for chunk in t.data.chunks_exact(4) {
let x = f32_from_bytes(chunk);
let s = x * scalar;
out.extend_from_slice(&s.to_le_bytes());
}
Ok(Value::Tensor(TensorValue {
elem_type: ElemType::Float32,
shape: t.shape.clone(),
data: out,
}))
}
pub fn perform_compute_multiply(
_ctx: &crate::vulkan::AshVulkanContext,
a: &TensorValue,
b: &TensorValue
) -> Result<Value, String> {
cpu_matrix_multiply(a, b)
}
pub fn perform_compute_add(
_ctx: &crate::vulkan::AshVulkanContext,
a: &TensorValue,
b: &TensorValue
) -> Result<Value, String> {
if a.shape != b.shape {
return Err("Shape mismatch in gpu:add".to_string());
}
crate::operators::elementwise_op(a, b, |x, y| x + y)
}
pub fn perform_compute_subtract(
_ctx: &crate::vulkan::AshVulkanContext,
a: &TensorValue,
b: &TensorValue
) -> Result<Value, String> {
if a.shape != b.shape {
return Err("Shape mismatch in gpu:subtract".to_string());
}
crate::operators::elementwise_op(a, b, |x, y| x - y)
}
pub fn perform_compute_hadamard(
_ctx: &crate::vulkan::AshVulkanContext,
a: &TensorValue,
b: &TensorValue
) -> Result<Value, String> {
if a.shape != b.shape {
return Err("Shape mismatch in gpu:hadamard".to_string());
}
crate::operators::elementwise_op(a, b, |x, y| x * y)
}
pub fn perform_compute_transpose(
_ctx: &crate::vulkan::AshVulkanContext,
t: &TensorValue
) -> Result<Value, String> {
if t.shape.len() != 2 {
return Err("Rank must be 2 for transpose".to_string());
}
cpu_transpose_2d(t)
}