Skip to main content

cjc_runtime/
sparse_solvers.rs

1//! Iterative solvers for sparse linear systems.
2//!
3//! All solvers operate on CSR matrices and use deterministic floating-point
4//! reductions via `binned_sum_f64` to guarantee bit-identical results across
5//! runs and platforms.
6
7use crate::accumulator::binned_sum_f64;
8use crate::sparse::SparseCsr;
9
10/// Result of an iterative solver.
11#[derive(Debug, Clone)]
12pub struct SolverResult {
13    /// Solution vector.
14    pub x: Vec<f64>,
15    /// Number of iterations used.
16    pub iterations: usize,
17    /// Final residual norm (L2).
18    pub residual: f64,
19    /// Whether the solver converged to within the requested tolerance.
20    pub converged: bool,
21}
22
23// ---------------------------------------------------------------------------
24// Vector helpers — all reductions use binned_sum_f64 for determinism.
25// ---------------------------------------------------------------------------
26
27/// Deterministic dot product of two vectors.
28fn dot(a: &[f64], b: &[f64]) -> f64 {
29    let products: Vec<f64> = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).collect();
30    binned_sum_f64(&products)
31}
32
33/// Deterministic L2 norm.
34fn norm2(v: &[f64]) -> f64 {
35    dot(v, v).sqrt()
36}
37
38/// y = a*x + y  (in-place AXPY).
39fn axpy(a: f64, x: &[f64], y: &mut [f64]) {
40    for (yi, &xi) in y.iter_mut().zip(x.iter()) {
41        *yi += a * xi;
42    }
43}
44
45/// r = b - A*x
46fn compute_residual(a: &SparseCsr, x: &[f64], b: &[f64]) -> Vec<f64> {
47    let ax = spmv(a, x);
48    b.iter().zip(ax.iter()).map(|(&bi, &axi)| bi - axi).collect()
49}
50
51/// Sparse matrix-vector product using binned summation for each row.
52fn spmv(a: &SparseCsr, x: &[f64]) -> Vec<f64> {
53    let mut y = vec![0.0f64; a.nrows];
54    for row in 0..a.nrows {
55        let start = a.row_offsets[row];
56        let end = a.row_offsets[row + 1];
57        let products: Vec<f64> = (start..end)
58            .map(|idx| a.values[idx] * x[a.col_indices[idx]])
59            .collect();
60        y[row] = binned_sum_f64(&products);
61    }
62    y
63}
64
65// ---------------------------------------------------------------------------
66// Conjugate Gradient (CG)
67// ---------------------------------------------------------------------------
68
69/// Conjugate Gradient solver for symmetric positive-definite systems Ax = b.
70///
71/// # Arguments
72/// * `a` — SPD sparse matrix (must be square, n x n)
73/// * `b` — right-hand side vector (length n)
74/// * `tol` — convergence tolerance on the relative residual norm
75/// * `max_iter` — maximum number of iterations
76///
77/// # Determinism
78/// All inner products and norms use `binned_sum_f64`.
79pub fn cg_solve(a: &SparseCsr, b: &[f64], tol: f64, max_iter: usize) -> SolverResult {
80    let n = b.len();
81    assert_eq!(a.nrows, n, "cg_solve: matrix rows must match rhs length");
82    assert_eq!(a.ncols, n, "cg_solve: matrix must be square");
83
84    let mut x = vec![0.0f64; n];
85    let mut r = b.to_vec(); // r = b - A*0 = b
86    let mut p = r.clone();
87    let mut rs_old = dot(&r, &r);
88    let b_norm = norm2(b);
89
90    if b_norm == 0.0 {
91        return SolverResult {
92            x,
93            iterations: 0,
94            residual: 0.0,
95            converged: true,
96        };
97    }
98
99    for iter in 0..max_iter {
100        let ap = spmv(a, &p);
101        let p_ap = dot(&p, &ap);
102
103        if p_ap == 0.0 {
104            // Breakdown — p is in the null space
105            return SolverResult {
106                x,
107                iterations: iter + 1,
108                residual: norm2(&r) / b_norm,
109                converged: false,
110            };
111        }
112
113        let alpha = rs_old / p_ap;
114
115        // x = x + alpha * p
116        axpy(alpha, &p, &mut x);
117        // r = r - alpha * A*p
118        axpy(-alpha, &ap, &mut r);
119
120        let rs_new = dot(&r, &r);
121        let res_norm = rs_new.sqrt() / b_norm;
122
123        if res_norm < tol {
124            return SolverResult {
125                x,
126                iterations: iter + 1,
127                residual: res_norm,
128                converged: true,
129            };
130        }
131
132        let beta = rs_new / rs_old;
133        // p = r + beta * p
134        for i in 0..n {
135            p[i] = r[i] + beta * p[i];
136        }
137        rs_old = rs_new;
138    }
139
140    SolverResult {
141        x,
142        iterations: max_iter,
143        residual: norm2(&r) / b_norm,
144        converged: false,
145    }
146}
147
148// ---------------------------------------------------------------------------
149// GMRES (Generalized Minimum Residual)
150// ---------------------------------------------------------------------------
151
152/// GMRES solver for general (non-symmetric) systems Ax = b with restarts.
153///
154/// # Arguments
155/// * `a` — sparse matrix (must be square, n x n)
156/// * `b` — right-hand side vector (length n)
157/// * `tol` — convergence tolerance on the relative residual norm
158/// * `max_iter` — maximum total number of iterations
159/// * `restart` — restart dimension (Krylov subspace size before restart)
160///
161/// # Determinism
162/// All inner products, norms, and Givens rotations use deterministic arithmetic.
163pub fn gmres_solve(
164    a: &SparseCsr,
165    b: &[f64],
166    tol: f64,
167    max_iter: usize,
168    restart: usize,
169) -> SolverResult {
170    let n = b.len();
171    assert_eq!(a.nrows, n, "gmres_solve: matrix rows must match rhs length");
172    assert_eq!(a.ncols, n, "gmres_solve: matrix must be square");
173
174    let b_norm = norm2(b);
175    if b_norm == 0.0 {
176        return SolverResult {
177            x: vec![0.0; n],
178            iterations: 0,
179            residual: 0.0,
180            converged: true,
181        };
182    }
183
184    let mut x = vec![0.0f64; n];
185    let mut total_iter = 0;
186
187    while total_iter < max_iter {
188        let r = compute_residual(a, &x, b);
189        let r_norm = norm2(&r);
190
191        if r_norm / b_norm < tol {
192            return SolverResult {
193                x,
194                iterations: total_iter,
195                residual: r_norm / b_norm,
196                converged: true,
197            };
198        }
199
200        // Arnoldi + Givens rotation based GMRES
201        let m = restart.min(max_iter - total_iter);
202
203        // Krylov basis vectors V[0..m+1], each of length n
204        let mut v: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
205        // v[0] = r / ||r||
206        let v0: Vec<f64> = r.iter().map(|&ri| ri / r_norm).collect();
207        v.push(v0);
208
209        // Upper Hessenberg matrix H (stored as (m+1) x m)
210        let mut h = vec![vec![0.0f64; m]; m + 1];
211
212        // Givens rotation parameters
213        let mut cs = vec![0.0f64; m];
214        let mut sn = vec![0.0f64; m];
215
216        // Right-hand side of the least-squares problem
217        let mut g = vec![0.0f64; m + 1];
218        g[0] = r_norm;
219
220        let mut last_res = r_norm / b_norm;
221
222        for j in 0..m {
223            total_iter += 1;
224
225            // w = A * v[j]
226            let mut w = spmv(a, &v[j]);
227
228            // Modified Gram-Schmidt orthogonalization
229            for i in 0..=j {
230                h[i][j] = dot(&w, &v[i]);
231                // w = w - h[i][j] * v[i]
232                axpy(-h[i][j], &v[i], &mut w);
233            }
234
235            h[j + 1][j] = norm2(&w);
236
237            if h[j + 1][j] > 1e-14 {
238                let inv = 1.0 / h[j + 1][j];
239                let vn: Vec<f64> = w.iter().map(|&wi| wi * inv).collect();
240                v.push(vn);
241            } else {
242                // Lucky breakdown — w is zero, push zero vector
243                v.push(vec![0.0; n]);
244            }
245
246            // Apply previous Givens rotations to column j of H
247            for i in 0..j {
248                let tmp = cs[i] * h[i][j] + sn[i] * h[i + 1][j];
249                h[i + 1][j] = -sn[i] * h[i][j] + cs[i] * h[i + 1][j];
250                h[i][j] = tmp;
251            }
252
253            // Compute new Givens rotation for row (j, j+1)
254            let (c, s) = givens_rotation(h[j][j], h[j + 1][j]);
255            cs[j] = c;
256            sn[j] = s;
257
258            h[j][j] = c * h[j][j] + s * h[j + 1][j];
259            h[j + 1][j] = 0.0;
260
261            // Update g
262            let tmp = cs[j] * g[j] + sn[j] * g[j + 1];
263            g[j + 1] = -sn[j] * g[j] + cs[j] * g[j + 1];
264            g[j] = tmp;
265
266            last_res = g[j + 1].abs() / b_norm;
267
268            if last_res < tol {
269                // Solve the upper triangular system H * y = g
270                let y = solve_upper_triangular(&h, &g, j + 1);
271                // x = x + V * y
272                update_solution(&mut x, &v, &y, j + 1);
273                return SolverResult {
274                    x,
275                    iterations: total_iter,
276                    residual: last_res,
277                    converged: true,
278                };
279            }
280        }
281
282        // End of restart cycle — solve and update
283        let y = solve_upper_triangular(&h, &g, m);
284        update_solution(&mut x, &v, &y, m);
285
286        if last_res < tol {
287            return SolverResult {
288                x,
289                iterations: total_iter,
290                residual: last_res,
291                converged: true,
292            };
293        }
294    }
295
296    let r = compute_residual(a, &x, b);
297    SolverResult {
298        x,
299        iterations: total_iter,
300        residual: norm2(&r) / b_norm,
301        converged: false,
302    }
303}
304
305/// Compute Givens rotation parameters (c, s) such that
306/// [c  s] [a]   [r]
307/// [-s c] [b] = [0]
308fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
309    if b == 0.0 {
310        (1.0, 0.0)
311    } else if a.abs() > b.abs() {
312        let t = b / a;
313        let r = (1.0 + t * t).sqrt();
314        let c = 1.0 / r;
315        (c, c * t)
316    } else {
317        let t = a / b;
318        let r = (1.0 + t * t).sqrt();
319        let s = 1.0 / r;
320        (s * t, s)
321    }
322}
323
324/// Solve upper triangular system H[0..k, 0..k] * y = g[0..k] by back substitution.
325fn solve_upper_triangular(h: &[Vec<f64>], g: &[f64], k: usize) -> Vec<f64> {
326    let mut y = vec![0.0f64; k];
327    for i in (0..k).rev() {
328        let mut sum_terms: Vec<f64> = Vec::with_capacity(k - i);
329        sum_terms.push(g[i]);
330        for j in (i + 1)..k {
331            sum_terms.push(-h[i][j] * y[j]);
332        }
333        let s = binned_sum_f64(&sum_terms);
334        if h[i][i].abs() > 1e-30 {
335            y[i] = s / h[i][i];
336        }
337    }
338    y
339}
340
341/// x += V[0..k] * y[0..k]
342fn update_solution(x: &mut [f64], v: &[Vec<f64>], y: &[f64], k: usize) {
343    for i in 0..k {
344        axpy(y[i], &v[i], x);
345    }
346}
347
348// ---------------------------------------------------------------------------
349// BiCGSTAB (Biconjugate Gradient Stabilized)
350// ---------------------------------------------------------------------------
351
352/// BiCGSTAB solver for general (non-symmetric) systems Ax = b.
353///
354/// # Arguments
355/// * `a` — sparse matrix (must be square, n x n)
356/// * `b` — right-hand side vector (length n)
357/// * `tol` — convergence tolerance on the relative residual norm
358/// * `max_iter` — maximum number of iterations
359///
360/// # Determinism
361/// All inner products and reductions use `binned_sum_f64`.
362pub fn bicgstab_solve(
363    a: &SparseCsr,
364    b: &[f64],
365    tol: f64,
366    max_iter: usize,
367) -> SolverResult {
368    let n = b.len();
369    assert_eq!(a.nrows, n, "bicgstab_solve: matrix rows must match rhs length");
370    assert_eq!(a.ncols, n, "bicgstab_solve: matrix must be square");
371
372    let b_norm = norm2(b);
373    if b_norm == 0.0 {
374        return SolverResult {
375            x: vec![0.0; n],
376            iterations: 0,
377            residual: 0.0,
378            converged: true,
379        };
380    }
381
382    let mut x = vec![0.0f64; n];
383    let mut r = b.to_vec(); // r = b - A*0 = b
384    let r0_hat = r.clone();  // shadow residual, kept fixed
385
386    let mut rho = 1.0f64;
387    let mut alpha = 1.0f64;
388    let mut omega = 1.0f64;
389
390    let mut v = vec![0.0f64; n];
391    let mut p = vec![0.0f64; n];
392
393    for iter in 0..max_iter {
394        let rho_new = dot(&r0_hat, &r);
395
396        if rho_new.abs() < 1e-30 {
397            // Breakdown
398            return SolverResult {
399                x,
400                iterations: iter + 1,
401                residual: norm2(&r) / b_norm,
402                converged: false,
403            };
404        }
405
406        let beta = (rho_new / rho) * (alpha / omega);
407        rho = rho_new;
408
409        // p = r + beta * (p - omega * v)
410        for i in 0..n {
411            p[i] = r[i] + beta * (p[i] - omega * v[i]);
412        }
413
414        // v = A * p
415        v = spmv(a, &p);
416
417        let r0_v = dot(&r0_hat, &v);
418        if r0_v.abs() < 1e-30 {
419            return SolverResult {
420                x,
421                iterations: iter + 1,
422                residual: norm2(&r) / b_norm,
423                converged: false,
424            };
425        }
426
427        alpha = rho / r0_v;
428
429        // s = r - alpha * v
430        let s: Vec<f64> = r.iter().zip(v.iter()).map(|(&ri, &vi)| ri - alpha * vi).collect();
431
432        let s_norm = norm2(&s) / b_norm;
433        if s_norm < tol {
434            // x = x + alpha * p
435            axpy(alpha, &p, &mut x);
436            return SolverResult {
437                x,
438                iterations: iter + 1,
439                residual: s_norm,
440                converged: true,
441            };
442        }
443
444        // t = A * s
445        let t = spmv(a, &s);
446
447        let t_t = dot(&t, &t);
448        if t_t.abs() < 1e-30 {
449            axpy(alpha, &p, &mut x);
450            return SolverResult {
451                x,
452                iterations: iter + 1,
453                residual: norm2(&s) / b_norm,
454                converged: false,
455            };
456        }
457
458        omega = dot(&t, &s) / t_t;
459
460        // x = x + alpha * p + omega * s
461        axpy(alpha, &p, &mut x);
462        axpy(omega, &s, &mut x);
463
464        // r = s - omega * t
465        r = s.iter().zip(t.iter()).map(|(&si, &ti)| si - omega * ti).collect();
466
467        let res_norm = norm2(&r) / b_norm;
468        if res_norm < tol {
469            return SolverResult {
470                x,
471                iterations: iter + 1,
472                residual: res_norm,
473                converged: true,
474            };
475        }
476
477        if omega.abs() < 1e-30 {
478            return SolverResult {
479                x,
480                iterations: iter + 1,
481                residual: res_norm,
482                converged: false,
483            };
484        }
485    }
486
487    let res = norm2(&r) / b_norm;
488    SolverResult {
489        x,
490        iterations: max_iter,
491        residual: res,
492        converged: false,
493    }
494}
495
496// ---------------------------------------------------------------------------
497// Tests
498// ---------------------------------------------------------------------------
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::sparse::SparseCsr;
504
505    /// Build a CSR matrix from dense data.
506    fn csr_from_dense(data: &[f64], nrows: usize, ncols: usize) -> SparseCsr {
507        let mut values = Vec::new();
508        let mut col_indices = Vec::new();
509        let mut row_offsets = vec![0usize];
510
511        for r in 0..nrows {
512            for c in 0..ncols {
513                let v = data[r * ncols + c];
514                if v != 0.0 {
515                    values.push(v);
516                    col_indices.push(c);
517                }
518            }
519            row_offsets.push(values.len());
520        }
521
522        SparseCsr { values, col_indices, row_offsets, nrows, ncols }
523    }
524
525    /// Build a tridiagonal SPD matrix: diag=4, off-diag=-1.
526    fn tridiag_spd(n: usize) -> SparseCsr {
527        let mut data = vec![0.0; n * n];
528        for i in 0..n {
529            data[i * n + i] = 4.0;
530            if i > 0 {
531                data[i * n + (i - 1)] = -1.0;
532            }
533            if i + 1 < n {
534                data[i * n + (i + 1)] = -1.0;
535            }
536        }
537        csr_from_dense(&data, n, n)
538    }
539
540    // -- CG tests --
541
542    #[test]
543    fn test_cg_tridiag() {
544        let n = 10;
545        let a = tridiag_spd(n);
546        let b: Vec<f64> = (1..=n as i64).map(|i| i as f64).collect();
547
548        let result = cg_solve(&a, &b, 1e-10, 100);
549        assert!(result.converged, "CG did not converge: residual={}", result.residual);
550        assert!(result.residual < 1e-10);
551
552        // Verify A*x ≈ b
553        let ax = spmv(&a, &result.x);
554        for i in 0..n {
555            assert!(
556                (ax[i] - b[i]).abs() < 1e-8,
557                "CG solution mismatch at i={}: ax={} b={}",
558                i, ax[i], b[i]
559            );
560        }
561    }
562
563    #[test]
564    fn test_cg_identity() {
565        let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
566        let b = vec![3.0, 7.0];
567        let result = cg_solve(&a, &b, 1e-12, 10);
568        assert!(result.converged);
569        assert!((result.x[0] - 3.0).abs() < 1e-10);
570        assert!((result.x[1] - 7.0).abs() < 1e-10);
571    }
572
573    #[test]
574    fn test_cg_zero_rhs() {
575        let a = tridiag_spd(5);
576        let b = vec![0.0; 5];
577        let result = cg_solve(&a, &b, 1e-10, 100);
578        assert!(result.converged);
579        assert_eq!(result.iterations, 0);
580        for &xi in &result.x {
581            assert_eq!(xi, 0.0);
582        }
583    }
584
585    #[test]
586    fn test_cg_determinism() {
587        let a = tridiag_spd(20);
588        let b: Vec<f64> = (0..20).map(|i| (i as f64).sin()).collect();
589
590        let r1 = cg_solve(&a, &b, 1e-10, 200);
591        let r2 = cg_solve(&a, &b, 1e-10, 200);
592
593        assert_eq!(r1.x, r2.x, "CG is not deterministic");
594        assert_eq!(r1.iterations, r2.iterations);
595        assert_eq!(r1.residual, r2.residual);
596    }
597
598    // -- GMRES tests --
599
600    #[test]
601    fn test_gmres_nonsymmetric() {
602        // Non-symmetric system
603        let a = csr_from_dense(
604            &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
605            3, 3,
606        );
607        let b = vec![1.0, 2.0, 3.0];
608
609        let result = gmres_solve(&a, &b, 1e-10, 100, 30);
610        assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
611
612        let ax = spmv(&a, &result.x);
613        for i in 0..3 {
614            assert!(
615                (ax[i] - b[i]).abs() < 1e-8,
616                "GMRES mismatch at i={}: ax={} b={}",
617                i, ax[i], b[i]
618            );
619        }
620    }
621
622    #[test]
623    fn test_gmres_identity() {
624        let a = csr_from_dense(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], 3, 3);
625        let b = vec![5.0, 6.0, 7.0];
626        let result = gmres_solve(&a, &b, 1e-12, 10, 10);
627        assert!(result.converged, "GMRES did not converge: residual={}", result.residual);
628        for i in 0..3 {
629            assert!((result.x[i] - b[i]).abs() < 1e-10,
630                "GMRES identity mismatch at i={}: x={} b={}", i, result.x[i], b[i]);
631        }
632    }
633
634    #[test]
635    fn test_gmres_zero_rhs() {
636        let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
637        let b = vec![0.0, 0.0];
638        let result = gmres_solve(&a, &b, 1e-10, 100, 10);
639        assert!(result.converged);
640        assert_eq!(result.iterations, 0);
641    }
642
643    #[test]
644    fn test_gmres_determinism() {
645        let a = csr_from_dense(
646            &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
647            3, 3,
648        );
649        let b = vec![1.0, 2.0, 3.0];
650
651        let r1 = gmres_solve(&a, &b, 1e-10, 100, 30);
652        let r2 = gmres_solve(&a, &b, 1e-10, 100, 30);
653
654        assert_eq!(r1.x, r2.x, "GMRES is not deterministic");
655        assert_eq!(r1.iterations, r2.iterations);
656    }
657
658    // -- BiCGSTAB tests --
659
660    #[test]
661    fn test_bicgstab_nonsymmetric() {
662        let a = csr_from_dense(
663            &[4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 0.0, 2.0],
664            3, 3,
665        );
666        let b = vec![1.0, 2.0, 3.0];
667
668        let result = bicgstab_solve(&a, &b, 1e-10, 100);
669        assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
670
671        let ax = spmv(&a, &result.x);
672        for i in 0..3 {
673            assert!(
674                (ax[i] - b[i]).abs() < 1e-8,
675                "BiCGSTAB mismatch at i={}: ax={} b={}",
676                i, ax[i], b[i]
677            );
678        }
679    }
680
681    #[test]
682    fn test_bicgstab_spd() {
683        // BiCGSTAB should also work for SPD systems
684        let a = tridiag_spd(10);
685        let b: Vec<f64> = (1..=10).map(|i| i as f64).collect();
686
687        let result = bicgstab_solve(&a, &b, 1e-10, 200);
688        assert!(result.converged, "BiCGSTAB did not converge: residual={}", result.residual);
689
690        let ax = spmv(&a, &result.x);
691        for i in 0..10 {
692            assert!(
693                (ax[i] - b[i]).abs() < 1e-8,
694                "BiCGSTAB SPD mismatch at i={}",
695                i
696            );
697        }
698    }
699
700    #[test]
701    fn test_bicgstab_identity() {
702        let a = csr_from_dense(&[1.0, 0.0, 0.0, 1.0], 2, 2);
703        let b = vec![3.0, 7.0];
704        let result = bicgstab_solve(&a, &b, 1e-12, 10);
705        assert!(result.converged);
706        assert!((result.x[0] - 3.0).abs() < 1e-10);
707        assert!((result.x[1] - 7.0).abs() < 1e-10);
708    }
709
710    #[test]
711    fn test_bicgstab_zero_rhs() {
712        let a = csr_from_dense(&[2.0, 1.0, 0.0, 3.0], 2, 2);
713        let b = vec![0.0, 0.0];
714        let result = bicgstab_solve(&a, &b, 1e-10, 100);
715        assert!(result.converged);
716        assert_eq!(result.iterations, 0);
717    }
718
719    #[test]
720    fn test_bicgstab_determinism() {
721        let a = tridiag_spd(15);
722        let b: Vec<f64> = (0..15).map(|i| (i as f64 * 0.7).cos()).collect();
723
724        let r1 = bicgstab_solve(&a, &b, 1e-10, 200);
725        let r2 = bicgstab_solve(&a, &b, 1e-10, 200);
726
727        assert_eq!(r1.x, r2.x, "BiCGSTAB is not deterministic");
728        assert_eq!(r1.iterations, r2.iterations);
729        assert_eq!(r1.residual, r2.residual);
730    }
731
732    // -- Cross-solver agreement --
733
734    #[test]
735    fn test_solvers_agree_on_spd_system() {
736        let a = tridiag_spd(8);
737        let b: Vec<f64> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
738
739        let cg = cg_solve(&a, &b, 1e-12, 200);
740        let gmres = gmres_solve(&a, &b, 1e-12, 200, 20);
741        let bicg = bicgstab_solve(&a, &b, 1e-12, 200);
742
743        assert!(cg.converged);
744        assert!(gmres.converged);
745        assert!(bicg.converged);
746
747        // All three should produce the same solution (within tolerance)
748        for i in 0..8 {
749            assert!(
750                (cg.x[i] - gmres.x[i]).abs() < 1e-8,
751                "CG vs GMRES disagree at i={}: {} vs {}",
752                i, cg.x[i], gmres.x[i]
753            );
754            assert!(
755                (cg.x[i] - bicg.x[i]).abs() < 1e-8,
756                "CG vs BiCGSTAB disagree at i={}: {} vs {}",
757                i, cg.x[i], bicg.x[i]
758            );
759        }
760    }
761}