use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, Zero};
#[derive(Debug, Clone)]
pub struct SolverConfig<T: Float> {
pub max_iter: usize,
pub tol: T,
pub restart: Option<usize>,
pub use_preconditioner: bool,
}
impl<T: Float> Default for SolverConfig<T> {
fn default() -> Self {
Self {
max_iter: 1000,
tol: T::from(1e-6).expect("1e-6 is a valid f64 constant"),
restart: Some(30),
use_preconditioner: false,
}
}
}
#[derive(Debug, Clone)]
pub struct SolverResult<T: Clone> {
pub solution: Array<T>,
pub iterations: usize,
pub residual_norm: T,
pub converged: bool,
}
#[inline]
pub fn compute_norm_vec<T: Float>(v: &[T]) -> T {
v.iter().fold(T::zero(), |acc, &x| acc + x * x).sqrt()
}
#[inline]
pub fn dot_vec<T: Float>(a: &[T], b: &[T]) -> T {
a.iter()
.zip(b.iter())
.fold(T::zero(), |acc, (&x, &y)| acc + x * y)
}
pub fn matvec<T>(a: &Array<T>, x: &Array<T>) -> Result<Array<T>>
where
T: Float + Clone + Zero,
{
let n = x.size();
let x_col = x.clone().reshape(&[n, 1]);
let result = a.matmul(&x_col)?;
Ok(result.reshape(&[n]))
}
pub fn compute_norm<T>(v: &Array<T>) -> Result<T>
where
T: Float + Clone + Zero,
{
let n = v.size();
let mut sum = T::zero();
for i in 0..n {
let val = v.get(&[i])?;
sum = sum + val * val;
}
Ok(sum.sqrt())
}
pub fn validate_system<T>(a: &Array<T>, b: &Array<T>) -> Result<usize>
where
T: Float + Clone,
{
let shape = a.shape();
if shape.len() != 2 || shape[0] != shape[1] {
return Err(NumRs2Error::DimensionMismatch(
"Matrix must be square".to_string(),
));
}
let n = shape[0];
if b.size() != n {
return Err(NumRs2Error::ShapeMismatch {
expected: vec![n],
actual: b.shape(),
});
}
Ok(n)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_solver_config_default() {
let config: SolverConfig<f64> = SolverConfig::default();
assert_eq!(config.max_iter, 1000);
assert!((config.tol - 1e-6).abs() < 1e-12);
assert_eq!(config.restart, Some(30));
assert!(!config.use_preconditioner);
}
#[test]
fn test_compute_norm_vec() {
let v = vec![3.0, 4.0];
let norm = compute_norm_vec(&v);
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn test_dot_vec() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let dot = dot_vec(&a, &b);
assert!((dot - 32.0).abs() < 1e-10);
}
#[test]
fn test_validate_system() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let b = Array::from_vec(vec![1.0, 2.0]);
let n = validate_system(&a, &b).expect("Should validate");
assert_eq!(n, 2);
}
#[test]
fn test_validate_system_non_square() {
let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).reshape(&[2, 3]);
let b = Array::from_vec(vec![1.0, 2.0]);
let result = validate_system(&a, &b);
assert!(result.is_err());
}
}