use crate::enums::{BaseKind, TransformKind};
use crate::utils::{array_resized_axis, check_array_axis};
use ndarray::{Array, Array2, ArrayBase, Axis, Data, DataMut, Dimension, Zip};
use num_traits::identities::Zero;
pub trait Base<T>:
BaseSize
+ BaseMatOpLaplacian
+ BaseMatOpDiffmat
+ BaseMatOpStencil
+ BaseElements
+ BaseGradient<T>
+ BaseFromOrtho<T>
+ BaseTransform
where
T: Zero + Copy,
{
}
impl<A, T> Base<T> for A
where
T: Zero + Copy,
A: BaseSize
+ BaseMatOpLaplacian
+ BaseMatOpDiffmat
+ BaseMatOpStencil
+ BaseElements
+ BaseGradient<T>
+ BaseFromOrtho<T>
+ BaseTransform,
{
}
pub trait BaseSize {
fn len_phys(&self) -> usize;
fn len_spec(&self) -> usize;
fn len_orth(&self) -> usize;
}
pub trait BaseElements {
type RealNum;
fn base_kind(&self) -> BaseKind;
fn transform_kind(&self) -> TransformKind;
fn coords(&self) -> Vec<Self::RealNum>;
}
pub trait BaseMatOpDiffmat {
type NumType;
fn diffmat(&self, _deriv: usize) -> Array2<Self::NumType>;
fn diffmat_pinv(&self, _deriv: usize) -> (Array2<Self::NumType>, Array2<Self::NumType>);
}
pub trait BaseMatOpStencil {
type NumType;
fn stencil(&self) -> Array2<Self::NumType>;
fn stencil_inv(&self) -> Array2<Self::NumType>;
}
pub trait BaseMatOpLaplacian {
type NumType;
fn laplacian(&self) -> Array2<Self::NumType>;
fn laplacian_pinv(&self) -> (Array2<Self::NumType>, Array2<Self::NumType>);
}
pub trait BaseFromOrtho<T>: BaseSize
where
T: Zero + Copy,
{
fn to_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
fn from_ortho_slice(&self, indata: &[T], outdata: &mut [T]);
fn to_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
{
let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
self.to_ortho_inplace(indata, &mut outdata, axis);
outdata
}
apply_along_axis!(
to_ortho_inplace,
T,
T,
to_ortho_slice,
len_spec,
len_orth,
"to_ortho"
);
fn from_ortho<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
{
let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
self.from_ortho_inplace(indata, &mut outdata, axis);
outdata
}
apply_along_axis!(
from_ortho_inplace,
T,
T,
from_ortho_slice,
len_orth,
len_spec,
"from_ortho"
);
fn to_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
Self: Sync,
T: Send + Sync,
{
let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
self.to_ortho_inplace_par(indata, &mut outdata, axis);
outdata
}
par_apply_along_axis!(
to_ortho_inplace_par,
T,
T,
to_ortho_slice,
len_spec,
len_orth,
"to_ortho"
);
fn from_ortho_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
Self: Sync,
T: Send + Sync,
{
let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
self.from_ortho_inplace_par(indata, &mut outdata, axis);
outdata
}
par_apply_along_axis!(
from_ortho_inplace_par,
T,
T,
from_ortho_slice,
len_orth,
len_spec,
"from_ortho"
);
}
pub trait BaseTransform: BaseSize {
type Physical;
type Spectral;
fn forward_slice(&self, indata: &[Self::Physical], outdata: &mut [Self::Spectral]);
fn backward_slice(&self, indata: &[Self::Spectral], outdata: &mut [Self::Physical]);
fn forward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
where
S: Data<Elem = Self::Physical>,
D: Dimension,
Self::Physical: Clone,
Self::Spectral: Zero + Clone + Copy,
{
let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
self.forward_inplace(indata, &mut outdata, axis);
outdata
}
apply_along_axis!(
forward_inplace,
Self::Physical,
Self::Spectral,
forward_slice,
len_phys,
len_spec,
"forward"
);
fn backward<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
where
S: Data<Elem = Self::Spectral>,
D: Dimension,
Self::Spectral: Clone,
Self::Physical: Zero + Clone + Copy,
{
let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
self.backward_inplace(indata, &mut outdata, axis);
outdata
}
apply_along_axis!(
backward_inplace,
Self::Spectral,
Self::Physical,
backward_slice,
len_spec,
len_phys,
"backward"
);
fn forward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Spectral, D>
where
S: Data<Elem = Self::Physical>,
D: Dimension,
Self::Physical: Clone + Send + Sync,
Self::Spectral: Zero + Clone + Copy + Send + Sync,
Self: Sync,
{
let mut outdata = array_resized_axis(indata, self.len_spec(), axis);
self.forward_inplace_par(indata, &mut outdata, axis);
outdata
}
par_apply_along_axis!(
forward_inplace_par,
Self::Physical,
Self::Spectral,
forward_slice,
len_phys,
len_spec,
"forward"
);
fn backward_par<S, D>(&self, indata: &ArrayBase<S, D>, axis: usize) -> Array<Self::Physical, D>
where
S: Data<Elem = Self::Spectral>,
D: Dimension,
Self::Spectral: Clone + Send + Sync,
Self::Physical: Zero + Clone + Copy + Send + Sync,
Self: Sync,
{
let mut outdata = array_resized_axis(indata, self.len_phys(), axis);
self.backward_inplace_par(indata, &mut outdata, axis);
outdata
}
par_apply_along_axis!(
backward_inplace_par,
Self::Spectral,
Self::Physical,
backward_slice,
len_spec,
len_phys,
"backward"
);
}
pub trait BaseGradient<T>: BaseSize
where
T: Zero + Copy,
{
fn gradient_slice(&self, indata: &[T], outdata: &mut [T], n_times: usize);
fn gradient<S, D>(&self, indata: &ArrayBase<S, D>, n_times: usize, axis: usize) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
{
let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
self.gradient_inplace(indata, &mut outdata, n_times, axis);
outdata
}
fn gradient_inplace<S1, S2, D>(
&self,
indata: &ArrayBase<S1, D>,
outdata: &mut ArrayBase<S2, D>,
n_times: usize,
axis: usize,
) where
S1: Data<Elem = T>,
S2: Data<Elem = T> + DataMut,
D: Dimension,
{
assert!(indata.is_standard_layout());
assert!(outdata.is_standard_layout());
check_array_axis(indata, self.len_spec(), axis, "gradient");
check_array_axis(outdata, self.len_orth(), axis, "gradient");
let outer_axis = outdata.ndim() - 1;
if axis == outer_axis {
Zip::from(indata.rows())
.and(outdata.rows_mut())
.for_each(|x, mut y| {
self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
});
} else {
let mut scratch: Vec<T> = vec![T::zero(); outdata.shape()[axis]];
Zip::from(indata.lanes(Axis(axis)))
.and(outdata.lanes_mut(Axis(axis)))
.for_each(|x, mut y| {
self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
for (yi, si) in y.iter_mut().zip(scratch.iter()) {
*yi = *si;
}
});
}
}
fn gradient_par<S, D>(
&self,
indata: &ArrayBase<S, D>,
n_times: usize,
axis: usize,
) -> Array<T, D>
where
S: Data<Elem = T>,
D: Dimension,
T: Send + Sync,
Self: Sync,
{
let mut outdata = array_resized_axis(indata, self.len_orth(), axis);
self.gradient_inplace_par(indata, &mut outdata, n_times, axis);
outdata
}
fn gradient_inplace_par<S1, S2, D>(
&self,
indata: &ArrayBase<S1, D>,
outdata: &mut ArrayBase<S2, D>,
n_times: usize,
axis: usize,
) where
S1: Data<Elem = T>,
S2: Data<Elem = T> + DataMut,
D: Dimension,
T: Send + Sync,
Self: Sync,
{
assert!(indata.is_standard_layout());
assert!(outdata.is_standard_layout());
check_array_axis(indata, self.len_spec(), axis, "gradient");
check_array_axis(outdata, self.len_orth(), axis, "gradient");
let outer_axis = outdata.ndim() - 1;
if axis == outer_axis {
Zip::from(indata.rows())
.and(outdata.rows_mut())
.par_for_each(|x, mut y| {
self.gradient_slice(x.as_slice().unwrap(), y.as_slice_mut().unwrap(), n_times);
});
} else {
let scratch_len = outdata.shape()[axis];
Zip::from(indata.lanes(Axis(axis)))
.and(outdata.lanes_mut(Axis(axis)))
.par_for_each(|x, mut y| {
let mut scratch: Vec<T> = vec![T::zero(); scratch_len];
self.gradient_slice(&x.to_vec(), &mut scratch, n_times);
for (yi, si) in y.iter_mut().zip(scratch.iter()) {
*yi = *si;
}
});
}
}
}