1use crate::traits::{ComplexField, LinearOperator};
7use ndarray::Array1;
8use num_traits::{Float, FromPrimitive, ToPrimitive, Zero};
9
10#[derive(Debug, Clone)]
12pub struct CgConfig<R> {
13 pub max_iterations: usize,
15 pub tolerance: R,
17 pub print_interval: usize,
19}
20
21impl Default for CgConfig<f64> {
22 fn default() -> Self {
23 Self {
24 max_iterations: 1000,
25 tolerance: 1e-6,
26 print_interval: 0,
27 }
28 }
29}
30
31#[derive(Debug)]
33pub struct CgSolution<T: ComplexField> {
34 pub x: Array1<T>,
36 pub iterations: usize,
38 pub residual: T::Real,
40 pub converged: bool,
42}
43
44pub fn cg<T, A>(operator: &A, b: &Array1<T>, config: &CgConfig<T::Real>) -> CgSolution<T>
49where
50 T: ComplexField,
51 A: LinearOperator<T>,
52{
53 let n = b.len();
54 let mut x = Array1::from_elem(n, T::zero());
55
56 let b_norm = vector_norm(b);
57 let tol_threshold = T::Real::from_f64(1e-15).unwrap();
58 if b_norm < tol_threshold {
59 return CgSolution {
60 x,
61 iterations: 0,
62 residual: T::Real::zero(),
63 converged: true,
64 };
65 }
66
67 let mut r = b.clone();
69 let mut p = r.clone();
70 let mut rho = inner_product(&r, &r);
71
72 for iter in 0..config.max_iterations {
73 let q = operator.apply(&p);
75
76 let pq = inner_product(&p, &q);
78 if pq.norm() < T::Real::from_f64(1e-30).unwrap() {
79 return CgSolution {
80 x,
81 iterations: iter,
82 residual: vector_norm(&r) / b_norm,
83 converged: false,
84 };
85 }
86
87 let alpha = rho / pq;
88
89 x = &x + &p.mapv(|pi| pi * alpha);
91
92 r = &r - &q.mapv(|qi| qi * alpha);
94
95 let rel_residual = vector_norm(&r) / b_norm;
96
97 if config.print_interval > 0 && (iter + 1) % config.print_interval == 0 {
98 log::info!(
99 "CG iteration {}: relative residual = {:.6e}",
100 iter + 1,
101 rel_residual.to_f64().unwrap_or(0.0)
102 );
103 }
104
105 if rel_residual < config.tolerance {
106 return CgSolution {
107 x,
108 iterations: iter + 1,
109 residual: rel_residual,
110 converged: true,
111 };
112 }
113
114 let rho_new = inner_product(&r, &r);
115 if rho.norm() < T::Real::from_f64(1e-30).unwrap() {
116 return CgSolution {
117 x,
118 iterations: iter + 1,
119 residual: rel_residual,
120 converged: false,
121 };
122 }
123
124 let beta = rho_new / rho;
125 rho = rho_new;
126
127 p = &r + &p.mapv(|pi| pi * beta);
129 }
130
131 let rel_residual = vector_norm(&r) / b_norm;
132 CgSolution {
133 x,
134 iterations: config.max_iterations,
135 residual: rel_residual,
136 converged: false,
137 }
138}
139
140#[inline]
141fn inner_product<T: ComplexField>(x: &Array1<T>, y: &Array1<T>) -> T {
142 x.iter()
143 .zip(y.iter())
144 .fold(T::zero(), |acc, (&xi, &yi)| acc + xi.conj() * yi)
145}
146
147#[inline]
148fn vector_norm<T: ComplexField>(x: &Array1<T>) -> T::Real {
149 x.iter()
150 .map(|xi| xi.norm_sqr())
151 .fold(T::Real::zero(), |acc, v| acc + v)
152 .sqrt()
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158 use crate::sparse::CsrMatrix;
159 use ndarray::array;
160
161 #[test]
162 fn test_cg_spd() {
163 let dense = array![[4.0_f64, 1.0], [1.0, 3.0],];
165
166 let a = CsrMatrix::from_dense(&dense, 1e-15);
167 let b = array![1.0_f64, 2.0];
168
169 let config = CgConfig {
170 max_iterations: 100,
171 tolerance: 1e-10,
172 print_interval: 0,
173 };
174
175 let solution = cg(&a, &b, &config);
176
177 assert!(solution.converged, "CG should converge for SPD matrix");
178
179 let ax = a.matvec(&solution.x);
180 let error: f64 = (&ax - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
181 assert!(error < 1e-8, "Solution should satisfy Ax = b");
182 }
183
184 #[test]
185 fn test_cg_identity() {
186 let n = 5;
187 let id: CsrMatrix<f64> = CsrMatrix::identity(n);
188 let b = Array1::from_iter((1..=n).map(|i| i as f64));
189
190 let config = CgConfig {
191 max_iterations: 10,
192 tolerance: 1e-12,
193 print_interval: 0,
194 };
195
196 let solution = cg(&id, &b, &config);
197
198 assert!(solution.converged);
199 assert!(solution.iterations <= 2);
200
201 let error: f64 = (&solution.x - &b).iter().map(|e| e * e).sum::<f64>().sqrt();
202 assert!(error < 1e-10);
203 }
204}