mumu-matrix 0.1.1

Matrix operations for the mumu/lava language
Documentation
// matrix/src/lib.rs

use mumu::{
    parser::interpreter::Interpreter,
    parser::types::{Value, FunctionValue},
};
use std::sync::{Arc, Mutex};

fn matrix_add(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
    if args.len() != 2 {
        return Err("matrix:add expects 2 arguments".to_string());
    }
    match (&args[0], &args[1]) {
        (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
                .collect();
            Ok(Value::Int2DArray(result))
        }
        (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + b).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 + b).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a + *b as f64).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        _ => Err("Not numeric".to_string()),
    }
}

fn matrix_subtract(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
    if args.len() != 2 {
        return Err("matrix:subtract expects 2 arguments".to_string());
    }
    match (&args[0], &args[1]) {
        (Value::Int2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<i32>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
                .collect();
            Ok(Value::Int2DArray(result))
        }
        (Value::Float2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - b).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        (Value::Int2DArray(xs), Value::Float2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| *a as f64 - b).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        (Value::Float2DArray(xs), Value::Int2DArray(ys)) if xs.len() == ys.len() && xs[0].len() == ys[0].len() => {
            let result: Vec<Vec<f64>> = xs.iter().zip(ys.iter())
                .map(|(row_x, row_y)| row_x.iter().zip(row_y.iter()).map(|(a, b)| a - *b as f64).collect())
                .collect();
            Ok(Value::Float2DArray(result))
        }
        _ => Err("Not numeric".to_string()),
    }
}

fn matrix_multiply(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
    if args.len() != 2 {
        return Err("matrix:multiply expects 2 arguments".to_string());
    }
    match (&args[0], &args[1]) {
        (Value::Int2DArray(a), Value::Int2DArray(b)) => matrix_mul_int(a, b),
        (Value::Float2DArray(a), Value::Float2DArray(b)) => matrix_mul_float(a, b),
        (Value::Int2DArray(a), Value::Float2DArray(b)) => matrix_mul_float2d_cast(a, b),
        (Value::Float2DArray(a), Value::Int2DArray(b)) => matrix_mul_float2d_cast_rev(a, b),
        _ => Err("Not numeric".to_string()),
    }
}

fn matrix_mul_int(a: &Vec<Vec<i32>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
    let n = a.len();
    let m = a[0].len();
    let p = b[0].len();
    if b.len() != m {
        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
    }
    let mut result = vec![vec![0i32; p]; n];
    for i in 0..n {
        for j in 0..p {
            for k in 0..m {
                result[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    Ok(Value::Int2DArray(result))
}

fn matrix_mul_float(a: &Vec<Vec<f64>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
    let n = a.len();
    let m = a[0].len();
    let p = b[0].len();
    if b.len() != m {
        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
    }
    let mut result = vec![vec![0.0f64; p]; n];
    for i in 0..n {
        for j in 0..p {
            for k in 0..m {
                result[i][j] += a[i][k] * b[k][j];
            }
        }
    }
    Ok(Value::Float2DArray(result))
}

fn matrix_mul_float2d_cast(a: &Vec<Vec<i32>>, b: &Vec<Vec<f64>>) -> Result<Value, String> {
    let n = a.len();
    let m = a[0].len();
    let p = b[0].len();
    if b.len() != m {
        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
    }
    let mut result = vec![vec![0.0f64; p]; n];
    for i in 0..n {
        for j in 0..p {
            for k in 0..m {
                result[i][j] += a[i][k] as f64 * b[k][j];
            }
        }
    }
    Ok(Value::Float2DArray(result))
}

fn matrix_mul_float2d_cast_rev(a: &Vec<Vec<f64>>, b: &Vec<Vec<i32>>) -> Result<Value, String> {
    let n = a.len();
    let m = a[0].len();
    let p = b[0].len();
    if b.len() != m {
        return Err(format!("dimension mismatch: left={}x{}, right={}x{}", n, m, b.len(), p));
    }
    let mut result = vec![vec![0.0f64; p]; n];
    for i in 0..n {
        for j in 0..p {
            for k in 0..m {
                result[i][j] += a[i][k] * b[k][j] as f64;
            }
        }
    }
    Ok(Value::Float2DArray(result))
}

fn transpose2d<T: Clone>(m: &Vec<Vec<T>>) -> Vec<Vec<T>> {
    let rows = m.len();
    let cols = if rows > 0 { m[0].len() } else { 0 };
    (0..cols)
        .map(|c| (0..rows).map(|r| m[r][c].clone()).collect())
        .collect()
}

fn matrix_transpose(_interp: &mut Interpreter, args: Vec<Value>) -> Result<Value, String> {
    if args.len() != 1 {
        return Err("matrix:transpose expects 1 argument".to_string());
    }
    match &args[0] {
        Value::Int2DArray(xs) => Ok(Value::Int2DArray(transpose2d(xs))),
        Value::Float2DArray(xs) => Ok(Value::Float2DArray(transpose2d(xs))),
        _ => Err("Not numeric".to_string()),
    }
}

#[no_mangle]
pub unsafe extern "C" fn Cargo_lock(
    interp_ptr: *mut std::ffi::c_void,
    _extra_str: *const std::ffi::c_void,
) -> i32 {
    let interp = &mut *(interp_ptr as *mut Interpreter);

    macro_rules! reg {
        ($name:expr, $f:expr) => {{
            let func = Arc::new(Mutex::new($f));
            interp.register_dynamic_function($name, func);
            interp.set_variable($name, Value::Function(Box::new(FunctionValue::Named($name.into()))));
        }};
    }
    reg!("matrix:add", matrix_add);
    reg!("matrix:subtract", matrix_subtract);
    reg!("matrix:multiply", matrix_multiply);
    reg!("matrix:transpose", matrix_transpose);

    0
}