math_audio_solvers/iterative/
cgs.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct CgsConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for CgsConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct CgsSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43}
44
45pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
47where
48 T: ComplexField,
49 A: LinearOperator<T>,
50{
51 let n = b.len();
52 let mut x = Array1::from_elem(n, T::zero());
53
54 let b_norm = vector_norm(b);
55 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
56 if b_norm < tol_threshold {
57 return CgsSolution {
58 x,
59 iterations: 0,
60 residual: T::Real::zero(),
61 converged: true,
62 };
63 }
64
65 let mut r = b.clone();
67 let r0 = r.clone(); let mut rho = inner_product(&r0, &r);
70 let mut p = r.clone();
71 let mut u = r.clone();
72
73 for iter in 0..config.max_iterations {
74 let v = operator.apply(&p);
76
77 let sigma = inner_product(&r0, &v);
78 if sigma.norm() < T::Real::from_f64(1e-30).unwrap() {
79 return CgsSolution {
80 x,
81 iterations: iter,
82 residual: vector_norm(&r) / b_norm,
83 converged: false,
84 };
85 }
86
87 let alpha = rho / sigma;
88
89 let q = &u - &v.mapv(|vi| vi * alpha);
91
92 let u_plus_q = &u + &q;
94 let w = operator.apply(&u_plus_q);
95
96 x = &x + &u_plus_q.mapv(|ui| ui * alpha);
98
99 r = &r - &w.mapv(|wi| wi * alpha);
101
102 let rel_residual = vector_norm(&r) / b_norm;
103
104 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
105 log::info!(
106 "CGS iteration {}: relative residual = {:.6e}",
107 iter + 1,
108 rel_residual.to_f64().unwrap_or(0.0)
109 );
110 }
111
112 if rel_residual < config.tolerance {
113 return CgsSolution {
114 x,
115 iterations: iter + 1,
116 residual: rel_residual,
117 converged: true,
118 };
119 }
120
121 let rho_new = inner_product(&r0, &r);
122 if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
123 return CgsSolution {
124 x,
125 iterations: iter + 1,
126 residual: rel_residual,
127 converged: false,
128 };
129 }
130
131 let beta = rho_new / rho;
132 rho = rho_new;
133
134 u = &r + &q.mapv(|qi| qi * beta);
136
137 let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
139 p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
140 }
141
142 let rel_residual = vector_norm(&r) / b_norm;
143 CgsSolution {
144 x,
145 iterations: config.max_iterations,
146 residual: rel_residual,
147 converged: false,
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::sparse::CsrMatrix;
155 use ndarray::array;
156 use num_complex::Complex64;
157
158 #[test]
159 fn test_cgs_simple() {
160 let dense = array![
161 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
162 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
163 ];
164
165 let a = CsrMatrix::from_dense(&dense, 1e-15);
166 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
167
168 let config = CgsConfig {
169 max_iterations: 100,
170 tolerance: 1e-10,
171 print_interval: 0,
172 };
173
174 let solution = cgs(&a, &b, &config);
175
176 assert!(solution.converged, "CGS should converge");
177
178 let ax = a.matvec(&solution.x);
179 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
180 assert!(error < 1e-8, "Solution should satisfy Ax = b");
181 }
182}