Skip to main content

ndarray_conv/conv_fft/
mod.rs

1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use ndarray::{
7    Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
8    SliceInfoElem,
9};
10use num::traits::NumAssign;
11use rustfft::FftNum;
12
13use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
14
15mod good_size;
16mod padding;
17mod processor;
18
19// pub use fft::Processor;
20pub use processor::{get as get_processor, GetProcessor, MaybeSync, Processor};
21
22// /// Represents a "baked" convolution operation.
23// ///
24// /// This struct holds pre-computed data for performing FFT-accelerated
25// /// convolutions, including the FFT size, FFT processor, scratch space,
26// /// and padding information. It's designed to optimize repeated
27// /// convolutions with the same kernel and settings.
28// pub struct Baked<T, SK, const N: usize>
29// where
30//     T: NumAssign + Debug + Copy,
31//     SK: RawData,
32// {
33//     fft_size: [usize; N],
34//     fft_processor: impl Processor<T>,
35//     scratch: Vec<Complex<T>>,
36//     cm: ExplicitConv<N>,
37//     padding_mode: PaddingMode<N, T>,
38//     kernel_raw_dim_with_dilation: [usize; N],
39//     pds_raw_dim: [usize; N],
40//     kernel_pd: Array<T, Dim<[Ix; N]>>,
41//     _sk_hint: PhantomData<SK>,
42// }
43
44/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
45///
46/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
47/// enabling efficient FFT-based convolutions on N-dimensional arrays.
48///
49/// # Type Parameters
50///
51/// *   `T`: The numeric type used internally for FFT operations. Must be a floating-point type that implements `FftNum`.
52/// *   `InElem`: The element type of the input arrays. Can be real (`T`) or complex (`Complex<T>`).
53/// *   `S`: The data storage type of the input array.
54/// *   `SK`: The data storage type of the kernel array.
55///
56/// # Methods
57///
58/// *   `conv_fft`: Performs an FFT-accelerated convolution with default settings.
59/// *   `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans across multiple convolutions for better performance.
60///
61/// # Example
62///
63/// ```rust
64/// use ndarray::prelude::*;
65/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
66///
67/// let arr = array![[1., 2.], [3., 4.]];
68/// let kernel = array![[1., 0.], [0., 1.]];
69/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
70/// ```
71///
72/// # Notes
73///
74/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
75/// Use standard convolution (`ConvExt::conv`) for small kernels or when working with integer types.
76///
77/// # Performance Tips
78///
79/// For repeated convolutions with different data but the same kernel and settings, consider using
80/// `conv_fft_with_processor` to reuse the FFT planner and avoid redundant setup overhead.
81pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
82where
83    T: NumAssign + Copy + FftNum,
84    InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
85    S: RawData,
86    SK: RawData,
87{
88    /// Performs an FFT-accelerated convolution operation.
89    ///
90    /// This method convolves the input array with a given kernel using FFT,
91    /// which is typically faster for larger kernels.
92    ///
93    /// # Arguments
94    ///
95    /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings.
96    /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
97    /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
98    ///
99    /// # Returns
100    ///
101    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
102    ///
103    /// # Example
104    ///
105    /// ```rust
106    /// use ndarray::array;
107    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
108    ///
109    /// let input = array![[1.0, 2.0], [3.0, 4.0]];
110    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
111    /// let result = input.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
112    /// ```
113    fn conv_fft(
114        &self,
115        kernel: impl IntoKernelWithDilation<'a, SK, N>,
116        conv_mode: ConvMode<N>,
117        padding_mode: PaddingMode<N, InElem>,
118    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
119
120    /// Performs an FFT-accelerated convolution using a provided processor.
121    ///
122    /// This method is useful when performing multiple convolutions, as it allows
123    /// reusing the FFT planner and avoiding redundant initialization overhead.
124    ///
125    /// # Arguments
126    ///
127    /// * `kernel`: The convolution kernel.
128    /// * `conv_mode`: The convolution mode.
129    /// * `padding_mode`: The padding mode.
130    /// * `fft_processor`: A mutable reference to an FFT processor instance.
131    ///
132    /// # Returns
133    ///
134    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// use ndarray::array;
140    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode, get_fft_processor};
141    ///
142    /// let input1 = array![[1.0, 2.0], [3.0, 4.0]];
143    /// let input2 = array![[5.0, 6.0], [7.0, 8.0]];
144    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
145    ///
146    /// // Reuse the same processor for multiple convolutions
147    /// let mut proc = get_fft_processor::<f32, f32>();
148    /// let result1 = input1.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
149    /// let result2 = input2.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
150    /// ```
151    fn conv_fft_with_processor(
152        &self,
153        kernel: impl IntoKernelWithDilation<'a, SK, N>,
154        conv_mode: ConvMode<N>,
155        padding_mode: PaddingMode<N, InElem>,
156        fft_processor: &mut impl Processor<T, InElem>,
157    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
158
159    /// Performs a parallel FFT-accelerated convolution operation using rayon.
160    ///
161    /// This method uses rayon to parallelize the FFT row operations, which can
162    /// significantly speed up convolution on large arrays.
163    ///
164    /// Only available when the `rayon` feature is enabled.
165    #[cfg(feature = "rayon")]
166    fn conv_fft_par(
167        &self,
168        kernel: impl IntoKernelWithDilation<'a, SK, N>,
169        conv_mode: ConvMode<N>,
170        padding_mode: PaddingMode<N, InElem>,
171    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
172
173
174
175    // fn conv_fft_bake(
176    //     &self,
177    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
178    //     conv_mode: ConvMode<N>,
179    //     padding_mode: PaddingMode<N, T>,
180    // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
181
182    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
183}
184
185fn conv_fft_proc_impl<'a, T, InElem, S, SK, const N: usize>(
186    data: &ArrayBase<S, Dim<[Ix; N]>>,
187    kernel: impl IntoKernelWithDilation<'a, SK, N>,
188    conv_mode: ConvMode<N>,
189    padding_mode: PaddingMode<N, InElem>,
190    fft_processor: &mut impl Processor<T, InElem>,
191    #[cfg_attr(not(feature = "rayon"), allow(unused_variables))] parallel: bool,
192) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>
193where
194    T: NumAssign + FftNum,
195    InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
196    S: Data<Elem = InElem> + MaybeSync + 'a,
197    SK: Data<Elem = InElem> + MaybeSync + 'a,
198    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
199    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
200        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
201    Dim<[Ix; N]>: RemoveAxis,
202{
203    let kwd = kernel.into_kernel_with_dilation();
204
205    let data_raw_dim = data.raw_dim();
206    if data.shape().iter().product::<usize>() == 0 {
207        return Err(crate::Error::DataShape(data_raw_dim));
208    }
209
210    let kernel_raw_dim = kwd.kernel.raw_dim();
211    if kwd.kernel.shape().iter().product::<usize>() == 0 {
212        return Err(crate::Error::DataShape(kernel_raw_dim));
213    }
214
215    let kernel_raw_dim_with_dilation: [usize; N] =
216        std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
217
218    let cm = conv_mode.unfold(&kwd);
219
220    let pds_raw_dim: [usize; N] =
221        std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
222    if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
223        return Err(crate::Error::MismatchShape(
224            conv_mode,
225            kernel_raw_dim_with_dilation,
226        ));
227    }
228
229    let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
230        pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
231    }));
232
233    // Parallel path: padding and forward FFT are folded into a single rayon::join,
234    // so data-padding+data-FFT runs concurrently with kernel-padding+kernel-FFT.
235    // The spectral multiply is parallelised too.
236    // Requires InElem/S/SK: MaybeSync so that shared references can cross thread
237    // boundaries inside rayon::join closures (no-op when rayon is disabled).
238    //
239    // Serial path: sequential padding then FFT as before.
240    #[cfg(feature = "rayon")]
241    let output = if parallel {
242        let (mut data_fft, kern_fft) = rayon::join(
243            || {
244                let mut pd = padding::data(data, padding_mode, cm.padding, fft_size);
245                fft_processor.forward(&mut pd, true)
246            },
247            || {
248                let mut pk = padding::kernel(kwd, fft_size);
249                let mut p = InElem::get_processor();
250                p.forward(&mut pk, true)
251            },
252        );
253        {
254            use rayon::prelude::*;
255            data_fft
256                .as_slice_mut()
257                .unwrap()
258                .par_iter_mut()
259                .zip(kern_fft.as_slice().unwrap().par_iter())
260                .for_each(|(d, k)| *d *= *k);
261        }
262        fft_processor.backward(&mut data_fft, true)
263    } else {
264        let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
265        let mut kernel_pd = padding::kernel(kwd, fft_size);
266        let mut data_fft = fft_processor.forward(&mut data_pd, false);
267        let kern_fft = fft_processor.forward(&mut kernel_pd, false);
268        data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
269        fft_processor.backward(&mut data_fft, false)
270    };
271
272    #[cfg(not(feature = "rayon"))]
273    let output = {
274        let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
275        let mut kernel_pd = padding::kernel(kwd, fft_size);
276        let mut data_fft = fft_processor.forward(&mut data_pd, false);
277        let kern_fft = fft_processor.forward(&mut kernel_pd, false);
278        data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
279        fft_processor.backward(&mut data_fft, false)
280    };
281
282    let output = output.slice_move(unsafe {
283        SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
284            start: kernel_raw_dim_with_dilation[i] as isize - 1,
285            end: Some((pds_raw_dim[i]) as isize),
286            step: cm.strides[i] as isize,
287        }))
288        .unwrap()
289    });
290
291    Ok(output)
292}
293
294impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
295    for ArrayBase<S, Dim<[Ix; N]>>
296where
297    T: NumAssign + FftNum,
298    InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
299    S: Data<Elem = InElem> + MaybeSync + 'a,
300    SK: Data<Elem = InElem> + MaybeSync + 'a,
301    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
302    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
303        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
304    Dim<[Ix; N]>: RemoveAxis,
305{
306    // fn conv_fft_bake(
307    //     &self,
308    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
309    //     conv_mode: ConvMode<N>,
310    //     padding_mode: PaddingMode<N, T>,
311    // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
312    //     let mut fft_processor = Processor::default();
313
314    //     let kwd = kernel.into_kernel_with_dilation();
315
316    //     let data_raw_dim = self.raw_dim();
317    //     if self.shape().iter().product::<usize>() == 0 {
318    //         return Err(crate::Error::DataShape(data_raw_dim));
319    //     }
320
321    //     let kernel_raw_dim = kwd.kernel.raw_dim();
322    //     if kwd.kernel.shape().iter().product::<usize>() == 0 {
323    //         return Err(crate::Error::DataShape(kernel_raw_dim));
324    //     }
325
326    //     let kernel_raw_dim_with_dilation: [usize; N] =
327    //         std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
328
329    //     let cm = conv_mode.unfold(&kwd);
330
331    //     let pds_raw_dim: [usize; N] =
332    //         std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
333    //     if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
334    //         return Err(crate::Error::MismatchShape(
335    //             conv_mode,
336    //             kernel_raw_dim_with_dilation,
337    //         ));
338    //     }
339
340    //     let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
341    //         pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
342    //     }));
343
344    //     let scratch = fft_processor.get_scratch(fft_size);
345
346    //     let kernel_pd = padding::kernel(kwd, fft_size);
347
348    //     Ok(Baked {
349    //         fft_size,
350    //         fft_processor,
351    //         scratch,
352    //         cm,
353    //         padding_mode,
354    //         kernel_raw_dim_with_dilation,
355    //         pds_raw_dim,
356    //         kernel_pd,
357    //         _sk_hint: PhantomData,
358    //     })
359    // }
360
361    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
362    //     let Baked {
363    //         scratch,
364    //         fft_processor,
365    //         fft_size,
366    //         cm,
367    //         padding_mode,
368    //         kernel_pd,
369    //         kernel_raw_dim_with_dilation,
370    //         pds_raw_dim,
371    //         _sk_hint,
372    //     } = baked;
373
374    //     let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
375
376    //     let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
377    //     let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
378
379    //     data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
380    //     // let mul_spec = data_pd_fft * kernel_pd_fft;
381
382    //     let output = fft_processor.backward(data_pd_fft);
383
384    //     output.slice_move(unsafe {
385    //         SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
386    //             start: kernel_raw_dim_with_dilation[i] as isize - 1,
387    //             end: Some((pds_raw_dim[i]) as isize),
388    //             step: cm.strides[i] as isize,
389    //         }))
390    //         .unwrap()
391    //     })
392    // }
393
394    fn conv_fft(
395        &self,
396        kernel: impl IntoKernelWithDilation<'a, SK, N>,
397        conv_mode: ConvMode<N>,
398        padding_mode: PaddingMode<N, InElem>,
399    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
400        let mut p = InElem::get_processor();
401        conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, false)
402    }
403
404    fn conv_fft_with_processor(
405        &self,
406        kernel: impl IntoKernelWithDilation<'a, SK, N>,
407        conv_mode: ConvMode<N>,
408        padding_mode: PaddingMode<N, InElem>,
409        fft_processor: &mut impl Processor<T, InElem>,
410    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
411        conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, fft_processor, false)
412    }
413
414    #[cfg(feature = "rayon")]
415    fn conv_fft_par(
416        &self,
417        kernel: impl IntoKernelWithDilation<'a, SK, N>,
418        conv_mode: ConvMode<N>,
419        padding_mode: PaddingMode<N, InElem>,
420    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
421        let mut p = InElem::get_processor();
422        conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, true)
423    }
424
425
426}
427
428#[cfg(test)]
429mod tests;