use super::SpecializedMatrix;
use crate::error::{LinalgError, LinalgResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2, ScalarOperand};
use scirs2_core::numeric::{Float, NumAssign, One, Zero};
use std::fmt::Debug;
use std::iter::Sum;
#[derive(Debug, Clone)]
pub struct BandedMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
data: Array2<A>,
lower_bandwidth: usize,
upper_bandwidth: usize,
nrows: usize,
ncols: usize,
}
impl<A> BandedMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
pub fn new(
data: ArrayView2<A>,
lower_bandwidth: usize,
upper_bandwidth: usize,
nrows: usize,
ncols: usize,
) -> LinalgResult<Self> {
let expected_rows = lower_bandwidth + upper_bandwidth + 1;
if data.nrows() != expected_rows {
return Err(LinalgError::ShapeError(format!(
"Data should have {expected_rows} rows for a matrix with lower _bandwidth {lower_bandwidth} and upper _bandwidth {upper_bandwidth}"
)));
}
let max_diag_len = std::cmp::min(nrows, ncols);
if data.ncols() != max_diag_len {
return Err(LinalgError::ShapeError(format!(
"Data should have {max_diag_len} columns for a matrix with dimensions {nrows}x{ncols}"
)));
}
Ok(Self {
data: data.to_owned(),
lower_bandwidth,
upper_bandwidth,
nrows,
ncols,
})
}
pub fn frommatrix(
a: &ArrayView2<A>,
lower_bandwidth: usize,
upper_bandwidth: usize,
) -> LinalgResult<Self> {
let nrows = a.nrows();
let ncols = a.ncols();
let max_diag_len = std::cmp::min(nrows, ncols);
let mut data = Array2::zeros((lower_bandwidth + upper_bandwidth + 1, max_diag_len));
for i in 0..nrows {
for j in 0..ncols {
if j < i + upper_bandwidth + 1 && i < j + lower_bandwidth + 1 {
let diag_index = (j as isize - i as isize + lower_bandwidth as isize) as usize;
let diag_pos = if j >= i {
i
} else {
j
};
data[[diag_index, diag_pos]] = a[[i, j]];
}
}
}
Ok(Self {
data,
lower_bandwidth,
upper_bandwidth,
nrows,
ncols,
})
}
pub fn bandwidth(&self) -> usize {
self.lower_bandwidth + self.upper_bandwidth + 1
}
pub fn lower_bandwidth(&self) -> usize {
self.lower_bandwidth
}
pub fn upper_bandwidth(&self) -> usize {
self.upper_bandwidth
}
pub fn solve(&self, b: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if b.len() != self.nrows {
return Err(LinalgError::ShapeError(format!(
"Right-hand side length {} does not match matrix dimension {}",
b.len(),
self.nrows
)));
}
if self.nrows != self.ncols {
return Err(LinalgError::ShapeError(
"Matrix must be square to solve system".to_string(),
));
}
if self.nrows == 1 {
let a = self.get(0, 0)?;
if a.abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular".to_string(),
));
}
let mut x = Array1::zeros(1);
x[0] = b[0] / a;
return Ok(x);
}
if self.lower_bandwidth == 1 && self.upper_bandwidth == 1 {
return self.solve_tridiagonal(b);
}
let mut x = Array1::zeros(self.nrows);
let a_dense = self.to_dense()?;
let mut augmented = Array2::zeros((self.nrows, self.ncols + 1));
for i in 0..self.nrows {
for j in 0..self.ncols {
augmented[[i, j]] = a_dense[[i, j]];
}
augmented[[i, self.ncols]] = b[i];
}
for i in 0..self.nrows - 1 {
let mut max_row = i;
let mut max_val = augmented[[i, i]].abs();
for j in i + 1..self.nrows {
let val = augmented[[j, i]].abs();
if val > max_val {
max_val = val;
max_row = j;
}
}
if max_val < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during Gaussian elimination".to_string(),
));
}
if max_row != i {
for j in i..=self.ncols {
let temp = augmented[[i, j]];
augmented[[i, j]] = augmented[[max_row, j]];
augmented[[max_row, j]] = temp;
}
}
for j in i + 1..self.nrows {
let factor = augmented[[j, i]] / augmented[[i, i]];
for k in i..=self.ncols {
let value_i_k = augmented[[i, k]];
augmented[[j, k]] -= factor * value_i_k;
}
}
}
for i in (0..self.nrows).rev() {
let mut sum = A::zero();
for j in i + 1..self.ncols {
sum += augmented[[i, j]] * x[j];
}
x[i] = (augmented[[i, self.ncols]] - sum) / augmented[[i, i]];
}
Ok(x)
}
fn solve_tridiagonal(&self, b: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
let n = self.nrows;
if n != b.len() {
return Err(LinalgError::ShapeError(format!(
"Matrix rows ({}) must match vector length ({})",
n,
b.len()
)));
}
if self.lower_bandwidth != 1 || self.upper_bandwidth != 1 {
return Err(LinalgError::ShapeError(
"solve_tridiagonal requires a matrix with lower_bandwidth=1 and upper_bandwidth=1"
.to_string(),
));
}
let mut lower = Array1::zeros(n - 1); let mut diag = Array1::zeros(n); let mut upper = Array1::zeros(n - 1);
for i in 0..n {
diag[i] = self.get(i, i)?;
if i > 0 {
lower[i - 1] = self.get(i, i - 1)?;
}
if i < n - 1 {
upper[i] = self.get(i, i + 1)?;
}
}
let mut c_prime = Array1::zeros(n - 1);
let mut d_prime = Array1::zeros(n);
let mut x = Array1::zeros(n);
if diag[0].abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during tridiagonal solve: zero on main diagonal".to_string(),
));
}
d_prime[0] = b[0] / diag[0];
c_prime[0] = upper[0] / diag[0];
for i in 1..n - 1 {
let m = lower[i - 1] / diag[i - 1];
let new_diag = diag[i] - m * upper[i - 1];
if new_diag.abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during tridiagonal solve: zero pivot encountered"
.to_string(),
));
}
d_prime[i] = (b[i] - lower[i - 1] * d_prime[i - 1]) / new_diag;
c_prime[i] = upper[i] / new_diag;
diag[i] = new_diag;
}
if n > 1 {
let m = lower[n - 2] / diag[n - 2];
let new_diag = diag[n - 1] - m * upper[n - 2];
if new_diag.abs() < A::epsilon() {
return Err(LinalgError::SingularMatrixError(
"Matrix is singular during tridiagonal solve: zero pivot in last row"
.to_string(),
));
}
d_prime[n - 1] = (b[n - 1] - lower[n - 2] * d_prime[n - 2]) / new_diag;
}
x[n - 1] = d_prime[n - 1];
for i in (0..n - 1).rev() {
x[i] = d_prime[i] - c_prime[i] * x[i + 1];
}
Ok(x)
}
}
impl<A> SpecializedMatrix<A> for BandedMatrix<A>
where
A: Float + NumAssign + Zero + Sum + One + ScalarOperand + Send + Sync + Debug,
{
fn nrows(&self) -> usize {
self.nrows
}
fn ncols(&self) -> usize {
self.ncols
}
fn get(&self, i: usize, j: usize) -> LinalgResult<A> {
if i >= self.nrows || j >= self.ncols {
return Err(LinalgError::IndexError(format!(
"Index ({}, {}) out of bounds for matrix of size {}x{}",
i, j, self.nrows, self.ncols
)));
}
if j > i + self.upper_bandwidth || i > j + self.lower_bandwidth {
return Ok(A::zero());
}
let diag_index = (j as isize - i as isize + self.lower_bandwidth as isize) as usize;
let diag_pos = if j >= i {
i
} else {
j
};
Ok(self.data[[diag_index, diag_pos]])
}
fn matvec(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.ncols {
return Err(LinalgError::ShapeError(format!(
"Vector length {} does not match matrix column count {}",
x.len(),
self.ncols
)));
}
let mut y = Array1::zeros(self.nrows);
for i in 0..self.nrows {
let j_start = i.saturating_sub(self.lower_bandwidth);
let j_end = std::cmp::min(i + self.upper_bandwidth + 1, self.ncols);
for j in j_start..j_end {
let diag_index = (j as isize - i as isize + self.lower_bandwidth as isize) as usize;
let diag_pos = if j >= i {
i
} else {
j
};
y[i] += self.data[[diag_index, diag_pos]] * x[j];
}
}
Ok(y)
}
fn matvec_transpose(&self, x: &ArrayView1<A>) -> LinalgResult<Array1<A>> {
if x.len() != self.nrows {
return Err(LinalgError::ShapeError(format!(
"Vector length {} does not match matrix row count {}",
x.len(),
self.nrows
)));
}
let mut y = Array1::zeros(self.ncols);
for j in 0..self.ncols {
let i_start = j.saturating_sub(self.upper_bandwidth);
let i_end = std::cmp::min(j + self.lower_bandwidth + 1, self.nrows);
for i in i_start..i_end {
let diag_index = (j as isize - i as isize + self.lower_bandwidth as isize) as usize;
let diag_pos = if j >= i {
i
} else {
j
};
y[j] += self.data[[diag_index, diag_pos]] * x[i];
}
}
Ok(y)
}
fn to_dense(&self) -> LinalgResult<Array2<A>> {
let mut a = Array2::zeros((self.nrows, self.ncols));
for i in 0..self.nrows {
let j_start = i.saturating_sub(self.lower_bandwidth);
let j_end = std::cmp::min(i + self.upper_bandwidth + 1, self.ncols);
for j in j_start..j_end {
let diag_index = (j as isize - i as isize + self.lower_bandwidth as isize) as usize;
let diag_pos = if j >= i {
i
} else {
j
};
a[[i, j]] = self.data[[diag_index, diag_pos]];
}
}
Ok(a)
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_banded_creation() {
let mut data = Array2::zeros((4, 5));
data[[0, 0]] = 1.0;
data[[0, 1]] = 2.0;
data[[0, 2]] = 3.0;
data[[0, 3]] = 4.0;
data[[1, 0]] = 5.0;
data[[1, 1]] = 6.0;
data[[1, 2]] = 7.0;
data[[1, 3]] = 8.0;
data[[1, 4]] = 9.0;
data[[2, 0]] = 10.0;
data[[2, 1]] = 11.0;
data[[2, 2]] = 12.0;
data[[2, 3]] = 13.0;
data[[3, 0]] = 14.0;
data[[3, 1]] = 15.0;
data[[3, 2]] = 16.0;
let band = BandedMatrix::new(data.view(), 1, 2, 5, 5).expect("Operation failed");
assert_eq!(band.nrows(), 5);
assert_eq!(band.ncols(), 5);
assert_eq!(band.bandwidth(), 4);
assert_eq!(band.lower_bandwidth(), 1);
assert_eq!(band.upper_bandwidth(), 2);
assert_relative_eq!(band.get(0, 0).expect("Operation failed"), 5.0);
assert_relative_eq!(band.get(0, 1).expect("Operation failed"), 10.0);
assert_relative_eq!(band.get(0, 2).expect("Operation failed"), 14.0);
assert_relative_eq!(band.get(1, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(band.get(1, 1).expect("Operation failed"), 6.0);
assert_relative_eq!(band.get(1, 2).expect("Operation failed"), 11.0);
assert_relative_eq!(band.get(1, 3).expect("Operation failed"), 15.0);
assert_relative_eq!(band.get(0, 3).expect("Operation failed"), 0.0);
assert_relative_eq!(band.get(0, 4).expect("Operation failed"), 0.0);
assert_relative_eq!(band.get(3, 0).expect("Operation failed"), 0.0);
}
#[test]
fn test_frommatrix() {
let a = array![
[5.0, 10.0, 14.0, 0.0, 0.0],
[1.0, 6.0, 11.0, 15.0, 0.0],
[0.0, 2.0, 7.0, 12.0, 16.0],
[0.0, 0.0, 3.0, 8.0, 13.0],
[0.0, 0.0, 0.0, 4.0, 9.0]
];
let band = BandedMatrix::frommatrix(&a.view(), 1, 2).expect("Operation failed");
assert_eq!(band.nrows(), 5);
assert_eq!(band.ncols(), 5);
assert_eq!(band.bandwidth(), 4);
assert_eq!(band.lower_bandwidth(), 1);
assert_eq!(band.upper_bandwidth(), 2);
assert_relative_eq!(band.get(0, 0).expect("Operation failed"), 5.0);
assert_relative_eq!(band.get(0, 1).expect("Operation failed"), 10.0);
assert_relative_eq!(band.get(0, 2).expect("Operation failed"), 14.0);
assert_relative_eq!(band.get(1, 0).expect("Operation failed"), 1.0);
assert_relative_eq!(band.get(1, 1).expect("Operation failed"), 6.0);
assert_relative_eq!(band.get(1, 2).expect("Operation failed"), 11.0);
assert_relative_eq!(band.get(1, 3).expect("Operation failed"), 15.0);
assert_relative_eq!(band.get(0, 3).expect("Operation failed"), 0.0);
assert_relative_eq!(band.get(0, 4).expect("Operation failed"), 0.0);
assert_relative_eq!(band.get(3, 0).expect("Operation failed"), 0.0);
}
#[test]
fn test_matvec() {
let mut data = Array2::zeros((3, 4));
data[[0, 0]] = 1.0;
data[[0, 1]] = 2.0;
data[[0, 2]] = 3.0;
data[[1, 0]] = 4.0;
data[[1, 1]] = 5.0;
data[[1, 2]] = 6.0;
data[[1, 3]] = 7.0;
data[[2, 0]] = 8.0;
data[[2, 1]] = 9.0;
data[[2, 2]] = 10.0;
let band = BandedMatrix::new(data.view(), 1, 1, 4, 4).expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = band.matvec(&x.view()).expect("Operation failed");
let expected = array![
4.0 * 1.0 + 8.0 * 2.0,
1.0 * 1.0 + 5.0 * 2.0 + 9.0 * 3.0,
2.0 * 2.0 + 6.0 * 3.0 + 10.0 * 4.0,
3.0 * 3.0 + 7.0 * 4.0
];
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(y[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(y[2], expected[2], epsilon = 1e-10);
assert_relative_eq!(y[3], expected[3], epsilon = 1e-10);
}
#[test]
fn test_matvec_transpose() {
let mut data = Array2::zeros((3, 4));
data[[0, 0]] = 1.0;
data[[0, 1]] = 2.0;
data[[0, 2]] = 3.0;
data[[1, 0]] = 4.0;
data[[1, 1]] = 5.0;
data[[1, 2]] = 6.0;
data[[1, 3]] = 7.0;
data[[2, 0]] = 8.0;
data[[2, 1]] = 9.0;
data[[2, 2]] = 10.0;
let band = BandedMatrix::new(data.view(), 1, 1, 4, 4).expect("Operation failed");
let x = array![1.0, 2.0, 3.0, 4.0];
let y = band.matvec_transpose(&x.view()).expect("Operation failed");
let expected = array![
4.0 * 1.0 + 1.0 * 2.0,
8.0 * 1.0 + 5.0 * 2.0 + 2.0 * 3.0,
9.0 * 2.0 + 6.0 * 3.0 + 3.0 * 4.0,
10.0 * 3.0 + 7.0 * 4.0
];
assert_eq!(y.len(), 4);
assert_relative_eq!(y[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(y[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(y[2], expected[2], epsilon = 1e-10);
assert_relative_eq!(y[3], expected[3], epsilon = 1e-10);
}
#[test]
fn test_to_dense() {
let mut data = Array2::zeros((3, 3));
data[[0, 0]] = 1.0;
data[[0, 1]] = 2.0;
data[[1, 0]] = 3.0;
data[[1, 1]] = 4.0;
data[[1, 2]] = 5.0;
data[[2, 0]] = 6.0;
data[[2, 1]] = 7.0;
let band = BandedMatrix::new(data.view(), 1, 1, 3, 3).expect("Operation failed");
let dense = band.to_dense().expect("Operation failed");
let expected = array![[3.0, 6.0, 0.0], [1.0, 4.0, 7.0], [0.0, 2.0, 5.0]];
assert_eq!(dense.shape(), &[3, 3]);
for i in 0..3 {
for j in 0..3 {
assert_relative_eq!(dense[[i, j]], expected[[i, j]], epsilon = 1e-10);
}
}
}
#[test]
fn test_solve() {
let a = array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 2.0]];
let band = BandedMatrix::frommatrix(&a.view(), 1, 1).expect("Operation failed");
let b = array![1.0, 2.0, 3.0];
let expected = array![2.5, 4.0, 3.5];
let x = band.solve_tridiagonal(&b.view()).expect("Operation failed");
assert_eq!(x.len(), 3);
assert_relative_eq!(x[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(x[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(x[2], expected[2], epsilon = 1e-10);
let ax = band.matvec(&x.view()).expect("Operation failed");
assert_eq!(ax.len(), 3);
assert_relative_eq!(ax[0], b[0], epsilon = 1e-10);
assert_relative_eq!(ax[1], b[1], epsilon = 1e-10);
assert_relative_eq!(ax[2], b[2], epsilon = 1e-10);
let x2 = band.solve(&b.view()).expect("Operation failed");
assert_eq!(x2.len(), 3);
assert_relative_eq!(x2[0], expected[0], epsilon = 1e-10);
assert_relative_eq!(x2[1], expected[1], epsilon = 1e-10);
assert_relative_eq!(x2[2], expected[2], epsilon = 1e-10);
}
}