rustfft/algorithm/
radix4.rs

1use std::sync::Arc;
2
3use num_complex::Complex;
4
5use crate::algorithm::butterflies::{
6    Butterfly1, Butterfly16, Butterfly2, Butterfly32, Butterfly4, Butterfly8,
7};
8use crate::algorithm::radixn::butterfly_4;
9use crate::array_utils::bitreversed_transpose;
10use crate::{common::FftNum, twiddles, FftDirection};
11use crate::{Direction, Fft, Length};
12
13/// FFT algorithm optimized for power-of-two sizes
14///
15/// ~~~
16/// // Computes a forward FFT of size 4096
17/// use rustfft::algorithm::Radix4;
18/// use rustfft::{Fft, FftDirection};
19/// use rustfft::num_complex::Complex;
20///
21/// let mut buffer = vec![Complex{ re: 0.0f32, im: 0.0f32 }; 4096];
22///
23/// let fft = Radix4::new(4096, FftDirection::Forward);
24/// fft.process(&mut buffer);
25/// ~~~
26
27pub struct Radix4<T> {
28    twiddles: Box<[Complex<T>]>,
29
30    base_fft: Arc<dyn Fft<T>>,
31    base_len: usize,
32
33    len: usize,
34    direction: FftDirection,
35    inplace_scratch_len: usize,
36    outofplace_scratch_len: usize,
37    immut_scratch_len: usize,
38}
39
40impl<T: FftNum> Radix4<T> {
41    /// Preallocates necessary arrays and precomputes necessary data to efficiently compute the power-of-two FFT
42    pub fn new(len: usize, direction: FftDirection) -> Self {
43        assert!(
44            len.is_power_of_two(),
45            "Radix4 algorithm requires a power-of-two input size. Got {}",
46            len
47        );
48
49        // figure out which base length we're going to use
50        let exponent = len.trailing_zeros();
51        let (base_exponent, base_fft) = match exponent {
52            0 => (0, Arc::new(Butterfly1::new(direction)) as Arc<dyn Fft<T>>),
53            1 => (1, Arc::new(Butterfly2::new(direction)) as Arc<dyn Fft<T>>),
54            2 => (2, Arc::new(Butterfly4::new(direction)) as Arc<dyn Fft<T>>),
55            3 => (3, Arc::new(Butterfly8::new(direction)) as Arc<dyn Fft<T>>),
56            _ => {
57                if exponent % 2 == 1 {
58                    (5, Arc::new(Butterfly32::new(direction)) as Arc<dyn Fft<T>>)
59                } else {
60                    (4, Arc::new(Butterfly16::new(direction)) as Arc<dyn Fft<T>>)
61                }
62            }
63        };
64
65        Self::new_with_base((exponent - base_exponent) / 2, base_fft)
66    }
67
68    /// Constructs a Radix4 instance which computes FFTs of length `4^k * base_fft.len()`
69    pub fn new_with_base(k: u32, base_fft: Arc<dyn Fft<T>>) -> Self {
70        let base_len = base_fft.len();
71        let len = base_len * (1 << (k * 2));
72
73        let direction = base_fft.fft_direction();
74
75        // precompute the twiddle factors this algorithm will use.
76        // we're doing the same precomputation of twiddle factors as the mixed radix algorithm where width=4 and height=len/4
77        // but mixed radix only does one step and then calls itself recusrively, and this algorithm does every layer all the way down
78        // so we're going to pack all the "layers" of twiddle factors into a single array, starting with the bottom layer and going up
79        const ROW_COUNT: usize = 4;
80        let mut cross_fft_len = base_len;
81        let mut twiddle_factors = Vec::with_capacity(len * 2);
82        while cross_fft_len < len {
83            let num_columns = cross_fft_len;
84            cross_fft_len *= ROW_COUNT;
85
86            for i in 0..num_columns {
87                for k in 1..ROW_COUNT {
88                    let twiddle = twiddles::compute_twiddle(i * k, cross_fft_len, direction);
89                    twiddle_factors.push(twiddle);
90                }
91            }
92        }
93
94        let base_inplace_scratch = base_fft.get_inplace_scratch_len();
95        let inplace_scratch_len = if base_inplace_scratch > cross_fft_len {
96            cross_fft_len + base_inplace_scratch
97        } else {
98            cross_fft_len
99        };
100        let outofplace_scratch_len = if base_inplace_scratch > len {
101            base_inplace_scratch
102        } else {
103            0
104        };
105
106        Self {
107            twiddles: twiddle_factors.into_boxed_slice(),
108
109            base_fft,
110            base_len,
111
112            len,
113            direction,
114
115            inplace_scratch_len,
116            outofplace_scratch_len,
117            immut_scratch_len: base_inplace_scratch,
118        }
119    }
120
121    fn inplace_scratch_len(&self) -> usize {
122        self.inplace_scratch_len
123    }
124    fn outofplace_scratch_len(&self) -> usize {
125        self.outofplace_scratch_len
126    }
127    fn immut_scratch_len(&self) -> usize {
128        self.immut_scratch_len
129    }
130
131    fn perform_fft_immut(
132        &self,
133        input: &[Complex<T>],
134        output: &mut [Complex<T>],
135        scratch: &mut [Complex<T>],
136    ) {
137        // copy the data into the output vector
138        if self.len() == self.base_len {
139            output.copy_from_slice(input);
140        } else {
141            bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
142        }
143
144        self.base_fft.process_with_scratch(output, scratch);
145
146        // cross-FFTs
147        const ROW_COUNT: usize = 4;
148        let mut cross_fft_len = self.base_len;
149        let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
150
151        let butterfly4 = Butterfly4::new(self.direction);
152
153        while cross_fft_len < output.len() {
154            let num_columns = cross_fft_len;
155            cross_fft_len *= ROW_COUNT;
156
157            for data in output.chunks_exact_mut(cross_fft_len) {
158                unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) }
159            }
160
161            // skip past all the twiddle factors used in this layer
162            let twiddle_offset = num_columns * (ROW_COUNT - 1);
163            layer_twiddles = &layer_twiddles[twiddle_offset..];
164        }
165    }
166
167    fn perform_fft_out_of_place(
168        &self,
169        input: &mut [Complex<T>],
170        output: &mut [Complex<T>],
171        scratch: &mut [Complex<T>],
172    ) {
173        // copy the data into the output vector
174        if self.len() == self.base_len {
175            output.copy_from_slice(input);
176        } else {
177            bitreversed_transpose::<Complex<T>, 4>(self.base_len, input, output);
178        }
179
180        // Base-level FFTs
181        let base_scratch = if scratch.len() > 0 { scratch } else { input };
182        self.base_fft.process_with_scratch(output, base_scratch);
183
184        // cross-FFTs
185        const ROW_COUNT: usize = 4;
186        let mut cross_fft_len = self.base_len;
187        let mut layer_twiddles: &[Complex<T>] = &self.twiddles;
188
189        let butterfly4 = Butterfly4::new(self.direction);
190
191        while cross_fft_len < output.len() {
192            let num_columns = cross_fft_len;
193            cross_fft_len *= ROW_COUNT;
194
195            for data in output.chunks_exact_mut(cross_fft_len) {
196                unsafe { butterfly_4(data, layer_twiddles, num_columns, &butterfly4) }
197            }
198
199            // skip past all the twiddle factors used in this layer
200            let twiddle_offset = num_columns * (ROW_COUNT - 1);
201            layer_twiddles = &layer_twiddles[twiddle_offset..];
202        }
203    }
204}
205boilerplate_fft_oop!(Radix4, |this: &Radix4<_>| this.len);
206
207#[cfg(test)]
208mod unit_tests {
209    use super::*;
210    use crate::test_utils::{check_fft_algorithm, construct_base};
211
212    #[test]
213    fn test_radix4_with_length() {
214        for pow in 0..8 {
215            let len = 1 << pow;
216
217            let forward_fft = Radix4::new(len, FftDirection::Forward);
218            check_fft_algorithm::<f32>(&forward_fft, len, FftDirection::Forward);
219
220            let inverse_fft = Radix4::new(len, FftDirection::Inverse);
221            check_fft_algorithm::<f32>(&inverse_fft, len, FftDirection::Inverse);
222        }
223    }
224
225    #[test]
226    fn test_radix4_with_base() {
227        for base in 1..=9 {
228            let base_forward = construct_base(base, FftDirection::Forward);
229            let base_inverse = construct_base(base, FftDirection::Inverse);
230
231            for k in 0..4 {
232                test_radix4(k, Arc::clone(&base_forward));
233                test_radix4(k, Arc::clone(&base_inverse));
234            }
235        }
236    }
237
238    fn test_radix4(k: u32, base_fft: Arc<dyn Fft<f64>>) {
239        let len = base_fft.len() * 4usize.pow(k as u32);
240        let direction = base_fft.fft_direction();
241        let fft = Radix4::new_with_base(k, base_fft);
242
243        check_fft_algorithm::<f64>(&fft, len, direction);
244    }
245}