Skip to main content

nabled_ml/
iterative.rs

1//! Iterative linear system solvers over ndarray matrices.
2
3use std::fmt;
4
5use nabled_core::scalar::NabledReal;
6use nabled_linalg::lu;
7use ndarray::{Array1, Array2};
8use num_complex::Complex64;
9
10const DEFAULT_TOLERANCE: f64 = 1.0e-12;
11
12/// Configuration for iterative solvers.
13#[derive(Debug, Clone)]
14pub struct IterativeConfig<T = f64> {
15    /// Relative residual tolerance.
16    pub tolerance:      T,
17    /// Maximum iterations.
18    pub max_iterations: usize,
19}
20
21impl IterativeConfig<f64> {
22    /// Default configuration for `f64`.
23    #[must_use]
24    pub const fn default_f64() -> Self { Self { tolerance: 1e-10, max_iterations: 1000 } }
25}
26
27impl Default for IterativeConfig<f64> {
28    fn default() -> Self { Self::default_f64() }
29}
30
31impl IterativeConfig<f32> {
32    /// Default configuration for `f32`.
33    #[must_use]
34    pub const fn default_f32() -> Self { Self { tolerance: 1e-6, max_iterations: 1000 } }
35}
36
37impl Default for IterativeConfig<f32> {
38    fn default() -> Self { Self::default_f32() }
39}
40
41/// Error type for iterative solvers.
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum IterativeError {
44    /// Matrix is empty.
45    EmptyMatrix,
46    /// Dimensions do not align.
47    DimensionMismatch,
48    /// Maximum iterations reached without convergence.
49    MaxIterationsExceeded,
50    /// Matrix is not positive definite (CG).
51    NotPositiveDefinite,
52    /// Algorithm breakdown.
53    Breakdown,
54}
55
56impl fmt::Display for IterativeError {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            IterativeError::EmptyMatrix => write!(f, "Matrix is empty"),
60            IterativeError::DimensionMismatch => write!(f, "Dimension mismatch"),
61            IterativeError::MaxIterationsExceeded => write!(f, "Maximum iterations exceeded"),
62            IterativeError::NotPositiveDefinite => write!(f, "Matrix is not positive definite"),
63            IterativeError::Breakdown => write!(f, "Algorithm breakdown"),
64        }
65    }
66}
67
68impl std::error::Error for IterativeError {}
69
70fn default_tolerance<T: NabledReal>() -> T {
71    T::from_f64(DEFAULT_TOLERANCE).unwrap_or_else(T::epsilon)
72}
73
74fn vector_norm<T: NabledReal>(vector: &Array1<T>) -> T {
75    vector
76        .iter()
77        .map(|value| *value * *value)
78        .fold(T::zero(), |acc, value| acc + value)
79        .sqrt()
80}
81
82fn vector_norm_complex(vector: &Array1<Complex64>) -> f64 {
83    vector.iter().map(Complex64::norm_sqr).sum::<f64>().sqrt()
84}
85
86#[cfg(feature = "lapack-provider")]
87trait IterativeLinearScalar: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack {}
88
89#[cfg(feature = "lapack-provider")]
90impl<T> IterativeLinearScalar for T where
91    T: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack
92{
93}
94
95#[cfg(not(feature = "lapack-provider"))]
96trait IterativeLinearScalar: NabledReal + std::ops::SubAssign {}
97
98#[cfg(not(feature = "lapack-provider"))]
99impl<T> IterativeLinearScalar for T where T: NabledReal + std::ops::SubAssign {}
100
101/// Conjugate Gradient for SPD systems `Ax=b`.
102///
103/// # Errors
104/// Returns an error when inputs are invalid or convergence fails.
105pub fn conjugate_gradient<T>(
106    matrix_a: &Array2<T>,
107    matrix_b: &Array1<T>,
108    config: &IterativeConfig<T>,
109) -> Result<Array1<T>, IterativeError>
110where
111    T: NabledReal,
112{
113    if matrix_a.is_empty() || matrix_b.is_empty() {
114        return Err(IterativeError::EmptyMatrix);
115    }
116    if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
117        return Err(IterativeError::DimensionMismatch);
118    }
119
120    let n = matrix_b.len();
121    let mut x = Array1::<T>::zeros(n);
122    let mut r = matrix_b.clone();
123    let mut p = r.clone();
124    let mut rs_old = r.dot(&r);
125
126    let tolerance = config.tolerance.max(default_tolerance::<T>());
127    if rs_old.sqrt() <= tolerance {
128        return Ok(x);
129    }
130
131    for _ in 0..config.max_iterations {
132        let ap = matrix_a.dot(&p);
133        let curvature = p.dot(&ap);
134        if curvature <= tolerance {
135            return Err(IterativeError::NotPositiveDefinite);
136        }
137
138        let alpha = rs_old / curvature;
139        x = &x + &p.mapv(|value| alpha * value);
140        r = &r - &ap.mapv(|value| alpha * value);
141
142        let rs_new = r.dot(&r);
143        if rs_new.sqrt() <= tolerance {
144            return Ok(x);
145        }
146
147        let beta = rs_new / rs_old;
148        p = &r + &p.mapv(|value| beta * value);
149        rs_old = rs_new;
150    }
151
152    Err(IterativeError::MaxIterationsExceeded)
153}
154
155/// Conjugate Gradient for Hermitian positive-definite systems `Ax=b`.
156///
157/// # Errors
158/// Returns an error when inputs are invalid or convergence fails.
159pub fn conjugate_gradient_complex(
160    matrix_a: &Array2<Complex64>,
161    matrix_b: &Array1<Complex64>,
162    config: &IterativeConfig<f64>,
163) -> Result<Array1<Complex64>, IterativeError> {
164    if matrix_a.is_empty() || matrix_b.is_empty() {
165        return Err(IterativeError::EmptyMatrix);
166    }
167    if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
168        return Err(IterativeError::DimensionMismatch);
169    }
170
171    let n = matrix_b.len();
172    let mut x = Array1::<Complex64>::zeros(n);
173    let mut r = matrix_b.clone();
174    let mut p = r.clone();
175    let mut rs_old = r.iter().zip(r.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
176    let tolerance = config.tolerance.max(DEFAULT_TOLERANCE);
177
178    if rs_old.re.max(0.0).sqrt() <= tolerance {
179        return Ok(x);
180    }
181
182    for _ in 0..config.max_iterations {
183        let ap = matrix_a.dot(&p);
184        let curvature =
185            p.iter().zip(ap.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
186        if curvature.re <= tolerance || curvature.norm() <= tolerance {
187            return Err(IterativeError::NotPositiveDefinite);
188        }
189
190        let alpha = rs_old / curvature;
191        x = &x + &(alpha * &p);
192        r = &r - &(alpha * &ap);
193
194        let rs_new = r.iter().zip(r.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
195        if rs_new.re.max(0.0).sqrt() <= tolerance {
196            return Ok(x);
197        }
198
199        if rs_old.norm() <= tolerance {
200            return Err(IterativeError::Breakdown);
201        }
202        let beta = rs_new / rs_old;
203        p = &r + &(beta * &p);
204        rs_old = rs_new;
205    }
206
207    Err(IterativeError::MaxIterationsExceeded)
208}
209
210fn solve_linear<T>(matrix: &Array2<T>, rhs: &Array1<T>) -> Result<Array1<T>, IterativeError>
211where
212    T: IterativeLinearScalar,
213{
214    lu::solve(matrix, rhs).map_err(|_| IterativeError::Breakdown)
215}
216
217/// GMRES for general systems `Ax=b`.
218///
219/// # Errors
220/// Returns an error when inputs are invalid or convergence fails.
221#[allow(clippy::many_single_char_names)]
222#[cfg(feature = "lapack-provider")]
223pub fn gmres<T>(
224    matrix_a: &Array2<T>,
225    matrix_b: &Array1<T>,
226    config: &IterativeConfig<T>,
227) -> Result<Array1<T>, IterativeError>
228where
229    T: NabledReal + std::ops::SubAssign + ndarray_linalg::Lapack,
230{
231    gmres_impl(matrix_a, matrix_b, config)
232}
233
234/// GMRES for general systems `Ax=b`.
235///
236/// # Errors
237/// Returns an error when inputs are invalid or convergence fails.
238#[allow(clippy::many_single_char_names)]
239#[cfg(not(feature = "lapack-provider"))]
240pub fn gmres<T>(
241    matrix_a: &Array2<T>,
242    matrix_b: &Array1<T>,
243    config: &IterativeConfig<T>,
244) -> Result<Array1<T>, IterativeError>
245where
246    T: NabledReal + std::ops::SubAssign,
247{
248    gmres_impl(matrix_a, matrix_b, config)
249}
250
251#[allow(clippy::many_single_char_names)]
252fn gmres_impl<T>(
253    matrix_a: &Array2<T>,
254    matrix_b: &Array1<T>,
255    config: &IterativeConfig<T>,
256) -> Result<Array1<T>, IterativeError>
257where
258    T: IterativeLinearScalar,
259{
260    if matrix_a.is_empty() || matrix_b.is_empty() {
261        return Err(IterativeError::EmptyMatrix);
262    }
263    if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
264        return Err(IterativeError::DimensionMismatch);
265    }
266
267    let n = matrix_b.len();
268    let m = n.min(config.max_iterations.max(1));
269    let mut basis = Array2::<T>::zeros((n, m + 1));
270    let mut hessenberg = Array2::<T>::zeros((m + 1, m));
271
272    let beta = vector_norm(matrix_b);
273    let tolerance = config.tolerance.max(default_tolerance::<T>());
274    if beta <= tolerance {
275        return Ok(Array1::<T>::zeros(n));
276    }
277
278    for row in 0..n {
279        basis[[row, 0]] = matrix_b[row] / beta;
280    }
281
282    let mut effective_m = m;
283    for j in 0..m {
284        let mut w = matrix_a.dot(&basis.column(j));
285
286        for i in 0..=j {
287            let vi = basis.column(i);
288            let hij = vi.dot(&w);
289            hessenberg[[i, j]] = hij;
290            for row in 0..n {
291                w[row] -= hij * basis[[row, i]];
292            }
293        }
294
295        let norm_w = vector_norm(&w);
296        hessenberg[[j + 1, j]] = norm_w;
297        if norm_w <= tolerance {
298            effective_m = j + 1;
299            break;
300        }
301        for row in 0..n {
302            basis[[row, j + 1]] = w[row] / norm_w;
303        }
304    }
305
306    let h = hessenberg.slice(ndarray::s![..(effective_m + 1), ..effective_m]);
307    let ht = h.t();
308    let normal_matrix = ht.dot(&h);
309
310    let mut rhs_ls = Array1::<T>::zeros(effective_m + 1);
311    rhs_ls[0] = beta;
312    let normal_rhs = ht.dot(&rhs_ls);
313
314    let y = solve_linear(&normal_matrix, &normal_rhs)?;
315    let x = basis.slice(ndarray::s![.., ..effective_m]).dot(&y);
316
317    let residual = matrix_b - &matrix_a.dot(&x);
318    if vector_norm(&residual) <= tolerance {
319        Ok(x)
320    } else {
321        Err(IterativeError::MaxIterationsExceeded)
322    }
323}
324
325/// GMRES for complex general systems `Ax=b`.
326///
327/// # Errors
328/// Returns an error when inputs are invalid or convergence fails.
329#[allow(clippy::many_single_char_names)]
330pub fn gmres_complex(
331    matrix_a: &Array2<Complex64>,
332    matrix_b: &Array1<Complex64>,
333    config: &IterativeConfig<f64>,
334) -> Result<Array1<Complex64>, IterativeError> {
335    if matrix_a.is_empty() || matrix_b.is_empty() {
336        return Err(IterativeError::EmptyMatrix);
337    }
338    if matrix_a.nrows() != matrix_a.ncols() || matrix_a.nrows() != matrix_b.len() {
339        return Err(IterativeError::DimensionMismatch);
340    }
341
342    let n = matrix_b.len();
343    let m = n.min(config.max_iterations.max(1));
344    let mut basis = Array2::<Complex64>::zeros((n, m + 1));
345    let mut hessenberg = Array2::<Complex64>::zeros((m + 1, m));
346    let tolerance = config.tolerance.max(DEFAULT_TOLERANCE);
347
348    let beta = vector_norm_complex(matrix_b);
349    if beta <= tolerance {
350        return Ok(Array1::<Complex64>::zeros(n));
351    }
352
353    for row in 0..n {
354        basis[[row, 0]] = matrix_b[row] / beta;
355    }
356
357    let mut effective_m = m;
358    for j in 0..m {
359        let mut w = matrix_a.dot(&basis.column(j));
360
361        for i in 0..=j {
362            let vi = basis.column(i);
363            let hij = vi.iter().zip(w.iter()).map(|(lhs, rhs)| lhs.conj() * rhs).sum::<Complex64>();
364            hessenberg[[i, j]] = hij;
365            for row in 0..n {
366                w[row] -= hij * basis[[row, i]];
367            }
368        }
369
370        let norm_w = vector_norm_complex(&w);
371        hessenberg[[j + 1, j]] = Complex64::new(norm_w, 0.0);
372        if norm_w <= tolerance {
373            effective_m = j + 1;
374            break;
375        }
376        for row in 0..n {
377            basis[[row, j + 1]] = w[row] / norm_w;
378        }
379    }
380
381    let h = hessenberg.slice(ndarray::s![..(effective_m + 1), ..effective_m]);
382    let h_conj_t = h.mapv(|value| value.conj()).reversed_axes();
383    let normal_matrix = h_conj_t.dot(&h);
384
385    let mut rhs_ls = Array1::<Complex64>::zeros(effective_m + 1);
386    rhs_ls[0] = Complex64::new(beta, 0.0);
387    let normal_rhs = h_conj_t.dot(&rhs_ls);
388
389    let y =
390        lu::solve_complex(&normal_matrix, &normal_rhs).map_err(|_| IterativeError::Breakdown)?;
391    let x = basis.slice(ndarray::s![.., ..effective_m]).dot(&y);
392
393    let residual = matrix_b - &matrix_a.dot(&x);
394    if vector_norm_complex(&residual) <= tolerance {
395        Ok(x)
396    } else {
397        Err(IterativeError::MaxIterationsExceeded)
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use ndarray::{Array1, Array2};
404    use num_complex::Complex64;
405
406    use super::*;
407
408    #[test]
409    fn cg_solves_spd_system() {
410        let matrix = Array2::from_shape_vec((2, 2), vec![4.0_f64, 1.0, 1.0, 3.0]).unwrap();
411        let rhs = Array1::from_vec(vec![1.0_f64, 2.0]);
412        let solution =
413            conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
414        let reconstructed = matrix.dot(&solution);
415        assert!((reconstructed[0] - rhs[0]).abs() < 1e-8);
416        assert!((reconstructed[1] - rhs[1]).abs() < 1e-8);
417    }
418
419    #[test]
420    fn gmres_solves_small_system() {
421        let matrix = Array2::from_shape_vec((2, 2), vec![3.0_f64, 1.0, 1.0, 2.0]).unwrap();
422        let rhs = Array1::from_vec(vec![9.0_f64, 8.0]);
423        let solution = gmres(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
424        let reconstructed = matrix.dot(&solution);
425        assert!((reconstructed[0] - rhs[0]).abs() < 1e-8);
426        assert!((reconstructed[1] - rhs[1]).abs() < 1e-8);
427    }
428
429    #[test]
430    fn real_f32_solvers_work() {
431        let matrix = Array2::from_shape_vec((2, 2), vec![4.0_f32, 1.0, 1.0, 3.0]).unwrap();
432        let rhs = Array1::from_vec(vec![1.0_f32, 2.0]);
433        let config = IterativeConfig::<f32>::default();
434
435        let cg = conjugate_gradient(&matrix, &rhs, &config).unwrap();
436        let gm = gmres(&matrix, &rhs, &config).unwrap();
437
438        let cg_reconstructed = matrix.dot(&cg);
439        let gm_reconstructed = matrix.dot(&gm);
440        for i in 0..rhs.len() {
441            assert!((cg_reconstructed[i] - rhs[i]).abs() < 1e-4);
442            assert!((gm_reconstructed[i] - rhs[i]).abs() < 1e-4);
443        }
444    }
445
446    #[test]
447    fn cg_rejects_dimension_mismatch() {
448        let matrix = Array2::<f64>::eye(2);
449        let rhs = Array1::from_vec(vec![1.0_f64, 2.0, 3.0]);
450        let result = conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default());
451        assert!(matches!(result, Err(IterativeError::DimensionMismatch)));
452    }
453
454    #[test]
455    fn gmres_rejects_empty_input() {
456        let matrix = Array2::<f64>::zeros((0, 0));
457        let rhs = Array1::<f64>::zeros(0);
458        let result = gmres(&matrix, &rhs, &IterativeConfig::<f64>::default());
459        assert!(matches!(result, Err(IterativeError::EmptyMatrix)));
460    }
461
462    #[test]
463    fn cg_returns_zero_for_zero_rhs() {
464        let matrix = Array2::<f64>::eye(2);
465        let rhs = Array1::from_vec(vec![0.0_f64, 0.0]);
466        let solution =
467            conjugate_gradient(&matrix, &rhs, &IterativeConfig::<f64>::default()).unwrap();
468        assert!(solution.iter().all(|value| value.abs() < 1e-12_f64));
469    }
470
471    #[test]
472    fn cg_complex_solves_hermitian_spd_system() {
473        let matrix = Array2::from_shape_vec((2, 2), vec![
474            Complex64::new(4.0, 0.0),
475            Complex64::new(1.0, 1.0),
476            Complex64::new(1.0, -1.0),
477            Complex64::new(3.0, 0.0),
478        ])
479        .unwrap();
480        let rhs = Array1::from_vec(vec![Complex64::new(1.0, 0.5), Complex64::new(2.0, -1.0)]);
481        let solution =
482            conjugate_gradient_complex(&matrix, &rhs, &IterativeConfig::default()).unwrap();
483        let reconstructed = matrix.dot(&solution);
484        for i in 0..rhs.len() {
485            assert!((reconstructed[i] - rhs[i]).norm() < 1e-7);
486        }
487    }
488
489    #[test]
490    fn gmres_complex_solves_small_system() {
491        let matrix = Array2::from_shape_vec((2, 2), vec![
492            Complex64::new(3.0, 1.0),
493            Complex64::new(1.0, -0.5),
494            Complex64::new(0.5, 1.0),
495            Complex64::new(2.0, -1.0),
496        ])
497        .unwrap();
498        let rhs = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, -1.0)]);
499        let solution = gmres_complex(&matrix, &rhs, &IterativeConfig::default()).unwrap();
500        let reconstructed = matrix.dot(&solution);
501        for i in 0..rhs.len() {
502            assert!((reconstructed[i] - rhs[i]).norm() < 1e-7);
503        }
504    }
505
506    #[test]
507    fn cg_complex_rejects_dimension_mismatch() {
508        let matrix = Array2::from_shape_vec((2, 2), vec![
509            Complex64::new(1.0, 0.0),
510            Complex64::new(0.0, 0.0),
511            Complex64::new(0.0, 0.0),
512            Complex64::new(1.0, 0.0),
513        ])
514        .unwrap();
515        let rhs = Array1::from_vec(vec![Complex64::new(1.0, 0.0)]);
516        let result = conjugate_gradient_complex(&matrix, &rhs, &IterativeConfig::default());
517        assert!(matches!(result, Err(IterativeError::DimensionMismatch)));
518    }
519
520    #[test]
521    fn gmres_complex_rejects_empty_input() {
522        let matrix = Array2::<Complex64>::zeros((0, 0));
523        let rhs = Array1::<Complex64>::zeros(0);
524        let result = gmres_complex(&matrix, &rhs, &IterativeConfig::default());
525        assert!(matches!(result, Err(IterativeError::EmptyMatrix)));
526    }
527}