ndarray_conv/conv_fft/
mod.rs

1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use ndarray::{
7    Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
8    SliceInfoElem,
9};
10use num::traits::NumAssign;
11use rustfft::FftNum;
12
13use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
14
15mod good_size;
16mod padding;
17mod processor;
18
19// pub use fft::Processor;
20pub use processor::{get as get_processor, GetProcessor, Processor};
21
22// /// Represents a "baked" convolution operation.
23// ///
24// /// This struct holds pre-computed data for performing FFT-accelerated
25// /// convolutions, including the FFT size, FFT processor, scratch space,
26// /// and padding information. It's designed to optimize repeated
27// /// convolutions with the same kernel and settings.
28// pub struct Baked<T, SK, const N: usize>
29// where
30//     T: NumAssign + Debug + Copy,
31//     SK: RawData,
32// {
33//     fft_size: [usize; N],
34//     fft_processor: impl Processor<T>,
35//     scratch: Vec<Complex<T>>,
36//     cm: ExplicitConv<N>,
37//     padding_mode: PaddingMode<N, T>,
38//     kernel_raw_dim_with_dilation: [usize; N],
39//     pds_raw_dim: [usize; N],
40//     kernel_pd: Array<T, Dim<[Ix; N]>>,
41//     _sk_hint: PhantomData<SK>,
42// }
43
44/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
45///
46/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
47/// enabling efficient FFT-based convolutions on N-dimensional arrays.
48///
49/// # Type Parameters
50///
51/// *   `T`: The numeric type used internally for FFT operations. Must be a floating-point type that implements `FftNum`.
52/// *   `InElem`: The element type of the input arrays. Can be real (`T`) or complex (`Complex<T>`).
53/// *   `S`: The data storage type of the input array.
54/// *   `SK`: The data storage type of the kernel array.
55///
56/// # Methods
57///
58/// *   `conv_fft`: Performs an FFT-accelerated convolution with default settings.
59/// *   `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans across multiple convolutions for better performance.
60///
61/// # Example
62///
63/// ```rust
64/// use ndarray::prelude::*;
65/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
66///
67/// let arr = array![[1., 2.], [3., 4.]];
68/// let kernel = array![[1., 0.], [0., 1.]];
69/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
70/// ```
71///
72/// # Notes
73///
74/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
75/// Use standard convolution (`ConvExt::conv`) for small kernels or when working with integer types.
76///
77/// # Performance Tips
78///
79/// For repeated convolutions with different data but the same kernel and settings, consider using
80/// `conv_fft_with_processor` to reuse the FFT planner and avoid redundant setup overhead.
81pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
82where
83    T: NumAssign + Copy + FftNum,
84    InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
85    S: RawData,
86    SK: RawData,
87{
88    /// Performs an FFT-accelerated convolution operation.
89    ///
90    /// This method convolves the input array with a given kernel using FFT,
91    /// which is typically faster for larger kernels.
92    ///
93    /// # Arguments
94    ///
95    /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings.
96    /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
97    /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
98    ///
99    /// # Returns
100    ///
101    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
102    ///
103    /// # Example
104    ///
105    /// ```rust
106    /// use ndarray::array;
107    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
108    ///
109    /// let input = array![[1.0, 2.0], [3.0, 4.0]];
110    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
111    /// let result = input.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
112    /// ```
113    fn conv_fft(
114        &self,
115        kernel: impl IntoKernelWithDilation<'a, SK, N>,
116        conv_mode: ConvMode<N>,
117        padding_mode: PaddingMode<N, InElem>,
118    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
119
120    /// Performs an FFT-accelerated convolution using a provided processor.
121    ///
122    /// This method is useful when performing multiple convolutions, as it allows
123    /// reusing the FFT planner and avoiding redundant initialization overhead.
124    ///
125    /// # Arguments
126    ///
127    /// * `kernel`: The convolution kernel.
128    /// * `conv_mode`: The convolution mode.
129    /// * `padding_mode`: The padding mode.
130    /// * `fft_processor`: A mutable reference to an FFT processor instance.
131    ///
132    /// # Returns
133    ///
134    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
135    ///
136    /// # Example
137    ///
138    /// ```rust
139    /// use ndarray::array;
140    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode, get_fft_processor};
141    ///
142    /// let input1 = array![[1.0, 2.0], [3.0, 4.0]];
143    /// let input2 = array![[5.0, 6.0], [7.0, 8.0]];
144    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
145    ///
146    /// // Reuse the same processor for multiple convolutions
147    /// let mut proc = get_fft_processor::<f32, f32>();
148    /// let result1 = input1.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
149    /// let result2 = input2.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
150    /// ```
151    fn conv_fft_with_processor(
152        &self,
153        kernel: impl IntoKernelWithDilation<'a, SK, N>,
154        conv_mode: ConvMode<N>,
155        padding_mode: PaddingMode<N, InElem>,
156        fft_processor: &mut impl Processor<T, InElem>,
157    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
158
159    // fn conv_fft_bake(
160    //     &self,
161    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
162    //     conv_mode: ConvMode<N>,
163    //     padding_mode: PaddingMode<N, T>,
164    // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
165
166    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
167}
168
169impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
170    for ArrayBase<S, Dim<[Ix; N]>>
171where
172    T: NumAssign + FftNum,
173    InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + 'a,
174    S: Data<Elem = InElem> + 'a,
175    SK: Data<Elem = InElem> + 'a,
176    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
177    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
178        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
179    Dim<[Ix; N]>: RemoveAxis,
180{
181    // fn conv_fft_bake(
182    //     &self,
183    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
184    //     conv_mode: ConvMode<N>,
185    //     padding_mode: PaddingMode<N, T>,
186    // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
187    //     let mut fft_processor = Processor::default();
188
189    //     let kwd = kernel.into_kernel_with_dilation();
190
191    //     let data_raw_dim = self.raw_dim();
192    //     if self.shape().iter().product::<usize>() == 0 {
193    //         return Err(crate::Error::DataShape(data_raw_dim));
194    //     }
195
196    //     let kernel_raw_dim = kwd.kernel.raw_dim();
197    //     if kwd.kernel.shape().iter().product::<usize>() == 0 {
198    //         return Err(crate::Error::DataShape(kernel_raw_dim));
199    //     }
200
201    //     let kernel_raw_dim_with_dilation: [usize; N] =
202    //         std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
203
204    //     let cm = conv_mode.unfold(&kwd);
205
206    //     let pds_raw_dim: [usize; N] =
207    //         std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
208    //     if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
209    //         return Err(crate::Error::MismatchShape(
210    //             conv_mode,
211    //             kernel_raw_dim_with_dilation,
212    //         ));
213    //     }
214
215    //     let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
216    //         pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
217    //     }));
218
219    //     let scratch = fft_processor.get_scratch(fft_size);
220
221    //     let kernel_pd = padding::kernel(kwd, fft_size);
222
223    //     Ok(Baked {
224    //         fft_size,
225    //         fft_processor,
226    //         scratch,
227    //         cm,
228    //         padding_mode,
229    //         kernel_raw_dim_with_dilation,
230    //         pds_raw_dim,
231    //         kernel_pd,
232    //         _sk_hint: PhantomData,
233    //     })
234    // }
235
236    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
237    //     let Baked {
238    //         scratch,
239    //         fft_processor,
240    //         fft_size,
241    //         cm,
242    //         padding_mode,
243    //         kernel_pd,
244    //         kernel_raw_dim_with_dilation,
245    //         pds_raw_dim,
246    //         _sk_hint,
247    //     } = baked;
248
249    //     let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
250
251    //     let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
252    //     let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
253
254    //     data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
255    //     // let mul_spec = data_pd_fft * kernel_pd_fft;
256
257    //     let output = fft_processor.backward(data_pd_fft);
258
259    //     output.slice_move(unsafe {
260    //         SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
261    //             start: kernel_raw_dim_with_dilation[i] as isize - 1,
262    //             end: Some((pds_raw_dim[i]) as isize),
263    //             step: cm.strides[i] as isize,
264    //         }))
265    //         .unwrap()
266    //     })
267    // }
268
269    fn conv_fft(
270        &self,
271        kernel: impl IntoKernelWithDilation<'a, SK, N>,
272        conv_mode: ConvMode<N>,
273        padding_mode: PaddingMode<N, InElem>,
274    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
275        let mut p = InElem::get_processor();
276        self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p)
277    }
278
279    fn conv_fft_with_processor(
280        &self,
281        kernel: impl IntoKernelWithDilation<'a, SK, N>,
282        conv_mode: ConvMode<N>,
283        padding_mode: PaddingMode<N, InElem>,
284        fft_processor: &mut impl Processor<T, InElem>,
285    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
286        let kwd = kernel.into_kernel_with_dilation();
287
288        let data_raw_dim = self.raw_dim();
289        if self.shape().iter().product::<usize>() == 0 {
290            return Err(crate::Error::DataShape(data_raw_dim));
291        }
292
293        let kernel_raw_dim = kwd.kernel.raw_dim();
294        if kwd.kernel.shape().iter().product::<usize>() == 0 {
295            return Err(crate::Error::DataShape(kernel_raw_dim));
296        }
297
298        let kernel_raw_dim_with_dilation: [usize; N] =
299            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
300
301        let cm = conv_mode.unfold(&kwd);
302
303        let pds_raw_dim: [usize; N] =
304            std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
305        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
306            return Err(crate::Error::MismatchShape(
307                conv_mode,
308                kernel_raw_dim_with_dilation,
309            ));
310        }
311
312        let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
313            pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
314        }));
315
316        let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size);
317        let mut kernel_pd = padding::kernel(kwd, fft_size);
318
319        let mut data_pd_fft = fft_processor.forward(&mut data_pd);
320        let kernel_pd_fft = fft_processor.forward(&mut kernel_pd);
321
322        data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
323        // let mul_spec = data_pd_fft * kernel_pd_fft;
324
325        let output = fft_processor.backward(&mut data_pd_fft);
326
327        let output = output.slice_move(unsafe {
328            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
329                start: kernel_raw_dim_with_dilation[i] as isize - 1,
330                end: Some((pds_raw_dim[i]) as isize),
331                step: cm.strides[i] as isize,
332            }))
333            .unwrap()
334        });
335
336        Ok(output)
337    }
338}
339
340#[cfg(test)]
341mod tests;