#![cfg(all(feature = "blas-lapack", not(target_arch = "wasm32")))]
use runmat_builtins::{CellArray, Tensor as Matrix, Value};
use runmat_runtime::{blas::*, call_builtin, lapack::*};
#[test]
fn test_blas_matrix_multiplication() {
let a = Matrix::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let b = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let result = blas_matrix_mul(&a, &b).unwrap();
assert_eq!(result.data, vec![4.0, 10.0, 5.0, 11.0]);
assert_eq!(result.rows(), 2);
assert_eq!(result.cols(), 2);
}
#[test]
fn test_blas_matrix_vector_multiplication() {
let matrix = Matrix::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let vector = vec![2.0, 3.0];
let result = blas_matrix_vector_mul(&matrix, &vector).unwrap();
assert_eq!(result, vec![8.0, 18.0]);
}
#[test]
fn test_blas_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = blas_dot_product(&a, &b).unwrap();
assert_eq!(result, 32.0);
}
#[test]
fn test_blas_vector_norm() {
let vector = vec![3.0, 4.0]; let norm = blas_vector_norm(&vector);
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn test_blas_vector_operations() {
let mut x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let mut result = y.clone();
blas_scale_vector(&mut x, 2.0);
assert_eq!(x, vec![2.0, 4.0, 6.0]);
blas_vector_add(2.0, &x, &mut result).unwrap();
assert_eq!(result, vec![8.0, 13.0, 18.0]);
}
#[test]
fn test_lapack_linear_solve() {
let a = Matrix::new_2d(vec![2.0, 1.0, 1.0, 3.0], 2, 2).unwrap();
let b = vec![5.0, 6.0];
let solution = lapack_solve_linear_system(&a, &b).unwrap();
assert!((solution[0] - 1.8).abs() < 1e-10);
assert!((solution[1] - 1.4).abs() < 1e-10);
}
#[test]
fn test_lapack_lu_decomposition() {
let matrix = Matrix::new_2d(vec![4.0, 3.0, 2.0, 1.0], 2, 2).unwrap();
let lu = lapack_lu_decomposition(&matrix).unwrap();
assert_eq!(lu.lu_matrix.rows(), 2);
assert_eq!(lu.lu_matrix.cols(), 2);
assert_eq!(lu.pivots.len(), 2);
}
#[test]
fn test_lapack_qr_decomposition() {
let matrix = Matrix::new_2d(vec![1.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let qr = lapack_qr_decomposition(&matrix).unwrap();
assert_eq!(qr.q.rows(), 2);
assert_eq!(qr.q.cols(), 2);
assert_eq!(qr.r.rows(), 2);
assert_eq!(qr.r.cols(), 2);
let qt = match call_builtin("transpose", &[Value::Tensor(qr.q.clone())]).unwrap() {
Value::Tensor(t) => t,
other => panic!("expected tensor transpose, got {other:?}"),
};
let qt_q = blas_matrix_mul(&qt, &qr.q).unwrap();
let tolerance = 1e-10;
assert!((qt_q.data[0] - 1.0).abs() < tolerance); assert!((qt_q.data[1]).abs() < tolerance); assert!((qt_q.data[2]).abs() < tolerance); assert!((qt_q.data[3] - 1.0).abs() < tolerance); }
#[test]
fn test_lapack_eigenvalues() {
let matrix = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let eig = lapack_eigenvalues(&matrix, false).unwrap();
assert_eq!(eig.eigenvalues.len(), 2);
let mut eigenvals = eig.eigenvalues;
eigenvals.sort_by(|a, b| a.partial_cmp(b).unwrap());
assert!((eigenvals[0] - 1.0).abs() < 1e-10);
assert!((eigenvals[1] - 3.0).abs() < 1e-10);
}
#[test]
fn test_lapack_determinant() {
let matrix = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let det = lapack_determinant(&matrix).unwrap();
assert!((det - 3.0).abs() < 1e-10);
}
#[test]
fn test_lapack_matrix_inverse() {
let matrix = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let inv = lapack_matrix_inverse(&matrix).unwrap();
let expected = [2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0];
for (i, &expected_val) in expected.iter().enumerate() {
assert!((inv.data[i] - expected_val).abs() < 1e-10);
}
let identity = blas_matrix_mul(&matrix, &inv).unwrap();
let tolerance = 1e-10;
assert!((identity.data[0] - 1.0).abs() < tolerance);
assert!((identity.data[1]).abs() < tolerance);
assert!((identity.data[2]).abs() < tolerance);
assert!((identity.data[3] - 1.0).abs() < tolerance);
}
#[test]
fn test_builtin_blas_functions() {
let a = Matrix::new_2d(vec![1.0, 2.0, 3.0, 4.0], 2, 2).unwrap();
let b = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
match call_builtin("blas_matmul", &[Value::Tensor(a), Value::Tensor(b)]) {
Ok(Value::Tensor(m)) => {
assert_eq!(m.data, vec![4.0, 10.0, 5.0, 11.0]);
}
Err(message) if message.message().contains("Undefined function") => {
return;
}
Ok(other) => panic!("Expected matrix result, got {other:?}"),
Err(message) => panic!("Unexpected error calling blas_matmul builtin: {message}"),
}
let vec_a = CellArray::new(
vec![Value::Num(1.0), Value::Num(2.0), Value::Num(3.0)],
1,
3,
)
.unwrap();
let vec_b = CellArray::new(
vec![Value::Num(4.0), Value::Num(5.0), Value::Num(6.0)],
1,
3,
)
.unwrap();
match call_builtin("dot", &[Value::Cell(vec_a), Value::Cell(vec_b)]) {
Ok(Value::Num(dot)) => assert_eq!(dot, 32.0),
Err(message) if message.message().contains("Undefined function") => (),
Ok(other) => panic!("Expected numeric result, got {other:?}"),
Err(message) => panic!("Unexpected error calling dot builtin: {message}"),
}
}
#[test]
fn test_builtin_lapack_functions() {
let a = Matrix::new_2d(vec![2.0, 1.0, 1.0, 3.0], 2, 2).unwrap();
let b_vec = CellArray::new(vec![Value::Num(5.0), Value::Num(6.0)], 2, 1).unwrap();
match call_builtin("solve", &[Value::Tensor(a), Value::Cell(b_vec)]) {
Ok(Value::Cell(sol)) => {
assert_eq!(sol.data.len(), 2);
if let (Value::Num(x), Value::Num(y)) = (&*sol.data[0], &*sol.data[1]) {
assert!((x - 1.8).abs() < 1e-10);
assert!((y - 1.4).abs() < 1e-10);
} else {
panic!("Expected numeric solution");
}
}
Err(message) if message.message().contains("Undefined function") => (),
Ok(other) => panic!("Expected vector result, got {other:?}"),
Err(message) => panic!("Unexpected error calling solve builtin: {message}"),
}
let matrix = Matrix::new_2d(vec![2.0, 1.0, 1.0, 2.0], 2, 2).unwrap();
let det_result = call_builtin("det", &[Value::Tensor(matrix)]).unwrap();
if let Value::Num(det) = det_result {
assert!((det - 3.0).abs() < 1e-10);
} else {
panic!("Expected numeric result");
}
}
#[test]
fn test_error_handling() {
let a = Matrix::new_2d(vec![1.0, 2.0], 1, 2).unwrap();
let b = Matrix::new_2d(vec![1.0, 2.0, 3.0], 1, 3).unwrap();
assert!(blas_matrix_mul(&a, &b).is_err());
let non_square = Matrix::new_2d(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], 2, 3).unwrap();
assert!(lapack_determinant(&non_square).is_err());
assert!(lapack_matrix_inverse(&non_square).is_err());
let vec_a = vec![1.0, 2.0];
let vec_b = vec![1.0, 2.0, 3.0];
assert!(blas_dot_product(&vec_a, &vec_b).is_err());
}