use crate::error::{IntegrateError, IntegrateResult};
use scirs2_core::ndarray::{Array1, ArrayView1, ArrayView2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LinearSolverType {
Direct,
Iterative,
Auto,
}
#[allow(dead_code)]
pub fn solve_linear_system<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign,
{
let n = a.shape()[0];
if a.shape()[0] != a.shape()[1] {
return Err(IntegrateError::ValueError(format!(
"Matrix must be square to solve linear system, got shape {:?}",
a.shape()
)));
}
if b.len() != n {
return Err(IntegrateError::ValueError(
format!("Right-hand side vector dimensions incompatible with matrix: matrix has {} rows but vector has {} elements",
n, b.len())
));
}
let mut a_copy = a.to_owned();
let mut b_copy = b.to_owned();
for k in 0..n {
let mut pivot_idx = k;
let mut max_val = a_copy[[k, k]].abs();
for i in (k + 1)..n {
let val = a_copy[[i, k]].abs();
if val > max_val {
max_val = val;
pivot_idx = i;
}
}
if max_val < F::from_f64(1e-14).expect("Operation failed") {
return Err(IntegrateError::ValueError(
"Matrix is singular or nearly singular".to_string(),
));
}
if pivot_idx != k {
for j in k..n {
let temp = a_copy[[k, j]];
a_copy[[k, j]] = a_copy[[pivot_idx, j]];
a_copy[[pivot_idx, j]] = temp;
}
let temp = b_copy[k];
b_copy[k] = b_copy[pivot_idx];
b_copy[pivot_idx] = temp;
}
for i in (k + 1)..n {
let factor = a_copy[[i, k]] / a_copy[[k, k]];
b_copy[i] = b_copy[i] - factor * b_copy[k];
a_copy[[i, k]] = F::zero();
for j in (k + 1)..n {
a_copy[[i, j]] = a_copy[[i, j]] - factor * a_copy[[k, j]];
}
}
}
let mut x = Array1::<F>::zeros(n);
for i in (0..n).rev() {
let mut sum = b_copy[i];
for j in (i + 1)..n {
sum -= a_copy[[i, j]] * x[j];
}
x[i] = sum / a_copy[[i, i]];
}
Ok(x)
}
#[allow(dead_code)]
pub fn vector_norm<F>(v: &ArrayView1<F>) -> F
where
F: Float,
{
let mut sum = F::zero();
for &val in v.iter() {
sum = sum + val * val;
}
sum.sqrt()
}
#[allow(dead_code)]
pub fn matrix_norm<F>(m: &ArrayView2<F>) -> F
where
F: Float,
{
let mut sum = F::zero();
for val in m.iter() {
sum = sum + (*val) * (*val);
}
sum.sqrt()
}
#[allow(dead_code)]
pub fn auto_solve_linear_system<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
solver_type: LinearSolverType,
) -> IntegrateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ std::default::Default
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ std::ops::DivAssign,
{
match solver_type {
LinearSolverType::Direct => solve_linear_system(a, b),
LinearSolverType::Iterative => {
solve_gmres(a, b, None, None, None)
}
LinearSolverType::Auto => {
let n = a.shape()[0];
if n < 100 {
solve_linear_system(a, b)
} else {
solve_gmres(a, b, None, None, None)
}
}
}
}
#[allow(dead_code)]
pub fn solve_lu<F>(a: &ArrayView2<F>, b: &ArrayView1<F>) -> IntegrateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign,
{
solve_linear_system(a, b)
}
#[allow(dead_code)]
pub fn solve_gmres<F>(
a: &ArrayView2<F>,
b: &ArrayView1<F>,
max_iter: Option<usize>,
tol: Option<F>,
restart: Option<usize>,
) -> IntegrateResult<Array1<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ std::ops::AddAssign
+ std::ops::SubAssign
+ std::ops::MulAssign
+ Default
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ std::ops::DivAssign,
{
let n = a.nrows();
if n != a.ncols() {
return Err(IntegrateError::ValueError(
"Matrix must be square".to_string(),
));
}
if n != b.len() {
return Err(IntegrateError::ValueError(
"Matrix and vector dimensions must match".to_string(),
));
}
let max_iter = max_iter.unwrap_or(std::cmp::min(n, 50));
let tol = tol.unwrap_or_else(|| F::from_f64(1e-10).expect("Operation failed"));
let restart = restart.unwrap_or(std::cmp::min(n, 20));
let mut x = Array1::<F>::zeros(n);
let mut r = b.to_owned();
for i in 0..n {
let mut ax_i = F::zero();
for j in 0..n {
ax_i += a[[i, j]] * x[j];
}
r[i] -= ax_i;
}
let initial_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
if initial_norm < tol {
return Ok(x); }
let mut outer_iter = 0;
while outer_iter < max_iter {
let m = std::cmp::min(restart, max_iter - outer_iter);
let beta = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
if beta < tol {
break; }
let mut v = vec![Array1::<F>::zeros(n); m + 1];
v[0] = &r / beta;
let mut h = vec![vec![F::zero(); m]; m + 1];
let mut g = vec![F::zero(); m + 1];
g[0] = beta;
let mut j = 0;
while j < m {
let mut w = Array1::<F>::zeros(n);
for i in 0..n {
for k in 0..n {
w[i] += a[[i, k]] * v[j][k];
}
}
for i in 0..=j {
h[i][j] = v[i].dot(&w);
for k in 0..n {
w[k] -= h[i][j] * v[i][k];
}
}
h[j + 1][j] = (w.iter().map(|&x| x * x).sum::<F>()).sqrt();
if h[j + 1][j] < F::from_f64(1e-14).expect("Operation failed") {
break;
}
v[j + 1] = &w / h[j + 1][j];
for i in 0..j {
let c = if i < g.len() - 1 {
h[i][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
} else {
F::one()
};
let s = if i < g.len() - 1 {
h[i + 1][j] / (h[i][j] * h[i][j] + h[i + 1][j] * h[i + 1][j]).sqrt()
} else {
F::zero()
};
let temp = c * h[i][j] + s * h[i + 1][j];
h[i + 1][j] = -s * h[i][j] + c * h[i + 1][j];
h[i][j] = temp;
}
let c = h[j][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
let s = h[j + 1][j] / (h[j][j] * h[j][j] + h[j + 1][j] * h[j + 1][j]).sqrt();
h[j][j] = c * h[j][j] + s * h[j + 1][j];
h[j + 1][j] = F::zero();
let temp = c * g[j];
g[j + 1] = -s * g[j];
g[j] = temp;
if g[j + 1].abs() < tol * initial_norm {
j += 1;
break;
}
j += 1;
}
let mut y = vec![F::zero(); j];
for i in (0..j).rev() {
let mut sum = g[i];
for k in (i + 1)..j {
sum -= h[i][k] * y[k];
}
y[i] = sum / h[i][i];
}
for i in 0..n {
for k in 0..j {
x[i] += y[k] * v[k][i];
}
}
r = b.to_owned();
for i in 0..n {
let mut ax_i = F::zero();
for k in 0..n {
ax_i += a[[i, k]] * x[k];
}
r[i] -= ax_i;
}
let residual_norm = (r.iter().map(|&x| x * x).sum::<F>()).sqrt();
if residual_norm < tol * initial_norm {
break; }
outer_iter += m;
}
Ok(x)
}