use mathhook_core::matrices::eigenvalues::EigenOperations;
use mathhook_core::matrices::{Matrix, MatrixDecomposition, MatrixOperations};
use mathhook_core::{expr, symbol, Expression, Number, Simplify};
fn number_to_f64(n: &Number) -> f64 {
match n {
Number::Integer(i) => *i as f64,
Number::Float(f) => *f,
Number::BigInteger(bi) => bi.to_string().parse::<f64>().unwrap_or(f64::NAN),
Number::Rational(r) => {
let numer = r.numer().to_string().parse::<f64>().unwrap_or(f64::NAN);
let denom = r.denom().to_string().parse::<f64>().unwrap_or(f64::NAN);
numer / denom
}
}
}
fn assert_numerically_equal(left: &Expression, right: &Expression) {
const EPSILON: f64 = 1e-10;
match (left, right) {
(Expression::Number(l), Expression::Number(r)) => {
let l_float = number_to_f64(l);
let r_float = number_to_f64(r);
assert!(
(l_float - r_float).abs() < EPSILON,
"Values not numerically equal: {} vs {}",
l_float,
r_float
);
}
_ => panic!("Expected numeric values, got {:?} and {:?}", left, right),
}
}
fn assert_is_zero_or_undefined(expr: &Expression) {
const EPSILON: f64 = 1e-10;
match expr {
Expression::Number(n) => {
let val = number_to_f64(n);
assert!(val.abs() < EPSILON, "Expected zero, got {}", val);
}
Expression::Function { name, args } if name.as_ref() == "undefined" && args.is_empty() => {
}
_ => panic!("Expected zero or undefined, got {:?}", expr),
}
}
#[test]
fn test_matrix_creation_2x2() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let (rows, cols) = m.matrix_dimensions().unwrap();
assert_eq!(rows, 2);
assert_eq!(cols, 2);
}
#[test]
fn test_matrix_creation_3x3() {
let m = Expression::matrix(vec![
vec![expr!(1), expr!(2), expr!(3)],
vec![expr!(4), expr!(5), expr!(6)],
vec![expr!(7), expr!(8), expr!(9)],
]);
let (rows, cols) = m.matrix_dimensions().unwrap();
assert_eq!(rows, 3);
assert_eq!(cols, 3);
}
#[test]
fn test_identity_matrix_creation() {
let id = Expression::identity_matrix(3);
let (rows, cols) = id.matrix_dimensions().unwrap();
assert_eq!(rows, 3);
assert_eq!(cols, 3);
assert!(id.is_identity_matrix());
}
#[test]
fn test_matrix_addition() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
let sum = a.matrix_add(&b).simplify();
let expected = Expression::matrix(vec![vec![expr!(6), expr!(8)], vec![expr!(10), expr!(12)]]);
assert_eq!(sum, expected);
}
#[test]
fn test_matrix_scalar_multiplication() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let scaled = m.matrix_scalar_multiply(&expr!(2)).simplify();
let expected = Expression::matrix(vec![vec![expr!(2), expr!(4)], vec![expr!(6), expr!(8)]]);
assert_eq!(scaled, expected);
}
#[test]
#[ignore = "BUG: simplify() on matrix results causes stack overflow - needs investigation"]
fn test_matrix_multiplication() {
let a = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let b = Expression::matrix(vec![vec![expr!(5), expr!(6)], vec![expr!(7), expr!(8)]]);
let product = a.matrix_multiply(&b).simplify();
let expected = Expression::matrix(vec![vec![expr!(19), expr!(22)], vec![expr!(43), expr!(50)]]);
assert_eq!(product, expected);
}
#[test]
fn test_matrix_transpose() {
let m = Expression::matrix(vec![
vec![expr!(1), expr!(2), expr!(3)],
vec![expr!(4), expr!(5), expr!(6)],
]);
let transposed = m.matrix_transpose().simplify();
let expected = Expression::matrix(vec![
vec![expr!(1), expr!(4)],
vec![expr!(2), expr!(5)],
vec![expr!(3), expr!(6)],
]);
assert_eq!(transposed, expected);
}
#[test]
fn test_determinant_2x2() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
let det = m.matrix_determinant().simplify();
assert_numerically_equal(&det, &expr!(-2));
}
#[test]
fn test_determinant_3x3_singular() {
let m = Expression::matrix(vec![
vec![expr!(1), expr!(2), expr!(3)],
vec![expr!(4), expr!(5), expr!(6)],
vec![expr!(7), expr!(8), expr!(9)],
]);
let det = m.matrix_determinant().simplify();
assert_is_zero_or_undefined(&det);
}
#[test]
fn test_determinant_identity() {
let m = Expression::identity_matrix(3);
let det = m.matrix_determinant().simplify();
assert_eq!(det, expr!(1));
}
#[test]
fn test_determinant_diagonal() {
let m = Expression::matrix(vec![
vec![expr!(2), expr!(0), expr!(0)],
vec![expr!(0), expr!(3), expr!(0)],
vec![expr!(0), expr!(0), expr!(5)],
]);
let det = m.matrix_determinant().simplify();
assert_eq!(det, expr!(30));
}
#[test]
#[ignore = "BUG: Depends on simplify() which causes stack overflow on matrix expressions"]
fn test_inverse_2x2() {
let m = Expression::matrix(vec![vec![expr!(4), expr!(7)], vec![expr!(2), expr!(6)]]);
let inv = m.matrix_inverse().simplify();
let product = m.matrix_multiply(&inv).simplify();
let identity = Expression::identity_matrix(2);
assert_eq!(product, identity);
}
#[test]
#[ignore = "BUG: Depends on simplify() which causes stack overflow on matrix expressions"]
fn test_inverse_3x3() {
let m = Expression::matrix(vec![
vec![expr!(1), expr!(2), expr!(3)],
vec![expr!(0), expr!(1), expr!(4)],
vec![expr!(5), expr!(6), expr!(0)],
]);
let inv = m.matrix_inverse().simplify();
let product = m.matrix_multiply(&inv).simplify();
let identity = Expression::identity_matrix(3);
assert_eq!(product, identity);
}
#[test]
fn test_matrix_trace() {
let m = Expression::matrix(vec![
vec![expr!(1), expr!(2), expr!(3)],
vec![expr!(4), expr!(5), expr!(6)],
vec![expr!(7), expr!(8), expr!(9)],
]);
let trace = m.matrix_trace().simplify();
assert_eq!(trace, expr!(15));
}
#[test]
fn test_matrix_power_zero() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(1)], vec![expr!(1), expr!(0)]]);
let m0 = m.matrix_power(&expr!(0)).simplify();
let identity = Expression::identity_matrix(2);
assert_eq!(m0, identity);
}
#[test]
fn test_matrix_power_one() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(1)], vec![expr!(1), expr!(0)]]);
let m1 = m.matrix_power(&expr!(1)).simplify();
assert_eq!(m1, m);
}
#[test]
#[ignore = "BUG: matrix_power causes stack overflow - needs investigation"]
fn test_matrix_power_two() {
let m = Expression::matrix(vec![vec![expr!(1), expr!(1)], vec![expr!(1), expr!(0)]]);
let m2 = m.matrix_power(&expr!(2)).simplify();
let expected = Expression::matrix(vec![vec![expr!(2), expr!(1)], vec![expr!(1), expr!(1)]]);
assert_eq!(m2, expected);
}
#[test]
fn test_is_diagonal() {
let diagonal = Expression::matrix(vec![
vec![expr!(1), expr!(0), expr!(0)],
vec![expr!(0), expr!(2), expr!(0)],
vec![expr!(0), expr!(0), expr!(3)],
]);
assert!(diagonal.is_diagonal());
let non_diagonal = Expression::matrix(vec![vec![expr!(1), expr!(2)], vec![expr!(3), expr!(4)]]);
assert!(!non_diagonal.is_diagonal());
}
#[test]
fn test_is_zero_matrix() {
let zero = Expression::matrix(vec![vec![expr!(0), expr!(0)], vec![expr!(0), expr!(0)]]);
assert!(zero.is_zero_matrix());
let non_zero = Expression::matrix(vec![vec![expr!(1), expr!(0)], vec![expr!(0), expr!(0)]]);
assert!(!non_zero.is_zero_matrix());
}
#[test]
fn test_matrix_with_symbols() {
let a = symbol!(a);
let b = symbol!(b);
let c = symbol!(c);
let d = symbol!(d);
let m = Expression::matrix(vec![
vec![Expression::symbol(a.clone()), Expression::symbol(b.clone())],
vec![Expression::symbol(c.clone()), Expression::symbol(d.clone())],
]);
let det = m.matrix_determinant().simplify();
let expected = expr!((a * d) - (b * c));
assert_eq!(det, expected.simplify());
}
#[test]
fn test_eigenvalues_2x2_diagonal() {
let matrix = Matrix::diagonal(vec![Expression::integer(2), Expression::integer(3)]);
let eigenvalues = matrix.eigenvalues();
assert_eq!(eigenvalues.len(), 2);
assert_eq!(eigenvalues[0], Expression::integer(2));
assert_eq!(eigenvalues[1], Expression::integer(3));
}
#[test]
#[ignore = "BUG: eigenvalues() not returning correct values for upper triangular matrix"]
fn test_eigenvalues_2x2_general() {
let matrix = Matrix::from_arrays([[1, 2], [0, 3]]);
let eigenvalues = matrix.eigenvalues();
assert_eq!(eigenvalues.len(), 2);
assert!(
eigenvalues.contains(&Expression::integer(1))
&& eigenvalues.contains(&Expression::integer(3))
);
}
#[test]
fn test_eigenvectors_2x2() {
let matrix = Matrix::diagonal(vec![Expression::integer(2), Expression::integer(3)]);
if let Some(decomp) = matrix.eigen_decomposition() {
assert_eq!(decomp.eigenvalues.len(), 2);
let (rows, cols) = decomp.eigenvectors.dimensions();
assert_eq!(rows, 2);
assert_eq!(cols, 2);
}
}
#[test]
fn test_lu_decomposition() {
let matrix = Matrix::from_arrays([[2, 1, 1], [4, 3, 3], [8, 7, 9]]);
let lu = matrix.lu_decomposition();
assert!(
lu.is_some(),
"LU decomposition should succeed for this matrix"
);
let decomp = lu.unwrap();
let (l_rows, l_cols) = decomp.l.dimensions();
let (u_rows, u_cols) = decomp.u.dimensions();
assert_eq!(l_rows, 3);
assert_eq!(l_cols, 3);
assert_eq!(u_rows, 3);
assert_eq!(u_cols, 3);
}
#[test]
fn test_qr_decomposition() {
let matrix = Matrix::from_arrays([[1, 1, 0], [1, 0, 1], [0, 1, 1]]);
let qr = matrix.qr_decomposition();
assert!(qr.is_some(), "QR decomposition should succeed");
let decomp = qr.unwrap();
let (q_rows, q_cols) = decomp.q.dimensions();
let (r_rows, r_cols) = decomp.r.dimensions();
assert_eq!(q_rows, 3);
assert_eq!(q_cols, 3);
assert_eq!(r_rows, 3);
assert_eq!(r_cols, 3);
}
#[test]
fn test_svd_decomposition() {
let matrix = Matrix::from_arrays([[1, 2], [3, 4], [5, 6]]);
let svd = matrix.svd_decomposition();
assert!(svd.is_some(), "SVD decomposition should succeed");
let decomp = svd.unwrap();
let (u_rows, _u_cols) = decomp.u.dimensions();
let (vt_rows, _vt_cols) = decomp.vt.dimensions();
assert_eq!(u_rows, 3); assert_eq!(vt_rows, 2); }
#[test]
fn test_cholesky_decomposition() {
let matrix = Matrix::from_arrays([[4, 2, 1], [2, 3, 0], [1, 0, 2]]);
let chol = matrix.cholesky_decomposition();
if let Some(decomp) = chol {
let (l_rows, l_cols) = decomp.l.dimensions();
assert_eq!(l_rows, 3);
assert_eq!(l_cols, 3);
}
}
#[test]
#[ignore = "BUG: rank() returns 0 for rank-deficient matrices"]
fn test_matrix_rank() {
let full_rank = Matrix::from_arrays([[1, 0, 0], [0, 1, 0], [0, 0, 1]]);
assert_eq!(full_rank.rank(), 3);
let rank_2 = Matrix::from_arrays([[1, 2, 3], [4, 5, 6], [5, 7, 9]]);
assert_eq!(rank_2.rank(), 2);
}
#[test]
#[ignore = "BUG: is_symmetric() returns false for symmetric matrices"]
fn test_matrix_is_symmetric() {
let symmetric = Matrix::from_arrays([[1, 2, 3], [2, 4, 5], [3, 5, 6]]);
assert!(symmetric.is_symmetric());
let non_symmetric = Matrix::from_arrays([[1, 2, 3], [4, 5, 6], [7, 8, 9]]);
assert!(!non_symmetric.is_symmetric());
}
#[test]
fn test_matrix_condition_number() {
let identity = Matrix::identity(3);
let cond = identity.condition_number();
match cond {
Expression::Number(n) => {
let val = number_to_f64(&n);
assert!(
(1.0..1.1).contains(&val),
"Identity should have cond ~1, got {}",
val
);
}
_ => {
}
}
}
#[test]
fn test_solve_linear_system_ax_equals_b() {
let a = Matrix::from_arrays([[2, 1], [1, 3]]);
let b = vec![Expression::integer(5), Expression::integer(7)];
let result = a.solve(&b);
assert!(
result.is_ok(),
"solve should succeed for non-singular system"
);
let x = result.unwrap();
assert_eq!(x.len(), 2);
}
#[test]
fn test_least_squares_overdetermined() {
let a = Matrix::from_arrays([[1, 1], [1, 2], [1, 3]]);
let b = vec![
Expression::integer(1),
Expression::integer(2),
Expression::integer(2),
];
let result = a.solve_least_squares(&b);
assert!(result.is_ok(), "least squares should succeed");
let x = result.unwrap();
assert_eq!(x.len(), 2);
}
#[test]
fn test_matrix_exponential() {
let zero = Matrix::from_arrays([[0, 0], [0, 0]]);
let exp_zero = zero.matrix_exponential();
if let Some(result) = exp_zero {
let (rows, cols) = result.dimensions();
assert_eq!(rows, 2);
assert_eq!(cols, 2);
}
}