math_audio_solvers/iterative/
cgs.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct CgsConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for CgsConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct CgsSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43 pub status: SolverStatus,
45}
46
47pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
49where
50 T: ComplexField,
51 A: LinearOperator<T>,
52{
53 let n = b.len();
54 let mut x = Array1::from_elem(n, T::zero());
55
56 let b_norm = vector_norm(b);
57 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
58 if b_norm < tol_threshold {
59 return CgsSolution {
60 x,
61 iterations: 0,
62 residual: T::Real::zero(),
63 converged: true,
64 status: SolverStatus::Converged,
65 };
66 }
67
68 let mut r = b.clone();
70 let r0 = r.clone(); let mut rho = inner_product(&r0, &r);
73 let mut p = r.clone();
74 let mut u = r.clone();
75
76 for iter in 0..config.max_iterations {
77 let v = operator.apply(&p);
79
80 let sigma = inner_product(&r0, &v);
81 if sigma.norm() < T::Real::from_f64(1e-20).unwrap() {
82 return CgsSolution {
83 x,
84 iterations: iter,
85 residual: vector_norm(&r) / b_norm,
86 converged: false,
87 status: SolverStatus::Breakdown,
88 };
89 }
90
91 let alpha = rho / sigma;
92
93 let q = &u - &v.mapv(|vi| vi * alpha);
95
96 let u_plus_q = &u + &q;
98 let w = operator.apply(&u_plus_q);
99
100 x = &x + &u_plus_q.mapv(|ui| ui * alpha);
102
103 r = &r - &w.mapv(|wi| wi * alpha);
105
106 let rel_residual = vector_norm(&r) / b_norm;
107
108 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
109 log::info!(
110 "CGS iteration {}: relative residual = {:.6e}",
111 iter + 1,
112 rel_residual.to_f64().unwrap_or(0.0)
113 );
114 }
115
116 if rel_residual < config.tolerance {
117 return CgsSolution {
118 x,
119 iterations: iter + 1,
120 residual: rel_residual,
121 converged: true,
122 status: SolverStatus::Converged,
123 };
124 }
125
126 let rho_new = inner_product(&r0, &r);
127 if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
128 return CgsSolution {
129 x,
130 iterations: iter + 1,
131 residual: rel_residual,
132 converged: false,
133 status: SolverStatus::Breakdown,
134 };
135 }
136
137 let beta = rho_new / rho;
138 rho = rho_new;
139
140 u = &r + &q.mapv(|qi| qi * beta);
142
143 let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
145 p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
146 }
147
148 let rel_residual = vector_norm(&r) / b_norm;
149 CgsSolution {
150 x,
151 iterations: config.max_iterations,
152 residual: rel_residual,
153 converged: false,
154 status: SolverStatus::MaxIterationsReached,
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use super::*;
161 use crate::sparse::CsrMatrix;
162 use ndarray::array;
163 use num_complex::Complex64;
164
165 #[test]
166 fn test_cgs_simple() {
167 let dense = array![
168 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
169 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
170 ];
171
172 let a = CsrMatrix::from_dense(&dense, 1e-15);
173 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
174
175 let config = CgsConfig {
176 max_iterations: 100,
177 tolerance: 1e-10,
178 print_interval: 0,
179 };
180
181 let solution = cgs(&a, &b, &config);
182
183 assert!(solution.converged, "CGS should converge");
184
185 let ax = a.matvec(&solution.x);
186 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
187 assert!(error < 1e-8, "Solution should satisfy Ax = b");
188 }
189}