solvers/
traits.rs

1//! Core traits for linear algebra operations
2//!
3//! This module defines the fundamental abstractions used throughout the solver library:
4//! - [`ComplexField`]: Trait for scalar types (complex and real numbers)
5//! - [`LinearOperator`]: Trait for matrix-like objects that can perform matrix-vector products
6//! - [`Preconditioner`]: Trait for preconditioning operations
7
8use ndarray::Array1;
9use num_complex::{Complex32, Complex64};
10use num_traits::{Float, FromPrimitive, NumAssign, One, ToPrimitive, Zero};
11use std::fmt::Debug;
12use std::ops::Neg;
13
14/// Trait for scalar types that can be used in linear algebra operations.
15///
16/// This trait abstracts over real and complex number types, providing
17/// a unified interface for operations like conjugation, norm computation,
18/// and conversion from real values.
19///
20/// # Implementations
21///
22/// Provided for:
23/// - `Complex64` (default for most acoustic applications)
24/// - `Complex32` (for memory-constrained applications)
25/// - `f64` (for real-valued problems)
26/// - `f32` (for real-valued, memory-constrained applications)
27pub trait ComplexField:
28    NumAssign + Clone + Copy + Send + Sync + Debug + Zero + One + Neg<Output = Self> + 'static
29{
30    /// The real number type underlying this field
31    type Real: Float + NumAssign + FromPrimitive + ToPrimitive + Send + Sync + Debug + 'static;
32
33    /// Complex conjugate
34    fn conj(&self) -> Self;
35
36    /// Squared magnitude |z|²
37    fn norm_sqr(&self) -> Self::Real;
38
39    /// Magnitude |z|
40    fn norm(&self) -> Self::Real {
41        self.norm_sqr().sqrt()
42    }
43
44    /// Create from a real value
45    fn from_real(r: Self::Real) -> Self;
46
47    /// Create from real and imaginary parts
48    fn from_re_im(re: Self::Real, im: Self::Real) -> Self;
49
50    /// Real part
51    fn re(&self) -> Self::Real;
52
53    /// Imaginary part
54    fn im(&self) -> Self::Real;
55
56    /// Check if this is approximately zero
57    fn is_zero_approx(&self, tol: Self::Real) -> bool {
58        self.norm_sqr() < tol * tol
59    }
60
61    /// Multiplicative inverse (1/z)
62    fn inv(&self) -> Self;
63
64    /// Square root
65    fn sqrt(&self) -> Self;
66}
67
68impl ComplexField for Complex64 {
69    type Real = f64;
70
71    #[inline]
72    fn conj(&self) -> Self {
73        Complex64::conj(self)
74    }
75
76    #[inline]
77    fn norm_sqr(&self) -> f64 {
78        self.re * self.re + self.im * self.im
79    }
80
81    #[inline]
82    fn from_real(r: f64) -> Self {
83        Complex64::new(r, 0.0)
84    }
85
86    #[inline]
87    fn from_re_im(re: f64, im: f64) -> Self {
88        Complex64::new(re, im)
89    }
90
91    #[inline]
92    fn re(&self) -> f64 {
93        self.re
94    }
95
96    #[inline]
97    fn im(&self) -> f64 {
98        self.im
99    }
100
101    #[inline]
102    fn inv(&self) -> Self {
103        let denom = self.norm_sqr();
104        Complex64::new(self.re / denom, -self.im / denom)
105    }
106
107    #[inline]
108    fn sqrt(&self) -> Self {
109        Complex64::sqrt(*self)
110    }
111}
112
113impl ComplexField for Complex32 {
114    type Real = f32;
115
116    #[inline]
117    fn conj(&self) -> Self {
118        Complex32::conj(self)
119    }
120
121    #[inline]
122    fn norm_sqr(&self) -> f32 {
123        self.re * self.re + self.im * self.im
124    }
125
126    #[inline]
127    fn from_real(r: f32) -> Self {
128        Complex32::new(r, 0.0)
129    }
130
131    #[inline]
132    fn from_re_im(re: f32, im: f32) -> Self {
133        Complex32::new(re, im)
134    }
135
136    #[inline]
137    fn re(&self) -> f32 {
138        self.re
139    }
140
141    #[inline]
142    fn im(&self) -> f32 {
143        self.im
144    }
145
146    #[inline]
147    fn inv(&self) -> Self {
148        let denom = self.norm_sqr();
149        Complex32::new(self.re / denom, -self.im / denom)
150    }
151
152    #[inline]
153    fn sqrt(&self) -> Self {
154        Complex32::sqrt(*self)
155    }
156}
157
158impl ComplexField for f64 {
159    type Real = f64;
160
161    #[inline]
162    fn conj(&self) -> Self {
163        *self
164    }
165
166    #[inline]
167    fn norm_sqr(&self) -> f64 {
168        *self * *self
169    }
170
171    #[inline]
172    fn from_real(r: f64) -> Self {
173        r
174    }
175
176    #[inline]
177    fn from_re_im(re: f64, _im: f64) -> Self {
178        re
179    }
180
181    #[inline]
182    fn re(&self) -> f64 {
183        *self
184    }
185
186    #[inline]
187    fn im(&self) -> f64 {
188        0.0
189    }
190
191    #[inline]
192    fn inv(&self) -> Self {
193        1.0 / *self
194    }
195
196    #[inline]
197    fn sqrt(&self) -> Self {
198        f64::sqrt(*self)
199    }
200}
201
202impl ComplexField for f32 {
203    type Real = f32;
204
205    #[inline]
206    fn conj(&self) -> Self {
207        *self
208    }
209
210    #[inline]
211    fn norm_sqr(&self) -> f32 {
212        *self * *self
213    }
214
215    #[inline]
216    fn from_real(r: f32) -> Self {
217        r
218    }
219
220    #[inline]
221    fn from_re_im(re: f32, _im: f32) -> Self {
222        re
223    }
224
225    #[inline]
226    fn re(&self) -> f32 {
227        *self
228    }
229
230    #[inline]
231    fn im(&self) -> f32 {
232        0.0
233    }
234
235    #[inline]
236    fn inv(&self) -> Self {
237        1.0 / *self
238    }
239
240    #[inline]
241    fn sqrt(&self) -> Self {
242        f32::sqrt(*self)
243    }
244}
245
246/// Trait for linear operators (matrices) that can perform matrix-vector products.
247///
248/// This abstraction allows solvers to work with dense matrices, sparse matrices,
249/// and matrix-free operators (e.g., FMM) interchangeably.
250pub trait LinearOperator<T: ComplexField>: Send + Sync {
251    /// Number of rows in the operator
252    fn num_rows(&self) -> usize;
253
254    /// Number of columns in the operator
255    fn num_cols(&self) -> usize;
256
257    /// Apply the operator: y = A * x
258    fn apply(&self, x: &Array1<T>) -> Array1<T>;
259
260    /// Apply the transpose: y = A^T * x
261    fn apply_transpose(&self, x: &Array1<T>) -> Array1<T>;
262
263    /// Apply the Hermitian (conjugate transpose): y = A^H * x
264    fn apply_hermitian(&self, x: &Array1<T>) -> Array1<T> {
265        // Default implementation: conjugate(A^T * conj(x))
266        let x_conj: Array1<T> = x.mapv(|v| v.conj());
267        self.apply_transpose(&x_conj).mapv(|v| v.conj())
268    }
269
270    /// Check if the operator is square
271    fn is_square(&self) -> bool {
272        self.num_rows() == self.num_cols()
273    }
274}
275
276/// Trait for preconditioners used in iterative solvers.
277///
278/// A preconditioner M approximates A^(-1), so that M*A is better conditioned
279/// than A alone. This accelerates convergence of iterative methods.
280pub trait Preconditioner<T: ComplexField>: Send + Sync {
281    /// Apply the preconditioner: y = M * r
282    ///
283    /// This should approximate solving A * y = r
284    fn apply(&self, r: &Array1<T>) -> Array1<T>;
285}
286
287/// Identity preconditioner (no preconditioning)
288#[derive(Clone, Debug, Default)]
289pub struct IdentityPreconditioner;
290
291impl<T: ComplexField> Preconditioner<T> for IdentityPreconditioner {
292    fn apply(&self, r: &Array1<T>) -> Array1<T> {
293        r.clone()
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300    use approx::assert_relative_eq;
301
302    #[test]
303    fn test_complex64_field() {
304        let z = Complex64::new(3.0, 4.0);
305        assert_relative_eq!(z.norm_sqr(), 25.0);
306        assert_relative_eq!(z.norm(), 5.0);
307
308        let z_conj = z.conj();
309        assert_relative_eq!(z_conj.re, 3.0);
310        assert_relative_eq!(z_conj.im, -4.0);
311
312        let z_inv = z.inv();
313        let product = z * z_inv;
314        assert_relative_eq!(product.re, 1.0, epsilon = 1e-10);
315        assert_relative_eq!(product.im, 0.0, epsilon = 1e-10);
316    }
317
318    #[test]
319    fn test_f64_field() {
320        let x: f64 = 3.0;
321        assert_relative_eq!(x.norm_sqr(), 9.0);
322        assert_relative_eq!(x.norm(), 3.0);
323        assert_relative_eq!(x.conj(), 3.0);
324        assert_relative_eq!(x.inv(), 1.0 / 3.0);
325    }
326
327    #[test]
328    fn test_identity_preconditioner() {
329        let precond = IdentityPreconditioner;
330        let r = Array1::from_vec(vec![Complex64::new(1.0, 2.0), Complex64::new(3.0, 4.0)]);
331        let y = precond.apply(&r);
332        assert_eq!(r, y);
333    }
334}