1use mumu::{
4 parser::interpreter::Interpreter,
5 parser::types::{Value, FunctionValue},
6};
7use std::sync::{Arc, Mutex};
8
9fn matrix_add(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
10 if args.len() != 2 {
11 return Err("matrix:add expects 2 arguments".to_string());
12 }
13 match (&args[0], &args[1]) {
14 (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
15 let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
16 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
17 .collect();
18 Ok(Value::Int2DArray(result))
19 }
20 (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
21 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
22 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
23 .collect();
24 Ok(Value::Float2DArray(result))
25 }
26 (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
27 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
28 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 + b).collect())
29 .collect();
30 Ok(Value::Float2DArray(result))
31 }
32 (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
33 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
34 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + *b as f64).collect())
35 .collect();
36 Ok(Value::Float2DArray(result))
37 }
38 _ => Err("Not numeric".to_string()),
39 }
40}
41
42fn matrix_subtract(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
43 if args.len() != 2 {
44 return Err("matrix:subtract expects 2 arguments".to_string());
45 }
46 match (&args[0], &args[1]) {
47 (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
48 let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
49 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
50 .collect();
51 Ok(Value::Int2DArray(result))
52 }
53 (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
54 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
55 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
56 .collect();
57 Ok(Value::Float2DArray(result))
58 }
59 (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
60 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
61 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 - b).collect())
62 .collect();
63 Ok(Value::Float2DArray(result))
64 }
65 (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
66 let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
67 .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - *b as f64).collect())
68 .collect();
69 Ok(Value::Float2DArray(result))
70 }
71 _ => Err("Not numeric".to_string()),
72 }
73}
74
75fn matrix_multiply(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
76 if args.len() != 2 {
77 return Err("matrix:multiply expects 2 arguments".to_string());
78 }
79 match (&args[0], &args[1]) {
80 (Value::Int2DArray(a), Value::Int2DArray(b)) => matrix_mul_int(a, b),
81 (Value::Float2DArray(a), Value::Float2DArray(b)) => matrix_mul_float(a, b),
82 (Value::Int2DArray(a), Value::Float2DArray(b)) => matrix_mul_float2d_cast(a, b),
83 (Value::Float2DArray(a), Value::Int2DArray(b)) => matrix_mul_float2d_cast_rev(a, b),
84 _ => Err("Not numeric".to_string()),
85 }
86}
87
88fn matrix_mul_int(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
89 let n = a.len();
90 let m = a[0].len();
91 let p = b[0].len();
92 if b.len() != m {
93 return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
94 }
95 let mut result = vec![vec![0i32; p]; n];
96 for i in 0..n {
97 for j in 0..p {
98 for k in 0..m {
99 result[i][j] += a[i][k] * b[k][j];
100 }
101 }
102 }
103 Ok(Value::Int2DArray(result))
104}
105
106fn matrix_mul_float(a: &Vec<Vec<f64>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
107 let n = a.len();
108 let m = a[0].len();
109 let p = b[0].len();
110 if b.len() != m {
111 return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
112 }
113 let mut result = vec![vec![0.0f64; p]; n];
114 for i in 0..n {
115 for j in 0..p {
116 for k in 0..m {
117 result[i][j] += a[i][k] * b[k][j];
118 }
119 }
120 }
121 Ok(Value::Float2DArray(result))
122}
123
124fn matrix_mul_float2d_cast(a: &Vec<Vec<i32>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
125 let n = a.len();
126 let m = a[0].len();
127 let p = b[0].len();
128 if b.len() != m {
129 return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
130 }
131 let mut result = vec![vec![0.0f64; p]; n];
132 for i in 0..n {
133 for j in 0..p {
134 for k in 0..m {
135 result[i][j] += a[i][k] as f64 * b[k][j];
136 }
137 }
138 }
139 Ok(Value::Float2DArray(result))
140}
141
142fn matrix_mul_float2d_cast_rev(a: &Vec<Vec<f64>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
143 let n = a.len();
144 let m = a[0].len();
145 let p = b[0].len();
146 if b.len() != m {
147 return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
148 }
149 let mut result = vec![vec![0.0f64; p]; n];
150 for i in 0..n {
151 for j in 0..p {
152 for k in 0..m {
153 result[i][j] += a[i][k] * b[k][j] as f64;
154 }
155 }
156 }
157 Ok(Value::Float2DArray(result))
158}
159
160fn transpose2d<T: Clone>(m: &Vec<Vec<T>>) -> Vec<Vec<T>> {
161 let rows = m.len();
162 let cols = if rows > 0 { m[0].len() } else { 0 };
163 (0..cols)
164 .map(|c| (0..rows).map(|r| m[r][c].clone()).collect())
165 .collect()
166}
167
168fn matrix_transpose(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
169 if args.len() != 1 {
170 return Err("matrix:transpose expects 1 argument".to_string());
171 }
172 match &args[0] {
173 Value::Int2DArray(xs) => Ok(Value::Int2DArray(transpose2d(xs))),
174 Value::Float2DArray(xs) => Ok(Value::Float2DArray(transpose2d(xs))),
175 _ => Err("Not numeric".to_string()),
176 }
177}
178
179#[no_mangle]
180pub unsafe extern "C" fn Cargo_lock(
181 interp_ptr: *mut std::ffi::c_void,
182 _extra_str: *const std::ffi::c_void,
183) -> i32 {
184 let interp = &mut *(interp_ptr as *mut Interpreter);
185
186 macro_rules! reg {
187 ($name:expr, $f:expr) => {{
188 let func = Arc::new(Mutex::new($f));
189 interp.register_dynamic_function($name, func);
190 interp.set_variable($name, Value::Function(Box::new(FunctionValue::Named($name.into()))));
191 }};
192 }
193 reg!("matrix:add", matrix_add);
194 reg!("matrix:subtract", matrix_subtract);
195 reg!("matrix:multiply", matrix_multiply);
196 reg!("matrix:transpose", matrix_transpose);
197
198 0
199}