use ndarray::{
Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
SliceInfoElem,
};
use num::traits::NumAssign;
use rustfft::FftNum;
use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
mod good_size;
mod padding;
mod processor;
pub use processor::{get as get_processor, GetProcessor, MaybeSync, Processor};
pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
where
T: NumAssign + Copy + FftNum,
InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
S: RawData,
SK: RawData,
{
fn conv_fft(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
fn conv_fft_with_processor(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
fft_processor: &mut impl Processor<T, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
#[cfg(feature = "rayon")]
fn conv_fft_par(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
}
fn conv_fft_proc_impl<'a, T, InElem, S, SK, const N: usize>(
data: &ArrayBase<S, Dim<[Ix; N]>>,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
fft_processor: &mut impl Processor<T, InElem>,
#[cfg_attr(not(feature = "rayon"), allow(unused_variables))] parallel: bool,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>
where
T: NumAssign + FftNum,
InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
S: Data<Elem = InElem> + MaybeSync + 'a,
SK: Data<Elem = InElem> + MaybeSync + 'a,
[Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
Dim<[Ix; N]>: RemoveAxis,
{
let kwd = kernel.into_kernel_with_dilation();
let data_raw_dim = data.raw_dim();
if data.shape().iter().product::<usize>() == 0 {
return Err(crate::Error::DataShape(data_raw_dim));
}
let kernel_raw_dim = kwd.kernel.raw_dim();
if kwd.kernel.shape().iter().product::<usize>() == 0 {
return Err(crate::Error::DataShape(kernel_raw_dim));
}
let kernel_raw_dim_with_dilation: [usize; N] =
std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
let cm = conv_mode.unfold(&kwd);
let pds_raw_dim: [usize; N] =
std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
return Err(crate::Error::MismatchShape(
conv_mode,
kernel_raw_dim_with_dilation,
));
}
let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
}));
#[cfg(feature = "rayon")]
let output = if parallel {
let (mut data_fft, kern_fft) = rayon::join(
|| {
let mut pd = padding::data(data, padding_mode, cm.padding, fft_size);
fft_processor.forward(&mut pd, true)
},
|| {
let mut pk = padding::kernel(kwd, fft_size);
let mut p = InElem::get_processor();
p.forward(&mut pk, true)
},
);
{
use rayon::prelude::*;
data_fft
.as_slice_mut()
.unwrap()
.par_iter_mut()
.zip(kern_fft.as_slice().unwrap().par_iter())
.for_each(|(d, k)| *d *= *k);
}
fft_processor.backward(&mut data_fft, true)
} else {
let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
let mut kernel_pd = padding::kernel(kwd, fft_size);
let mut data_fft = fft_processor.forward(&mut data_pd, false);
let kern_fft = fft_processor.forward(&mut kernel_pd, false);
data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
fft_processor.backward(&mut data_fft, false)
};
#[cfg(not(feature = "rayon"))]
let output = {
let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
let mut kernel_pd = padding::kernel(kwd, fft_size);
let mut data_fft = fft_processor.forward(&mut data_pd, false);
let kern_fft = fft_processor.forward(&mut kernel_pd, false);
data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
fft_processor.backward(&mut data_fft, false)
};
let output = output.slice_move(unsafe {
SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
start: kernel_raw_dim_with_dilation[i] as isize - 1,
end: Some((pds_raw_dim[i]) as isize),
step: cm.strides[i] as isize,
}))
.unwrap()
});
Ok(output)
}
impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
for ArrayBase<S, Dim<[Ix; N]>>
where
T: NumAssign + FftNum,
InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
S: Data<Elem = InElem> + MaybeSync + 'a,
SK: Data<Elem = InElem> + MaybeSync + 'a,
[Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
Dim<[Ix; N]>: RemoveAxis,
{
fn conv_fft(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
let mut p = InElem::get_processor();
conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, false)
}
fn conv_fft_with_processor(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
fft_processor: &mut impl Processor<T, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, fft_processor, false)
}
#[cfg(feature = "rayon")]
fn conv_fft_par(
&self,
kernel: impl IntoKernelWithDilation<'a, SK, N>,
conv_mode: ConvMode<N>,
padding_mode: PaddingMode<N, InElem>,
) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
let mut p = InElem::get_processor();
conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, true)
}
}
#[cfg(test)]
mod tests;