1use ndarray::{Array, ArrayBase, DataMut, Dim, IntoDimension, Ix, RemoveAxis};
7use num::Complex;
8use rustfft::FftNum;
9
10pub struct Processor<T: FftNum> {
12 rp: realfft::RealFftPlanner<T>,
13 rp_origin_len: usize,
14 cp: rustfft::FftPlanner<T>,
15}
16
17impl<T: FftNum> Default for Processor<T> {
18 fn default() -> Self {
19 Self {
20 rp: Default::default(),
21 rp_origin_len: Default::default(),
22 cp: rustfft::FftPlanner::new(),
23 }
24 }
25}
26
27impl<T: FftNum> Processor<T> {
28 #[allow(clippy::uninit_vec)]
37 pub fn get_scratch<const N: usize>(&mut self, input_dim: [usize; N]) -> Vec<Complex<T>> {
38 let mut output_shape = input_dim;
40 let rp = self.rp.plan_fft_forward(output_shape[N - 1]);
41 let rp_len = rp.get_scratch_len();
42
43 output_shape[N - 1] = rp.complex_len();
44 let cp_len = output_shape
45 .iter()
46 .take(N - 1)
47 .map(|&dim| self.cp.plan_fft_forward(dim).get_inplace_scratch_len())
48 .max()
49 .unwrap_or(0);
50
51 let mut scratch = Vec::with_capacity(rp_len.max(cp_len));
53 unsafe { scratch.set_len(rp_len.max(cp_len)) };
54
55 scratch
56 }
57
58 pub fn forward<S: DataMut<Elem = T>, const N: usize>(
68 &mut self,
69 input: &mut ArrayBase<S, Dim<[Ix; N]>>,
70 ) -> Array<Complex<T>, Dim<[Ix; N]>>
71 where
72 Dim<[Ix; N]>: RemoveAxis,
73 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
74 {
75 let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
76
77 let rp = self.rp.plan_fft_forward(raw_dim[N - 1]);
78 self.rp_origin_len = rp.len();
79
80 let mut output_shape = raw_dim;
81 output_shape[N - 1] = rp.complex_len();
82 let mut output = Array::zeros(output_shape);
83
84 for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
85 rp.process(
86 input.as_slice_mut().unwrap(),
87 output.as_slice_mut().unwrap(),
88 )
89 .unwrap();
90 }
91
92 let mut axes: [usize; N] = std::array::from_fn(|i| i);
93 axes.rotate_right(1);
94 for _ in 0..N - 1 {
95 output_shape.rotate_right(1);
96
97 let mut buffer = Array::uninit(output_shape.into_dimension());
106 buffer.zip_mut_with(&output.permuted_axes(axes), |transpose, &origin| {
107 transpose.write(origin);
108 });
109 output = unsafe { buffer.assume_init() };
110
111 let cp = self.cp.plan_fft_forward(output_shape[N - 1]);
112 cp.process(output.as_slice_mut().unwrap());
113 }
114
115 output
116 }
117
118 pub fn backward<const N: usize>(
127 &mut self,
128 mut input: Array<Complex<T>, Dim<[Ix; N]>>,
129 ) -> Array<T, Dim<[Ix; N]>>
130 where
131 Dim<[Ix; N]>: RemoveAxis,
132 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
133 {
134 let mut raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
136
137 let rp = self.rp.plan_fft_inverse(self.rp_origin_len);
138
139 let mut axes: [usize; N] = std::array::from_fn(|i| i);
140 axes.rotate_left(1);
141 for _ in 0..N - 1 {
142 let cp = self.cp.plan_fft_inverse(raw_dim[N - 1]);
143 cp.process(input.as_slice_mut().unwrap());
144
145 raw_dim.rotate_left(1);
146
147 let mut buffer = Array::uninit(raw_dim.into_dimension());
148 buffer.zip_mut_with(&input.permuted_axes(axes), |transpose, &origin| {
149 transpose.write(origin);
150 });
151 input = unsafe { buffer.assume_init() };
152 }
153
154 let mut output_shape = input.raw_dim();
155 output_shape[N - 1] = self.rp_origin_len;
156 let mut output = Array::zeros(output_shape);
157
158 for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
159 let _ = rp.process(
160 input.as_slice_mut().unwrap(),
161 output.as_slice_mut().unwrap(),
162 );
163 }
164
165 let len = T::from_usize(output.len()).unwrap();
166 output.map_mut(|x| *x = x.div(len));
167 output
168 }
169
170 pub fn forward_with_scratch<S: DataMut<Elem = T>, const N: usize>(
181 &mut self,
182 input: &mut ArrayBase<S, Dim<[Ix; N]>>,
183 scratch: &mut Vec<Complex<T>>,
184 ) -> Array<Complex<T>, Dim<[Ix; N]>>
185 where
186 Dim<[Ix; N]>: RemoveAxis,
187 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
188 {
189 let raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
190
191 let rp = self.rp.plan_fft_forward(raw_dim[N - 1]);
192 self.rp_origin_len = rp.len();
193
194 let mut output_shape = raw_dim;
195 output_shape[N - 1] = rp.complex_len();
196 let mut output = Array::zeros(output_shape);
197
198 for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
199 rp.process_with_scratch(
200 input.as_slice_mut().unwrap(),
201 output.as_slice_mut().unwrap(),
202 scratch,
203 )
204 .unwrap();
205 }
206
207 let mut axes: [usize; N] = std::array::from_fn(|i| i);
208 axes.rotate_right(1);
209 for _ in 0..N - 1 {
210 output_shape.rotate_right(1);
211
212 let mut buffer = Array::uninit(output_shape.into_dimension());
221 buffer.zip_mut_with(&output.permuted_axes(axes), |transpose, &origin| {
222 transpose.write(origin);
223 });
224 output = unsafe { buffer.assume_init() };
225
226 let cp = self.cp.plan_fft_forward(output_shape[N - 1]);
227 cp.process_with_scratch(output.as_slice_mut().unwrap(), scratch);
228 }
229
230 output
231 }
232
233 pub fn backward_with_scratch<const N: usize>(
244 &mut self,
245 mut input: Array<Complex<T>, Dim<[Ix; N]>>,
246 scratch: &mut Vec<Complex<T>>,
247 ) -> Array<T, Dim<[Ix; N]>>
248 where
249 Dim<[Ix; N]>: RemoveAxis,
250 [Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
251 {
252 let mut raw_dim: [usize; N] = std::array::from_fn(|i| input.raw_dim()[i]);
254
255 let rp = self.rp.plan_fft_inverse(self.rp_origin_len);
256
257 let mut axes: [usize; N] = std::array::from_fn(|i| i);
258 axes.rotate_left(1);
259 for _ in 0..N - 1 {
260 let cp = self.cp.plan_fft_inverse(raw_dim[N - 1]);
261 cp.process_with_scratch(input.as_slice_mut().unwrap(), scratch);
262
263 raw_dim.rotate_left(1);
264
265 let mut buffer = Array::uninit(raw_dim.into_dimension());
266 buffer.zip_mut_with(&input.permuted_axes(axes), |transpose, &origin| {
267 transpose.write(origin);
268 });
269 input = unsafe { buffer.assume_init() };
270 }
271
272 let mut output_shape = input.raw_dim();
273 output_shape[N - 1] = self.rp_origin_len;
274 let mut output = Array::zeros(output_shape);
275
276 for (mut input, mut output) in input.rows_mut().into_iter().zip(output.rows_mut()) {
277 let _ = rp.process_with_scratch(
278 input.as_slice_mut().unwrap(),
279 output.as_slice_mut().unwrap(),
280 scratch,
281 );
282 }
283
284 let len = T::from_usize(output.len()).unwrap();
285 output.map_mut(|x| *x = x.div(len));
286 output
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use ndarray::{array, Axis};
294
295 #[test]
296 fn index_axis() {
297 let a = array![[1, 2, 3], [4, 5, 6]];
298
299 let shape = a.shape();
300 for dim in 0..shape.len() {
301 for i in 0..shape[dim] {
302 dbg!(a.index_axis(Axis(dim), i));
303 }
304 }
305 }
306
307 #[test]
308 fn transpose() {
309 let a = array![
310 [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
311 [[13, 14, 15, 16], [17, 18, 19, 20], [21, 22, 23, 24]]
312 ];
313 let mut raw_dim = *unsafe {
314 (&mut a.raw_dim() as *mut _ as *mut [usize; 3])
315 .as_mut()
316 .unwrap()
317 };
318 let mut axes = [0, 1, 2];
323
324 axes.rotate_right(1);
325 raw_dim.rotate_right(1);
326 let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
327 .unwrap();
328 dbg!(&a);
329
330 raw_dim.rotate_right(1);
332 let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
333 .unwrap();
334 dbg!(&a);
335
336 raw_dim.rotate_right(1);
338 let a = Array::from_shape_vec(raw_dim, a.permuted_axes(axes).iter().copied().collect())
339 .unwrap();
340 dbg!(&a);
341 }
342
343 #[test]
344 fn test_forward_backward() {
345 let mut a = array![
346 [[1., 2., 3.], [4., 5., 6.]],
347 [[7., 8., 9.], [10., 11., 12.]]
348 ];
349 let mut p = Processor {
363 rp: realfft::RealFftPlanner::new(),
364 rp_origin_len: 0,
365 cp: rustfft::FftPlanner::new(),
366 };
367
368 let a_fft = p.forward(&mut a);
369
370 dbg!(&a_fft);
371
372 let a = p.backward(a_fft);
373
374 dbg!(&a);
375 }
376
377 #[test]
378 fn test_forward_backward_complex() {
379 let mut arr = array![[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4],]
380 .map(|&v| Complex::new(v as f32, 0.0));
381 let mut fft = rustfft::FftPlanner::new();
382
383 let row_forward = fft.plan_fft_forward(arr.shape()[1]);
385 for mut row in arr.rows_mut() {
386 row_forward.process(row.as_slice_mut().unwrap());
387 }
388
389 let mut arr = Array::from_shape_vec(
391 [arr.shape()[1], arr.shape()[0]],
392 arr.permuted_axes([1, 0]).iter().copied().collect(),
393 )
394 .unwrap();
395
396 let row_forward = fft.plan_fft_forward(arr.shape()[1]);
397 for mut row in arr.rows_mut() {
398 row_forward.process(row.as_slice_mut().unwrap());
399 }
400
401 arr /= Complex::new(16.0, 0.0);
402
403 let row_backward = fft.plan_fft_inverse(arr.shape()[1]);
405 for mut row in arr.rows_mut() {
406 row_backward.process(row.as_slice_mut().unwrap());
407 }
408
409 let mut arr = Array::from_shape_vec(
411 [arr.shape()[1], arr.shape()[0]],
412 arr.permuted_axes([1, 0]).iter().copied().collect(),
413 )
414 .unwrap();
415
416 let row_backward = fft.plan_fft_inverse(arr.shape()[1]);
417 for mut row in arr.rows_mut() {
418 row_backward.process(row.as_slice_mut().unwrap());
419 }
420
421 dbg!(arr);
422 }
423}