rustfft/
array_utils.rs

1use crate::common::RadixFactor;
2use crate::Complex;
3use crate::FftNum;
4use std::ops::{Deref, DerefMut};
5
6/// Given an array of size width * height, representing a flattened 2D array,
7/// transpose the rows and columns of that 2D array into the output
8/// benchmarking shows that loop tiling isn't effective for small arrays (in the range of 50x50 or smaller)
9pub unsafe fn transpose_small<T: Copy>(width: usize, height: usize, input: &[T], output: &mut [T]) {
10    for x in 0..width {
11        for y in 0..height {
12            let input_index = x + y * width;
13            let output_index = y + x * height;
14
15            *output.get_unchecked_mut(output_index) = *input.get_unchecked(input_index);
16        }
17    }
18}
19
20#[allow(unused)]
21pub unsafe fn workaround_transmute<T, U>(slice: &[T]) -> &[U] {
22    let ptr = slice.as_ptr() as *const U;
23    let len = slice.len();
24    std::slice::from_raw_parts(ptr, len)
25}
26#[allow(unused)]
27pub unsafe fn workaround_transmute_mut<T, U>(slice: &mut [T]) -> &mut [U] {
28    let ptr = slice.as_mut_ptr() as *mut U;
29    let len = slice.len();
30    std::slice::from_raw_parts_mut(ptr, len)
31}
32
33pub(crate) trait LoadStore<T: FftNum>: DerefMut {
34    unsafe fn load(&self, idx: usize) -> Complex<T>;
35    unsafe fn store(&mut self, val: Complex<T>, idx: usize);
36}
37
38impl<T: FftNum> LoadStore<T> for &mut [Complex<T>] {
39    #[inline(always)]
40    unsafe fn load(&self, idx: usize) -> Complex<T> {
41        debug_assert!(idx < self.len());
42        *self.get_unchecked(idx)
43    }
44    #[inline(always)]
45    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
46        debug_assert!(idx < self.len());
47        *self.get_unchecked_mut(idx) = val;
48    }
49}
50impl<T: FftNum, const N: usize> LoadStore<T> for &mut [Complex<T>; N] {
51    #[inline(always)]
52    unsafe fn load(&self, idx: usize) -> Complex<T> {
53        debug_assert!(idx < self.len());
54        *self.get_unchecked(idx)
55    }
56    #[inline(always)]
57    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
58        debug_assert!(idx < self.len());
59        *self.get_unchecked_mut(idx) = val;
60    }
61}
62
63pub(crate) struct DoubleBuf<'a, T> {
64    pub input: &'a [Complex<T>],
65    pub output: &'a mut [Complex<T>],
66}
67impl<'a, T> Deref for DoubleBuf<'a, T> {
68    type Target = [Complex<T>];
69    fn deref(&self) -> &Self::Target {
70        self.input
71    }
72}
73impl<'a, T> DerefMut for DoubleBuf<'a, T> {
74    fn deref_mut(&mut self) -> &mut Self::Target {
75        self.output
76    }
77}
78impl<'a, T: FftNum> LoadStore<T> for DoubleBuf<'a, T> {
79    #[inline(always)]
80    unsafe fn load(&self, idx: usize) -> Complex<T> {
81        debug_assert!(idx < self.input.len());
82        *self.input.get_unchecked(idx)
83    }
84    #[inline(always)]
85    unsafe fn store(&mut self, val: Complex<T>, idx: usize) {
86        debug_assert!(idx < self.output.len());
87        *self.output.get_unchecked_mut(idx) = val;
88    }
89}
90
91pub(crate) trait Load<T: FftNum>: Deref {
92    unsafe fn load(&self, idx: usize) -> Complex<T>;
93}
94
95impl<T: FftNum> Load<T> for &[Complex<T>] {
96    #[inline(always)]
97    unsafe fn load(&self, idx: usize) -> Complex<T> {
98        debug_assert!(idx < self.len());
99        *self.get_unchecked(idx)
100    }
101}
102impl<T: FftNum, const N: usize> Load<T> for &[Complex<T>; N] {
103    #[inline(always)]
104    unsafe fn load(&self, idx: usize) -> Complex<T> {
105        debug_assert!(idx < self.len());
106        *self.get_unchecked(idx)
107    }
108}
109
110#[cfg(test)]
111mod unit_tests {
112    use super::*;
113    use crate::test_utils::random_signal;
114    use num_complex::Complex;
115    use num_traits::Zero;
116
117    #[test]
118    fn test_transpose() {
119        let sizes: Vec<usize> = (1..16).collect();
120
121        for &width in &sizes {
122            for &height in &sizes {
123                let len = width * height;
124
125                let input: Vec<Complex<f32>> = random_signal(len);
126                let mut output = vec![Zero::zero(); len];
127
128                unsafe { transpose_small(width, height, &input, &mut output) };
129
130                for x in 0..width {
131                    for y in 0..height {
132                        assert_eq!(
133                            input[x + y * width],
134                            output[y + x * height],
135                            "x = {}, y = {}",
136                            x,
137                            y
138                        );
139                    }
140                }
141            }
142        }
143    }
144}
145
146// Loop over exact chunks of the provided buffer. Very similar in semantics to ChunksExactMut, but generates smaller code and requires no modulo operations
147// Returns Ok() if every element ended up in a chunk, Err() if there was a remainder
148pub fn iter_chunks<T>(
149    mut buffer: &mut [T],
150    chunk_size: usize,
151    mut chunk_fn: impl FnMut(&mut [T]),
152) -> Result<(), ()> {
153    // Loop over the buffer, splicing off chunk_size at a time, and calling chunk_fn on each
154    while buffer.len() >= chunk_size {
155        let (head, tail) = buffer.split_at_mut(chunk_size);
156        buffer = tail;
157
158        chunk_fn(head);
159    }
160
161    // We have a remainder if there's data still in the buffer -- in which case we want to indicate to the caller that there was an unwanted remainder
162    if buffer.len() == 0 {
163        Ok(())
164    } else {
165        Err(())
166    }
167}
168
169// Loop over exact zipped chunks of the 2 provided buffers. Very similar in semantics to ChunksExactMut.zip(ChunksExactMut), but generates smaller code and requires no modulo operations
170// Returns Ok() if every element of both buffers ended up in a chunk, Err() if there was a remainder
171pub fn iter_chunks_zipped<T>(
172    mut buffer1: &mut [T],
173    mut buffer2: &mut [T],
174    chunk_size: usize,
175    mut chunk_fn: impl FnMut(&mut [T], &mut [T]),
176) -> Result<(), ()> {
177    // If the two buffers aren't the same size, record the fact that they're different, then snip them to be the same size
178    let uneven = if buffer1.len() > buffer2.len() {
179        buffer1 = &mut buffer1[..buffer2.len()];
180        true
181    } else if buffer2.len() < buffer1.len() {
182        buffer2 = &mut buffer2[..buffer1.len()];
183        true
184    } else {
185        false
186    };
187
188    // Now that we know the two slices are the same length, loop over each one, splicing off chunk_size at a time, and calling chunk_fn on each
189    while buffer1.len() >= chunk_size && buffer2.len() >= chunk_size {
190        let (head1, tail1) = buffer1.split_at_mut(chunk_size);
191        buffer1 = tail1;
192
193        let (head2, tail2) = buffer2.split_at_mut(chunk_size);
194        buffer2 = tail2;
195
196        chunk_fn(head1, head2);
197    }
198
199    // We have a remainder if the 2 chunks were uneven to start with, or if there's still data in the buffers -- in which case we want to indicate to the caller that there was an unwanted remainder
200    if !uneven && buffer1.len() == 0 {
201        Ok(())
202    } else {
203        Err(())
204    }
205}
206
207// Utility to help reorder data as a part of computing RadixD FFTs. Conceputally, it works like a transpose, but with the column indexes bit-reversed.
208// Use a lookup table to avoid repeating the slow bit reverse operations.
209// Unrolling the outer loop by a factor D helps speed things up.
210// const parameter D (for Divisor) determines the divisor to use for the "bit reverse", and how much to unroll. `input.len() / height` must be a power of D.
211pub fn bitreversed_transpose<T: Copy, const D: usize>(
212    height: usize,
213    input: &[T],
214    output: &mut [T],
215) {
216    let width = input.len() / height;
217
218    // Let's make sure the arguments are ok
219    assert!(D > 1 && input.len() % height == 0 && input.len() == output.len());
220
221    let strided_width = width / D;
222    let rev_digits = if D.is_power_of_two() {
223        let width_bits = width.trailing_zeros();
224        let d_bits = D.trailing_zeros();
225
226        // verify that width is a power of d
227        assert!(width_bits % d_bits == 0);
228        width_bits / d_bits
229    } else {
230        compute_logarithm::<D>(width).unwrap()
231    };
232
233    for x in 0..strided_width {
234        let mut i = 0;
235        let x_fwd = [(); D].map(|_| {
236            let value = D * x + i;
237            i += 1;
238            value
239        }); // If we had access to rustc 1.63, we could use std::array::from_fn instead
240        let x_rev = x_fwd.map(|x| reverse_bits::<D>(x, rev_digits));
241
242        // Assert that the the bit reversed indices will not exceed the length of the output.
243        // The highest index the loop reaches is: (x_rev[n] + 1)*height - 1
244        // The last element of the data is at index: width*height - 1
245        // Thus it is sufficient to assert that x_rev[n]<width.
246        for r in x_rev {
247            assert!(r < width);
248        }
249        for y in 0..height {
250            for (fwd, rev) in x_fwd.iter().zip(x_rev.iter()) {
251                let input_index = *fwd + y * width;
252                let output_index = y + *rev * height;
253
254                unsafe {
255                    let temp = *input.get_unchecked(input_index);
256                    *output.get_unchecked_mut(output_index) = temp;
257                }
258            }
259        }
260    }
261}
262
263// Repeatedly divide `value` by divisor `D`, `iters` times, and apply the remainders to a new value
264// When D is a power of 2, this is exactly equal (implementation and assembly)-wise to a bit reversal
265// When D is not a power of 2, think of this function as a logical equivalent to a bit reversal
266pub fn reverse_bits<const D: usize>(value: usize, rev_digits: u32) -> usize {
267    assert!(D > 1);
268
269    let mut result: usize = 0;
270    let mut value = value;
271    for _ in 0..rev_digits {
272        result = (result * D) + (value % D);
273        value = value / D;
274    }
275    result
276}
277
278// computes `n` such that `D ^ n == value`. Returns `None` if `value` is not a perfect power of `D`, otherwise returns `Some(n)`
279pub fn compute_logarithm<const D: usize>(value: usize) -> Option<u32> {
280    if value == 0 || D < 2 {
281        return None;
282    }
283
284    let mut current_exponent = 0;
285    let mut current_value = value;
286
287    while current_value % D == 0 {
288        current_exponent += 1;
289        current_value /= D;
290    }
291
292    if current_value == 1 {
293        Some(current_exponent)
294    } else {
295        None
296    }
297}
298
299pub(crate) struct TransposeFactor {
300    pub factor: RadixFactor,
301    pub count: u8,
302}
303
304// Utility to help reorder data as a part of computing RadixD FFTs. Conceputally, it works like a transpose, but with the column indexes bit-reversed.
305// Use a lookup table to avoid repeating the slow bit reverse operations.
306// Unrolling the outer loop by a factor D helps speed things up.
307// const parameter D (for Divisor) determines how much to unroll. `input.len() / height` must divisible by D.
308pub(crate) fn factor_transpose<T: Copy, const D: usize>(
309    height: usize,
310    input: &[T],
311    output: &mut [T],
312    factors: &[TransposeFactor],
313) {
314    let width = input.len() / height;
315
316    // Let's make sure the arguments are ok
317    assert!(width % D == 0 && D > 1 && input.len() % width == 0 && input.len() == output.len());
318
319    let strided_width = width / D;
320    for x in 0..strided_width {
321        let mut i = 0;
322        let x_fwd = [(); D].map(|_| {
323            let value = D * x + i;
324            i += 1;
325            value
326        }); // If we had access to rustc 1.63, we could use std::array::from_fn instead
327        let x_rev = x_fwd.map(|x| reverse_remainders(x, factors));
328
329        // Assert that the the bit reversed indices will not exceed the length of the output.
330        // The highest index the loop reaches is: (x_rev[n] + 1)*height - 1
331        // The last element of the data is at index: width*height - 1
332        // Thus it is sufficient to assert that x_rev[n]<width.
333        for r in x_rev {
334            assert!(r < width);
335        }
336        for y in 0..height {
337            for (fwd, rev) in x_fwd.iter().zip(x_rev.iter()) {
338                let input_index = *fwd + y * width;
339                let output_index = y + *rev * height;
340
341                unsafe {
342                    let temp = *input.get_unchecked(input_index);
343                    *output.get_unchecked_mut(output_index) = temp;
344                }
345            }
346        }
347    }
348}
349
350// Divide `value` by the provided array of factors, and push the remainders into a new number
351// When all of the provided factors are 2, this is exactly equal to a bit reversal
352// When some of the factors are not 2, think of this as a "generalization" of a bit reversal, to something like a "Remainder reversal".
353pub(crate) fn reverse_remainders(value: usize, factors: &[TransposeFactor]) -> usize {
354    let mut result: usize = 0;
355    let mut value = value;
356    for f in factors.iter() {
357        match f.factor {
358            RadixFactor::Factor2 => {
359                for _ in 0..f.count {
360                    result = (result * 2) + (value % 2);
361                    value = value / 2;
362                }
363            }
364            RadixFactor::Factor3 => {
365                for _ in 0..f.count {
366                    result = (result * 3) + (value % 3);
367                    value = value / 3;
368                }
369            }
370            RadixFactor::Factor4 => {
371                for _ in 0..f.count {
372                    result = (result * 4) + (value % 4);
373                    value = value / 4;
374                }
375            }
376            RadixFactor::Factor5 => {
377                for _ in 0..f.count {
378                    result = (result * 5) + (value % 5);
379                    value = value / 5;
380                }
381            }
382            RadixFactor::Factor6 => {
383                for _ in 0..f.count {
384                    result = (result * 6) + (value % 6);
385                    value = value / 6;
386                }
387            }
388            RadixFactor::Factor7 => {
389                for _ in 0..f.count {
390                    result = (result * 7) + (value % 7);
391                    value = value / 7;
392                }
393            }
394        }
395    }
396    result
397}