1use crate::enums::{BaseKind, TransformKind};
3use crate::utils::{array_resized_axis, check_array_axis};
4use ndarray::{Array, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Zip};
5use num_traits::identities::Zero;
6
7pub trait Base<T>:
9 BaseSize
10 + BaseMatOpLaplacian
11 + BaseMatOpDiffmat
12 + BaseMatOpStencil
13 + BaseElements
14 + BaseGradient<T>
15 + BaseFromOrtho<T>
16 + BaseTransform
17where
18 T: Zero + Copy,
19{
20}
21
22impl<A, T> Base<T> for A
23where
24 T: Zero + Copy,
25 A: BaseSize
26 + BaseMatOpLaplacian
27 + BaseMatOpDiffmat
28 + BaseMatOpStencil
29 + BaseElements
30 + BaseGradient<T>
31 + BaseFromOrtho<T>
32 + BaseTransform,
33{
34}
35
36pub trait BaseSize {
38 fn len_phys(&self) -> usize;
40
41 fn len_spec(&self) -> usize;
43
44 fn len_orth(&self) -> usize;
46}
47
48pub trait BaseElements {
50 type RealNum;
52
53 fn base_kind(&self) -> BaseKind;
55
56 fn transform_kind(&self) -> TransformKind;
58
59 fn coords(&self) -> Vec<Self::RealNum>;
61}
62
63pub trait BaseMatOpDiffmat {
65 type NumType;
67
68 fn diffmat(&self, _deriv: usize) -> Array2<Self::NumType>;
72
73 fn diffmat_pinv(&self, _deriv: usize) -> (Array2<Self::NumType>, Array2<Self::NumType>);
85}
86
87pub trait BaseMatOpStencil {
89 type NumType;
91
92 fn stencil(&self) -> Array2<Self::NumType>;
94
95 fn stencil_inv(&self) -> Array2<Self::NumType>;
97}
98
99pub trait BaseMatOpLaplacian {
101 type NumType;
103
104 fn laplacian(&self) -> Array2<Self::NumType>;
106
107 fn laplacian_pinv(&self) -> (Array2<Self::NumType>, Array2<Self::NumType>);
116}
117
118pub trait BaseFromOrtho<T>: BaseSize
120where
121 T: Zero + Copy,
122{
123 fn to_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
125
126 fn from_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
128
129 fn to_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
131 where
132 S: Data<Elem = T>,
133 D: Dimension,
134 {
135 let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
136 self.to_ortho_inplace(indata, &mut outdata, axis);
137 outdata
138 }
139
140 apply_along_axis!(
141 to_ortho_inplace,
143 T,
144 T,
145 to_ortho_slice,
146 len_spec,
147 len_orth,
148 "to_ortho"
149 );
150
151 fn from_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
153 where
154 S: Data<Elem = T>,
155 D: Dimension,
156 {
157 let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
158 self.from_ortho_inplace(indata, &mut outdata, axis);
159 outdata
160 }
161
162 apply_along_axis!(
163 from_ortho_inplace,
165 T,
166 T,
167 from_ortho_slice,
168 len_orth,
169 len_spec,
170 "from_ortho"
171 );
172
173 fn to_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
175 where
176 S: Data<Elem = T>,
177 D: Dimension,
178 Self: Sync,
179 T: Send + Sync,
180 {
181 let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
182 self.to_ortho_inplace_par(indata, &mut outdata, axis);
183 outdata
184 }
185
186 par_apply_along_axis!(
187 to_ortho_inplace_par,
189 T,
190 T,
191 to_ortho_slice,
192 len_spec,
193 len_orth,
194 "to_ortho"
195 );
196
197 fn from_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
199 where
200 S: Data<Elem = T>,
201 D: Dimension,
202 Self: Sync,
203 T: Send + Sync,
204 {
205 let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
206 self.from_ortho_inplace_par(indata, &mut outdata, axis);
207 outdata
208 }
209
210 par_apply_along_axis!(
211 from_ortho_inplace_par,
213 T,
214 T,
215 from_ortho_slice,
216 len_orth,
217 len_spec,
218 "from_ortho"
219 );
220}
221
222pub trait BaseTransform: BaseSize {
227 type Physical;
229
230 type Spectral;
232
233 fn forward_slice(&self, indata: &[Self::Physical], outdata: &mut [Self::Spectral]);
237
238 fn backward_slice(&self, indata: &[Self::Spectral], outdata: &mut [Self::Physical]);
242
243 fn forward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
245 where
246 S: Data<Elem = Self::Physical>,
247 D: Dimension,
248 Self::Physical: Clone,
249 Self::Spectral: Zero + Clone + Copy,
250 {
251 let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
252 self.forward_inplace(indata, &mut outdata, axis);
253 outdata
254 }
255
256 apply_along_axis!(
257 forward_inplace,
259 Self::Physical,
260 Self::Spectral,
261 forward_slice,
262 len_phys,
263 len_spec,
264 "forward"
265 );
266
267 fn backward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
269 where
270 S: Data<Elem = Self::Spectral>,
271 D: Dimension,
272 Self::Spectral: Clone,
273 Self::Physical: Zero + Clone + Copy,
274 {
275 let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
276 self.backward_inplace(indata, &mut outdata, axis);
277 outdata
278 }
279
280 apply_along_axis!(
281 backward_inplace,
283 Self::Spectral,
284 Self::Physical,
285 backward_slice,
286 len_spec,
287 len_phys,
288 "backward"
289 );
290
291 fn forward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
293 where
294 S: Data<Elem = Self::Physical>,
295 D: Dimension,
296 Self::Physical: Clone + Send + Sync,
297 Self::Spectral: Zero + Clone + Copy + Send + Sync,
298 Self: Sync,
299 {
300 let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
301 self.forward_inplace_par(indata, &mut outdata, axis);
302 outdata
303 }
304
305 par_apply_along_axis!(
306 forward_inplace_par,
308 Self::Physical,
309 Self::Spectral,
310 forward_slice,
311 len_phys,
312 len_spec,
313 "forward"
314 );
315
316 fn backward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
318 where
319 S: Data<Elem = Self::Spectral>,
320 D: Dimension,
321 Self::Spectral: Clone + Send + Sync,
322 Self::Physical: Zero + Clone + Copy + Send + Sync,
323 Self: Sync,
324 {
325 let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
326 self.backward_inplace_par(indata, &mut outdata, axis);
327 outdata
328 }
329
330 par_apply_along_axis!(
331 backward_inplace_par,
333 Self::Spectral,
334 Self::Physical,
335 backward_slice,
336 len_spec,
337 len_phys,
338 "backward"
339 );
340}
341
342pub trait BaseGradient<T>: BaseSize
344where
345 T: Zero + Copy,
346{
347 fn gradient_slice(&self, indata: &[T], outdata: &mut [T], n_times: usize);
349
350 fn gradient<S, D>(&self, indata: &ArrayBase<S, D>, n_times: usize, axis: usize) -> Array<T, D>
352 where
353 S: Data<Elem = T>,
354 D: Dimension,
355 {
356 let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
357 self.gradient_inplace(indata, &mut outdata, n_times, axis);
358 outdata
359 }
360
361 fn gradient_inplace<S1, S2, D>(
363 &self,
364 indata: &ArrayBase<S1, D>,
365 outdata: &mut ArrayBase<S2, D>,
366 n_times: usize,
367 axis: usize,
368 ) where
369 S1: Data<Elem = T>,
370 S2: Data<Elem = T> + DataMut,
371 D: Dimension,
372 {
373 assert!(indata.is_standard_layout());
374 assert!(outdata.is_standard_layout());
375 check_array_axis(indata, self.len_spec(), axis, "gradient");
376 check_array_axis(outdata, self.len_orth(), axis, "gradient");
377
378 let outer_axis = outdata.ndim() - 1;
379 if axis == outer_axis {
380 Zip::from(indata.rows())
382 .and(outdata.rows_mut())
383 .for_each(|x, mut y| {
384 self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
385 });
386 } else {
387 let mut scratch: Vec<T> = vec![T::zero(); outdata.shape()[axis]];
389 Zip::from(indata.lanes(Axis(axis)))
390 .and(outdata.lanes_mut(Axis(axis)))
391 .for_each(|x, mut y| {
392 self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
393 for (yi, si) in y.iter_mut().zip(scratch.iter()) {
394 *yi = *si;
395 }
396 });
397 }
398 }
399
400 fn gradient_par<S, D>(
402 &self,
403 indata: &ArrayBase<S, D>,
404 n_times: usize,
405 axis: usize,
406 ) -> Array<T, D>
407 where
408 S: Data<Elem = T>,
409 D: Dimension,
410 T: Send + Sync,
411 Self: Sync,
412 {
413 let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
414 self.gradient_inplace_par(indata, &mut outdata, n_times, axis);
415 outdata
416 }
417
418 fn gradient_inplace_par<S1, S2, D>(
420 &self,
421 indata: &ArrayBase<S1, D>,
422 outdata: &mut ArrayBase<S2, D>,
423 n_times: usize,
424 axis: usize,
425 ) where
426 S1: Data<Elem = T>,
427 S2: Data<Elem = T> + DataMut,
428 D: Dimension,
429 T: Send + Sync,
430 Self: Sync,
431 {
432 assert!(indata.is_standard_layout());
433 assert!(outdata.is_standard_layout());
434 check_array_axis(indata, self.len_spec(), axis, "gradient");
435 check_array_axis(outdata, self.len_orth(), axis, "gradient");
436
437 let outer_axis = outdata.ndim() - 1;
438 if axis == outer_axis {
439 Zip::from(indata.rows())
441 .and(outdata.rows_mut())
442 .par_for_each(|x, mut y| {
443 self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
444 });
445 } else {
446 let scratch_len = outdata.shape()[axis];
448 Zip::from(indata.lanes(Axis(axis)))
449 .and(outdata.lanes_mut(Axis(axis)))
450 .par_for_each(|x, mut y| {
451 let mut scratch: Vec<T> = vec![T::zero(); scratch_len];
452 self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
453 for (yi, si) in y.iter_mut().zip(scratch.iter()) {
454 *yi = *si;
455 }
456 });
457 }
458 }
459}
460
461