math_audio_solvers/
traits.rs1use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, NumAssign, One, Zero, FromPrimitive, ToPrimitive};
11use std::fmt::Debug;
12use std::ops::Neg;
13
14#[cfg(feature = "ndarray-linalg")]
28pub trait ComplexField:
29 NumAssign
30 + Clone
31 + Copy
32 + Send
33 + Sync
34 + Debug
35 + Zero
36 + One
37 + Neg<Output = Self>
38 + ndarray_linalg::Lapack
39 + 'static
40{
41 #[cfg(not(feature = "ndarray-linalg"))]
45 fn conj(&self) -> Self;
46
47 fn norm_sqr(&self) -> Self::Real;
49
50 fn norm(&self) -> Self::Real {
52 self.norm_sqr().sqrt()
53 }
54
55 #[cfg(not(feature = "ndarray-linalg"))]
57 fn from_real(r: Self::Real) -> Self;
58
59 fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
61
62 fn re(&self) -> Self::Real;
64
65 fn im(&self) -> Self::Real;
67
68 fn is_zero_approx(&self, tol: Self::Real) -> bool {
70 self.norm_sqr() < tol * tol
71 }
72
73 fn inv(&self) -> Self;
75
76 fn sqrt(&self) -> Self;
78}
79
80#[cfg(not(feature = "ndarray-linalg"))]
81pub trait ComplexField:
82 NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
83{
84 type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
86
87 fn conj(&self) -> Self;
89
90 fn norm_sqr(&self) -> Self::Real;
92
93 fn norm(&self) -> Self::Real {
95 self.norm_sqr().sqrt()
96 }
97
98 fn from_real(r: Self::Real) -> Self;
100
101 fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
103
104 fn re(&self) -> Self::Real;
106
107 fn im(&self) -> Self::Real;
109
110 fn is_zero_approx(&self, tol: Self::Real) -> bool {
112 self.norm_sqr() < tol * tol
113 }
114
115 fn inv(&self) -> Self;
117
118 fn sqrt(&self) -> Self;
120}
121
122impl ComplexField for Complex64 {
123 #[cfg(not(feature = "ndarray-linalg"))]
124 type Real = f64;
125
126 #[inline]
127 #[cfg(not(feature = "ndarray-linalg"))]
128 fn conj(&self) -> Self {
129 Complex64::conj(self)
130 }
131
132 #[inline]
133 fn norm_sqr(&self) -> f64 {
134 self.re * self.re + self.im * self.im
135 }
136
137 #[inline]
138 #[cfg(not(feature = "ndarray-linalg"))]
139 fn from_real(r: f64) -> Self {
140 Complex64::new(r, 0.0)
141 }
142
143 #[inline]
144 fn from_re_im(re: f64, im: f64) -> Self {
145 Complex64::new(re, im)
146 }
147
148 #[inline]
149 fn re(&self) -> f64 {
150 self.re
151 }
152
153 #[inline]
154 fn im(&self) -> f64 {
155 self.im
156 }
157
158 #[inline]
159 fn inv(&self) -> Self {
160 let denom = self.norm_sqr();
161 Complex64::new(self.re / denom, -self.im / denom)
162 }
163
164 #[inline]
165 fn sqrt(&self) -> Self {
166 Complex64::sqrt(*self)
167 }
168}
169
170impl ComplexField for Complex32 {
171 #[cfg(not(feature = "ndarray-linalg"))]
172 type Real = f32;
173
174 #[inline]
175 #[cfg(not(feature = "ndarray-linalg"))]
176 fn conj(&self) -> Self {
177 Complex32::conj(self)
178 }
179
180 #[inline]
181 fn norm_sqr(&self) -> f32 {
182 self.re * self.re + self.im * self.im
183 }
184
185 #[inline]
186 #[cfg(not(feature = "ndarray-linalg"))]
187 fn from_real(r: f32) -> Self {
188 Complex32::new(r, 0.0)
189 }
190
191 #[inline]
192 fn from_re_im(re: f32, im: f32) -> Self {
193 Complex32::new(re, im)
194 }
195
196 #[inline]
197 fn re(&self) -> f32 {
198 self.re
199 }
200
201 #[inline]
202 fn im(&self) -> f32 {
203 self.im
204 }
205
206 #[inline]
207 fn inv(&self) -> Self {
208 let denom = self.norm_sqr();
209 Complex32::new(self.re / denom, -self.im / denom)
210 }
211
212 #[inline]
213 fn sqrt(&self) -> Self {
214 Complex32::sqrt(*self)
215 }
216}
217
218impl ComplexField for f64 {
219 #[cfg(not(feature = "ndarray-linalg"))]
220 type Real = f64;
221
222 #[inline]
223 #[cfg(not(feature = "ndarray-linalg"))]
224 fn conj(&self) -> Self {
225 *self
226 }
227
228 #[inline]
229 fn norm_sqr(&self) -> f64 {
230 *self * *self
231 }
232
233 #[inline]
234 #[cfg(not(feature = "ndarray-linalg"))]
235 fn from_real(r: f64) -> Self {
236 r
237 }
238
239 #[inline]
240 fn from_re_im(re: f64, _im: f64) -> Self {
241 re
242 }
243
244 #[inline]
245 fn re(&self) -> f64 {
246 *self
247 }
248
249 #[inline]
250 fn im(&self) -> f64 {
251 0.0
252 }
253
254 #[inline]
255 fn inv(&self) -> Self {
256 1.0 / *self
257 }
258
259 #[inline]
260 fn sqrt(&self) -> Self {
261 f64::sqrt(*self)
262 }
263}
264
265impl ComplexField for f32 {
266 #[cfg(not(feature = "ndarray-linalg"))]
267 type Real = f32;
268
269 #[inline]
270 #[cfg(not(feature = "ndarray-linalg"))]
271 fn conj(&self) -> Self {
272 *self
273 }
274
275 #[inline]
276 fn norm_sqr(&self) -> f32 {
277 *self * *self
278 }
279
280 #[inline]
281 #[cfg(not(feature = "ndarray-linalg"))]
282 fn from_real(r: f32) -> Self {
283 r
284 }
285
286 #[inline]
287 fn from_re_im(re: f32, _im: f32) -> Self {
288 re
289 }
290
291 #[inline]
292 fn re(&self) -> f32 {
293 *self
294 }
295
296 #[inline]
297 fn im(&self) -> f32 {
298 0.0
299 }
300
301 #[inline]
302 fn inv(&self) -> Self {
303 1.0 / *self
304 }
305
306 #[inline]
307 fn sqrt(&self) -> Self {
308 f32::sqrt(*self)
309 }
310}
311
312pub trait LinearOperator<T: ComplexField>: Send + Sync {
317 fn num_rows(&self) -> usize;
319
320 fn num_cols(&self) -> usize;
322
323 fn apply(&self, x: &Array1<T>) -> Array1<T>;
325
326 fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
328
329 fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
331 let x_conj: Array1<T> = x.mapv(|v| {
336 #[cfg(feature = "ndarray-linalg")]
337 {
338 ndarray_linalg::Scalar::conj(&v)
339 }
340 #[cfg(not(feature = "ndarray-linalg"))]
341 {
342 v.conj()
343 }
344 });
345
346 let y = self.apply_transpose(&x_conj);
347
348 y.mapv(|v| {
349 #[cfg(feature = "ndarray-linalg")]
350 {
351 ndarray_linalg::Scalar::conj(&v)
352 }
353 #[cfg(not(feature = "ndarray-linalg"))]
354 {
355 v.conj()
356 }
357 })
358 }
359
360 fn is_square(&self) -> bool {
362 self.num_rows() == self.num_cols()
363 }
364}
365
366pub trait Preconditioner<T: ComplexField>: Send + Sync {
371 fn apply(&self, r: &Array1<T>) -> Array1<T>;
375}
376
377#[derive(Clone, Debug, Default)]
379pub struct IdentityPreconditioner;
380
381impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
382 fn apply(&self, r: &Array1<T>) -> Array1<T> {
383 r.clone()
384 }
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390 use approx::assert_relative_eq;
391
392 #[test]
393 fn test_complex64_field() {
394 let z = Complex64::new(3.0, 4.0);
395 assert_relative_eq!(z.norm_sqr(), 25.0);
396 assert_relative_eq!(z.norm(), 5.0);
397
398 #[cfg(not(feature = "ndarray-linalg"))]
401 let z_conj = ComplexField::conj(&z);
402 #[cfg(feature = "ndarray-linalg")]
403 let z_conj = ndarray_linalg::Scalar::conj(&z);
404
405 assert_relative_eq!(z_conj.re, 3.0);
406 assert_relative_eq!(z_conj.im, -4.0);
407
408 let z_inv = z.inv();
409 let product = z * z_inv;
410 assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
411 assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
412 }
413
414 #[test]
415 fn test_f64_field() {
416 let x: f64 = 3.0;
417 assert_relative_eq!(x.norm_sqr(), 9.0);
418 assert_relative_eq!(x.norm(), 3.0);
419
420 #[cfg(not(feature = "ndarray-linalg"))]
421 assert_relative_eq!(ComplexField::conj(&x), 3.0);
422 #[cfg(feature = "ndarray-linalg")]
423 assert_relative_eq!(ndarray_linalg::Scalar::conj(&x), 3.0);
424
425 assert_relative_eq!(x.inv(), 1.0 / 3.0);
426 }
427
428 #[test]
429 fn test_identity_preconditioner() {
430 let precond = IdentityPreconditioner;
431 let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
432 let y = precond.apply(&r);
433 assert_eq!(r, y);
434 }
435}