1use crate::accumulator::binned_sum_f64;
8use crate::sparse::SparseCsr;
9
10#[derive(Debug, Clone)]
12pub struct SolverResult {
13 pub x: Vec<f64>,
15 pub iterations: usize,
17 pub residual: f64,
19 pub converged: bool,
21}
22
23fn dot(a: &[f64], b: &[f64]) -> f64 {
29 let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
30 binned_sum_f64(&products)
31}
32
33fn norm2(v: &[f64]) -> f64 {
35 dot(v, v).sqrt()
36}
37
38fn axpy(a: f64, x: &[f64], y: &mut [f64]) {
40 for (yi, &xi) in y.iter_mut().zip(x.iter()) {
41 *yi += a * xi;
42 }
43}
44
45fn compute_residual(a: &SparseCsr, x: &[f64], b: &[f64]) -> Vec<f64> {
47 let ax = spmv(a, x);
48 b.iter().zip(ax.iter()).map(|(&bi, &axi)| bi - axi).collect()
49}
50
51fn spmv(a: &SparseCsr, x: &[f64]) -> Vec<f64> {
53 let mut y = vec![0.0f64; a.nrows];
54 for row in 0..a.nrows {
55 let start = a.row_offsets[row];
56 let end = a.row_offsets[row + 1];
57 let products: Vec<f64> = (start..end)
58 .map(|idx| a.values[idx] * x[a.col_indices[idx]])
59 .collect();
60 y[row] = binned_sum_f64(&products);
61 }
62 y
63}
64
65pub fn cg_solve(a: &SparseCsr, b: &[f64], tol: f64, max_iter: usize) -> SolverResult {
80 let n = b.len();
81 assert_eq!(a.nrows, n, "cg_solve: matrix rows must match rhs length");
82 assert_eq!(a.ncols, n, "cg_solve: matrix must be square");
83
84 let mut x = vec![0.0f64; n];
85 let mut r = b.to_vec(); let mut p = r.clone();
87 let mut rs_old = dot(&r, &r);
88 let b_norm = norm2(b);
89
90 if b_norm == 0.0 {
91 return SolverResult {
92 x,
93 iterations: 0,
94 residual: 0.0,
95 converged: true,
96 };
97 }
98
99 for iter in 0..max_iter {
100 let ap = spmv(a, &p);
101 let p_ap = dot(&p, &ap);
102
103 if p_ap == 0.0 {
104 return SolverResult {
106 x,
107 iterations: iter + 1,
108 residual: norm2(&r) / b_norm,
109 converged: false,
110 };
111 }
112
113 let alpha = rs_old / p_ap;
114
115 axpy(alpha, &p, &mut x);
117 axpy(-alpha, &ap, &mut r);
119
120 let rs_new = dot(&r, &r);
121 let res_norm = rs_new.sqrt() / b_norm;
122
123 if res_norm < tol {
124 return SolverResult {
125 x,
126 iterations: iter + 1,
127 residual: res_norm,
128 converged: true,
129 };
130 }
131
132 let beta = rs_new / rs_old;
133 for i in 0..n {
135 p[i] = r[i] + beta * p[i];
136 }
137 rs_old = rs_new;
138 }
139
140 SolverResult {
141 x,
142 iterations: max_iter,
143 residual: norm2(&r) / b_norm,
144 converged: false,
145 }
146}
147
148pub fn gmres_solve(
164 a: &SparseCsr,
165 b: &[f64],
166 tol: f64,
167 max_iter: usize,
168 restart: usize,
169) -> SolverResult {
170 let n = b.len();
171 assert_eq!(a.nrows, n, "gmres_solve: matrix rows must match rhs length");
172 assert_eq!(a.ncols, n, "gmres_solve: matrix must be square");
173
174 let b_norm = norm2(b);
175 if b_norm == 0.0 {
176 return SolverResult {
177 x: vec![0.0; n],
178 iterations: 0,
179 residual: 0.0,
180 converged: true,
181 };
182 }
183
184 let mut x = vec![0.0f64; n];
185 let mut total_iter = 0;
186
187 while total_iter < max_iter {
188 let r = compute_residual(a, &x, b);
189 let r_norm = norm2(&r);
190
191 if r_norm / b_norm < tol {
192 return SolverResult {
193 x,
194 iterations: total_iter,
195 residual: r_norm / b_norm,
196 converged: true,
197 };
198 }
199
200 let m = restart.min(max_iter - total_iter);
202
203 let mut v: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
205 let v0: Vec<f64> = r.iter().map(|&ri| ri / r_norm).collect();
207 v.push(v0);
208
209 let mut h = vec![vec![0.0f64; m]; m + 1];
211
212 let mut cs = vec![0.0f64; m];
214 let mut sn = vec![0.0f64; m];
215
216 let mut g = vec![0.0f64; m + 1];
218 g[0] = r_norm;
219
220 let mut last_res = r_norm / b_norm;
221
222 for j in 0..m {
223 total_iter += 1;
224
225 let mut w = spmv(a, &v[j]);
227
228 for i in 0..=j {
230 h[i][j] = dot(&w, &v[i]);
231 axpy(-h[i][j], &v[i], &mut w);
233 }
234
235 h[j + 1][j] = norm2(&w);
236
237 if h[j + 1][j] > 1e-14 {
238 let inv = 1.0 / h[j + 1][j];
239 let vn: Vec<f64> = w.iter().map(|&wi| wi * inv).collect();
240 v.push(vn);
241 } else {
242 v.push(vec![0.0; n]);
244 }
245
246 for i in 0..j {
248 let tmp = cs[i] * h[i][j] + sn[i] * h[i + 1][j];
249 h[i + 1][j] = -sn[i] * h[i][j] + cs[i] * h[i + 1][j];
250 h[i][j] = tmp;
251 }
252
253 let (c, s) = givens_rotation(h[j][j], h[j + 1][j]);
255 cs[j] = c;
256 sn[j] = s;
257
258 h[j][j] = c * h[j][j] + s * h[j + 1][j];
259 h[j + 1][j] = 0.0;
260
261 let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
263 g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
264 g[j] = tmp;
265
266 last_res = g[j + 1].abs() / b_norm;
267
268 if last_res < tol {
269 let y = solve_upper_triangular(&h, &g, j + 1);
271 update_solution(&mut x, &v, &y, j + 1);
273 return SolverResult {
274 x,
275 iterations: total_iter,
276 residual: last_res,
277 converged: true,
278 };
279 }
280 }
281
282 let y = solve_upper_triangular(&h, &g, m);
284 update_solution(&mut x, &v, &y, m);
285
286 if last_res < tol {
287 return SolverResult {
288 x,
289 iterations: total_iter,
290 residual: last_res,
291 converged: true,
292 };
293 }
294 }
295
296 let r = compute_residual(a, &x, b);
297 SolverResult {
298 x,
299 iterations: total_iter,
300 residual: norm2(&r) / b_norm,
301 converged: false,
302 }
303}
304
305fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
309 if b == 0.0 {
310 (1.0, 0.0)
311 } else if a.abs() > b.abs() {
312 let t = b / a;
313 let r = (1.0 + t * t).sqrt();
314 let c = 1.0 / r;
315 (c, c * t)
316 } else {
317 let t = a / b;
318 let r = (1.0 + t * t).sqrt();
319 let s = 1.0 / r;
320 (s * t, s)
321 }
322}
323
324fn solve_upper_triangular(h: &[Vec<f64>], g: &[f64], k: usize) -> Vec<f64> {
326 let mut y = vec![0.0f64; k];
327 for i in (0..k).rev() {
328 let mut sum_terms: Vec<f64> = Vec::with_capacity(k - i);
329 sum_terms.push(g[i]);
330 for j in (i + 1)..k {
331 sum_terms.push(-h[i][j] * y[j]);
332 }
333 let s = binned_sum_f64(&sum_terms);
334 if h[i][i].abs() > 1e-30 {
335 y[i] = s / h[i][i];
336 }
337 }
338 y
339}
340
341fn update_solution(x: &mut [f64], v: &[Vec<f64>], y: &[f64], k: usize) {
343 for i in 0..k {
344 axpy(y[i], &v[i], x);
345 }
346}
347
348pub fn bicgstab_solve(
363 a: &SparseCsr,
364 b: &[f64],
365 tol: f64,
366 max_iter: usize,
367) -> SolverResult {
368 let n = b.len();
369 assert_eq!(a.nrows, n, "bicgstab_solve: matrix rows must match rhs length");
370 assert_eq!(a.ncols, n, "bicgstab_solve: matrix must be square");
371
372 let b_norm = norm2(b);
373 if b_norm == 0.0 {
374 return SolverResult {
375 x: vec![0.0; n],
376 iterations: 0,
377 residual: 0.0,
378 converged: true,
379 };
380 }
381
382 let mut x = vec![0.0f64; n];
383 let mut r = b.to_vec(); let r0_hat = r.clone(); let mut rho = 1.0f64;
387 let mut alpha = 1.0f64;
388 let mut omega = 1.0f64;
389
390 let mut v = vec![0.0f64; n];
391 let mut p = vec![0.0f64; n];
392
393 for iter in 0..max_iter {
394 let rho_new = dot(&r0_hat, &r);
395
396 if rho_new.abs() < 1e-30 {
397 return SolverResult {
399 x,
400 iterations: iter + 1,
401 residual: norm2(&r) / b_norm,
402 converged: false,
403 };
404 }
405
406 let beta = (rho_new / rho) * (alpha / omega);
407 rho = rho_new;
408
409 for i in 0..n {
411 p[i] = r[i] + beta * (p[i] - omega * v[i]);
412 }
413
414 v = spmv(a, &p);
416
417 let r0_v = dot(&r0_hat, &v);
418 if r0_v.abs() < 1e-30 {
419 return SolverResult {
420 x,
421 iterations: iter + 1,
422 residual: norm2(&r) / b_norm,
423 converged: false,
424 };
425 }
426
427 alpha = rho / r0_v;
428
429 let s: Vec<f64> = r.iter().zip(v.iter()).map(|(&ri, &vi)| ri - alpha * vi).collect();
431
432 let s_norm = norm2(&s) / b_norm;
433 if s_norm < tol {
434 axpy(alpha, &p, &mut x);
436 return SolverResult {
437 x,
438 iterations: iter + 1,
439 residual: s_norm,
440 converged: true,
441 };
442 }
443
444 let t = spmv(a, &s);
446
447 let t_t = dot(&t, &t);
448 if t_t.abs() < 1e-30 {
449 axpy(alpha, &p, &mut x);
450 return SolverResult {
451 x,
452 iterations: iter + 1,
453 residual: norm2(&s) / b_norm,
454 converged: false,
455 };
456 }
457
458 omega = dot(&t, &s) / t_t;
459
460 axpy(alpha, &p, &mut x);
462 axpy(omega, &s, &mut x);
463
464 r = s.iter().zip(t.iter()).map(|(&si, &ti)| si - omega * ti).collect();
466
467 let res_norm = norm2(&r) / b_norm;
468 if res_norm < tol {
469 return SolverResult {
470 x,
471 iterations: iter + 1,
472 residual: res_norm,
473 converged: true,
474 };
475 }
476
477 if omega.abs() < 1e-30 {
478 return SolverResult {
479 x,
480 iterations: iter + 1,
481 residual: res_norm,
482 converged: false,
483 };
484 }
485 }
486
487 let res = norm2(&r) / b_norm;
488 SolverResult {
489 x,
490 iterations: max_iter,
491 residual: res,
492 converged: false,
493 }
494}
495
496#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::sparse::SparseCsr;
504
505 fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
507 let mut values = Vec::new();
508 let mut col_indices = Vec::new();
509 let mut row_offsets = vec![0usize];
510
511 for r in 0..nrows {
512 for c in 0..ncols {
513 let v = data[r * ncols + c];
514 if v != 0.0 {
515 values.push(v);
516 col_indices.push(c);
517 }
518 }
519 row_offsets.push(values.len());
520 }
521
522 SparseCsr { values, col_indices, row_offsets, nrows, ncols }
523 }
524
525 fn tridiag_spd(n: usize) -> SparseCsr {
527 let mut data = vec![0.0; n * n];
528 for i in 0..n {
529 data[i * n + i] = 4.0;
530 if i > 0 {
531 data[i * n + (i - 1)] = -1.0;
532 }
533 if i + 1 < n {
534 data[i * n + (i + 1)] = -1.0;
535 }
536 }
537 csr_from_dense(&data, n, n)
538 }
539
540 #[test]
543 fn test_cg_tridiag() {
544 let n = 10;
545 let a = tridiag_spd(n);
546 let b: Vec<f64> = (1..=n as i64).map(|i| i as f64).collect();
547
548 let result = cg_solve(&a, &b, 1e-10, 100);
549 assert!(result.converged, "CG did not converge: residual={}", result.residual);
550 assert!(result.residual < 1e-10);
551
552 let ax = spmv(&a, &result.x);
554 for i in 0..n {
555 assert!(
556 (ax[i] - b[i]).abs() < 1e-8,
557 "CG solution mismatch at i={}: ax={} b={}",
558 i, ax[i], b[i]
559 );
560 }
561 }
562
563 #[test]
564 fn test_cg_identity() {
565 let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
566 let b = vec![3.0, 7.0];
567 let result = cg_solve(&a, &b, 1e-12, 10);
568 assert!(result.converged);
569 assert!((result.x[0] - 3.0).abs() < 1e-10);
570 assert!((result.x[1] - 7.0).abs() < 1e-10);
571 }
572
573 #[test]
574 fn test_cg_zero_rhs() {
575 let a = tridiag_spd(5);
576 let b = vec![0.0; 5];
577 let result = cg_solve(&a, &b, 1e-10, 100);
578 assert!(result.converged);
579 assert_eq!(result.iterations, 0);
580 for &xi in &result.x {
581 assert_eq!(xi, 0.0);
582 }
583 }
584
585 #[test]
586 fn test_cg_determinism() {
587 let a = tridiag_spd(20);
588 let b: Vec<f64> = (0..20).map(|i| (i as f64).sin()).collect();
589
590 let r1 = cg_solve(&a, &b, 1e-10, 200);
591 let r2 = cg_solve(&a, &b, 1e-10, 200);
592
593 assert_eq!(r1.x, r2.x, "CG is not deterministic");
594 assert_eq!(r1.iterations, r2.iterations);
595 assert_eq!(r1.residual, r2.residual);
596 }
597
598 #[test]
601 fn test_gmres_nonsymmetric() {
602 let a = csr_from_dense(
604 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
605 3, 3,
606 );
607 let b = vec![1.0, 2.0, 3.0];
608
609 let result = gmres_solve(&a, &b, 1e-10, 100, 30);
610 assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
611
612 let ax = spmv(&a, &result.x);
613 for i in 0..3 {
614 assert!(
615 (ax[i] - b[i]).abs() < 1e-8,
616 "GMRES mismatch at i={}: ax={} b={}",
617 i, ax[i], b[i]
618 );
619 }
620 }
621
622 #[test]
623 fn test_gmres_identity() {
624 let a = csr_from_dense(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3);
625 let b = vec![5.0, 6.0, 7.0];
626 let result = gmres_solve(&a, &b, 1e-12, 10, 10);
627 assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
628 for i in 0..3 {
629 assert!((result.x[i] - b[i]).abs() < 1e-10,
630 "GMRES identity mismatch at i={}: x={} b={}", i, result.x[i], b[i]);
631 }
632 }
633
634 #[test]
635 fn test_gmres_zero_rhs() {
636 let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
637 let b = vec![0.0, 0.0];
638 let result = gmres_solve(&a, &b, 1e-10, 100, 10);
639 assert!(result.converged);
640 assert_eq!(result.iterations, 0);
641 }
642
643 #[test]
644 fn test_gmres_determinism() {
645 let a = csr_from_dense(
646 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
647 3, 3,
648 );
649 let b = vec![1.0, 2.0, 3.0];
650
651 let r1 = gmres_solve(&a, &b, 1e-10, 100, 30);
652 let r2 = gmres_solve(&a, &b, 1e-10, 100, 30);
653
654 assert_eq!(r1.x, r2.x, "GMRES is not deterministic");
655 assert_eq!(r1.iterations, r2.iterations);
656 }
657
658 #[test]
661 fn test_bicgstab_nonsymmetric() {
662 let a = csr_from_dense(
663 &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
664 3, 3,
665 );
666 let b = vec![1.0, 2.0, 3.0];
667
668 let result = bicgstab_solve(&a, &b, 1e-10, 100);
669 assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
670
671 let ax = spmv(&a, &result.x);
672 for i in 0..3 {
673 assert!(
674 (ax[i] - b[i]).abs() < 1e-8,
675 "BiCGSTAB mismatch at i={}: ax={} b={}",
676 i, ax[i], b[i]
677 );
678 }
679 }
680
681 #[test]
682 fn test_bicgstab_spd() {
683 let a = tridiag_spd(10);
685 let b: Vec<f64> = (1..=10).map(|i| i as f64).collect();
686
687 let result = bicgstab_solve(&a, &b, 1e-10, 200);
688 assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
689
690 let ax = spmv(&a, &result.x);
691 for i in 0..10 {
692 assert!(
693 (ax[i] - b[i]).abs() < 1e-8,
694 "BiCGSTAB SPD mismatch at i={}",
695 i
696 );
697 }
698 }
699
700 #[test]
701 fn test_bicgstab_identity() {
702 let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
703 let b = vec![3.0, 7.0];
704 let result = bicgstab_solve(&a, &b, 1e-12, 10);
705 assert!(result.converged);
706 assert!((result.x[0] - 3.0).abs() < 1e-10);
707 assert!((result.x[1] - 7.0).abs() < 1e-10);
708 }
709
710 #[test]
711 fn test_bicgstab_zero_rhs() {
712 let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
713 let b = vec![0.0, 0.0];
714 let result = bicgstab_solve(&a, &b, 1e-10, 100);
715 assert!(result.converged);
716 assert_eq!(result.iterations, 0);
717 }
718
719 #[test]
720 fn test_bicgstab_determinism() {
721 let a = tridiag_spd(15);
722 let b: Vec<f64> = (0..15).map(|i| (i as f64 * 0.7).cos()).collect();
723
724 let r1 = bicgstab_solve(&a, &b, 1e-10, 200);
725 let r2 = bicgstab_solve(&a, &b, 1e-10, 200);
726
727 assert_eq!(r1.x, r2.x, "BiCGSTAB is not deterministic");
728 assert_eq!(r1.iterations, r2.iterations);
729 assert_eq!(r1.residual, r2.residual);
730 }
731
732 #[test]
735 fn test_solvers_agree_on_spd_system() {
736 let a = tridiag_spd(8);
737 let b: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
738
739 let cg = cg_solve(&a, &b, 1e-12, 200);
740 let gmres = gmres_solve(&a, &b, 1e-12, 200, 20);
741 let bicg = bicgstab_solve(&a, &b, 1e-12, 200);
742
743 assert!(cg.converged);
744 assert!(gmres.converged);
745 assert!(bicg.converged);
746
747 for i in 0..8 {
749 assert!(
750 (cg.x[i] - gmres.x[i]).abs() < 1e-8,
751 "CG vs GMRES disagree at i={}: {} vs {}",
752 i, cg.x[i], gmres.x[i]
753 );
754 assert!(
755 (cg.x[i] - bicg.x[i]).abs() < 1e-8,
756 "CG vs BiCGSTAB disagree at i={}: {} vs {}",
757 i, cg.x[i], bicg.x[i]
758 );
759 }
760 }
761}