math_audio_solvers/iterative/
cgs.rs1use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10#[derive(Debug, Clone)]
12pub struct CgsConfig<R> {
13 pub max_iterations: usize,
15 pub tolerance: R,
17 pub print_interval: usize,
19}
20
21impl Default for CgsConfig<f64> {
22 fn default() -> Self {
23 Self {
24 max_iterations: 1000,
25 tolerance: 1e-6,
26 print_interval: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
33pub struct CgsSolution<T: ComplexField> {
34 pub x: Array1<T>,
36 pub iterations: usize,
38 pub residual: T::Real,
40 pub converged: bool,
42}
43
44pub fn cgs<T, A>(operator: &A, b: &Array1<T>, config: &CgsConfig<T::Real>) -> CgsSolution<T>
46where
47 T: ComplexField,
48 A: LinearOperator<T>,
49{
50 let n = b.len();
51 let mut x = Array1::from_elem(n, T::zero());
52
53 let b_norm = vector_norm(b);
54 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
55 if b_norm < tol_threshold {
56 return CgsSolution {
57 x,
58 iterations: 0,
59 residual: T::Real::zero(),
60 converged: true,
61 };
62 }
63
64 let mut r = b.clone();
66 let r0 = r.clone(); let mut rho = inner_product(&r0, &r);
69 let mut p = r.clone();
70 let mut u = r.clone();
71
72 for iter in 0..config.max_iterations {
73 let v = operator.apply(&p);
75
76 let sigma = inner_product(&r0, &v);
77 if sigma.norm() < T::Real::from_f64(1e-30).unwrap() {
78 return CgsSolution {
79 x,
80 iterations: iter,
81 residual: vector_norm(&r) / b_norm,
82 converged: false,
83 };
84 }
85
86 let alpha = rho / sigma;
87
88 let q = &u - &v.mapv(|vi| vi * alpha);
90
91 let u_plus_q = &u + &q;
93 let w = operator.apply(&u_plus_q);
94
95 x = &x + &u_plus_q.mapv(|ui| ui * alpha);
97
98 r = &r - &w.mapv(|wi| wi * alpha);
100
101 let rel_residual = vector_norm(&r) / b_norm;
102
103 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
104 log::info!(
105 "CGS iteration {}: relative residual = {:.6e}",
106 iter + 1,
107 rel_residual.to_f64().unwrap_or(0.0)
108 );
109 }
110
111 if rel_residual < config.tolerance {
112 return CgsSolution {
113 x,
114 iterations: iter + 1,
115 residual: rel_residual,
116 converged: true,
117 };
118 }
119
120 let rho_new = inner_product(&r0, &r);
121 if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
122 return CgsSolution {
123 x,
124 iterations: iter + 1,
125 residual: rel_residual,
126 converged: false,
127 };
128 }
129
130 let beta = rho_new / rho;
131 rho = rho_new;
132
133 u = &r + &q.mapv(|qi| qi * beta);
135
136 let q_plus_beta_p = &q + &p.mapv(|pi| pi * beta);
138 p = &u + &q_plus_beta_p.mapv(|vi| vi * beta);
139 }
140
141 let rel_residual = vector_norm(&r) / b_norm;
142 CgsSolution {
143 x,
144 iterations: config.max_iterations,
145 residual: rel_residual,
146 converged: false,
147 }
148}
149
150#[inline]
151fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
152 x.iter()
153 .zip(y.iter())
154 .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
155}
156
157#[inline]
158fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
159 x.iter()
160 .map(|xi| xi.norm_sqr())
161 .fold(T::Real::zero(), |acc, v| acc + v)
162 .sqrt()
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::sparse::CsrMatrix;
169 use ndarray::array;
170 use num_complex::Complex64;
171
172 #[test]
173 fn test_cgs_simple() {
174 let dense = array![
175 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
176 [Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
177 ];
178
179 let a = CsrMatrix::from_dense(&dense, 1e-15);
180 let b = array![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)];
181
182 let config = CgsConfig {
183 max_iterations: 100,
184 tolerance: 1e-10,
185 print_interval: 0,
186 };
187
188 let solution = cgs(&a, &b, &config);
189
190 assert!(solution.converged, "CGS should converge");
191
192 let ax = a.matvec(&solution.x);
193 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
194 assert!(error < 1e-8, "Solution should satisfy Ax = b");
195 }
196}