mumumatrix/
lib.rs

1// matrix/src/lib.rs
2
3use mumu::{
4    parser::interpreter::Interpreter,
5    parser::types::{Value, FunctionValue},
6};
7use std::sync::{Arc, Mutex};
8
9fn matrix_add(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
10    if args.len() != 2 {
11        return Err("matrix:add expects 2 arguments".to_string());
12    }
13    match (&args[0], &args[1]) {
14        (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
15            let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
16                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
17                .collect();
18            Ok(Value::Int2DArray(result))
19        }
20        (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
21            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
22                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
23                .collect();
24            Ok(Value::Float2DArray(result))
25        }
26        (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
27            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
28                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 + b).collect())
29                .collect();
30            Ok(Value::Float2DArray(result))
31        }
32        (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
33            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
34                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + *b as f64).collect())
35                .collect();
36            Ok(Value::Float2DArray(result))
37        }
38        _ => Err("Not numeric".to_string()),
39    }
40}
41
42fn matrix_subtract(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
43    if args.len() != 2 {
44        return Err("matrix:subtract expects 2 arguments".to_string());
45    }
46    match (&args[0], &args[1]) {
47        (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
48            let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
49                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
50                .collect();
51            Ok(Value::Int2DArray(result))
52        }
53        (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
54            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
55                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
56                .collect();
57            Ok(Value::Float2DArray(result))
58        }
59        (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
60            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
61                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 - b).collect())
62                .collect();
63            Ok(Value::Float2DArray(result))
64        }
65        (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
66            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
67                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - *b as f64).collect())
68                .collect();
69            Ok(Value::Float2DArray(result))
70        }
71        _ => Err("Not numeric".to_string()),
72    }
73}
74
75fn matrix_multiply(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
76    if args.len() != 2 {
77        return Err("matrix:multiply expects 2 arguments".to_string());
78    }
79    match (&args[0], &args[1]) {
80        (Value::Int2DArray(a), Value::Int2DArray(b)) => matrix_mul_int(a, b),
81        (Value::Float2DArray(a), Value::Float2DArray(b)) => matrix_mul_float(a, b),
82        (Value::Int2DArray(a), Value::Float2DArray(b)) => matrix_mul_float2d_cast(a, b),
83        (Value::Float2DArray(a), Value::Int2DArray(b)) => matrix_mul_float2d_cast_rev(a, b),
84        _ => Err("Not numeric".to_string()),
85    }
86}
87
88fn matrix_mul_int(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
89    let n = a.len();
90    let m = a[0].len();
91    let p = b[0].len();
92    if b.len() != m {
93        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
94    }
95    let mut result = vec![vec![0i32; p]; n];
96    for i in 0..n {
97        for j in 0..p {
98            for k in 0..m {
99                result[i][j] += a[i][k] * b[k][j];
100            }
101        }
102    }
103    Ok(Value::Int2DArray(result))
104}
105
106fn matrix_mul_float(a: &Vec<Vec<f64>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
107    let n = a.len();
108    let m = a[0].len();
109    let p = b[0].len();
110    if b.len() != m {
111        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
112    }
113    let mut result = vec![vec![0.0f64; p]; n];
114    for i in 0..n {
115        for j in 0..p {
116            for k in 0..m {
117                result[i][j] += a[i][k] * b[k][j];
118            }
119        }
120    }
121    Ok(Value::Float2DArray(result))
122}
123
124fn matrix_mul_float2d_cast(a: &Vec<Vec<i32>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
125    let n = a.len();
126    let m = a[0].len();
127    let p = b[0].len();
128    if b.len() != m {
129        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
130    }
131    let mut result = vec![vec![0.0f64; p]; n];
132    for i in 0..n {
133        for j in 0..p {
134            for k in 0..m {
135                result[i][j] += a[i][k] as f64 * b[k][j];
136            }
137        }
138    }
139    Ok(Value::Float2DArray(result))
140}
141
142fn matrix_mul_float2d_cast_rev(a: &Vec<Vec<f64>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
143    let n = a.len();
144    let m = a[0].len();
145    let p = b[0].len();
146    if b.len() != m {
147        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
148    }
149    let mut result = vec![vec![0.0f64; p]; n];
150    for i in 0..n {
151        for j in 0..p {
152            for k in 0..m {
153                result[i][j] += a[i][k] * b[k][j] as f64;
154            }
155        }
156    }
157    Ok(Value::Float2DArray(result))
158}
159
160fn transpose2d<T: Clone>(m: &Vec<Vec<T>>) -> Vec<Vec<T>> {
161    let rows = m.len();
162    let cols = if rows > 0 { m[0].len() } else { 0 };
163    (0..cols)
164        .map(|c| (0..rows).map(|r| m[r][c].clone()).collect())
165        .collect()
166}
167
168fn matrix_transpose(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
169    if args.len() != 1 {
170        return Err("matrix:transpose expects 1 argument".to_string());
171    }
172    match &args[0] {
173        Value::Int2DArray(xs) => Ok(Value::Int2DArray(transpose2d(xs))),
174        Value::Float2DArray(xs) => Ok(Value::Float2DArray(transpose2d(xs))),
175        _ => Err("Not numeric".to_string()),
176    }
177}
178
179#[no_mangle]
180pub unsafe extern "C" fn Cargo_lock(
181    interp_ptr: *mut std::ffi::c_void,
182    _extra_str: *const std::ffi::c_void,
183) -> i32 {
184    let interp = &mut *(interp_ptr as *mut Interpreter);
185
186    macro_rules! reg {
187        ($name:expr, $f:expr) => {{
188            let func = Arc::new(Mutex::new($f));
189            interp.register_dynamic_function($name, func);
190            interp.set_variable($name, Value::Function(Box::new(FunctionValue::Named($name.into()))));
191        }};
192    }
193    reg!("matrix:add", matrix_add);
194    reg!("matrix:subtract", matrix_subtract);
195    reg!("matrix:multiply", matrix_multiply);
196    reg!("matrix:transpose", matrix_transpose);
197
198    0
199}