solvers/preconditioners/
diagonal.rs1use crate::sparse::CsrMatrix;
9use crate::traits::{ComplexField, Preconditioner};
10use ndarray::Array1;
11use num_traits::FromPrimitive;
12
13#[cfg(feature = "rayon")]
14use rayon::prelude::*;
15
16#[derive(Debug, Clone)]
20pub struct DiagonalPreconditioner<T: ComplexField> {
21 inv_diag: Array1<T>,
23}
24
25impl<T: ComplexField> DiagonalPreconditioner<T> {
26 pub fn from_csr(matrix: &CsrMatrix<T>) -> Self {
28 let diag = matrix.diagonal();
29 let inv_diag = diag.mapv(|d| {
30 if d.norm() > T::Real::from_f64(1e-30).unwrap() {
31 d.inv()
32 } else {
33 T::one()
34 }
35 });
36 Self { inv_diag }
37 }
38
39 pub fn from_diagonal(diag: &Array1<T>) -> Self {
41 let inv_diag = diag.mapv(|d| {
42 if d.norm() > T::Real::from_f64(1e-30).unwrap() {
43 d.inv()
44 } else {
45 T::one()
46 }
47 });
48 Self { inv_diag }
49 }
50
51 pub fn from_inverse_diagonal(inv_diag: Array1<T>) -> Self {
53 Self { inv_diag }
54 }
55}
56
57impl<T: ComplexField> Preconditioner<T> for DiagonalPreconditioner<T> {
58 fn apply(&self, r: &Array1<T>) -> Array1<T> {
59 #[cfg(feature = "rayon")]
60 {
61 if r.len() >= 1000 {
62 return self.apply_parallel(r);
63 }
64 }
65 self.apply_sequential(r)
66 }
67}
68
69impl<T: ComplexField> DiagonalPreconditioner<T> {
70 fn apply_sequential(&self, r: &Array1<T>) -> Array1<T> {
71 r.iter()
72 .zip(self.inv_diag.iter())
73 .map(|(&ri, &di)| ri * di)
74 .collect()
75 }
76
77 #[cfg(feature = "rayon")]
78 fn apply_parallel(&self, r: &Array1<T>) -> Array1<T>
79 where
80 T: Send + Sync,
81 {
82 let r_slice = r.as_slice().expect("Array should be contiguous");
83 let inv_slice = self
84 .inv_diag
85 .as_slice()
86 .expect("Array should be contiguous");
87
88 let results: Vec<T> = r_slice
89 .par_iter()
90 .zip(inv_slice.par_iter())
91 .map(|(&ri, &di)| ri * di)
92 .collect();
93
94 Array1::from_vec(results)
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101 use approx::assert_relative_eq;
102 use ndarray::array;
103 use num_complex::Complex64;
104
105 #[test]
106 fn test_diagonal_preconditioner() {
107 let diag = array![
108 Complex64::new(2.0, 0.0),
109 Complex64::new(4.0, 0.0),
110 Complex64::new(1.0, 0.0)
111 ];
112
113 let precond = DiagonalPreconditioner::from_diagonal(&diag);
114
115 let r = array![
116 Complex64::new(2.0, 0.0),
117 Complex64::new(8.0, 0.0),
118 Complex64::new(3.0, 0.0)
119 ];
120
121 let result = precond.apply(&r);
122
123 assert_relative_eq!(result[0].re, 1.0, epsilon = 1e-10);
124 assert_relative_eq!(result[1].re, 2.0, epsilon = 1e-10);
125 assert_relative_eq!(result[2].re, 3.0, epsilon = 1e-10);
126 }
127
128 #[test]
129 fn test_diagonal_from_csr() {
130 use crate::sparse::CsrMatrix;
131
132 let dense = array![
133 [Complex64::new(4.0, 0.0), Complex64::new(1.0, 0.0)],
134 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
135 ];
136
137 let matrix = CsrMatrix::from_dense(&dense, 1e-15);
138 let precond = DiagonalPreconditioner::from_csr(&matrix);
139
140 let r = array![Complex64::new(4.0, 0.0), Complex64::new(4.0, 0.0)];
141 let result = precond.apply(&r);
142
143 assert_relative_eq!(result[0].re, 1.0, epsilon = 1e-10);
144 assert_relative_eq!(result[1].re, 2.0, epsilon = 1e-10);
145 }
146}