math_audio_solvers/iterative/
cg.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator, SolverStatus};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct CgConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for CgConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct CgSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43 pub status: SolverStatus,
45}
46
47pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
52where
53 T: ComplexField,
54 A: LinearOperator<T>,
55{
56 let n = b.len();
57 let mut x = Array1::from_elem(n, T::zero());
58
59 let b_norm = vector_norm(b);
60 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
61 if b_norm < tol_threshold {
62 return CgSolution {
63 x,
64 iterations: 0,
65 residual: T::Real::zero(),
66 converged: true,
67 status: SolverStatus::Converged,
68 };
69 }
70
71 let mut r = b.clone();
73 let mut p = r.clone();
74 let mut rho = inner_product(&r, &r);
75
76 for iter in 0..config.max_iterations {
77 let q = operator.apply(&p);
79
80 let pq = inner_product(&p, &q);
82 if pq.norm() < T::Real::from_f64(1e-20).unwrap() {
83 return CgSolution {
84 x,
85 iterations: iter,
86 residual: vector_norm(&r) / b_norm,
87 converged: false,
88 status: SolverStatus::Breakdown,
89 };
90 }
91
92 let alpha = rho / pq;
93
94 x = &x + &p.mapv(|pi| pi * alpha);
96
97 r = &r - &q.mapv(|qi| qi * alpha);
99
100 let rel_residual = vector_norm(&r) / b_norm;
101
102 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
103 log::info!(
104 "CG iteration {}: relative residual = {:.6e}",
105 iter + 1,
106 rel_residual.to_f64().unwrap_or(0.0)
107 );
108 }
109
110 if rel_residual < config.tolerance {
111 return CgSolution {
112 x,
113 iterations: iter + 1,
114 residual: rel_residual,
115 converged: true,
116 status: SolverStatus::Converged,
117 };
118 }
119
120 let rho_new = inner_product(&r, &r);
121 if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
122 return CgSolution {
123 x,
124 iterations: iter + 1,
125 residual: rel_residual,
126 converged: false,
127 status: SolverStatus::Breakdown,
128 };
129 }
130
131 let beta = rho_new / rho;
132 rho = rho_new;
133
134 p = &r + &p.mapv(|pi| pi * beta);
136 }
137
138 let rel_residual = vector_norm(&r) / b_norm;
139 CgSolution {
140 x,
141 iterations: config.max_iterations,
142 residual: rel_residual,
143 converged: false,
144 status: SolverStatus::MaxIterationsReached,
145 }
146}
147
148pub fn pcg<T, A, P>(
156 operator: &A,
157 precond: &P,
158 b: &Array1<T>,
159 config: &CgConfig<T::Real>,
160) -> CgSolution<T>
161where
162 T: ComplexField,
163 A: LinearOperator<T>,
164 P: crate::traits::Preconditioner<T>,
165{
166 let n = b.len();
167 let mut x = Array1::from_elem(n, T::zero());
168
169 let b_norm = vector_norm(b);
170 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
171 if b_norm < tol_threshold {
172 return CgSolution {
173 x,
174 iterations: 0,
175 residual: T::Real::zero(),
176 converged: true,
177 status: SolverStatus::Converged,
178 };
179 }
180
181 let mut r = b.clone();
183 let mut z = precond.apply(&r);
184 let mut p = z.clone();
185 let mut rho = inner_product(&r, &z);
186
187 for iter in 0..config.max_iterations {
188 let q = operator.apply(&p);
189
190 let pq = inner_product(&p, &q);
191 if pq.norm() < T::Real::from_f64(1e-20).unwrap() {
192 return CgSolution {
193 x,
194 iterations: iter,
195 residual: vector_norm(&r) / b_norm,
196 converged: false,
197 status: SolverStatus::Breakdown,
198 };
199 }
200
201 let alpha = rho / pq;
202
203 x = &x + &p.mapv(|pi| pi * alpha);
205
206 r = &r - &q.mapv(|qi| qi * alpha);
208
209 let rel_residual = vector_norm(&r) / b_norm;
210
211 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
212 log::info!(
213 "PCG iteration {}: relative residual = {:.6e}",
214 iter + 1,
215 rel_residual.to_f64().unwrap_or(0.0)
216 );
217 }
218
219 if rel_residual < config.tolerance {
220 return CgSolution {
221 x,
222 iterations: iter + 1,
223 residual: rel_residual,
224 converged: true,
225 status: SolverStatus::Converged,
226 };
227 }
228
229 z = precond.apply(&r);
230 let rho_new = inner_product(&r, &z);
231 if rho_new.norm() < T::Real::from_f64(1e-20).unwrap() {
232 return CgSolution {
233 x,
234 iterations: iter + 1,
235 residual: rel_residual,
236 converged: false,
237 status: SolverStatus::Breakdown,
238 };
239 }
240
241 let beta = rho_new / rho;
242 rho = rho_new;
243
244 p = &z + &p.mapv(|pi| pi * beta);
246 }
247
248 let rel_residual = vector_norm(&r) / b_norm;
249 CgSolution {
250 x,
251 iterations: config.max_iterations,
252 residual: rel_residual,
253 converged: false,
254 status: SolverStatus::MaxIterationsReached,
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261 use crate::sparse::CsrMatrix;
262 use ndarray::array;
263
264 #[test]
265 fn test_cg_spd() {
266 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
268
269 let a = CsrMatrix::from_dense(&dense, 1e-15);
270 let b = array![1.0_f64, 2.0];
271
272 let config = CgConfig {
273 max_iterations: 100,
274 tolerance: 1e-10,
275 print_interval: 0,
276 };
277
278 let solution = cg(&a, &b, &config);
279
280 assert!(solution.converged, "CG should converge for SPD matrix");
281
282 let ax = a.matvec(&solution.x);
283 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
284 assert!(error < 1e-8, "Solution should satisfy Ax = b");
285 }
286
287 #[test]
288 fn test_cg_identity() {
289 let n = 5;
290 let id: CsrMatrix<f64> = CsrMatrix::identity(n);
291 let b = Array1::from_iter((1..=n).map(|i| i as f64));
292
293 let config = CgConfig {
294 max_iterations: 10,
295 tolerance: 1e-12,
296 print_interval: 0,
297 };
298
299 let solution = cg(&id, &b, &config);
300
301 assert!(solution.converged);
302 assert!(solution.iterations <= 2);
303
304 let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
305 assert!(error < 1e-10);
306 }
307}