solvers/iterative/
cg.rs

1//! CG (Conjugate Gradient) solver
2//!
3//! The Conjugate Gradient method for symmetric positive definite systems.
4//! This is the method of choice for SPD matrices as it has optimal convergence.
5
6use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10/// CG solver configuration
11#[derive(Debug, Clone)]
12pub struct CgConfig<R> {
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Relative tolerance for convergence
16    pub tolerance: R,
17    /// Print progress every N iterations (0 = no output)
18    pub print_interval: usize,
19}
20
21impl Default for CgConfig<f64> {
22    fn default() -> Self {
23        Self {
24            max_iterations: 1000,
25            tolerance: 1e-6,
26            print_interval: 0,
27        }
28    }
29}
30
31/// CG solver result
32#[derive(Debug)]
33pub struct CgSolution<T: ComplexField> {
34    /// Solution vector
35    pub x: Array1<T>,
36    /// Number of iterations
37    pub iterations: usize,
38    /// Final relative residual
39    pub residual: T::Real,
40    /// Whether convergence was achieved
41    pub converged: bool,
42}
43
44/// Solve Ax = b using the Conjugate Gradient method
45///
46/// Note: This method is only correct for symmetric positive definite matrices.
47/// For non-symmetric systems, use GMRES or BiCGSTAB instead.
48pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
49where
50    T: ComplexField,
51    A: LinearOperator<T>,
52{
53    let n = b.len();
54    let mut x = Array1::from_elem(n, T::zero());
55
56    let b_norm = vector_norm(b);
57    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
58    if b_norm < tol_threshold {
59        return CgSolution {
60            x,
61            iterations: 0,
62            residual: T::Real::zero(),
63            converged: true,
64        };
65    }
66
67    // Initial residual r = b - Ax = b (since x = 0)
68    let mut r = b.clone();
69    let mut p = r.clone();
70    let mut rho = inner_product(&r, &r);
71
72    for iter in 0..config.max_iterations {
73        // q = A * p
74        let q = operator.apply(&p);
75
76        // alpha = rho / (p, q)
77        let pq = inner_product(&p, &q);
78        if pq.norm() < T::Real::from_f64(1e-30).unwrap() {
79            return CgSolution {
80                x,
81                iterations: iter,
82                residual: vector_norm(&r) / b_norm,
83                converged: false,
84            };
85        }
86
87        let alpha = rho / pq;
88
89        // x = x + alpha * p
90        x = &x + &p.mapv(|pi| pi * alpha);
91
92        // r = r - alpha * q
93        r = &r - &q.mapv(|qi| qi * alpha);
94
95        let rel_residual = vector_norm(&r) / b_norm;
96
97        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
98            log::info!(
99                "CG iteration {}: relative residual = {:.6e}",
100                iter + 1,
101                rel_residual.to_f64().unwrap_or(0.0)
102            );
103        }
104
105        if rel_residual < config.tolerance {
106            return CgSolution {
107                x,
108                iterations: iter + 1,
109                residual: rel_residual,
110                converged: true,
111            };
112        }
113
114        let rho_new = inner_product(&r, &r);
115        if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
116            return CgSolution {
117                x,
118                iterations: iter + 1,
119                residual: rel_residual,
120                converged: false,
121            };
122        }
123
124        let beta = rho_new / rho;
125        rho = rho_new;
126
127        // p = r + beta * p
128        p = &r + &p.mapv(|pi| pi * beta);
129    }
130
131    let rel_residual = vector_norm(&r) / b_norm;
132    CgSolution {
133        x,
134        iterations: config.max_iterations,
135        residual: rel_residual,
136        converged: false,
137    }
138}
139
140#[inline]
141fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
142    x.iter()
143        .zip(y.iter())
144        .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
145}
146
147#[inline]
148fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
149    x.iter()
150        .map(|xi| xi.norm_sqr())
151        .fold(T::Real::zero(), |acc, v| acc + v)
152        .sqrt()
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::sparse::CsrMatrix;
159    use ndarray::array;
160
161    #[test]
162    fn test_cg_spd() {
163        // Symmetric positive definite matrix
164        let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
165
166        let a = CsrMatrix::from_dense(&dense, 1e-15);
167        let b = array![1.0_f64, 2.0];
168
169        let config = CgConfig {
170            max_iterations: 100,
171            tolerance: 1e-10,
172            print_interval: 0,
173        };
174
175        let solution = cg(&a, &b, &config);
176
177        assert!(solution.converged, "CG should converge for SPD matrix");
178
179        let ax = a.matvec(&solution.x);
180        let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
181        assert!(error < 1e-8, "Solution should satisfy Ax = b");
182    }
183
184    #[test]
185    fn test_cg_identity() {
186        let n = 5;
187        let id: CsrMatrix<f64> = CsrMatrix::identity(n);
188        let b = Array1::from_iter((1..=n).map(|i| i as f64));
189
190        let config = CgConfig {
191            max_iterations: 10,
192            tolerance: 1e-12,
193            print_interval: 0,
194        };
195
196        let solution = cg(&id, &b, &config);
197
198        assert!(solution.converged);
199        assert!(solution.iterations <= 2);
200
201        let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
202        assert!(error < 1e-10);
203    }
204}