Skip to main content

math_audio_solvers/
traits.rs

1//! Core traits for linear algebra operations
2//!
3//! This module defines the fundamental abstractions used throughout the solver library:
4//! - [`ComplexField`]: Trait for scalar types (complex and real numbers)
5//! - [`LinearOperator`]: Trait for matrix-like objects that can perform matrix-vector products
6//! - [`Preconditioner`]: Trait for preconditioning operations
7
8use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, NumAssign, One, Zero};
11use num_traits::{FromPrimitive, ToPrimitive};
12use std::fmt::Debug;
13use std::ops::Neg;
14
15// Note: `Array1` is used in the BLAS-dispatch default methods on `ComplexField`.
16// The f64/f32 overrides use ndarray's `.dot()` and `.scaled_add()` which route
17// through BLAS when the `native` feature (ndarray/blas) is enabled.
18
19/// Trait for scalar types that can be used in linear algebra operations.
20///
21/// This trait abstracts over real and complex number types, providing
22/// a unified interface for operations like conjugation, norm computation,
23/// and conversion from real values.
24pub trait ComplexField:
25    NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
26{
27    /// The real number type underlying this field
28    type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
29
30    /// Complex conjugate
31    fn conj(&self) -> Self;
32
33    /// Squared magnitude |z|²
34    fn norm_sqr(&self) -> Self::Real;
35
36    /// Magnitude |z|
37    fn norm(&self) -> Self::Real {
38        self.norm_sqr().sqrt()
39    }
40
41    /// Create from a real value
42    fn from_real(r: Self::Real) -> Self;
43
44    /// Create from real and imaginary parts
45    fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
46
47    /// Real part
48    fn re(&self) -> Self::Real;
49
50    /// Imaginary part
51    fn im(&self) -> Self::Real;
52
53    /// Check if this is approximately zero
54    fn is_zero_approx(&self, tol: Self::Real) -> bool {
55        self.norm_sqr() < tol * tol
56    }
57
58    /// Multiplicative inverse (1/z)
59    fn inv(&self) -> Self;
60
61    /// Square root
62    fn sqrt(&self) -> Self;
63
64    // ------------------------------------------------------------------
65    // BLAS-dispatch methods
66    //
67    // Default implementations use generic Rust loops. The f64/f32 impls
68    // override these to use ndarray operations backed by BLAS (when the
69    // `native` feature is enabled).
70    // ------------------------------------------------------------------
71
72    /// Inner product: Σ conj(x_i) * y_i
73    fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
74        let mut sum = Self::zero();
75        for (xi, yi) in x.iter().zip(y.iter()) {
76            sum += xi.conj() * *yi;
77        }
78        sum
79    }
80
81    /// Squared vector norm: Σ |x_i|²
82    fn vec_norm_sqr(x: &Array1<Self>) -> Self::Real {
83        let mut sum = Self::Real::zero();
84        for xi in x.iter() {
85            sum += xi.norm_sqr();
86        }
87        sum
88    }
89
90    /// AXPY: y += α * x
91    fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
92        for (xi, yi) in x.iter().zip(y.iter_mut()) {
93            *yi += alpha * *xi;
94        }
95    }
96
97    /// In-place scale: x *= α
98    fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
99        for xi in x.iter_mut() {
100            *xi *= alpha;
101        }
102    }
103}
104
105impl ComplexField for Complex64 {
106    type Real = f64;
107
108    #[inline]
109    fn conj(&self) -> Self {
110        Complex64::conj(self)
111    }
112
113    #[inline]
114    fn norm_sqr(&self) -> f64 {
115        self.re * self.re + self.im * self.im
116    }
117
118    #[inline]
119    fn from_real(r: f64) -> Self {
120        Complex64::new(r, 0.0)
121    }
122
123    #[inline]
124    fn from_re_im(re: f64, im: f64) -> Self {
125        Complex64::new(re, im)
126    }
127
128    #[inline]
129    fn re(&self) -> f64 {
130        self.re
131    }
132
133    #[inline]
134    fn im(&self) -> f64 {
135        self.im
136    }
137
138    #[inline]
139    fn inv(&self) -> Self {
140        let denom = self.norm_sqr();
141        Complex64::new(self.re / denom, -self.im / denom)
142    }
143
144    #[inline]
145    fn sqrt(&self) -> Self {
146        Complex64::sqrt(*self)
147    }
148}
149
150impl ComplexField for Complex32 {
151    type Real = f32;
152
153    #[inline]
154    fn conj(&self) -> Self {
155        Complex32::conj(self)
156    }
157
158    #[inline]
159    fn norm_sqr(&self) -> f32 {
160        self.re * self.re + self.im * self.im
161    }
162
163    #[inline]
164    fn from_real(r: f32) -> Self {
165        Complex32::new(r, 0.0)
166    }
167
168    #[inline]
169    fn from_re_im(re: f32, im: f32) -> Self {
170        Complex32::new(re, im)
171    }
172
173    #[inline]
174    fn re(&self) -> f32 {
175        self.re
176    }
177
178    #[inline]
179    fn im(&self) -> f32 {
180        self.im
181    }
182
183    #[inline]
184    fn inv(&self) -> Self {
185        let denom = self.norm_sqr();
186        Complex32::new(self.re / denom, -self.im / denom)
187    }
188
189    #[inline]
190    fn sqrt(&self) -> Self {
191        Complex32::sqrt(*self)
192    }
193}
194
195impl ComplexField for f64 {
196    type Real = f64;
197
198    #[inline]
199    fn conj(&self) -> Self {
200        *self
201    }
202
203    #[inline]
204    fn norm_sqr(&self) -> f64 {
205        *self * *self
206    }
207
208    #[inline]
209    fn from_real(r: f64) -> Self {
210        r
211    }
212
213    #[inline]
214    fn from_re_im(re: f64, _im: f64) -> Self {
215        re
216    }
217
218    #[inline]
219    fn re(&self) -> f64 {
220        *self
221    }
222
223    #[inline]
224    fn im(&self) -> f64 {
225        0.0
226    }
227
228    #[inline]
229    fn inv(&self) -> Self {
230        1.0 / *self
231    }
232
233    #[inline]
234    fn sqrt(&self) -> Self {
235        f64::sqrt(*self)
236    }
237
238    // BLAS-accelerated overrides via ndarray (uses DDOT/DNRM2/DAXPY)
239
240    #[inline]
241    fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
242        x.dot(y)
243    }
244
245    #[inline]
246    fn vec_norm_sqr(x: &Array1<Self>) -> Self {
247        x.dot(x)
248    }
249
250    #[inline]
251    fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
252        y.scaled_add(alpha, x);
253    }
254
255    #[inline]
256    fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
257        x.mapv_inplace(|v| v * alpha);
258    }
259}
260
261impl ComplexField for f32 {
262    type Real = f32;
263
264    #[inline]
265    fn conj(&self) -> Self {
266        *self
267    }
268
269    #[inline]
270    fn norm_sqr(&self) -> f32 {
271        *self * *self
272    }
273
274    #[inline]
275    fn from_real(r: f32) -> Self {
276        r
277    }
278
279    #[inline]
280    fn from_re_im(re: f32, _im: f32) -> Self {
281        re
282    }
283
284    #[inline]
285    fn re(&self) -> f32 {
286        *self
287    }
288
289    #[inline]
290    fn im(&self) -> f32 {
291        0.0
292    }
293
294    #[inline]
295    fn inv(&self) -> Self {
296        1.0 / *self
297    }
298
299    #[inline]
300    fn sqrt(&self) -> Self {
301        f32::sqrt(*self)
302    }
303
304    // BLAS-accelerated overrides via ndarray (uses SDOT/SNRM2/SAXPY)
305
306    #[inline]
307    fn vec_dot(x: &Array1<Self>, y: &Array1<Self>) -> Self {
308        x.dot(y)
309    }
310
311    #[inline]
312    fn vec_norm_sqr(x: &Array1<Self>) -> Self {
313        x.dot(x)
314    }
315
316    #[inline]
317    fn vec_axpy(alpha: Self, x: &Array1<Self>, y: &mut Array1<Self>) {
318        y.scaled_add(alpha, x);
319    }
320
321    #[inline]
322    fn vec_scale(x: &mut Array1<Self>, alpha: Self) {
323        x.mapv_inplace(|v| v * alpha);
324    }
325}
326
327/// Trait for linear operators (matrices) that can perform matrix-vector products.
328///
329/// This abstraction allows solvers to work with dense matrices, sparse matrices,
330/// and matrix-free operators (e.g., FMM) interchangeably.
331pub trait LinearOperator<T: ComplexField>: Send + Sync {
332    /// Number of rows in the operator
333    fn num_rows(&self) -> usize;
334
335    /// Number of columns in the operator
336    fn num_cols(&self) -> usize;
337
338    /// Apply the operator: y = A * x
339    fn apply(&self, x: &Array1<T>) -> Array1<T>;
340
341    /// Apply the transpose: y = A^T * x
342    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
343
344    /// Apply the Hermitian (conjugate transpose): y = A^H * x
345    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
346        let x_conj: Array1<T> = x.mapv(|v| v.conj());
347        let y = self.apply_transpose(&x_conj);
348        y.mapv(|v| v.conj())
349    }
350
351    /// Check if the operator is square
352    fn is_square(&self) -> bool {
353        self.num_rows() == self.num_cols()
354    }
355}
356
357/// Status of an iterative solver
358#[derive(Debug, Clone, Copy, PartialEq, Eq)]
359pub enum SolverStatus {
360    /// Solver converged to the desired tolerance
361    Converged,
362    /// Solver reached the maximum number of iterations without converging
363    MaxIterationsReached,
364    /// Solver encountered a breakdown (e.g., division by zero)
365    Breakdown,
366    /// Solver stagnated (no progress made)
367    Stagnated,
368    /// Solver diverged (residual is increasing)
369    Diverged,
370}
371
372/// Error information from iterative solvers
373#[derive(Debug, thiserror::Error)]
374pub enum SolverError {
375    #[error("Solver failed to converge: {status:?}")]
376    ConvergenceError {
377        status: SolverStatus,
378        iterations: usize,
379        residual: f64,
380    },
381    #[error("Linear operator dimension mismatch: expected {expected}, got {got}")]
382    DimensionMismatch { expected: usize, got: usize },
383}
384
385/// Trait for preconditioners used in iterative solvers.
386///
387/// A preconditioner M approximates A^(-1), so that M*A is better conditioned
388/// than A alone. This accelerates convergence of iterative methods.
389pub trait Preconditioner<T: ComplexField>: Send + Sync {
390    /// Apply the preconditioner: y = M * r
391    ///
392    /// This should approximate solving A * y = r
393    fn apply(&self, r: &Array1<T>) -> Array1<T>;
394}
395
396/// Identity preconditioner (no preconditioning)
397#[derive(Clone, Debug, Default)]
398pub struct IdentityPreconditioner;
399
400impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
401    fn apply(&self, r: &Array1<T>) -> Array1<T> {
402        r.clone()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use approx::assert_relative_eq;
410
411    #[test]
412    fn test_complex64_field() {
413        let z = Complex64::new(3.0, 4.0);
414        assert_relative_eq!(z.norm_sqr(), 25.0);
415        assert_relative_eq!(z.norm(), 5.0);
416
417        let z_conj = z.conj();
418        assert_relative_eq!(z_conj.re, 3.0);
419        assert_relative_eq!(z_conj.im, -4.0);
420
421        let z_inv = z.inv();
422        let product = z * z_inv;
423        assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
424        assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
425    }
426
427    #[test]
428    fn test_f64_field() {
429        let x: f64 = 3.0;
430        assert_relative_eq!(x.norm_sqr(), 9.0);
431        assert_relative_eq!(x.norm(), 3.0);
432        assert_relative_eq!(x.conj(), 3.0);
433        assert_relative_eq!(x.inv(), 1.0 / 3.0);
434    }
435
436    #[test]
437    fn test_identity_preconditioner() {
438        let precond = IdentityPreconditioner;
439        let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
440        let y = precond.apply(&r);
441        assert_eq!(r, y);
442    }
443}