use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, Zero};
use super::core::{compute_norm_vec, dot_vec, matvec, SolverResult};
pub fn minres<T>(
a: &Array<T>,
b: &Array<T>,
x0: Option<&Array<T>>,
tol: Option<T>,
max_iter: Option<usize>,
) -> Result<SolverResult<T>>
where
T: Float + Clone + Zero,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"Matrix must be square".to_string(),
));
}
let n = shape[0];
if b.size() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: b.shape(),
});
}
let tol = tol.unwrap_or_else(|| T::from(1e-6).unwrap_or(T::epsilon()));
let max_iter = max_iter.unwrap_or(n * 2);
let b_vec = b.to_vec();
let b_norm = compute_norm_vec(&b_vec);
if b_norm < T::from(1e-14).unwrap_or(T::epsilon()) {
return Ok(SolverResult {
solution: Array::zeros(&[n]),
iterations: 0,
residual_norm: T::zero(),
converged: true,
});
}
let mut x_vec = if let Some(x0_arr) = x0 {
x0_arr.to_vec()
} else {
vec![T::zero(); n]
};
let x_arr = Array::from_vec(x_vec.clone());
let ax = matvec(a, &x_arr)?;
let ax_vec = ax.to_vec();
let r_vec: Vec<T> = b_vec
.iter()
.zip(ax_vec.iter())
.map(|(&bi, &axi)| bi - axi)
.collect();
let beta1 = compute_norm_vec(&r_vec);
if beta1 < T::from(1e-14).unwrap_or(T::epsilon()) {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: 0,
residual_norm: T::zero(),
converged: true,
});
}
let mut v_prev = vec![T::zero(); n];
let mut v: Vec<T> = r_vec.iter().map(|&ri| ri / beta1).collect();
let mut d_prev = vec![T::zero(); n];
let mut d_prev2 = vec![T::zero(); n];
let mut c_prev = T::one(); let mut s_prev = T::zero(); let mut c_prev2 = T::one(); let mut s_prev2 = T::zero();
let mut phi_bar = beta1;
let mut beta_k = T::zero();
let mut iter = 0;
for k in 0..max_iter {
iter = k + 1;
let v_arr = Array::from_vec(v.clone());
let av = matvec(a, &v_arr)?;
let av_vec = av.to_vec();
let alpha_k = dot_vec(&v, &av_vec);
let v_new: Vec<T> = (0..n)
.map(|i| av_vec[i] - alpha_k * v[i] - beta_k * v_prev[i])
.collect();
let beta_next = compute_norm_vec(&v_new);
let epsilon_k = s_prev2 * beta_k;
let beta_rotated = c_prev2 * beta_k;
let delta_k = c_prev * beta_rotated + s_prev * alpha_k;
let gamma_tilde = -s_prev * beta_rotated + c_prev * alpha_k;
let gamma_k = (gamma_tilde * gamma_tilde + beta_next * beta_next).sqrt();
let (c_k, s_k) = if gamma_k > T::from(1e-14).unwrap_or(T::epsilon()) {
(gamma_tilde / gamma_k, beta_next / gamma_k)
} else {
(T::one(), T::zero())
};
let d_new: Vec<T> = if gamma_k > T::from(1e-14).unwrap_or(T::epsilon()) {
(0..n)
.map(|i| (v[i] - delta_k * d_prev[i] - epsilon_k * d_prev2[i]) / gamma_k)
.collect()
} else {
vec![T::zero(); n]
};
let tau_k = c_k * phi_bar;
for i in 0..n {
x_vec[i] = x_vec[i] + tau_k * d_new[i];
}
phi_bar = -s_k * phi_bar;
let residual_norm = phi_bar.abs();
if residual_norm / b_norm < tol {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter,
residual_norm,
converged: true,
});
}
if beta_next < T::from(1e-14).unwrap_or(T::epsilon()) {
return Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter,
residual_norm,
converged: residual_norm / b_norm < tol,
});
}
v_prev = v;
v = v_new.iter().map(|&x| x / beta_next).collect();
d_prev2 = d_prev;
d_prev = d_new;
c_prev2 = c_prev;
s_prev2 = s_prev;
c_prev = c_k;
s_prev = s_k;
beta_k = beta_next;
}
let x_arr = Array::from_vec(x_vec.clone());
let ax = matvec(a, &x_arr)?;
let final_residual: T = b_vec
.iter()
.zip(ax.to_vec().iter())
.map(|(&bi, &axi)| {
let diff = bi - axi;
diff * diff
})
.fold(T::zero(), |acc, x| acc + x)
.sqrt();
Ok(SolverResult {
solution: Array::from_vec(x_vec),
iterations: iter,
residual_norm: final_residual,
converged: false,
})
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_minres_symmetric_indefinite() {
let a = Array::from_vec(vec![
2.0, 1.0, 1.0, -1.0, ])
.reshape(&[2, 2]);
let b = Array::from_vec(vec![1.0, 0.0]);
let result = minres(&a, &b, None, Some(1e-6), Some(100)).expect("Should solve");
assert!(
result.converged,
"MINRES should converge for symmetric indefinite system"
);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..2 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-5
);
}
}
#[test]
fn test_minres_spd_matrix() {
let a = Array::from_vec(vec![4.0, 1.0, 1.0, 3.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![1.0, 2.0]);
let result = minres(&a, &b, None, Some(1e-6), Some(100)).expect("Should solve");
assert!(result.converged, "MINRES should work for SPD matrices");
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..2 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-5
);
}
}
#[test]
fn test_minres_saddle_point() {
let a = Array::from_vec(vec![
3.0, 1.0, 0.0, 1.0, 2.0, 1.0, 0.0, 1.0, -1.0, ])
.reshape(&[3, 3]);
let b = Array::from_vec(vec![1.0, 2.0, 1.0]);
let result = minres(&a, &b, None, Some(1e-6), Some(150)).expect("Should solve");
assert!(
result.converged,
"MINRES should handle saddle point problems"
);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..3 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-5
);
}
}
#[test]
fn test_minres_identity_matrix() {
let a = Array::from_vec(vec![1.0, 0.0, 0.0, 1.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![3.0, 4.0]);
let result = minres(&a, &b, None, Some(1e-10), Some(100)).expect("Should solve");
assert!(result.converged);
assert!(
result.iterations <= 2,
"Identity should converge in <= 2 iterations"
);
for i in 0..2 {
assert_relative_eq!(
result.solution.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-9
);
}
}
#[test]
fn test_minres_larger_indefinite() {
let a = Array::from_vec(vec![
4.0, 1.0, 0.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 1.0, -2.0,
1.0, 0.0, 0.0, 1.0, 2.0,
])
.reshape(&[4, 4]);
let b = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let result = minres(&a, &b, None, Some(1e-6), Some(200)).expect("Should solve");
assert!(
result.converged,
"MINRES should converge for larger indefinite system"
);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..4 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-4
);
}
}
#[test]
fn test_minres_with_initial_guess() {
let a = Array::from_vec(vec![2.0, 1.0, 1.0, -1.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![1.0, 0.0]);
let x0 = Array::from_vec(vec![0.5, 0.5]);
let result = minres(&a, &b, Some(&x0), Some(1e-6), Some(100)).expect("Should solve");
assert!(result.converged);
let ax = matvec(&a, &result.solution).expect("matvec should work");
for i in 0..2 {
assert_relative_eq!(
ax.get(&[i]).expect("valid"),
b.get(&[i]).expect("valid"),
epsilon = 1e-5
);
}
}
#[test]
fn test_minres_zero_rhs() {
let a = Array::from_vec(vec![2.0, 1.0, 1.0, -1.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![0.0, 0.0]);
let result = minres(&a, &b, None, Some(1e-6), Some(100)).expect("Should solve");
assert!(result.converged);
assert_eq!(result.iterations, 0, "Zero RHS should converge immediately");
for i in 0..2 {
assert_relative_eq!(
result.solution.get(&[i]).expect("valid"),
0.0,
epsilon = 1e-10
);
}
}
}