use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::Debug;
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, RemAssign, Sub, SubAssign};
#[derive(Debug, Clone)]
pub struct BandMatrix<T>
where
T: Float + Copy,
{
size: usize,
kl: usize,
ku: usize,
band_data: Array2<T>,
}
impl<T> BandMatrix<T>
where
T: Float + Copy + Zero + AddAssign,
{
pub fn new(size: usize, kl: usize, ku: usize) -> Self {
let band_data = Array2::zeros((kl + ku + 1, size));
Self {
size,
kl,
ku,
band_data,
}
}
pub fn from_dense(dense: &ArrayView2<T>, kl: usize, ku: usize) -> InterpolateResult<Self> {
if dense.nrows() != dense.ncols() {
return Err(InterpolateError::invalid_input(
"matrix must be square".to_string(),
));
}
let size = dense.nrows();
let mut band_matrix = Self::new(size, kl, ku);
for i in 0..size {
for j in 0..size {
let diag_offset = j as isize - i as isize;
if diag_offset >= -(kl as isize) && diag_offset <= (ku as isize) {
let band_row = (ku as isize - diag_offset) as usize;
band_matrix.band_data[[band_row, i]] = dense[[i, j]];
}
}
}
Ok(band_matrix)
}
pub fn size(&self) -> usize {
self.size
}
pub fn subdiagonals(&self) -> usize {
self.kl
}
pub fn superdiagonals(&self) -> usize {
self.ku
}
pub fn set_diagonal(&mut self, i: usize, value: T) {
if i < self.size {
self.band_data[[self.ku, i]] = value;
}
}
pub fn set_superdiagonal(&mut self, i: usize, value: T) {
if i < self.size - 1 {
self.band_data[[0, i]] = value;
}
}
pub fn set_subdiagonal(&mut self, i: usize, value: T) {
if i > 0 && i < self.size {
self.band_data[[2, i]] = value;
}
}
pub fn set(&mut self, i: usize, j: usize, value: T) -> InterpolateResult<()> {
if i >= self.size || j >= self.size {
return Err(InterpolateError::invalid_input(
"indices out of bounds".to_string(),
));
}
let diag_offset = j as isize - i as isize;
if diag_offset < -(self.kl as isize) || diag_offset > (self.ku as isize) {
return Err(InterpolateError::invalid_input(
"element outside band structure".to_string(),
));
}
let band_row = (self.ku as isize - diag_offset) as usize;
self.band_data[[band_row, i]] = value;
Ok(())
}
pub fn get(&self, i: usize, j: usize) -> T {
if i >= self.size || j >= self.size {
return T::zero();
}
let diag_offset = j as isize - i as isize;
if diag_offset < -(self.kl as isize) || diag_offset > (self.ku as isize) {
return T::zero();
}
let band_row = (self.ku as isize - diag_offset) as usize;
self.band_data[[band_row, i]]
}
pub fn to_dense(&self) -> Array2<T> {
let mut dense = Array2::zeros((self.size, self.size));
for i in 0..self.size {
for j in 0..self.size {
let value = self.get(i, j);
if value != T::zero() {
dense[[i, j]] = value;
}
}
}
dense
}
pub fn multiply_vector(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
if x.len() != self.size {
return Err(InterpolateError::invalid_input(
"vector dimension must match matrix size".to_string(),
));
}
let mut y = Array1::zeros(self.size);
for i in 0..self.size {
let mut sum = T::zero();
let j_start = i.saturating_sub(self.kl);
let j_end = (i + self.ku + 1).min(self.size);
for j in j_start..j_end {
let a_ij = self.get(i, j);
if a_ij != T::zero() {
sum += a_ij * x[j];
}
}
y[i] = sum;
}
Ok(y)
}
pub fn band_data(&self) -> &Array2<T> {
&self.band_data
}
pub fn band_data_mut(&mut self) -> &mut Array2<T> {
&mut self.band_data
}
}
#[derive(Debug, Clone)]
pub struct CSRMatrix<T>
where
T: Float + Copy,
{
nrows: usize,
ncols: usize,
row_ptrs: Vec<usize>,
col_indices: Vec<usize>,
data: Vec<T>,
}
impl<T> CSRMatrix<T>
where
T: Float + Copy + Zero + AddAssign,
{
pub fn new(nrows: usize, ncols: usize) -> Self {
let row_ptrs = vec![0; nrows + 1];
Self {
nrows,
ncols,
row_ptrs,
col_indices: Vec::new(),
data: Vec::new(),
}
}
pub fn from_dense(dense: &ArrayView2<T>, tolerance: T) -> Self {
let (nrows, ncols) = dense.dim();
let mut row_ptrs = Vec::with_capacity(nrows + 1);
let mut col_indices = Vec::new();
let mut data = Vec::new();
row_ptrs.push(0);
for i in 0..nrows {
let mut row_nnz = 0;
for j in 0..ncols {
let value = dense[[i, j]];
if value.abs() > tolerance {
col_indices.push(j);
data.push(value);
row_nnz += 1;
}
}
row_ptrs.push(row_ptrs[i] + row_nnz);
}
Self {
nrows,
ncols,
row_ptrs,
col_indices,
data,
}
}
pub fn shape(&self) -> (usize, usize) {
(self.nrows, self.ncols)
}
pub fn nnz(&self) -> usize {
self.data.len()
}
pub fn multiply_vector(&self, x: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
if x.len() != self.ncols {
return Err(InterpolateError::invalid_input(
"vector dimension must match matrix columns".to_string(),
));
}
let mut y = Array1::zeros(self.nrows);
for i in 0..self.nrows {
let mut sum = T::zero();
let start = self.row_ptrs[i];
let end = self.row_ptrs[i + 1];
for k in start..end {
let j = self.col_indices[k];
let a_ij = self.data[k];
sum += a_ij * x[j];
}
y[i] = sum;
}
Ok(y)
}
pub fn get(&self, i: usize, j: usize) -> T {
if i >= self.nrows || j >= self.ncols {
return T::zero();
}
let start = self.row_ptrs[i];
let end = self.row_ptrs[i + 1];
let mut left = start;
let mut right = end;
while left < right {
let mid = (left + right) / 2;
if self.col_indices[mid] < j {
left = mid + 1;
} else {
right = mid;
}
}
if left < end && self.col_indices[left] == j {
self.data[left]
} else {
T::zero()
}
}
pub fn to_dense(&self) -> Array2<T> {
let mut dense = Array2::zeros((self.nrows, self.ncols));
for i in 0..self.nrows {
let start = self.row_ptrs[i];
let end = self.row_ptrs[i + 1];
for k in start..end {
let j = self.col_indices[k];
dense[[i, j]] = self.data[k];
}
}
dense
}
pub fn data(&self) -> (&[usize], &[usize], &[T]) {
(&self.row_ptrs, &self.col_indices, &self.data)
}
}
#[allow(dead_code)]
pub fn solve_band_system<T>(
band_matrix: &BandMatrix<T>,
rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
if rhs.len() != band_matrix.size() {
return Err(InterpolateError::invalid_input(
"RHS vector size must match _matrix size".to_string(),
));
}
let _n = band_matrix.size();
let _kl = band_matrix.subdiagonals();
let _ku = band_matrix.superdiagonals();
let dense = band_matrix.to_dense();
solve_dense_system(&dense.view(), rhs)
}
pub(crate) fn solve_dense_system<T>(
matrix: &ArrayView2<T>,
rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
let n = matrix.nrows();
if matrix.ncols() != n {
return Err(InterpolateError::invalid_input(
"matrix must be square".to_string(),
));
}
if rhs.len() != n {
return Err(InterpolateError::invalid_input(
"RHS vector size must match matrix size".to_string(),
));
}
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = matrix[[i, j]];
}
aug[[i, n]] = rhs[i];
}
for k in 0..n {
let mut max_row = k;
let mut max_val = aug[[k, k]].abs();
for i in (k + 1)..n {
let val = aug[[i, k]].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
if max_val < T::from_f64(1e-14).expect("Operation failed") {
return Err(InterpolateError::invalid_input(
"matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
for j in 0..=n {
let temp = aug[[k, j]];
aug[[k, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
for i in (k + 1)..n {
let factor = aug[[i, k]] / aug[[k, k]];
for j in k..=n {
let temp = aug[[k, j]];
aug[[i, j]] -= factor * temp;
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = aug[[i, n]];
for j in (i + 1)..n {
sum -= aug[[i, j]] * x[j];
}
x[i] = sum / aug[[i, i]];
}
Ok(x)
}
#[allow(dead_code)]
pub fn solve_sparse_system<T>(
sparse_matrix: &CSRMatrix<T>,
rhs: &ArrayView1<T>,
tolerance: T,
max_iterations: usize,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
let n = sparse_matrix.nrows;
if rhs.len() != n {
return Err(InterpolateError::invalid_input(
"RHS vector size must match _matrix size".to_string(),
));
}
let mut x = Array1::zeros(n);
let mut x_new = Array1::zeros(n);
for _iter in 0..max_iterations {
for i in 0..n {
let mut sum = T::zero();
let start = sparse_matrix.row_ptrs[i];
let end = sparse_matrix.row_ptrs[i + 1];
let mut diagonal = T::zero();
for k in start..end {
let j = sparse_matrix.col_indices[k];
let a_ij = sparse_matrix.data[k];
if i == j {
diagonal = a_ij;
} else {
sum += a_ij * x[j];
}
}
if diagonal.abs() < T::from_f64(1e-14).expect("Operation failed") {
return Err(InterpolateError::invalid_input(
"_matrix has zero diagonal element".to_string(),
));
}
x_new[i] = (rhs[i] - sum) / diagonal;
}
let mut diff_norm = T::zero();
for i in 0..n {
let diff = x_new[i] - x[i];
diff_norm += diff * diff;
}
diff_norm = diff_norm.sqrt();
if diff_norm < tolerance {
return Ok(x_new);
}
x.assign(&x_new);
}
Err(InterpolateError::invalid_input(
"iterative solver failed to converge".to_string(),
))
}
#[allow(dead_code)]
pub fn solve_structured_least_squares<T>(
matrix: &ArrayView2<T>,
rhs: &ArrayView1<T>,
tolerance: Option<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
let m = matrix.nrows();
let n = matrix.ncols();
if rhs.len() != m {
return Err(InterpolateError::invalid_input(
"RHS vector size must match matrix rows".to_string(),
));
}
let mut ata = Array2::zeros((n, n));
for i in 0..n {
for j in 0..n {
let mut sum = T::zero();
for k in 0..m {
sum += matrix[[k, i]] * matrix[[k, j]];
}
ata[[i, j]] = sum;
}
}
let mut atb = Array1::zeros(n);
for i in 0..n {
let mut sum = T::zero();
for k in 0..m {
sum += matrix[[k, i]] * rhs[k];
}
atb[i] = sum;
}
if let Some(reg) = tolerance {
for i in 0..n {
ata[[i, i]] += reg;
}
}
solve_dense_system(&ata.view(), &atb.view())
}
#[allow(dead_code)]
pub fn create_bspline_band_matrix<T>(n: usize, degree: usize) -> BandMatrix<T>
where
T: Float + Copy + Zero + AddAssign,
{
let bandwidth = degree;
BandMatrix::new(n, bandwidth, bandwidth)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
pub fn vectorized_matvec<T>(
matrix: &ArrayView2<T>,
vector: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float + Copy + Zero + AddAssign + 'static,
{
use crate::simd_optimized::is_simd_available;
let (m, n) = matrix.dim();
if vector.len() != n {
return Err(InterpolateError::invalid_input(
"vector size must match matrix columns".to_string(),
));
}
let mut result = Array1::zeros(m);
if is_simd_available() && std::any::TypeId::of::<T>() == std::any::TypeId::of::<f64>() {
vectorized_matvec_simd_f64(matrix, vector, &mut result)?;
} else {
vectorized_matvec_scalar(matrix, vector, &mut result)?;
}
Ok(result)
}
#[cfg(feature = "simd")]
#[allow(dead_code)]
fn vectorized_matvec_simd_f64<T>(
matrix: &ArrayView2<T>,
vector: &ArrayView1<T>,
result: &mut Array1<T>,
) -> InterpolateResult<()>
where
T: Float + Copy + Zero + AddAssign,
{
let (m, n) = matrix.dim();
for i in 0..m {
let mut sum = T::zero();
for j in 0..n {
sum += matrix[[i, j]] * vector[j];
}
result[i] = sum;
}
Ok(())
}
#[cfg(not(feature = "simd"))]
#[allow(dead_code)]
pub fn vectorized_matvec<T>(
matrix: &ArrayView2<T>,
vector: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float + Copy + Zero + AddAssign + 'static,
{
let (m, n) = matrix.dim();
if vector.len() != n {
return Err(InterpolateError::invalid_input(
"vector size must match matrix columns".to_string(),
));
}
let mut result = Array1::zeros(m);
vectorized_matvec_scalar(matrix, vector, &mut result)?;
Ok(result)
}
#[allow(dead_code)]
fn vectorized_matvec_scalar<T>(
matrix: &ArrayView2<T>,
vector: &ArrayView1<T>,
result: &mut Array1<T>,
) -> InterpolateResult<()>
where
T: Float + Copy + Zero + AddAssign,
{
let (m, n) = matrix.dim();
const BLOCK_SIZE: usize = 64;
for i_block in (0..m).step_by(BLOCK_SIZE) {
let i_end = (i_block + BLOCK_SIZE).min(m);
for j_block in (0..n).step_by(BLOCK_SIZE) {
let j_end = (j_block + BLOCK_SIZE).min(n);
for i in i_block..i_end {
let mut sum = T::zero();
for j in j_block..j_end {
sum += matrix[[i, j]] * vector[j];
}
result[i] += sum;
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_band_matrix_operations() {
let mut band_matrix = BandMatrix::new(3, 1, 1);
band_matrix.set_diagonal(0, 2.0);
band_matrix.set_diagonal(1, 2.0);
band_matrix.set_diagonal(2, 2.0);
band_matrix.set_superdiagonal(0, -1.0); band_matrix.set_superdiagonal(1, -1.0); band_matrix.set_subdiagonal(1, -1.0); band_matrix.set_subdiagonal(2, -1.0);
assert_eq!(band_matrix.get(0, 0), 2.0);
assert_eq!(band_matrix.get(0, 1), -1.0);
assert_eq!(band_matrix.get(0, 2), 0.0);
assert_eq!(band_matrix.get(1, 0), -1.0);
assert_eq!(band_matrix.get(1, 1), 2.0);
let x = array![1.0, 2.0, 3.0];
let y = band_matrix
.multiply_vector(&x.view())
.expect("Operation failed");
assert_relative_eq!(y[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(y[1], 0.0, epsilon = 1e-10);
assert_relative_eq!(y[2], 4.0, epsilon = 1e-10);
}
#[test]
fn test_sparse_matrix_operations() {
let dense = array![[2.0, -1.0, 0.0], [-1.0, 2.0, -1.0], [0.0, -1.0, 2.0]];
let sparse = CSRMatrix::from_dense(&dense.view(), 1e-12);
assert_eq!(sparse.shape(), (3, 3));
assert_eq!(sparse.nnz(), 7);
assert_eq!(sparse.get(0, 0), 2.0);
assert_eq!(sparse.get(0, 1), -1.0);
assert_eq!(sparse.get(0, 2), 0.0);
let x = array![1.0, 2.0, 3.0];
let y = sparse.multiply_vector(&x.view()).expect("Operation failed");
assert_relative_eq!(y[0], 0.0, epsilon = 1e-10);
assert_relative_eq!(y[1], 0.0, epsilon = 1e-10);
assert_relative_eq!(y[2], 4.0, epsilon = 1e-10);
}
#[test]
fn test_band_system_solver() {
let mut matrix = BandMatrix::new(3, 1, 1);
matrix.set_diagonal(0, 1.0);
matrix.set_diagonal(1, 2.0);
matrix.set_diagonal(2, 1.0);
matrix.set_superdiagonal(1, 1.0);
matrix.set_superdiagonal(2, 1.0);
matrix.set_subdiagonal(1, 1.0);
matrix.set_subdiagonal(2, 1.0);
let rhs = array![2.0, 4.0, 2.0];
let solution = solve_band_system(&matrix, &rhs.view()).expect("Operation failed");
let verification = matrix
.multiply_vector(&solution.view())
.expect("Operation failed");
for i in 0..3 {
assert_relative_eq!(verification[i], rhs[i], epsilon = 1e-10);
}
}
#[test]
fn test_sparse_system_solver() {
let dense = array![[2.0, 0.0, 0.0], [0.0, 3.0, 0.0], [0.0, 0.0, 4.0]];
let sparse = CSRMatrix::from_dense(&dense.view(), 1e-12);
let rhs = array![4.0, 9.0, 16.0];
let solution =
solve_sparse_system(&sparse, &rhs.view(), 1e-10, 100).expect("Operation failed");
assert_relative_eq!(solution[0], 2.0, epsilon = 1e-8);
assert_relative_eq!(solution[1], 3.0, epsilon = 1e-8);
assert_relative_eq!(solution[2], 4.0, epsilon = 1e-8);
}
#[test]
fn test_bspline_band_matrix_creation() {
let band_matrix = create_bspline_band_matrix::<f64>(10, 3);
assert_eq!(band_matrix.size(), 10);
assert_eq!(band_matrix.subdiagonals(), 3);
assert_eq!(band_matrix.superdiagonals(), 3);
}
#[test]
fn test_structured_least_squares() {
let matrix = array![[1.0, 1.0], [2.0, 1.0], [3.0, 1.0]];
let rhs = array![2.0, 3.0, 4.0];
let solution = solve_structured_least_squares(&matrix.view(), &rhs.view(), None)
.expect("Operation failed");
let residual = {
let mut r = Array1::zeros(3);
for i in 0..3 {
let mut pred = 0.0;
for j in 0..2 {
pred += matrix[[i, j]] * solution[j];
}
r[i] = rhs[i] - pred;
}
r
};
let residual_norm: f64 = residual.iter().map(|&x| x * x).sum::<f64>().sqrt();
assert!(residual_norm < 1e-10);
}
}