ndarray_conv/conv_fft/
mod.rs

1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use std::fmt::Debug;
7
8use ndarray::{
9    Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
10    SliceInfoElem,
11};
12use num::traits::NumAssign;
13use rustfft::FftNum;
14
15use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
16
17mod good_size;
18mod padding;
19mod processor;
20
21// pub use fft::Processor;
22pub use processor::{get as get_processor, GetProcessor, Processor};
23
24// /// Represents a "baked" convolution operation.
25// ///
26// /// This struct holds pre-computed data for performing FFT-accelerated
27// /// convolutions, including the FFT size, FFT processor, scratch space,
28// /// and padding information. It's designed to optimize repeated
29// /// convolutions with the same kernel and settings.
30// pub struct Baked<T, SK, const N: usize>
31// where
32//     T: NumAssign + Debug + Copy,
33//     SK: RawData,
34// {
35//     fft_size: [usize; N],
36//     fft_processor: impl Processor<T>,
37//     scratch: Vec<Complex<T>>,
38//     cm: ExplicitConv<N>,
39//     padding_mode: PaddingMode<N, T>,
40//     kernel_raw_dim_with_dilation: [usize; N],
41//     pds_raw_dim: [usize; N],
42//     kernel_pd: Array<T, Dim<[Ix; N]>>,
43//     _sk_hint: PhantomData<SK>,
44// }
45
46/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
47///
48/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
49/// enabling efficient FFT-based convolutions on N-dimensional arrays.
50///
51/// # Type Parameters
52///
53/// *   `T`: The numeric type used internally for FFT operations. Must be a floating-point type that implements `FftNum`.
54/// *   `InElem`: The element type of the input arrays. Can be real (`T`) or complex (`Complex<T>`).
55/// *   `S`: The data storage type of the input array.
56/// *   `SK`: The data storage type of the kernel array.
57///
58/// # Methods
59///
60/// *   `conv_fft`: Performs an FFT-accelerated convolution with default settings.
61/// *   `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans across multiple convolutions for better performance.
62///
63/// # Example
64///
65/// ```rust
66/// use ndarray::prelude::*;
67/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
68///
69/// let arr = array![[1., 2.], [3., 4.]];
70/// let kernel = array![[1., 0.], [0., 1.]];
71/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
72/// ```
73///
74/// # Notes
75///
76/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
77/// Use standard convolution (`ConvExt::conv`) for small kernels or when working with integer types.
78///
79/// # Performance Tips
80///
81/// For repeated convolutions with different data but the same kernel and settings, consider using
82/// `conv_fft_with_processor` to reuse the FFT planner and avoid redundant setup overhead.
83pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
84where
85    T: NumAssign + Copy + FftNum,
86    InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
87    S: RawData,
88    SK: RawData,
89{
90    /// Performs an FFT-accelerated convolution operation.
91    ///
92    /// This method convolves the input array with a given kernel using FFT,
93    /// which is typically faster for larger kernels.
94    ///
95    /// # Arguments
96    ///
97    /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings.
98    /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
99    /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
100    ///
101    /// # Returns
102    ///
103    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use ndarray::array;
109    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
110    ///
111    /// let input = array![[1.0, 2.0], [3.0, 4.0]];
112    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
113    /// let result = input.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
114    /// ```
115    fn conv_fft(
116        &self,
117        kernel: impl IntoKernelWithDilation<'a, SK, N>,
118        conv_mode: ConvMode<N>,
119        padding_mode: PaddingMode<N, InElem>,
120    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
121
122    /// Performs an FFT-accelerated convolution using a provided processor.
123    ///
124    /// This method is useful when performing multiple convolutions, as it allows
125    /// reusing the FFT planner and avoiding redundant initialization overhead.
126    ///
127    /// # Arguments
128    ///
129    /// * `kernel`: The convolution kernel.
130    /// * `conv_mode`: The convolution mode.
131    /// * `padding_mode`: The padding mode.
132    /// * `fft_processor`: A mutable reference to an FFT processor instance.
133    ///
134    /// # Returns
135    ///
136    /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
137    ///
138    /// # Example
139    ///
140    /// ```rust
141    /// use ndarray::array;
142    /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode, get_fft_processor};
143    ///
144    /// let input1 = array![[1.0, 2.0], [3.0, 4.0]];
145    /// let input2 = array![[5.0, 6.0], [7.0, 8.0]];
146    /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
147    ///
148    /// // Reuse the same processor for multiple convolutions
149    /// let mut proc = get_fft_processor::<f32, f32>();
150    /// let result1 = input1.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
151    /// let result2 = input2.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
152    /// ```
153    fn conv_fft_with_processor(
154        &self,
155        kernel: impl IntoKernelWithDilation<'a, SK, N>,
156        conv_mode: ConvMode<N>,
157        padding_mode: PaddingMode<N, InElem>,
158        fft_processor: &mut impl Processor<T, InElem>,
159    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
160
161    // fn conv_fft_bake(
162    //     &self,
163    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
164    //     conv_mode: ConvMode<N>,
165    //     padding_mode: PaddingMode<N, T>,
166    // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
167
168    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
169}
170
171impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
172    for ArrayBase<S, Dim<[Ix; N]>>
173where
174    T: NumAssign + FftNum,
175    InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + Debug,
176    S: Data<Elem = InElem> + 'a,
177    SK: Data<Elem = InElem> + 'a,
178    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
179    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
180        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
181    Dim<[Ix; N]>: RemoveAxis,
182{
183    // fn conv_fft_bake(
184    //     &self,
185    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
186    //     conv_mode: ConvMode<N>,
187    //     padding_mode: PaddingMode<N, T>,
188    // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
189    //     let mut fft_processor = Processor::default();
190
191    //     let kwd = kernel.into_kernel_with_dilation();
192
193    //     let data_raw_dim = self.raw_dim();
194    //     if self.shape().iter().product::<usize>() == 0 {
195    //         return Err(crate::Error::DataShape(data_raw_dim));
196    //     }
197
198    //     let kernel_raw_dim = kwd.kernel.raw_dim();
199    //     if kwd.kernel.shape().iter().product::<usize>() == 0 {
200    //         return Err(crate::Error::DataShape(kernel_raw_dim));
201    //     }
202
203    //     let kernel_raw_dim_with_dilation: [usize; N] =
204    //         std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
205
206    //     let cm = conv_mode.unfold(&kwd);
207
208    //     let pds_raw_dim: [usize; N] =
209    //         std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
210    //     if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
211    //         return Err(crate::Error::MismatchShape(
212    //             conv_mode,
213    //             kernel_raw_dim_with_dilation,
214    //         ));
215    //     }
216
217    //     let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
218    //         pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
219    //     }));
220
221    //     let scratch = fft_processor.get_scratch(fft_size);
222
223    //     let kernel_pd = padding::kernel(kwd, fft_size);
224
225    //     Ok(Baked {
226    //         fft_size,
227    //         fft_processor,
228    //         scratch,
229    //         cm,
230    //         padding_mode,
231    //         kernel_raw_dim_with_dilation,
232    //         pds_raw_dim,
233    //         kernel_pd,
234    //         _sk_hint: PhantomData,
235    //     })
236    // }
237
238    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
239    //     let Baked {
240    //         scratch,
241    //         fft_processor,
242    //         fft_size,
243    //         cm,
244    //         padding_mode,
245    //         kernel_pd,
246    //         kernel_raw_dim_with_dilation,
247    //         pds_raw_dim,
248    //         _sk_hint,
249    //     } = baked;
250
251    //     let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
252
253    //     let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
254    //     let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
255
256    //     data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
257    //     // let mul_spec = data_pd_fft * kernel_pd_fft;
258
259    //     let output = fft_processor.backward(data_pd_fft);
260
261    //     output.slice_move(unsafe {
262    //         SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
263    //             start: kernel_raw_dim_with_dilation[i] as isize - 1,
264    //             end: Some((pds_raw_dim[i]) as isize),
265    //             step: cm.strides[i] as isize,
266    //         }))
267    //         .unwrap()
268    //     })
269    // }
270
271    fn conv_fft(
272        &self,
273        kernel: impl IntoKernelWithDilation<'a, SK, N>,
274        conv_mode: ConvMode<N>,
275        padding_mode: PaddingMode<N, InElem>,
276    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
277        let mut p = InElem::get_processor();
278        self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p)
279    }
280
281    fn conv_fft_with_processor(
282        &self,
283        kernel: impl IntoKernelWithDilation<'a, SK, N>,
284        conv_mode: ConvMode<N>,
285        padding_mode: PaddingMode<N, InElem>,
286        fft_processor: &mut impl Processor<T, InElem>,
287    ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
288        let kwd = kernel.into_kernel_with_dilation();
289
290        let data_raw_dim = self.raw_dim();
291        if self.shape().iter().product::<usize>() == 0 {
292            return Err(crate::Error::DataShape(data_raw_dim));
293        }
294
295        let kernel_raw_dim = kwd.kernel.raw_dim();
296        if kwd.kernel.shape().iter().product::<usize>() == 0 {
297            return Err(crate::Error::DataShape(kernel_raw_dim));
298        }
299
300        let kernel_raw_dim_with_dilation: [usize; N] =
301            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
302
303        let cm = conv_mode.unfold(&kwd);
304
305        let pds_raw_dim: [usize; N] =
306            std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
307        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
308            return Err(crate::Error::MismatchShape(
309                conv_mode,
310                kernel_raw_dim_with_dilation,
311            ));
312        }
313
314        let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
315            pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
316        }));
317
318        let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size);
319        let mut kernel_pd = padding::kernel(kwd, fft_size);
320
321        let mut data_pd_fft = fft_processor.forward(&mut data_pd);
322        let kernel_pd_fft = fft_processor.forward(&mut kernel_pd);
323
324        data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
325        // let mul_spec = data_pd_fft * kernel_pd_fft;
326
327        let output = fft_processor.backward(&mut data_pd_fft);
328
329        let output = output.slice_move(unsafe {
330            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
331                start: kernel_raw_dim_with_dilation[i] as isize - 1,
332                end: Some((pds_raw_dim[i]) as isize),
333                step: cm.strides[i] as isize,
334            }))
335            .unwrap()
336        });
337
338        Ok(output)
339    }
340}
341
342#[cfg(test)]
343mod tests;