ndarray_conv/conv_fft/mod.rs
1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use ndarray::{
7 Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
8 SliceInfoElem,
9};
10use num::traits::NumAssign;
11use rustfft::FftNum;
12
13use crate::{dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
14
15mod good_size;
16mod padding;
17mod processor;
18
19// pub use fft::Processor;
20pub use processor::{get as get_processor, GetProcessor, MaybeSync, Processor};
21
22// /// Represents a "baked" convolution operation.
23// ///
24// /// This struct holds pre-computed data for performing FFT-accelerated
25// /// convolutions, including the FFT size, FFT processor, scratch space,
26// /// and padding information. It's designed to optimize repeated
27// /// convolutions with the same kernel and settings.
28// pub struct Baked<T, SK, const N: usize>
29// where
30// T: NumAssign + Debug + Copy,
31// SK: RawData,
32// {
33// fft_size: [usize; N],
34// fft_processor: impl Processor<T>,
35// scratch: Vec<Complex<T>>,
36// cm: ExplicitConv<N>,
37// padding_mode: PaddingMode<N, T>,
38// kernel_raw_dim_with_dilation: [usize; N],
39// pds_raw_dim: [usize; N],
40// kernel_pd: Array<T, Dim<[Ix; N]>>,
41// _sk_hint: PhantomData<SK>,
42// }
43
44/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
45///
46/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
47/// enabling efficient FFT-based convolutions on N-dimensional arrays.
48///
49/// # Type Parameters
50///
51/// * `T`: The numeric type used internally for FFT operations. Must be a floating-point type that implements `FftNum`.
52/// * `InElem`: The element type of the input arrays. Can be real (`T`) or complex (`Complex<T>`).
53/// * `S`: The data storage type of the input array.
54/// * `SK`: The data storage type of the kernel array.
55///
56/// # Methods
57///
58/// * `conv_fft`: Performs an FFT-accelerated convolution with default settings.
59/// * `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans across multiple convolutions for better performance.
60///
61/// # Example
62///
63/// ```rust
64/// use ndarray::prelude::*;
65/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
66///
67/// let arr = array![[1., 2.], [3., 4.]];
68/// let kernel = array![[1., 0.], [0., 1.]];
69/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
70/// ```
71///
72/// # Notes
73///
74/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
75/// Use standard convolution (`ConvExt::conv`) for small kernels or when working with integer types.
76///
77/// # Performance Tips
78///
79/// For repeated convolutions with different data but the same kernel and settings, consider using
80/// `conv_fft_with_processor` to reuse the FFT planner and avoid redundant setup overhead.
81pub trait ConvFFTExt<'a, T, InElem, S, SK, const N: usize>
82where
83 T: NumAssign + Copy + FftNum,
84 InElem: processor::GetProcessor<T, InElem> + Copy + NumAssign,
85 S: RawData,
86 SK: RawData,
87{
88 /// Performs an FFT-accelerated convolution operation.
89 ///
90 /// This method convolves the input array with a given kernel using FFT,
91 /// which is typically faster for larger kernels.
92 ///
93 /// # Arguments
94 ///
95 /// * `kernel`: The convolution kernel. Can be a reference to an array, or an array with dilation settings.
96 /// * `conv_mode`: The convolution mode (`Full`, `Same`, `Valid`, `Custom`, `Explicit`).
97 /// * `padding_mode`: The padding mode (`Zeros`, `Const`, `Reflect`, `Replicate`, `Circular`, `Custom`, `Explicit`).
98 ///
99 /// # Returns
100 ///
101 /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
102 ///
103 /// # Example
104 ///
105 /// ```rust
106 /// use ndarray::array;
107 /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
108 ///
109 /// let input = array![[1.0, 2.0], [3.0, 4.0]];
110 /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
111 /// let result = input.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
112 /// ```
113 fn conv_fft(
114 &self,
115 kernel: impl IntoKernelWithDilation<'a, SK, N>,
116 conv_mode: ConvMode<N>,
117 padding_mode: PaddingMode<N, InElem>,
118 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
119
120 /// Performs an FFT-accelerated convolution using a provided processor.
121 ///
122 /// This method is useful when performing multiple convolutions, as it allows
123 /// reusing the FFT planner and avoiding redundant initialization overhead.
124 ///
125 /// # Arguments
126 ///
127 /// * `kernel`: The convolution kernel.
128 /// * `conv_mode`: The convolution mode.
129 /// * `padding_mode`: The padding mode.
130 /// * `fft_processor`: A mutable reference to an FFT processor instance.
131 ///
132 /// # Returns
133 ///
134 /// Returns `Ok(Array<InElem, Dim<[Ix; N]>>)` containing the convolution result, or an `Err(Error<N>)` if the operation fails.
135 ///
136 /// # Example
137 ///
138 /// ```rust
139 /// use ndarray::array;
140 /// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode, get_fft_processor};
141 ///
142 /// let input1 = array![[1.0, 2.0], [3.0, 4.0]];
143 /// let input2 = array![[5.0, 6.0], [7.0, 8.0]];
144 /// let kernel = array![[1.0, 0.0], [0.0, 1.0]];
145 ///
146 /// // Reuse the same processor for multiple convolutions
147 /// let mut proc = get_fft_processor::<f32, f32>();
148 /// let result1 = input1.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
149 /// let result2 = input2.conv_fft_with_processor(&kernel, ConvMode::Same, PaddingMode::Zeros, &mut proc).unwrap();
150 /// ```
151 fn conv_fft_with_processor(
152 &self,
153 kernel: impl IntoKernelWithDilation<'a, SK, N>,
154 conv_mode: ConvMode<N>,
155 padding_mode: PaddingMode<N, InElem>,
156 fft_processor: &mut impl Processor<T, InElem>,
157 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
158
159 /// Performs a parallel FFT-accelerated convolution operation using rayon.
160 ///
161 /// This method uses rayon to parallelize the FFT row operations, which can
162 /// significantly speed up convolution on large arrays.
163 ///
164 /// Only available when the `rayon` feature is enabled.
165 #[cfg(feature = "rayon")]
166 fn conv_fft_par(
167 &self,
168 kernel: impl IntoKernelWithDilation<'a, SK, N>,
169 conv_mode: ConvMode<N>,
170 padding_mode: PaddingMode<N, InElem>,
171 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>;
172
173
174
175 // fn conv_fft_bake(
176 // &self,
177 // kernel: impl IntoKernelWithDilation<'a, SK, N>,
178 // conv_mode: ConvMode<N>,
179 // padding_mode: PaddingMode<N, T>,
180 // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
181
182 // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
183}
184
185fn conv_fft_proc_impl<'a, T, InElem, S, SK, const N: usize>(
186 data: &ArrayBase<S, Dim<[Ix; N]>>,
187 kernel: impl IntoKernelWithDilation<'a, SK, N>,
188 conv_mode: ConvMode<N>,
189 padding_mode: PaddingMode<N, InElem>,
190 fft_processor: &mut impl Processor<T, InElem>,
191 #[cfg_attr(not(feature = "rayon"), allow(unused_variables))] parallel: bool,
192) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>>
193where
194 T: NumAssign + FftNum,
195 InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
196 S: Data<Elem = InElem> + MaybeSync + 'a,
197 SK: Data<Elem = InElem> + MaybeSync + 'a,
198 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
199 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
200 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
201 Dim<[Ix; N]>: RemoveAxis,
202{
203 let kwd = kernel.into_kernel_with_dilation();
204
205 let data_raw_dim = data.raw_dim();
206 if data.shape().iter().product::<usize>() == 0 {
207 return Err(crate::Error::DataShape(data_raw_dim));
208 }
209
210 let kernel_raw_dim = kwd.kernel.raw_dim();
211 if kwd.kernel.shape().iter().product::<usize>() == 0 {
212 return Err(crate::Error::DataShape(kernel_raw_dim));
213 }
214
215 let kernel_raw_dim_with_dilation: [usize; N] =
216 std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
217
218 let cm = conv_mode.unfold(&kwd);
219
220 let pds_raw_dim: [usize; N] =
221 std::array::from_fn(|i| data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]);
222 if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
223 return Err(crate::Error::MismatchShape(
224 conv_mode,
225 kernel_raw_dim_with_dilation,
226 ));
227 }
228
229 let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
230 pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
231 }));
232
233 // Parallel path: padding and forward FFT are folded into a single rayon::join,
234 // so data-padding+data-FFT runs concurrently with kernel-padding+kernel-FFT.
235 // The spectral multiply is parallelised too.
236 // Requires InElem/S/SK: MaybeSync so that shared references can cross thread
237 // boundaries inside rayon::join closures (no-op when rayon is disabled).
238 //
239 // Serial path: sequential padding then FFT as before.
240 #[cfg(feature = "rayon")]
241 let output = if parallel {
242 let (mut data_fft, kern_fft) = rayon::join(
243 || {
244 let mut pd = padding::data(data, padding_mode, cm.padding, fft_size);
245 fft_processor.forward(&mut pd, true)
246 },
247 || {
248 let mut pk = padding::kernel(kwd, fft_size);
249 let mut p = InElem::get_processor();
250 p.forward(&mut pk, true)
251 },
252 );
253 {
254 use rayon::prelude::*;
255 data_fft
256 .as_slice_mut()
257 .unwrap()
258 .par_iter_mut()
259 .zip(kern_fft.as_slice().unwrap().par_iter())
260 .for_each(|(d, k)| *d *= *k);
261 }
262 fft_processor.backward(&mut data_fft, true)
263 } else {
264 let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
265 let mut kernel_pd = padding::kernel(kwd, fft_size);
266 let mut data_fft = fft_processor.forward(&mut data_pd, false);
267 let kern_fft = fft_processor.forward(&mut kernel_pd, false);
268 data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
269 fft_processor.backward(&mut data_fft, false)
270 };
271
272 #[cfg(not(feature = "rayon"))]
273 let output = {
274 let mut data_pd = padding::data(data, padding_mode, cm.padding, fft_size);
275 let mut kernel_pd = padding::kernel(kwd, fft_size);
276 let mut data_fft = fft_processor.forward(&mut data_pd, false);
277 let kern_fft = fft_processor.forward(&mut kernel_pd, false);
278 data_fft.zip_mut_with(&kern_fft, |d, k| *d *= *k);
279 fft_processor.backward(&mut data_fft, false)
280 };
281
282 let output = output.slice_move(unsafe {
283 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
284 start: kernel_raw_dim_with_dilation[i] as isize - 1,
285 end: Some((pds_raw_dim[i]) as isize),
286 step: cm.strides[i] as isize,
287 }))
288 .unwrap()
289 });
290
291 Ok(output)
292}
293
294impl<'a, T, InElem, S, SK, const N: usize> ConvFFTExt<'a, T, InElem, S, SK, N>
295 for ArrayBase<S, Dim<[Ix; N]>>
296where
297 T: NumAssign + FftNum,
298 InElem: processor::GetProcessor<T, InElem> + NumAssign + Copy + MaybeSync + 'a,
299 S: Data<Elem = InElem> + MaybeSync + 'a,
300 SK: Data<Elem = InElem> + MaybeSync + 'a,
301 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
302 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
303 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
304 Dim<[Ix; N]>: RemoveAxis,
305{
306 // fn conv_fft_bake(
307 // &self,
308 // kernel: impl IntoKernelWithDilation<'a, SK, N>,
309 // conv_mode: ConvMode<N>,
310 // padding_mode: PaddingMode<N, T>,
311 // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
312 // let mut fft_processor = Processor::default();
313
314 // let kwd = kernel.into_kernel_with_dilation();
315
316 // let data_raw_dim = self.raw_dim();
317 // if self.shape().iter().product::<usize>() == 0 {
318 // return Err(crate::Error::DataShape(data_raw_dim));
319 // }
320
321 // let kernel_raw_dim = kwd.kernel.raw_dim();
322 // if kwd.kernel.shape().iter().product::<usize>() == 0 {
323 // return Err(crate::Error::DataShape(kernel_raw_dim));
324 // }
325
326 // let kernel_raw_dim_with_dilation: [usize; N] =
327 // std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
328
329 // let cm = conv_mode.unfold(&kwd);
330
331 // let pds_raw_dim: [usize; N] =
332 // std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
333 // if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
334 // return Err(crate::Error::MismatchShape(
335 // conv_mode,
336 // kernel_raw_dim_with_dilation,
337 // ));
338 // }
339
340 // let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
341 // pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
342 // }));
343
344 // let scratch = fft_processor.get_scratch(fft_size);
345
346 // let kernel_pd = padding::kernel(kwd, fft_size);
347
348 // Ok(Baked {
349 // fft_size,
350 // fft_processor,
351 // scratch,
352 // cm,
353 // padding_mode,
354 // kernel_raw_dim_with_dilation,
355 // pds_raw_dim,
356 // kernel_pd,
357 // _sk_hint: PhantomData,
358 // })
359 // }
360
361 // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
362 // let Baked {
363 // scratch,
364 // fft_processor,
365 // fft_size,
366 // cm,
367 // padding_mode,
368 // kernel_pd,
369 // kernel_raw_dim_with_dilation,
370 // pds_raw_dim,
371 // _sk_hint,
372 // } = baked;
373
374 // let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
375
376 // let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
377 // let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
378
379 // data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
380 // // let mul_spec = data_pd_fft * kernel_pd_fft;
381
382 // let output = fft_processor.backward(data_pd_fft);
383
384 // output.slice_move(unsafe {
385 // SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
386 // start: kernel_raw_dim_with_dilation[i] as isize - 1,
387 // end: Some((pds_raw_dim[i]) as isize),
388 // step: cm.strides[i] as isize,
389 // }))
390 // .unwrap()
391 // })
392 // }
393
394 fn conv_fft(
395 &self,
396 kernel: impl IntoKernelWithDilation<'a, SK, N>,
397 conv_mode: ConvMode<N>,
398 padding_mode: PaddingMode<N, InElem>,
399 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
400 let mut p = InElem::get_processor();
401 conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, false)
402 }
403
404 fn conv_fft_with_processor(
405 &self,
406 kernel: impl IntoKernelWithDilation<'a, SK, N>,
407 conv_mode: ConvMode<N>,
408 padding_mode: PaddingMode<N, InElem>,
409 fft_processor: &mut impl Processor<T, InElem>,
410 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
411 conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, fft_processor, false)
412 }
413
414 #[cfg(feature = "rayon")]
415 fn conv_fft_par(
416 &self,
417 kernel: impl IntoKernelWithDilation<'a, SK, N>,
418 conv_mode: ConvMode<N>,
419 padding_mode: PaddingMode<N, InElem>,
420 ) -> Result<Array<InElem, Dim<[Ix; N]>>, crate::Error<N>> {
421 let mut p = InElem::get_processor();
422 conv_fft_proc_impl(self, kernel, conv_mode, padding_mode, &mut p, true)
423 }
424
425
426}
427
428#[cfg(test)]
429mod tests;