math_audio_solvers/preconditioners/
amg.rs

1//! Algebraic Multigrid (AMG) Preconditioner
2//!
3//! This module implements an algebraic multigrid preconditioner inspired by
4//! hypre's BoomerAMG, designed for better parallel scalability across CPU cores.
5//!
6//! ## Features
7//!
8//! - **Parallel coarsening**: Classical Ruge-Stüben (RS) and PMIS algorithms
9//! - **Interpolation**: Standard and extended interpolation operators
10//! - **Smoothers**: Jacobi (fully parallel) and symmetric Gauss-Seidel
11//! - **V-cycle**: Standard V(ν₁, ν₂) cycling with configurable pre/post smoothing
12//!
13//! ## Scalability
14//!
15//! The AMG preconditioner scales better than ILU across multiple cores because:
16//! - Coarsening can be parallelized (PMIS is inherently parallel)
17//! - Jacobi smoothing is embarrassingly parallel
18//! - Each level's operations can be parallelized independently
19//!
20//! ## Usage
21//!
22//! ```ignore
23//! use math_audio_solvers::{AmgPreconditioner, AmgConfig, CsrMatrix};
24//!
25//! let config = AmgConfig::default();
26//! let precond = AmgPreconditioner::from_csr(&matrix, config);
27//!
28//! // Use with GMRES
29//! let z = precond.apply(&residual);
30//! ```
31
32#[cfg(any(feature = "native", feature = "wasm"))]
33use crate::parallel::{parallel_enumerate_map, parallel_map_indexed};
34use crate::sparse::CsrMatrix;
35use crate::traits::{ComplexField, Preconditioner};
36use ndarray::Array1;
37use num_traits::FromPrimitive;
38use std::collections::HashMap;
39
40/// Coarsening algorithm for AMG hierarchy construction
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
42pub enum AmgCoarsening {
43    /// Classical Ruge-Stüben coarsening
44    /// Good quality but limited parallelism in the selection phase
45    #[default]
46    RugeStuben,
47
48    /// Parallel Modified Independent Set (PMIS)
49    /// Better parallel scalability, may produce slightly larger coarse grids
50    Pmis,
51
52    /// Hybrid MIS (HMIS) - PMIS in first pass, then RS cleanup
53    /// Balance between quality and parallelism
54    Hmis,
55}
56
57/// Interpolation operator type
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
59pub enum AmgInterpolation {
60    /// Standard interpolation - direct interpolation from coarse neighbors
61    #[default]
62    Standard,
63
64    /// Extended interpolation - includes indirect (distance-2) connections
65    /// Better for some problem types but more expensive
66    Extended,
67
68    /// Direct interpolation - simplest, only immediate strong connections
69    /// Fastest but may have poor convergence for hard problems
70    Direct,
71}
72
73/// Smoother type for AMG relaxation
74#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
75pub enum AmgSmoother {
76    /// Jacobi relaxation - fully parallel, requires damping (ω ≈ 0.6-0.8)
77    #[default]
78    Jacobi,
79
80    /// l1-Jacobi - Jacobi with l1 norm scaling, more robust
81    L1Jacobi,
82
83    /// Symmetric Gauss-Seidel - forward then backward sweep
84    /// Better convergence but limited parallelism
85    SymmetricGaussSeidel,
86
87    /// Chebyshev polynomial smoother - fully parallel, no damping needed
88    Chebyshev,
89}
90
91/// AMG cycle type
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
93pub enum AmgCycle {
94    /// V-cycle: one visit to each level
95    #[default]
96    VCycle,
97
98    /// W-cycle: two visits to coarser levels (more expensive)
99    WCycle,
100
101    /// F-cycle: hybrid between V and W
102    FCycle,
103}
104
105/// Configuration for AMG preconditioner
106#[derive(Debug, Clone)]
107pub struct AmgConfig {
108    /// Coarsening algorithm
109    pub coarsening: AmgCoarsening,
110
111    /// Interpolation operator type
112    pub interpolation: AmgInterpolation,
113
114    /// Smoother for pre- and post-relaxation
115    pub smoother: AmgSmoother,
116
117    /// Cycle type (V, W, or F)
118    pub cycle: AmgCycle,
119
120    /// Strong connection threshold (default: 0.25)
121    /// Connections with |a_ij| >= θ * max_k |a_ik| are considered strong
122    pub strong_threshold: f64,
123
124    /// Maximum number of levels in the hierarchy
125    pub max_levels: usize,
126
127    /// Coarsest level size - switch to direct solve below this
128    pub coarse_size: usize,
129
130    /// Number of pre-smoothing sweeps (ν₁)
131    pub num_pre_smooth: usize,
132
133    /// Number of post-smoothing sweeps (ν₂)
134    pub num_post_smooth: usize,
135
136    /// Jacobi damping parameter (ω)
137    pub jacobi_weight: f64,
138
139    /// Truncation factor for interpolation (drop small weights)
140    pub trunc_factor: f64,
141
142    /// Maximum interpolation stencil size per row
143    pub max_interp_elements: usize,
144
145    /// Enable aggressive coarsening on first few levels
146    pub aggressive_coarsening_levels: usize,
147}
148
149impl Default for AmgConfig {
150    fn default() -> Self {
151        Self {
152            coarsening: AmgCoarsening::default(),
153            interpolation: AmgInterpolation::default(),
154            smoother: AmgSmoother::default(),
155            cycle: AmgCycle::default(),
156            strong_threshold: 0.25,
157            max_levels: 25,
158            coarse_size: 50,
159            num_pre_smooth: 1,
160            num_post_smooth: 1,
161            jacobi_weight: 0.6667, // 2/3 is optimal for Poisson
162            trunc_factor: 0.0,
163            max_interp_elements: 4,
164            aggressive_coarsening_levels: 0,
165        }
166    }
167}
168
169impl AmgConfig {
170    /// Configuration optimized for BEM systems
171    ///
172    /// BEM matrices are typically denser and less sparse than FEM,
173    /// requiring adjusted thresholds.
174    pub fn for_bem() -> Self {
175        Self {
176            strong_threshold: 0.5,           // Higher for denser BEM matrices
177            coarsening: AmgCoarsening::Pmis, // Better parallel scalability
178            smoother: AmgSmoother::L1Jacobi, // More robust for BEM
179            max_interp_elements: 6,
180            ..Default::default()
181        }
182    }
183
184    /// Configuration optimized for FEM systems
185    pub fn for_fem() -> Self {
186        Self {
187            strong_threshold: 0.25,
188            coarsening: AmgCoarsening::RugeStuben,
189            smoother: AmgSmoother::SymmetricGaussSeidel,
190            ..Default::default()
191        }
192    }
193
194    /// Configuration optimized for maximum parallel scalability
195    pub fn for_parallel() -> Self {
196        Self {
197            coarsening: AmgCoarsening::Pmis,
198            smoother: AmgSmoother::Jacobi,
199            jacobi_weight: 0.8,
200            num_pre_smooth: 2,
201            num_post_smooth: 2,
202            ..Default::default()
203        }
204    }
205
206    /// Configuration for difficult/ill-conditioned problems
207    pub fn for_difficult_problems() -> Self {
208        Self {
209            coarsening: AmgCoarsening::RugeStuben,
210            interpolation: AmgInterpolation::Extended,
211            smoother: AmgSmoother::SymmetricGaussSeidel,
212            strong_threshold: 0.25,
213            max_interp_elements: 8,
214            num_pre_smooth: 2,
215            num_post_smooth: 2,
216            ..Default::default()
217        }
218    }
219}
220
221/// Point classification in coarsening
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
223enum PointType {
224    /// Undecided
225    Undecided,
226    /// Coarse point (C-point)
227    Coarse,
228    /// Fine point (F-point)
229    Fine,
230}
231
232/// Single level in the AMG hierarchy
233#[derive(Debug, Clone)]
234struct AmgLevel<T: ComplexField> {
235    /// System matrix A at this level (CSR format)
236    matrix: CsrMatrix<T>,
237
238    /// Prolongation operator P: coarse -> fine
239    prolongation: Option<CsrMatrix<T>>,
240
241    /// Restriction operator R: fine -> coarse (typically R = P^T)
242    restriction: Option<CsrMatrix<T>>,
243
244    /// Inverse diagonal for Jacobi smoothing
245    diag_inv: Array1<T>,
246
247    /// Coarse-to-fine mapping
248    coarse_to_fine: Vec<usize>,
249
250    /// Number of DOFs at this level
251    num_dofs: usize,
252}
253
254/// Algebraic Multigrid Preconditioner
255///
256/// Implements a classical AMG V-cycle preconditioner with configurable
257/// coarsening, interpolation, and smoothing strategies.
258#[derive(Debug, Clone)]
259pub struct AmgPreconditioner<T: ComplexField> {
260    /// AMG hierarchy (finest to coarsest)
261    levels: Vec<AmgLevel<T>>,
262
263    /// Configuration
264    config: AmgConfig,
265
266    /// Statistics
267    setup_time_ms: f64,
268    grid_complexity: f64,
269    operator_complexity: f64,
270}
271
272impl<T: ComplexField> AmgPreconditioner<T>
273where
274    T::Real: Sync + Send,
275{
276    /// Create AMG preconditioner from a CSR matrix
277    pub fn from_csr(matrix: &CsrMatrix<T>, config: AmgConfig) -> Self {
278        let start = std::time::Instant::now();
279
280        let mut levels = Vec::new();
281        let mut current_matrix = matrix.clone();
282
283        // Extract diagonal for first level
284        let diag_inv = Self::compute_diag_inv(&current_matrix);
285
286        levels.push(AmgLevel {
287            matrix: current_matrix.clone(),
288            prolongation: None,
289            restriction: None,
290            diag_inv,
291            coarse_to_fine: Vec::new(),
292            num_dofs: current_matrix.num_rows,
293        });
294
295        // Build hierarchy
296        for _level_idx in 0..config.max_levels - 1 {
297            let n = current_matrix.num_rows;
298            if n <= config.coarse_size {
299                break;
300            }
301
302            // Compute strength matrix
303            let strong_connections =
304                Self::compute_strength_matrix(&current_matrix, config.strong_threshold);
305
306            // Coarsening
307            let (point_types, coarse_to_fine) = match config.coarsening {
308                AmgCoarsening::RugeStuben => {
309                    Self::coarsen_ruge_stuben(&current_matrix, &strong_connections)
310                }
311                AmgCoarsening::Pmis | AmgCoarsening::Hmis => {
312                    Self::coarsen_pmis(&current_matrix, &strong_connections)
313                }
314            };
315
316            let num_coarse = coarse_to_fine.len();
317            if num_coarse == 0 || num_coarse >= n {
318                // Can't coarsen further
319                break;
320            }
321
322            // Build interpolation operator
323            let prolongation = Self::build_interpolation(
324                &current_matrix,
325                &strong_connections,
326                &point_types,
327                &coarse_to_fine,
328                &config,
329            );
330
331            // Restriction is transpose of prolongation
332            let restriction = Self::transpose_csr(&prolongation);
333
334            // Galerkin coarse grid: A_c = R * A * P
335            let coarse_matrix =
336                Self::galerkin_product(&restriction, &current_matrix, &prolongation);
337
338            // Extract diagonal for new level
339            let coarse_diag_inv = Self::compute_diag_inv(&coarse_matrix);
340
341            // Update level with P and R
342            if let Some(last) = levels.last_mut() {
343                last.prolongation = Some(prolongation);
344                last.restriction = Some(restriction);
345                last.coarse_to_fine = coarse_to_fine;
346            }
347
348            // Add coarse level
349            levels.push(AmgLevel {
350                matrix: coarse_matrix.clone(),
351                prolongation: None,
352                restriction: None,
353                diag_inv: coarse_diag_inv,
354                coarse_to_fine: Vec::new(),
355                num_dofs: num_coarse,
356            });
357
358            current_matrix = coarse_matrix;
359        }
360
361        let setup_time_ms = start.elapsed().as_secs_f64() * 1000.0;
362
363        // Compute complexities
364        let (grid_complexity, operator_complexity) = Self::compute_complexities(&levels);
365
366        Self {
367            levels,
368            config,
369            setup_time_ms,
370            grid_complexity,
371            operator_complexity,
372        }
373    }
374
375    /// Get number of levels in hierarchy
376    pub fn num_levels(&self) -> usize {
377        self.levels.len()
378    }
379
380    /// Get setup time in milliseconds
381    pub fn setup_time_ms(&self) -> f64 {
382        self.setup_time_ms
383    }
384
385    /// Get grid complexity (sum of DOFs / fine DOFs)
386    pub fn grid_complexity(&self) -> f64 {
387        self.grid_complexity
388    }
389
390    /// Get operator complexity (sum of nnz / fine nnz)
391    pub fn operator_complexity(&self) -> f64 {
392        self.operator_complexity
393    }
394
395    /// Get configuration
396    pub fn config(&self) -> &AmgConfig {
397        &self.config
398    }
399
400    /// Compute inverse diagonal for Jacobi smoothing
401    fn compute_diag_inv(matrix: &CsrMatrix<T>) -> Array1<T> {
402        let n = matrix.num_rows;
403        let mut diag_inv = Array1::from_elem(n, T::one());
404
405        for i in 0..n {
406            let diag = matrix.get(i, i);
407            let tol = T::Real::from_f64(1e-15).unwrap();
408            if diag.norm() > tol {
409                diag_inv[i] = diag.inv();
410            }
411        }
412
413        diag_inv
414    }
415
416    /// Compute strength of connection matrix
417    ///
418    /// Entry (i,j) is strong if |a_ij| >= θ * max_k!=i |a_ik|
419    fn compute_strength_matrix(matrix: &CsrMatrix<T>, theta: f64) -> Vec<Vec<usize>> {
420        let n = matrix.num_rows;
421
422        #[cfg(any(feature = "native", feature = "wasm"))]
423        {
424            parallel_map_indexed(n, |i| {
425                // Find max off-diagonal magnitude in row i
426                let mut max_off_diag = T::Real::from_f64(0.0).unwrap();
427                for (j, val) in matrix.row_entries(i) {
428                    if i != j {
429                        let norm = val.norm();
430                        if norm > max_off_diag {
431                            max_off_diag = norm;
432                        }
433                    }
434                }
435
436                let threshold = T::Real::from_f64(theta).unwrap() * max_off_diag;
437
438                // Collect strong connections
439                let mut row_strong = Vec::new();
440                for (j, val) in matrix.row_entries(i) {
441                    if i != j && val.norm() >= threshold {
442                        row_strong.push(j);
443                    }
444                }
445                row_strong
446            })
447        }
448
449        #[cfg(not(any(feature = "native", feature = "wasm")))]
450        {
451            let mut strong: Vec<Vec<usize>> = vec![Vec::new(); n];
452            for (i, row_strong) in strong.iter_mut().enumerate().take(n) {
453                // Find max off-diagonal magnitude in row i
454                let mut max_off_diag = T::Real::from_f64(0.0).unwrap();
455                for (j, val) in matrix.row_entries(i) {
456                    if i != j {
457                        let norm = val.norm();
458                        if norm > max_off_diag {
459                            max_off_diag = norm;
460                        }
461                    }
462                }
463
464                let threshold = T::Real::from_f64(theta).unwrap() * max_off_diag;
465
466                // Collect strong connections
467                for (j, val) in matrix.row_entries(i) {
468                    if i != j && val.norm() >= threshold {
469                        row_strong.push(j);
470                    }
471                }
472            }
473            strong
474        }
475    }
476
477    /// Classical Ruge-Stüben coarsening
478    fn coarsen_ruge_stuben(
479        matrix: &CsrMatrix<T>,
480        strong: &[Vec<usize>],
481    ) -> (Vec<PointType>, Vec<usize>) {
482        let n = matrix.num_rows;
483        let mut point_types = vec![PointType::Undecided; n];
484
485        // Compute influence measure λ_i = |S_i^T| (how many points strongly depend on i)
486        let mut lambda: Vec<usize> = vec![0; n];
487        for row in strong.iter().take(n) {
488            for &j in row {
489                lambda[j] += 1;
490            }
491        }
492
493        // Build priority queue (we use a simple approach: process by decreasing lambda)
494        let mut order: Vec<usize> = (0..n).collect();
495        order.sort_by(|&a, &b| lambda[b].cmp(&lambda[a]));
496
497        // First pass: select C-points
498        for &i in &order {
499            if point_types[i] != PointType::Undecided {
500                continue;
501            }
502
503            // Make i a C-point
504            point_types[i] = PointType::Coarse;
505
506            // All points that strongly depend on i become F-points
507            for j in 0..n {
508                if point_types[j] == PointType::Undecided && strong[j].contains(&i) {
509                    point_types[j] = PointType::Fine;
510                    // Update lambda for neighbors
511                    for &k in &strong[j] {
512                        if point_types[k] == PointType::Undecided {
513                            lambda[k] = lambda[k].saturating_sub(1);
514                        }
515                    }
516                }
517            }
518        }
519
520        // Ensure all remaining undecided become fine
521        for pt in &mut point_types {
522            if *pt == PointType::Undecided {
523                *pt = PointType::Fine;
524            }
525        }
526
527        // Build coarse-to-fine mapping
528        let coarse_to_fine: Vec<usize> = (0..n)
529            .filter(|&i| point_types[i] == PointType::Coarse)
530            .collect();
531
532        (point_types, coarse_to_fine)
533    }
534
535    /// Parallel Modified Independent Set (PMIS) coarsening
536    fn coarsen_pmis(matrix: &CsrMatrix<T>, strong: &[Vec<usize>]) -> (Vec<PointType>, Vec<usize>) {
537        let n = matrix.num_rows;
538        let mut point_types = vec![PointType::Undecided; n];
539
540        // Compute weights based on number of strong connections
541        let weights: Vec<f64> = (0..n)
542            .map(|i| {
543                // Weight = |S_i| + random tie-breaker
544                strong[i].len() as f64 + (i as f64 * 0.0001) % 0.001
545            })
546            .collect();
547
548        // Iterative independent set selection
549        let mut changed = true;
550        let mut iteration = 0;
551        const MAX_ITERATIONS: usize = 100;
552
553        while changed && iteration < MAX_ITERATIONS {
554            changed = false;
555            iteration += 1;
556
557            #[cfg(any(feature = "native", feature = "wasm"))]
558            {
559                // Parallel pass: determine new C-points and F-points
560                let updates: Vec<(usize, PointType)> =
561                    parallel_enumerate_map(&point_types, |i, pt| {
562                        if *pt != PointType::Undecided {
563                            return (i, *pt);
564                        }
565
566                        // Check if i has maximum weight among undecided strong neighbors
567                        let mut is_max = true;
568                        for &j in &strong[i] {
569                            if point_types[j] == PointType::Undecided && weights[j] > weights[i] {
570                                is_max = false;
571                                break;
572                            }
573                        }
574
575                        // Check if any strong neighbor is already C
576                        let has_c_neighbor = strong[i]
577                            .iter()
578                            .any(|&j| point_types[j] == PointType::Coarse);
579
580                        if has_c_neighbor {
581                            (i, PointType::Fine)
582                        } else if is_max {
583                            (i, PointType::Coarse)
584                        } else {
585                            (i, PointType::Undecided)
586                        }
587                    });
588
589                for (i, new_type) in updates {
590                    if point_types[i] != new_type {
591                        point_types[i] = new_type;
592                        changed = true;
593                    }
594                }
595            }
596
597            #[cfg(not(any(feature = "native", feature = "wasm")))]
598            {
599                // Sequential fallback
600                let old_types = point_types.clone();
601                for i in 0..n {
602                    if old_types[i] != PointType::Undecided {
603                        continue;
604                    }
605
606                    // Check if i has maximum weight among undecided strong neighbors
607                    let mut is_max = true;
608                    for &j in &strong[i] {
609                        if old_types[j] == PointType::Undecided && weights[j] > weights[i] {
610                            is_max = false;
611                            break;
612                        }
613                    }
614
615                    // Check if any strong neighbor is already C
616                    let has_c_neighbor =
617                        strong[i].iter().any(|&j| old_types[j] == PointType::Coarse);
618
619                    if has_c_neighbor {
620                        point_types[i] = PointType::Fine;
621                        changed = true;
622                    } else if is_max {
623                        point_types[i] = PointType::Coarse;
624                        changed = true;
625                    }
626                }
627            }
628        }
629
630        // Any remaining undecided become coarse
631        for pt in &mut point_types {
632            if *pt == PointType::Undecided {
633                *pt = PointType::Coarse;
634            }
635        }
636
637        // Build coarse-to-fine mapping
638        let coarse_to_fine: Vec<usize> = (0..n)
639            .filter(|&i| point_types[i] == PointType::Coarse)
640            .collect();
641
642        (point_types, coarse_to_fine)
643    }
644
645    /// Build interpolation operator P
646    fn build_interpolation(
647        matrix: &CsrMatrix<T>,
648        strong: &[Vec<usize>],
649        point_types: &[PointType],
650        coarse_to_fine: &[usize],
651        config: &AmgConfig,
652    ) -> CsrMatrix<T> {
653        let n_fine = matrix.num_rows;
654        let n_coarse = coarse_to_fine.len();
655
656        // Build fine-to-coarse mapping
657        let mut fine_to_coarse = vec![usize::MAX; n_fine];
658        for (coarse_idx, &fine_idx) in coarse_to_fine.iter().enumerate() {
659            fine_to_coarse[fine_idx] = coarse_idx;
660        }
661
662        // Build P row by row
663        let mut triplets: Vec<(usize, usize, T)> = Vec::new();
664
665        for i in 0..n_fine {
666            match point_types[i] {
667                PointType::Coarse => {
668                    // C-point: identity mapping P_ij = 1 if j = coarse_index(i)
669                    let coarse_idx = fine_to_coarse[i];
670                    triplets.push((i, coarse_idx, T::one()));
671                }
672                PointType::Fine => {
673                    // F-point: interpolate from strong C-neighbors
674                    let a_ii = matrix.get(i, i);
675
676                    // Collect strong C-neighbors
677                    let c_neighbors: Vec<usize> = strong[i]
678                        .iter()
679                        .copied()
680                        .filter(|&j| point_types[j] == PointType::Coarse)
681                        .collect();
682
683                    if c_neighbors.is_empty() {
684                        continue;
685                    }
686
687                    // Standard interpolation weights
688                    match config.interpolation {
689                        AmgInterpolation::Direct | AmgInterpolation::Standard => {
690                            let mut weights: Vec<(usize, T)> = Vec::new();
691                            let mut sum_weights = T::zero();
692
693                            for &j in &c_neighbors {
694                                let a_ij = matrix.get(i, j);
695                                let tol = T::Real::from_f64(1e-15).unwrap();
696                                if a_ii.norm() > tol {
697                                    let w = T::zero() - a_ij * a_ii.inv();
698                                    weights.push((fine_to_coarse[j], w));
699                                    sum_weights += w;
700                                }
701                            }
702
703                            // Add weak connections contribution (standard interpolation)
704                            if config.interpolation == AmgInterpolation::Standard {
705                                let mut weak_sum = T::zero();
706                                for (j, val) in matrix.row_entries(i) {
707                                    if j != i && !c_neighbors.contains(&j) {
708                                        weak_sum += val;
709                                    }
710                                }
711
712                                let tol = T::Real::from_f64(1e-15).unwrap();
713                                if sum_weights.norm() > tol && weak_sum.norm() > tol {
714                                    let scale = T::one() + weak_sum * (a_ii * sum_weights).inv();
715                                    for (_, w) in &mut weights {
716                                        *w *= scale;
717                                    }
718                                }
719                            }
720
721                            // Truncate small weights if configured
722                            if config.trunc_factor > 0.0 {
723                                let max_w = weights.iter().map(|(_, w)| w.norm()).fold(
724                                    T::Real::from_f64(0.0).unwrap(),
725                                    |a, b| {
726                                        if a > b { a } else { b }
727                                    },
728                                );
729                                let threshold =
730                                    T::Real::from_f64(config.trunc_factor).unwrap() * max_w;
731                                weights.retain(|(_, w)| w.norm() >= threshold);
732
733                                if weights.len() > config.max_interp_elements {
734                                    weights.sort_by(|a, b| {
735                                        b.1.norm().partial_cmp(&a.1.norm()).unwrap()
736                                    });
737                                    weights.truncate(config.max_interp_elements);
738                                }
739                            }
740
741                            for (coarse_idx, w) in weights {
742                                triplets.push((i, coarse_idx, w));
743                            }
744                        }
745                        AmgInterpolation::Extended => {
746                            let mut weights: Vec<(usize, T)> = Vec::new();
747
748                            // Direct C-neighbors
749                            for &j in &c_neighbors {
750                                let a_ij = matrix.get(i, j);
751                                let tol = T::Real::from_f64(1e-15).unwrap();
752                                if a_ii.norm() > tol {
753                                    let w = T::zero() - a_ij * a_ii.inv();
754                                    weights.push((fine_to_coarse[j], w));
755                                }
756                            }
757
758                            // F-neighbors contribute through their C-neighbors
759                            let f_neighbors: Vec<usize> = strong[i]
760                                .iter()
761                                .copied()
762                                .filter(|&j| point_types[j] == PointType::Fine)
763                                .collect();
764
765                            for &k in &f_neighbors {
766                                let a_ik = matrix.get(i, k);
767                                let a_kk = matrix.get(k, k);
768
769                                let tol = T::Real::from_f64(1e-15).unwrap();
770                                if a_kk.norm() < tol {
771                                    continue;
772                                }
773
774                                for &j in &strong[k] {
775                                    if point_types[j] == PointType::Coarse {
776                                        let a_kj = matrix.get(k, j);
777                                        let w = T::zero() - a_ik * a_kj * (a_ii * a_kk).inv();
778
779                                        let coarse_j = fine_to_coarse[j];
780                                        if let Some((_, existing)) =
781                                            weights.iter_mut().find(|(idx, _)| *idx == coarse_j)
782                                        {
783                                            *existing += w;
784                                        } else {
785                                            weights.push((coarse_j, w));
786                                        }
787                                    }
788                                }
789                            }
790
791                            if weights.len() > config.max_interp_elements {
792                                weights
793                                    .sort_by(|a, b| b.1.norm().partial_cmp(&a.1.norm()).unwrap());
794                                weights.truncate(config.max_interp_elements);
795                            }
796
797                            for (coarse_idx, w) in weights {
798                                triplets.push((i, coarse_idx, w));
799                            }
800                        }
801                    }
802                }
803                PointType::Undecided => {}
804            }
805        }
806
807        CsrMatrix::from_triplets(n_fine, n_coarse, triplets)
808    }
809
810    /// Transpose a CSR matrix
811    fn transpose_csr(matrix: &CsrMatrix<T>) -> CsrMatrix<T> {
812        let m = matrix.num_rows;
813        let n = matrix.num_cols;
814
815        let mut triplets: Vec<(usize, usize, T)> = Vec::new();
816        for i in 0..m {
817            for (j, val) in matrix.row_entries(i) {
818                triplets.push((j, i, val));
819            }
820        }
821
822        CsrMatrix::from_triplets(n, m, triplets)
823    }
824
825    /// Compute Galerkin coarse grid operator: A_c = R * A * P
826    fn galerkin_product(r: &CsrMatrix<T>, a: &CsrMatrix<T>, p: &CsrMatrix<T>) -> CsrMatrix<T> {
827        let ap = Self::sparse_matmul(a, p);
828        Self::sparse_matmul(r, &ap)
829    }
830
831    /// Sparse matrix multiplication
832    fn sparse_matmul(a: &CsrMatrix<T>, b: &CsrMatrix<T>) -> CsrMatrix<T> {
833        assert_eq!(a.num_cols, b.num_rows, "Matrix dimension mismatch");
834
835        let m = a.num_rows;
836        let n = b.num_cols;
837        let mut triplets: Vec<(usize, usize, T)> = Vec::new();
838
839        for i in 0..m {
840            let mut row_acc: HashMap<usize, T> = HashMap::new();
841
842            for (k, a_ik) in a.row_entries(i) {
843                for (j, b_kj) in b.row_entries(k) {
844                    *row_acc.entry(j).or_insert(T::zero()) += a_ik * b_kj;
845                }
846            }
847
848            let tol = T::Real::from_f64(1e-15).unwrap();
849            for (j, val) in row_acc {
850                if val.norm() > tol {
851                    triplets.push((i, j, val));
852                }
853            }
854        }
855
856        CsrMatrix::from_triplets(m, n, triplets)
857    }
858
859    /// Compute grid and operator complexities
860    fn compute_complexities(levels: &[AmgLevel<T>]) -> (f64, f64) {
861        if levels.is_empty() {
862            return (1.0, 1.0);
863        }
864
865        let fine_dofs = levels[0].num_dofs as f64;
866        let fine_nnz = levels[0].matrix.nnz() as f64;
867
868        let total_dofs: f64 = levels.iter().map(|l| l.num_dofs as f64).sum();
869        let total_nnz: f64 = levels.iter().map(|l| l.matrix.nnz() as f64).sum();
870
871        let grid_complexity = total_dofs / fine_dofs;
872        let operator_complexity = total_nnz / fine_nnz;
873
874        (grid_complexity, operator_complexity)
875    }
876
877    /// Apply Jacobi smoothing: x = x + ω * D^{-1} * (b - A*x)
878    fn smooth_jacobi(
879        matrix: &CsrMatrix<T>,
880        diag_inv: &Array1<T>,
881        x: &mut Array1<T>,
882        b: &Array1<T>,
883        omega: f64,
884        num_sweeps: usize,
885    ) {
886        let omega = T::from_real(T::Real::from_f64(omega).unwrap());
887        let n = x.len();
888
889        for _ in 0..num_sweeps {
890            let r = b - &matrix.matvec(x);
891
892            #[cfg(any(feature = "native", feature = "wasm"))]
893            {
894                let updates: Vec<T> = parallel_map_indexed(n, |i| omega * diag_inv[i] * r[i]);
895                for (i, delta) in updates.into_iter().enumerate() {
896                    x[i] += delta;
897                }
898            }
899
900            #[cfg(not(any(feature = "native", feature = "wasm")))]
901            {
902                for i in 0..n {
903                    x[i] += omega * diag_inv[i] * r[i];
904                }
905            }
906        }
907    }
908
909    /// Apply l1-Jacobi smoothing
910    fn smooth_l1_jacobi(
911        matrix: &CsrMatrix<T>,
912        x: &mut Array1<T>,
913        b: &Array1<T>,
914        num_sweeps: usize,
915    ) {
916        let n = x.len();
917
918        let l1_diag: Vec<T::Real> = (0..n)
919            .map(|i| {
920                let mut sum = T::Real::from_f64(0.0).unwrap();
921                for (_, val) in matrix.row_entries(i) {
922                    sum += val.norm();
923                }
924                let tol = T::Real::from_f64(1e-15).unwrap();
925                if sum > tol {
926                    sum
927                } else {
928                    T::Real::from_f64(1.0).unwrap()
929                }
930            })
931            .collect();
932
933        for _ in 0..num_sweeps {
934            let r = b - &matrix.matvec(x);
935
936            #[cfg(any(feature = "native", feature = "wasm"))]
937            {
938                let updates: Vec<T> =
939                    parallel_map_indexed(n, |i| r[i] * T::from_real(l1_diag[i]).inv());
940                for (i, delta) in updates.into_iter().enumerate() {
941                    x[i] += delta;
942                }
943            }
944
945            #[cfg(not(any(feature = "native", feature = "wasm")))]
946            {
947                for i in 0..n {
948                    x[i] += r[i] * T::from_real(l1_diag[i]).inv();
949                }
950            }
951        }
952    }
953
954    /// Apply symmetric Gauss-Seidel smoothing
955    fn smooth_sym_gauss_seidel(
956        matrix: &CsrMatrix<T>,
957        x: &mut Array1<T>,
958        b: &Array1<T>,
959        num_sweeps: usize,
960    ) {
961        let n = x.len();
962        let tol = T::Real::from_f64(1e-15).unwrap();
963
964        for _ in 0..num_sweeps {
965            // Forward sweep
966            for i in 0..n {
967                let mut sum = b[i];
968                let mut diag = T::one();
969
970                for (j, val) in matrix.row_entries(i) {
971                    if j == i {
972                        diag = val;
973                    } else {
974                        sum -= val * x[j];
975                    }
976                }
977
978                if diag.norm() > tol {
979                    x[i] = sum * diag.inv();
980                }
981            }
982
983            // Backward sweep
984            for i in (0..n).rev() {
985                let mut sum = b[i];
986                let mut diag = T::one();
987
988                for (j, val) in matrix.row_entries(i) {
989                    if j == i {
990                        diag = val;
991                    } else {
992                        sum -= val * x[j];
993                    }
994                }
995
996                if diag.norm() > tol {
997                    x[i] = sum * diag.inv();
998                }
999            }
1000        }
1001    }
1002
1003    /// Apply V-cycle
1004    fn v_cycle(&self, level: usize, x: &mut Array1<T>, b: &Array1<T>) {
1005        let lvl = &self.levels[level];
1006
1007        // Coarsest level: direct solve (or many smoothing iterations)
1008        if level == self.levels.len() - 1 || lvl.prolongation.is_none() {
1009            match self.config.smoother {
1010                AmgSmoother::Jacobi | AmgSmoother::Chebyshev => {
1011                    Self::smooth_jacobi(
1012                        &lvl.matrix,
1013                        &lvl.diag_inv,
1014                        x,
1015                        b,
1016                        self.config.jacobi_weight,
1017                        20,
1018                    );
1019                }
1020                AmgSmoother::L1Jacobi => {
1021                    Self::smooth_l1_jacobi(&lvl.matrix, x, b, 20);
1022                }
1023                AmgSmoother::SymmetricGaussSeidel => {
1024                    Self::smooth_sym_gauss_seidel(&lvl.matrix, x, b, 10);
1025                }
1026            }
1027            return;
1028        }
1029
1030        // Pre-smoothing
1031        match self.config.smoother {
1032            AmgSmoother::Jacobi | AmgSmoother::Chebyshev => {
1033                Self::smooth_jacobi(
1034                    &lvl.matrix,
1035                    &lvl.diag_inv,
1036                    x,
1037                    b,
1038                    self.config.jacobi_weight,
1039                    self.config.num_pre_smooth,
1040                );
1041            }
1042            AmgSmoother::L1Jacobi => {
1043                Self::smooth_l1_jacobi(&lvl.matrix, x, b, self.config.num_pre_smooth);
1044            }
1045            AmgSmoother::SymmetricGaussSeidel => {
1046                Self::smooth_sym_gauss_seidel(&lvl.matrix, x, b, self.config.num_pre_smooth);
1047            }
1048        }
1049
1050        // Compute residual: r = b - A*x
1051        let r = b - &lvl.matrix.matvec(x);
1052
1053        // Restrict residual to coarse grid: r_c = R * r
1054        let r_coarse = lvl.restriction.as_ref().unwrap().matvec(&r);
1055
1056        // Initialize coarse correction
1057        let n_coarse = self.levels[level + 1].num_dofs;
1058        let mut e_coarse = Array1::from_elem(n_coarse, T::zero());
1059
1060        // Recursive call
1061        self.v_cycle(level + 1, &mut e_coarse, &r_coarse);
1062
1063        // Prolongate correction: e = P * e_c
1064        let e = lvl.prolongation.as_ref().unwrap().matvec(&e_coarse);
1065
1066        // Apply correction: x = x + e
1067        *x = x.clone() + e;
1068
1069        // Post-smoothing
1070        match self.config.smoother {
1071            AmgSmoother::Jacobi | AmgSmoother::Chebyshev => {
1072                Self::smooth_jacobi(
1073                    &lvl.matrix,
1074                    &lvl.diag_inv,
1075                    x,
1076                    b,
1077                    self.config.jacobi_weight,
1078                    self.config.num_post_smooth,
1079                );
1080            }
1081            AmgSmoother::L1Jacobi => {
1082                Self::smooth_l1_jacobi(&lvl.matrix, x, b, self.config.num_post_smooth);
1083            }
1084            AmgSmoother::SymmetricGaussSeidel => {
1085                Self::smooth_sym_gauss_seidel(&lvl.matrix, x, b, self.config.num_post_smooth);
1086            }
1087        }
1088    }
1089}
1090
1091impl<T: ComplexField> Preconditioner<T> for AmgPreconditioner<T>
1092where
1093    T::Real: Sync + Send,
1094{
1095    fn apply(&self, r: &Array1<T>) -> Array1<T> {
1096        if self.levels.is_empty() {
1097            return r.clone();
1098        }
1099
1100        let n = self.levels[0].num_dofs;
1101        if r.len() != n {
1102            return r.clone();
1103        }
1104
1105        let mut z = Array1::from_elem(n, T::zero());
1106
1107        match self.config.cycle {
1108            AmgCycle::VCycle => {
1109                self.v_cycle(0, &mut z, r);
1110            }
1111            AmgCycle::WCycle => {
1112                self.v_cycle(0, &mut z, r);
1113                self.v_cycle(0, &mut z, r);
1114            }
1115            AmgCycle::FCycle => {
1116                self.v_cycle(0, &mut z, r);
1117                let residual = r - &self.levels[0].matrix.matvec(&z);
1118                let mut correction = Array1::from_elem(n, T::zero());
1119                self.v_cycle(0, &mut correction, &residual);
1120                z = z + correction;
1121            }
1122        }
1123
1124        z
1125    }
1126}
1127
1128/// Diagnostic information about AMG setup
1129#[derive(Debug, Clone)]
1130pub struct AmgDiagnostics {
1131    /// Number of levels
1132    pub num_levels: usize,
1133    /// Grid complexity
1134    pub grid_complexity: f64,
1135    /// Operator complexity
1136    pub operator_complexity: f64,
1137    /// Setup time in milliseconds
1138    pub setup_time_ms: f64,
1139    /// DOFs per level
1140    pub level_dofs: Vec<usize>,
1141    /// NNZ per level
1142    pub level_nnz: Vec<usize>,
1143}
1144
1145impl<T: ComplexField> AmgPreconditioner<T> {
1146    /// Get diagnostic information
1147    pub fn diagnostics(&self) -> AmgDiagnostics {
1148        AmgDiagnostics {
1149            num_levels: self.levels.len(),
1150            grid_complexity: self.grid_complexity,
1151            operator_complexity: self.operator_complexity,
1152            setup_time_ms: self.setup_time_ms,
1153            level_dofs: self.levels.iter().map(|l| l.num_dofs).collect(),
1154            level_nnz: self.levels.iter().map(|l| l.matrix.nnz()).collect(),
1155        }
1156    }
1157}
1158
1159#[cfg(test)]
1160mod tests {
1161    use super::*;
1162    use num_complex::Complex64;
1163
1164    /// Create a simple 1D Laplacian matrix for testing
1165    fn create_1d_laplacian(n: usize) -> CsrMatrix<Complex64> {
1166        let mut triplets: Vec<(usize, usize, Complex64)> = Vec::new();
1167
1168        for i in 0..n {
1169            triplets.push((i, i, Complex64::new(2.0, 0.0)));
1170            if i > 0 {
1171                triplets.push((i, i - 1, Complex64::new(-1.0, 0.0)));
1172            }
1173            if i < n - 1 {
1174                triplets.push((i, i + 1, Complex64::new(-1.0, 0.0)));
1175            }
1176        }
1177
1178        CsrMatrix::from_triplets(n, n, triplets)
1179    }
1180
1181    #[test]
1182    fn test_amg_creation() {
1183        let matrix = create_1d_laplacian(100);
1184        let config = AmgConfig::default();
1185
1186        let amg = AmgPreconditioner::from_csr(&matrix, config);
1187
1188        assert!(amg.num_levels() >= 2);
1189        assert!(amg.grid_complexity() >= 1.0);
1190        assert!(amg.operator_complexity() >= 1.0);
1191    }
1192
1193    #[test]
1194    fn test_amg_apply() {
1195        let matrix = create_1d_laplacian(50);
1196        let config = AmgConfig::default();
1197        let amg = AmgPreconditioner::from_csr(&matrix, config);
1198
1199        let r = Array1::from_vec((0..50).map(|i| Complex64::new(i as f64, 0.0)).collect());
1200
1201        let z = amg.apply(&r);
1202
1203        assert_eq!(z.len(), r.len());
1204
1205        let diff: f64 = (&z - &r).iter().map(|x| x.norm()).sum();
1206        assert!(diff > 1e-10, "Preconditioner should modify the vector");
1207    }
1208
1209    #[test]
1210    fn test_amg_pmis_coarsening() {
1211        let matrix = create_1d_laplacian(100);
1212        let config = AmgConfig {
1213            coarsening: AmgCoarsening::Pmis,
1214            ..Default::default()
1215        };
1216
1217        let amg = AmgPreconditioner::from_csr(&matrix, config);
1218        assert!(amg.num_levels() >= 2);
1219    }
1220
1221    #[test]
1222    fn test_amg_different_smoothers() {
1223        let matrix = create_1d_laplacian(50);
1224        let r = Array1::from_vec((0..50).map(|i| Complex64::new(i as f64, 0.0)).collect());
1225
1226        for smoother in [
1227            AmgSmoother::Jacobi,
1228            AmgSmoother::L1Jacobi,
1229            AmgSmoother::SymmetricGaussSeidel,
1230        ] {
1231            let config = AmgConfig {
1232                smoother,
1233                ..Default::default()
1234            };
1235            let amg = AmgPreconditioner::from_csr(&matrix, config);
1236
1237            let z = amg.apply(&r);
1238            assert_eq!(z.len(), r.len());
1239        }
1240    }
1241
1242    #[test]
1243    fn test_amg_reduces_residual() {
1244        let n = 64;
1245        let matrix = create_1d_laplacian(n);
1246        let config = AmgConfig::default();
1247        let amg = AmgPreconditioner::from_csr(&matrix, config);
1248
1249        let b = Array1::from_vec(
1250            (0..n)
1251                .map(|i| Complex64::new((i as f64).sin(), 0.0))
1252                .collect(),
1253        );
1254
1255        let mut x = Array1::from_elem(n, Complex64::new(0.0, 0.0));
1256
1257        let r0 = &b - &matrix.matvec(&x);
1258        let norm_r0: f64 = r0.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
1259
1260        for _ in 0..10 {
1261            let r = &b - &matrix.matvec(&x);
1262            let z = amg.apply(&r);
1263            x = x + z;
1264        }
1265
1266        let rf = &b - &matrix.matvec(&x);
1267        let norm_rf: f64 = rf.iter().map(|v| v.norm_sqr()).sum::<f64>().sqrt();
1268
1269        assert!(
1270            norm_rf < norm_r0 * 0.1,
1271            "AMG should significantly reduce residual: {} -> {}",
1272            norm_r0,
1273            norm_rf
1274        );
1275    }
1276
1277    #[test]
1278    fn test_diagnostics() {
1279        let matrix = create_1d_laplacian(100);
1280        let amg = AmgPreconditioner::from_csr(&matrix, AmgConfig::default());
1281
1282        let diag = amg.diagnostics();
1283
1284        assert!(diag.num_levels >= 2);
1285        assert_eq!(diag.level_dofs.len(), diag.num_levels);
1286        assert_eq!(diag.level_nnz.len(), diag.num_levels);
1287        assert!(diag.grid_complexity >= 1.0);
1288        assert!(diag.setup_time_ms >= 0.0);
1289    }
1290}