ndarray_conv/conv_fft/mod.rs
1//! Provides FFT-accelerated convolution operations.
2//!
3//! This module offers the `ConvFFTExt` trait, which extends `ndarray`
4//! with FFT-based convolution methods.
5
6use std::{fmt::Debug, marker::PhantomData};
7
8use ndarray::{
9 Array, ArrayBase, Data, Dim, IntoDimension, Ix, RawData, RemoveAxis, SliceArg, SliceInfo,
10 SliceInfoElem,
11};
12use num::{traits::NumAssign, Complex};
13use rustfft::FftNum;
14
15use crate::{conv::ExplicitConv, dilation::IntoKernelWithDilation, ConvMode, PaddingMode};
16
17mod fft;
18mod good_size;
19mod padding;
20
21pub use fft::Processor;
22
23/// Represents a "baked" convolution operation.
24///
25/// This struct holds pre-computed data for performing FFT-accelerated
26/// convolutions, including the FFT size, FFT processor, scratch space,
27/// and padding information. It's designed to optimize repeated
28/// convolutions with the same kernel and settings.
29pub struct Baked<T, SK, const N: usize>
30where
31 T: NumAssign + Debug + FftNum,
32 SK: RawData,
33{
34 fft_size: [usize; N],
35 fft_processor: Processor<T>,
36 scratch: Vec<Complex<T>>,
37 cm: ExplicitConv<N>,
38 padding_mode: PaddingMode<N, T>,
39 kernel_raw_dim_with_dilation: [usize; N],
40 pds_raw_dim: [usize; N],
41 kernel_pd: Array<T, Dim<[Ix; N]>>,
42 _sk_hint: PhantomData<SK>,
43}
44
45/// Extends `ndarray`'s `ArrayBase` with FFT-accelerated convolution operations.
46///
47/// This trait adds the `conv_fft` and `conv_fft_with_processor` methods to `ArrayBase`,
48/// enabling efficient FFT-based convolutions on N-dimensional arrays.
49///
50/// # Type Parameters
51///
52/// * `T`: The numeric type of the array elements. Must be a floating-point type that implements `FftNum`.
53/// * `S`: The data storage type of the input array.
54/// * `SK`: The data storage type of the kernel array.
55///
56/// # Methods
57///
58/// * `conv_fft`: Performs an FFT-accelerated convolution with default settings.
59/// * `conv_fft_with_processor`: Performs an FFT-accelerated convolution using a provided `Processor` instance, allowing for reuse of FFT plans and scratch space.
60/// * `conv_fft_bake`: Precomputes and stores necessary data for performing repeated convolutions in the form of `Baked`.
61/// * `conv_fft_with_baked`: Performs a convolution with the provided `Baked` data.
62///
63/// # Example
64///
65/// ```rust
66/// use ndarray::prelude::*;
67/// use ndarray_conv::{ConvFFTExt, ConvMode, PaddingMode};
68///
69/// let arr = array![[1., 2.], [3., 4.]];
70/// let kernel = array![[1., 0.], [0., 1.]];
71/// let result = arr.conv_fft(&kernel, ConvMode::Same, PaddingMode::Zeros).unwrap();
72/// ```
73///
74/// # Notes
75///
76/// FFT-based convolutions are generally faster for larger kernels but may have higher overhead for smaller kernels.
77pub trait ConvFFTExt<'a, T, S, SK, const N: usize>
78where
79 T: FftNum + NumAssign,
80 S: RawData,
81 SK: RawData,
82{
83 fn conv_fft(
84 &self,
85 kernel: impl IntoKernelWithDilation<'a, SK, N>,
86 conv_mode: ConvMode<N>,
87 padding_mode: PaddingMode<N, T>,
88 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
89
90 fn conv_fft_with_processor(
91 &self,
92 kernel: impl IntoKernelWithDilation<'a, SK, N>,
93 conv_mode: ConvMode<N>,
94 padding_mode: PaddingMode<N, T>,
95 fft_processor: &mut Processor<T>,
96 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>>;
97
98 // fn conv_fft_bake(
99 // &self,
100 // kernel: impl IntoKernelWithDilation<'a, SK, N>,
101 // conv_mode: ConvMode<N>,
102 // padding_mode: PaddingMode<N, T>,
103 // ) -> Result<Baked<T, SK, N>, crate::Error<N>>;
104
105 // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>>;
106}
107
108impl<'a, T, S, SK, const N: usize> ConvFFTExt<'a, T, S, SK, N> for ArrayBase<S, Dim<[Ix; N]>>
109where
110 T: NumAssign + FftNum,
111 S: Data<Elem = T> + 'a,
112 SK: Data<Elem = T> + 'a,
113 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
114 SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
115 SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
116 Dim<[Ix; N]>: RemoveAxis,
117{
118 // fn conv_fft_bake(
119 // &self,
120 // kernel: impl IntoKernelWithDilation<'a, SK, N>,
121 // conv_mode: ConvMode<N>,
122 // padding_mode: PaddingMode<N, T>,
123 // ) -> Result<Baked<T, SK, N>, crate::Error<N>> {
124 // let mut fft_processor = Processor::default();
125
126 // let kwd = kernel.into_kernel_with_dilation();
127
128 // let data_raw_dim = self.raw_dim();
129 // if self.shape().iter().product::<usize>() == 0 {
130 // return Err(crate::Error::DataShape(data_raw_dim));
131 // }
132
133 // let kernel_raw_dim = kwd.kernel.raw_dim();
134 // if kwd.kernel.shape().iter().product::<usize>() == 0 {
135 // return Err(crate::Error::DataShape(kernel_raw_dim));
136 // }
137
138 // let kernel_raw_dim_with_dilation: [usize; N] =
139 // std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
140
141 // let cm = conv_mode.unfold(&kwd);
142
143 // let pds_raw_dim: [usize; N] =
144 // std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
145 // if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
146 // return Err(crate::Error::MismatchShape(
147 // conv_mode,
148 // kernel_raw_dim_with_dilation,
149 // ));
150 // }
151
152 // let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
153 // pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
154 // }));
155
156 // let scratch = fft_processor.get_scratch(fft_size);
157
158 // let kernel_pd = padding::kernel(kwd, fft_size);
159
160 // Ok(Baked {
161 // fft_size,
162 // fft_processor,
163 // scratch,
164 // cm,
165 // padding_mode,
166 // kernel_raw_dim_with_dilation,
167 // pds_raw_dim,
168 // kernel_pd,
169 // _sk_hint: PhantomData,
170 // })
171 // }
172
173 // fn conv_fft_with_baked(&self, baked: &mut Baked<T, SK, N>) -> Array<T, Dim<[Ix; N]>> {
174 // let Baked {
175 // scratch,
176 // fft_processor,
177 // fft_size,
178 // cm,
179 // padding_mode,
180 // kernel_pd,
181 // kernel_raw_dim_with_dilation,
182 // pds_raw_dim,
183 // _sk_hint,
184 // } = baked;
185
186 // let mut data_pd = padding::data(self, *padding_mode, cm.padding, *fft_size);
187
188 // let mut data_pd_fft = fft_processor.forward_with_scratch(&mut data_pd, scratch);
189 // let kernel_pd_fft = fft_processor.forward_with_scratch(kernel_pd, scratch);
190
191 // data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
192 // // let mul_spec = data_pd_fft * kernel_pd_fft;
193
194 // let output = fft_processor.backward(data_pd_fft);
195
196 // output.slice_move(unsafe {
197 // SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
198 // start: kernel_raw_dim_with_dilation[i] as isize - 1,
199 // end: Some((pds_raw_dim[i]) as isize),
200 // step: cm.strides[i] as isize,
201 // }))
202 // .unwrap()
203 // })
204 // }
205
206 fn conv_fft(
207 &self,
208 kernel: impl IntoKernelWithDilation<'a, SK, N>,
209 conv_mode: ConvMode<N>,
210 padding_mode: PaddingMode<N, T>,
211 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
212 let mut p = Processor::default();
213 self.conv_fft_with_processor(kernel, conv_mode, padding_mode, &mut p)
214 }
215
216 fn conv_fft_with_processor(
217 &self,
218 kernel: impl IntoKernelWithDilation<'a, SK, N>,
219 conv_mode: ConvMode<N>,
220 padding_mode: PaddingMode<N, T>,
221 fft_processor: &mut Processor<T>,
222 ) -> Result<Array<T, Dim<[Ix; N]>>, crate::Error<N>> {
223 let kwd = kernel.into_kernel_with_dilation();
224
225 let data_raw_dim = self.raw_dim();
226 if self.shape().iter().product::<usize>() == 0 {
227 return Err(crate::Error::DataShape(data_raw_dim));
228 }
229
230 let kernel_raw_dim = kwd.kernel.raw_dim();
231 if kwd.kernel.shape().iter().product::<usize>() == 0 {
232 return Err(crate::Error::DataShape(kernel_raw_dim));
233 }
234
235 let kernel_raw_dim_with_dilation: [usize; N] =
236 std::array::from_fn(|i| kernel_raw_dim[i] * kwd.dilation[i] - kwd.dilation[i] + 1);
237
238 let cm = conv_mode.unfold(&kwd);
239
240 let pds_raw_dim: [usize; N] =
241 std::array::from_fn(|i| (data_raw_dim[i] + cm.padding[i][0] + cm.padding[i][1]));
242 if !(0..N).all(|i| kernel_raw_dim_with_dilation[i] <= pds_raw_dim[i]) {
243 return Err(crate::Error::MismatchShape(
244 conv_mode,
245 kernel_raw_dim_with_dilation,
246 ));
247 }
248
249 let fft_size = good_size::compute::<N>(&std::array::from_fn(|i| {
250 pds_raw_dim[i].max(kernel_raw_dim_with_dilation[i])
251 }));
252
253 let mut data_pd = padding::data(self, padding_mode, cm.padding, fft_size);
254 let mut kernel_pd = padding::kernel(kwd, fft_size);
255
256 let mut data_pd_fft = fft_processor.forward(&mut data_pd);
257 let kernel_pd_fft = fft_processor.forward(&mut kernel_pd);
258
259 data_pd_fft.zip_mut_with(&kernel_pd_fft, |d, k| *d *= *k);
260 // let mul_spec = data_pd_fft * kernel_pd_fft;
261
262 let output = fft_processor.backward(data_pd_fft);
263
264 let output = output.slice_move(unsafe {
265 SliceInfo::new(std::array::from_fn(|i| SliceInfoElem::Slice {
266 start: kernel_raw_dim_with_dilation[i] as isize - 1,
267 end: Some((pds_raw_dim[i]) as isize),
268 step: cm.strides[i] as isize,
269 }))
270 .unwrap()
271 });
272
273 Ok(output)
274 }
275}
276
277#[cfg(test)]
278mod tests {
279 use ndarray::array;
280
281 use crate::{dilation::WithDilation, ConvExt, ReverseKernel};
282
283 use super::*;
284
285 #[test]
286 fn correct_size() {
287 let arr = array![[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]];
288 let kernel = array![[1, 0], [3, 1]];
289
290 let res_normal = arr
291 .conv(&kernel, ConvMode::Same, PaddingMode::Replicate)
292 .unwrap();
293 // dbg!(res_normal);
294
295 let res_fft = arr
296 .map(|&x| x as f64)
297 // The padding does not matter here, it is only used to calculate the correct size
298 .conv_fft(
299 &kernel.map(|&x| x as f64),
300 ConvMode::Same,
301 PaddingMode::Replicate,
302 )
303 .unwrap()
304 .map(|x| x.round() as i32);
305 // dbg!(res_fft);
306
307 assert_eq!(res_normal, res_fft);
308 }
309
310 #[test]
311 fn conv_fft() {
312 let arr = array![[[1, 2], [3, 4]], [[5, 6], [7, 8]]];
313 let kernel = array![
314 [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
315 [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
316 ];
317
318 let res_normal = arr
319 .conv(&kernel, ConvMode::Same, PaddingMode::Zeros)
320 .unwrap();
321 // dbg!(res_normal);
322
323 let res_fft = arr
324 .map(|&x| x as f32)
325 .conv_fft(
326 &kernel.map(|&x| x as f32),
327 ConvMode::Same,
328 PaddingMode::Zeros,
329 )
330 .unwrap()
331 .map(|x| x.round() as i32);
332 // dbg!(res_fft);
333
334 assert_eq!(res_normal, res_fft);
335
336 //
337
338 let arr = array![[1, 2], [3, 4]];
339 let kernel = array![[1, 0], [3, 1]];
340
341 let res_normal = arr
342 .conv(
343 kernel.with_dilation(2).no_reverse(),
344 ConvMode::Custom {
345 padding: [3, 3],
346 strides: [2, 2],
347 },
348 PaddingMode::Replicate,
349 )
350 .unwrap();
351 // dbg!(res_normal);
352
353 let res_fft = arr
354 .map(|&x| x as f64)
355 .conv_fft(
356 kernel.map(|&x| x as f64).with_dilation(2).no_reverse(),
357 ConvMode::Custom {
358 padding: [3, 3],
359 strides: [2, 2],
360 },
361 PaddingMode::Replicate,
362 )
363 .unwrap()
364 .map(|x| x.round() as i32);
365 // dbg!(res_fft);
366
367 assert_eq!(res_normal, res_fft);
368
369 //
370
371 let arr = array![1, 2, 3, 4, 5, 6];
372 let kernel = array![1, 1, 1, 1];
373
374 let res_normal = arr
375 .conv(kernel.with_dilation(2), ConvMode::Same, PaddingMode::Zeros)
376 .unwrap();
377 // dbg!(&res_normal);
378
379 let res_fft = arr
380 .map(|&x| x as f32)
381 .conv_fft(
382 kernel.map(|&x| x as f32).with_dilation(2),
383 ConvMode::Same,
384 PaddingMode::Zeros,
385 )
386 .unwrap()
387 .map(|x| x.round() as i32);
388 // dbg!(res_fft);
389
390 assert_eq!(res_normal, res_fft);
391 }
392
393 #[test]
394 fn test_conv_fft_circular() {
395 use crate::*;
396 use ndarray::Array1;
397
398 let arr: Array1<f32> =
399 array![0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4, 0.0, 0.1, 0.3, 0.4];
400 let kernel: Array1<f32> = array![0.1, 0.3, 0.6, 0.3, 0.1];
401
402 arr.conv(&kernel, crate::ConvMode::Same, crate::PaddingMode::Circular)
403 .unwrap()
404 .iter()
405 .zip(
406 arr.conv_fft(&kernel, crate::ConvMode::Same, crate::PaddingMode::Circular)
407 .unwrap(),
408 )
409 .for_each(|(a, b)| assert!((a - b).abs() < 1e-6));
410 }
411}