1use crate::accumulator::binned_sum_f64;
8use crate::sparse::SparseCsr;
9
10#[derive(Debug, Clone)]
14pub struct SolverResult {
15 pub x: Vec<f64>,
17 pub iterations: usize,
19 pub residual: f64,
21 pub converged: bool,
23}
24
25fn dot(a: &[f64], b: &[f64]) -> f64 {
31 let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
32 binned_sum_f64(&products)
33}
34
35fn norm2(v: &[f64]) -> f64 {
37 dot(v, v).sqrt()
38}
39
40fn axpy(a: f64, x: &[f64], y: &mut [f64]) {
42 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
43 *yi += a * xi;
44 }
45}
46
47fn compute_residual(a: &SparseCsr, x: &[f64], b: &[f64]) -> Vec<f64> {
49 let ax = spmv(a, x);
50 b.iter().zip(ax.iter()).map(|(&bi, &axi)| bi - axi).collect()
51}
52
53fn spmv(a: &SparseCsr, x: &[f64]) -> Vec<f64> {
55 let mut y = vec![0.0f64; a.nrows];
56 for row in 0..a.nrows {
57 let start = a.row_offsets[row];
58 let end = a.row_offsets[row + 1];
59 let products: Vec<f64> = (start..end)
60 .map(|idx| a.values[idx] * x[a.col_indices[idx]])
61 .collect();
62 y[row] = binned_sum_f64(&products);
63 }
64 y
65}
66
67pub fn cg_solve(a: &SparseCsr, b: &[f64], tol: f64, max_iter: usize) -> SolverResult {
82 let n = b.len();
83 assert_eq!(a.nrows, n, "cg_solve: matrix rows must match rhs length");
84 assert_eq!(a.ncols, n, "cg_solve: matrix must be square");
85
86 let mut x = vec![0.0f64; n];
87 let mut r = b.to_vec(); let mut p = r.clone();
89 let mut rs_old = dot(&r, &r);
90 let b_norm = norm2(b);
91
92 if b_norm == 0.0 {
93 return SolverResult {
94 x,
95 iterations: 0,
96 residual: 0.0,
97 converged: true,
98 };
99 }
100
101 for iter in 0..max_iter {
102 let ap = spmv(a, &p);
103 let p_ap = dot(&p, &ap);
104
105 if p_ap == 0.0 {
106 return SolverResult {
108 x,
109 iterations: iter + 1,
110 residual: norm2(&r) / b_norm,
111 converged: false,
112 };
113 }
114
115 let alpha = rs_old / p_ap;
116
117 axpy(alpha, &p, &mut x);
119 axpy(-alpha, &ap, &mut r);
121
122 let rs_new = dot(&r, &r);
123 let res_norm = rs_new.sqrt() / b_norm;
124
125 if res_norm < tol {
126 return SolverResult {
127 x,
128 iterations: iter + 1,
129 residual: res_norm,
130 converged: true,
131 };
132 }
133
134 let beta = rs_new / rs_old;
135 for i in 0..n {
137 p[i] = r[i] + beta * p[i];
138 }
139 rs_old = rs_new;
140 }
141
142 SolverResult {
143 x,
144 iterations: max_iter,
145 residual: norm2(&r) / b_norm,
146 converged: false,
147 }
148}
149
150pub fn gmres_solve(
166 a: &SparseCsr,
167 b: &[f64],
168 tol: f64,
169 max_iter: usize,
170 restart: usize,
171) -> SolverResult {
172 let n = b.len();
173 assert_eq!(a.nrows, n, "gmres_solve: matrix rows must match rhs length");
174 assert_eq!(a.ncols, n, "gmres_solve: matrix must be square");
175
176 let b_norm = norm2(b);
177 if b_norm == 0.0 {
178 return SolverResult {
179 x: vec![0.0; n],
180 iterations: 0,
181 residual: 0.0,
182 converged: true,
183 };
184 }
185
186 let mut x = vec![0.0f64; n];
187 let mut total_iter = 0;
188
189 while total_iter < max_iter {
190 let r = compute_residual(a, &x, b);
191 let r_norm = norm2(&r);
192
193 if r_norm / b_norm < tol {
194 return SolverResult {
195 x,
196 iterations: total_iter,
197 residual: r_norm / b_norm,
198 converged: true,
199 };
200 }
201
202 let m = restart.min(max_iter - total_iter);
204
205 let mut v: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
207 let v0: Vec<f64> = r.iter().map(|&ri| ri / r_norm).collect();
209 v.push(v0);
210
211 let mut h = vec![vec![0.0f64; m]; m + 1];
213
214 let mut cs = vec![0.0f64; m];
216 let mut sn = vec![0.0f64; m];
217
218 let mut g = vec![0.0f64; m + 1];
220 g[0] = r_norm;
221
222 let mut last_res = r_norm / b_norm;
223
224 for j in 0..m {
225 total_iter += 1;
226
227 let mut w = spmv(a, &v[j]);
229
230 for i in 0..=j {
232 h[i][j] = dot(&w, &v[i]);
233 axpy(-h[i][j], &v[i], &mut w);
235 }
236
237 h[j + 1][j] = norm2(&w);
238
239 if h[j + 1][j] > 1e-14 {
240 let inv = 1.0 / h[j + 1][j];
241 let vn: Vec<f64> = w.iter().map(|&wi| wi * inv).collect();
242 v.push(vn);
243 } else {
244 v.push(vec![0.0; n]);
246 }
247
248 for i in 0..j {
250 let tmp = cs[i] * h[i][j] + sn[i] * h[i + 1][j];
251 h[i + 1][j] = -sn[i] * h[i][j] + cs[i] * h[i + 1][j];
252 h[i][j] = tmp;
253 }
254
255 let (c, s) = givens_rotation(h[j][j], h[j + 1][j]);
257 cs[j] = c;
258 sn[j] = s;
259
260 h[j][j] = c * h[j][j] + s * h[j + 1][j];
261 h[j + 1][j] = 0.0;
262
263 let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
265 g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
266 g[j] = tmp;
267
268 last_res = g[j + 1].abs() / b_norm;
269
270 if last_res < tol {
271 let y = solve_upper_triangular(&h, &g, j + 1);
273 update_solution(&mut x, &v, &y, j + 1);
275 return SolverResult {
276 x,
277 iterations: total_iter,
278 residual: last_res,
279 converged: true,
280 };
281 }
282 }
283
284 let y = solve_upper_triangular(&h, &g, m);
286 update_solution(&mut x, &v, &y, m);
287
288 if last_res < tol {
289 return SolverResult {
290 x,
291 iterations: total_iter,
292 residual: last_res,
293 converged: true,
294 };
295 }
296 }
297
298 let r = compute_residual(a, &x, b);
299 SolverResult {
300 x,
301 iterations: total_iter,
302 residual: norm2(&r) / b_norm,
303 converged: false,
304 }
305}
306
307fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
311 if b == 0.0 {
312 (1.0, 0.0)
313 } else if a.abs() > b.abs() {
314 let t = b / a;
315 let r = (1.0 + t * t).sqrt();
316 let c = 1.0 / r;
317 (c, c * t)
318 } else {
319 let t = a / b;
320 let r = (1.0 + t * t).sqrt();
321 let s = 1.0 / r;
322 (s * t, s)
323 }
324}
325
326fn solve_upper_triangular(h: &[Vec<f64>], g: &[f64], k: usize) -> Vec<f64> {
328 let mut y = vec![0.0f64; k];
329 for i in (0..k).rev() {
330 let mut sum_terms: Vec<f64> = Vec::with_capacity(k - i);
331 sum_terms.push(g[i]);
332 for j in (i + 1)..k {
333 sum_terms.push(-h[i][j] * y[j]);
334 }
335 let s = binned_sum_f64(&sum_terms);
336 if h[i][i].abs() > 1e-30 {
337 y[i] = s / h[i][i];
338 }
339 }
340 y
341}
342
343fn update_solution(x: &mut [f64], v: &[Vec<f64>], y: &[f64], k: usize) {
345 for i in 0..k {
346 axpy(y[i], &v[i], x);
347 }
348}
349
350pub fn bicgstab_solve(
365 a: &SparseCsr,
366 b: &[f64],
367 tol: f64,
368 max_iter: usize,
369) -> SolverResult {
370 let n = b.len();
371 assert_eq!(a.nrows, n, "bicgstab_solve: matrix rows must match rhs length");
372 assert_eq!(a.ncols, n, "bicgstab_solve: matrix must be square");
373
374 let b_norm = norm2(b);
375 if b_norm == 0.0 {
376 return SolverResult {
377 x: vec![0.0; n],
378 iterations: 0,
379 residual: 0.0,
380 converged: true,
381 };
382 }
383
384 let mut x = vec![0.0f64; n];
385 let mut r = b.to_vec(); let r0_hat = r.clone(); let mut rho = 1.0f64;
389 let mut alpha = 1.0f64;
390 let mut omega = 1.0f64;
391
392 let mut v = vec![0.0f64; n];
393 let mut p = vec![0.0f64; n];
394
395 for iter in 0..max_iter {
396 let rho_new = dot(&r0_hat, &r);
397
398 if rho_new.abs() < 1e-30 {
399 return SolverResult {
401 x,
402 iterations: iter + 1,
403 residual: norm2(&r) / b_norm,
404 converged: false,
405 };
406 }
407
408 let beta = (rho_new / rho) * (alpha / omega);
409 rho = rho_new;
410
411 for i in 0..n {
413 p[i] = r[i] + beta * (p[i] - omega * v[i]);
414 }
415
416 v = spmv(a, &p);
418
419 let r0_v = dot(&r0_hat, &v);
420 if r0_v.abs() < 1e-30 {
421 return SolverResult {
422 x,
423 iterations: iter + 1,
424 residual: norm2(&r) / b_norm,
425 converged: false,
426 };
427 }
428
429 alpha = rho / r0_v;
430
431 let s: Vec<f64> = r.iter().zip(v.iter()).map(|(&ri, &vi)| ri - alpha * vi).collect();
433
434 let s_norm = norm2(&s) / b_norm;
435 if s_norm < tol {
436 axpy(alpha, &p, &mut x);
438 return SolverResult {
439 x,
440 iterations: iter + 1,
441 residual: s_norm,
442 converged: true,
443 };
444 }
445
446 let t = spmv(a, &s);
448
449 let t_t = dot(&t, &t);
450 if t_t.abs() < 1e-30 {
451 axpy(alpha, &p, &mut x);
452 return SolverResult {
453 x,
454 iterations: iter + 1,
455 residual: norm2(&s) / b_norm,
456 converged: false,
457 };
458 }
459
460 omega = dot(&t, &s) / t_t;
461
462 axpy(alpha, &p, &mut x);
464 axpy(omega, &s, &mut x);
465
466 r = s.iter().zip(t.iter()).map(|(&si, &ti)| si - omega * ti).collect();
468
469 let res_norm = norm2(&r) / b_norm;
470 if res_norm < tol {
471 return SolverResult {
472 x,
473 iterations: iter + 1,
474 residual: res_norm,
475 converged: true,
476 };
477 }
478
479 if omega.abs() < 1e-30 {
480 return SolverResult {
481 x,
482 iterations: iter + 1,
483 residual: res_norm,
484 converged: false,
485 };
486 }
487 }
488
489 let res = norm2(&r) / b_norm;
490 SolverResult {
491 x,
492 iterations: max_iter,
493 residual: res,
494 converged: false,
495 }
496}
497
498#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::sparse::SparseCsr;
506
507 fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
509 let mut values = Vec::new();
510 let mut col_indices = Vec::new();
511 let mut row_offsets = vec![0usize];
512
513 for r in 0..nrows {
514 for c in 0..ncols {
515 let v = data[r * ncols + c];
516 if v != 0.0 {
517 values.push(v);
518 col_indices.push(c);
519 }
520 }
521 row_offsets.push(values.len());
522 }
523
524 SparseCsr { values, col_indices, row_offsets, nrows, ncols }
525 }
526
527 fn tridiag_spd(n: usize) -> SparseCsr {
529 let mut data = vec![0.0; n * n];
530 for i in 0..n {
531 data[i * n + i] = 4.0;
532 if i > 0 {
533 data[i * n + (i - 1)] = -1.0;
534 }
535 if i + 1 < n {
536 data[i * n + (i + 1)] = -1.0;
537 }
538 }
539 csr_from_dense(&data, n, n)
540 }
541
542 #[test]
545 fn test_cg_tridiag() {
546 let n = 10;
547 let a = tridiag_spd(n);
548 let b: Vec<f64> = (1..=n as i64).map(|i| i as f64).collect();
549
550 let result = cg_solve(&a, &b, 1e-10, 100);
551 assert!(result.converged, "CG did not converge: residual={}", result.residual);
552 assert!(result.residual < 1e-10);
553
554 let ax = spmv(&a, &result.x);
556 for i in 0..n {
557 assert!(
558 (ax[i] - b[i]).abs() < 1e-8,
559 "CG solution mismatch at i={}: ax={} b={}",
560 i, ax[i], b[i]
561 );
562 }
563 }
564
565 #[test]
566 fn test_cg_identity() {
567 let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
568 let b = vec![3.0, 7.0];
569 let result = cg_solve(&a, &b, 1e-12, 10);
570 assert!(result.converged);
571 assert!((result.x[0] - 3.0).abs() < 1e-10);
572 assert!((result.x[1] - 7.0).abs() < 1e-10);
573 }
574
575 #[test]
576 fn test_cg_zero_rhs() {
577 let a = tridiag_spd(5);
578 let b = vec![0.0; 5];
579 let result = cg_solve(&a, &b, 1e-10, 100);
580 assert!(result.converged);
581 assert_eq!(result.iterations, 0);
582 for &xi in &result.x {
583 assert_eq!(xi, 0.0);
584 }
585 }
586
587 #[test]
588 fn test_cg_determinism() {
589 let a = tridiag_spd(20);
590 let b: Vec<f64> = (0..20).map(|i| (i as f64).sin()).collect();
591
592 let r1 = cg_solve(&a, &b, 1e-10, 200);
593 let r2 = cg_solve(&a, &b, 1e-10, 200);
594
595 assert_eq!(r1.x, r2.x, "CG is not deterministic");
596 assert_eq!(r1.iterations, r2.iterations);
597 assert_eq!(r1.residual, r2.residual);
598 }
599
600 #[test]
603 fn test_gmres_nonsymmetric() {
604 let a = csr_from_dense(
606 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
607 3, 3,
608 );
609 let b = vec![1.0, 2.0, 3.0];
610
611 let result = gmres_solve(&a, &b, 1e-10, 100, 30);
612 assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
613
614 let ax = spmv(&a, &result.x);
615 for i in 0..3 {
616 assert!(
617 (ax[i] - b[i]).abs() < 1e-8,
618 "GMRES mismatch at i={}: ax={} b={}",
619 i, ax[i], b[i]
620 );
621 }
622 }
623
624 #[test]
625 fn test_gmres_identity() {
626 let a = csr_from_dense(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3);
627 let b = vec![5.0, 6.0, 7.0];
628 let result = gmres_solve(&a, &b, 1e-12, 10, 10);
629 assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
630 for i in 0..3 {
631 assert!((result.x[i] - b[i]).abs() < 1e-10,
632 "GMRES identity mismatch at i={}: x={} b={}", i, result.x[i], b[i]);
633 }
634 }
635
636 #[test]
637 fn test_gmres_zero_rhs() {
638 let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
639 let b = vec![0.0, 0.0];
640 let result = gmres_solve(&a, &b, 1e-10, 100, 10);
641 assert!(result.converged);
642 assert_eq!(result.iterations, 0);
643 }
644
645 #[test]
646 fn test_gmres_determinism() {
647 let a = csr_from_dense(
648 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
649 3, 3,
650 );
651 let b = vec![1.0, 2.0, 3.0];
652
653 let r1 = gmres_solve(&a, &b, 1e-10, 100, 30);
654 let r2 = gmres_solve(&a, &b, 1e-10, 100, 30);
655
656 assert_eq!(r1.x, r2.x, "GMRES is not deterministic");
657 assert_eq!(r1.iterations, r2.iterations);
658 }
659
660 #[test]
663 fn test_bicgstab_nonsymmetric() {
664 let a = csr_from_dense(
665 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
666 3, 3,
667 );
668 let b = vec![1.0, 2.0, 3.0];
669
670 let result = bicgstab_solve(&a, &b, 1e-10, 100);
671 assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
672
673 let ax = spmv(&a, &result.x);
674 for i in 0..3 {
675 assert!(
676 (ax[i] - b[i]).abs() < 1e-8,
677 "BiCGSTAB mismatch at i={}: ax={} b={}",
678 i, ax[i], b[i]
679 );
680 }
681 }
682
683 #[test]
684 fn test_bicgstab_spd() {
685 let a = tridiag_spd(10);
687 let b: Vec<f64> = (1..=10).map(|i| i as f64).collect();
688
689 let result = bicgstab_solve(&a, &b, 1e-10, 200);
690 assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
691
692 let ax = spmv(&a, &result.x);
693 for i in 0..10 {
694 assert!(
695 (ax[i] - b[i]).abs() < 1e-8,
696 "BiCGSTAB SPD mismatch at i={}",
697 i
698 );
699 }
700 }
701
702 #[test]
703 fn test_bicgstab_identity() {
704 let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
705 let b = vec![3.0, 7.0];
706 let result = bicgstab_solve(&a, &b, 1e-12, 10);
707 assert!(result.converged);
708 assert!((result.x[0] - 3.0).abs() < 1e-10);
709 assert!((result.x[1] - 7.0).abs() < 1e-10);
710 }
711
712 #[test]
713 fn test_bicgstab_zero_rhs() {
714 let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
715 let b = vec![0.0, 0.0];
716 let result = bicgstab_solve(&a, &b, 1e-10, 100);
717 assert!(result.converged);
718 assert_eq!(result.iterations, 0);
719 }
720
721 #[test]
722 fn test_bicgstab_determinism() {
723 let a = tridiag_spd(15);
724 let b: Vec<f64> = (0..15).map(|i| (i as f64 * 0.7).cos()).collect();
725
726 let r1 = bicgstab_solve(&a, &b, 1e-10, 200);
727 let r2 = bicgstab_solve(&a, &b, 1e-10, 200);
728
729 assert_eq!(r1.x, r2.x, "BiCGSTAB is not deterministic");
730 assert_eq!(r1.iterations, r2.iterations);
731 assert_eq!(r1.residual, r2.residual);
732 }
733
734 #[test]
737 fn test_solvers_agree_on_spd_system() {
738 let a = tridiag_spd(8);
739 let b: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
740
741 let cg = cg_solve(&a, &b, 1e-12, 200);
742 let gmres = gmres_solve(&a, &b, 1e-12, 200, 20);
743 let bicg = bicgstab_solve(&a, &b, 1e-12, 200);
744
745 assert!(cg.converged);
746 assert!(gmres.converged);
747 assert!(bicg.converged);
748
749 for i in 0..8 {
751 assert!(
752 (cg.x[i] - gmres.x[i]).abs() < 1e-8,
753 "CG vs GMRES disagree at i={}: {} vs {}",
754 i, cg.x[i], gmres.x[i]
755 );
756 assert!(
757 (cg.x[i] - bicg.x[i]).abs() < 1e-8,
758 "CG vs BiCGSTAB disagree at i={}: {} vs {}",
759 i, cg.x[i], bicg.x[i]
760 );
761 }
762 }
763}