mumugpu/operators/
elementwise.rs1#![allow(dead_code, unused_imports)]
2use core_mumu::parser::types::Value;
3
4pub fn ensure_float2d(v: &Value) -> Result<&Vec<Vec<f64>>, String> {
6 match v {
7 Value::Float2DArray(rows) => Ok(rows),
8 _ => Err("Argument must be a Float2DArray".to_string()),
9 }
10}
11
12pub fn elementwise_op(
15 a: &Vec<Vec<f64>>,
16 b: &Vec<Vec<f64>>,
17 fun: fn(f64, f64) -> f64,
18) -> Result<Value, String> {
19 if a.len() != b.len() {
20 return Err("Shape mismatch in elementwise op (row count)".to_string());
21 }
22 if let (Some(ar0), Some(br0)) = (a.get(0), b.get(0)) {
23 if ar0.len() != br0.len() {
24 return Err("Shape mismatch in elementwise op (col count)".to_string());
25 }
26 }
27 for (i, (ra, rb)) in a.iter().zip(b.iter()).enumerate() {
28 if ra.len() != rb.len() {
29 return Err(format!(
30 "Shape mismatch in elementwise op at row {} ({} vs {})",
31 i,
32 ra.len(),
33 rb.len()
34 ));
35 }
36 }
37
38 let mut out = Vec::with_capacity(a.len());
39 for (ra, rb) in a.iter().zip(b.iter()) {
40 let mut row = Vec::with_capacity(ra.len());
41 for (xa, xb) in ra.iter().zip(rb.iter()) {
42 row.push(fun(*xa, *xb));
43 }
44 out.push(row);
45 }
46 Ok(Value::Float2DArray(out))
47}