use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive, Zero};
use std::fmt::{Debug, Display};
use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, RemAssign, Sub, SubAssign};
use crate::error::{InterpolateError, InterpolateResult};
pub fn solve_linear_system<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy,
{
if a.nrows() != a.ncols() {
return Err(InterpolateError::invalid_input(
"matrix must be square for direct solve".to_string(),
));
}
if a.nrows() != b.len() {
return Err(InterpolateError::invalid_input(
"matrix and vector dimensions must match".to_string(),
));
}
let bandwidth = estimate_bandwidth(a);
let n = a.nrows();
if bandwidth > 0 && bandwidth < n / 4 {
solve_band_system(a, b, bandwidth)
} else {
solve_dense_fallback(a, b)
}
}
pub fn solve_least_squares<T>(a: &ArrayView2<T>, b: &ArrayView1<T>) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Display
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ Zero
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy,
{
if a.nrows() != b.len() {
return Err(InterpolateError::invalid_input(
"matrix and vector dimensions must match".to_string(),
));
}
let at = transpose_matrix(a);
let ata = matrix_multiply(&at.view(), a)?;
let atb = matrix_vector_multiply(&at.view(), b)?;
let mut regularized_ata = ata.clone();
let reg = if a.nrows() < a.ncols() {
T::from_f64(1e-8).expect("Operation failed")
} else {
T::from_f64(1e-10).expect("Operation failed")
};
for i in 0..regularized_ata.nrows() {
regularized_ata[[i, i]] += reg;
}
match solve_linear_system(®ularized_ata.view(), &atb.view()) {
Ok(result) => Ok(result),
Err(_) => {
for i in 0..regularized_ata.nrows() {
regularized_ata[[i, i]] += T::from_f64(1e-6).expect("Operation failed");
}
solve_linear_system(®ularized_ata.view(), &atb.view())
}
}
}
fn solve_band_system<T>(
a: &ArrayView2<T>,
b: &ArrayView1<T>,
bandwidth: usize,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
let _n = a.nrows();
solve_dense_fallback(a, b)
}
fn estimate_bandwidth<T: Float + Zero + FromPrimitive>(matrix: &ArrayView2<T>) -> usize {
let n = matrix.nrows();
let mut max_bandwidth = 0;
let tolerance = T::from_f64(1e-14).expect("Operation failed");
for i in 0..n {
for j in 0..n {
if matrix[[i, j]].abs() > tolerance {
let bandwidth = i.abs_diff(j);
max_bandwidth = max_bandwidth.max(bandwidth);
}
}
}
max_bandwidth
}
fn solve_dense_fallback<T>(
matrix: &ArrayView2<T>,
rhs: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Zero
+ Copy,
{
let n = matrix.nrows();
let mut aug = Array2::zeros((n, n + 1));
for i in 0..n {
for j in 0..n {
aug[[i, j]] = matrix[[i, j]];
}
aug[[i, n]] = rhs[i];
}
for k in 0..n {
let mut max_row = k;
let mut max_val = aug[[k, k]].abs();
for i in (k + 1)..n {
let val = aug[[i, k]].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
let eps = T::from_f64(2.22e-16).expect("Operation failed"); let threshold = eps
* T::from_usize(n).expect("Operation failed")
* T::from_f64(1e8).expect("Operation failed");
if max_val < threshold {
return Err(InterpolateError::invalid_input(
"matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
for j in 0..=n {
let temp = aug[[k, j]];
aug[[k, j]] = aug[[max_row, j]];
aug[[max_row, j]] = temp;
}
}
for i in (k + 1)..n {
let factor = aug[[i, k]] / aug[[k, k]];
for j in k..=n {
let temp = aug[[k, j]];
aug[[i, j]] -= factor * temp;
}
}
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = aug[[i, n]];
for j in (i + 1)..n {
sum -= aug[[i, j]] * x[j];
}
x[i] = sum / aug[[i, i]];
}
Ok(x)
}
pub fn transpose_matrix<T: Copy + Zero>(matrix: &ArrayView2<T>) -> Array2<T> {
let (rows, cols) = matrix.dim();
let mut result = Array2::<T>::zeros((cols, rows));
for i in 0..rows {
for j in 0..cols {
result[[j, i]] = matrix[[i, j]];
}
}
result
}
pub fn matrix_multiply<T>(a: &ArrayView2<T>, b: &ArrayView2<T>) -> InterpolateResult<Array2<T>>
where
T: Float + Zero + AddAssign + MulAssign + Copy,
{
if a.ncols() != b.nrows() {
return Err(InterpolateError::invalid_input(
"matrix dimensions do not match for multiplication".to_string(),
));
}
let (m, n) = (a.nrows(), b.ncols());
let mut result = Array2::zeros((m, n));
for i in 0..m {
for j in 0..n {
let mut sum = T::zero();
for k in 0..a.ncols() {
sum += a[[i, k]] * b[[k, j]];
}
result[[i, j]] = sum;
}
}
Ok(result)
}
pub fn matrix_vector_multiply<T>(
a: &ArrayView2<T>,
b: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float + Zero + AddAssign + Copy,
{
if a.ncols() != b.len() {
return Err(InterpolateError::invalid_input(
"matrix and vector dimensions do not match for multiplication".to_string(),
));
}
let mut result = Array1::zeros(a.nrows());
for i in 0..a.nrows() {
let mut sum = T::zero();
for j in 0..a.ncols() {
sum += a[[i, j]] * b[j];
}
result[i] = sum;
}
Ok(result)
}
pub fn lu_decomposition<T>(
matrix: &ArrayView2<T>,
) -> InterpolateResult<(Array2<T>, Array2<T>, Vec<usize>)>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ Zero
+ Copy,
{
let n = matrix.nrows();
if n != matrix.ncols() {
return Err(InterpolateError::invalid_input(
"matrix must be square for LU decomposition".to_string(),
));
}
let mut l = Array2::zeros((n, n));
let mut u = matrix.to_owned();
let mut perm = Vec::new();
for i in 0..n {
l[[i, i]] = T::one();
perm.push(i);
}
for k in 0..n - 1 {
let mut max_row = k;
let mut max_val = u[[k, k]].abs();
for i in (k + 1)..n {
let val = u[[i, k]].abs();
if val > max_val {
max_val = val;
max_row = i;
}
}
let eps = T::from_f64(2.22e-16).expect("Operation failed"); let threshold = eps
* T::from_usize(n).expect("Operation failed")
* T::from_f64(1000.0).expect("Operation failed");
if max_val < threshold {
return Err(InterpolateError::invalid_input(
"matrix is singular or nearly singular".to_string(),
));
}
if max_row != k {
for j in 0..n {
let temp = u[[k, j]];
u[[k, j]] = u[[max_row, j]];
u[[max_row, j]] = temp;
}
perm.swap(k, max_row);
}
for i in (k + 1)..n {
let factor = u[[i, k]] / u[[k, k]];
l[[i, k]] = factor;
for j in k..n {
let kj_value = u[[k, j]];
u[[i, j]] -= factor * kj_value;
}
}
}
Ok((l, u, perm))
}
pub fn solve_with_lu<T>(
l: &ArrayView2<T>,
u: &ArrayView2<T>,
perm: &[usize],
b: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float
+ FromPrimitive
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ Zero
+ Copy,
{
let n = b.len();
let mut pb = Array1::zeros(n);
for i in 0..n {
pb[i] = b[perm[i]];
}
let mut y = Array1::zeros(n);
for i in 0..n {
let mut sum = pb[i];
for j in 0..i {
sum -= l[[i, j]] * y[j];
}
y[i] = sum / l[[i, i]];
}
let mut x = Array1::zeros(n);
for i in (0..n).rev() {
let mut sum = y[i];
for j in (i + 1)..n {
sum -= u[[i, j]] * x[j];
}
x[i] = sum / u[[i, i]];
}
Ok(x)
}
pub fn solve_multiple_rhs<T>(
matrix: &ArrayView2<T>,
rhs_matrix: &ArrayView2<T>,
) -> InterpolateResult<Array2<T>>
where
T: Float
+ FromPrimitive
+ Debug
+ Add<Output = T>
+ Sub<Output = T>
+ Mul<Output = T>
+ Div<Output = T>
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ Zero
+ Copy,
{
if matrix.nrows() != rhs_matrix.nrows() {
return Err(InterpolateError::invalid_input(
"matrix and RHS dimensions must match".to_string(),
));
}
let (l, u, perm) = lu_decomposition(matrix)?;
let mut solutions = Array2::zeros((matrix.ncols(), rhs_matrix.ncols()));
for j in 0..rhs_matrix.ncols() {
let rhs_col = rhs_matrix.column(j);
let solution = solve_with_lu(&l.view(), &u.view(), &perm, &rhs_col)?;
for i in 0..solution.len() {
solutions[[i, j]] = solution[i];
}
}
Ok(solutions)
}
pub fn condition_number<T>(matrix: &ArrayView2<T>) -> InterpolateResult<T>
where
T: Float + FromPrimitive + Debug + Zero + Copy + AddAssign,
{
let n = matrix.nrows();
if n != matrix.ncols() {
return Err(InterpolateError::invalid_input(
"matrix must be square to compute condition number".to_string(),
));
}
let matrix_norm = frobenius_norm(matrix);
let mut min_diag = T::infinity();
for i in 0..n {
let val = matrix[[i, i]].abs();
if val < min_diag {
min_diag = val;
}
}
if min_diag < T::from_f64(1e-14).expect("Operation failed") {
Ok(T::infinity())
} else {
Ok(matrix_norm / min_diag)
}
}
fn frobenius_norm<T: Float + Zero + Copy + AddAssign>(matrix: &ArrayView2<T>) -> T {
let mut sum = T::zero();
for i in 0..matrix.nrows() {
for j in 0..matrix.ncols() {
let val = matrix[[i, j]];
sum += val * val;
}
}
sum.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_dense_solver() {
let a = array![[2.0, 1.0], [1.0, 3.0]];
let b = array![1.0, 2.0];
let x = solve_dense_fallback(&a.view(), &b.view()).expect("Operation failed");
let verification = matrix_vector_multiply(&a.view(), &x.view()).expect("Operation failed");
for i in 0..b.len() {
assert!((verification[i] - b[i]).abs() < 1e-10);
}
}
#[test]
fn test_matrix_multiply() {
let a = array![[1.0, 2.0], [3.0, 4.0]];
let b = array![[5.0, 6.0], [7.0, 8.0]];
let result = matrix_multiply(&a.view(), &b.view()).expect("Operation failed");
assert_eq!(result[[0, 0]], 19.0);
assert_eq!(result[[0, 1]], 22.0);
assert_eq!(result[[1, 0]], 43.0);
assert_eq!(result[[1, 1]], 50.0);
}
#[test]
fn test_lu_decomposition() {
let a = array![[2.0, 1.0, 1.0], [1.0, 3.0, 2.0], [1.0, 0.0, 0.0]];
let result = lu_decomposition(&a.view());
assert!(result.is_ok());
let (l, u, _perm) = result.expect("Operation failed");
for i in 0..l.nrows() {
assert_eq!(l[[i, i]], 1.0);
for j in (i + 1)..l.ncols() {
assert_eq!(l[[i, j]], 0.0);
}
}
for i in 0..u.nrows() {
for j in 0..i {
assert!(u[[i, j]].abs() < 1e-10);
}
}
}
#[test]
fn test_bandwidth_estimation() {
let mut a = Array2::zeros((4, 4));
a[[0, 0]] = 1.0;
a[[0, 1]] = 1.0;
a[[1, 0]] = 1.0;
a[[1, 1]] = 2.0;
a[[1, 2]] = 1.0;
a[[2, 1]] = 1.0;
a[[2, 2]] = 2.0;
a[[2, 3]] = 1.0;
a[[3, 2]] = 1.0;
a[[3, 3]] = 1.0;
let bandwidth = estimate_bandwidth(&a.view());
assert_eq!(bandwidth, 1); }
}