math_audio_solvers/
traits.rs

1//! Core traits for linear algebra operations
2//!
3//! This module defines the fundamental abstractions used throughout the solver library:
4//! - [`ComplexField`]: Trait for scalar types (complex and real numbers)
5//! - [`LinearOperator`]: Trait for matrix-like objects that can perform matrix-vector products
6//! - [`Preconditioner`]: Trait for preconditioning operations
7
8use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, NumAssign, One, Zero, FromPrimitive, ToPrimitive};
11use std::fmt::Debug;
12use std::ops::Neg;
13
14/// Trait for scalar types that can be used in linear algebra operations.
15///
16/// This trait abstracts over real and complex number types, providing
17/// a unified interface for operations like conjugation, norm computation,
18/// and conversion from real values.
19///
20/// # Implementations
21///
22/// Provided for:
23/// - `Complex64` (default for most acoustic applications)
24/// - `Complex32` (for memory-constrained applications)
25/// - `f64` (for real-valued problems)
26/// - `f32` (for real-valued, memory-constrained applications)
27#[cfg(feature = "ndarray-linalg")]
28pub trait ComplexField:
29    NumAssign
30    + Clone
31    + Copy
32    + Send
33    + Sync
34    + Debug
35    + Zero
36    + One
37    + Neg<Output = Self>
38    + ndarray_linalg::Lapack
39    + 'static
40{
41    // type Real inherited from ndarray_linalg::Scalar via Lapack
42
43    /// Complex conjugate
44    #[cfg(not(feature = "ndarray-linalg"))]
45    fn conj(&self) -> Self;
46
47    /// Squared magnitude |z|²
48    fn norm_sqr(&self) -> Self::Real;
49
50    /// Magnitude |z|
51    fn norm(&self) -> Self::Real {
52        self.norm_sqr().sqrt()
53    }
54
55    /// Create from a real value
56    #[cfg(not(feature = "ndarray-linalg"))]
57    fn from_real(r: Self::Real) -> Self;
58
59    /// Create from real and imaginary parts
60    fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
61
62    /// Real part
63    fn re(&self) -> Self::Real;
64
65    /// Imaginary part
66    fn im(&self) -> Self::Real;
67
68    /// Check if this is approximately zero
69    fn is_zero_approx(&self, tol: Self::Real) -> bool {
70        self.norm_sqr() < tol * tol
71    }
72
73    /// Multiplicative inverse (1/z)
74    fn inv(&self) -> Self;
75
76    /// Square root
77    fn sqrt(&self) -> Self;
78}
79
80#[cfg(not(feature = "ndarray-linalg"))]
81pub trait ComplexField:
82    NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
83{
84    /// The real number type underlying this field
85    type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
86
87    /// Complex conjugate
88    fn conj(&self) -> Self;
89
90    /// Squared magnitude |z|²
91    fn norm_sqr(&self) -> Self::Real;
92
93    /// Magnitude |z|
94    fn norm(&self) -> Self::Real {
95        self.norm_sqr().sqrt()
96    }
97
98    /// Create from a real value
99    fn from_real(r: Self::Real) -> Self;
100
101    /// Create from real and imaginary parts
102    fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
103
104    /// Real part
105    fn re(&self) -> Self::Real;
106
107    /// Imaginary part
108    fn im(&self) -> Self::Real;
109
110    /// Check if this is approximately zero
111    fn is_zero_approx(&self, tol: Self::Real) -> bool {
112        self.norm_sqr() < tol * tol
113    }
114
115    /// Multiplicative inverse (1/z)
116    fn inv(&self) -> Self;
117
118    /// Square root
119    fn sqrt(&self) -> Self;
120}
121
122impl ComplexField for Complex64 {
123    #[cfg(not(feature = "ndarray-linalg"))]
124    type Real = f64;
125
126    #[inline]
127    #[cfg(not(feature = "ndarray-linalg"))]
128    fn conj(&self) -> Self {
129        Complex64::conj(self)
130    }
131
132    #[inline]
133    fn norm_sqr(&self) -> f64 {
134        self.re * self.re + self.im * self.im
135    }
136
137    #[inline]
138    #[cfg(not(feature = "ndarray-linalg"))]
139    fn from_real(r: f64) -> Self {
140        Complex64::new(r, 0.0)
141    }
142
143    #[inline]
144    fn from_re_im(re: f64, im: f64) -> Self {
145        Complex64::new(re, im)
146    }
147
148    #[inline]
149    fn re(&self) -> f64 {
150        self.re
151    }
152
153    #[inline]
154    fn im(&self) -> f64 {
155        self.im
156    }
157
158    #[inline]
159    fn inv(&self) -> Self {
160        let denom = self.norm_sqr();
161        Complex64::new(self.re / denom, -self.im / denom)
162    }
163
164    #[inline]
165    fn sqrt(&self) -> Self {
166        Complex64::sqrt(*self)
167    }
168}
169
170impl ComplexField for Complex32 {
171    #[cfg(not(feature = "ndarray-linalg"))]
172    type Real = f32;
173
174    #[inline]
175    #[cfg(not(feature = "ndarray-linalg"))]
176    fn conj(&self) -> Self {
177        Complex32::conj(self)
178    }
179
180    #[inline]
181    fn norm_sqr(&self) -> f32 {
182        self.re * self.re + self.im * self.im
183    }
184
185    #[inline]
186    #[cfg(not(feature = "ndarray-linalg"))]
187    fn from_real(r: f32) -> Self {
188        Complex32::new(r, 0.0)
189    }
190
191    #[inline]
192    fn from_re_im(re: f32, im: f32) -> Self {
193        Complex32::new(re, im)
194    }
195
196    #[inline]
197    fn re(&self) -> f32 {
198        self.re
199    }
200
201    #[inline]
202    fn im(&self) -> f32 {
203        self.im
204    }
205
206    #[inline]
207    fn inv(&self) -> Self {
208        let denom = self.norm_sqr();
209        Complex32::new(self.re / denom, -self.im / denom)
210    }
211
212    #[inline]
213    fn sqrt(&self) -> Self {
214        Complex32::sqrt(*self)
215    }
216}
217
218impl ComplexField for f64 {
219    #[cfg(not(feature = "ndarray-linalg"))]
220    type Real = f64;
221
222    #[inline]
223    #[cfg(not(feature = "ndarray-linalg"))]
224    fn conj(&self) -> Self {
225        *self
226    }
227
228    #[inline]
229    fn norm_sqr(&self) -> f64 {
230        *self * *self
231    }
232
233    #[inline]
234    #[cfg(not(feature = "ndarray-linalg"))]
235    fn from_real(r: f64) -> Self {
236        r
237    }
238
239    #[inline]
240    fn from_re_im(re: f64, _im: f64) -> Self {
241        re
242    }
243
244    #[inline]
245    fn re(&self) -> f64 {
246        *self
247    }
248
249    #[inline]
250    fn im(&self) -> f64 {
251        0.0
252    }
253
254    #[inline]
255    fn inv(&self) -> Self {
256        1.0 / *self
257    }
258
259    #[inline]
260    fn sqrt(&self) -> Self {
261        f64::sqrt(*self)
262    }
263}
264
265impl ComplexField for f32 {
266    #[cfg(not(feature = "ndarray-linalg"))]
267    type Real = f32;
268
269    #[inline]
270    #[cfg(not(feature = "ndarray-linalg"))]
271    fn conj(&self) -> Self {
272        *self
273    }
274
275    #[inline]
276    fn norm_sqr(&self) -> f32 {
277        *self * *self
278    }
279
280    #[inline]
281    #[cfg(not(feature = "ndarray-linalg"))]
282    fn from_real(r: f32) -> Self {
283        r
284    }
285
286    #[inline]
287    fn from_re_im(re: f32, _im: f32) -> Self {
288        re
289    }
290
291    #[inline]
292    fn re(&self) -> f32 {
293        *self
294    }
295
296    #[inline]
297    fn im(&self) -> f32 {
298        0.0
299    }
300
301    #[inline]
302    fn inv(&self) -> Self {
303        1.0 / *self
304    }
305
306    #[inline]
307    fn sqrt(&self) -> Self {
308        f32::sqrt(*self)
309    }
310}
311
312/// Trait for linear operators (matrices) that can perform matrix-vector products.
313///
314/// This abstraction allows solvers to work with dense matrices, sparse matrices,
315/// and matrix-free operators (e.g., FMM) interchangeably.
316pub trait LinearOperator<T: ComplexField>: Send + Sync {
317    /// Number of rows in the operator
318    fn num_rows(&self) -> usize;
319
320    /// Number of columns in the operator
321    fn num_cols(&self) -> usize;
322
323    /// Apply the operator: y = A * x
324    fn apply(&self, x: &Array1<T>) -> Array1<T>;
325
326    /// Apply the transpose: y = A^T * x
327    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
328
329    /// Apply the Hermitian (conjugate transpose): y = A^H * x
330    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
331        // Default implementation: conjugate(A^T * conj(x))
332        // Note: x.mapv(|v| v.conj()) uses scalar conjugation.
333        // If ComplexField does not have conj(), this relies on Scalar::conj().
334        // However, mapv takes a closure.
335        let x_conj: Array1<T> = x.mapv(|v| {
336            #[cfg(feature = "ndarray-linalg")]
337            {
338                ndarray_linalg::Scalar::conj(&v)
339            }
340            #[cfg(not(feature = "ndarray-linalg"))]
341            {
342                v.conj()
343            }
344        });
345
346        let y = self.apply_transpose(&x_conj);
347
348        y.mapv(|v| {
349            #[cfg(feature = "ndarray-linalg")]
350            {
351                ndarray_linalg::Scalar::conj(&v)
352            }
353            #[cfg(not(feature = "ndarray-linalg"))]
354            {
355                v.conj()
356            }
357        })
358    }
359
360    /// Check if the operator is square
361    fn is_square(&self) -> bool {
362        self.num_rows() == self.num_cols()
363    }
364}
365
366/// Trait for preconditioners used in iterative solvers.
367///
368/// A preconditioner M approximates A^(-1), so that M*A is better conditioned
369/// than A alone. This accelerates convergence of iterative methods.
370pub trait Preconditioner<T: ComplexField>: Send + Sync {
371    /// Apply the preconditioner: y = M * r
372    ///
373    /// This should approximate solving A * y = r
374    fn apply(&self, r: &Array1<T>) -> Array1<T>;
375}
376
377/// Identity preconditioner (no preconditioning)
378#[derive(Clone, Debug, Default)]
379pub struct IdentityPreconditioner;
380
381impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
382    fn apply(&self, r: &Array1<T>) -> Array1<T> {
383        r.clone()
384    }
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390    use approx::assert_relative_eq;
391
392    #[test]
393    fn test_complex64_field() {
394        let z = Complex64::new(3.0, 4.0);
395        assert_relative_eq!(z.norm_sqr(), 25.0);
396        assert_relative_eq!(z.norm(), 5.0);
397
398        // When ndarray-linalg is on, conj is from Scalar.
399        // We can't test "ComplexField::conj" but we can test method call syntax.
400        #[cfg(not(feature = "ndarray-linalg"))]
401        let z_conj = ComplexField::conj(&z);
402        #[cfg(feature = "ndarray-linalg")]
403        let z_conj = ndarray_linalg::Scalar::conj(&z);
404
405        assert_relative_eq!(z_conj.re, 3.0);
406        assert_relative_eq!(z_conj.im, -4.0);
407
408        let z_inv = z.inv();
409        let product = z * z_inv;
410        assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
411        assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
412    }
413
414    #[test]
415    fn test_f64_field() {
416        let x: f64 = 3.0;
417        assert_relative_eq!(x.norm_sqr(), 9.0);
418        assert_relative_eq!(x.norm(), 3.0);
419
420        #[cfg(not(feature = "ndarray-linalg"))]
421        assert_relative_eq!(ComplexField::conj(&x), 3.0);
422        #[cfg(feature = "ndarray-linalg")]
423        assert_relative_eq!(ndarray_linalg::Scalar::conj(&x), 3.0);
424
425        assert_relative_eq!(x.inv(), 1.0 / 3.0);
426    }
427
428    #[test]
429    fn test_identity_preconditioner() {
430        let precond = IdentityPreconditioner;
431        let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
432        let y = precond.apply(&r);
433        assert_eq!(r, y);
434    }
435}