funspace/
traits.rs

1//! # Collection of traits that bases must implement
2use crate::enums::{BaseKind, TransformKind};
3use crate::utils::{array_resized_axis, check_array_axis};
4use ndarray::{Array, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Zip};
5use num_traits::identities::Zero;
6
7/// Base super trait
8pub trait Base<T>:
9    BaseSize
10    + BaseMatOpLaplacian
11    + BaseMatOpDiffmat
12    + BaseMatOpStencil
13    + BaseElements
14    + BaseGradient<T>
15    + BaseFromOrtho<T>
16    + BaseTransform
17where
18    T: Zero + Copy,
19{
20}
21
22impl<A, T> Base<T> for A
23where
24    T: Zero + Copy,
25    A: BaseSize
26        + BaseMatOpLaplacian
27        + BaseMatOpDiffmat
28        + BaseMatOpStencil
29        + BaseElements
30        + BaseGradient<T>
31        + BaseFromOrtho<T>
32        + BaseTransform,
33{
34}
35
36/// Dimensions
37pub trait BaseSize {
38    /// Size in physical space
39    fn len_phys(&self) -> usize;
40
41    /// Size in spectral space
42    fn len_spec(&self) -> usize;
43
44    /// Size of orthogonal space
45    fn len_orth(&self) -> usize;
46}
47
48/// Coordinates and base functions
49pub trait BaseElements {
50    /// Real valued scalar type
51    type RealNum;
52
53    /// Return kind of base
54    fn base_kind(&self) -> BaseKind;
55
56    /// Return kind of transform
57    fn transform_kind(&self) -> TransformKind;
58
59    /// Coordinates in physical space
60    fn coords(&self) -> Vec<Self::RealNum>;
61}
62
63/// Collection of differentiation matrix operators
64pub trait BaseMatOpDiffmat {
65    /// Scalar type of matrix
66    type NumType;
67
68    /// Explicit differential operator $ D $
69    ///
70    /// Matrix-based version of [`BaseGradient::gradient()`]
71    fn diffmat(&self, _deriv: usize) -> Array2<Self::NumType>;
72
73    /// Explicit inverse of differential operator $ D^* $
74    ///
75    /// Returns ``(D_pinv, I_pinv)``, where `D_pinv` is the pseudoinverse
76    /// and ``I_pinv`` the corresponding pseudoidentity matrix, such
77    /// that
78    ///
79    /// ```text
80    /// D_pinv @ D = I_pinv
81    /// ```
82    ///
83    /// Can be used as a preconditioner.
84    fn diffmat_pinv(&self, _deriv: usize) -> (Array2<Self::NumType>, Array2<Self::NumType>);
85}
86
87/// Collection of stencil matrix operators
88pub trait BaseMatOpStencil {
89    /// Scalar type of matrix
90    type NumType;
91
92    /// Transformation stencil composite -> orthogonal space
93    fn stencil(&self) -> Array2<Self::NumType>;
94
95    /// Inverse of transformation stencil
96    fn stencil_inv(&self) -> Array2<Self::NumType>;
97}
98
99/// Collection of *Laplacian* matrix operators
100pub trait BaseMatOpLaplacian {
101    /// Scalar type of matrix
102    type NumType;
103
104    /// Laplacian $ L $
105    fn laplacian(&self) -> Array2<Self::NumType>;
106
107    /// Pseudoinverse matrix of Laplacian $ L^{-1} $
108    ///
109    /// Returns pseudoinverse and pseudoidentity,i.e
110    /// ``(D_pinv, I_pinv)``
111    ///
112    /// ```text
113    /// D_pinv @ D = I_pinv
114    /// ```
115    fn laplacian_pinv(&self) -> (Array2<Self::NumType>, Array2<Self::NumType>);
116}
117
118/// # Transform from orthogonal <-> composite base
119pub trait BaseFromOrtho<T>: BaseSize
120where
121    T: Zero + Copy,
122{
123    /// ## Composite coefficients -> Orthogonal coefficients
124    fn to_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
125
126    /// ## Orthogonal coefficients -> Composite coefficients
127    fn from_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
128
129    /// ## Composite coefficients -> Orthogonal coefficients
130    fn to_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
131    where
132        S: Data<Elem = T>,
133        D: Dimension,
134    {
135        let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
136        self.to_ortho_inplace(indata, &mut outdata, axis);
137        outdata
138    }
139
140    apply_along_axis!(
141        /// ## Composite coefficients -> Orthogonal coefficients
142        to_ortho_inplace,
143        T,
144        T,
145        to_ortho_slice,
146        len_spec,
147        len_orth,
148        "to_ortho"
149    );
150
151    /// ## Composite coefficients -> Orthogonal coefficients
152    fn from_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
153    where
154        S: Data<Elem = T>,
155        D: Dimension,
156    {
157        let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
158        self.from_ortho_inplace(indata, &mut outdata, axis);
159        outdata
160    }
161
162    apply_along_axis!(
163        /// ## Composite coefficients -> Orthogonal coefficients
164        from_ortho_inplace,
165        T,
166        T,
167        from_ortho_slice,
168        len_orth,
169        len_spec,
170        "from_ortho"
171    );
172
173    /// ## Composite coefficients -> Orthogonal coefficients  (Parallel)
174    fn to_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
175    where
176        S: Data<Elem = T>,
177        D: Dimension,
178        Self: Sync,
179        T: Send + Sync,
180    {
181        let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
182        self.to_ortho_inplace_par(indata, &mut outdata, axis);
183        outdata
184    }
185
186    par_apply_along_axis!(
187        /// ## Composite coefficients -> Orthogonal coefficients (Parallel)
188        to_ortho_inplace_par,
189        T,
190        T,
191        to_ortho_slice,
192        len_spec,
193        len_orth,
194        "to_ortho"
195    );
196
197    /// ## Composite coefficients -> Orthogonal coefficients  (Parallel)
198    fn from_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
199    where
200        S: Data<Elem = T>,
201        D: Dimension,
202        Self: Sync,
203        T: Send + Sync,
204    {
205        let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
206        self.from_ortho_inplace_par(indata, &mut outdata, axis);
207        outdata
208    }
209
210    par_apply_along_axis!(
211        /// ## Composite coefficients -> Orthogonal coefficients (Parallel)
212        from_ortho_inplace_par,
213        T,
214        T,
215        from_ortho_slice,
216        len_orth,
217        len_spec,
218        "from_ortho"
219    );
220}
221
222/// The associated types *Physical* and *Spectral* refer
223/// to the scalar types in the physical and spectral space.
224/// For example, Fourier transforms from real-to-complex,
225/// while Chebyshev transforms from real-to-real.
226pub trait BaseTransform: BaseSize {
227    /// ## Scalar type in physical space
228    type Physical;
229
230    /// ## Scalar type in spectral space
231    type Spectral;
232
233    /// ## Physical values -> Spectral coefficients
234    ///
235    /// Transforms a one-dimensional slice.
236    fn forward_slice(&self, indata: &[Self::Physical], outdata: &mut [Self::Spectral]);
237
238    /// ## Spectral coefficients -> Physical values
239    ///
240    /// Transforms a one-dimensional slice.
241    fn backward_slice(&self, indata: &[Self::Spectral], outdata: &mut [Self::Physical]);
242
243    /// ## Physical values -> Spectral coefficients
244    fn forward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
245    where
246        S: Data<Elem = Self::Physical>,
247        D: Dimension,
248        Self::Physical: Clone,
249        Self::Spectral: Zero + Clone + Copy,
250    {
251        let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
252        self.forward_inplace(indata, &mut outdata, axis);
253        outdata
254    }
255
256    apply_along_axis!(
257        /// ## Physical values -> Spectral coefficients
258        forward_inplace,
259        Self::Physical,
260        Self::Spectral,
261        forward_slice,
262        len_phys,
263        len_spec,
264        "forward"
265    );
266
267    /// ## Spectral coefficients -> Physical values
268    fn backward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
269    where
270        S: Data<Elem = Self::Spectral>,
271        D: Dimension,
272        Self::Spectral: Clone,
273        Self::Physical: Zero + Clone + Copy,
274    {
275        let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
276        self.backward_inplace(indata, &mut outdata, axis);
277        outdata
278    }
279
280    apply_along_axis!(
281        /// ## Spectral coefficients -> Physical values
282        backward_inplace,
283        Self::Spectral,
284        Self::Physical,
285        backward_slice,
286        len_spec,
287        len_phys,
288        "backward"
289    );
290
291    /// ## Physical values -> Spectral coefficients (Parallel)
292    fn forward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
293    where
294        S: Data<Elem = Self::Physical>,
295        D: Dimension,
296        Self::Physical: Clone + Send + Sync,
297        Self::Spectral: Zero + Clone + Copy + Send + Sync,
298        Self: Sync,
299    {
300        let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
301        self.forward_inplace_par(indata, &mut outdata, axis);
302        outdata
303    }
304
305    par_apply_along_axis!(
306        /// ## Physical values -> Spectral coefficients (Parallel)
307        forward_inplace_par,
308        Self::Physical,
309        Self::Spectral,
310        forward_slice,
311        len_phys,
312        len_spec,
313        "forward"
314    );
315
316    /// ## Spectral coefficients -> Physical values (Parallel)
317    fn backward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
318    where
319        S: Data<Elem = Self::Spectral>,
320        D: Dimension,
321        Self::Spectral: Clone + Send + Sync,
322        Self::Physical: Zero + Clone + Copy + Send + Sync,
323        Self: Sync,
324    {
325        let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
326        self.backward_inplace_par(indata, &mut outdata, axis);
327        outdata
328    }
329
330    par_apply_along_axis!(
331        /// ## Spectral coefficients -> Physical values (Parallel)
332        backward_inplace_par,
333        Self::Spectral,
334        Self::Physical,
335        backward_slice,
336        len_spec,
337        len_phys,
338        "backward"
339    );
340}
341
342/// # Gradient
343pub trait BaseGradient<T>: BaseSize
344where
345    T: Zero + Copy,
346{
347    /// Differentiate in spectral space
348    fn gradient_slice(&self, indata: &[T], outdata: &mut [T], n_times: usize);
349
350    /// Differentiate in spectral space
351    fn gradient<S, D>(&self, indata: &ArrayBase<S, D>, n_times: usize, axis: usize) -> Array<T, D>
352    where
353        S: Data<Elem = T>,
354        D: Dimension,
355    {
356        let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
357        self.gradient_inplace(indata, &mut outdata, n_times, axis);
358        outdata
359    }
360
361    /// Differentiate in spectral space
362    fn gradient_inplace<S1, S2, D>(
363        &self,
364        indata: &ArrayBase<S1, D>,
365        outdata: &mut ArrayBase<S2, D>,
366        n_times: usize,
367        axis: usize,
368    ) where
369        S1: Data<Elem = T>,
370        S2: Data<Elem = T> + DataMut,
371        D: Dimension,
372    {
373        assert!(indata.is_standard_layout());
374        assert!(outdata.is_standard_layout());
375        check_array_axis(indata, self.len_spec(), axis, "gradient");
376        check_array_axis(outdata, self.len_orth(), axis, "gradient");
377
378        let outer_axis = outdata.ndim() - 1;
379        if axis == outer_axis {
380            // Data is contiguous in memory
381            Zip::from(indata.rows())
382                .and(outdata.rows_mut())
383                .for_each(|x, mut y| {
384                    self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
385                });
386        } else {
387            // Data is *not* contiguous in memory.
388            let mut scratch: Vec<T> = vec![T::zero(); outdata.shape()[axis]];
389            Zip::from(indata.lanes(Axis(axis)))
390                .and(outdata.lanes_mut(Axis(axis)))
391                .for_each(|x, mut y| {
392                    self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
393                    for (yi, si) in y.iter_mut().zip(scratch.iter()) {
394                        *yi = *si;
395                    }
396                });
397        }
398    }
399
400    /// Differentiate in spectral space
401    fn gradient_par<S, D>(
402        &self,
403        indata: &ArrayBase<S, D>,
404        n_times: usize,
405        axis: usize,
406    ) -> Array<T, D>
407    where
408        S: Data<Elem = T>,
409        D: Dimension,
410        T: Send + Sync,
411        Self: Sync,
412    {
413        let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
414        self.gradient_inplace_par(indata, &mut outdata, n_times, axis);
415        outdata
416    }
417
418    /// Differentiate in spectral space
419    fn gradient_inplace_par<S1, S2, D>(
420        &self,
421        indata: &ArrayBase<S1, D>,
422        outdata: &mut ArrayBase<S2, D>,
423        n_times: usize,
424        axis: usize,
425    ) where
426        S1: Data<Elem = T>,
427        S2: Data<Elem = T> + DataMut,
428        D: Dimension,
429        T: Send + Sync,
430        Self: Sync,
431    {
432        assert!(indata.is_standard_layout());
433        assert!(outdata.is_standard_layout());
434        check_array_axis(indata, self.len_spec(), axis, "gradient");
435        check_array_axis(outdata, self.len_orth(), axis, "gradient");
436
437        let outer_axis = outdata.ndim() - 1;
438        if axis == outer_axis {
439            // Data is contiguous in memory
440            Zip::from(indata.rows())
441                .and(outdata.rows_mut())
442                .par_for_each(|x, mut y| {
443                    self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
444                });
445        } else {
446            // Data is *not* contiguous in memory.
447            let scratch_len = outdata.shape()[axis];
448            Zip::from(indata.lanes(Axis(axis)))
449                .and(outdata.lanes_mut(Axis(axis)))
450                .par_for_each(|x, mut y| {
451                    let mut scratch: Vec<T> = vec![T::zero(); scratch_len];
452                    self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
453                    for (yi, si) in y.iter_mut().zip(scratch.iter()) {
454                        *yi = *si;
455                    }
456                });
457        }
458    }
459}
460
461// /// Applys function along lanes of *axis*
462// fn apply_along_axis<F, S1, S2, D, T1, T2>(
463//     &self,
464//     indata: &ArrayBase<S1, D>,
465//     outdata: &mut ArrayBase<S2, D>,
466//     axis: usize,
467//     function: &F,
468// ) where
469//     S1: Data<Elem = T1>,
470//     S2: Data<Elem = T2> + DataMut,
471//     D: Dimension,
472//     F: Fn(&Self, &[T1], &mut [T2]),
473//     T1: Clone,
474//     T2: Clone + Zero + Copy,
475// {
476//     let outer_axis = indata.ndim() - 1;
477//     if axis == outer_axis {
478//         // Data is contiguous in memory
479//         Zip::from(indata.lanes(Axis(axis)))
480//             .and(outdata.lanes_mut(Axis(axis)))
481//             .for_each(|x, mut y| {
482//                 function(self, x.as_slice().unwrap(), y.as_slice_mut().unwrap());
483//             });
484//     } else {
485//         // Data is *not* contiguous in memory.
486//         let mut scratch: Vec<T2> = vec![T2::zero(); outdata.shape()[axis]];
487//         Zip::from(indata.lanes(Axis(axis)))
488//             .and(outdata.lanes_mut(Axis(axis)))
489//             .for_each(|x, mut y| {
490//                 function(self, &x.to_vec(), &mut scratch);
491//                 for (yi, si) in y.iter_mut().zip(scratch.iter()) {
492//                     *yi = *si;
493//                 }
494//             });
495//     }
496// }