mumugpu/operators/
elementwise.rs

1#![allow(dead_code, unused_imports)]
2use core_mumu::parser::types::Value;
3
4/// Ensure the value is a Float2DArray and return a reference to it.
5pub fn ensure_float2d(v: &Value) -> Result<&Vec<Vec<f64>>, String> {
6    match v {
7        Value::Float2DArray(rows) => Ok(rows),
8        _ => Err("Argument must be a Float2DArray".to_string()),
9    }
10}
11
12/// Elementwise operation on two Float2DArray values with identical shapes.
13/// Returns a new Float2DArray with the same shape.
14pub fn elementwise_op(
15    a: &Vec<Vec<f64>>,
16    b: &Vec<Vec<f64>>,
17    fun: fn(f64, f64) -> f64,
18) -> Result<Value, String> {
19    if a.len() != b.len() {
20        return Err("Shape mismatch in elementwise op (row count)".to_string());
21    }
22    if let (Some(ar0), Some(br0)) = (a.get(0), b.get(0)) {
23        if ar0.len() != br0.len() {
24            return Err("Shape mismatch in elementwise op (col count)".to_string());
25        }
26    }
27    for (i, (ra, rb)) in a.iter().zip(b.iter()).enumerate() {
28        if ra.len() != rb.len() {
29            return Err(format!(
30                "Shape mismatch in elementwise op at row {} ({} vs {})",
31                i,
32                ra.len(),
33                rb.len()
34            ));
35        }
36    }
37
38    let mut out = Vec::with_capacity(a.len());
39    for (ra, rb) in a.iter().zip(b.iter()) {
40        let mut row = Vec::with_capacity(ra.len());
41        for (xa, xb) in ra.iter().zip(rb.iter()) {
42            row.push(fun(*xa, *xb));
43        }
44        out.push(row);
45    }
46    Ok(Value::Float2DArray(out))
47}