use crate::banded_array::BandedArray;
use crate::error::SparseResult;
use crate::sparray::SparseArray;
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::{Float, One, Zero};
use scirs2_core::SparseElement;
use std::fmt::Debug;
pub type BandedMatrix<T> = BandedArray<T>;
impl<T> BandedMatrix<T>
where
T: SparseElement
+ std::fmt::Display
+ Float
+ Zero
+ One
+ Send
+ Sync
+ 'static
+ std::ops::AddAssign,
{
pub fn matmul(&self, other: &BandedMatrix<T>) -> SparseResult<BandedMatrix<T>> {
let a_dense = self.to_array();
let b_dense = other.to_array();
if a_dense.ncols() != b_dense.nrows() {
return Err(crate::error::SparseError::DimensionMismatch {
expected: a_dense.ncols(),
found: b_dense.nrows(),
});
}
let result_dense = a_dense.dot(&b_dense);
let max_bandwidth = self.kl() + self.ku() + other.kl() + other.ku();
Self::from_dense(&result_dense, max_bandwidth, max_bandwidth)
}
pub fn from_dense(dense: &Array2<T>, kl: usize, ku: usize) -> SparseResult<Self> {
let (rows, cols) = dense.dim();
let mut result = Self::zeros((rows, cols), kl, ku);
for i in 0..rows {
for j in 0..cols {
if result.is_in_band(i, j) {
let val = dense[[i, j]];
if !SparseElement::is_zero(&val) {
result.set_unchecked(i, j, val);
}
}
}
}
Ok(result)
}
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
if !self.is_in_band(row, col) {
return None;
}
let band_idx = self.ku() + col - row;
if band_idx < self.kl() + self.ku() + 1 && row < self.shape().0 {
Some(&mut self.data_mut()[[band_idx, row]])
} else {
None
}
}
#[allow(dead_code)]
fn banded_data_mut(&mut self) -> &mut Array2<T> {
BandedArray::data_mut(self)
}
pub fn set(&mut self, row: usize, col: usize, value: T) -> SparseResult<()> {
if !self.is_in_band(row, col) {
if !SparseElement::is_zero(&value) {
return Err(crate::error::SparseError::ValueError(format!(
"Cannot set non-zero element at ({row}, {col}) outside band structure"
)));
}
return Ok(());
}
self.set_unchecked(row, col, value);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_banded_matrix_creation() {
let diag = vec![1.0, 2.0, 3.0];
let lower = vec![4.0, 5.0];
let upper = vec![6.0, 7.0];
let matrix = BandedMatrix::tridiagonal(&diag, &lower, &upper).expect("Operation failed");
assert_eq!(matrix.shape(), (3, 3));
assert_eq!(matrix.get(0, 0), 1.0);
assert_eq!(matrix.get(1, 1), 2.0);
assert_eq!(matrix.get(2, 2), 3.0);
assert_eq!(matrix.get(1, 0), 4.0);
assert_eq!(matrix.get(2, 1), 5.0);
assert_eq!(matrix.get(0, 1), 6.0);
assert_eq!(matrix.get(1, 2), 7.0);
}
#[test]
fn test_banded_matrix_set() {
let mut matrix = BandedMatrix::<f64>::zeros((3, 3), 1, 1);
assert!(matrix.set(0, 0, 1.0).is_ok());
assert!(matrix.set(0, 1, 2.0).is_ok());
assert!(matrix.set(1, 0, 3.0).is_ok());
assert!(matrix.set(0, 2, 4.0).is_err());
assert!(matrix.set(0, 2, 0.0).is_ok());
assert_eq!(matrix.get(0, 0), 1.0);
assert_eq!(matrix.get(0, 1), 2.0);
assert_eq!(matrix.get(1, 0), 3.0);
}
#[test]
fn test_banded_matrix_matmul() {
let a = BandedMatrix::tridiagonal(&[2.0, 2.0, 2.0], &[1.0, 1.0], &[1.0, 1.0])
.expect("Operation failed");
let b = BandedMatrix::tridiagonal(&[1.0, 1.0, 1.0], &[0.5, 0.5], &[0.5, 0.5])
.expect("Operation failed");
let c = a.matmul(&b).expect("Operation failed");
assert!(c.shape() == (3, 3));
assert_relative_eq!(c.get(0, 0), 2.5, epsilon = 1e-10);
}
#[test]
fn test_from_dense() {
let dense =
Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 0.0, 3.0, 4.0, 5.0, 0.0, 6.0, 7.0])
.expect("Operation failed");
let banded = BandedMatrix::from_dense(&dense, 1, 1).expect("Operation failed");
assert_eq!(banded.get(0, 0), 1.0);
assert_eq!(banded.get(0, 1), 2.0);
assert_eq!(banded.get(1, 0), 3.0);
assert_eq!(banded.get(1, 1), 4.0);
assert_eq!(banded.get(1, 2), 5.0);
assert_eq!(banded.get(2, 1), 6.0);
assert_eq!(banded.get(2, 2), 7.0);
assert_eq!(banded.get(0, 2), 0.0);
assert_eq!(banded.get(2, 0), 0.0);
}
}