ndarray_conv/conv_fft/
mod.rs

1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use std::fmt::Debug;
7
8use ndarray::{
9    Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
10    SliceInfoElem,
11};
12use num::traits::NumAssign;
13use rustfft::FftNum;
14
15use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
16
17mod good_size;
18mod padding;
19mod processor;
20
21// pub use fft::Processor;
22pub use processor::{get as get_processor, GetProcessor, Processor};
23
24// /// Represents a "baked" convolution operation.
25// ///
26// /// This struct holds pre-computed data for performing FFT-accelerated
27// /// convolutions, including the FFT size, FFT processor, scratch space,
28// /// and padding information. It's designed to optimize repeated
29// /// convolutions with the same kernel and settings.
30// pub struct Baked<T, SK, const N: usize>
31// where
32//     T: NumAssign + Debug + Copy,
33//     SK: RawData,
34// {
35//     fft_size: [usize; N],
36//     fft_processor: impl Processor<T>,
37//     scratch: Vec<Complex<T>>,
38//     cm: ExplicitConv<N>,
39//     padding_mode: PaddingMode<N, T>,
40//     kernel_raw_dim_with_dilation: [usize; N],
41//     pds_raw_dim: [usize; N],
42//     kernel_pd: Array<T, Dim<[Ix; N]>>,
43//     _sk_hint: PhantomData<SK>,
44// }
45
46/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
47///
48/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
49/// enabling efficient FFT-based convolutions on N-dimensional arrays.
50///
51/// # Type Parameters
52///
53/// *   `T`: The numeric type of the array elements. Must be a floating-point type that implements `FftNum`.
54/// *   `S`: The data storage type of the input array.
55/// *   `SK`: The data storage type of the kernel array.
56///
57/// # Methods
58///
59/// *   `conv_fft`: Performs an FFT-accelerated convolution with default settings.
60/// *   `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans and scratch space.
61/// *   `conv_fft_bake`: Precomputes and stores necessary data for performing repeated convolutions in the form of `Baked`.
62/// *   `conv_fft_with_baked`: Performs a convolution with the provided `Baked` data.
63///
64/// # Example
65///
66/// ```rust
67/// use ndarray::prelude::*;
68/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
69///
70/// let arr = array![[1., 2.], [3., 4.]];
71/// let kernel = array![[1., 0.], [0., 1.]];
72/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
73/// ```
74///
75/// # Notes
76///
77/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
78pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
79where
80    T: NumAssign + Copy + FftNum,
81    InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
82    S: RawData,
83    SK: RawData,
84{
85    fn conv_fft(
86        &self,
87        kernel: impl IntoKernelWithDilation<'a, SK, N>,
88        conv_mode: ConvMode<N>,
89        padding_mode: PaddingMode<N, InElem>,
90    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
91
92    fn conv_fft_with_processor(
93        &self,
94        kernel: impl IntoKernelWithDilation<'a, SK, N>,
95        conv_mode: ConvMode<N>,
96        padding_mode: PaddingMode<N, InElem>,
97        fft_processor: &mut impl Processor<T, InElem>,
98    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
99
100    // fn conv_fft_bake(
101    //     &self,
102    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
103    //     conv_mode: ConvMode<N>,
104    //     padding_mode: PaddingMode<N, T>,
105    // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
106
107    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
108}
109
110impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
111    for ArrayBase<S, Dim<[Ix; N]>>
112where
113    T: NumAssign + FftNum,
114    InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + Debug,
115    S: Data<Elem = InElem> + 'a,
116    SK: Data<Elem = InElem> + 'a,
117    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
118    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
119        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
120    Dim<[Ix; N]>: RemoveAxis,
121{
122    // fn conv_fft_bake(
123    //     &self,
124    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
125    //     conv_mode: ConvMode<N>,
126    //     padding_mode: PaddingMode<N, T>,
127    // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
128    //     let mut fft_processor = Processor::default();
129
130    //     let kwd = kernel.into_kernel_with_dilation();
131
132    //     let data_raw_dim = self.raw_dim();
133    //     if self.shape().iter().product::<usize>() == 0 {
134    //         return Err(crate::Error::DataShape(data_raw_dim));
135    //     }
136
137    //     let kernel_raw_dim = kwd.kernel.raw_dim();
138    //     if kwd.kernel.shape().iter().product::<usize>() == 0 {
139    //         return Err(crate::Error::DataShape(kernel_raw_dim));
140    //     }
141
142    //     let kernel_raw_dim_with_dilation: [usize; N] =
143    //         std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
144
145    //     let cm = conv_mode.unfold(&kwd);
146
147    //     let pds_raw_dim: [usize; N] =
148    //         std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
149    //     if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
150    //         return Err(crate::Error::MismatchShape(
151    //             conv_mode,
152    //             kernel_raw_dim_with_dilation,
153    //         ));
154    //     }
155
156    //     let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
157    //         pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
158    //     }));
159
160    //     let scratch = fft_processor.get_scratch(fft_size);
161
162    //     let kernel_pd = padding::kernel(kwd, fft_size);
163
164    //     Ok(Baked {
165    //         fft_size,
166    //         fft_processor,
167    //         scratch,
168    //         cm,
169    //         padding_mode,
170    //         kernel_raw_dim_with_dilation,
171    //         pds_raw_dim,
172    //         kernel_pd,
173    //         _sk_hint: PhantomData,
174    //     })
175    // }
176
177    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
178    //     let Baked {
179    //         scratch,
180    //         fft_processor,
181    //         fft_size,
182    //         cm,
183    //         padding_mode,
184    //         kernel_pd,
185    //         kernel_raw_dim_with_dilation,
186    //         pds_raw_dim,
187    //         _sk_hint,
188    //     } = baked;
189
190    //     let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
191
192    //     let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
193    //     let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
194
195    //     data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
196    //     // let mul_spec = data_pd_fft * kernel_pd_fft;
197
198    //     let output = fft_processor.backward(data_pd_fft);
199
200    //     output.slice_move(unsafe {
201    //         SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
202    //             start: kernel_raw_dim_with_dilation[i] as isize - 1,
203    //             end: Some((pds_raw_dim[i]) as isize),
204    //             step: cm.strides[i] as isize,
205    //         }))
206    //         .unwrap()
207    //     })
208    // }
209
210    fn conv_fft(
211        &self,
212        kernel: impl IntoKernelWithDilation<'a, SK, N>,
213        conv_mode: ConvMode<N>,
214        padding_mode: PaddingMode<N, InElem>,
215    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
216        let mut p = InElem::get_processor();
217        self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p)
218    }
219
220    fn conv_fft_with_processor(
221        &self,
222        kernel: impl IntoKernelWithDilation<'a, SK, N>,
223        conv_mode: ConvMode<N>,
224        padding_mode: PaddingMode<N, InElem>,
225        fft_processor: &mut impl Processor<T, InElem>,
226    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
227        let kwd = kernel.into_kernel_with_dilation();
228
229        let data_raw_dim = self.raw_dim();
230        if self.shape().iter().product::<usize>() == 0 {
231            return Err(crate::Error::DataShape(data_raw_dim));
232        }
233
234        let kernel_raw_dim = kwd.kernel.raw_dim();
235        if kwd.kernel.shape().iter().product::<usize>() == 0 {
236            return Err(crate::Error::DataShape(kernel_raw_dim));
237        }
238
239        let kernel_raw_dim_with_dilation: [usize; N] =
240            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
241
242        let cm = conv_mode.unfold(&kwd);
243
244        let pds_raw_dim: [usize; N] =
245            std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
246        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
247            return Err(crate::Error::MismatchShape(
248                conv_mode,
249                kernel_raw_dim_with_dilation,
250            ));
251        }
252
253        let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
254            pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
255        }));
256
257        let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size);
258        let mut kernel_pd = padding::kernel(kwd, fft_size);
259
260        let mut data_pd_fft = fft_processor.forward(&mut data_pd);
261        let kernel_pd_fft = fft_processor.forward(&mut kernel_pd);
262
263        data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
264        // let mul_spec = data_pd_fft * kernel_pd_fft;
265
266        let output = fft_processor.backward(&mut data_pd_fft);
267
268        let output = output.slice_move(unsafe {
269            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
270                start: kernel_raw_dim_with_dilation[i] as isize - 1,
271                end: Some((pds_raw_dim[i]) as isize),
272                step: cm.strides[i] as isize,
273            }))
274            .unwrap()
275        });
276
277        Ok(output)
278    }
279}
280
281#[cfg(test)]
282mod tests;