math_audio_solvers/iterative/
cg.rs1use crate::blas_helpers::{inner_product, vector_norm};
7use crate::traits::{ComplexField, LinearOperator};
8use ndarray::Array1;
9use num_traits::{FromPrimitive, ToPrimitive, Zero};
10
11#[derive(Debug, Clone)]
13pub struct CgConfig<R> {
14 pub max_iterations: usize,
16 pub tolerance: R,
18 pub print_interval: usize,
20}
21
22impl Default for CgConfig<f64> {
23 fn default() -> Self {
24 Self {
25 max_iterations: 1000,
26 tolerance: 1e-6,
27 print_interval: 0,
28 }
29 }
30}
31
32#[derive(Debug)]
34pub struct CgSolution<T: ComplexField> {
35 pub x: Array1<T>,
37 pub iterations: usize,
39 pub residual: T::Real,
41 pub converged: bool,
43}
44
45pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
50where
51 T: ComplexField,
52 A: LinearOperator<T>,
53{
54 let n = b.len();
55 let mut x = Array1::from_elem(n, T::zero());
56
57 let b_norm = vector_norm(b);
58 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
59 if b_norm < tol_threshold {
60 return CgSolution {
61 x,
62 iterations: 0,
63 residual: T::Real::zero(),
64 converged: true,
65 };
66 }
67
68 let mut r = b.clone();
70 let mut p = r.clone();
71 let mut rho = inner_product(&r, &r);
72
73 for iter in 0..config.max_iterations {
74 let q = operator.apply(&p);
76
77 let pq = inner_product(&p, &q);
79 if pq.norm() < T::Real::from_f64(1e-30).unwrap() {
80 return CgSolution {
81 x,
82 iterations: iter,
83 residual: vector_norm(&r) / b_norm,
84 converged: false,
85 };
86 }
87
88 let alpha = rho / pq;
89
90 x = &x + &p.mapv(|pi| pi * alpha);
92
93 r = &r - &q.mapv(|qi| qi * alpha);
95
96 let rel_residual = vector_norm(&r) / b_norm;
97
98 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
99 log::info!(
100 "CG iteration {}: relative residual = {:.6e}",
101 iter + 1,
102 rel_residual.to_f64().unwrap_or(0.0)
103 );
104 }
105
106 if rel_residual < config.tolerance {
107 return CgSolution {
108 x,
109 iterations: iter + 1,
110 residual: rel_residual,
111 converged: true,
112 };
113 }
114
115 let rho_new = inner_product(&r, &r);
116 if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
117 return CgSolution {
118 x,
119 iterations: iter + 1,
120 residual: rel_residual,
121 converged: false,
122 };
123 }
124
125 let beta = rho_new / rho;
126 rho = rho_new;
127
128 p = &r + &p.mapv(|pi| pi * beta);
130 }
131
132 let rel_residual = vector_norm(&r) / b_norm;
133 CgSolution {
134 x,
135 iterations: config.max_iterations,
136 residual: rel_residual,
137 converged: false,
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use crate::sparse::CsrMatrix;
145 use ndarray::array;
146
147 #[test]
148 fn test_cg_spd() {
149 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
151
152 let a = CsrMatrix::from_dense(&dense, 1e-15);
153 let b = array![1.0_f64, 2.0];
154
155 let config = CgConfig {
156 max_iterations: 100,
157 tolerance: 1e-10,
158 print_interval: 0,
159 };
160
161 let solution = cg(&a, &b, &config);
162
163 assert!(solution.converged, "CG should converge for SPD matrix");
164
165 let ax = a.matvec(&solution.x);
166 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
167 assert!(error < 1e-8, "Solution should satisfy Ax = b");
168 }
169
170 #[test]
171 fn test_cg_identity() {
172 let n = 5;
173 let id: CsrMatrix<f64> = CsrMatrix::identity(n);
174 let b = Array1::from_iter((1..=n).map(|i| i as f64));
175
176 let config = CgConfig {
177 max_iterations: 10,
178 tolerance: 1e-12,
179 print_interval: 0,
180 };
181
182 let solution = cg(&id, &b, &config);
183
184 assert!(solution.converged);
185 assert!(solution.iterations <= 2);
186
187 let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
188 assert!(error < 1e-10);
189 }
190}