#[derive(Debug, Clone, PartialEq)]
pub enum MatrixError {
DimensionMismatch {
expected: (usize, usize),
got: (usize, usize),
},
NotSquare { rows: usize, cols: usize },
Singular,
NotSymmetric,
NotPositiveDefinite,
InvalidData { expected: usize, got: usize },
}
impl std::fmt::Display for MatrixError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MatrixError::DimensionMismatch { expected, got } => {
write!(
f,
"dimension mismatch: expected {}×{}, got {}×{}",
expected.0, expected.1, got.0, got.1
)
}
MatrixError::NotSquare { rows, cols } => {
write!(f, "matrix must be square, got {rows}×{cols}")
}
MatrixError::Singular => write!(f, "matrix is singular"),
MatrixError::NotSymmetric => write!(f, "matrix is not symmetric"),
MatrixError::NotPositiveDefinite => write!(f, "matrix is not positive-definite"),
MatrixError::InvalidData { expected, got } => {
write!(f, "data length mismatch: expected {expected}, got {got}")
}
}
}
}
impl std::error::Error for MatrixError {}
#[derive(Debug, Clone, PartialEq)]
pub struct Matrix {
data: Vec<f64>,
rows: usize,
cols: usize,
}
impl Matrix {
pub fn new(rows: usize, cols: usize, data: Vec<f64>) -> Result<Self, MatrixError> {
if data.len() != rows * cols {
return Err(MatrixError::InvalidData {
expected: rows * cols,
got: data.len(),
});
}
Ok(Self { data, rows, cols })
}
pub fn from_rows(rows: &[&[f64]]) -> Self {
assert!(!rows.is_empty(), "must have at least one row");
let ncols = rows[0].len();
assert!(ncols > 0, "must have at least one column");
let nrows = rows.len();
let mut data = Vec::with_capacity(nrows * ncols);
for (i, row) in rows.iter().enumerate() {
assert_eq!(
row.len(),
ncols,
"row {i} has {} columns, expected {ncols}",
row.len()
);
data.extend_from_slice(row);
}
Self {
data,
rows: nrows,
cols: ncols,
}
}
pub fn zeros(rows: usize, cols: usize) -> Self {
Self {
data: vec![0.0; rows * cols],
rows,
cols,
}
}
pub fn identity(n: usize) -> Self {
let mut m = Self::zeros(n, n);
for i in 0..n {
m.data[i * n + i] = 1.0;
}
m
}
pub fn from_col(data: &[f64]) -> Self {
Self {
data: data.to_vec(),
rows: data.len(),
cols: 1,
}
}
#[inline]
pub fn rows(&self) -> usize {
self.rows
}
#[inline]
pub fn cols(&self) -> usize {
self.cols
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> f64 {
self.data[row * self.cols + col]
}
#[inline]
pub fn set(&mut self, row: usize, col: usize, value: f64) {
self.data[row * self.cols + col] = value;
}
pub fn data(&self) -> &[f64] {
&self.data
}
#[inline]
pub fn row(&self, row: usize) -> &[f64] {
let start = row * self.cols;
&self.data[start..start + self.cols]
}
pub fn diag(&self) -> Vec<f64> {
let n = self.rows.min(self.cols);
(0..n).map(|i| self.get(i, i)).collect()
}
pub fn is_square(&self) -> bool {
self.rows == self.cols
}
pub fn transpose(&self) -> Self {
let mut result = Self::zeros(self.cols, self.rows);
for i in 0..self.rows {
for j in 0..self.cols {
result.data[j * self.rows + i] = self.data[i * self.cols + j];
}
}
result
}
pub fn add(&self, other: &Self) -> Result<Self, MatrixError> {
if self.rows != other.rows || self.cols != other.cols {
return Err(MatrixError::DimensionMismatch {
expected: (self.rows, self.cols),
got: (other.rows, other.cols),
});
}
let data: Vec<f64> = self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a + b)
.collect();
Ok(Self {
data,
rows: self.rows,
cols: self.cols,
})
}
pub fn sub(&self, other: &Self) -> Result<Self, MatrixError> {
if self.rows != other.rows || self.cols != other.cols {
return Err(MatrixError::DimensionMismatch {
expected: (self.rows, self.cols),
got: (other.rows, other.cols),
});
}
let data: Vec<f64> = self
.data
.iter()
.zip(&other.data)
.map(|(a, b)| a - b)
.collect();
Ok(Self {
data,
rows: self.rows,
cols: self.cols,
})
}
pub fn scale(&self, c: f64) -> Self {
let data: Vec<f64> = self.data.iter().map(|x| c * x).collect();
Self {
data,
rows: self.rows,
cols: self.cols,
}
}
pub fn mul_mat(&self, other: &Self) -> Result<Self, MatrixError> {
if self.cols != other.rows {
return Err(MatrixError::DimensionMismatch {
expected: (self.rows, self.cols),
got: (other.rows, other.cols),
});
}
let mut result = Self::zeros(self.rows, other.cols);
for i in 0..self.rows {
for k in 0..self.cols {
let a_ik = self.data[i * self.cols + k];
let row_start = i * other.cols;
let other_row_start = k * other.cols;
for j in 0..other.cols {
result.data[row_start + j] += a_ik * other.data[other_row_start + j];
}
}
}
Ok(result)
}
pub fn mul_vec(&self, v: &[f64]) -> Result<Vec<f64>, MatrixError> {
if self.cols != v.len() {
return Err(MatrixError::DimensionMismatch {
expected: (self.rows, self.cols),
got: (v.len(), 1),
});
}
let mut result = vec![0.0; self.rows];
for (i, res) in result.iter_mut().enumerate() {
let row_start = i * self.cols;
*res = self.data[row_start..row_start + self.cols]
.iter()
.zip(v.iter())
.map(|(&a, &b)| a * b)
.sum();
}
Ok(result)
}
pub fn frobenius_norm(&self) -> f64 {
self.data.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn is_symmetric(&self, tol: f64) -> bool {
if self.rows != self.cols {
return false;
}
for i in 0..self.rows {
for j in (i + 1)..self.cols {
if (self.get(i, j) - self.get(j, i)).abs() > tol {
return false;
}
}
}
true
}
fn swap_rows(&mut self, a: usize, b: usize) {
if a == b {
return;
}
let cols = self.cols;
for j in 0..cols {
self.data.swap(a * cols + j, b * cols + j);
}
}
pub fn determinant(&self) -> Result<f64, MatrixError> {
if !self.is_square() {
return Err(MatrixError::NotSquare {
rows: self.rows,
cols: self.cols,
});
}
let n = self.rows;
if n == 0 {
return Ok(1.0);
}
if n == 1 {
return Ok(self.data[0]);
}
let mut work = self.clone();
let mut sign = 1.0_f64;
let pivot_tol = 1e-15 * self.frobenius_norm().max(1e-300);
for k in 0..n {
let mut max_val = work.get(k, k).abs();
let mut max_row = k;
for i in (k + 1)..n {
let v = work.get(i, k).abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val <= pivot_tol {
return Ok(0.0); }
if max_row != k {
work.swap_rows(k, max_row);
sign = -sign;
}
let pivot = work.get(k, k);
for i in (k + 1)..n {
let factor = work.get(i, k) / pivot;
for j in (k + 1)..n {
let val = work.get(i, j) - factor * work.get(k, j);
work.set(i, j, val);
}
}
}
let mut det = sign;
for i in 0..n {
det *= work.get(i, i);
}
Ok(det)
}
pub fn inverse(&self) -> Result<Self, MatrixError> {
if !self.is_square() {
return Err(MatrixError::NotSquare {
rows: self.rows,
cols: self.cols,
});
}
let n = self.rows;
if n == 0 {
return Ok(Self::zeros(0, 0));
}
let n2 = 2 * n;
let mut aug = Self::zeros(n, n2);
for i in 0..n {
for j in 0..n {
aug.set(i, j, self.get(i, j));
}
aug.set(i, n + i, 1.0);
}
let pivot_tol = 1e-14 * self.frobenius_norm().max(1e-300);
for k in 0..n {
let mut max_val = aug.get(k, k).abs();
let mut max_row = k;
for i in (k + 1)..n {
let v = aug.get(i, k).abs();
if v > max_val {
max_val = v;
max_row = i;
}
}
if max_val <= pivot_tol {
return Err(MatrixError::Singular);
}
if max_row != k {
aug.swap_rows(k, max_row);
}
let pivot = aug.get(k, k);
for j in 0..n2 {
aug.set(k, j, aug.get(k, j) / pivot);
}
for i in 0..n {
if i != k {
let factor = aug.get(i, k);
for j in 0..n2 {
let val = aug.get(i, j) - factor * aug.get(k, j);
aug.set(i, j, val);
}
}
}
}
let mut inv = Self::zeros(n, n);
for i in 0..n {
for j in 0..n {
inv.set(i, j, aug.get(i, n + j));
}
}
Ok(inv)
}
pub fn cholesky(&self) -> Result<Self, MatrixError> {
if !self.is_square() {
return Err(MatrixError::NotSquare {
rows: self.rows,
cols: self.cols,
});
}
let n = self.rows;
let sym_tol = 1e-10 * self.frobenius_norm().max(1e-300);
if !self.is_symmetric(sym_tol) {
return Err(MatrixError::NotSymmetric);
}
let mut l = Self::zeros(n, n);
for j in 0..n {
let mut sum = 0.0;
for k in 0..j {
let ljk = l.get(j, k);
sum += ljk * ljk;
}
let diag = self.get(j, j) - sum;
if diag <= 0.0 {
return Err(MatrixError::NotPositiveDefinite);
}
l.set(j, j, diag.sqrt());
let ljj = l.get(j, j);
for i in (j + 1)..n {
let mut sum = 0.0;
for k in 0..j {
sum += l.get(i, k) * l.get(j, k);
}
l.set(i, j, (self.get(i, j) - sum) / ljj);
}
}
Ok(l)
}
pub fn cholesky_solve(&self, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
if b.len() != self.rows {
return Err(MatrixError::DimensionMismatch {
expected: (self.rows, 1),
got: (b.len(), 1),
});
}
let l = self.cholesky()?;
let y = solve_lower_triangular(&l, b)?;
let lt = l.transpose();
solve_upper_triangular(<, &y)
}
pub fn eigen_symmetric(&self) -> Result<(Vec<f64>, Matrix), MatrixError> {
let n = self.rows;
if !self.is_square() {
return Err(MatrixError::NotSquare {
rows: self.rows,
cols: self.cols,
});
}
let sym_tol = 1e-10 * self.frobenius_norm();
if !self.is_symmetric(sym_tol) {
return Err(MatrixError::NotSymmetric);
}
let mut a = self.data.clone();
let mut v = vec![0.0; n * n];
for i in 0..n {
v[i * n + i] = 1.0;
}
let max_sweeps = 100;
let tol = 1e-15;
for _ in 0..max_sweeps {
let mut off_norm = 0.0;
for i in 0..n {
for j in (i + 1)..n {
off_norm += 2.0 * a[i * n + j] * a[i * n + j];
}
}
off_norm = off_norm.sqrt();
if off_norm < tol {
break;
}
for p in 0..n {
for q in (p + 1)..n {
let apq = a[p * n + q];
if apq.abs() < tol * 0.01 {
continue;
}
let app = a[p * n + p];
let aqq = a[q * n + q];
let diff = aqq - app;
let (cos, sin) = if diff.abs() < 1e-300 {
let s = std::f64::consts::FRAC_1_SQRT_2;
(s, if apq > 0.0 { s } else { -s })
} else {
let tau = diff / (2.0 * apq);
let t = if tau >= 0.0 {
1.0 / (tau + (1.0 + tau * tau).sqrt())
} else {
-1.0 / (-tau + (1.0 + tau * tau).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
(c, s)
};
a[p * n + p] -=
2.0 * sin * cos * apq + sin * sin * (a[q * n + q] - a[p * n + p]);
a[q * n + q] += 2.0 * sin * cos * apq + sin * sin * (aqq - app); a[p * n + q] = 0.0;
a[q * n + p] = 0.0;
a[p * n + p] = app;
a[q * n + q] = aqq;
a[p * n + q] = apq;
a[q * n + p] = apq;
for r in 0..n {
if r == p || r == q {
continue;
}
let arp = a[r * n + p];
let arq = a[r * n + q];
a[r * n + p] = cos * arp - sin * arq;
a[r * n + q] = sin * arp + cos * arq;
a[p * n + r] = a[r * n + p]; a[q * n + r] = a[r * n + q]; }
let new_pp = cos * cos * app - 2.0 * sin * cos * apq + sin * sin * aqq;
let new_qq = sin * sin * app + 2.0 * sin * cos * apq + cos * cos * aqq;
a[p * n + p] = new_pp;
a[q * n + q] = new_qq;
a[p * n + q] = 0.0;
a[q * n + p] = 0.0;
for r in 0..n {
let vp = v[r * n + p];
let vq = v[r * n + q];
v[r * n + p] = cos * vp - sin * vq;
v[r * n + q] = sin * vp + cos * vq;
}
}
}
}
let mut eigen_pairs: Vec<(f64, Vec<f64>)> = (0..n)
.map(|i| {
let eigenvalue = a[i * n + i];
let eigenvector: Vec<f64> = (0..n).map(|r| v[r * n + i]).collect();
(eigenvalue, eigenvector)
})
.collect();
eigen_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let eigenvalues: Vec<f64> = eigen_pairs.iter().map(|(val, _)| *val).collect();
let mut eigvec_data = vec![0.0; n * n];
for (col, (_, vec)) in eigen_pairs.iter().enumerate() {
for (row, &val) in vec.iter().enumerate() {
eigvec_data[row * n + col] = val;
}
}
let eigenvectors = Matrix {
data: eigvec_data,
rows: n,
cols: n,
};
Ok((eigenvalues, eigenvectors))
}
}
fn solve_lower_triangular(l: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
let n = l.rows();
let mut x = vec![0.0; n];
for i in 0..n {
let mut sum = 0.0;
for (j, &xj) in x[..i].iter().enumerate() {
sum += l.get(i, j) * xj;
}
let diag = l.get(i, i);
if diag.abs() < 1e-300 {
return Err(MatrixError::Singular);
}
x[i] = (b[i] - sum) / diag;
}
Ok(x)
}
fn solve_upper_triangular(u: &Matrix, b: &[f64]) -> Result<Vec<f64>, MatrixError> {
let n = u.rows();
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = 0.0;
for (off, &xj) in x[i + 1..].iter().enumerate() {
sum += u.get(i, i + 1 + off) * xj;
}
let diag = u.get(i, i);
if diag.abs() < 1e-300 {
return Err(MatrixError::Singular);
}
x[i] = (b[i] - sum) / diag;
}
Ok(x)
}
impl std::fmt::Display for Matrix {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for i in 0..self.rows {
write!(f, "[")?;
for j in 0..self.cols {
if j > 0 {
write!(f, ", ")?;
}
write!(f, "{:>10.4}", self.get(i, j))?;
}
writeln!(f, "]")?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_valid() {
let m = Matrix::new(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
assert_eq!(m.rows(), 2);
assert_eq!(m.cols(), 3);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(1, 2), 6.0);
}
#[test]
fn test_new_invalid_length() {
assert!(Matrix::new(2, 3, vec![1.0, 2.0]).is_err());
}
#[test]
fn test_from_rows() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
assert_eq!(m.get(0, 0), 1.0);
assert_eq!(m.get(1, 1), 4.0);
}
#[test]
fn test_zeros() {
let m = Matrix::zeros(3, 4);
assert_eq!(m.rows(), 3);
assert_eq!(m.cols(), 4);
assert_eq!(m.get(2, 3), 0.0);
}
#[test]
fn test_identity() {
let eye = Matrix::identity(3);
assert_eq!(eye.get(0, 0), 1.0);
assert_eq!(eye.get(1, 1), 1.0);
assert_eq!(eye.get(2, 2), 1.0);
assert_eq!(eye.get(0, 1), 0.0);
assert_eq!(eye.get(1, 2), 0.0);
}
#[test]
fn test_diag() {
let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
assert_eq!(m.diag(), vec![1.0, 5.0, 9.0]);
}
#[test]
fn test_transpose() {
let m = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
let t = m.transpose();
assert_eq!(t.rows(), 3);
assert_eq!(t.cols(), 2);
assert_eq!(t.get(0, 0), 1.0);
assert_eq!(t.get(2, 1), 6.0);
}
#[test]
fn test_transpose_twice() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0], &[5.0, 6.0]]);
let tt = m.transpose().transpose();
assert_eq!(m, tt);
}
#[test]
fn test_add() {
let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
let c = a.add(&b).unwrap();
assert_eq!(c.get(0, 0), 6.0);
assert_eq!(c.get(1, 1), 12.0);
}
#[test]
fn test_add_dimension_mismatch() {
let a = Matrix::zeros(2, 3);
let b = Matrix::zeros(3, 2);
assert!(a.add(&b).is_err());
}
#[test]
fn test_sub() {
let a = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
let b = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let c = a.sub(&b).unwrap();
assert_eq!(c.get(0, 0), 4.0);
assert_eq!(c.get(1, 1), 4.0);
}
#[test]
fn test_scale() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let s = m.scale(2.0);
assert_eq!(s.get(0, 0), 2.0);
assert_eq!(s.get(1, 1), 8.0);
}
#[test]
fn test_mul_identity() {
let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]]);
let eye = Matrix::identity(3);
let result = a.mul_mat(&eye).unwrap();
assert_eq!(a, result);
}
#[test]
fn test_mul_2x2() {
let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let b = Matrix::from_rows(&[&[5.0, 6.0], &[7.0, 8.0]]);
let c = a.mul_mat(&b).unwrap();
assert_eq!(c.get(0, 0), 19.0);
assert_eq!(c.get(0, 1), 22.0);
assert_eq!(c.get(1, 0), 43.0);
assert_eq!(c.get(1, 1), 50.0);
}
#[test]
fn test_mul_nonsquare() {
let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]]);
let b = Matrix::from_rows(&[&[7.0, 8.0], &[9.0, 10.0], &[11.0, 12.0]]);
let c = a.mul_mat(&b).unwrap();
assert_eq!(c.rows(), 2);
assert_eq!(c.cols(), 2);
assert_eq!(c.get(0, 0), 58.0);
assert_eq!(c.get(0, 1), 64.0);
}
#[test]
fn test_mul_vec() {
let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
let v = vec![5.0, 6.0];
let result = a.mul_vec(&v).unwrap();
assert_eq!(result, vec![17.0, 39.0]);
}
#[test]
fn test_mul_dimension_mismatch() {
let a = Matrix::zeros(2, 3);
let b = Matrix::zeros(2, 3);
assert!(a.mul_mat(&b).is_err());
}
#[test]
fn test_det_2x2() {
let m = Matrix::from_rows(&[&[2.0, 3.0], &[1.0, 4.0]]);
assert!((m.determinant().unwrap() - 5.0).abs() < 1e-10);
}
#[test]
fn test_det_3x3() {
let m = Matrix::from_rows(&[&[6.0, 1.0, 1.0], &[4.0, -2.0, 5.0], &[2.0, 8.0, 7.0]]);
assert!((m.determinant().unwrap() - (-306.0)).abs() < 1e-8);
}
#[test]
fn test_det_identity() {
let eye = Matrix::identity(4);
assert!((eye.determinant().unwrap() - 1.0).abs() < 1e-10);
}
#[test]
fn test_det_singular() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
assert!(m.determinant().unwrap().abs() < 1e-10);
}
#[test]
fn test_det_not_square() {
let m = Matrix::zeros(2, 3);
assert!(m.determinant().is_err());
}
#[test]
fn test_inverse_2x2() {
let a = Matrix::from_rows(&[&[4.0, 7.0], &[2.0, 6.0]]);
let inv = a.inverse().unwrap();
let eye = a.mul_mat(&inv).unwrap();
for i in 0..2 {
for j in 0..2 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(eye.get(i, j) - expected).abs() < 1e-10,
"A·A⁻¹[{i},{j}] = {}, expected {expected}",
eye.get(i, j)
);
}
}
}
#[test]
fn test_inverse_3x3() {
let a = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[0.0, 1.0, 4.0], &[5.0, 6.0, 0.0]]);
let inv = a.inverse().unwrap();
let eye = a.mul_mat(&inv).unwrap();
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(eye.get(i, j) - expected).abs() < 1e-10,
"A·A⁻¹[{i},{j}] = {}",
eye.get(i, j)
);
}
}
}
#[test]
fn test_inverse_identity() {
let eye = Matrix::identity(4);
let inv = eye.inverse().unwrap();
assert_eq!(eye, inv);
}
#[test]
fn test_inverse_singular() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 4.0]]);
assert!(m.inverse().is_err());
}
#[test]
fn test_cholesky_2x2() {
let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
let l = a.cholesky().unwrap();
assert!(l.get(0, 1).abs() < 1e-15);
let llt = l.mul_mat(&l.transpose()).unwrap();
for i in 0..2 {
for j in 0..2 {
assert!(
(llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
"LLᵀ[{i},{j}] = {}, expected {}",
llt.get(i, j),
a.get(i, j)
);
}
}
}
#[test]
fn test_cholesky_3x3() {
let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
let l = a.cholesky().unwrap();
let llt = l.mul_mat(&l.transpose()).unwrap();
for i in 0..3 {
for j in 0..3 {
assert!(
(llt.get(i, j) - a.get(i, j)).abs() < 1e-10,
"LLᵀ[{i},{j}] = {}, A[{i},{j}] = {}",
llt.get(i, j),
a.get(i, j)
);
}
}
}
#[test]
fn test_cholesky_identity() {
let eye = Matrix::identity(3);
let l = eye.cholesky().unwrap();
assert_eq!(l, eye);
}
#[test]
fn test_cholesky_not_positive_definite() {
let a = Matrix::from_rows(&[&[1.0, 2.0], &[2.0, 1.0]]);
assert!(matches!(
a.cholesky(),
Err(MatrixError::NotPositiveDefinite)
));
}
#[test]
fn test_cholesky_not_symmetric() {
let a = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
assert!(matches!(a.cholesky(), Err(MatrixError::NotSymmetric)));
}
#[test]
fn test_cholesky_solve() {
let a = Matrix::from_rows(&[&[4.0, 2.0], &[2.0, 3.0]]);
let b = vec![1.0, 2.0];
let x = a.cholesky_solve(&b).unwrap();
let ax = a.mul_vec(&x).unwrap();
for i in 0..2 {
assert!(
(ax[i] - b[i]).abs() < 1e-10,
"Ax[{i}] = {}, b[{i}] = {}",
ax[i],
b[i]
);
}
}
#[test]
fn test_cholesky_solve_3x3() {
let a = Matrix::from_rows(&[&[25.0, 15.0, -5.0], &[15.0, 18.0, 0.0], &[-5.0, 0.0, 11.0]]);
let b = vec![35.0, 33.0, 6.0];
let x = a.cholesky_solve(&b).unwrap();
let ax = a.mul_vec(&x).unwrap();
for i in 0..3 {
assert!((ax[i] - b[i]).abs() < 1e-10);
}
}
#[test]
fn test_frobenius_norm() {
let m = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
assert!((m.frobenius_norm() - 30.0_f64.sqrt()).abs() < 1e-10);
}
#[test]
fn test_is_symmetric() {
let sym = Matrix::from_rows(&[&[1.0, 2.0, 3.0], &[2.0, 5.0, 6.0], &[3.0, 6.0, 9.0]]);
assert!(sym.is_symmetric(1e-10));
let asym = Matrix::from_rows(&[&[1.0, 2.0], &[3.0, 4.0]]);
assert!(!asym.is_symmetric(1e-10));
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
fn square_matrix(n: usize) -> impl Strategy<Value = Matrix> {
proptest::collection::vec(-10.0_f64..10.0, n * n)
.prop_map(move |data| Matrix::new(n, n, data).expect("valid dimensions"))
}
fn spd_matrix(n: usize) -> impl Strategy<Value = Matrix> {
proptest::collection::vec(-5.0_f64..5.0, n * n).prop_map(move |data| {
let a = Matrix::new(n, n, data).expect("valid dimensions");
let ata = a.transpose().mul_mat(&a).expect("compatible");
let eye_scaled = Matrix::identity(n).scale(n as f64);
ata.add(&eye_scaled).expect("compatible")
})
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(200))]
#[test]
fn transpose_involution(m in square_matrix(3)) {
let m_tt = m.transpose().transpose();
for i in 0..3 {
for j in 0..3 {
prop_assert!((m.get(i, j) - m_tt.get(i, j)).abs() < 1e-14);
}
}
}
#[test]
fn mul_identity_is_identity(m in square_matrix(3)) {
let eye = Matrix::identity(3);
let me = m.mul_mat(&eye).unwrap();
let em = eye.mul_mat(&m).unwrap();
for i in 0..3 {
for j in 0..3 {
prop_assert!((me.get(i, j) - m.get(i, j)).abs() < 1e-10);
prop_assert!((em.get(i, j) - m.get(i, j)).abs() < 1e-10);
}
}
}
#[test]
fn det_of_product(a in square_matrix(3), b in square_matrix(3)) {
let det_a = a.determinant().unwrap();
let det_b = b.determinant().unwrap();
let ab = a.mul_mat(&b).unwrap();
let det_ab = ab.determinant().unwrap();
let expected = det_a * det_b;
let tol = 1e-6 * expected.abs().max(det_ab.abs()).max(1.0);
prop_assert!(
(det_ab - expected).abs() < tol,
"det(AB)={det_ab}, det(A)*det(B)={expected}"
);
}
#[test]
fn cholesky_roundtrip(a in spd_matrix(3)) {
let l = a.cholesky().expect("SPD should decompose");
let llt = l.mul_mat(&l.transpose()).expect("compatible");
for i in 0..3 {
for j in 0..3 {
let diff = (llt.get(i, j) - a.get(i, j)).abs();
let tol = 1e-8 * a.get(i, j).abs().max(1.0);
prop_assert!(
diff < tol,
"LLᵀ[{i},{j}]={}, A[{i},{j}]={}",
llt.get(i, j), a.get(i, j)
);
}
}
}
#[test]
fn cholesky_solve_roundtrip(a in spd_matrix(3), b in proptest::collection::vec(-10.0_f64..10.0, 3)) {
let x = a.cholesky_solve(&b).expect("SPD solve should work");
let ax = a.mul_vec(&x).expect("compatible");
for i in 0..3 {
let tol = 1e-8 * b[i].abs().max(1.0);
prop_assert!(
(ax[i] - b[i]).abs() < tol,
"Ax[{i}]={}, b[{i}]={}",
ax[i], b[i]
);
}
}
#[test]
fn inverse_roundtrip(a in spd_matrix(3)) {
let inv = a.inverse().expect("SPD invertible");
let eye = a.mul_mat(&inv).expect("compatible");
for i in 0..3 {
for j in 0..3 {
let expected = if i == j { 1.0 } else { 0.0 };
let diff = (eye.get(i, j) - expected).abs();
prop_assert!(
diff < 1e-6,
"A·A⁻¹[{i},{j}]={}, expected {expected}",
eye.get(i, j)
);
}
}
}
}
}