1use ndarray::Array1;
14use num_complex::Complex64;
15
16#[derive(Debug, Clone)]
18pub struct CgsConfig {
19 pub max_iterations: usize,
21 pub tolerance: f64,
23 pub print_interval: usize,
25}
26
27impl Default for CgsConfig {
28 fn default() -> Self {
29 Self {
30 max_iterations: 1000,
31 tolerance: 1e-6,
32 print_interval: 10,
33 }
34 }
35}
36
37#[derive(Debug)]
39pub struct CgsSolution {
40 pub x: Array1<Complex64>,
42 pub iterations: usize,
44 pub residual: f64,
46 pub converged: bool,
48}
49
50pub fn cgs_solve<F>(
68 matvec: F,
69 b: &Array1<Complex64>,
70 x0: Option<&Array1<Complex64>>,
71 config: &CgsConfig,
72) -> CgsSolution
73where
74 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
75{
76 let n = b.len();
77
78 let mut x = match x0 {
80 Some(x0) => x0.clone(),
81 None => Array1::zeros(n),
82 };
83
84 let ax = matvec(&x);
86 let mut r: Array1<Complex64> = b - &ax;
87
88 let r_tilde = r.clone();
90
91 let mut u = r.clone();
93 let mut p = r.clone();
94
95 let err_ori = residual_norm(&r);
97 if err_ori < 1e-15 {
98 return CgsSolution {
99 x,
100 iterations: 0,
101 residual: 0.0,
102 converged: true,
103 };
104 }
105
106 let mut rho = inner_product(&r, &r_tilde);
108
109 let mut iterations = 0;
110 let mut err_rel = 1.0;
111
112 for j in 0..config.max_iterations {
113 iterations = j + 1;
114
115 let v = matvec(&p);
117
118 let v_r_tilde = inner_product(&v, &r_tilde);
120 if v_r_tilde.norm() < 1e-30 {
121 return CgsSolution {
123 x,
124 iterations,
125 residual: err_rel,
126 converged: false,
127 };
128 }
129 let alpha = rho / v_r_tilde;
130
131 let q: Array1<Complex64> = &u - &(&v * alpha);
133
134 let u_q: Array1<Complex64> = &u + &q;
136
137 let a_uq = matvec(&u_q);
139
140 x = &x + &(&u_q * alpha);
142
143 r = &r - &(&a_uq * alpha);
145
146 let r_norm = residual_norm(&r);
148 err_rel = r_norm / err_ori;
149
150 if config.print_interval > 0 && j % config.print_interval == 0 {
152 eprintln!("CGS iteration {}: relative residual = {:.6e}", j, err_rel);
153 }
154
155 if err_rel < config.tolerance {
157 return CgsSolution {
158 x,
159 iterations,
160 residual: err_rel,
161 converged: true,
162 };
163 }
164
165 let rho_new = inner_product(&r, &r_tilde);
167
168 if rho.norm() < 1e-30 {
170 return CgsSolution {
172 x,
173 iterations,
174 residual: err_rel,
175 converged: false,
176 };
177 }
178 let beta = rho_new / rho;
179
180 u = &r + &(&q * beta);
182
183 let q_beta_p: Array1<Complex64> = &q + &(&p * beta);
185 p = &u + &(&q_beta_p * beta);
186
187 rho = rho_new;
188 }
189
190 CgsSolution {
191 x,
192 iterations,
193 residual: err_rel,
194 converged: false,
195 }
196}
197
198fn inner_product(x: &Array1<Complex64>, y: &Array1<Complex64>) -> Complex64 {
200 x.iter()
201 .zip(y.iter())
202 .map(|(xi, yi)| xi.conj() * yi)
203 .sum()
204}
205
206fn residual_norm(r: &Array1<Complex64>) -> f64 {
208 r.iter().map(|ri| ri.norm_sqr()).sum::<f64>().sqrt()
209}
210
211pub fn cgs_solve_preconditioned<F, P>(
215 matvec: F,
216 precond_solve: P,
217 b: &Array1<Complex64>,
218 x0: Option<&Array1<Complex64>>,
219 config: &CgsConfig,
220) -> CgsSolution
221where
222 F: Fn(&Array1<Complex64>) -> Array1<Complex64>,
223 P: Fn(&Array1<Complex64>) -> Array1<Complex64>,
224{
225 let n = b.len();
226
227 let mut x = match x0 {
229 Some(x0) => x0.clone(),
230 None => Array1::zeros(n),
231 };
232
233 let ax = matvec(&x);
235 let r0: Array1<Complex64> = b - &ax;
236
237 let mut r = precond_solve(&r0);
239
240 let r_tilde = r.clone();
242
243 let mut u = r.clone();
245 let mut p = r.clone();
246
247 let err_ori = residual_norm(&r);
249 if err_ori < 1e-15 {
250 return CgsSolution {
251 x,
252 iterations: 0,
253 residual: 0.0,
254 converged: true,
255 };
256 }
257
258 let mut rho = inner_product(&r, &r_tilde);
259
260 let mut iterations = 0;
261 let mut err_rel = 1.0;
262
263 for j in 0..config.max_iterations {
264 iterations = j + 1;
265
266 let ap = matvec(&p);
268 let v = precond_solve(&ap);
269
270 let v_r_tilde = inner_product(&v, &r_tilde);
272 if v_r_tilde.norm() < 1e-30 {
273 return CgsSolution {
274 x,
275 iterations,
276 residual: err_rel,
277 converged: false,
278 };
279 }
280 let alpha = rho / v_r_tilde;
281
282 let q: Array1<Complex64> = &u - &(&v * alpha);
284
285 let u_q: Array1<Complex64> = &u + &q;
287
288 let a_uq = matvec(&u_q);
290 let ma_uq = precond_solve(&a_uq);
291
292 x = &x + &(&u_q * alpha);
294
295 r = &r - &(&ma_uq * alpha);
297
298 let r_norm = residual_norm(&r);
300 err_rel = r_norm / err_ori;
301
302 if config.print_interval > 0 && j % config.print_interval == 0 {
303 eprintln!(
304 "CGS (precond) iteration {}: relative residual = {:.6e}",
305 j, err_rel
306 );
307 }
308
309 if err_rel < config.tolerance {
310 return CgsSolution {
311 x,
312 iterations,
313 residual: err_rel,
314 converged: true,
315 };
316 }
317
318 let rho_new = inner_product(&r, &r_tilde);
319 if rho.norm() < 1e-30 {
320 return CgsSolution {
321 x,
322 iterations,
323 residual: err_rel,
324 converged: false,
325 };
326 }
327 let beta = rho_new / rho;
328
329 u = &r + &(&q * beta);
330 let q_beta_p: Array1<Complex64> = &q + &(&p * beta);
331 p = &u + &(&q_beta_p * beta);
332
333 rho = rho_new;
334 }
335
336 CgsSolution {
337 x,
338 iterations,
339 residual: err_rel,
340 converged: false,
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347 use ndarray::Array2;
348
349 #[test]
350 fn test_cgs_simple() {
351 let a = Array2::from_shape_vec(
353 (2, 2),
354 vec![
355 Complex64::new(4.0, 0.0),
356 Complex64::new(1.0, 0.0),
357 Complex64::new(1.0, 0.0),
358 Complex64::new(3.0, 0.0),
359 ],
360 )
361 .unwrap();
362
363 let b = Array1::from_vec(vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)]);
364
365 let matvec = |x: &Array1<Complex64>| a.dot(x);
366
367 let config = CgsConfig {
368 max_iterations: 100,
369 tolerance: 1e-10,
370 print_interval: 0,
371 };
372
373 let solution = cgs_solve(&matvec, &b, None, &config);
374
375 assert!(solution.converged, "CGS should converge");
376
377 let ax = a.dot(&solution.x);
379 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
380 assert!(error < 1e-8, "Solution should satisfy Ax = b");
381 }
382
383 #[test]
384 fn test_cgs_complex() {
385 let a = Array2::from_shape_vec(
387 (2, 2),
388 vec![
389 Complex64::new(2.0, 1.0),
390 Complex64::new(0.0, -1.0),
391 Complex64::new(0.0, 1.0),
392 Complex64::new(2.0, -1.0),
393 ],
394 )
395 .unwrap();
396
397 let b = Array1::from_vec(vec![Complex64::new(1.0, 1.0), Complex64::new(1.0, -1.0)]);
398
399 let matvec = |x: &Array1<Complex64>| a.dot(x);
400
401 let config = CgsConfig {
402 max_iterations: 100,
403 tolerance: 1e-10,
404 print_interval: 0,
405 };
406
407 let solution = cgs_solve(&matvec, &b, None, &config);
408
409 assert!(solution.converged, "CGS should converge for complex system");
410
411 let ax = a.dot(&solution.x);
413 let error: f64 = (&ax - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
414 assert!(error < 1e-8, "Solution should satisfy Ax = b");
415 }
416
417 #[test]
418 fn test_cgs_identity() {
419 let n = 5;
421 let b = Array1::from_vec(
422 (1..=n)
423 .map(|i| Complex64::new(i as f64, 0.0))
424 .collect::<Vec<_>>(),
425 );
426
427 let matvec = |x: &Array1<Complex64>| x.clone();
428
429 let config = CgsConfig {
430 max_iterations: 10,
431 tolerance: 1e-12,
432 print_interval: 0,
433 };
434
435 let solution = cgs_solve(&matvec, &b, None, &config);
436
437 assert!(solution.converged);
438 assert!(solution.iterations <= 2); let error: f64 = (&solution.x - &b).iter().map(|e| e.norm_sqr()).sum::<f64>().sqrt();
442 assert!(error < 1e-10);
443 }
444
445 #[test]
446 fn test_inner_product() {
447 let x = Array1::from_vec(vec![
448 Complex64::new(1.0, 2.0),
449 Complex64::new(3.0, -1.0),
450 ]);
451 let y = Array1::from_vec(vec![
452 Complex64::new(2.0, 0.0),
453 Complex64::new(0.0, 1.0),
454 ]);
455
456 let result = inner_product(&x, &y);
457
458 let expected = Complex64::new(1.0, -1.0);
460 assert!((result - expected).norm() < 1e-10);
461 }
462}