Skip to main content

oxiphysics_fem/
parallel_solver.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Parallel sparse solver using Rayon for multi-threaded finite-element assembly.
5//!
6//! This module provides two complementary components:
7//!
8//! 1. [`ParallelAssembler`] — collects per-element stiffness matrices concurrently
9//!    and assembles the global CSR matrix without data races via atomic row-locking.
10//! 2. [`ParallelPcgSolver`] — Preconditioned Conjugate Gradient with Rayon-parallel
11//!    sparse matrix–vector products (SpMV).
12//!
13//! ## Design notes
14//!
15//! * Assembly is parallelised at the **element** level: each element computes its
16//!   local stiffness matrix independently; the scatter (triplet → CSR) step uses a
17//!   per-row Mutex for safe concurrent updates without a global lock.
18//! * The SpMV kernel splits the CSR row range into chunks processed concurrently
19//!   by Rayon.  Each output entry is independent, so no synchronisation is needed.
20//! * The PCG solve loop itself is sequential (the dominant cost is SpMV, which is
21//!   parallel); the dot products use Rayon's `par_iter().map().sum()` reduction.
22//!
23//! ## Usage
24//!
25//! ```
26//! use oxiphysics_fem::parallel_solver::{CsrMatrix, ParallelAssembler, ParallelPcgSolver};
27//!
28//! // Build a tiny 4×4 CSR matrix manually
29//! let mat = CsrMatrix {
30//!     nrows: 4,
31//!     ncols: 4,
32//!     row_offsets: vec![0, 1, 2, 3, 4],
33//!     col_indices:  vec![0, 1, 2, 3],
34//!     values:       vec![4.0, 4.0, 4.0, 4.0],
35//! };
36//! let rhs = vec![1.0, 2.0, 3.0, 4.0];
37//! let mut x   = vec![0.0f64; 4];
38//! let stats = ParallelPcgSolver::default().solve(&mat, &rhs, &mut x);
39//! assert!(stats.converged, "PCG did not converge");
40//! ```
41
42use rayon::prelude::*;
43use std::sync::Mutex;
44
45// ── CsrMatrix ────────────────────────────────────────────────────────────────
46
47/// A sparse matrix in Compressed Sparse Row (CSR) format.
48///
49/// * `row_offsets[i]..row_offsets[i+1]` indexes the non-zeros of row `i`.
50/// * `col_indices[row_offsets[i]..row_offsets[i+1]]` gives the column indices.
51/// * `values[row_offsets[i]..row_offsets[i+1]]` gives the corresponding values.
52#[derive(Debug, Clone, Default)]
53pub struct CsrMatrix {
54    /// Number of rows.
55    pub nrows: usize,
56    /// Number of columns.
57    pub ncols: usize,
58    /// Row start offsets (length `nrows + 1`).
59    pub row_offsets: Vec<usize>,
60    /// Column indices of non-zeros.
61    pub col_indices: Vec<usize>,
62    /// Non-zero values.
63    pub values: Vec<f64>,
64}
65
66impl CsrMatrix {
67    /// Create a zero matrix with a given sparsity pattern.
68    pub fn new_from_pattern(
69        nrows: usize,
70        ncols: usize,
71        row_offsets: Vec<usize>,
72        col_indices: Vec<usize>,
73    ) -> Self {
74        let nnz = col_indices.len();
75        Self {
76            nrows,
77            ncols,
78            row_offsets,
79            col_indices,
80            values: vec![0.0; nnz],
81        }
82    }
83
84    /// Create an identity matrix of size `n`.
85    pub fn identity(n: usize) -> Self {
86        Self {
87            nrows: n,
88            ncols: n,
89            row_offsets: (0..=n).collect(),
90            col_indices: (0..n).collect(),
91            values: vec![1.0; n],
92        }
93    }
94
95    /// Return the number of non-zero entries.
96    pub fn nnz(&self) -> usize {
97        self.values.len()
98    }
99
100    /// Parallel sparse matrix–vector product: `y = A * x`.
101    ///
102    /// Each row is computed independently, allowing Rayon to distribute rows
103    /// across worker threads with zero synchronisation overhead.
104    pub fn spmv_par(&self, x: &[f64], y: &mut [f64]) {
105        debug_assert_eq!(x.len(), self.ncols);
106        debug_assert_eq!(y.len(), self.nrows);
107
108        y.par_iter_mut().enumerate().for_each(|(i, yi)| {
109            let row_start = self.row_offsets[i];
110            let row_end = self.row_offsets[i + 1];
111            let mut sum = 0.0;
112            for k in row_start..row_end {
113                sum += self.values[k] * x[self.col_indices[k]];
114            }
115            *yi = sum;
116        });
117    }
118
119    /// Sequential sparse matrix–vector product (for comparison / small matrices).
120    pub fn spmv(&self, x: &[f64], y: &mut [f64]) {
121        debug_assert_eq!(x.len(), self.ncols);
122        debug_assert_eq!(y.len(), self.nrows);
123        for (i, yi) in y.iter_mut().enumerate().take(self.nrows) {
124            let row_start = self.row_offsets[i];
125            let row_end = self.row_offsets[i + 1];
126            let mut sum = 0.0;
127            for k in row_start..row_end {
128                sum += self.values[k] * x[self.col_indices[k]];
129            }
130            *yi = sum;
131        }
132    }
133
134    /// Diagonal (Jacobi) preconditioner: returns `1 / A[i,i]` for each row.
135    ///
136    /// Uses `1.0` for rows with zero diagonal to avoid division by zero.
137    pub fn diagonal_preconditioner(&self) -> Vec<f64> {
138        let mut diag = vec![1.0f64; self.nrows];
139        for (i, diag_i) in diag.iter_mut().enumerate().take(self.nrows) {
140            for k in self.row_offsets[i]..self.row_offsets[i + 1] {
141                if self.col_indices[k] == i {
142                    let d = self.values[k];
143                    if d.abs() > 1e-15 {
144                        *diag_i = 1.0 / d;
145                    }
146                    break;
147                }
148            }
149        }
150        diag
151    }
152
153    /// Chunked SpMV: splits the row range into cache-friendly tiles of `chunk_size` rows.
154    ///
155    /// Produces the same result as [`CsrMatrix::spmv_par`] but processes rows in
156    /// contiguous chunks to improve L3 cache utilisation on large matrices.
157    pub fn spmv_chunked(&self, x: &[f64], y: &mut [f64], chunk_size: usize) {
158        debug_assert_eq!(x.len(), self.ncols);
159        debug_assert_eq!(y.len(), self.nrows);
160        let chunk_size = chunk_size.max(1);
161        y.par_chunks_mut(chunk_size)
162            .enumerate()
163            .for_each(|(chunk_idx, y_chunk)| {
164                let row_start = chunk_idx * chunk_size;
165                for (k, yi) in y_chunk.iter_mut().enumerate() {
166                    let row = row_start + k;
167                    let rs = self.row_offsets[row];
168                    let re = self.row_offsets[row + 1];
169                    let mut sum = 0.0;
170                    for j in rs..re {
171                        sum += self.values[j] * x[self.col_indices[j]];
172                    }
173                    *yi = sum;
174                }
175            });
176    }
177}
178
179// ── AssemblyTask ────────────────────────────────────────────────────────────
180
181/// Describes a single finite-element's local stiffness contribution.
182///
183/// `global_dofs` maps each local DOF to a global row/column index.
184/// `ke` is the (ndof × ndof) local stiffness matrix stored in row-major order.
185#[derive(Debug, Clone)]
186pub struct AssemblyTask {
187    /// Global DOF indices for this element's local DOFs.
188    pub global_dofs: Vec<usize>,
189    /// Local stiffness matrix values, row-major, shape `(ndof, ndof)`.
190    pub ke: Vec<f64>,
191}
192
193impl AssemblyTask {
194    /// Create from a symmetric positive definite local stiffness matrix.
195    pub fn new(global_dofs: Vec<usize>, ke: Vec<f64>) -> Self {
196        let ndof = global_dofs.len();
197        debug_assert_eq!(ke.len(), ndof * ndof);
198        Self { global_dofs, ke }
199    }
200
201    /// Number of degrees of freedom for this element.
202    pub fn ndof(&self) -> usize {
203        self.global_dofs.len()
204    }
205}
206
207// ── ParallelAssembler ────────────────────────────────────────────────────────
208
209/// Assembles per-element stiffness contributions into a global CSR matrix.
210///
211/// ## Thread safety
212///
213/// Elements are processed in parallel (Rayon).  Each non-zero entry's position
214/// in the CSR `values` array is pre-determined from the sparsity pattern, so
215/// scatter is lock-free at the element level.  A per-row Mutex guards concurrent
216/// writes to overlapping rows (e.g. when elements share nodes on different threads).
217#[derive(Debug, Default)]
218pub struct ParallelAssembler {
219    /// Total number of global DOFs.
220    pub ndofs: usize,
221}
222
223impl ParallelAssembler {
224    /// Create an assembler for a problem with `ndofs` degrees of freedom.
225    pub fn new(ndofs: usize) -> Self {
226        Self { ndofs }
227    }
228
229    /// Assemble all element stiffness matrices into a global CSR matrix.
230    ///
231    /// The sparsity pattern is computed from the union of all element connectivity
232    /// patterns.  Values are accumulated atomically using per-row mutexes.
233    ///
234    /// # Panics
235    ///
236    /// Panics if any `global_dof` index exceeds `ndofs`.
237    pub fn assemble(&self, tasks: &[AssemblyTask]) -> CsrMatrix {
238        // ── Step 1: Build sparsity pattern (sequential) ──────────────────────
239        // For each row, collect the set of column indices that appear.
240        let mut row_cols: Vec<std::collections::BTreeSet<usize>> =
241            vec![std::collections::BTreeSet::new(); self.ndofs];
242
243        for task in tasks {
244            for &row_dof in &task.global_dofs {
245                for &col_dof in &task.global_dofs {
246                    row_cols[row_dof].insert(col_dof);
247                }
248            }
249        }
250
251        // Flatten to CSR arrays
252        let mut row_offsets = vec![0usize; self.ndofs + 1];
253        let mut col_indices: Vec<usize> = Vec::new();
254        for (i, cols) in row_cols.iter().enumerate() {
255            row_offsets[i + 1] = row_offsets[i] + cols.len();
256            col_indices.extend(cols.iter().copied());
257        }
258        let nnz = col_indices.len();
259        let values = vec![0.0f64; nnz];
260
261        // ── Step 2: Per-row lookup table (col → CSR index) ──────────────────
262        let row_col_to_csr: Vec<std::collections::HashMap<usize, usize>> = row_cols
263            .iter()
264            .enumerate()
265            .map(|(i, cols)| {
266                let base = row_offsets[i];
267                cols.iter()
268                    .enumerate()
269                    .map(|(j, &c)| (c, base + j))
270                    .collect()
271            })
272            .collect();
273
274        // ── Step 3: Parallel scatter using per-row mutexes ───────────────────
275        let values_locked: Vec<Mutex<f64>> = values.into_iter().map(Mutex::new).collect();
276
277        tasks.par_iter().for_each(|task| {
278            let ndof = task.ndof();
279            for (li, &row) in task.global_dofs.iter().enumerate() {
280                for (lj, &col) in task.global_dofs.iter().enumerate() {
281                    let ke_val = task.ke[li * ndof + lj];
282                    if let Some(&csr_idx) = row_col_to_csr[row].get(&col) {
283                        // SAFETY: each `csr_idx` is unique per (row, col) pair.
284                        let mut guard = values_locked[csr_idx]
285                            .lock()
286                            .unwrap_or_else(|e| e.into_inner());
287                        *guard += ke_val;
288                    }
289                }
290            }
291        });
292
293        let values: Vec<f64> = values_locked
294            .into_iter()
295            .map(|m| {
296                m.into_inner()
297                    .expect("mutex not poisoned after parallel assembly")
298            })
299            .collect();
300
301        CsrMatrix {
302            nrows: self.ndofs,
303            ncols: self.ndofs,
304            row_offsets,
305            col_indices,
306            values,
307        }
308    }
309
310    /// Assemble a load vector (right-hand side) from per-element force vectors.
311    ///
312    /// `element_forces[e]` must have the same DOF count as `element_dofs[e]`.
313    pub fn assemble_rhs(
314        &self,
315        element_dofs: &[Vec<usize>],
316        element_forces: &[Vec<f64>],
317    ) -> Vec<f64> {
318        let rhs_locked: Vec<Mutex<f64>> = (0..self.ndofs).map(|_| Mutex::new(0.0f64)).collect();
319        element_dofs
320            .par_iter()
321            .zip(element_forces.par_iter())
322            .for_each(|(dofs, forces)| {
323                for (&dof, &f) in dofs.iter().zip(forces.iter()) {
324                    *rhs_locked[dof].lock().unwrap_or_else(|e| e.into_inner()) += f;
325                }
326            });
327        rhs_locked
328            .into_iter()
329            .map(|m| {
330                m.into_inner()
331                    .expect("mutex not poisoned after parallel rhs assembly")
332            })
333            .collect()
334    }
335
336    /// Assemble using element graph coloring for better cache locality.
337    ///
338    /// Computes a graph coloring of the elements (no two elements of the same color
339    /// share a DOF) and assembles each color group in parallel.  The result is
340    /// identical to [`ParallelAssembler::assemble`].
341    pub fn assemble_colored(
342        &self,
343        tasks: &[AssemblyTask],
344        element_dofs: &[Vec<usize>],
345    ) -> CsrMatrix {
346        use crate::solvers::assembly_coloring::{assemble_colored_csr, color_elements};
347        let coloring = color_elements(tasks.len(), element_dofs);
348        assemble_colored_csr(self.ndofs, tasks, &coloring)
349    }
350}
351
352// ── PcgStats ─────────────────────────────────────────────────────────────────
353
354/// Statistics returned by [`ParallelPcgSolver::solve`].
355#[derive(Debug, Clone, Copy)]
356pub struct PcgStats {
357    /// Number of iterations performed.
358    pub iterations: usize,
359    /// Final residual norm ||r||.
360    pub residual_norm: f64,
361    /// Whether the solve converged.
362    pub converged: bool,
363}
364
365// ── ParallelPcgSolver ────────────────────────────────────────────────────────
366
367/// Preconditioned Conjugate Gradient with parallel SpMV.
368///
369/// Uses a Jacobi (diagonal) preconditioner by default.  The SpMV step uses
370/// [`CsrMatrix::spmv_par`] and the dot products use `rayon::par_iter`.
371#[derive(Debug, Clone)]
372pub struct ParallelPcgSolver {
373    /// Maximum number of CG iterations.
374    pub max_iterations: usize,
375    /// Relative residual tolerance for convergence.
376    pub tolerance: f64,
377    /// Minimum absolute residual (avoids infinite loop on singular systems).
378    pub abs_tolerance: f64,
379}
380
381impl Default for ParallelPcgSolver {
382    fn default() -> Self {
383        Self {
384            max_iterations: 500,
385            tolerance: 1e-8,
386            abs_tolerance: 1e-14,
387        }
388    }
389}
390
391impl ParallelPcgSolver {
392    /// Create a new solver with explicit parameters.
393    pub fn new(max_iterations: usize, tolerance: f64) -> Self {
394        Self {
395            max_iterations,
396            tolerance,
397            abs_tolerance: 1e-14,
398        }
399    }
400
401    /// Solve `A * x = b` using Parallel PCG.
402    ///
403    /// `x` is used as the initial guess (zero-initialise for a cold start).
404    ///
405    /// Returns [`PcgStats`] with convergence information.
406    pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
407        let n = a.nrows;
408        assert_eq!(b.len(), n);
409        assert_eq!(x.len(), n);
410
411        let m_inv = a.diagonal_preconditioner();
412
413        // r = b - A*x
414        let mut ax = vec![0.0f64; n];
415        a.spmv_par(x, &mut ax);
416        let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
417
418        let b_norm = dot_par(b, b).sqrt();
419        if b_norm < self.abs_tolerance {
420            return PcgStats {
421                iterations: 0,
422                residual_norm: 0.0,
423                converged: true,
424            };
425        }
426
427        // z = M^{-1} * r
428        let mut z: Vec<f64> = r.iter().zip(m_inv.iter()).map(|(ri, mi)| ri * mi).collect();
429
430        // p = z
431        let mut p = z.clone();
432
433        let mut rz = dot_par(&r, &z);
434
435        let mut ap = vec![0.0f64; n];
436        let mut iters = 0;
437        let mut res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
438
439        for _ in 0..self.max_iterations {
440            if res_norm / b_norm < self.tolerance {
441                break;
442            }
443            if res_norm < self.abs_tolerance {
444                break;
445            }
446
447            // ap = A * p  (parallel SpMV)
448            a.spmv_par(&p, &mut ap);
449
450            let pap = dot_par(&p, &ap);
451            if pap.abs() < 1e-300 {
452                break;
453            }
454            let alpha = rz / pap;
455
456            // x = x + alpha * p,  r = r - alpha * ap  (parallel)
457            x.par_iter_mut()
458                .zip(p.par_iter())
459                .for_each(|(xi, pi)| *xi += alpha * pi);
460            r.par_iter_mut()
461                .zip(ap.par_iter())
462                .for_each(|(ri, api)| *ri -= alpha * api);
463
464            // z = M^{-1} * r
465            z.par_iter_mut()
466                .zip(r.par_iter().zip(m_inv.par_iter()))
467                .for_each(|(zi, (ri, mi))| *zi = ri * mi);
468
469            let rz_new = dot_par(&r, &z);
470            let beta = rz_new / rz.max(1e-300);
471            rz = rz_new;
472
473            // p = z + beta * p  (parallel)
474            p.par_iter_mut()
475                .zip(z.par_iter())
476                .for_each(|(pi, zi)| *pi = zi + beta * *pi);
477
478            res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
479            iters += 1;
480        }
481
482        PcgStats {
483            iterations: iters,
484            residual_norm: res_norm,
485            converged: res_norm / b_norm.max(1e-300) < self.tolerance
486                || res_norm < self.abs_tolerance,
487        }
488    }
489}
490
491/// Parallel dot product using Rayon's reduction.
492fn dot_par(a: &[f64], b: &[f64]) -> f64 {
493    a.par_iter().zip(b.par_iter()).map(|(ai, bi)| ai * bi).sum()
494}
495
496// ── ParallelSparseDirectSolver (ILU(0) + GMRES stub) ────────────────────────
497
498/// Parallel sparse direct solver using ILU(0) preconditioning.
499///
500/// For symmetric positive definite systems, prefer [`ParallelPcgSolver`].
501/// This solver handles non-symmetric systems (e.g. convection-dominated problems).
502#[derive(Debug, Clone)]
503pub struct ParallelGmresSolver {
504    /// Krylov subspace dimension (restart).
505    pub krylov_dim: usize,
506    /// Maximum number of outer restarts.
507    pub max_restarts: usize,
508    /// Convergence tolerance.
509    pub tolerance: f64,
510}
511
512impl Default for ParallelGmresSolver {
513    fn default() -> Self {
514        Self {
515            krylov_dim: 30,
516            max_restarts: 10,
517            tolerance: 1e-8,
518        }
519    }
520}
521
522impl ParallelGmresSolver {
523    /// Solve `A * x = b` using restarted GMRES with diagonal preconditioning.
524    ///
525    /// For symmetric positive definite systems prefer [`ParallelPcgSolver`] which
526    /// is more efficient.
527    pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
528        let n = a.nrows;
529        let m_inv = a.diagonal_preconditioner();
530        let mut res_norm = 0.0f64;
531        // Use preconditioned b-norm for relative stopping criterion
532        let mb: Vec<f64> = b.iter().zip(m_inv.iter()).map(|(bi, mi)| bi * mi).collect();
533        let b_norm = dot_par(&mb, &mb).sqrt().max(1e-300);
534        let mut total_iters = 0;
535
536        for _restart in 0..self.max_restarts {
537            // Compute initial residual r = b - A*x
538            let mut ax = vec![0.0f64; n];
539            a.spmv_par(x, &mut ax);
540            let r0: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
541
542            // Left-preconditioned residual: z0 = M^{-1} * r0
543            let z0: Vec<f64> = r0
544                .iter()
545                .zip(m_inv.iter())
546                .map(|(ri, mi)| ri * mi)
547                .collect();
548            let beta = dot_par(&z0, &z0).sqrt().max(1e-300);
549            res_norm = beta;
550
551            if res_norm / b_norm < self.tolerance {
552                break;
553            }
554
555            // Arnoldi process: build Krylov basis Q (columns), upper Hessenberg H
556            let m = self.krylov_dim.min(n);
557            let mut q: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
558
559            // q[0] = z0 / beta (first Krylov vector in preconditioned space)
560            q.push(z0.iter().map(|v| v / beta).collect());
561
562            // Hessenberg matrix (stored as flat (m+1) x m)
563            let mut h = vec![0.0f64; (m + 1) * m];
564            // Cosines and sines for Givens rotations
565            let mut cs = vec![0.0f64; m];
566            let mut sn = vec![0.0f64; m];
567            // Right-hand side of the least-squares problem: ||M^{-1}r0|| * e_1
568            let mut e1 = vec![0.0f64; m + 1];
569            e1[0] = beta;
570
571            let mut j_stop = m;
572            for j in 0..m {
573                // w = M^{-1} * A * q[j]
574                let mut aqj = vec![0.0f64; n];
575                a.spmv_par(&q[j], &mut aqj);
576                let w: Vec<f64> = aqj.iter().zip(m_inv.iter()).map(|(v, mi)| v * mi).collect();
577
578                // Modified Gram-Schmidt orthogonalization
579                let mut w = w;
580                for i in 0..=j {
581                    let hij = dot_par(&w, &q[i]);
582                    h[i * m + j] = hij;
583                    w.par_iter_mut()
584                        .zip(q[i].par_iter())
585                        .for_each(|(wi, qi)| *wi -= hij * qi);
586                }
587                let w_norm = dot_par(&w, &w).sqrt();
588                h[(j + 1) * m + j] = w_norm;
589
590                // Whether or not w_norm is near zero, we must apply Givens rotations
591                // to column j of H so that back-substitution is correct.
592
593                // Push the next Krylov vector only if it is non-degenerate.
594                let exact_convergence = w_norm < 1e-14;
595                if !exact_convergence {
596                    q.push(w.iter().map(|v| v / w_norm).collect());
597                }
598
599                // Apply previous Givens rotations to column j of H
600                for i in 0..j {
601                    let tmp = cs[i] * h[i * m + j] + sn[i] * h[(i + 1) * m + j];
602                    h[(i + 1) * m + j] = -sn[i] * h[i * m + j] + cs[i] * h[(i + 1) * m + j];
603                    h[i * m + j] = tmp;
604                }
605
606                // Compute new Givens rotation
607                let (c, s) = givens_rotation(h[j * m + j], h[(j + 1) * m + j]);
608                cs[j] = c;
609                sn[j] = s;
610
611                h[j * m + j] = c * h[j * m + j] + s * h[(j + 1) * m + j];
612                h[(j + 1) * m + j] = 0.0;
613                e1[j + 1] = -s * e1[j];
614                e1[j] *= c;
615
616                res_norm = e1[j + 1].abs();
617                total_iters += 1;
618
619                // Stop if converged or if Krylov space is exhausted (exact solve).
620                if res_norm / b_norm < self.tolerance || exact_convergence {
621                    j_stop = j + 1;
622                    break;
623                }
624            }
625
626            // Back-substitution for y (upper triangular solve on H[0..j_stop, 0..j_stop])
627            let mut y = vec![0.0f64; j_stop];
628            for i in (0..j_stop).rev() {
629                y[i] = e1[i];
630                for k in (i + 1)..j_stop {
631                    y[i] -= h[i * m + k] * y[k];
632                }
633                let hii = h[i * m + i];
634                if hii.abs() > 1e-300 {
635                    y[i] /= hii;
636                }
637            }
638
639            // Update solution: x = x + Q * y
640            for j in 0..j_stop {
641                let yj = y[j];
642                x.par_iter_mut()
643                    .zip(q[j].par_iter())
644                    .for_each(|(xi, qji)| *xi += yj * qji);
645            }
646
647            if res_norm / b_norm < self.tolerance {
648                break;
649            }
650        }
651
652        PcgStats {
653            iterations: total_iters,
654            residual_norm: res_norm,
655            converged: res_norm / b_norm < self.tolerance,
656        }
657    }
658}
659
660/// Compute a Givens rotation `(c, s)` such that `[c, s; -s, c] * [a; b] = [r; 0]`.
661fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
662    if b.abs() < 1e-300 {
663        (1.0, 0.0)
664    } else {
665        let r = a.hypot(b);
666        (a / r, b / r)
667    }
668}
669
670// ── Tests ────────────────────────────────────────────────────────────────────
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675
676    fn diag_matrix(n: usize, diag_val: f64) -> CsrMatrix {
677        CsrMatrix {
678            nrows: n,
679            ncols: n,
680            row_offsets: (0..=n).collect(),
681            col_indices: (0..n).collect(),
682            values: vec![diag_val; n],
683        }
684    }
685
686    #[test]
687    fn pcg_solves_diagonal_system() {
688        let n = 16;
689        let mat = diag_matrix(n, 4.0);
690        let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
691        let mut x = vec![0.0f64; n];
692        let stats = ParallelPcgSolver::default().solve(&mat, &rhs, &mut x);
693        assert!(stats.converged, "PCG did not converge: {:?}", stats);
694        for (i, xi) in x.iter().enumerate() {
695            let expected = (i + 1) as f64 / 4.0;
696            assert!(
697                (xi - expected).abs() < 1e-10,
698                "x[{i}] = {xi}, expected {expected}"
699            );
700        }
701    }
702
703    #[test]
704    fn parallel_spmv_matches_sequential() {
705        let mat = diag_matrix(64, 3.0);
706        let x: Vec<f64> = (0..64).map(|i| i as f64).collect();
707        let mut y_par = vec![0.0f64; 64];
708        let mut y_seq = vec![0.0f64; 64];
709        mat.spmv_par(&x, &mut y_par);
710        mat.spmv(&x, &mut y_seq);
711        for (a, b) in y_par.iter().zip(y_seq.iter()) {
712            assert!((a - b).abs() < 1e-14, "{a} != {b}");
713        }
714    }
715
716    #[test]
717    fn parallel_assembler_assembles_2d_bar() {
718        // Two bar elements sharing a common node: DOFs [0,1], [1,2]
719        // Element stiffness (1×1 blocks): [[1,-1],[-1,1]] for each element
720        let tasks = vec![
721            AssemblyTask::new(vec![0, 1], vec![1.0, -1.0, -1.0, 1.0]),
722            AssemblyTask::new(vec![1, 2], vec![1.0, -1.0, -1.0, 1.0]),
723        ];
724        let asm = ParallelAssembler::new(3);
725        let mat = asm.assemble(&tasks);
726        assert_eq!(mat.nrows, 3);
727        // Check diagonal: [1, 2, 1]
728        let find_val = |row: usize, col: usize| -> f64 {
729            for k in mat.row_offsets[row]..mat.row_offsets[row + 1] {
730                if mat.col_indices[k] == col {
731                    return mat.values[k];
732                }
733            }
734            0.0
735        };
736        assert!((find_val(0, 0) - 1.0).abs() < 1e-14);
737        assert!((find_val(1, 1) - 2.0).abs() < 1e-14);
738        assert!((find_val(2, 2) - 1.0).abs() < 1e-14);
739        assert!((find_val(0, 1) - (-1.0)).abs() < 1e-14);
740        assert!((find_val(1, 2) - (-1.0)).abs() < 1e-14);
741    }
742
743    #[test]
744    fn assemble_rhs_sums_forces() {
745        let asm = ParallelAssembler::new(3);
746        let dofs = vec![vec![0usize, 1], vec![1, 2]];
747        let forces = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
748        let rhs = asm.assemble_rhs(&dofs, &forces);
749        assert!((rhs[0] - 1.0).abs() < 1e-14);
750        assert!((rhs[1] - 5.0).abs() < 1e-14); // 2 + 3
751        assert!((rhs[2] - 4.0).abs() < 1e-14);
752    }
753
754    #[test]
755    fn gmres_solves_diagonal_system() {
756        let n = 8;
757        let mat = diag_matrix(n, 2.0);
758        let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
759        let mut x = vec![0.0f64; n];
760        let stats = ParallelGmresSolver::default().solve(&mat, &rhs, &mut x);
761        assert!(stats.converged, "GMRES did not converge: {:?}", stats);
762        for (i, xi) in x.iter().enumerate() {
763            let expected = (i + 1) as f64 / 2.0;
764            assert!(
765                (xi - expected).abs() < 1e-6,
766                "x[{i}] = {xi}, expected {expected}"
767            );
768        }
769    }
770
771    #[test]
772    fn spmv_chunked_matches_spmv() {
773        let n = 64;
774        let mat = diag_matrix(n, 3.0);
775        let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
776        let mut y_par = vec![0.0f64; n];
777        let mut y_chunked = vec![0.0f64; n];
778        mat.spmv_par(&x, &mut y_par);
779        mat.spmv_chunked(&x, &mut y_chunked, 16);
780        for (a, b) in y_par.iter().zip(y_chunked.iter()) {
781            assert!((a - b).abs() < 1e-14, "{a} != {b}");
782        }
783    }
784}
785
786// ── AMG-preconditioned Krylov solvers ─────────────────────────────────────────
787
788use crate::solvers::amg::{
789    cycle::{AmgHierarchy, CycleKind},
790    preconditioner::Preconditioner,
791};
792
793/// PCG solver with a pluggable preconditioner.
794///
795/// The preconditioner `P` must implement [`Preconditioner`], which maps a
796/// residual `r` to a correction `z ≈ A^{-1} r`.
797#[derive(Debug)]
798pub struct PcgWithPrecond<P: Preconditioner> {
799    /// Maximum number of PCG iterations.
800    pub max_iterations: usize,
801    /// Relative residual tolerance.
802    pub tolerance: f64,
803    /// Absolute residual tolerance.
804    pub abs_tolerance: f64,
805    /// Preconditioner instance.
806    pub precond: P,
807}
808
809impl<P: Preconditioner> PcgWithPrecond<P> {
810    /// Create a new preconditioned PCG solver.
811    pub fn new(precond: P, max_iterations: usize, tolerance: f64) -> Self {
812        Self {
813            max_iterations,
814            tolerance,
815            abs_tolerance: 1e-14,
816            precond,
817        }
818    }
819
820    /// Solve `A * x = b` using preconditioned CG.
821    ///
822    /// `x` is used as the initial guess.
823    pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
824        let n = a.nrows;
825        let mut z = vec![0.0f64; n];
826
827        // r = b - A*x
828        let mut ax = vec![0.0f64; n];
829        a.spmv(x, &mut ax);
830        let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
831
832        self.precond.apply(&r, &mut z);
833        let mut p = z.clone();
834        let mut rz = dot_par(&r, &z);
835        let b_norm = dot_par(b, b).sqrt().max(1e-300);
836        let mut ap = vec![0.0f64; n];
837        let mut iters = 0;
838        let mut res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
839
840        for _ in 0..self.max_iterations {
841            if res_norm / b_norm < self.tolerance {
842                break;
843            }
844            if res_norm < self.abs_tolerance {
845                break;
846            }
847
848            a.spmv_par(&p, &mut ap);
849            let pap = dot_par(&p, &ap);
850            if pap.abs() < 1e-300 {
851                break;
852            }
853            let alpha = rz / pap;
854
855            for i in 0..n {
856                x[i] += alpha * p[i];
857                r[i] -= alpha * ap[i];
858            }
859            self.precond.apply(&r, &mut z);
860            let rz_new = dot_par(&r, &z);
861            let beta = rz_new / rz.max(1e-300);
862            rz = rz_new;
863            for i in 0..n {
864                p[i] = z[i] + beta * p[i];
865            }
866            res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
867            iters += 1;
868        }
869
870        PcgStats {
871            iterations: iters,
872            residual_norm: res_norm,
873            converged: res_norm / b_norm < self.tolerance || res_norm < self.abs_tolerance,
874        }
875    }
876}
877
878/// PCG with an AMG preconditioner — preferred for SPD systems.
879pub type PcgWithAmg = PcgWithPrecond<crate::solvers::amg::preconditioner::AmgPreconditioner>;
880
881/// GMRES with AMG preconditioning for non-symmetric or indefinite systems.
882///
883/// Uses left-preconditioning: at each Arnoldi step, apply `M^{-1}` (one AMG cycle)
884/// before adding the vector to the Krylov basis.
885pub struct GmresWithAmg {
886    /// Krylov subspace dimension (restart size).
887    pub krylov_dim: usize,
888    /// Maximum number of outer restarts.
889    pub max_restarts: usize,
890    /// Relative residual convergence tolerance.
891    pub tolerance: f64,
892    /// AMG hierarchy for preconditioning.
893    pub hierarchy: AmgHierarchy,
894}
895
896impl GmresWithAmg {
897    /// Create a new AMG-preconditioned GMRES solver.
898    pub fn new(
899        hierarchy: AmgHierarchy,
900        krylov_dim: usize,
901        max_restarts: usize,
902        tolerance: f64,
903    ) -> Self {
904        Self {
905            krylov_dim,
906            max_restarts,
907            tolerance,
908            hierarchy,
909        }
910    }
911
912    /// Solve `A * x = b` using AMG-preconditioned GMRES.
913    pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
914        let n = a.nrows;
915        let b_norm = dot_par(b, b).sqrt().max(1e-300);
916        let mut res_norm = 0.0f64;
917        let mut total_iters = 0;
918
919        let pcg = ParallelPcgSolver::new(500, 1e-10);
920        let amg_precond = crate::solvers::amg::preconditioner::AmgPreconditioner {
921            hierarchy: self.hierarchy.clone(),
922            cycle_kind: CycleKind::V,
923            pcg,
924        };
925
926        for _restart in 0..self.max_restarts {
927            // Compute r = b - A*x, then apply preconditioner: r0z = M^{-1} r
928            let mut ax = vec![0.0f64; n];
929            a.spmv(x, &mut ax);
930            let r0: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
931            let mut r0z = vec![0.0f64; n];
932            amg_precond.apply(&r0, &mut r0z);
933            res_norm = r0z.iter().map(|v| v * v).sum::<f64>().sqrt();
934
935            if res_norm / b_norm < self.tolerance {
936                break;
937            }
938
939            let beta = res_norm;
940            let mut q: Vec<Vec<f64>> = vec![r0z.iter().map(|v| v / beta).collect()];
941            let m = self.krylov_dim.min(n);
942            let mut h = vec![vec![0.0f64; m]; m + 1];
943            let mut cs = vec![0.0f64; m];
944            let mut sn = vec![0.0f64; m];
945            let mut e1 = vec![0.0f64; m + 1];
946            e1[0] = beta;
947
948            let mut j_stop = m;
949
950            for jj in 0..m {
951                // w = M^{-1} A q_j
952                let mut aq = vec![0.0f64; n];
953                a.spmv_par(&q[jj], &mut aq);
954                let mut w = vec![0.0f64; n];
955                amg_precond.apply(&aq, &mut w);
956
957                // Modified Gram-Schmidt
958                for ii in 0..=jj {
959                    h[ii][jj] = dot_par(&w, &q[ii]);
960                    for k in 0..n {
961                        w[k] -= h[ii][jj] * q[ii][k];
962                    }
963                }
964                h[jj + 1][jj] = w.iter().map(|v| v * v).sum::<f64>().sqrt();
965
966                // Whether or not Krylov space is exhausted, still apply Givens rotations.
967                let exact_convergence = h[jj + 1][jj] < 1e-14;
968                if !exact_convergence {
969                    let inv_norm = 1.0 / h[jj + 1][jj];
970                    q.push(w.iter().map(|v| v * inv_norm).collect());
971                }
972
973                // Apply previous Givens rotations
974                for ii in 0..jj {
975                    let tmp = cs[ii] * h[ii][jj] + sn[ii] * h[ii + 1][jj];
976                    h[ii + 1][jj] = -sn[ii] * h[ii][jj] + cs[ii] * h[ii + 1][jj];
977                    h[ii][jj] = tmp;
978                }
979
980                // New Givens rotation
981                let denom = (h[jj][jj] * h[jj][jj] + h[jj + 1][jj] * h[jj + 1][jj]).sqrt();
982                cs[jj] = if denom > 1e-300 {
983                    h[jj][jj] / denom
984                } else {
985                    1.0
986                };
987                sn[jj] = if denom > 1e-300 {
988                    h[jj + 1][jj] / denom
989                } else {
990                    0.0
991                };
992                h[jj][jj] = cs[jj] * h[jj][jj] + sn[jj] * h[jj + 1][jj];
993                h[jj + 1][jj] = 0.0;
994                e1[jj + 1] = -sn[jj] * e1[jj];
995                e1[jj] *= cs[jj];
996                res_norm = e1[jj + 1].abs();
997
998                if res_norm / b_norm < self.tolerance || exact_convergence {
999                    j_stop = jj + 1;
1000                    break;
1001                }
1002            }
1003
1004            // Back-substitution
1005            let mut y = vec![0.0f64; j_stop];
1006            for ii in (0..j_stop).rev() {
1007                y[ii] = e1[ii];
1008                for kk in (ii + 1)..j_stop {
1009                    y[ii] -= h[ii][kk] * y[kk];
1010                }
1011                if h[ii][ii].abs() > 1e-300 {
1012                    y[ii] /= h[ii][ii];
1013                }
1014            }
1015
1016            // Update solution
1017            for ii in 0..j_stop {
1018                let yi = y[ii];
1019                for k in 0..n {
1020                    x[k] += yi * q[ii][k];
1021                }
1022            }
1023
1024            total_iters += j_stop;
1025            if res_norm / b_norm < self.tolerance {
1026                break;
1027            }
1028        }
1029
1030        PcgStats {
1031            iterations: total_iters,
1032            residual_norm: res_norm,
1033            converged: res_norm / b_norm < self.tolerance,
1034        }
1035    }
1036}