use super::SpecializedMatrix;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct SymmetricMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
data: Array2<A>,
n: usize,
}
impl<A> SymmetricMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(lower: ArrayView2<A>) -> LinalgResult<Self> {
let n = lower.nrows();
if lower.ncols() != n {
return Err(LinalgError::ShapeError(format!(
"Lower triangular part must be square, got shape {:?}",
lower.shape()
)));
}
for i in 0..n {
for j in i + 1..n {
if lower[[i, j]] != A::zero() {
return Err(LinalgError::InvalidInputError(
"Lower triangular part must have zeros above the diagonal".to_string(),
));
}
}
}
Ok(Self {
data: lower.to_owned(),
n,
})
}
pub fn frommatrix(a: &ArrayView2<A>) -> LinalgResult<Self> {
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square to be symmetric, got shape {:?}",
a.shape()
)));
}
let n = a.nrows();
for i in 0..n {
for j in i + 1..n {
if (a[[i, j]] - a[[j, i]]).abs() > A::epsilon() {
return Err(LinalgError::InvalidInputError(format!(
"Matrix is not symmetric, a[{i}, {j}] != a[{j}, {i}]"
)));
}
}
}
let mut lower = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
lower[[i, j]] = a[[i, j]];
}
}
Ok(Self { data: lower, n })
}
pub fn cholesky(&self) -> LinalgResult<Array2<A>> {
let n = self.n;
let mut l = Array2::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let mut sum = A::zero();
if j == i {
for k in 0..j {
sum += l[[j, k]] * l[[j, k]];
}
let diag_val = self.data[[i, j]] - sum;
if diag_val <= A::zero() {
return Err(LinalgError::InvalidInputError(
"Matrix is not positive definite".to_string(),
));
}
l[[j, j]] = diag_val.sqrt();
} else {
for k in 0..j {
sum += l[[i, k]] * l[[j, k]];
}
l[[i, j]] = (self.data[[i, j]] - sum) / l[[j, j]];
}
}
}
Ok(l)
}
pub fn solve(&self, b: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if b.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Right-hand side length {} does not match matrix dimension {}",
b.len(),
self.n
)));
}
let l = self.cholesky()?;
let mut y = Array1::zeros(self.n);
for i in 0..self.n {
let mut sum = A::zero();
for j in 0..i {
sum += l[[i, j]] * y[j];
}
y[i] = (b[i] - sum) / l[[i, i]];
}
let mut x = Array1::zeros(self.n);
for i_rev in 0..self.n {
let i = self.n - 1 - i_rev;
let mut sum = A::zero();
for j in i + 1..self.n {
sum += l[[j, i]] * x[j];
}
x[i] = (y[i] - sum) / l[[i, i]];
}
Ok(x)
}
}
impl<A> SpecializedMatrix<A> for SymmetricMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.n
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.n || j >= self.n {
return Err(LinalgError::IndexError(format!(
"Index ({}, {}) out of bounds for matrix of size {}",
i, j, self.n
)));
}
if j > i {
Ok(self.data[[j, i]])
} else {
Ok(self.data[[i, j]])
}
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Vector length {} does not match matrix dimension {}",
x.len(),
self.n
)));
}
let mut y = Array1::zeros(self.n);
for i in 0..self.n {
y[i] += self.data[[i, i]] * x[i];
for j in 0..i {
let a_ij = self.data[[i, j]];
y[i] += a_ij * x[j]; y[j] += a_ij * x[i]; }
}
Ok(y)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
self.matvec(x)
}
fn to_dense(&self) -> LinalgResult<Array2<A>> {
let mut a = Array2::zeros((self.n, self.n));
for i in 0..self.n {
for j in 0..=i {
let val = self.data[[i, j]];
a[[i, j]] = val;
if i != j {
a[[j, i]] = val;
}
}
}
Ok(a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_symmetric_creation() {
let a = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
assert_eq!(sym.nrows(), 3);
assert_eq!(sym.ncols(), 3);
assert_relative_eq!(
sym.get(0, 0).expect("Operation failed"),
1.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(0, 1).expect("Operation failed"),
2.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(1, 0).expect("Operation failed"),
2.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(1, 1).expect("Operation failed"),
4.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(1, 2).expect("Operation failed"),
5.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(2, 1).expect("Operation failed"),
5.0,
epsilon = 1e-10
);
assert_relative_eq!(
sym.get(2, 2).expect("Operation failed"),
6.0,
epsilon = 1e-10
);
}
#[test]
fn test_non_symmetric_error() {
let a = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let result = SymmetricMatrix::frommatrix(&a.view());
assert!(result.is_err());
}
#[test]
fn test_matvec() {
let a = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y = sym.matvec(&x.view()).expect("Operation failed");
let expected = array![
1.0 * 1.0 + 2.0 * 2.0 + 3.0 * 3.0,
2.0 * 1.0 + 4.0 * 2.0 + 5.0 * 3.0,
3.0 * 1.0 + 5.0 * 2.0 + 6.0 * 3.0
];
assert_eq!(y.len(), 3);
assert_relative_eq!(y[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(y[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(y[2], expected[2], epsilon = 1e-10);
}
#[test]
fn test_matvec_transpose() {
let a = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
let x = array![1.0, 2.0, 3.0];
let y1 = sym.matvec(&x.view()).expect("Operation failed");
let y2 = sym.matvec_transpose(&x.view()).expect("Operation failed");
assert_eq!(y1.len(), 3);
assert_eq!(y2.len(), 3);
assert_relative_eq!(y1[0], y2[0], epsilon = 1e-10);
assert_relative_eq!(y1[1], y2[1], epsilon = 1e-10);
assert_relative_eq!(y1[2], y2[2], epsilon = 1e-10);
}
#[test]
fn test_to_dense() {
let a = array![[1.0, 2.0, 3.0], [2.0, 4.0, 5.0], [3.0, 5.0, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
let dense = sym.to_dense().expect("Operation failed");
assert_eq!(dense.shape(), &[3, 3]);
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(dense[[i, j]], a[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_cholesky() {
let a = array![[4.0, 2.0, 1.0], [2.0, 3.0, 0.5], [1.0, 0.5, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
let l = sym.cholesky().expect("Operation failed");
let mut result = Array2::<f64>::zeros((3, 3));
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
if k <= i.min(j) {
result[[i, j]] += l[[i, k]] * l[[j, k]];
}
}
}
}
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(result[[i, j]], a[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_solve() {
let a = array![[4.0, 2.0, 1.0], [2.0, 3.0, 0.5], [1.0, 0.5, 6.0]];
let sym = SymmetricMatrix::frommatrix(&a.view()).expect("Operation failed");
let b = array![1.0, 2.0, 3.0];
let x = sym.solve(&b.view()).expect("Operation failed");
let ax = sym.matvec(&x.view()).expect("Operation failed");
assert_eq!(ax.len(), 3);
assert_relative_eq!(ax[0], b[0], epsilon = 1e-10);
assert_relative_eq!(ax[1], b[1], epsilon = 1e-10);
assert_relative_eq!(ax[2], b[2], epsilon = 1e-10);
}
}