math_audio_solvers/iterative/
cgs.rs

1//! CGS (Conjugate Gradient Squared) solver
2//!
3//! CGS is a Krylov subspace method for non-symmetric systems.
4//! It can converge faster than BiCG but may be less stable.
5
6use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10/// CGS solver configuration
11#[derive(Debug, Clone)]
12pub struct CgsConfig<R> {
13    /// Maximum number of iterations
14    pub max_iterations: usize,
15    /// Relative tolerance for convergence
16    pub tolerance: R,
17    /// Print progress every N iterations (0 = no output)
18    pub print_interval: usize,
19}
20
21impl Default for CgsConfig<f64> {
22    fn default() -> Self {
23        Self {
24            max_iterations: 1000,
25            tolerance: 1e-6,
26            print_interval: 0,
27        }
28    }
29}
30
31/// CGS solver result
32#[derive(Debug)]
33pub struct CgsSolution<T: ComplexField> {
34    /// Solution vector
35    pub x: Array1<T>,
36    /// Number of iterations
37    pub iterations: usize,
38    /// Final relative residual
39    pub residual: T::Real,
40    /// Whether convergence was achieved
41    pub converged: bool,
42}
43
44/// Solve Ax = b using the CGS method
45pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
46where
47    T: ComplexField,
48    A: LinearOperator<T>,
49{
50    let n = b.len();
51    let mut x = Array1::from_elem(n, T::zero());
52
53    let b_norm = vector_norm(b);
54    let tol_threshold = T::Real::from_f64(1e-15).unwrap();
55    if b_norm < tol_threshold {
56        return CgsSolution {
57            x,
58            iterations: 0,
59            residual: T::Real::zero(),
60            converged: true,
61        };
62    }
63
64    // Initial residual
65    let mut r = b.clone();
66    let r0 = r.clone(); // Shadow residual
67
68    let mut rho = inner_product(&r0, &r);
69    let mut p = r.clone();
70    let mut u = r.clone();
71
72    for iter in 0..config.max_iterations {
73        // v = A * p
74        let v = operator.apply(&p);
75
76        let sigma = inner_product(&r0, &v);
77        if sigma.norm() < T::Real::from_f64(1e-30).unwrap() {
78            return CgsSolution {
79                x,
80                iterations: iter,
81                residual: vector_norm(&r) / b_norm,
82                converged: false,
83            };
84        }
85
86        let alpha = rho / sigma;
87
88        // q = u - alpha * v
89        let q = &u - &v.mapv(|vi| vi * alpha);
90
91        // w = A * (u + q)
92        let u_plus_q = &u + &q;
93        let w = operator.apply(&u_plus_q);
94
95        // x = x + alpha * (u + q)
96        x = &x + &u_plus_q.mapv(|ui| ui * alpha);
97
98        // r = r - alpha * w
99        r = &r - &w.mapv(|wi| wi * alpha);
100
101        let rel_residual = vector_norm(&r) / b_norm;
102
103        if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
104            log::info!(
105                "CGS iteration {}: relative residual = {:.6e}",
106                iter + 1,
107                rel_residual.to_f64().unwrap_or(0.0)
108            );
109        }
110
111        if rel_residual < config.tolerance {
112            return CgsSolution {
113                x,
114                iterations: iter + 1,
115                residual: rel_residual,
116                converged: true,
117            };
118        }
119
120        let rho_new = inner_product(&r0, &r);
121        if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
122            return CgsSolution {
123                x,
124                iterations: iter + 1,
125                residual: rel_residual,
126                converged: false,
127            };
128        }
129
130        let beta = rho_new / rho;
131        rho = rho_new;
132
133        // u = r + beta * q
134        u = &r + &q.mapv(|qi| qi * beta);
135
136        // p = u + beta * (q + beta * p)
137        let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
138        p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
139    }
140
141    let rel_residual = vector_norm(&r) / b_norm;
142    CgsSolution {
143        x,
144        iterations: config.max_iterations,
145        residual: rel_residual,
146        converged: false,
147    }
148}
149
150#[inline]
151fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
152    x.iter()
153        .zip(y.iter())
154        .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
155}
156
157#[inline]
158fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
159    x.iter()
160        .map(|xi| xi.norm_sqr())
161        .fold(T::Real::zero(), |acc, v| acc + v)
162        .sqrt()
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::sparse::CsrMatrix;
169    use ndarray::array;
170    use num_complex::Complex64;
171
172    #[test]
173    fn test_cgs_simple() {
174        let dense = array![
175            [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
176            [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
177        ];
178
179        let a = CsrMatrix::from_dense(&dense, 1e-15);
180        let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
181
182        let config = CgsConfig {
183            max_iterations: 100,
184            tolerance: 1e-10,
185            print_interval: 0,
186        };
187
188        let solution = cgs(&a, &b, &config);
189
190        assert!(solution.converged, "CGS should converge");
191
192        let ax = a.matvec(&solution.x);
193        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
194        assert!(error < 1e-8, "Solution should satisfy Ax = b");
195    }
196}