#[cfg(test)]
mod tests {
use crate::core::Expression;
use crate::matrices::{CoreMatrixOps, Matrix, MatrixDecomposition};
#[test]
fn test_lu_decomposition_correctness() {
let matrix = Matrix::dense(vec![
vec![Expression::integer(2), Expression::integer(1)],
vec![Expression::integer(4), Expression::integer(3)],
]);
let lu = matrix.lu_decomposition().unwrap();
let lu_product = lu.l.multiply(&lu.u).expect("L * U should succeed");
if let Some(p) = &lu.p {
let pa = p.multiply(&matrix).expect("P * A should succeed");
assert_eq!(pa.dimensions(), lu_product.dimensions());
} else {
assert_eq!(matrix.dimensions(), lu_product.dimensions());
}
let (l_rows, l_cols) = lu.l.dimensions();
for i in 0..l_rows {
for j in 0..l_cols {
let elem = lu.l.get_element(i, j);
if i == j {
assert_eq!(elem, Expression::integer(1)); } else if i < j {
assert!(elem.is_zero()); }
}
}
let (u_rows, u_cols) = lu.u.dimensions();
for i in 0..u_rows {
for j in 0..u_cols {
if i > j {
let elem = lu.u.get_element(i, j);
assert!(elem.is_zero()); }
}
}
}
#[test]
fn test_lu_decomposition_special_cases() {
let identity = Matrix::identity(3);
let lu = identity.lu_decomposition().unwrap();
assert!(matches!(lu.l, Matrix::Identity(_)));
assert!(matches!(lu.u, Matrix::Identity(_)));
let diagonal = Matrix::diagonal(vec![
Expression::integer(2),
Expression::integer(3),
Expression::integer(4),
]);
let lu = diagonal.lu_decomposition().unwrap();
assert!(matches!(lu.l, Matrix::Identity(_)));
assert_eq!(lu.u, diagonal);
}
#[test]
fn test_qr_decomposition_correctness() {
let matrix = Matrix::dense(vec![
vec![Expression::integer(1), Expression::integer(1)],
vec![Expression::integer(0), Expression::integer(1)],
]);
let qr = matrix.qr_decomposition().unwrap();
let qr_product = qr.q.multiply(&qr.r).unwrap();
assert_eq!(matrix.dimensions(), qr_product.dimensions());
let (r_rows, r_cols) = qr.r.dimensions();
for i in 0..r_rows {
for j in 0..r_cols {
if i > j {
let elem = qr.r.get_element(i, j);
assert!(elem.is_zero()); }
}
}
let (q_rows, q_cols) = qr.q.dimensions();
assert_eq!(q_rows, matrix.dimensions().0);
assert_eq!(q_cols, matrix.dimensions().1);
}
#[test]
fn test_qr_decomposition_special_cases() {
let identity = Matrix::identity(2);
let qr = identity.qr_decomposition().unwrap();
assert!(matches!(qr.q, Matrix::Identity(_)));
assert!(matches!(qr.r, Matrix::Identity(_)));
let zero = Matrix::zero(2, 2);
let qr = zero.qr_decomposition().unwrap();
assert!(matches!(qr.q, Matrix::Identity(_)));
assert!(matches!(qr.r, Matrix::Zero(_)));
}
#[test]
fn test_cholesky_decomposition_correctness() {
let matrix = Matrix::dense(vec![
vec![Expression::integer(4), Expression::integer(2)],
vec![Expression::integer(2), Expression::integer(3)],
]);
if let Some(chol) = matrix.cholesky_decomposition() {
let l_transpose = chol.l.transpose();
let llt_product = chol.l.multiply(&l_transpose).unwrap();
assert_eq!(matrix.dimensions(), llt_product.dimensions());
let (l_rows, l_cols) = chol.l.dimensions();
for i in 0..l_rows {
for j in 0..l_cols {
if i < j {
let elem = chol.l.get_element(i, j);
assert!(elem.is_zero()); }
}
}
}
}
#[test]
fn test_cholesky_decomposition_special_cases() {
let identity = Matrix::identity(3);
let chol = identity.cholesky_decomposition().unwrap();
assert!(matches!(chol.l, Matrix::Identity(_)));
let scalar = Matrix::scalar(2, Expression::integer(4));
let chol = scalar.cholesky_decomposition().unwrap();
assert!(matches!(chol.l, Matrix::Scalar(_)));
let diagonal = Matrix::diagonal(vec![Expression::integer(4), Expression::integer(9)]);
let chol = diagonal.cholesky_decomposition().unwrap();
assert!(matches!(chol.l, Matrix::Diagonal(_)));
}
#[test]
fn test_svd_decomposition_correctness() {
let matrix = Matrix::dense(vec![
vec![Expression::integer(1), Expression::integer(2)],
vec![Expression::integer(3), Expression::integer(4)],
]);
let svd = matrix.svd_decomposition().unwrap();
let sigma_vt = svd.sigma.multiply(&svd.vt).unwrap();
let usvt_product = svd.u.multiply(&sigma_vt).unwrap();
assert_eq!(matrix.dimensions(), usvt_product.dimensions());
let (sigma_rows, sigma_cols) = svd.sigma.dimensions();
for i in 0..sigma_rows {
for j in 0..sigma_cols {
if i != j {
let elem = svd.sigma.get_element(i, j);
assert!(elem.is_zero()); }
}
}
}
#[test]
fn test_svd_special_cases() {
let identity = Matrix::identity(2);
let svd = identity.svd_decomposition().unwrap();
assert!(matches!(svd.u, Matrix::Identity(_)));
assert!(matches!(svd.sigma, Matrix::Identity(_)));
assert!(matches!(svd.vt, Matrix::Identity(_)));
let zero = Matrix::zero(2, 2);
let svd = zero.svd_decomposition().unwrap();
assert!(matches!(svd.sigma, Matrix::Zero(_)));
let diagonal = Matrix::diagonal(vec![Expression::integer(3), Expression::integer(4)]);
let svd = diagonal.svd_decomposition().unwrap();
assert!(matches!(svd.sigma, Matrix::Diagonal(_)));
}
#[test]
fn test_matrix_rank() {
let identity = Matrix::identity(3);
assert_eq!(identity.rank(), 3);
let zero = Matrix::zero(3, 3);
assert_eq!(zero.rank(), 0);
let diagonal = Matrix::diagonal(vec![
Expression::integer(1),
Expression::integer(0),
Expression::integer(3),
]);
assert_eq!(diagonal.rank(), 2);
}
#[test]
fn test_positive_definite_check() {
let identity = Matrix::identity(2);
assert!(identity.is_positive_definite());
let pos_scalar = Matrix::scalar(2, Expression::integer(5));
assert!(pos_scalar.is_positive_definite());
let pos_diagonal = Matrix::diagonal(vec![
Expression::integer(1),
Expression::integer(2),
Expression::integer(3),
]);
assert!(pos_diagonal.is_positive_definite());
}
#[test]
fn test_condition_number() {
let identity = Matrix::identity(2);
let cond = identity.condition_number();
assert_eq!(cond, Expression::integer(1));
let diagonal = Matrix::diagonal(vec![Expression::integer(2), Expression::integer(2)]);
let cond = diagonal.condition_number();
assert_eq!(cond, Expression::integer(1)); }
}