math_audio_solvers/iterative/
cgs.rs

1//! CGS (Conjugate Gradient Squared) solver
2//!
3//! CGS is a Krylov subspace method for non-symmetric systems.
4//! It can converge faster than BiCG but may be less stable.
5
6use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11/// CGS solver configuration
12#[derive(Debug, Clone)]
13pub struct CgsConfig<R> {
14    /// Maximum number of iterations
15    pub max_iterations: usize,
16    /// Relative tolerance for convergence
17    pub tolerance: R,
18    /// Print progress every N iterations (0 = no output)
19    pub print_interval: usize,
20}
21
22impl Default for CgsConfig<f64> {
23    fn default() -> Self {
24        Self {
25            max_iterations: 1000,
26            tolerance: 1e-6,
27            print_interval: 0,
28        }
29    }
30}
31
32/// CGS solver result
33#[derive(Debug)]
34pub struct CgsSolution<T: ComplexField> {
35    /// Solution vector
36    pub x: Array1<T>,
37    /// Number of iterations
38    pub iterations: usize,
39    /// Final relative residual
40    pub residual: T::Real,
41    /// Whether convergence was achieved
42    pub converged: bool,
43}
44
45/// Solve Ax = b using the CGS method
46pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
47where
48    T: ComplexField,
49    A: LinearOperator<T>,
50{
51    let n = b.len();
52    let mut x = Array1::from_elem(n, T::zero());
53
54    let b_norm = vector_norm(b);
55    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
56    if b_norm < tol_threshold {
57        return CgsSolution {
58            x,
59            iterations: 0,
60            residual: T::Real::zero(),
61            converged: true,
62        };
63    }
64
65    // Initial residual
66    let mut r = b.clone();
67    let r0 = r.clone(); // Shadow residual
68
69    let mut rho = inner_product(&r0, &r);
70    let mut p = r.clone();
71    let mut u = r.clone();
72
73    for iter in 0..config.max_iterations {
74        // v = A * p
75        let v = operator.apply(&p);
76
77        let sigma = inner_product(&r0, &v);
78        if sigma.norm() < T::Real::from_f64(1e-30).unwrap() {
79            return CgsSolution {
80                x,
81                iterations: iter,
82                residual: vector_norm(&r) / b_norm,
83                converged: false,
84            };
85        }
86
87        let alpha = rho / sigma;
88
89        // q = u - alpha * v
90        let q = &u - &v.mapv(|vi| vi * alpha);
91
92        // w = A * (u + q)
93        let u_plus_q = &u + &q;
94        let w = operator.apply(&u_plus_q);
95
96        // x = x + alpha * (u + q)
97        x = &x + &u_plus_q.mapv(|ui| ui * alpha);
98
99        // r = r - alpha * w
100        r = &r - &w.mapv(|wi| wi * alpha);
101
102        let rel_residual = vector_norm(&r) / b_norm;
103
104        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
105            log::info!(
106                "CGS iteration {}: relative residual = {:.6e}",
107                iter + 1,
108                rel_residual.to_f64().unwrap_or(0.0)
109            );
110        }
111
112        if rel_residual < config.tolerance {
113            return CgsSolution {
114                x,
115                iterations: iter + 1,
116                residual: rel_residual,
117                converged: true,
118            };
119        }
120
121        let rho_new = inner_product(&r0, &r);
122        if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
123            return CgsSolution {
124                x,
125                iterations: iter + 1,
126                residual: rel_residual,
127                converged: false,
128            };
129        }
130
131        let beta = rho_new / rho;
132        rho = rho_new;
133
134        // u = r + beta * q
135        u = &r + &q.mapv(|qi| qi * beta);
136
137        // p = u + beta * (q + beta * p)
138        let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
139        p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
140    }
141
142    let rel_residual = vector_norm(&r) / b_norm;
143    CgsSolution {
144        x,
145        iterations: config.max_iterations,
146        residual: rel_residual,
147        converged: false,
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::sparse::CsrMatrix;
155    use ndarray::array;
156    use num_complex::Complex64;
157
158    #[test]
159    fn test_cgs_simple() {
160        let dense = array![
161            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
162            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
163        ];
164
165        let a = CsrMatrix::from_dense(&dense, 1e-15);
166        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
167
168        let config = CgsConfig {
169            max_iterations: 100,
170            tolerance: 1e-10,
171            print_interval: 0,
172        };
173
174        let solution = cgs(&a, &b, &config);
175
176        assert!(solution.converged, "CGS should converge");
177
178        let ax = a.matvec(&solution.x);
179        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
180        assert!(error < 1e-8, "Solution should satisfy Ax = b");
181    }
182}