Skip to main content

math_audio_solvers/iterative/
cgs.rs

1//! CGS (Conjugate Gradient Squared) solver
2//!
3//! CGS is a Krylov subspace method for non-symmetric systems.
4//! It can converge faster than BiCG but may be less stable.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// CGS solver configuration
12#[derive(Debug, Clone)]
13pub struct CgsConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for CgsConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// CGS solver result
33#[derive(Debug)]
34pub struct CgsSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43    /// Solver status
44    pub status: SolverStatus,
45}
46
47/// Solve Ax = b using the CGS method
48pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
49where
50    T: ComplexField,
51    A: LinearOperator<T>,
52{
53    let n = b.len();
54    let mut x = Array1::from_elem(n, T::zero());
55
56    let b_norm = vector_norm(b);
57    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
58    if b_norm < tol_threshold {
59        return CgsSolution {
60            x,
61            iterations: 0,
62            residual: T::Real::zero(),
63            converged: true,
64            status: SolverStatus::Converged,
65        };
66    }
67
68    // Initial residual
69    let mut r = b.clone();
70    let r0 = r.clone(); // Shadow residual
71
72    let mut rho = inner_product(&r0, &r);
73    let mut p = r.clone();
74    let mut u = r.clone();
75
76    for iter in 0..config.max_iterations {
77        // v = A * p
78        let v = operator.apply(&p);
79
80        let sigma = inner_product(&r0, &v);
81        if sigma.norm() < T::Real::from_f64(1e-20).unwrap() {
82            return CgsSolution {
83                x,
84                iterations: iter,
85                residual: vector_norm(&r) / b_norm,
86                converged: false,
87                status: SolverStatus::Breakdown,
88            };
89        }
90
91        let alpha = rho / sigma;
92
93        // q = u - alpha * v
94        let q = &u - &v.mapv(|vi| vi * alpha);
95
96        // w = A * (u + q)
97        let u_plus_q = &u + &q;
98        let w = operator.apply(&u_plus_q);
99
100        // x = x + alpha * (u + q)
101        x = &x + &u_plus_q.mapv(|ui| ui * alpha);
102
103        // r = r - alpha * w
104        r = &r - &w.mapv(|wi| wi * alpha);
105
106        let rel_residual = vector_norm(&r) / b_norm;
107
108        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
109            log::info!(
110                "CGS iteration {}: relative residual = {:.6e}",
111                iter + 1,
112                rel_residual.to_f64().unwrap_or(0.0)
113            );
114        }
115
116        if rel_residual < config.tolerance {
117            return CgsSolution {
118                x,
119                iterations: iter + 1,
120                residual: rel_residual,
121                converged: true,
122                status: SolverStatus::Converged,
123            };
124        }
125
126        let rho_new = inner_product(&r0, &r);
127        if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
128            return CgsSolution {
129                x,
130                iterations: iter + 1,
131                residual: rel_residual,
132                converged: false,
133                status: SolverStatus::Breakdown,
134            };
135        }
136
137        let beta = rho_new / rho;
138        rho = rho_new;
139
140        // u = r + beta * q
141        u = &r + &q.mapv(|qi| qi * beta);
142
143        // p = u + beta * (q + beta * p)
144        let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
145        p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
146    }
147
148    let rel_residual = vector_norm(&r) / b_norm;
149    CgsSolution {
150        x,
151        iterations: config.max_iterations,
152        residual: rel_residual,
153        converged: false,
154        status: SolverStatus::MaxIterationsReached,
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use super::*;
161    use crate::sparse::CsrMatrix;
162    use ndarray::array;
163    use num_complex::Complex64;
164
165    #[test]
166    fn test_cgs_simple() {
167        let dense = array![
168            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
169            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
170        ];
171
172        let a = CsrMatrix::from_dense(&dense, 1e-15);
173        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
174
175        let config = CgsConfig {
176            max_iterations: 100,
177            tolerance: 1e-10,
178            print_interval: 0,
179        };
180
181        let solution = cgs(&a, &b, &config);
182
183        assert!(solution.converged, "CGS should converge");
184
185        let ax = a.matvec(&solution.x);
186        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
187        assert!(error < 1e-8, "Solution should satisfy Ax = b");
188    }
189}