ndarray_conv/conv_fft/
mod.rs

1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use std::{fmt::Debug, marker::PhantomData};
7
8use ndarray::{
9    Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
10    SliceInfoElem,
11};
12use num::{traits::NumAssign, Complex};
13use rustfft::FftNum;
14
15use crate::{conv::ExplicitConv, dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
16
17mod fft;
18mod good_size;
19mod padding;
20
21pub use fft::Processor;
22
23/// Represents a "baked" convolution operation.
24///
25/// This struct holds pre-computed data for performing FFT-accelerated
26/// convolutions, including the FFT size, FFT processor, scratch space,
27/// and padding information. It's designed to optimize repeated
28/// convolutions with the same kernel and settings.
29pub struct Baked<T, SK, const N: usize>
30where
31    T: NumAssign + Debug + FftNum,
32    SK: RawData,
33{
34    fft_size: [usize; N],
35    fft_processor: Processor<T>,
36    scratch: Vec<Complex<T>>,
37    cm: ExplicitConv<N>,
38    padding_mode: PaddingMode<N, T>,
39    kernel_raw_dim_with_dilation: [usize; N],
40    pds_raw_dim: [usize; N],
41    kernel_pd: Array<T, Dim<[Ix; N]>>,
42    _sk_hint: PhantomData<SK>,
43}
44
45/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
46///
47/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
48/// enabling efficient FFT-based convolutions on N-dimensional arrays.
49///
50/// # Type Parameters
51///
52/// *   `T`: The numeric type of the array elements. Must be a floating-point type that implements `FftNum`.
53/// *   `S`: The data storage type of the input array.
54/// *   `SK`: The data storage type of the kernel array.
55///
56/// # Methods
57///
58/// *   `conv_fft`: Performs an FFT-accelerated convolution with default settings.
59/// *   `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans and scratch space.
60/// *   `conv_fft_bake`: Precomputes and stores necessary data for performing repeated convolutions in the form of `Baked`.
61/// *   `conv_fft_with_baked`: Performs a convolution with the provided `Baked` data.
62///
63/// # Example
64///
65/// ```rust
66/// use ndarray::prelude::*;
67/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
68///
69/// let arr = array![[1., 2.], [3., 4.]];
70/// let kernel = array![[1., 0.], [0., 1.]];
71/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
72/// ```
73///
74/// # Notes
75///
76/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
77pub trait ConvFFTExt<'a, T, S, SK, const N: usize>
78where
79    T: FftNum + NumAssign,
80    S: RawData,
81    SK: RawData,
82{
83    fn conv_fft(
84        &self,
85        kernel: impl IntoKernelWithDilation<'a, SK, N>,
86        conv_mode: ConvMode<N>,
87        padding_mode: PaddingMode<N, T>,
88    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
89
90    fn conv_fft_with_processor(
91        &self,
92        kernel: impl IntoKernelWithDilation<'a, SK, N>,
93        conv_mode: ConvMode<N>,
94        padding_mode: PaddingMode<N, T>,
95        fft_processor: &mut Processor<T>,
96    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
97
98    // fn conv_fft_bake(
99    //     &self,
100    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
101    //     conv_mode: ConvMode<N>,
102    //     padding_mode: PaddingMode<N, T>,
103    // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
104
105    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
106}
107
108impl<'a, T, S, SK, const N: usize> ConvFFTExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
109where
110    T: NumAssign + FftNum,
111    S: Data<Elem = T> + 'a,
112    SK: Data<Elem = T> + 'a,
113    [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
114    SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
115        SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
116    Dim<[Ix; N]>: RemoveAxis,
117{
118    // fn conv_fft_bake(
119    //     &self,
120    //     kernel: impl IntoKernelWithDilation<'a, SK, N>,
121    //     conv_mode: ConvMode<N>,
122    //     padding_mode: PaddingMode<N, T>,
123    // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
124    //     let mut fft_processor = Processor::default();
125
126    //     let kwd = kernel.into_kernel_with_dilation();
127
128    //     let data_raw_dim = self.raw_dim();
129    //     if self.shape().iter().product::<usize>() == 0 {
130    //         return Err(crate::Error::DataShape(data_raw_dim));
131    //     }
132
133    //     let kernel_raw_dim = kwd.kernel.raw_dim();
134    //     if kwd.kernel.shape().iter().product::<usize>() == 0 {
135    //         return Err(crate::Error::DataShape(kernel_raw_dim));
136    //     }
137
138    //     let kernel_raw_dim_with_dilation: [usize; N] =
139    //         std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
140
141    //     let cm = conv_mode.unfold(&kwd);
142
143    //     let pds_raw_dim: [usize; N] =
144    //         std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
145    //     if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
146    //         return Err(crate::Error::MismatchShape(
147    //             conv_mode,
148    //             kernel_raw_dim_with_dilation,
149    //         ));
150    //     }
151
152    //     let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
153    //         pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
154    //     }));
155
156    //     let scratch = fft_processor.get_scratch(fft_size);
157
158    //     let kernel_pd = padding::kernel(kwd, fft_size);
159
160    //     Ok(Baked {
161    //         fft_size,
162    //         fft_processor,
163    //         scratch,
164    //         cm,
165    //         padding_mode,
166    //         kernel_raw_dim_with_dilation,
167    //         pds_raw_dim,
168    //         kernel_pd,
169    //         _sk_hint: PhantomData,
170    //     })
171    // }
172
173    // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
174    //     let Baked {
175    //         scratch,
176    //         fft_processor,
177    //         fft_size,
178    //         cm,
179    //         padding_mode,
180    //         kernel_pd,
181    //         kernel_raw_dim_with_dilation,
182    //         pds_raw_dim,
183    //         _sk_hint,
184    //     } = baked;
185
186    //     let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
187
188    //     let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
189    //     let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
190
191    //     data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
192    //     // let mul_spec = data_pd_fft * kernel_pd_fft;
193
194    //     let output = fft_processor.backward(data_pd_fft);
195
196    //     output.slice_move(unsafe {
197    //         SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
198    //             start: kernel_raw_dim_with_dilation[i] as isize - 1,
199    //             end: Some((pds_raw_dim[i]) as isize),
200    //             step: cm.strides[i] as isize,
201    //         }))
202    //         .unwrap()
203    //     })
204    // }
205
206    fn conv_fft(
207        &self,
208        kernel: impl IntoKernelWithDilation<'a, SK, N>,
209        conv_mode: ConvMode<N>,
210        padding_mode: PaddingMode<N, T>,
211    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
212        let mut p = Processor::default();
213        self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p)
214    }
215
216    fn conv_fft_with_processor(
217        &self,
218        kernel: impl IntoKernelWithDilation<'a, SK, N>,
219        conv_mode: ConvMode<N>,
220        padding_mode: PaddingMode<N, T>,
221        fft_processor: &mut Processor<T>,
222    ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
223        let kwd = kernel.into_kernel_with_dilation();
224
225        let data_raw_dim = self.raw_dim();
226        if self.shape().iter().product::<usize>() == 0 {
227            return Err(crate::Error::DataShape(data_raw_dim));
228        }
229
230        let kernel_raw_dim = kwd.kernel.raw_dim();
231        if kwd.kernel.shape().iter().product::<usize>() == 0 {
232            return Err(crate::Error::DataShape(kernel_raw_dim));
233        }
234
235        let kernel_raw_dim_with_dilation: [usize; N] =
236            std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
237
238        let cm = conv_mode.unfold(&kwd);
239
240        let pds_raw_dim: [usize; N] =
241            std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
242        if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
243            return Err(crate::Error::MismatchShape(
244                conv_mode,
245                kernel_raw_dim_with_dilation,
246            ));
247        }
248
249        let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
250            pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
251        }));
252
253        let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size);
254        let mut kernel_pd = padding::kernel(kwd, fft_size);
255
256        let mut data_pd_fft = fft_processor.forward(&mut data_pd);
257        let kernel_pd_fft = fft_processor.forward(&mut kernel_pd);
258
259        data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
260        // let mul_spec = data_pd_fft * kernel_pd_fft;
261
262        let output = fft_processor.backward(data_pd_fft);
263
264        let output = output.slice_move(unsafe {
265            SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
266                start: kernel_raw_dim_with_dilation[i] as isize - 1,
267                end: Some((pds_raw_dim[i]) as isize),
268                step: cm.strides[i] as isize,
269            }))
270            .unwrap()
271        });
272
273        Ok(output)
274    }
275}
276
277#[cfg(test)]
278mod tests {
279    use ndarray::array;
280
281    use crate::{dilation::WithDilation, ConvExt, ReverseKernel};
282
283    use super::*;
284
285    #[test]
286    fn correct_size() {
287        let arr = array![[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]];
288        let kernel = array![[1, 0], [3, 1]];
289
290        let res_normal = arr
291            .conv(&kernel, ConvMode::Same, PaddingMode::Replicate)
292            .unwrap();
293        // dbg!(res_normal);
294
295        let res_fft = arr
296            .map(|&x| x as f64)
297            // The padding does not matter here, it is only used to calculate the correct size
298            .conv_fft(
299                &kernel.map(|&x| x as f64),
300                ConvMode::Same,
301                PaddingMode::Replicate,
302            )
303            .unwrap()
304            .map(|x| x.round() as i32);
305        // dbg!(res_fft);
306
307        assert_eq!(res_normal, res_fft);
308    }
309
310    #[test]
311    fn conv_fft() {
312        let arr = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];
313        let kernel = array![
314            [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
315            [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
316        ];
317
318        let res_normal = arr
319            .conv(&kernel, ConvMode::Same, PaddingMode::Zeros)
320            .unwrap();
321        // dbg!(res_normal);
322
323        let res_fft = arr
324            .map(|&x| x as f32)
325            .conv_fft(
326                &kernel.map(|&x| x as f32),
327                ConvMode::Same,
328                PaddingMode::Zeros,
329            )
330            .unwrap()
331            .map(|x| x.round() as i32);
332        // dbg!(res_fft);
333
334        assert_eq!(res_normal, res_fft);
335
336        //
337
338        let arr = array![[1, 2], [3, 4]];
339        let kernel = array![[1, 0], [3, 1]];
340
341        let res_normal = arr
342            .conv(
343                kernel.with_dilation(2).no_reverse(),
344                ConvMode::Custom {
345                    padding: [3, 3],
346                    strides: [2, 2],
347                },
348                PaddingMode::Replicate,
349            )
350            .unwrap();
351        // dbg!(res_normal);
352
353        let res_fft = arr
354            .map(|&x| x as f64)
355            .conv_fft(
356                kernel.map(|&x| x as f64).with_dilation(2).no_reverse(),
357                ConvMode::Custom {
358                    padding: [3, 3],
359                    strides: [2, 2],
360                },
361                PaddingMode::Replicate,
362            )
363            .unwrap()
364            .map(|x| x.round() as i32);
365        // dbg!(res_fft);
366
367        assert_eq!(res_normal, res_fft);
368
369        //
370
371        let arr = array![1, 2, 3, 4, 5, 6];
372        let kernel = array![1, 1, 1, 1];
373
374        let res_normal = arr
375            .conv(kernel.with_dilation(2), ConvMode::Same, PaddingMode::Zeros)
376            .unwrap();
377        // dbg!(&res_normal);
378
379        let res_fft = arr
380            .map(|&x| x as f32)
381            .conv_fft(
382                kernel.map(|&x| x as f32).with_dilation(2),
383                ConvMode::Same,
384                PaddingMode::Zeros,
385            )
386            .unwrap()
387            .map(|x| x.round() as i32);
388        // dbg!(res_fft);
389
390        assert_eq!(res_normal, res_fft);
391    }
392
393    #[test]
394    fn test_conv_fft_circular() {
395        use crate::*;
396        use ndarray::Array1;
397
398        let arr: Array1<f32> =
399            array![0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4];
400        let kernel: Array1<f32> = array![0.1, 0.3, 0.6, 0.3, 0.1];
401
402        arr.conv(&kernel, crate::ConvMode::Same, crate::PaddingMode::Circular)
403            .unwrap()
404            .iter()
405            .zip(
406                arr.conv_fft(&kernel, crate::ConvMode::Same, crate::PaddingMode::Circular)
407                    .unwrap(),
408            )
409            .for_each(|(a, b)| assert!((a - b).abs() < 1e-6));
410    }
411}