use mumu::{
parser::interpreter::Interpreter,
parser::types::{Value, FunctionValue},
};
use std::sync::{Arc, Mutex};
fn matrix_add(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
if args.len() != 2 {
return Err("matrix:add expects 2 arguments".to_string());
}
match (&args[0], &args[1]) {
(Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
.collect();
Ok(Value::Int2DArray(result))
}
(Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
.collect();
Ok(Value::Float2DArray(result))
}
(Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 + b).collect())
.collect();
Ok(Value::Float2DArray(result))
}
(Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + *b as f64).collect())
.collect();
Ok(Value::Float2DArray(result))
}
_ => Err("Not numeric".to_string()),
}
}
fn matrix_subtract(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
if args.len() != 2 {
return Err("matrix:subtract expects 2 arguments".to_string());
}
match (&args[0], &args[1]) {
(Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
.collect();
Ok(Value::Int2DArray(result))
}
(Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
.collect();
Ok(Value::Float2DArray(result))
}
(Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 - b).collect())
.collect();
Ok(Value::Float2DArray(result))
}
(Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
.map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - *b as f64).collect())
.collect();
Ok(Value::Float2DArray(result))
}
_ => Err("Not numeric".to_string()),
}
}
fn matrix_multiply(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
if args.len() != 2 {
return Err("matrix:multiply expects 2 arguments".to_string());
}
match (&args[0], &args[1]) {
(Value::Int2DArray(a), Value::Int2DArray(b)) => matrix_mul_int(a, b),
(Value::Float2DArray(a), Value::Float2DArray(b)) => matrix_mul_float(a, b),
(Value::Int2DArray(a), Value::Float2DArray(b)) => matrix_mul_float2d_cast(a, b),
(Value::Float2DArray(a), Value::Int2DArray(b)) => matrix_mul_float2d_cast_rev(a, b),
_ => Err("Not numeric".to_string()),
}
}
fn matrix_mul_int(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
let n = a.len();
let m = a[0].len();
let p = b[0].len();
if b.len() != m {
return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
}
let mut result = vec![vec![0i32; p]; n];
for i in 0..n {
for j in 0..p {
for k in 0..m {
result[i][j] += a[i][k] * b[k][j];
}
}
}
Ok(Value::Int2DArray(result))
}
fn matrix_mul_float(a: &Vec<Vec<f64>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
let n = a.len();
let m = a[0].len();
let p = b[0].len();
if b.len() != m {
return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
}
let mut result = vec![vec![0.0f64; p]; n];
for i in 0..n {
for j in 0..p {
for k in 0..m {
result[i][j] += a[i][k] * b[k][j];
}
}
}
Ok(Value::Float2DArray(result))
}
fn matrix_mul_float2d_cast(a: &Vec<Vec<i32>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
let n = a.len();
let m = a[0].len();
let p = b[0].len();
if b.len() != m {
return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
}
let mut result = vec![vec![0.0f64; p]; n];
for i in 0..n {
for j in 0..p {
for k in 0..m {
result[i][j] += a[i][k] as f64 * b[k][j];
}
}
}
Ok(Value::Float2DArray(result))
}
fn matrix_mul_float2d_cast_rev(a: &Vec<Vec<f64>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
let n = a.len();
let m = a[0].len();
let p = b[0].len();
if b.len() != m {
return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
}
let mut result = vec![vec![0.0f64; p]; n];
for i in 0..n {
for j in 0..p {
for k in 0..m {
result[i][j] += a[i][k] * b[k][j] as f64;
}
}
}
Ok(Value::Float2DArray(result))
}
fn transpose2d<T: Clone>(m: &Vec<Vec<T>>) -> Vec<Vec<T>> {
let rows = m.len();
let cols = if rows > 0 { m[0].len() } else { 0 };
(0..cols)
.map(|c| (0..rows).map(|r| m[r][c].clone()).collect())
.collect()
}
fn matrix_transpose(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
if args.len() != 1 {
return Err("matrix:transpose expects 1 argument".to_string());
}
match &args[0] {
Value::Int2DArray(xs) => Ok(Value::Int2DArray(transpose2d(xs))),
Value::Float2DArray(xs) => Ok(Value::Float2DArray(transpose2d(xs))),
_ => Err("Not numeric".to_string()),
}
}
#[no_mangle]
pub unsafe extern "C" fn Cargo_lock(
interp_ptr: *mut std::ffi::c_void,
_extra_str: *const std::ffi::c_void,
) -> i32 {
let interp = &mut *(interp_ptr as *mut Interpreter);
macro_rules! reg {
($name:expr, $f:expr) => {{
let func = Arc::new(Mutex::new($f));
interp.register_dynamic_function($name, func);
interp.set_variable($name, Value::Function(Box::new(FunctionValue::Named($name.into()))));
}};
}
reg!("matrix:add", matrix_add);
reg!("matrix:subtract", matrix_subtract);
reg!("matrix:multiply", matrix_multiply);
reg!("matrix:transpose", matrix_transpose);
0
}