Skip to main content

math_audio_solvers/iterative/
cg.rs

1//! CG (Conjugate Gradient) solver
2//!
3//! The Conjugate Gradient method for symmetric positive definite systems.
4//! This is the method of choice for SPD matrices as it has optimal convergence.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// CG solver configuration
12#[derive(Debug, Clone)]
13pub struct CgConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for CgConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// CG solver result
33#[derive(Debug)]
34pub struct CgSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43    /// Solver status
44    pub status: SolverStatus,
45}
46
47/// Solve Ax = b using the Conjugate Gradient method
48///
49/// Note: This method is only correct for symmetric positive definite matrices.
50/// For non-symmetric systems, use GMRES or BiCGSTAB instead.
51pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
52where
53    T: ComplexField,
54    A: LinearOperator<T>,
55{
56    let n = b.len();
57    let mut x = Array1::from_elem(n, T::zero());
58
59    let b_norm = vector_norm(b);
60    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
61    if b_norm < tol_threshold {
62        return CgSolution {
63            x,
64            iterations: 0,
65            residual: T::Real::zero(),
66            converged: true,
67            status: SolverStatus::Converged,
68        };
69    }
70
71    // Initial residual r = b - Ax = b (since x = 0)
72    let mut r = b.clone();
73    let mut p = r.clone();
74    let mut rho = inner_product(&r, &r);
75
76    for iter in 0..config.max_iterations {
77        // q = A * p
78        let q = operator.apply(&p);
79
80        // alpha = rho / (p, q)
81        let pq = inner_product(&p, &q);
82        if pq.norm() < T::Real::from_f64(1e-20).unwrap() {
83            return CgSolution {
84                x,
85                iterations: iter,
86                residual: vector_norm(&r) / b_norm,
87                converged: false,
88                status: SolverStatus::Breakdown,
89            };
90        }
91
92        let alpha = rho / pq;
93
94        // x = x + alpha * p
95        x = &x + &p.mapv(|pi| pi * alpha);
96
97        // r = r - alpha * q
98        r = &r - &q.mapv(|qi| qi * alpha);
99
100        let rel_residual = vector_norm(&r) / b_norm;
101
102        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
103            log::info!(
104                "CG iteration {}: relative residual = {:.6e}",
105                iter + 1,
106                rel_residual.to_f64().unwrap_or(0.0)
107            );
108        }
109
110        if rel_residual < config.tolerance {
111            return CgSolution {
112                x,
113                iterations: iter + 1,
114                residual: rel_residual,
115                converged: true,
116                status: SolverStatus::Converged,
117            };
118        }
119
120        let rho_new = inner_product(&r, &r);
121        if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
122            return CgSolution {
123                x,
124                iterations: iter + 1,
125                residual: rel_residual,
126                converged: false,
127                status: SolverStatus::Breakdown,
128            };
129        }
130
131        let beta = rho_new / rho;
132        rho = rho_new;
133
134        // p = r + beta * p
135        p = &r + &p.mapv(|pi| pi * beta);
136    }
137
138    let rel_residual = vector_norm(&r) / b_norm;
139    CgSolution {
140        x,
141        iterations: config.max_iterations,
142        residual: rel_residual,
143        converged: false,
144        status: SolverStatus::MaxIterationsReached,
145    }
146}
147
148/// Solve Ax = b using the Preconditioned Conjugate Gradient method
149///
150/// Uses a preconditioner M to accelerate convergence. At each iteration,
151/// the preconditioner is applied as z = M^{-1} r, and the search direction
152/// is computed using the preconditioned residual.
153///
154/// Note: Both A and M^{-1} must be symmetric positive definite.
155pub fn pcg<T, A, P>(
156    operator: &A,
157    precond: &P,
158    b: &Array1<T>,
159    config: &CgConfig<T::Real>,
160) -> CgSolution<T>
161where
162    T: ComplexField,
163    A: LinearOperator<T>,
164    P: crate::traits::Preconditioner<T>,
165{
166    let n = b.len();
167    let mut x = Array1::from_elem(n, T::zero());
168
169    let b_norm = vector_norm(b);
170    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
171    if b_norm < tol_threshold {
172        return CgSolution {
173            x,
174            iterations: 0,
175            residual: T::Real::zero(),
176            converged: true,
177            status: SolverStatus::Converged,
178        };
179    }
180
181    // Initial residual r = b - Ax = b (since x = 0)
182    let mut r = b.clone();
183    let mut z = precond.apply(&r);
184    let mut p = z.clone();
185    let mut rho = inner_product(&r, &z);
186
187    for iter in 0..config.max_iterations {
188        let q = operator.apply(&p);
189
190        let pq = inner_product(&p, &q);
191        if pq.norm() < T::Real::from_f64(1e-20).unwrap() {
192            return CgSolution {
193                x,
194                iterations: iter,
195                residual: vector_norm(&r) / b_norm,
196                converged: false,
197                status: SolverStatus::Breakdown,
198            };
199        }
200
201        let alpha = rho / pq;
202
203        // x = x + alpha * p
204        x = &x + &p.mapv(|pi| pi * alpha);
205
206        // r = r - alpha * q
207        r = &r - &q.mapv(|qi| qi * alpha);
208
209        let rel_residual = vector_norm(&r) / b_norm;
210
211        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
212            log::info!(
213                "PCG iteration {}: relative residual = {:.6e}",
214                iter + 1,
215                rel_residual.to_f64().unwrap_or(0.0)
216            );
217        }
218
219        if rel_residual < config.tolerance {
220            return CgSolution {
221                x,
222                iterations: iter + 1,
223                residual: rel_residual,
224                converged: true,
225                status: SolverStatus::Converged,
226            };
227        }
228
229        z = precond.apply(&r);
230        let rho_new = inner_product(&r, &z);
231        if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
232            return CgSolution {
233                x,
234                iterations: iter + 1,
235                residual: rel_residual,
236                converged: false,
237                status: SolverStatus::Breakdown,
238            };
239        }
240
241        let beta = rho_new / rho;
242        rho = rho_new;
243
244        // p = z + beta * p
245        p = &z + &p.mapv(|pi| pi * beta);
246    }
247
248    let rel_residual = vector_norm(&r) / b_norm;
249    CgSolution {
250        x,
251        iterations: config.max_iterations,
252        residual: rel_residual,
253        converged: false,
254        status: SolverStatus::MaxIterationsReached,
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261    use crate::sparse::CsrMatrix;
262    use ndarray::array;
263
264    #[test]
265    fn test_cg_spd() {
266        // Symmetric positive definite matrix
267        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
268
269        let a = CsrMatrix::from_dense(&dense, 1e-15);
270        let b = array![1.0_f64, 2.0];
271
272        let config = CgConfig {
273            max_iterations: 100,
274            tolerance: 1e-10,
275            print_interval: 0,
276        };
277
278        let solution = cg(&a, &b, &config);
279
280        assert!(solution.converged, "CG should converge for SPD matrix");
281
282        let ax = a.matvec(&solution.x);
283        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
284        assert!(error < 1e-8, "Solution should satisfy Ax = b");
285    }
286
287    #[test]
288    fn test_cg_identity() {
289        let n = 5;
290        let id: CsrMatrix<f64> = CsrMatrix::identity(n);
291        let b = Array1::from_iter((1..=n).map(|i| i as f64));
292
293        let config = CgConfig {
294            max_iterations: 10,
295            tolerance: 1e-12,
296            print_interval: 0,
297        };
298
299        let solution = cg(&id, &b, &config);
300
301        assert!(solution.converged);
302        assert!(solution.iterations <= 2);
303
304        let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
305        assert!(error < 1e-10);
306    }
307}