ndarray_conv/conv_fft/processor/
mod.rs

1use std::marker::PhantomData;
2
3use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
4use num::Complex;
5use rustfft::FftNum;
6
7pub mod complex;
8pub mod real;
9
10pub trait ConvFftNum: FftNum {}
11
12macro_rules! impl_conv_fft_num {
13    ($($t:ty),*) => {
14        $(impl ConvFftNum for $t {})*
15    };
16}
17
18impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
19
20pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
21    InElem::get_processor()
22}
23
24pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>> {
25    fn forward<S: DataMut<Elem = InElem>, const N: usize>(
26        &mut self,
27        input: &mut ArrayBase<S, Dim<[Ix; N]>>,
28    ) -> Array<Complex<T>, Dim<[Ix; N]>>
29    where
30        Dim<[Ix; N]>: RemoveAxis,
31        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
32
33    fn backward<const N: usize>(
34        &mut self,
35        input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
36    ) -> Array<InElem, Dim<[Ix; N]>>
37    where
38        Dim<[Ix; N]>: RemoveAxis,
39        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
40}
41
42pub trait GetProcessor<T: FftNum, InElem>
43where
44    InElem: GetProcessor<T, InElem>,
45{
46    fn get_processor() -> impl Processor<T, InElem>;
47}
48
49impl<T: ConvFftNum> GetProcessor<T, T> for T {
50    fn get_processor() -> impl Processor<T, T> {
51        real::Processor::<T>::default()
52    }
53}
54
55impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
56    fn get_processor() -> impl Processor<T, Complex<T>> {
57        complex::Processor::<T>::default()
58    }
59}