1use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, FromPrimitive, NumAssign, One, ToPrimitive, Zero};
11use std::fmt::Debug;
12use std::ops::Neg;
13
14pub trait ComplexField:
28 NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
29{
30 type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
32
33 fn conj(&self) -> Self;
35
36 fn norm_sqr(&self) -> Self::Real;
38
39 fn norm(&self) -> Self::Real {
41 self.norm_sqr().sqrt()
42 }
43
44 fn from_real(r: Self::Real) -> Self;
46
47 fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
49
50 fn re(&self) -> Self::Real;
52
53 fn im(&self) -> Self::Real;
55
56 fn is_zero_approx(&self, tol: Self::Real) -> bool {
58 self.norm_sqr() < tol * tol
59 }
60
61 fn inv(&self) -> Self;
63
64 fn sqrt(&self) -> Self;
66}
67
68impl ComplexField for Complex64 {
69 type Real = f64;
70
71 #[inline]
72 fn conj(&self) -> Self {
73 Complex64::conj(self)
74 }
75
76 #[inline]
77 fn norm_sqr(&self) -> f64 {
78 self.re * self.re + self.im * self.im
79 }
80
81 #[inline]
82 fn from_real(r: f64) -> Self {
83 Complex64::new(r, 0.0)
84 }
85
86 #[inline]
87 fn from_re_im(re: f64, im: f64) -> Self {
88 Complex64::new(re, im)
89 }
90
91 #[inline]
92 fn re(&self) -> f64 {
93 self.re
94 }
95
96 #[inline]
97 fn im(&self) -> f64 {
98 self.im
99 }
100
101 #[inline]
102 fn inv(&self) -> Self {
103 let denom = self.norm_sqr();
104 Complex64::new(self.re / denom, -self.im / denom)
105 }
106
107 #[inline]
108 fn sqrt(&self) -> Self {
109 Complex64::sqrt(*self)
110 }
111}
112
113impl ComplexField for Complex32 {
114 type Real = f32;
115
116 #[inline]
117 fn conj(&self) -> Self {
118 Complex32::conj(self)
119 }
120
121 #[inline]
122 fn norm_sqr(&self) -> f32 {
123 self.re * self.re + self.im * self.im
124 }
125
126 #[inline]
127 fn from_real(r: f32) -> Self {
128 Complex32::new(r, 0.0)
129 }
130
131 #[inline]
132 fn from_re_im(re: f32, im: f32) -> Self {
133 Complex32::new(re, im)
134 }
135
136 #[inline]
137 fn re(&self) -> f32 {
138 self.re
139 }
140
141 #[inline]
142 fn im(&self) -> f32 {
143 self.im
144 }
145
146 #[inline]
147 fn inv(&self) -> Self {
148 let denom = self.norm_sqr();
149 Complex32::new(self.re / denom, -self.im / denom)
150 }
151
152 #[inline]
153 fn sqrt(&self) -> Self {
154 Complex32::sqrt(*self)
155 }
156}
157
158impl ComplexField for f64 {
159 type Real = f64;
160
161 #[inline]
162 fn conj(&self) -> Self {
163 *self
164 }
165
166 #[inline]
167 fn norm_sqr(&self) -> f64 {
168 *self * *self
169 }
170
171 #[inline]
172 fn from_real(r: f64) -> Self {
173 r
174 }
175
176 #[inline]
177 fn from_re_im(re: f64, _im: f64) -> Self {
178 re
179 }
180
181 #[inline]
182 fn re(&self) -> f64 {
183 *self
184 }
185
186 #[inline]
187 fn im(&self) -> f64 {
188 0.0
189 }
190
191 #[inline]
192 fn inv(&self) -> Self {
193 1.0 / *self
194 }
195
196 #[inline]
197 fn sqrt(&self) -> Self {
198 f64::sqrt(*self)
199 }
200}
201
202impl ComplexField for f32 {
203 type Real = f32;
204
205 #[inline]
206 fn conj(&self) -> Self {
207 *self
208 }
209
210 #[inline]
211 fn norm_sqr(&self) -> f32 {
212 *self * *self
213 }
214
215 #[inline]
216 fn from_real(r: f32) -> Self {
217 r
218 }
219
220 #[inline]
221 fn from_re_im(re: f32, _im: f32) -> Self {
222 re
223 }
224
225 #[inline]
226 fn re(&self) -> f32 {
227 *self
228 }
229
230 #[inline]
231 fn im(&self) -> f32 {
232 0.0
233 }
234
235 #[inline]
236 fn inv(&self) -> Self {
237 1.0 / *self
238 }
239
240 #[inline]
241 fn sqrt(&self) -> Self {
242 f32::sqrt(*self)
243 }
244}
245
246pub trait LinearOperator<T: ComplexField>: Send + Sync {
251 fn num_rows(&self) -> usize;
253
254 fn num_cols(&self) -> usize;
256
257 fn apply(&self, x: &Array1<T>) -> Array1<T>;
259
260 fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
262
263 fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
265 let x_conj: Array1<T> = x.mapv(|v| v.conj());
267 self.apply_transpose(&x_conj).mapv(|v| v.conj())
268 }
269
270 fn is_square(&self) -> bool {
272 self.num_rows() == self.num_cols()
273 }
274}
275
276pub trait Preconditioner<T: ComplexField>: Send + Sync {
281 fn apply(&self, r: &Array1<T>) -> Array1<T>;
285}
286
287#[derive(Clone, Debug, Default)]
289pub struct IdentityPreconditioner;
290
291impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
292 fn apply(&self, r: &Array1<T>) -> Array1<T> {
293 r.clone()
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300 use approx::assert_relative_eq;
301
302 #[test]
303 fn test_complex64_field() {
304 let z = Complex64::new(3.0, 4.0);
305 assert_relative_eq!(z.norm_sqr(), 25.0);
306 assert_relative_eq!(z.norm(), 5.0);
307
308 let z_conj = z.conj();
309 assert_relative_eq!(z_conj.re, 3.0);
310 assert_relative_eq!(z_conj.im, -4.0);
311
312 let z_inv = z.inv();
313 let product = z * z_inv;
314 assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
315 assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
316 }
317
318 #[test]
319 fn test_f64_field() {
320 let x: f64 = 3.0;
321 assert_relative_eq!(x.norm_sqr(), 9.0);
322 assert_relative_eq!(x.norm(), 3.0);
323 assert_relative_eq!(x.conj(), 3.0);
324 assert_relative_eq!(x.inv(), 1.0 / 3.0);
325 }
326
327 #[test]
328 fn test_identity_preconditioner() {
329 let precond = IdentityPreconditioner;
330 let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
331 let y = precond.apply(&r);
332 assert_eq!(r, y);
333 }
334}