use super::SpecializedMatrix;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct TridiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
diag: Array1<A>,
superdiag: Array1<A>,
subdiag: Array1<A>,
n: usize,
}
impl<A> TridiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(
diag: ArrayView1<A>,
superdiag: ArrayView1<A>,
subdiag: ArrayView1<A>,
) -> LinalgResult<Self> {
let n = diag.len();
if superdiag.len() != n - 1 || subdiag.len() != n - 1 {
return Err(LinalgError::ShapeError(format!(
"Diagonal lengths are incompatible. Main diagonal: {}, superdiagonal: {}, subdiagonal: {}",
n, superdiag.len(), subdiag.len()
)));
}
Ok(Self {
diag: diag.to_owned(),
superdiag: superdiag.to_owned(),
subdiag: subdiag.to_owned(),
n,
})
}
pub fn frommatrix(a: &ArrayView2<A>) -> LinalgResult<Self> {
if a.nrows() != a.ncols() {
return Err(LinalgError::ShapeError(format!(
"Matrix must be square to convert to tridiagonal, got shape {:?}",
a.shape()
)));
}
let n = a.nrows();
if n < 2 {
let mut diag = Array1::zeros(n);
diag[0] = a[[0, 0]];
return Ok(Self {
diag,
superdiag: Array1::zeros(0),
subdiag: Array1::zeros(0),
n,
});
}
let mut diag = Array1::zeros(n);
let mut superdiag = Array1::zeros(n - 1);
let mut subdiag = Array1::zeros(n - 1);
for i in 0..n {
diag[i] = a[[i, i]];
if i < n - 1 {
superdiag[i] = a[[i, i + 1]];
subdiag[i] = a[[i + 1, i]];
}
}
Ok(Self {
diag,
superdiag,
subdiag,
n,
})
}
pub fn solve(&self, b: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if b.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Right-hand side length {} does not match matrix dimension {}",
b.len(),
self.n
)));
}
if self.n == 1 {
if self.diag[0].abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Tridiagonal matrix is singular".to_string(),
));
}
let mut x = Array1::zeros(1);
x[0] = b[0] / self.diag[0];
return Ok(x);
}
let mut c_prime = Array1::zeros(self.n);
let mut d_prime = Array1::zeros(self.n);
c_prime[0] = self.superdiag[0] / self.diag[0];
d_prime[0] = b[0] / self.diag[0];
for i in 1..self.n {
let m = self.diag[i] - self.subdiag[i - 1] * c_prime[i - 1];
if m.abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Tridiagonal matrix is singular".to_string(),
));
}
if i < self.n - 1 {
c_prime[i] = self.superdiag[i] / m;
}
d_prime[i] = (b[i] - self.subdiag[i - 1] * d_prime[i - 1]) / m;
}
let mut x = Array1::zeros(self.n);
x[self.n - 1] = d_prime[self.n - 1];
for i in (0..self.n - 1).rev() {
x[i] = d_prime[i] - c_prime[i] * x[i + 1];
}
Ok(x)
}
}
impl<A> SpecializedMatrix<A> for TridiagonalMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.n
}
fn ncols(&self) -> usize {
self.n
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.n || j >= self.n {
return Err(LinalgError::IndexError(format!(
"Index ({}, {}) out of bounds for matrix of size {}",
i, j, self.n
)));
}
if i == j {
return Ok(self.diag[i]);
}
if i + 1 == j {
return Ok(self.superdiag[i]);
}
if i == j + 1 {
return Ok(self.subdiag[j]);
}
Ok(A::zero())
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Vector length {} does not match matrix dimension {}",
x.len(),
self.n
)));
}
let mut y = Array1::zeros(self.n);
if self.n == 1 {
y[0] = self.diag[0] * x[0];
return Ok(y);
}
y[0] = self.diag[0] * x[0] + self.superdiag[0] * x[1];
for i in 1..self.n - 1 {
y[i] =
self.subdiag[i - 1] * x[i - 1] + self.diag[i] * x[i] + self.superdiag[i] * x[i + 1];
}
y[self.n - 1] =
self.subdiag[self.n - 2] * x[self.n - 2] + self.diag[self.n - 1] * x[self.n - 1];
Ok(y)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.n {
return Err(LinalgError::ShapeError(format!(
"Vector length {} does not match matrix dimension {}",
x.len(),
self.n
)));
}
let mut y = Array1::zeros(self.n);
if self.n == 1 {
y[0] = self.diag[0] * x[0];
return Ok(y);
}
y[0] = self.diag[0] * x[0] + self.subdiag[0] * x[1];
for i in 1..self.n - 1 {
y[i] =
self.superdiag[i - 1] * x[i - 1] + self.diag[i] * x[i] + self.subdiag[i] * x[i + 1];
}
y[self.n - 1] =
self.superdiag[self.n - 2] * x[self.n - 2] + self.diag[self.n - 1] * x[self.n - 1];
Ok(y)
}
fn to_dense(&self) -> LinalgResult<Array2<A>> {
let mut a = Array2::zeros((self.n, self.n));
for i in 0..self.n {
a[[i, i]] = self.diag[i];
}
for i in 0..self.n - 1 {
a[[i, i + 1]] = self.superdiag[i];
}
for i in 0..self.n - 1 {
a[[i + 1, i]] = self.subdiag[i];
}
Ok(a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_tridiagonal_creation() {
let diag = array![1.0, 2.0, 3.0, 4.0];
let superdiag = array![5.0, 6.0, 7.0];
let subdiag = array![8.0, 9.0, 10.0];
let tri = TridiagonalMatrix::new(diag.view(), superdiag.view(), subdiag.view())
.expect("Operation failed");
assert_eq!(tri.nrows(), 4);
assert_eq!(tri.ncols(), 4);
assert_relative_eq!(tri.get(0, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(tri.get(0, 1).expect("Operation failed"), 5.0);
assert_relative_eq!(tri.get(1, 0).expect("Operation failed"), 8.0);
assert_relative_eq!(tri.get(1, 1).expect("Operation failed"), 2.0);
assert_relative_eq!(tri.get(1, 2).expect("Operation failed"), 6.0);
assert_relative_eq!(tri.get(2, 1).expect("Operation failed"), 9.0);
assert_relative_eq!(tri.get(0, 2).expect("Operation failed"), 0.0);
assert_relative_eq!(tri.get(0, 3).expect("Operation failed"), 0.0);
assert_relative_eq!(tri.get(2, 0).expect("Operation failed"), 0.0);
}
#[test]
fn test_frommatrix() {
let a = array![
[1.0, 5.0, 0.0, 0.0],
[8.0, 2.0, 6.0, 0.0],
[0.0, 9.0, 3.0, 7.0],
[0.0, 0.0, 10.0, 4.0]
];
let tri = TridiagonalMatrix::frommatrix(&a.view()).expect("Operation failed");
assert_eq!(tri.nrows(), 4);
assert_eq!(tri.ncols(), 4);
assert_relative_eq!(tri.diag[0], 1.0);
assert_relative_eq!(tri.diag[1], 2.0);
assert_relative_eq!(tri.diag[2], 3.0);
assert_relative_eq!(tri.diag[3], 4.0);
assert_relative_eq!(tri.superdiag[0], 5.0);
assert_relative_eq!(tri.superdiag[1], 6.0);
assert_relative_eq!(tri.superdiag[2], 7.0);
assert_relative_eq!(tri.subdiag[0], 8.0);
assert_relative_eq!(tri.subdiag[1], 9.0);
assert_relative_eq!(tri.subdiag[2], 10.0);
}
#[test]
fn test_matvec() {
let diag = array![1.0, 2.0, 3.0, 4.0];
let superdiag = array![5.0, 6.0, 7.0];
let subdiag = array![8.0, 9.0, 10.0];
let tri = TridiagonalMatrix::new(diag.view(), superdiag.view(), subdiag.view())
.expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = tri.matvec(&x.view()).expect("Operation failed");
let expected = array![
1.0 * 1.0 + 5.0 * 2.0,
8.0 * 1.0 + 2.0 * 2.0 + 6.0 * 3.0,
9.0 * 2.0 + 3.0 * 3.0 + 7.0 * 4.0,
10.0 * 3.0 + 4.0 * 4.0
];
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(y[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(y[2], expected[2], epsilon = 1e-10);
assert_relative_eq!(y[3], expected[3], epsilon = 1e-10);
}
#[test]
fn test_matvec_transpose() {
let diag = array![1.0, 2.0, 3.0, 4.0];
let superdiag = array![5.0, 6.0, 7.0];
let subdiag = array![8.0, 9.0, 10.0];
let tri = TridiagonalMatrix::new(diag.view(), superdiag.view(), subdiag.view())
.expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = tri.matvec_transpose(&x.view()).expect("Operation failed");
let expected = array![
1.0 * 1.0 + 8.0 * 2.0,
5.0 * 1.0 + 2.0 * 2.0 + 9.0 * 3.0,
6.0 * 2.0 + 3.0 * 3.0 + 10.0 * 4.0,
7.0 * 3.0 + 4.0 * 4.0
];
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(y[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(y[2], expected[2], epsilon = 1e-10);
assert_relative_eq!(y[3], expected[3], epsilon = 1e-10);
}
#[test]
fn test_to_dense() {
let diag = array![1.0, 2.0, 3.0];
let superdiag = array![4.0, 5.0];
let subdiag = array![6.0, 7.0];
let tri = TridiagonalMatrix::new(diag.view(), superdiag.view(), subdiag.view())
.expect("Operation failed");
let dense = tri.to_dense().expect("Operation failed");
let expected = array![[1.0, 4.0, 0.0], [6.0, 2.0, 5.0], [0.0, 7.0, 3.0]];
assert_eq!(dense.shape(), &[3, 3]);
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(dense[[i, j]], expected[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_solve() {
let diag = array![2.0, 2.0, 2.0, 2.0];
let superdiag = array![-1.0, -1.0, -1.0];
let subdiag = array![-1.0, -1.0, -1.0];
let tri = TridiagonalMatrix::new(diag.view(), superdiag.view(), subdiag.view())
.expect("Operation failed");
let b = array![1.0, 2.0, 3.0, 4.0];
let x = tri.solve(&b.view()).expect("Operation failed");
let ax = tri.matvec(&x.view()).expect("Operation failed");
assert_eq!(ax.len(), 4);
assert_relative_eq!(ax[0], b[0], epsilon = 1e-10);
assert_relative_eq!(ax[1], b[1], epsilon = 1e-10);
assert_relative_eq!(ax[2], b[2], epsilon = 1e-10);
assert_relative_eq!(ax[3], b[3], epsilon = 1e-10);
}
}