ndarray_conv/conv_fft/
fft.rs

1//! Provides FFT-related functionality for convolution operations.
2//!
3//! This module includes the `Processor` struct, which manages FFT
4//! planning and execution using `rustfft` and `realfft` crates.
5
6use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
7use num::Complex;
8use rustfft::FftNum;
9
10/// Manages FFT planning and execution for convolution operations.
11pub struct Processor<T: FftNum> {
12    rp: realfft::RealFftPlanner<T>,
13    rp_origin_len: usize,
14    cp: rustfft::FftPlanner<T>,
15}
16
17impl<T: FftNum> Default for Processor<T> {
18    fn default() -> Self {
19        Self {
20            rp: Default::default(),
21            rp_origin_len: Default::default(),
22            cp: rustfft::FftPlanner::new(),
23        }
24    }
25}
26
27impl<T: FftNum> Processor<T> {
28    /// Creates a scratch buffer for FFT operations.
29    ///
30    /// This method calculates the required size for the scratch buffer based on
31    /// the input dimensions and creates a vector with uninitialized memory to use
32    /// as the scratch buffer for the real and complex FFTs.
33    ///
34    /// # Arguments
35    ///
36    #[allow(clippy::uninit_vec)]
37    pub fn get_scratch<const N: usize>(&mut self, input_dim: [usize; N]) -> Vec<Complex<T>> {
38        // needs to check backward len
39        let mut output_shape = input_dim;
40        let rp = self.rp.plan_fft_forward(output_shape[N - 1]);
41        let rp_len = rp.get_scratch_len();
42
43        output_shape[N - 1] = rp.complex_len();
44        let cp_len = output_shape
45            .iter()
46            .take(N - 1)
47            .map(|&dim| self.cp.plan_fft_forward(dim).get_inplace_scratch_len())
48            .max()
49            .unwrap_or(0);
50
51        // avoid init mem
52        let mut scratch = Vec::with_capacity(rp_len.max(cp_len));
53        unsafe { scratch.set_len(rp_len.max(cp_len)) };
54
55        scratch
56    }
57
58    /// Performs a forward FFT on the given input array.
59    ///
60    /// This method computes the forward Fast Fourier Transform of the input array using
61    /// a real-to-complex FFT in the last dimension and complex-to-complex FFTs in the other dimensions.
62    ///
63    /// # Arguments
64    ///
65    /// *   `input`: A mutable reference to the input array.
66    ///
67    pub fn forward<S: DataMut<Elem = T>, const N: usize>(
68        &mut self,
69        input: &mut ArrayBase<S, Dim<[Ix; N]>>,
70    ) -> Array<Complex<T>, Dim<[Ix; N]>>
71    where
72        Dim<[Ix; N]>: RemoveAxis,
73        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
74    {
75        let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
76
77        let rp = self.rp.plan_fft_forward(raw_dim[N - 1]);
78        self.rp_origin_len = rp.len();
79
80        let mut output_shape = raw_dim;
81        output_shape[N - 1] = rp.complex_len();
82        let mut output = Array::zeros(output_shape);
83
84        for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
85            rp.process(
86                input.as_slice_mut().unwrap(),
87                output.as_slice_mut().unwrap(),
88            )
89            .unwrap();
90        }
91
92        let mut axes: [usize; N] = std::array::from_fn(|i| i);
93        axes.rotate_right(1);
94        for _ in 0..N - 1 {
95            output_shape.rotate_right(1);
96
97            // transpose takes a lot of time
98            // this method is very slow
99            // input = Array::from_shape_vec(
100            //     raw_dim.into_dimension(),
101            //     input.permuted_axes(axes).iter().copied().collect(),
102            // )
103            // .unwrap();
104
105            let mut buffer = Array::uninit(output_shape.into_dimension());
106            buffer.zip_mut_with(&output.permuted_axes(axes), |transpose, &origin| {
107                transpose.write(origin);
108            });
109            output = unsafe { buffer.assume_init() };
110
111            let cp = self.cp.plan_fft_forward(output_shape[N - 1]);
112            cp.process(output.as_slice_mut().unwrap());
113        }
114
115        output
116    }
117
118    /// Performs an inverse FFT on the given input array.
119    ///
120    /// This method computes the inverse Fast Fourier Transform of the input array using
121    /// a complex-to-real FFT in the last dimension and complex-to-complex FFTs in the other dimensions.
122    ///
123    /// # Arguments
124    ///
125    /// *   `input`: The input array.
126    pub fn backward<const N: usize>(
127        &mut self,
128        mut input: Array<Complex<T>, Dim<[Ix; N]>>,
129    ) -> Array<T, Dim<[Ix; N]>>
130    where
131        Dim<[Ix; N]>: RemoveAxis,
132        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
133    {
134        // at this time, the raw_dim has been routate_left by N - 1 times
135        let mut raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
136
137        let rp = self.rp.plan_fft_inverse(self.rp_origin_len);
138
139        let mut axes: [usize; N] = std::array::from_fn(|i| i);
140        axes.rotate_left(1);
141        for _ in 0..N - 1 {
142            let cp = self.cp.plan_fft_inverse(raw_dim[N - 1]);
143            cp.process(input.as_slice_mut().unwrap());
144
145            raw_dim.rotate_left(1);
146
147            let mut buffer = Array::uninit(raw_dim.into_dimension());
148            buffer.zip_mut_with(&input.permuted_axes(axes), |transpose, &origin| {
149                transpose.write(origin);
150            });
151            input = unsafe { buffer.assume_init() };
152        }
153
154        let mut output_shape = input.raw_dim();
155        output_shape[N - 1] = self.rp_origin_len;
156        let mut output = Array::zeros(output_shape);
157
158        for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
159            let _ = rp.process(
160                input.as_slice_mut().unwrap(),
161                output.as_slice_mut().unwrap(),
162            );
163        }
164
165        let len = T::from_usize(output.len()).unwrap();
166        output.map_mut(|x| *x = x.div(len));
167        output
168    }
169
170    /// Performs a forward FFT on the given input array using a scratch buffer.
171    ///
172    /// This method computes the forward Fast Fourier Transform of the input array using
173    /// a real-to-complex FFT in the last dimension and complex-to-complex FFTs in the other dimensions.
174    /// It uses the given scratch buffer for FFT calculations, potentially improving performance
175    /// for multiple FFT calls.
176    ///
177    /// # Arguments
178    ///
179    /// *   `input`: A mutable reference to the input array.
180    pub fn forward_with_scratch<S: DataMut<Elem = T>, const N: usize>(
181        &mut self,
182        input: &mut ArrayBase<S, Dim<[Ix; N]>>,
183        scratch: &mut Vec<Complex<T>>,
184    ) -> Array<Complex<T>, Dim<[Ix; N]>>
185    where
186        Dim<[Ix; N]>: RemoveAxis,
187        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
188    {
189        let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
190
191        let rp = self.rp.plan_fft_forward(raw_dim[N - 1]);
192        self.rp_origin_len = rp.len();
193
194        let mut output_shape = raw_dim;
195        output_shape[N - 1] = rp.complex_len();
196        let mut output = Array::zeros(output_shape);
197
198        for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
199            rp.process_with_scratch(
200                input.as_slice_mut().unwrap(),
201                output.as_slice_mut().unwrap(),
202                scratch,
203            )
204            .unwrap();
205        }
206
207        let mut axes: [usize; N] = std::array::from_fn(|i| i);
208        axes.rotate_right(1);
209        for _ in 0..N - 1 {
210            output_shape.rotate_right(1);
211
212            // transpose takes a lot of time
213            // this method is very slow
214            // input = Array::from_shape_vec(
215            //     raw_dim.into_dimension(),
216            //     input.permuted_axes(axes).iter().copied().collect(),
217            // )
218            // .unwrap();
219
220            let mut buffer = Array::uninit(output_shape.into_dimension());
221            buffer.zip_mut_with(&output.permuted_axes(axes), |transpose, &origin| {
222                transpose.write(origin);
223            });
224            output = unsafe { buffer.assume_init() };
225
226            let cp = self.cp.plan_fft_forward(output_shape[N - 1]);
227            cp.process_with_scratch(output.as_slice_mut().unwrap(), scratch);
228        }
229
230        output
231    }
232
233    /// Performs an inverse FFT on the given input array using a scratch buffer.
234    ///
235    /// This method computes the inverse Fast Fourier Transform of the input array using
236    /// a complex-to-real FFT in the last dimension and complex-to-complex FFTs in the other dimensions.
237    /// It uses the given scratch buffer for FFT calculations, potentially improving performance
238    /// for multiple FFT calls.
239    ///
240    /// # Arguments
241    ///
242    /// *   `input`: The input array.
243    pub fn backward_with_scratch<const N: usize>(
244        &mut self,
245        mut input: Array<Complex<T>, Dim<[Ix; N]>>,
246        scratch: &mut Vec<Complex<T>>,
247    ) -> Array<T, Dim<[Ix; N]>>
248    where
249        Dim<[Ix; N]>: RemoveAxis,
250        [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
251    {
252        // at this time, the raw_dim has been routate_left by N - 1 times
253        let mut raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
254
255        let rp = self.rp.plan_fft_inverse(self.rp_origin_len);
256
257        let mut axes: [usize; N] = std::array::from_fn(|i| i);
258        axes.rotate_left(1);
259        for _ in 0..N - 1 {
260            let cp = self.cp.plan_fft_inverse(raw_dim[N - 1]);
261            cp.process_with_scratch(input.as_slice_mut().unwrap(), scratch);
262
263            raw_dim.rotate_left(1);
264
265            let mut buffer = Array::uninit(raw_dim.into_dimension());
266            buffer.zip_mut_with(&input.permuted_axes(axes), |transpose, &origin| {
267                transpose.write(origin);
268            });
269            input = unsafe { buffer.assume_init() };
270        }
271
272        let mut output_shape = input.raw_dim();
273        output_shape[N - 1] = self.rp_origin_len;
274        let mut output = Array::zeros(output_shape);
275
276        for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
277            let _ = rp.process_with_scratch(
278                input.as_slice_mut().unwrap(),
279                output.as_slice_mut().unwrap(),
280                scratch,
281            );
282        }
283
284        let len = T::from_usize(output.len()).unwrap();
285        output.map_mut(|x| *x = x.div(len));
286        output
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use ndarray::{array, Axis};
294
295    #[test]
296    fn index_axis() {
297        let a = array![[1, 2, 3], [4, 5, 6]];
298
299        let shape = a.shape();
300        for dim in 0..shape.len() {
301            for i in 0..shape[dim] {
302                dbg!(a.index_axis(Axis(dim), i));
303            }
304        }
305    }
306
307    #[test]
308    fn transpose() {
309        let a = array![
310            [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
311            [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
312        ];
313        let mut raw_dim = *unsafe {
314            (&mut a.raw_dim() as *mut _ as *mut [usize; 3])
315                .as_mut()
316                .unwrap()
317        };
318        // dbg!(&a);
319        // dbg!(a.t());
320        // dbg!(a.t().t());
321
322        let mut axes = [0, 1, 2];
323
324        axes.rotate_right(1);
325        raw_dim.rotate_right(1);
326        let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
327            .unwrap();
328        dbg!(&a);
329
330        // axes.rotate_left(1);
331        raw_dim.rotate_right(1);
332        let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
333            .unwrap();
334        dbg!(&a);
335
336        // axes.rotate_left(1);
337        raw_dim.rotate_right(1);
338        let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
339            .unwrap();
340        dbg!(&a);
341    }
342
343    #[test]
344    fn test_forward_backward() {
345        let mut a = array![
346            [[1., 2., 3.], [4., 5., 6.]],
347            [[7., 8., 9.], [10., 11., 12.]]
348        ];
349        // let mut a = array![1., 2., 3.];
350        // let kernel = array![
351        //     [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
352        //     [[1, 1, 1], [1, 1, 1], [1, 1, 1]],
353        // ];
354
355        // conv_fft::padding::data(
356        //     &a,
357        //     PaddingMode::Zeros,
358        //     ConvMode::Same.unfold(&kernel),
359        //     [2, 2, 3],
360        // );
361
362        let mut p = Processor {
363            rp: realfft::RealFftPlanner::new(),
364            rp_origin_len: 0,
365            cp: rustfft::FftPlanner::new(),
366        };
367
368        let a_fft = p.forward(&mut a);
369
370        dbg!(&a_fft);
371
372        let a = p.backward(a_fft);
373
374        dbg!(&a);
375    }
376
377    #[test]
378    fn test_forward_backward_complex() {
379        let mut arr = array![[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4],]
380            .map(|&v| Complex::new(v as f32, 0.0));
381        let mut fft = rustfft::FftPlanner::new();
382
383        // forward
384        let row_forward = fft.plan_fft_forward(arr.shape()[1]);
385        for mut row in arr.rows_mut() {
386            row_forward.process(row.as_slice_mut().unwrap());
387        }
388
389        // transpose
390        let mut arr = Array::from_shape_vec(
391            [arr.shape()[1], arr.shape()[0]],
392            arr.permuted_axes([1, 0]).iter().copied().collect(),
393        )
394        .unwrap();
395
396        let row_forward = fft.plan_fft_forward(arr.shape()[1]);
397        for mut row in arr.rows_mut() {
398            row_forward.process(row.as_slice_mut().unwrap());
399        }
400
401        arr /= Complex::new(16.0, 0.0);
402
403        // backward
404        let row_backward = fft.plan_fft_inverse(arr.shape()[1]);
405        for mut row in arr.rows_mut() {
406            row_backward.process(row.as_slice_mut().unwrap());
407        }
408
409        // transpose
410        let mut arr = Array::from_shape_vec(
411            [arr.shape()[1], arr.shape()[0]],
412            arr.permuted_axes([1, 0]).iter().copied().collect(),
413        )
414        .unwrap();
415
416        let row_backward = fft.plan_fft_inverse(arr.shape()[1]);
417        for mut row in arr.rows_mut() {
418            row_backward.process(row.as_slice_mut().unwrap());
419        }
420
421        dbg!(arr);
422    }
423}