use crate::csr_array::CsrArray;
use crate::error::{SparseError, SparseResult};
use crate::linalg::interface::AsLinearOperator;
use crate::linalg::{bicgstab, cg, gmres, BiCGSTABOptions, CGOptions, GMRESOptions};
use crate::sparray::SparseArray;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::numeric::{Float, NumAssign, SparseElement};
use std::fmt::Debug;
pub fn solve_saddle_point<T>(
a: &CsrArray<T>,
b: &CsrArray<T>,
f: &Array1<T>,
g: &Array1<T>,
tol: T,
max_iter: usize,
) -> SparseResult<(Array1<T>, Array1<T>)>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + NumAssign + 'static,
{
let (n, n2) = a.shape();
let (m, n3) = b.shape();
if n != n2 {
return Err(SparseError::ValueError(
"Matrix A must be square".to_string(),
));
}
if n != n3 {
return Err(SparseError::ValueError(
"Matrix B columns must match A dimensions".to_string(),
));
}
if f.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: f.len(),
});
}
if g.len() != m {
return Err(SparseError::DimensionMismatch {
expected: m,
found: g.len(),
});
}
let total_size = n + m;
let mut rows = Vec::new();
let mut cols = Vec::new();
let mut data = Vec::new();
let (a_rows, a_cols, a_data) = a.find();
for i in 0..a_rows.len() {
rows.push(a_rows[i]);
cols.push(a_cols[i]);
data.push(a_data[i]);
}
let (b_rows, b_cols, b_data) = b.find();
for i in 0..b_rows.len() {
rows.push(b_cols[i]);
cols.push(b_rows[i] + n);
data.push(b_data[i]);
}
for i in 0..b_rows.len() {
rows.push(b_rows[i] + n);
cols.push(b_cols[i]);
data.push(b_data[i]);
}
let system_matrix =
CsrArray::from_triplets(&rows, &cols, &data, (total_size, total_size), false)?;
let mut rhs = Array1::zeros(total_size);
for i in 0..n {
rhs[i] = f[i];
}
for i in 0..m {
rhs[n + i] = g[i];
}
let options = GMRESOptions {
rtol: tol,
max_iter,
restart: 30.min(total_size / 2),
..Default::default()
};
let rhs_vec = rhs.to_vec();
let result = gmres(&*system_matrix.as_linear_operator(), &rhs_vec, options)?;
let x = Array1::from_vec(result.x[0..n].to_vec());
let y = Array1::from_vec(result.x[n..total_size].to_vec());
Ok((x, y))
}
pub fn solve_block_2x2<T>(
a11: &CsrArray<T>,
a12: &CsrArray<T>,
a21: &CsrArray<T>,
a22: &CsrArray<T>,
b1: &Array1<T>,
b2: &Array1<T>,
tol: T,
max_iter: usize,
) -> SparseResult<(Array1<T>, Array1<T>)>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + NumAssign + 'static,
{
let (n1, n1_2) = a11.shape();
let (n2, n2_2) = a22.shape();
if n1 != n1_2 || n2 != n2_2 {
return Err(SparseError::ValueError(
"Diagonal blocks must be square".to_string(),
));
}
if a12.shape() != (n1, n2) || a21.shape() != (n2, n1) {
return Err(SparseError::ValueError(
"Off-diagonal blocks have incompatible dimensions".to_string(),
));
}
if b1.len() != n1 {
return Err(SparseError::DimensionMismatch {
expected: n1,
found: b1.len(),
});
}
if b2.len() != n2 {
return Err(SparseError::DimensionMismatch {
expected: n2,
found: b2.len(),
});
}
let mut x1 = Array1::zeros(n1);
let mut x2 = Array1::zeros(n2);
for _ in 0..max_iter {
let x1_old = x1.clone();
let x2_old = x2.clone();
let rhs1 = b1 - &a12.dot_vector(&x2.view())?;
let cg_options = CGOptions {
rtol: tol,
max_iter: 50,
..Default::default()
};
let rhs1_vec = rhs1.to_vec();
let result1 = cg(&*a11.as_linear_operator(), &rhs1_vec, cg_options)?;
x1 = Array1::from_vec(result1.x);
let rhs2 = b2 - &a21.dot_vector(&x1.view())?;
let cg_options2 = CGOptions {
rtol: tol,
max_iter: 50,
..Default::default()
};
let rhs2_vec = rhs2.to_vec();
let result2 = cg(&*a22.as_linear_operator(), &rhs2_vec, cg_options2)?;
x2 = Array1::from_vec(result2.x);
let diff1: T = x1
.iter()
.zip(x1_old.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
let diff2: T = x2
.iter()
.zip(x2_old.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
let total_diff = (diff1 + diff2).sqrt();
if total_diff < tol {
break;
}
}
Ok((x1, x2))
}
pub fn solve_kronecker_system<T>(
a: &CsrArray<T>,
b: &CsrArray<T>,
c: &Array2<T>,
tol: T,
max_iter: usize,
) -> SparseResult<Array2<T>>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + 'static,
{
let (m, m2) = a.shape();
let (n, n2) = b.shape();
if m != m2 || n != n2 {
return Err(SparseError::ValueError(
"Matrices A and B must be square".to_string(),
));
}
if c.shape() != [n, m] {
return Err(SparseError::ShapeMismatch {
expected: (n, m),
found: (c.shape()[0], c.shape()[1]),
});
}
let mut x = c.clone();
for iter in 0..max_iter {
let x_old = x.clone();
let mut residual = c.clone();
let mut bx = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
let mut sum = T::sparse_zero();
for k in 0..n {
sum = sum + b.get(i, k) * x[[k, j]];
}
bx[[i, j]] = sum;
}
}
let mut bxat = Array2::zeros((n, m));
for i in 0..n {
for j in 0..m {
let mut sum = T::sparse_zero();
for k in 0..m {
sum = sum + bx[[i, k]] * a.get(j, k); }
bxat[[i, j]] = sum;
}
}
residual = &residual - &bxat;
let res_norm: T = residual.iter().map(|&r| r * r).sum();
let res_norm = res_norm.sqrt();
if res_norm < tol {
break;
}
let alpha = T::from(0.1)
.ok_or_else(|| SparseError::ComputationError("Cannot convert 0.1".to_string()))?;
x = &x + &residual.mapv(|r| alpha * r);
if iter > 10 {
let diff: T = x
.iter()
.zip(x_old.iter())
.map(|(&a, &b)| (a - b) * (a - b))
.sum();
let diff = diff.sqrt();
if diff
< tol
* T::from(0.01).ok_or_else(|| {
SparseError::ComputationError("Cannot convert 0.01".to_string())
})?
{
break;
}
}
}
Ok(x)
}
pub fn solve_arrow_matrix<T>(
diag: &Array1<T>,
arrow_row: &Array1<T>,
arrow_col: &Array1<T>,
rhs: &Array1<T>,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + 'static,
{
let n = diag.len();
if arrow_row.len() != n - 1 {
return Err(SparseError::DimensionMismatch {
expected: n - 1,
found: arrow_row.len(),
});
}
if arrow_col.len() != n - 1 {
return Err(SparseError::DimensionMismatch {
expected: n - 1,
found: arrow_col.len(),
});
}
if rhs.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: rhs.len(),
});
}
let mut solution = Array1::zeros(n);
for i in 0..n {
if diag[i].abs()
< T::from(1e-10)
.ok_or_else(|| SparseError::ComputationError("Cannot convert 1e-10".to_string()))?
{
return Err(SparseError::ComputationError(
"Singular diagonal element in arrow matrix".to_string(),
));
}
solution[i] = rhs[i] / diag[i];
}
let mut correction = T::sparse_zero();
for i in 1..n {
correction = correction + arrow_col[i - 1] * solution[i] / diag[i];
}
let denom = diag[0] + correction;
if denom.abs()
< T::from(1e-10)
.ok_or_else(|| SparseError::ComputationError("Cannot convert 1e-10".to_string()))?
{
return Err(SparseError::ComputationError(
"Singular system in arrow matrix".to_string(),
));
}
solution[0] = (rhs[0] - correction) / denom;
for i in 1..n {
solution[i] = solution[i] - arrow_row[i - 1] * solution[0] / diag[i];
}
Ok(solution)
}
pub fn solve_banded_system<T>(
matrix: &CsrArray<T>,
rhs: &Array1<T>,
bandwidth: usize,
tol: T,
max_iter: usize,
) -> SparseResult<Array1<T>>
where
T: Float + SparseElement + Debug + Copy + std::iter::Sum + NumAssign + 'static,
{
let (n, n2) = matrix.shape();
if n != n2 {
return Err(SparseError::ValueError("Matrix must be square".to_string()));
}
if rhs.len() != n {
return Err(SparseError::DimensionMismatch {
expected: n,
found: rhs.len(),
});
}
let options = BiCGSTABOptions {
rtol: tol,
max_iter,
..Default::default()
};
let rhs_vec = rhs.to_vec();
let result = bicgstab(&*matrix.as_linear_operator(), &rhs_vec, options)?;
Ok(Array1::from_vec(result.x))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_saddle_point_solver() {
let rows_a = vec![0, 1];
let cols_a = vec![0, 1];
let data_a = vec![2.0, 3.0];
let a = CsrArray::from_triplets(&rows_a, &cols_a, &data_a, (2, 2), false).expect("Failed");
let rows_b = vec![0, 0];
let cols_b = vec![0, 1];
let data_b = vec![1.0, 1.0];
let b = CsrArray::from_triplets(&rows_b, &cols_b, &data_b, (1, 2), false).expect("Failed");
let f = Array1::from_vec(vec![1.0, 2.0]);
let g = Array1::from_vec(vec![3.0]);
let result = solve_saddle_point(&a, &b, &f, &g, 1e-6, 100);
assert!(result.is_ok());
let (x, y) = result.expect("Failed");
assert_eq!(x.len(), 2);
assert_eq!(y.len(), 1);
}
#[test]
fn test_block_2x2_solver() {
let rows_a11 = vec![0, 1];
let cols_a11 = vec![0, 1];
let data_a11 = vec![4.0, 5.0];
let a11 = CsrArray::from_triplets(&rows_a11, &cols_a11, &data_a11, (2, 2), false)
.expect("Failed");
let rows_a12 = vec![0];
let cols_a12 = vec![0];
let data_a12 = vec![1.0];
let a12 = CsrArray::from_triplets(&rows_a12, &cols_a12, &data_a12, (2, 1), false)
.expect("Failed");
let rows_a21 = vec![0];
let cols_a21 = vec![0];
let data_a21 = vec![1.0];
let a21 = CsrArray::from_triplets(&rows_a21, &cols_a21, &data_a21, (1, 2), false)
.expect("Failed");
let rows_a22 = vec![0];
let cols_a22 = vec![0];
let data_a22 = vec![3.0];
let a22 = CsrArray::from_triplets(&rows_a22, &cols_a22, &data_a22, (1, 1), false)
.expect("Failed");
let b1 = Array1::from_vec(vec![1.0, 2.0]);
let b2 = Array1::from_vec(vec![3.0]);
let result = solve_block_2x2(&a11, &a12, &a21, &a22, &b1, &b2, 1e-6, 100);
assert!(result.is_ok());
let (x1, x2) = result.expect("Failed");
assert_eq!(x1.len(), 2);
assert_eq!(x2.len(), 1);
}
#[test]
fn test_arrow_matrix_solver() {
let diag = Array1::from_vec(vec![3.0, 2.0, 4.0, 5.0]);
let arrow_row = Array1::from_vec(vec![1.0, 0.5, 0.3]);
let arrow_col = Array1::from_vec(vec![0.8, 0.6, 0.4]);
let rhs = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = solve_arrow_matrix(&diag, &arrow_row, &arrow_col, &rhs);
assert!(result.is_ok());
let solution = result.expect("Failed");
assert_eq!(solution.len(), 4);
for &val in solution.iter() {
assert!(val.is_finite());
}
}
#[test]
fn test_banded_system_solver() {
let rows = vec![0, 0, 1, 1, 1, 2, 2, 2, 3, 3];
let cols = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
let data = vec![2.0, -1.0, -1.0, 2.0, -1.0, -1.0, 2.0, -1.0, -1.0, 2.0];
let matrix = CsrArray::from_triplets(&rows, &cols, &data, (4, 4), false).expect("Failed");
let rhs = Array1::from_vec(vec![1.0, 0.0, 0.0, 1.0]);
let result = solve_banded_system(&matrix, &rhs, 1, 1e-6, 100);
assert!(result.is_ok());
let solution = result.expect("Failed");
assert_eq!(solution.len(), 4);
let ax = matrix.dot_vector(&solution.view()).expect("Failed");
for i in 0..4 {
assert_relative_eq!(ax[i], rhs[i], epsilon = 1e-4);
}
}
}