math_audio_solvers/iterative/
cg.rs

1//! CG (Conjugate Gradient) solver
2//!
3//! The Conjugate Gradient method for symmetric positive definite systems.
4//! This is the method of choice for SPD matrices as it has optimal convergence.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// CG solver configuration
12#[derive(Debug, Clone)]
13pub struct CgConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for CgConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// CG solver result
33#[derive(Debug)]
34pub struct CgSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43}
44
45/// Solve Ax = b using the Conjugate Gradient method
46///
47/// Note: This method is only correct for symmetric positive definite matrices.
48/// For non-symmetric systems, use GMRES or BiCGSTAB instead.
49pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
50where
51    T: ComplexField,
52    A: LinearOperator<T>,
53{
54    let n = b.len();
55    let mut x = Array1::from_elem(n, T::zero());
56
57    let b_norm = vector_norm(b);
58    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
59    if b_norm < tol_threshold {
60        return CgSolution {
61            x,
62            iterations: 0,
63            residual: T::Real::zero(),
64            converged: true,
65        };
66    }
67
68    // Initial residual r = b - Ax = b (since x = 0)
69    let mut r = b.clone();
70    let mut p = r.clone();
71    let mut rho = inner_product(&r, &r);
72
73    for iter in 0..config.max_iterations {
74        // q = A * p
75        let q = operator.apply(&p);
76
77        // alpha = rho / (p, q)
78        let pq = inner_product(&p, &q);
79        if pq.norm() < T::Real::from_f64(1e-30).unwrap() {
80            return CgSolution {
81                x,
82                iterations: iter,
83                residual: vector_norm(&r) / b_norm,
84                converged: false,
85            };
86        }
87
88        let alpha = rho / pq;
89
90        // x = x + alpha * p
91        x = &x + &p.mapv(|pi| pi * alpha);
92
93        // r = r - alpha * q
94        r = &r - &q.mapv(|qi| qi * alpha);
95
96        let rel_residual = vector_norm(&r) / b_norm;
97
98        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
99            log::info!(
100                "CG iteration {}: relative residual = {:.6e}",
101                iter + 1,
102                rel_residual.to_f64().unwrap_or(0.0)
103            );
104        }
105
106        if rel_residual < config.tolerance {
107            return CgSolution {
108                x,
109                iterations: iter + 1,
110                residual: rel_residual,
111                converged: true,
112            };
113        }
114
115        let rho_new = inner_product(&r, &r);
116        if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
117            return CgSolution {
118                x,
119                iterations: iter + 1,
120                residual: rel_residual,
121                converged: false,
122            };
123        }
124
125        let beta = rho_new / rho;
126        rho = rho_new;
127
128        // p = r + beta * p
129        p = &r + &p.mapv(|pi| pi * beta);
130    }
131
132    let rel_residual = vector_norm(&r) / b_norm;
133    CgSolution {
134        x,
135        iterations: config.max_iterations,
136        residual: rel_residual,
137        converged: false,
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use crate::sparse::CsrMatrix;
145    use ndarray::array;
146
147    #[test]
148    fn test_cg_spd() {
149        // Symmetric positive definite matrix
150        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
151
152        let a = CsrMatrix::from_dense(&dense, 1e-15);
153        let b = array![1.0_f64, 2.0];
154
155        let config = CgConfig {
156            max_iterations: 100,
157            tolerance: 1e-10,
158            print_interval: 0,
159        };
160
161        let solution = cg(&a, &b, &config);
162
163        assert!(solution.converged, "CG should converge for SPD matrix");
164
165        let ax = a.matvec(&solution.x);
166        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
167        assert!(error < 1e-8, "Solution should satisfy Ax = b");
168    }
169
170    #[test]
171    fn test_cg_identity() {
172        let n = 5;
173        let id: CsrMatrix<f64> = CsrMatrix::identity(n);
174        let b = Array1::from_iter((1..=n).map(|i| i as f64));
175
176        let config = CgConfig {
177            max_iterations: 10,
178            tolerance: 1e-12,
179            print_interval: 0,
180        };
181
182        let solution = cg(&id, &b, &config);
183
184        assert!(solution.converged);
185        assert!(solution.iterations <= 2);
186
187        let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
188        assert!(error < 1e-10);
189    }
190}