math_audio_solvers/
traits.rs1use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, NumAssign, One, Zero};
11use num_traits::{FromPrimitive, ToPrimitive};
12use std::fmt::Debug;
13use std::ops::Neg;
14
15pub trait ComplexField:
25 NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
26{
27 type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
29
30 fn conj(&self) -> Self;
32
33 fn norm_sqr(&self) -> Self::Real;
35
36 fn norm(&self) -> Self::Real {
38 self.norm_sqr().sqrt()
39 }
40
41 fn from_real(r: Self::Real) -> Self;
43
44 fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
46
47 fn re(&self) -> Self::Real;
49
50 fn im(&self) -> Self::Real;
52
53 fn is_zero_approx(&self, tol: Self::Real) -> bool {
55 self.norm_sqr() < tol * tol
56 }
57
58 fn inv(&self) -> Self;
60
61 fn sqrt(&self) -> Self;
63
64 fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
74 let mut sum = Self::zero();
75 for (xi, yi) in x.iter().zip(y.iter()) {
76 sum += xi.conj() * *yi;
77 }
78 sum
79 }
80
81 fn vec_norm_sqr(x: &Array1<Self>) -> Self::Real {
83 let mut sum = Self::Real::zero();
84 for xi in x.iter() {
85 sum += xi.norm_sqr();
86 }
87 sum
88 }
89
90 fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
92 for (xi, yi) in x.iter().zip(y.iter_mut()) {
93 *yi += alpha * *xi;
94 }
95 }
96
97 fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
99 for xi in x.iter_mut() {
100 *xi *= alpha;
101 }
102 }
103}
104
105impl ComplexField for Complex64 {
106 type Real = f64;
107
108 #[inline]
109 fn conj(&self) -> Self {
110 Complex64::conj(self)
111 }
112
113 #[inline]
114 fn norm_sqr(&self) -> f64 {
115 self.re * self.re + self.im * self.im
116 }
117
118 #[inline]
119 fn from_real(r: f64) -> Self {
120 Complex64::new(r, 0.0)
121 }
122
123 #[inline]
124 fn from_re_im(re: f64, im: f64) -> Self {
125 Complex64::new(re, im)
126 }
127
128 #[inline]
129 fn re(&self) -> f64 {
130 self.re
131 }
132
133 #[inline]
134 fn im(&self) -> f64 {
135 self.im
136 }
137
138 #[inline]
139 fn inv(&self) -> Self {
140 let denom = self.norm_sqr();
141 Complex64::new(self.re / denom, -self.im / denom)
142 }
143
144 #[inline]
145 fn sqrt(&self) -> Self {
146 Complex64::sqrt(*self)
147 }
148}
149
150impl ComplexField for Complex32 {
151 type Real = f32;
152
153 #[inline]
154 fn conj(&self) -> Self {
155 Complex32::conj(self)
156 }
157
158 #[inline]
159 fn norm_sqr(&self) -> f32 {
160 self.re * self.re + self.im * self.im
161 }
162
163 #[inline]
164 fn from_real(r: f32) -> Self {
165 Complex32::new(r, 0.0)
166 }
167
168 #[inline]
169 fn from_re_im(re: f32, im: f32) -> Self {
170 Complex32::new(re, im)
171 }
172
173 #[inline]
174 fn re(&self) -> f32 {
175 self.re
176 }
177
178 #[inline]
179 fn im(&self) -> f32 {
180 self.im
181 }
182
183 #[inline]
184 fn inv(&self) -> Self {
185 let denom = self.norm_sqr();
186 Complex32::new(self.re / denom, -self.im / denom)
187 }
188
189 #[inline]
190 fn sqrt(&self) -> Self {
191 Complex32::sqrt(*self)
192 }
193}
194
195impl ComplexField for f64 {
196 type Real = f64;
197
198 #[inline]
199 fn conj(&self) -> Self {
200 *self
201 }
202
203 #[inline]
204 fn norm_sqr(&self) -> f64 {
205 *self * *self
206 }
207
208 #[inline]
209 fn from_real(r: f64) -> Self {
210 r
211 }
212
213 #[inline]
214 fn from_re_im(re: f64, _im: f64) -> Self {
215 re
216 }
217
218 #[inline]
219 fn re(&self) -> f64 {
220 *self
221 }
222
223 #[inline]
224 fn im(&self) -> f64 {
225 0.0
226 }
227
228 #[inline]
229 fn inv(&self) -> Self {
230 1.0 / *self
231 }
232
233 #[inline]
234 fn sqrt(&self) -> Self {
235 f64::sqrt(*self)
236 }
237
238 #[inline]
241 fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
242 x.dot(y)
243 }
244
245 #[inline]
246 fn vec_norm_sqr(x: &Array1<Self>) -> Self {
247 x.dot(x)
248 }
249
250 #[inline]
251 fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
252 y.scaled_add(alpha, x);
253 }
254
255 #[inline]
256 fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
257 x.mapv_inplace(|v| v * alpha);
258 }
259}
260
261impl ComplexField for f32 {
262 type Real = f32;
263
264 #[inline]
265 fn conj(&self) -> Self {
266 *self
267 }
268
269 #[inline]
270 fn norm_sqr(&self) -> f32 {
271 *self * *self
272 }
273
274 #[inline]
275 fn from_real(r: f32) -> Self {
276 r
277 }
278
279 #[inline]
280 fn from_re_im(re: f32, _im: f32) -> Self {
281 re
282 }
283
284 #[inline]
285 fn re(&self) -> f32 {
286 *self
287 }
288
289 #[inline]
290 fn im(&self) -> f32 {
291 0.0
292 }
293
294 #[inline]
295 fn inv(&self) -> Self {
296 1.0 / *self
297 }
298
299 #[inline]
300 fn sqrt(&self) -> Self {
301 f32::sqrt(*self)
302 }
303
304 #[inline]
307 fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
308 x.dot(y)
309 }
310
311 #[inline]
312 fn vec_norm_sqr(x: &Array1<Self>) -> Self {
313 x.dot(x)
314 }
315
316 #[inline]
317 fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
318 y.scaled_add(alpha, x);
319 }
320
321 #[inline]
322 fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
323 x.mapv_inplace(|v| v * alpha);
324 }
325}
326
327pub trait LinearOperator<T: ComplexField>: Send + Sync {
332 fn num_rows(&self) -> usize;
334
335 fn num_cols(&self) -> usize;
337
338 fn apply(&self, x: &Array1<T>) -> Array1<T>;
340
341 fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
343
344 fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
346 let x_conj: Array1<T> = x.mapv(|v| v.conj());
347 let y = self.apply_transpose(&x_conj);
348 y.mapv(|v| v.conj())
349 }
350
351 fn is_square(&self) -> bool {
353 self.num_rows() == self.num_cols()
354 }
355}
356
357#[derive(Debug, Clone, Copy, PartialEq, Eq)]
359pub enum SolverStatus {
360 Converged,
362 MaxIterationsReached,
364 Breakdown,
366 Stagnated,
368 Diverged,
370}
371
372#[derive(Debug, thiserror::Error)]
374pub enum SolverError {
375 #[error("Solver failed to converge: {status:?}")]
376 ConvergenceError {
377 status: SolverStatus,
378 iterations: usize,
379 residual: f64,
380 },
381 #[error("Linear operator dimension mismatch: expected {expected}, got {got}")]
382 DimensionMismatch { expected: usize, got: usize },
383}
384
385pub trait Preconditioner<T: ComplexField>: Send + Sync {
390 fn apply(&self, r: &Array1<T>) -> Array1<T>;
394}
395
396#[derive(Clone, Debug, Default)]
398pub struct IdentityPreconditioner;
399
400impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
401 fn apply(&self, r: &Array1<T>) -> Array1<T> {
402 r.clone()
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use approx::assert_relative_eq;
410
411 #[test]
412 fn test_complex64_field() {
413 let z = Complex64::new(3.0, 4.0);
414 assert_relative_eq!(z.norm_sqr(), 25.0);
415 assert_relative_eq!(z.norm(), 5.0);
416
417 let z_conj = z.conj();
418 assert_relative_eq!(z_conj.re, 3.0);
419 assert_relative_eq!(z_conj.im, -4.0);
420
421 let z_inv = z.inv();
422 let product = z * z_inv;
423 assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
424 assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
425 }
426
427 #[test]
428 fn test_f64_field() {
429 let x: f64 = 3.0;
430 assert_relative_eq!(x.norm_sqr(), 9.0);
431 assert_relative_eq!(x.norm(), 3.0);
432 assert_relative_eq!(x.conj(), 3.0);
433 assert_relative_eq!(x.inv(), 1.0 / 3.0);
434 }
435
436 #[test]
437 fn test_identity_preconditioner() {
438 let precond = IdentityPreconditioner;
439 let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
440 let y = precond.apply(&r);
441 assert_eq!(r, y);
442 }
443}