bem/core/solver/
cgs.rs

1//! Conjugate Gradient Squared (CGS) solver
2//!
3//! Direct port of NC_IterativeSolverCGS from NC_CommonFunctions.cpp.
4//!
5//! The CGS algorithm is based on A. Meister p.168 and provides faster
6//! convergence than standard CG for non-symmetric systems.
7//!
8//! ## Algorithm
9//!
10//! CGS squares the convergence polynomial of BiCG without needing
11//! the transpose matrix. This makes it suitable for BEM systems.
12
13use ndarray::Array1;
14use num_complex::Complex64;
15
16/// CGS solver configuration
17#[derive(Debug, Clone)]
18pub struct CgsConfig {
19    /// Maximum number of iterations
20    pub max_iterations: usize,
21    /// Relative tolerance for convergence
22    pub tolerance: f64,
23    /// Print progress every N iterations (0 = no output)
24    pub print_interval: usize,
25}
26
27impl Default for CgsConfig {
28    fn default() -> Self {
29        Self {
30            max_iterations: 1000,
31            tolerance: 1e-6,
32            print_interval: 10,
33        }
34    }
35}
36
37/// CGS solver result
38#[derive(Debug)]
39pub struct CgsSolution {
40    /// Solution vector
41    pub x: Array1<Complex64>,
42    /// Number of iterations performed
43    pub iterations: usize,
44    /// Final relative residual
45    pub residual: f64,
46    /// Whether convergence was achieved
47    pub converged: bool,
48}
49
50/// Solve Ax = b using the Conjugate Gradient Squared method
51///
52/// # Arguments
53/// * `matvec` - Function to compute A*x for a given x
54/// * `b` - Right-hand side vector
55/// * `x0` - Optional initial guess (defaults to zero)
56/// * `config` - Solver configuration
57///
58/// # Returns
59/// Solution struct containing x, iteration count, and convergence info
60///
61/// # Example
62/// ```ignore
63/// let config = CgsConfig::default();
64/// let matvec = |x: &Array1<Complex64>| system.matvec(x);
65/// let solution = cgs_solve(&matvec, &rhs, None, &config);
66/// ```
67pub fn cgs_solve<F>(
68    matvec: F,
69    b: &Array1<Complex64>,
70    x0: Option<&Array1<Complex64>>,
71    config: &CgsConfig,
72) -> CgsSolution
73where
74    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
75{
76    let n = b.len();
77
78    // Initialize solution vector
79    let mut x = match x0 {
80        Some(x0) => x0.clone(),
81        None => Array1::zeros(n),
82    };
83
84    // Initial residual: r = b - A*x
85    let ax = matvec(&x);
86    let mut r: Array1<Complex64> = b - &ax;
87
88    // r̃₀ = r₀ (shadow residual, kept constant)
89    let r_tilde = r.clone();
90
91    // Initialize vectors
92    let mut u = r.clone();
93    let mut p = r.clone();
94
95    // Compute initial residual norm
96    let err_ori = residual_norm(&r);
97    if err_ori < 1e-15 {
98        return CgsSolution {
99            x,
100            iterations: 0,
101            residual: 0.0,
102            converged: true,
103        };
104    }
105
106    // (r₀, r̃₀)
107    let mut rho = inner_product(&r, &r_tilde);
108
109    let mut iterations = 0;
110    let mut err_rel = 1.0;
111
112    for j in 0..config.max_iterations {
113        iterations = j + 1;
114
115        // v = A * p
116        let v = matvec(&p);
117
118        // α = (r_j, r̃₀) / (v, r̃₀)
119        let v_r_tilde = inner_product(&v, &r_tilde);
120        if v_r_tilde.norm() < 1e-30 {
121            // Breakdown - return current solution
122            return CgsSolution {
123                x,
124                iterations,
125                residual: err_rel,
126                converged: false,
127            };
128        }
129        let alpha = rho / v_r_tilde;
130
131        // q = u - α*v
132        let q: Array1<Complex64> = &u - &(&v * alpha);
133
134        // u_q = u + q
135        let u_q: Array1<Complex64> = &u + &q;
136
137        // A*(u + q)
138        let a_uq = matvec(&u_q);
139
140        // x_{j+1} = x_j + α*(u + q)
141        x = &x + &(&u_q * alpha);
142
143        // r_{j+1} = r_j - α*A*(u + q)
144        r = &r - &(&a_uq * alpha);
145
146        // Compute residual norm
147        let r_norm = residual_norm(&r);
148        err_rel = r_norm / err_ori;
149
150        // Print progress
151        if config.print_interval > 0 && j % config.print_interval == 0 {
152            eprintln!("CGS iteration {}: relative residual = {:.6e}", j, err_rel);
153        }
154
155        // Check convergence
156        if err_rel < config.tolerance {
157            return CgsSolution {
158                x,
159                iterations,
160                residual: err_rel,
161                converged: true,
162            };
163        }
164
165        // (r_{j+1}, r̃₀)
166        let rho_new = inner_product(&r, &r_tilde);
167
168        // β = (r_{j+1}, r̃₀) / (r_j, r̃₀)
169        if rho.norm() < 1e-30 {
170            // Breakdown
171            return CgsSolution {
172                x,
173                iterations,
174                residual: err_rel,
175                converged: false,
176            };
177        }
178        let beta = rho_new / rho;
179
180        // u_{j+1} = r_{j+1} + β*q
181        u = &r + &(&q * beta);
182
183        // p_{j+1} = u_{j+1} + β*(q + β*p_j)
184        let q_beta_p: Array1<Complex64> = &q + &(&p * beta);
185        p = &u + &(&q_beta_p * beta);
186
187        rho = rho_new;
188    }
189
190    CgsSolution {
191        x,
192        iterations,
193        residual: err_rel,
194        converged: false,
195    }
196}
197
198/// Compute inner product (x, y) = Σ conj(x_i) * y_i
199fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
200    x.iter()
201        .zip(y.iter())
202        .map(|(xi, yi)| xi.conj() * yi)
203        .sum()
204}
205
206/// Compute residual norm ||r||₂
207fn residual_norm(r: &Array1<Complex64>) -> f64 {
208    r.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt()
209}
210
211/// CGS solver with preconditioner
212///
213/// Solves M⁻¹Ax = M⁻¹b where M is the preconditioner
214pub fn cgs_solve_preconditioned<F, P>(
215    matvec: F,
216    precond_solve: P,
217    b: &Array1<Complex64>,
218    x0: Option<&Array1<Complex64>>,
219    config: &CgsConfig,
220) -> CgsSolution
221where
222    F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
223    P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
224{
225    let n = b.len();
226
227    // Initialize solution vector
228    let mut x = match x0 {
229        Some(x0) => x0.clone(),
230        None => Array1::zeros(n),
231    };
232
233    // Initial residual: r = b - A*x
234    let ax = matvec(&x);
235    let r0: Array1<Complex64> = b - &ax;
236
237    // Apply preconditioner to initial residual
238    let mut r = precond_solve(&r0);
239
240    // r̃₀ = r₀ (shadow residual)
241    let r_tilde = r.clone();
242
243    // Initialize vectors
244    let mut u = r.clone();
245    let mut p = r.clone();
246
247    // Compute initial residual norm
248    let err_ori = residual_norm(&r);
249    if err_ori < 1e-15 {
250        return CgsSolution {
251            x,
252            iterations: 0,
253            residual: 0.0,
254            converged: true,
255        };
256    }
257
258    let mut rho = inner_product(&r, &r_tilde);
259
260    let mut iterations = 0;
261    let mut err_rel = 1.0;
262
263    for j in 0..config.max_iterations {
264        iterations = j + 1;
265
266        // v = M⁻¹ * A * p
267        let ap = matvec(&p);
268        let v = precond_solve(&ap);
269
270        // α = (r_j, r̃₀) / (v, r̃₀)
271        let v_r_tilde = inner_product(&v, &r_tilde);
272        if v_r_tilde.norm() < 1e-30 {
273            return CgsSolution {
274                x,
275                iterations,
276                residual: err_rel,
277                converged: false,
278            };
279        }
280        let alpha = rho / v_r_tilde;
281
282        // q = u - α*v
283        let q: Array1<Complex64> = &u - &(&v * alpha);
284
285        // u_q = u + q
286        let u_q: Array1<Complex64> = &u + &q;
287
288        // A*(u + q)
289        let a_uq = matvec(&u_q);
290        let ma_uq = precond_solve(&a_uq);
291
292        // x_{j+1} = x_j + α*(u + q)
293        x = &x + &(&u_q * alpha);
294
295        // r_{j+1} = r_j - α*M⁻¹*A*(u + q)
296        r = &r - &(&ma_uq * alpha);
297
298        // Compute residual norm
299        let r_norm = residual_norm(&r);
300        err_rel = r_norm / err_ori;
301
302        if config.print_interval > 0 && j % config.print_interval == 0 {
303            eprintln!(
304                "CGS (precond) iteration {}: relative residual = {:.6e}",
305                j, err_rel
306            );
307        }
308
309        if err_rel < config.tolerance {
310            return CgsSolution {
311                x,
312                iterations,
313                residual: err_rel,
314                converged: true,
315            };
316        }
317
318        let rho_new = inner_product(&r, &r_tilde);
319        if rho.norm() < 1e-30 {
320            return CgsSolution {
321                x,
322                iterations,
323                residual: err_rel,
324                converged: false,
325            };
326        }
327        let beta = rho_new / rho;
328
329        u = &r + &(&q * beta);
330        let q_beta_p: Array1<Complex64> = &q + &(&p * beta);
331        p = &u + &(&q_beta_p * beta);
332
333        rho = rho_new;
334    }
335
336    CgsSolution {
337        x,
338        iterations,
339        residual: err_rel,
340        converged: false,
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347    use ndarray::Array2;
348
349    #[test]
350    fn test_cgs_simple() {
351        // Simple 2x2 positive definite system
352        let a = Array2::from_shape_vec(
353            (2, 2),
354            vec![
355                Complex64::new(4.0, 0.0),
356                Complex64::new(1.0, 0.0),
357                Complex64::new(1.0, 0.0),
358                Complex64::new(3.0, 0.0),
359            ],
360        )
361        .unwrap();
362
363        let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
364
365        let matvec = |x: &Array1<Complex64>| a.dot(x);
366
367        let config = CgsConfig {
368            max_iterations: 100,
369            tolerance: 1e-10,
370            print_interval: 0,
371        };
372
373        let solution = cgs_solve(&matvec, &b, None, &config);
374
375        assert!(solution.converged, "CGS should converge");
376
377        // Verify solution: Ax ≈ b
378        let ax = a.dot(&solution.x);
379        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
380        assert!(error < 1e-8, "Solution should satisfy Ax = b");
381    }
382
383    #[test]
384    fn test_cgs_complex() {
385        // Complex system
386        let a = Array2::from_shape_vec(
387            (2, 2),
388            vec![
389                Complex64::new(2.0, 1.0),
390                Complex64::new(0.0, -1.0),
391                Complex64::new(0.0, 1.0),
392                Complex64::new(2.0, -1.0),
393            ],
394        )
395        .unwrap();
396
397        let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
398
399        let matvec = |x: &Array1<Complex64>| a.dot(x);
400
401        let config = CgsConfig {
402            max_iterations: 100,
403            tolerance: 1e-10,
404            print_interval: 0,
405        };
406
407        let solution = cgs_solve(&matvec, &b, None, &config);
408
409        assert!(solution.converged, "CGS should converge for complex system");
410
411        // Verify solution
412        let ax = a.dot(&solution.x);
413        let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
414        assert!(error < 1e-8, "Solution should satisfy Ax = b");
415    }
416
417    #[test]
418    fn test_cgs_identity() {
419        // Identity matrix - should converge in 1 iteration
420        let n = 5;
421        let b = Array1::from_vec(
422            (1..=n)
423                .map(|i| Complex64::new(i as f64, 0.0))
424                .collect::<Vec<_>>(),
425        );
426
427        let matvec = |x: &Array1<Complex64>| x.clone();
428
429        let config = CgsConfig {
430            max_iterations: 10,
431            tolerance: 1e-12,
432            print_interval: 0,
433        };
434
435        let solution = cgs_solve(&matvec, &b, None, &config);
436
437        assert!(solution.converged);
438        assert!(solution.iterations <= 2); // Should converge very quickly
439
440        // x should equal b for identity matrix
441        let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
442        assert!(error < 1e-10);
443    }
444
445    #[test]
446    fn test_inner_product() {
447        let x = Array1::from_vec(vec![
448            Complex64::new(1.0, 2.0),
449            Complex64::new(3.0, -1.0),
450        ]);
451        let y = Array1::from_vec(vec![
452            Complex64::new(2.0, 0.0),
453            Complex64::new(0.0, 1.0),
454        ]);
455
456        let result = inner_product(&x, &y);
457
458        // (1-2i)*2 + (3+1i)*(1i) = 2-4i + 3i-1 = 1-i
459        let expected = Complex64::new(1.0, -1.0);
460        assert!((result - expected).norm() < 1e-10);
461    }
462}