ndarray_conv/conv_fft/processor/
mod.rs1use std::marker::PhantomData;
2
3use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
4use num::Complex;
5use rustfft::FftNum;
6
7pub mod complex;
8pub mod real;
9
10pub trait ConvFftNum: FftNum {}
11
12macro_rules! impl_conv_fft_num {
13 ($($t:ty),*) => {
14 $(impl ConvFftNum for $t {})*
15 };
16}
17
18impl_conv_fft_num!(i8, i16, i32, i64, i128, isize, f32, f64);
19
20pub fn get<T: FftNum, InElem: GetProcessor<T, InElem>>() -> impl Processor<T, InElem> {
21 InElem::get_processor()
22}
23
24pub trait Processor<T: FftNum, InElem: GetProcessor<T, InElem>> {
25 fn forward<S: DataMut<Elem = InElem>, const N: usize>(
26 &mut self,
27 input: &mut ArrayBase<S, Dim<[Ix; N]>>,
28 ) -> Array<Complex<T>, Dim<[Ix; N]>>
29 where
30 Dim<[Ix; N]>: RemoveAxis,
31 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
32
33 fn backward<const N: usize>(
34 &mut self,
35 input: &mut Array<Complex<T>, Dim<[Ix; N]>>,
36 ) -> Array<InElem, Dim<[Ix; N]>>
37 where
38 Dim<[Ix; N]>: RemoveAxis,
39 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>;
40}
41
42pub trait GetProcessor<T: FftNum, InElem>
43where
44 InElem: GetProcessor<T, InElem>,
45{
46 fn get_processor() -> impl Processor<T, InElem>;
47}
48
49impl<T: ConvFftNum> GetProcessor<T, T> for T {
50 fn get_processor() -> impl Processor<T, T> {
51 real::Processor::<T>::default()
52 }
53}
54
55impl<T: FftNum> GetProcessor<T, Complex<T>> for Complex<T> {
56 fn get_processor() -> impl Processor<T, Complex<T>> {
57 complex::Processor::<T>::default()
58 }
59}