1use rayon::prelude::*;
43use std::sync::Mutex;
44
45#[derive(Debug, Clone, Default)]
53pub struct CsrMatrix {
54 pub nrows: usize,
56 pub ncols: usize,
58 pub row_offsets: Vec<usize>,
60 pub col_indices: Vec<usize>,
62 pub values: Vec<f64>,
64}
65
66impl CsrMatrix {
67 pub fn new_from_pattern(
69 nrows: usize,
70 ncols: usize,
71 row_offsets: Vec<usize>,
72 col_indices: Vec<usize>,
73 ) -> Self {
74 let nnz = col_indices.len();
75 Self {
76 nrows,
77 ncols,
78 row_offsets,
79 col_indices,
80 values: vec![0.0; nnz],
81 }
82 }
83
84 pub fn identity(n: usize) -> Self {
86 Self {
87 nrows: n,
88 ncols: n,
89 row_offsets: (0..=n).collect(),
90 col_indices: (0..n).collect(),
91 values: vec![1.0; n],
92 }
93 }
94
95 pub fn nnz(&self) -> usize {
97 self.values.len()
98 }
99
100 pub fn spmv_par(&self, x: &[f64], y: &mut [f64]) {
105 debug_assert_eq!(x.len(), self.ncols);
106 debug_assert_eq!(y.len(), self.nrows);
107
108 y.par_iter_mut().enumerate().for_each(|(i, yi)| {
109 let row_start = self.row_offsets[i];
110 let row_end = self.row_offsets[i + 1];
111 let mut sum = 0.0;
112 for k in row_start..row_end {
113 sum += self.values[k] * x[self.col_indices[k]];
114 }
115 *yi = sum;
116 });
117 }
118
119 pub fn spmv(&self, x: &[f64], y: &mut [f64]) {
121 debug_assert_eq!(x.len(), self.ncols);
122 debug_assert_eq!(y.len(), self.nrows);
123 for (i, yi) in y.iter_mut().enumerate().take(self.nrows) {
124 let row_start = self.row_offsets[i];
125 let row_end = self.row_offsets[i + 1];
126 let mut sum = 0.0;
127 for k in row_start..row_end {
128 sum += self.values[k] * x[self.col_indices[k]];
129 }
130 *yi = sum;
131 }
132 }
133
134 pub fn diagonal_preconditioner(&self) -> Vec<f64> {
138 let mut diag = vec![1.0f64; self.nrows];
139 for (i, diag_i) in diag.iter_mut().enumerate().take(self.nrows) {
140 for k in self.row_offsets[i]..self.row_offsets[i + 1] {
141 if self.col_indices[k] == i {
142 let d = self.values[k];
143 if d.abs() > 1e-15 {
144 *diag_i = 1.0 / d;
145 }
146 break;
147 }
148 }
149 }
150 diag
151 }
152
153 pub fn spmv_chunked(&self, x: &[f64], y: &mut [f64], chunk_size: usize) {
158 debug_assert_eq!(x.len(), self.ncols);
159 debug_assert_eq!(y.len(), self.nrows);
160 let chunk_size = chunk_size.max(1);
161 y.par_chunks_mut(chunk_size)
162 .enumerate()
163 .for_each(|(chunk_idx, y_chunk)| {
164 let row_start = chunk_idx * chunk_size;
165 for (k, yi) in y_chunk.iter_mut().enumerate() {
166 let row = row_start + k;
167 let rs = self.row_offsets[row];
168 let re = self.row_offsets[row + 1];
169 let mut sum = 0.0;
170 for j in rs..re {
171 sum += self.values[j] * x[self.col_indices[j]];
172 }
173 *yi = sum;
174 }
175 });
176 }
177}
178
179#[derive(Debug, Clone)]
186pub struct AssemblyTask {
187 pub global_dofs: Vec<usize>,
189 pub ke: Vec<f64>,
191}
192
193impl AssemblyTask {
194 pub fn new(global_dofs: Vec<usize>, ke: Vec<f64>) -> Self {
196 let ndof = global_dofs.len();
197 debug_assert_eq!(ke.len(), ndof * ndof);
198 Self { global_dofs, ke }
199 }
200
201 pub fn ndof(&self) -> usize {
203 self.global_dofs.len()
204 }
205}
206
207#[derive(Debug, Default)]
218pub struct ParallelAssembler {
219 pub ndofs: usize,
221}
222
223impl ParallelAssembler {
224 pub fn new(ndofs: usize) -> Self {
226 Self { ndofs }
227 }
228
229 pub fn assemble(&self, tasks: &[AssemblyTask]) -> CsrMatrix {
238 let mut row_cols: Vec<std::collections::BTreeSet<usize>> =
241 vec![std::collections::BTreeSet::new(); self.ndofs];
242
243 for task in tasks {
244 for &row_dof in &task.global_dofs {
245 for &col_dof in &task.global_dofs {
246 row_cols[row_dof].insert(col_dof);
247 }
248 }
249 }
250
251 let mut row_offsets = vec![0usize; self.ndofs + 1];
253 let mut col_indices: Vec<usize> = Vec::new();
254 for (i, cols) in row_cols.iter().enumerate() {
255 row_offsets[i + 1] = row_offsets[i] + cols.len();
256 col_indices.extend(cols.iter().copied());
257 }
258 let nnz = col_indices.len();
259 let values = vec![0.0f64; nnz];
260
261 let row_col_to_csr: Vec<std::collections::HashMap<usize, usize>> = row_cols
263 .iter()
264 .enumerate()
265 .map(|(i, cols)| {
266 let base = row_offsets[i];
267 cols.iter()
268 .enumerate()
269 .map(|(j, &c)| (c, base + j))
270 .collect()
271 })
272 .collect();
273
274 let values_locked: Vec<Mutex<f64>> = values.into_iter().map(Mutex::new).collect();
276
277 tasks.par_iter().for_each(|task| {
278 let ndof = task.ndof();
279 for (li, &row) in task.global_dofs.iter().enumerate() {
280 for (lj, &col) in task.global_dofs.iter().enumerate() {
281 let ke_val = task.ke[li * ndof + lj];
282 if let Some(&csr_idx) = row_col_to_csr[row].get(&col) {
283 let mut guard = values_locked[csr_idx]
285 .lock()
286 .unwrap_or_else(|e| e.into_inner());
287 *guard += ke_val;
288 }
289 }
290 }
291 });
292
293 let values: Vec<f64> = values_locked
294 .into_iter()
295 .map(|m| {
296 m.into_inner()
297 .expect("mutex not poisoned after parallel assembly")
298 })
299 .collect();
300
301 CsrMatrix {
302 nrows: self.ndofs,
303 ncols: self.ndofs,
304 row_offsets,
305 col_indices,
306 values,
307 }
308 }
309
310 pub fn assemble_rhs(
314 &self,
315 element_dofs: &[Vec<usize>],
316 element_forces: &[Vec<f64>],
317 ) -> Vec<f64> {
318 let rhs_locked: Vec<Mutex<f64>> = (0..self.ndofs).map(|_| Mutex::new(0.0f64)).collect();
319 element_dofs
320 .par_iter()
321 .zip(element_forces.par_iter())
322 .for_each(|(dofs, forces)| {
323 for (&dof, &f) in dofs.iter().zip(forces.iter()) {
324 *rhs_locked[dof].lock().unwrap_or_else(|e| e.into_inner()) += f;
325 }
326 });
327 rhs_locked
328 .into_iter()
329 .map(|m| {
330 m.into_inner()
331 .expect("mutex not poisoned after parallel rhs assembly")
332 })
333 .collect()
334 }
335
336 pub fn assemble_colored(
342 &self,
343 tasks: &[AssemblyTask],
344 element_dofs: &[Vec<usize>],
345 ) -> CsrMatrix {
346 use crate::solvers::assembly_coloring::{assemble_colored_csr, color_elements};
347 let coloring = color_elements(tasks.len(), element_dofs);
348 assemble_colored_csr(self.ndofs, tasks, &coloring)
349 }
350}
351
352#[derive(Debug, Clone, Copy)]
356pub struct PcgStats {
357 pub iterations: usize,
359 pub residual_norm: f64,
361 pub converged: bool,
363}
364
365#[derive(Debug, Clone)]
372pub struct ParallelPcgSolver {
373 pub max_iterations: usize,
375 pub tolerance: f64,
377 pub abs_tolerance: f64,
379}
380
381impl Default for ParallelPcgSolver {
382 fn default() -> Self {
383 Self {
384 max_iterations: 500,
385 tolerance: 1e-8,
386 abs_tolerance: 1e-14,
387 }
388 }
389}
390
391impl ParallelPcgSolver {
392 pub fn new(max_iterations: usize, tolerance: f64) -> Self {
394 Self {
395 max_iterations,
396 tolerance,
397 abs_tolerance: 1e-14,
398 }
399 }
400
401 pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
407 let n = a.nrows;
408 assert_eq!(b.len(), n);
409 assert_eq!(x.len(), n);
410
411 let m_inv = a.diagonal_preconditioner();
412
413 let mut ax = vec![0.0f64; n];
415 a.spmv_par(x, &mut ax);
416 let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
417
418 let b_norm = dot_par(b, b).sqrt();
419 if b_norm < self.abs_tolerance {
420 return PcgStats {
421 iterations: 0,
422 residual_norm: 0.0,
423 converged: true,
424 };
425 }
426
427 let mut z: Vec<f64> = r.iter().zip(m_inv.iter()).map(|(ri, mi)| ri * mi).collect();
429
430 let mut p = z.clone();
432
433 let mut rz = dot_par(&r, &z);
434
435 let mut ap = vec![0.0f64; n];
436 let mut iters = 0;
437 let mut res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
438
439 for _ in 0..self.max_iterations {
440 if res_norm / b_norm < self.tolerance {
441 break;
442 }
443 if res_norm < self.abs_tolerance {
444 break;
445 }
446
447 a.spmv_par(&p, &mut ap);
449
450 let pap = dot_par(&p, &ap);
451 if pap.abs() < 1e-300 {
452 break;
453 }
454 let alpha = rz / pap;
455
456 x.par_iter_mut()
458 .zip(p.par_iter())
459 .for_each(|(xi, pi)| *xi += alpha * pi);
460 r.par_iter_mut()
461 .zip(ap.par_iter())
462 .for_each(|(ri, api)| *ri -= alpha * api);
463
464 z.par_iter_mut()
466 .zip(r.par_iter().zip(m_inv.par_iter()))
467 .for_each(|(zi, (ri, mi))| *zi = ri * mi);
468
469 let rz_new = dot_par(&r, &z);
470 let beta = rz_new / rz.max(1e-300);
471 rz = rz_new;
472
473 p.par_iter_mut()
475 .zip(z.par_iter())
476 .for_each(|(pi, zi)| *pi = zi + beta * *pi);
477
478 res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
479 iters += 1;
480 }
481
482 PcgStats {
483 iterations: iters,
484 residual_norm: res_norm,
485 converged: res_norm / b_norm.max(1e-300) < self.tolerance
486 || res_norm < self.abs_tolerance,
487 }
488 }
489}
490
491fn dot_par(a: &[f64], b: &[f64]) -> f64 {
493 a.par_iter().zip(b.par_iter()).map(|(ai, bi)| ai * bi).sum()
494}
495
496#[derive(Debug, Clone)]
503pub struct ParallelGmresSolver {
504 pub krylov_dim: usize,
506 pub max_restarts: usize,
508 pub tolerance: f64,
510}
511
512impl Default for ParallelGmresSolver {
513 fn default() -> Self {
514 Self {
515 krylov_dim: 30,
516 max_restarts: 10,
517 tolerance: 1e-8,
518 }
519 }
520}
521
522impl ParallelGmresSolver {
523 pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
528 let n = a.nrows;
529 let m_inv = a.diagonal_preconditioner();
530 let mut res_norm = 0.0f64;
531 let mb: Vec<f64> = b.iter().zip(m_inv.iter()).map(|(bi, mi)| bi * mi).collect();
533 let b_norm = dot_par(&mb, &mb).sqrt().max(1e-300);
534 let mut total_iters = 0;
535
536 for _restart in 0..self.max_restarts {
537 let mut ax = vec![0.0f64; n];
539 a.spmv_par(x, &mut ax);
540 let r0: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
541
542 let z0: Vec<f64> = r0
544 .iter()
545 .zip(m_inv.iter())
546 .map(|(ri, mi)| ri * mi)
547 .collect();
548 let beta = dot_par(&z0, &z0).sqrt().max(1e-300);
549 res_norm = beta;
550
551 if res_norm / b_norm < self.tolerance {
552 break;
553 }
554
555 let m = self.krylov_dim.min(n);
557 let mut q: Vec<Vec<f64>> = Vec::with_capacity(m + 1);
558
559 q.push(z0.iter().map(|v| v / beta).collect());
561
562 let mut h = vec![0.0f64; (m + 1) * m];
564 let mut cs = vec![0.0f64; m];
566 let mut sn = vec![0.0f64; m];
567 let mut e1 = vec![0.0f64; m + 1];
569 e1[0] = beta;
570
571 let mut j_stop = m;
572 for j in 0..m {
573 let mut aqj = vec![0.0f64; n];
575 a.spmv_par(&q[j], &mut aqj);
576 let w: Vec<f64> = aqj.iter().zip(m_inv.iter()).map(|(v, mi)| v * mi).collect();
577
578 let mut w = w;
580 for i in 0..=j {
581 let hij = dot_par(&w, &q[i]);
582 h[i * m + j] = hij;
583 w.par_iter_mut()
584 .zip(q[i].par_iter())
585 .for_each(|(wi, qi)| *wi -= hij * qi);
586 }
587 let w_norm = dot_par(&w, &w).sqrt();
588 h[(j + 1) * m + j] = w_norm;
589
590 let exact_convergence = w_norm < 1e-14;
595 if !exact_convergence {
596 q.push(w.iter().map(|v| v / w_norm).collect());
597 }
598
599 for i in 0..j {
601 let tmp = cs[i] * h[i * m + j] + sn[i] * h[(i + 1) * m + j];
602 h[(i + 1) * m + j] = -sn[i] * h[i * m + j] + cs[i] * h[(i + 1) * m + j];
603 h[i * m + j] = tmp;
604 }
605
606 let (c, s) = givens_rotation(h[j * m + j], h[(j + 1) * m + j]);
608 cs[j] = c;
609 sn[j] = s;
610
611 h[j * m + j] = c * h[j * m + j] + s * h[(j + 1) * m + j];
612 h[(j + 1) * m + j] = 0.0;
613 e1[j + 1] = -s * e1[j];
614 e1[j] *= c;
615
616 res_norm = e1[j + 1].abs();
617 total_iters += 1;
618
619 if res_norm / b_norm < self.tolerance || exact_convergence {
621 j_stop = j + 1;
622 break;
623 }
624 }
625
626 let mut y = vec![0.0f64; j_stop];
628 for i in (0..j_stop).rev() {
629 y[i] = e1[i];
630 for k in (i + 1)..j_stop {
631 y[i] -= h[i * m + k] * y[k];
632 }
633 let hii = h[i * m + i];
634 if hii.abs() > 1e-300 {
635 y[i] /= hii;
636 }
637 }
638
639 for j in 0..j_stop {
641 let yj = y[j];
642 x.par_iter_mut()
643 .zip(q[j].par_iter())
644 .for_each(|(xi, qji)| *xi += yj * qji);
645 }
646
647 if res_norm / b_norm < self.tolerance {
648 break;
649 }
650 }
651
652 PcgStats {
653 iterations: total_iters,
654 residual_norm: res_norm,
655 converged: res_norm / b_norm < self.tolerance,
656 }
657 }
658}
659
660fn givens_rotation(a: f64, b: f64) -> (f64, f64) {
662 if b.abs() < 1e-300 {
663 (1.0, 0.0)
664 } else {
665 let r = a.hypot(b);
666 (a / r, b / r)
667 }
668}
669
670#[cfg(test)]
673mod tests {
674 use super::*;
675
676 fn diag_matrix(n: usize, diag_val: f64) -> CsrMatrix {
677 CsrMatrix {
678 nrows: n,
679 ncols: n,
680 row_offsets: (0..=n).collect(),
681 col_indices: (0..n).collect(),
682 values: vec![diag_val; n],
683 }
684 }
685
686 #[test]
687 fn pcg_solves_diagonal_system() {
688 let n = 16;
689 let mat = diag_matrix(n, 4.0);
690 let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
691 let mut x = vec![0.0f64; n];
692 let stats = ParallelPcgSolver::default().solve(&mat, &rhs, &mut x);
693 assert!(stats.converged, "PCG did not converge: {:?}", stats);
694 for (i, xi) in x.iter().enumerate() {
695 let expected = (i + 1) as f64 / 4.0;
696 assert!(
697 (xi - expected).abs() < 1e-10,
698 "x[{i}] = {xi}, expected {expected}"
699 );
700 }
701 }
702
703 #[test]
704 fn parallel_spmv_matches_sequential() {
705 let mat = diag_matrix(64, 3.0);
706 let x: Vec<f64> = (0..64).map(|i| i as f64).collect();
707 let mut y_par = vec![0.0f64; 64];
708 let mut y_seq = vec![0.0f64; 64];
709 mat.spmv_par(&x, &mut y_par);
710 mat.spmv(&x, &mut y_seq);
711 for (a, b) in y_par.iter().zip(y_seq.iter()) {
712 assert!((a - b).abs() < 1e-14, "{a} != {b}");
713 }
714 }
715
716 #[test]
717 fn parallel_assembler_assembles_2d_bar() {
718 let tasks = vec![
721 AssemblyTask::new(vec![0, 1], vec![1.0, -1.0, -1.0, 1.0]),
722 AssemblyTask::new(vec![1, 2], vec![1.0, -1.0, -1.0, 1.0]),
723 ];
724 let asm = ParallelAssembler::new(3);
725 let mat = asm.assemble(&tasks);
726 assert_eq!(mat.nrows, 3);
727 let find_val = |row: usize, col: usize| -> f64 {
729 for k in mat.row_offsets[row]..mat.row_offsets[row + 1] {
730 if mat.col_indices[k] == col {
731 return mat.values[k];
732 }
733 }
734 0.0
735 };
736 assert!((find_val(0, 0) - 1.0).abs() < 1e-14);
737 assert!((find_val(1, 1) - 2.0).abs() < 1e-14);
738 assert!((find_val(2, 2) - 1.0).abs() < 1e-14);
739 assert!((find_val(0, 1) - (-1.0)).abs() < 1e-14);
740 assert!((find_val(1, 2) - (-1.0)).abs() < 1e-14);
741 }
742
743 #[test]
744 fn assemble_rhs_sums_forces() {
745 let asm = ParallelAssembler::new(3);
746 let dofs = vec![vec![0usize, 1], vec![1, 2]];
747 let forces = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
748 let rhs = asm.assemble_rhs(&dofs, &forces);
749 assert!((rhs[0] - 1.0).abs() < 1e-14);
750 assert!((rhs[1] - 5.0).abs() < 1e-14); assert!((rhs[2] - 4.0).abs() < 1e-14);
752 }
753
754 #[test]
755 fn gmres_solves_diagonal_system() {
756 let n = 8;
757 let mat = diag_matrix(n, 2.0);
758 let rhs: Vec<f64> = (0..n).map(|i| (i + 1) as f64).collect();
759 let mut x = vec![0.0f64; n];
760 let stats = ParallelGmresSolver::default().solve(&mat, &rhs, &mut x);
761 assert!(stats.converged, "GMRES did not converge: {:?}", stats);
762 for (i, xi) in x.iter().enumerate() {
763 let expected = (i + 1) as f64 / 2.0;
764 assert!(
765 (xi - expected).abs() < 1e-6,
766 "x[{i}] = {xi}, expected {expected}"
767 );
768 }
769 }
770
771 #[test]
772 fn spmv_chunked_matches_spmv() {
773 let n = 64;
774 let mat = diag_matrix(n, 3.0);
775 let x: Vec<f64> = (0..n).map(|i| i as f64).collect();
776 let mut y_par = vec![0.0f64; n];
777 let mut y_chunked = vec![0.0f64; n];
778 mat.spmv_par(&x, &mut y_par);
779 mat.spmv_chunked(&x, &mut y_chunked, 16);
780 for (a, b) in y_par.iter().zip(y_chunked.iter()) {
781 assert!((a - b).abs() < 1e-14, "{a} != {b}");
782 }
783 }
784}
785
786use crate::solvers::amg::{
789 cycle::{AmgHierarchy, CycleKind},
790 preconditioner::Preconditioner,
791};
792
793#[derive(Debug)]
798pub struct PcgWithPrecond<P: Preconditioner> {
799 pub max_iterations: usize,
801 pub tolerance: f64,
803 pub abs_tolerance: f64,
805 pub precond: P,
807}
808
809impl<P: Preconditioner> PcgWithPrecond<P> {
810 pub fn new(precond: P, max_iterations: usize, tolerance: f64) -> Self {
812 Self {
813 max_iterations,
814 tolerance,
815 abs_tolerance: 1e-14,
816 precond,
817 }
818 }
819
820 pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
824 let n = a.nrows;
825 let mut z = vec![0.0f64; n];
826
827 let mut ax = vec![0.0f64; n];
829 a.spmv(x, &mut ax);
830 let mut r: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
831
832 self.precond.apply(&r, &mut z);
833 let mut p = z.clone();
834 let mut rz = dot_par(&r, &z);
835 let b_norm = dot_par(b, b).sqrt().max(1e-300);
836 let mut ap = vec![0.0f64; n];
837 let mut iters = 0;
838 let mut res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
839
840 for _ in 0..self.max_iterations {
841 if res_norm / b_norm < self.tolerance {
842 break;
843 }
844 if res_norm < self.abs_tolerance {
845 break;
846 }
847
848 a.spmv_par(&p, &mut ap);
849 let pap = dot_par(&p, &ap);
850 if pap.abs() < 1e-300 {
851 break;
852 }
853 let alpha = rz / pap;
854
855 for i in 0..n {
856 x[i] += alpha * p[i];
857 r[i] -= alpha * ap[i];
858 }
859 self.precond.apply(&r, &mut z);
860 let rz_new = dot_par(&r, &z);
861 let beta = rz_new / rz.max(1e-300);
862 rz = rz_new;
863 for i in 0..n {
864 p[i] = z[i] + beta * p[i];
865 }
866 res_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
867 iters += 1;
868 }
869
870 PcgStats {
871 iterations: iters,
872 residual_norm: res_norm,
873 converged: res_norm / b_norm < self.tolerance || res_norm < self.abs_tolerance,
874 }
875 }
876}
877
878pub type PcgWithAmg = PcgWithPrecond<crate::solvers::amg::preconditioner::AmgPreconditioner>;
880
881pub struct GmresWithAmg {
886 pub krylov_dim: usize,
888 pub max_restarts: usize,
890 pub tolerance: f64,
892 pub hierarchy: AmgHierarchy,
894}
895
896impl GmresWithAmg {
897 pub fn new(
899 hierarchy: AmgHierarchy,
900 krylov_dim: usize,
901 max_restarts: usize,
902 tolerance: f64,
903 ) -> Self {
904 Self {
905 krylov_dim,
906 max_restarts,
907 tolerance,
908 hierarchy,
909 }
910 }
911
912 pub fn solve(&self, a: &CsrMatrix, b: &[f64], x: &mut [f64]) -> PcgStats {
914 let n = a.nrows;
915 let b_norm = dot_par(b, b).sqrt().max(1e-300);
916 let mut res_norm = 0.0f64;
917 let mut total_iters = 0;
918
919 let pcg = ParallelPcgSolver::new(500, 1e-10);
920 let amg_precond = crate::solvers::amg::preconditioner::AmgPreconditioner {
921 hierarchy: self.hierarchy.clone(),
922 cycle_kind: CycleKind::V,
923 pcg,
924 };
925
926 for _restart in 0..self.max_restarts {
927 let mut ax = vec![0.0f64; n];
929 a.spmv(x, &mut ax);
930 let r0: Vec<f64> = b.iter().zip(ax.iter()).map(|(bi, ai)| bi - ai).collect();
931 let mut r0z = vec![0.0f64; n];
932 amg_precond.apply(&r0, &mut r0z);
933 res_norm = r0z.iter().map(|v| v * v).sum::<f64>().sqrt();
934
935 if res_norm / b_norm < self.tolerance {
936 break;
937 }
938
939 let beta = res_norm;
940 let mut q: Vec<Vec<f64>> = vec![r0z.iter().map(|v| v / beta).collect()];
941 let m = self.krylov_dim.min(n);
942 let mut h = vec![vec![0.0f64; m]; m + 1];
943 let mut cs = vec![0.0f64; m];
944 let mut sn = vec![0.0f64; m];
945 let mut e1 = vec![0.0f64; m + 1];
946 e1[0] = beta;
947
948 let mut j_stop = m;
949
950 for jj in 0..m {
951 let mut aq = vec![0.0f64; n];
953 a.spmv_par(&q[jj], &mut aq);
954 let mut w = vec![0.0f64; n];
955 amg_precond.apply(&aq, &mut w);
956
957 for ii in 0..=jj {
959 h[ii][jj] = dot_par(&w, &q[ii]);
960 for k in 0..n {
961 w[k] -= h[ii][jj] * q[ii][k];
962 }
963 }
964 h[jj + 1][jj] = w.iter().map(|v| v * v).sum::<f64>().sqrt();
965
966 let exact_convergence = h[jj + 1][jj] < 1e-14;
968 if !exact_convergence {
969 let inv_norm = 1.0 / h[jj + 1][jj];
970 q.push(w.iter().map(|v| v * inv_norm).collect());
971 }
972
973 for ii in 0..jj {
975 let tmp = cs[ii] * h[ii][jj] + sn[ii] * h[ii + 1][jj];
976 h[ii + 1][jj] = -sn[ii] * h[ii][jj] + cs[ii] * h[ii + 1][jj];
977 h[ii][jj] = tmp;
978 }
979
980 let denom = (h[jj][jj] * h[jj][jj] + h[jj + 1][jj] * h[jj + 1][jj]).sqrt();
982 cs[jj] = if denom > 1e-300 {
983 h[jj][jj] / denom
984 } else {
985 1.0
986 };
987 sn[jj] = if denom > 1e-300 {
988 h[jj + 1][jj] / denom
989 } else {
990 0.0
991 };
992 h[jj][jj] = cs[jj] * h[jj][jj] + sn[jj] * h[jj + 1][jj];
993 h[jj + 1][jj] = 0.0;
994 e1[jj + 1] = -sn[jj] * e1[jj];
995 e1[jj] *= cs[jj];
996 res_norm = e1[jj + 1].abs();
997
998 if res_norm / b_norm < self.tolerance || exact_convergence {
999 j_stop = jj + 1;
1000 break;
1001 }
1002 }
1003
1004 let mut y = vec![0.0f64; j_stop];
1006 for ii in (0..j_stop).rev() {
1007 y[ii] = e1[ii];
1008 for kk in (ii + 1)..j_stop {
1009 y[ii] -= h[ii][kk] * y[kk];
1010 }
1011 if h[ii][ii].abs() > 1e-300 {
1012 y[ii] /= h[ii][ii];
1013 }
1014 }
1015
1016 for ii in 0..j_stop {
1018 let yi = y[ii];
1019 for k in 0..n {
1020 x[k] += yi * q[ii][k];
1021 }
1022 }
1023
1024 total_iters += j_stop;
1025 if res_norm / b_norm < self.tolerance {
1026 break;
1027 }
1028 }
1029
1030 PcgStats {
1031 iterations: total_iters,
1032 residual_norm: res_norm,
1033 converged: res_norm / b_norm < self.tolerance,
1034 }
1035 }
1036}