use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::Zero;
use std::fmt;
use std::ops::{Add, Mul};
#[derive(Clone)]
pub struct BandedMatrix<T> {
rows: usize,
cols: usize,
sub_diagonals: usize,
super_diagonals: usize,
data: Array<T>,
}
impl<T> BandedMatrix<T>
where
T: Clone + Default + Zero + PartialEq,
{
pub fn new(rows: usize, cols: usize, sub_diagonals: usize, super_diagonals: usize) -> Self {
let bands = sub_diagonals + super_diagonals + 1;
let band_length = cols;
let data = Array::full(&[bands, band_length], T::default());
Self {
rows,
cols,
sub_diagonals,
super_diagonals,
data,
}
}
pub fn from_array(
array: &Array<T>,
sub_diagonals: usize,
super_diagonals: usize,
) -> Result<Self> {
if array.ndim() != 2 {
return Err(NumRs2Error::DimensionMismatch(format!(
"Banded matrix must be created from 2D array, got {}-dimensional array",
array.ndim()
)));
}
let shape = array.shape();
let rows = shape[0];
let cols = shape[1];
let mut banded = Self::new(rows, cols, sub_diagonals, super_diagonals);
for i in 0..rows {
for j in 0..cols {
let diagonal = j as isize - i as isize;
if diagonal >= -(sub_diagonals as isize) && diagonal <= super_diagonals as isize {
let value = array.get(&[i, j])?;
banded.set(i, j, value.clone())?;
}
}
}
Ok(banded)
}
pub fn nrows(&self) -> usize {
self.rows
}
pub fn ncols(&self) -> usize {
self.cols
}
pub fn sub_diagonals(&self) -> usize {
self.sub_diagonals
}
pub fn super_diagonals(&self) -> usize {
self.super_diagonals
}
pub fn band_width(&self) -> usize {
self.sub_diagonals + self.super_diagonals + 1
}
pub fn is_in_band(&self, i: usize, j: usize) -> bool {
let diagonal = j as isize - i as isize;
diagonal >= -(self.sub_diagonals as isize) && diagonal <= self.super_diagonals as isize
}
pub fn get(&self, i: usize, j: usize) -> Result<T> {
if i >= self.rows || j >= self.cols {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index ({}, {}) out of bounds for banded matrix with shape ({}, {})",
i, j, self.rows, self.cols
)));
}
if !self.is_in_band(i, j) {
return Ok(T::default());
}
let diagonal = j as isize - i as isize;
let band_row = (self.sub_diagonals as isize + diagonal) as usize;
let band_col = if diagonal < 0 { j } else { i };
Ok(self.data.get(&[band_row, band_col])?.clone())
}
pub fn set(&mut self, i: usize, j: usize, value: T) -> Result<()> {
if i >= self.rows || j >= self.cols {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index ({}, {}) out of bounds for banded matrix with shape ({}, {})",
i, j, self.rows, self.cols
)));
}
if !self.is_in_band(i, j) {
return Err(NumRs2Error::InvalidOperation(format!(
"Cannot set element at ({}, {}), position is outside the band",
i, j
)));
}
let diagonal = j as isize - i as isize;
let band_row = (self.sub_diagonals as isize + diagonal) as usize;
let band_col = if diagonal < 0 { j } else { i };
self.data.set(&[band_row, band_col], value)
}
pub fn to_array(&self) -> Array<T> {
let mut array = Array::full(&[self.rows, self.cols], T::default());
for i in 0..self.rows {
for j in 0..self.cols {
if self.is_in_band(i, j) {
let value = self
.get(i, j)
.expect("to_array: index within band should be valid");
array
.set(&[i, j], value)
.expect("to_array: index within array bounds should be valid");
}
}
}
array
}
pub fn diagonal(&self) -> Vec<T> {
let diag_length = std::cmp::min(self.rows, self.cols);
let mut diag = Vec::with_capacity(diag_length);
for i in 0..diag_length {
diag.push(
self.get(i, i)
.expect("diagonal: index within min(rows, cols) should be valid"),
);
}
diag
}
pub fn is_square(&self) -> bool {
self.rows == self.cols
}
}
impl<T> BandedMatrix<T>
where
T: Clone + Default + Zero + PartialEq + Add<Output = T> + Mul<Output = T>,
{
pub fn matvec(&self, vec: &[T]) -> Result<Vec<T>> {
if vec.len() != self.cols {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![self.cols],
actual: vec![vec.len()],
});
}
let mut result = Vec::with_capacity(self.rows);
for i in 0..self.rows {
let mut sum = T::default();
let j_start = i.saturating_sub(self.sub_diagonals);
let j_end = std::cmp::min(i + self.super_diagonals + 1, self.cols);
#[allow(clippy::needless_range_loop)]
for j in j_start..j_end {
let a_ij = self
.get(i, j)
.expect("matvec: index within band should be valid");
let x_j = vec[j].clone();
sum = sum + (a_ij * x_j);
}
result.push(sum);
}
Ok(result)
}
}
impl<T> fmt::Display for BandedMatrix<T>
where
T: Clone + fmt::Display + Default + Zero + PartialEq,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(
f,
"BandedMatrix({}, {}, sub_diagonals={}, super_diagonals={})",
self.rows, self.cols, self.sub_diagonals, self.super_diagonals
)?;
for i in 0..self.rows {
write!(f, "[")?;
for j in 0..self.cols {
let value = self
.get(i, j)
.expect("Display: index within matrix bounds should be valid");
if j > 0 {
write!(f, ", ")?;
}
if !self.is_in_band(i, j) {
write!(f, "0")?;
} else {
write!(f, "{}", value)?;
}
}
writeln!(f, "]")?;
}
Ok(())
}
}