Skip to main content

feanor_math/algorithms/fft/
cooley_tuckey.rs

1use std::alloc::{Allocator, Global};
2use std::fmt::Debug;
3use std::ops::Range;
4
5use super::complex_fft::*;
6use crate::algorithms::fft::*;
7use crate::algorithms::unity_root::*;
8use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
9use crate::homomorphism::*;
10use crate::rings::float_complex::*;
11use crate::rings::zn::*;
12use crate::seq::{SwappableVectorViewMut, VectorViewMut};
13
14/// An optimized implementation of the Cooley-Tukey FFT algorithm, to compute
15/// the Fourier transform of an array with power-of-two length.
16///
17/// # Example
18/// ```rust
19/// # use feanor_math::assert_el_eq;
20/// # use feanor_math::ring::*;
21/// # use feanor_math::algorithms::fft::*;
22/// # use feanor_math::rings::zn::*;
23/// # use feanor_math::algorithms::fft::cooley_tuckey::*;
24/// // this ring has a 256-th primitive root of unity
25/// let ring = zn_64::Zn::new(257);
26/// let fft_table = CooleyTuckeyFFT::for_zn(ring, 8).unwrap();
27/// let mut data = [ring.one()]
28///     .into_iter()
29///     .chain((0..255).map(|_| ring.zero()))
30///     .collect::<Vec<_>>();
31/// fft_table.unordered_fft(&mut data, &ring);
32/// assert_el_eq!(ring, ring.one(), data[0]);
33/// assert_el_eq!(ring, ring.one(), data[1]);
34/// ```
35///
36/// # Convention
37///
38/// This implementation does not follows the standard convention for the mathematical
39/// DFT, by performing the standard/forward FFT with the inverse root of unity `z^-1`.
40/// In other words, the forward FFT computes
41/// ```text
42///   (a_0, ..., a_(n - 1)) -> (sum_j a_j z^(-ij))_i
43/// ```
44/// as demonstrated by
45/// ```rust
46/// # use feanor_math::assert_el_eq;
47/// # use feanor_math::ring::*;
48/// # use feanor_math::algorithms::fft::*;
49/// # use feanor_math::rings::zn::*;
50/// # use feanor_math::algorithms::fft::cooley_tuckey::*;
51/// # use feanor_math::homomorphism::*;
52/// # use feanor_math::divisibility::*;
53/// // this ring has a 4-th primitive root of unity
54/// let ring = zn_64::Zn::new(5);
55/// let root_of_unity = ring.int_hom().map(2);
56/// let fft_table = CooleyTuckeyFFT::new(ring, root_of_unity, 2);
57/// let mut data = [ring.one(), ring.one(), ring.zero(), ring.zero()];
58/// fft_table.fft(&mut data, ring);
59/// let inv_root_of_unity = ring.invert(&root_of_unity).unwrap();
60/// assert_el_eq!(ring, ring.add(ring.one(), inv_root_of_unity), data[1]);
61/// ```
62///
63/// # On optimizations
64///
65/// I tried my best to make this as fast as possible in general, with special focus
66/// on the Number-theoretic transform case. I did not implement the following
67/// optimizations, for the following reasons:
68///  - Larger butterflies: This would improve data locality, but decrease twiddle locality (or
69///    increase arithmetic operation count). Since I focused mainly on the `Z/nZ` case, where the
70///    twiddles are larger than the ring elements (since they have additional data to speed up
71///    multiplications), this is not sensible.
72///  - The same reasoning applies to a SplitRadix approach, which only actually decreases the total
73///    number of operations if multiplication-by-`i` is free.
74pub struct CooleyTuckeyFFT<R_main, R_twiddle, H, A = Global>
75where
76    R_main: ?Sized + RingBase,
77    R_twiddle: ?Sized + RingBase,
78    H: Homomorphism<R_twiddle, R_main>,
79    A: Allocator,
80{
81    hom: H,
82    root_of_unity: R_main::Element,
83    log2_n: usize,
84    // stores the powers of `root_of_unity^-1` in special bitreversed order
85    root_of_unity_list: Vec<Vec<R_twiddle::Element>>,
86    // stores the powers of `root_of_unity` in special bitreversed order
87    inv_root_of_unity_list: Vec<Vec<R_twiddle::Element>>,
88    allocator: A,
89    two_inv: R_twiddle::Element,
90    n_inv: R_twiddle::Element,
91}
92
93/// Assumes that `index` has only the least significant `bits` bits set.
94/// Then computes the value that results from reversing the least significant `bits`
95/// bits.
96pub fn bitreverse(index: usize, bits: usize) -> usize {
97    index.reverse_bits().checked_shr(usize::BITS - bits as u32).unwrap_or(0)
98}
99
100#[inline(never)]
101fn butterfly_loop<T, S, F>(
102    log2_n: usize,
103    data: &mut [T],
104    butterfly_range: Range<usize>,
105    stride_range: Range<usize>,
106    log2_step: usize,
107    twiddles: &[S],
108    butterfly: F,
109) where
110    F: Fn(&mut T, &mut T, &S) + Clone,
111{
112    assert_eq!(1 << log2_n, data.len());
113    assert!(log2_step < log2_n);
114
115    // the coefficients of a group of inputs have this distance to each other
116    let stride = 1 << (log2_n - log2_step - 1);
117    assert!(stride_range.start <= stride_range.end);
118    assert!(stride_range.end <= stride);
119
120    // how many butterflies we compute within each group
121    assert!(butterfly_range.start <= butterfly_range.end);
122    assert!(butterfly_range.end <= (1 << log2_step));
123    assert!(butterfly_range.end <= twiddles.len());
124
125    let current_data = &mut data[(stride_range.start + butterfly_range.start * 2 * stride)..];
126    let stride_range_len = stride_range.end - stride_range.start;
127
128    if stride == 1 && stride_range_len == 1 {
129        for (twiddle, butterfly_data) in twiddles[butterfly_range]
130            .iter()
131            .zip(current_data.as_chunks_mut::<2>().0.iter_mut())
132        {
133            let [a, b] = butterfly_data.each_mut();
134            butterfly(a, b, twiddle);
135        }
136    } else if stride_range_len >= 1 {
137        for (twiddle, butterfly_data) in twiddles[butterfly_range]
138            .iter()
139            .zip(current_data.chunks_mut(2 * stride))
140        {
141            let (first, second) = butterfly_data[..(stride + stride_range_len)].split_at_mut(stride);
142            let (first_chunks, first_rem) = first[..stride_range_len].as_chunks_mut::<4>();
143            let (second_chunks, second_rem) = second.as_chunks_mut::<4>();
144            for (a, b) in first_chunks.iter_mut().zip(second_chunks.iter_mut()) {
145                butterfly(&mut a[0], &mut b[0], twiddle);
146                butterfly(&mut a[1], &mut b[1], twiddle);
147                butterfly(&mut a[2], &mut b[2], twiddle);
148                butterfly(&mut a[3], &mut b[3], twiddle);
149            }
150            for (a, b) in first_rem.iter_mut().zip(second_rem.iter_mut()) {
151                butterfly(a, b, twiddle);
152            }
153        }
154    }
155}
156
157impl<R_main, H> CooleyTuckeyFFT<R_main, Complex64Base, H, Global>
158where
159    R_main: ?Sized + RingBase,
160    H: Homomorphism<Complex64Base, R_main>,
161{
162    /// Creates an [`CooleyTuckeyFFT`] for the complex field, using the given homomorphism
163    /// to connect the ring implementation for twiddles with the main ring implementation.
164    ///
165    /// This function is mainly provided for parity with other rings, since in the complex case
166    /// it currently does not make much sense to use a different homomorphism than the identity.
167    /// Hence, it is simpler to use [`CooleyTuckeyFFT::for_complex()`].
168    pub fn for_complex_with_hom(hom: H, log2_n: usize) -> Self {
169        let CC = *hom.domain().get_ring();
170        Self::new_with_pows_with_hom(hom, |i| CC.root_of_unity(i, 1 << log2_n), log2_n)
171    }
172}
173
174impl<R> CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<R>, Global>
175where
176    R: RingStore<Type = Complex64Base>,
177{
178    /// Creates an [`CooleyTuckeyFFT`] for the complex field.
179    pub fn for_complex(ring: R, log2_n: usize) -> Self { Self::for_complex_with_hom(ring.into_identity(), log2_n) }
180}
181
182impl<R> CooleyTuckeyFFT<R::Type, R::Type, Identity<R>, Global>
183where
184    R: RingStore,
185    R::Type: DivisibilityRing,
186{
187    /// Creates an [`CooleyTuckeyFFT`] for the given ring, using the given root of unity.
188    ///
189    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
190    /// will incur avoidable precision loss.
191    pub fn new(ring: R, root_of_unity: El<R>, log2_n: usize) -> Self {
192        Self::new_with_hom(ring.into_identity(), root_of_unity, log2_n)
193    }
194
195    /// Creates an [`CooleyTuckeyFFT`] for the given ring, using the passed function to
196    /// provide the necessary roots of unity.
197    ///
198    /// Concretely, `root_of_unity_pow(i)` should return `z^i`, where `z` is a `2^log2_n`-th
199    /// primitive root of unity.
200    pub fn new_with_pows<F>(ring: R, root_of_unity_pow: F, log2_n: usize) -> Self
201    where
202        F: FnMut(i64) -> El<R>,
203    {
204        Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_pow, log2_n)
205    }
206
207    /// Creates an [`CooleyTuckeyFFT`] for a prime field, assuming it has a characteristic
208    /// congruent to 1 modulo `2^log2_n`.
209    pub fn for_zn(ring: R, log2_n: usize) -> Option<Self>
210    where
211        R::Type: ZnRing,
212    {
213        Self::for_zn_with_hom(ring.into_identity(), log2_n)
214    }
215}
216
217impl<R_main, R_twiddle, H> CooleyTuckeyFFT<R_main, R_twiddle, H, Global>
218where
219    R_main: ?Sized + RingBase,
220    R_twiddle: ?Sized + RingBase + DivisibilityRing,
221    H: Homomorphism<R_twiddle, R_main>,
222{
223    /// Creates an [`CooleyTuckeyFFT`] for the given rings, using the given root of unity.
224    ///
225    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
226    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
227    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
228    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
229    ///
230    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
231    /// will incur avoidable precision loss.
232    pub fn new_with_hom(hom: H, root_of_unity: R_twiddle::Element, log2_n: usize) -> Self {
233        let ring = hom.domain();
234        let root_of_unity_pow = |i: i64| {
235            if i >= 0 {
236                ring.pow(ring.clone_el(&root_of_unity), i as usize)
237            } else {
238                ring.invert(&ring.pow(ring.clone_el(&root_of_unity), (-i) as usize))
239                    .unwrap()
240            }
241        };
242        let result = CooleyTuckeyFFT::create(&hom, root_of_unity_pow, log2_n, Global);
243
244        return CooleyTuckeyFFT {
245            root_of_unity_list: result.root_of_unity_list,
246            inv_root_of_unity_list: result.inv_root_of_unity_list,
247            two_inv: result.two_inv,
248            n_inv: result.n_inv,
249            root_of_unity: result.root_of_unity,
250            log2_n: result.log2_n,
251            allocator: result.allocator,
252            hom,
253        };
254    }
255
256    /// Creates an [`CooleyTuckeyFFT`] for the given rings, using the given function to create
257    /// the necessary powers of roots of unity.
258    ///
259    /// Concretely, `root_of_unity_pow(i)` should return `z^i`, where `z` is a `2^log2_n`-th
260    /// primitive root of unity.
261    ///
262    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
263    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
264    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
265    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
266    pub fn new_with_pows_with_hom<F>(hom: H, root_of_unity_pow: F, log2_n: usize) -> Self
267    where
268        F: FnMut(i64) -> R_twiddle::Element,
269    {
270        Self::create(hom, root_of_unity_pow, log2_n, Global)
271    }
272
273    /// Creates an [`CooleyTuckeyFFT`] for the given prime fields, assuming they have
274    /// a characteristic congruent to 1 modulo `2^log2_n`.
275    ///
276    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
277    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
278    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
279    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
280    pub fn for_zn_with_hom(hom: H, log2_n: usize) -> Option<Self>
281    where
282        R_twiddle: ZnRing,
283    {
284        let root_of_unity = get_prim_root_of_unity_zn(hom.domain(), 1 << log2_n)?;
285        Some(Self::new_with_hom(hom, root_of_unity, log2_n))
286    }
287}
288
289impl<R_main, R_twiddle, H, A> PartialEq for CooleyTuckeyFFT<R_main, R_twiddle, H, A>
290where
291    R_main: ?Sized + RingBase,
292    R_twiddle: ?Sized + RingBase + DivisibilityRing,
293    H: Homomorphism<R_twiddle, R_main>,
294    A: Allocator,
295{
296    fn eq(&self, other: &Self) -> bool {
297        self.ring().get_ring() == other.ring().get_ring()
298            && self.log2_n == other.log2_n
299            && self
300                .ring()
301                .eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
302    }
303}
304
305impl<R_main, R_twiddle, H, A> Debug for CooleyTuckeyFFT<R_main, R_twiddle, H, A>
306where
307    R_main: ?Sized + RingBase + Debug,
308    R_twiddle: ?Sized + RingBase + DivisibilityRing,
309    H: Homomorphism<R_twiddle, R_main>,
310    A: Allocator,
311{
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        f.debug_struct("CooleyTuckeyFFT")
314            .field("ring", &self.ring().get_ring())
315            .field("n", &(1 << self.log2_n))
316            .field("root_of_unity", &self.ring().format(self.root_of_unity(self.ring())))
317            .finish()
318    }
319}
320
321impl<R_main, R_twiddle, H, A> Clone for CooleyTuckeyFFT<R_main, R_twiddle, H, A>
322where
323    R_main: ?Sized + RingBase,
324    R_twiddle: ?Sized + RingBase + DivisibilityRing,
325    H: Homomorphism<R_twiddle, R_main> + Clone,
326    A: Allocator + Clone,
327{
328    fn clone(&self) -> Self {
329        Self {
330            two_inv: self.hom.domain().clone_el(&self.two_inv),
331            n_inv: self.hom.domain().clone_el(&self.n_inv),
332            hom: self.hom.clone(),
333            inv_root_of_unity_list: self
334                .inv_root_of_unity_list
335                .iter()
336                .map(|list| list.iter().map(|x| self.hom.domain().clone_el(x)).collect())
337                .collect(),
338            root_of_unity_list: self
339                .root_of_unity_list
340                .iter()
341                .map(|list| list.iter().map(|x| self.hom.domain().clone_el(x)).collect())
342                .collect(),
343            root_of_unity: self.hom.codomain().clone_el(&self.root_of_unity),
344            log2_n: self.log2_n,
345            allocator: self.allocator.clone(),
346        }
347    }
348}
349
350/// A helper trait that defines the Cooley-Tukey butterfly operation.
351/// It is default-implemented for all rings, but for increase FFT performance, some rings
352/// might wish to provide a specialization.
353///
354/// # Why not a subtrait of [`Homomorphism`]?
355///
356/// With the current design, indeed making this a subtrait of [`Homomorphism`] would
357/// indeed be the conceptually most fitting choice. It would allow specializing on
358/// the twiddle ring, the main ring and the inclusion.
359///
360/// Unfortunately, there is a technical issue: With the current `min_specialization`,
361/// we can only specialize on concrete type. If this is a subtrait of [`Homomorphism`], this
362/// means we can only specialize on, say, `CanHom<ZnFastmul, Zn>`, which then does not give a
363/// specialization for `CanHom<&ZnFastmul, Zn>` - in other words, we would specialize on
364/// the [`RingStore`], and not on the [`RingBase`] as we should. Hence, we'll keep this
365/// suboptimal design until full specialization works.
366pub trait CooleyTuckeyButterfly<S>: RingBase
367where
368    S: ?Sized + RingBase,
369{
370    /// Should compute `(values[i1], values[i2]) := (values[i1] + twiddle * values[i2], values[i1] -
371    /// twiddle * values[i2])`.
372    ///
373    /// It is guaranteed that the input elements are either outputs of
374    /// [`CooleyTuckeyButterfly::butterfly()`] or of [`CooleyTuckeyButterfly::prepare_for_fft()`].
375    ///
376    /// Deprecated in favor of [`CooleyTuckeyButterfly::butterfly_new()`].
377    #[deprecated]
378    fn butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(
379        &self,
380        hom: H,
381        values: &mut V,
382        twiddle: &S::Element,
383        i1: usize,
384        i2: usize,
385    );
386
387    /// Should compute `(values[i1], values[i2]) := (values[i1] + values[i2], (values[i1] -
388    /// values[i2]) * twiddle)`
389    ///
390    /// It is guaranteed that the input elements are either outputs of
391    /// [`CooleyTuckeyButterfly::inv_butterfly()`] or of
392    /// [`CooleyTuckeyButterfly::prepare_for_inv_fft()`].
393    ///
394    /// Deprecated in favor of [`CooleyTuckeyButterfly::inv_butterfly_new()`].
395    #[deprecated]
396    fn inv_butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(
397        &self,
398        hom: H,
399        values: &mut V,
400        twiddle: &S::Element,
401        i1: usize,
402        i2: usize,
403    );
404
405    /// Should compute `(x, y) := (x + twiddle * y, x - twiddle * y)`.
406    ///
407    /// It is guaranteed that the input elements are either outputs of
408    /// [`CooleyTuckeyButterfly::butterfly_new()`] or of
409    /// [`CooleyTuckeyButterfly::prepare_for_fft()`].
410    fn butterfly_new<H: Homomorphism<S, Self>>(
411        hom: H,
412        x: &mut Self::Element,
413        y: &mut Self::Element,
414        twiddle: &S::Element,
415    );
416
417    /// Should compute `(x, y) := (x + y, (x - y) * twiddle)`
418    ///
419    /// It is guaranteed that the input elements are either outputs of
420    /// [`CooleyTuckeyButterfly::inv_butterfly_new()`] or of
421    /// [`CooleyTuckeyButterfly::prepare_for_inv_fft()`].
422    fn inv_butterfly_new<H: Homomorphism<S, Self>>(
423        hom: H,
424        x: &mut Self::Element,
425        y: &mut Self::Element,
426        twiddle: &S::Element,
427    );
428
429    /// Possibly pre-processes elements before the FFT starts. Here you can
430    /// bring ring element into a certain form, and assume during
431    /// [`CooleyTuckeyButterfly::butterfly_new()`] that the inputs are in this form.
432    #[inline(always)]
433    fn prepare_for_fft(&self, _value: &mut Self::Element) {}
434
435    /// Possibly pre-processes elements before the inverse FFT starts. Here you can
436    /// bring ring element into a certain form, and assume during
437    /// [`CooleyTuckeyButterfly::inv_butterfly_new()`] that the inputs are in this form.
438    #[inline(always)]
439    fn prepare_for_inv_fft(&self, _value: &mut Self::Element) {}
440}
441
442#[allow(deprecated)]
443impl<R, S> CooleyTuckeyButterfly<S> for R
444where
445    S: ?Sized + RingBase,
446    R: ?Sized + RingBase,
447{
448    #[inline(always)]
449    default fn butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(
450        &self,
451        hom: H,
452        values: &mut V,
453        twiddle: &<S as RingBase>::Element,
454        i1: usize,
455        i2: usize,
456    ) {
457        hom.mul_assign_ref_map(values.at_mut(i2), twiddle);
458        let new_a = self.add_ref(values.at(i1), values.at(i2));
459        let a = std::mem::replace(values.at_mut(i1), new_a);
460        self.sub_self_assign(values.at_mut(i2), a);
461    }
462
463    #[inline(always)]
464    #[allow(deprecated)]
465    default fn butterfly_new<H: Homomorphism<S, Self>>(
466        hom: H,
467        x: &mut Self::Element,
468        y: &mut Self::Element,
469        twiddle: &S::Element,
470    ) {
471        let mut values = [hom.codomain().clone_el(x), hom.codomain().clone_el(y)];
472        <Self as CooleyTuckeyButterfly<S>>::butterfly(hom.codomain().get_ring(), &hom, &mut values, twiddle, 0, 1);
473        [*x, *y] = values;
474    }
475
476    #[inline(always)]
477    default fn inv_butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(
478        &self,
479        hom: H,
480        values: &mut V,
481        twiddle: &<S as RingBase>::Element,
482        i1: usize,
483        i2: usize,
484    ) {
485        let new_a = self.add_ref(values.at(i1), values.at(i2));
486        let a = std::mem::replace(values.at_mut(i1), new_a);
487        self.sub_self_assign(values.at_mut(i2), a);
488        hom.mul_assign_ref_map(values.at_mut(i2), twiddle);
489    }
490
491    #[inline(always)]
492    #[allow(deprecated)]
493    default fn inv_butterfly_new<H: Homomorphism<S, Self>>(
494        hom: H,
495        x: &mut Self::Element,
496        y: &mut Self::Element,
497        twiddle: &S::Element,
498    ) {
499        let mut values = [hom.codomain().clone_el(x), hom.codomain().clone_el(y)];
500        <Self as CooleyTuckeyButterfly<S>>::inv_butterfly(hom.codomain().get_ring(), &hom, &mut values, twiddle, 0, 1);
501        [*x, *y] = values;
502    }
503
504    #[inline(always)]
505    default fn prepare_for_fft(&self, _value: &mut Self::Element) {}
506
507    #[inline(always)]
508    default fn prepare_for_inv_fft(&self, _value: &mut Self::Element) {}
509}
510
511impl<R_main, R_twiddle, H, A> CooleyTuckeyFFT<R_main, R_twiddle, H, A>
512where
513    R_main: ?Sized + RingBase,
514    R_twiddle: ?Sized + RingBase + DivisibilityRing,
515    H: Homomorphism<R_twiddle, R_main>,
516    A: Allocator,
517{
518    /// Most general way to create a [`CooleyTuckeyFFT`].
519    ///
520    /// This is currently the same as [`CooleyTuckeyFFT::new_with_pows_with_hom()`], except
521    /// that it additionally accepts an allocator, which is used to copy the input data in
522    /// cases where the input data layout is not optimal for the algorithm.
523    #[stability::unstable(feature = "enable")]
524    pub fn create<F>(hom: H, mut root_of_unity_pow: F, log2_n: usize, allocator: A) -> Self
525    where
526        F: FnMut(i64) -> R_twiddle::Element,
527    {
528        let ring = hom.domain();
529        assert!(ring.is_commutative());
530        assert!(ring.get_ring().is_approximate() || is_prim_root_of_unity_pow2(&ring, &root_of_unity_pow(1), log2_n));
531        assert!(
532            hom.codomain().get_ring().is_approximate()
533                || is_prim_root_of_unity_pow2(&hom.codomain(), &hom.map(root_of_unity_pow(1)), log2_n)
534        );
535
536        let root_of_unity_list = Self::create_root_of_unity_list(|i| root_of_unity_pow(-i), log2_n);
537        let inv_root_of_unity_list = Self::create_root_of_unity_list(&mut root_of_unity_pow, log2_n);
538        let root_of_unity = root_of_unity_pow(1);
539
540        let store_twiddle_ring = root_of_unity_list.len();
541        CooleyTuckeyFFT {
542            root_of_unity_list: root_of_unity_list.into_iter().take(store_twiddle_ring).collect(),
543            inv_root_of_unity_list: inv_root_of_unity_list.into_iter().take(store_twiddle_ring).collect(),
544            two_inv: hom.domain().invert(&hom.domain().int_hom().map(2)).unwrap(),
545            n_inv: hom.domain().invert(&hom.domain().int_hom().map(1 << log2_n)).unwrap(),
546            root_of_unity: hom.map(root_of_unity),
547            hom,
548            log2_n,
549            allocator,
550        }
551    }
552
553    /// Replaces the ring that this object can compute FFTs over, assuming that the current
554    /// twiddle factors can be mapped into the new ring with the given homomorphism.
555    ///
556    /// In particular, this function does not recompute twiddles, but uses a different
557    /// homomorphism to map the current twiddles into a new ring. Hence, it is extremely
558    /// cheap.
559    #[stability::unstable(feature = "enable")]
560    pub fn change_ring<R_new: ?Sized + RingBase, H_new: Homomorphism<R_twiddle, R_new>>(
561        self,
562        new_hom: H_new,
563    ) -> (CooleyTuckeyFFT<R_new, R_twiddle, H_new, A>, H) {
564        let ring = new_hom.codomain();
565        let root_of_unity = if self.log2_n == 0 {
566            new_hom.codomain().one()
567        } else {
568            new_hom.map_ref(&self.inv_root_of_unity_list[self.log2_n - 1][bitreverse(1, self.log2_n - 1)])
569        };
570        assert!(ring.is_commutative());
571        assert!(ring.get_ring().is_approximate() || is_prim_root_of_unity_pow2(&ring, &root_of_unity, self.log2_n));
572
573        return (
574            CooleyTuckeyFFT {
575                root_of_unity_list: self.root_of_unity_list,
576                inv_root_of_unity_list: self.inv_root_of_unity_list,
577                two_inv: self.two_inv,
578                n_inv: self.n_inv,
579                root_of_unity,
580                hom: new_hom,
581                log2_n: self.log2_n,
582                allocator: self.allocator,
583            },
584            self.hom,
585        );
586    }
587
588    fn create_root_of_unity_list<F>(mut root_of_unity_pow: F, log2_n: usize) -> Vec<Vec<R_twiddle::Element>>
589    where
590        F: FnMut(i64) -> R_twiddle::Element,
591    {
592        let mut twiddles: Vec<Vec<R_twiddle::Element>> = (0..log2_n).map(|_| Vec::new()).collect();
593        for log2_step in 0..log2_n {
594            let butterfly_count = 1 << log2_step;
595            for i in 0..butterfly_count {
596                twiddles[log2_step].push(root_of_unity_pow(bitreverse(i, log2_n - 1) as i64));
597            }
598        }
599        return twiddles;
600    }
601
602    /// Returns the ring over which this object can compute FFTs.
603    pub fn ring(&self) -> &<H as Homomorphism<R_twiddle, R_main>>::CodomainStore { self.hom.codomain() }
604
605    /// Computes the main butterfly step, either forward or backward (without division by two).
606    ///
607    /// The forward butterfly is
608    /// ```text
609    ///   (a, b) -> (a + twiddle * b, a - twiddle * b)
610    /// ```
611    /// The backward butterfly is
612    /// ```text
613    ///   (u, v) -> (u + v, twiddle * (u - v))
614    /// ```
615    ///
616    /// The `#[inline(never)]` here is absolutely important for performance!
617    /// No idea why...
618    #[inline(never)]
619    fn butterfly_step_main<const INV: bool, const IS_PREPARED: bool>(
620        &self,
621        data: &mut [R_main::Element],
622        butterfly_range: Range<usize>,
623        stride_range: Range<usize>,
624        log2_step: usize,
625    ) {
626        let twiddles = if INV {
627            &self.inv_root_of_unity_list[log2_step]
628        } else {
629            &self.root_of_unity_list[log2_step]
630        };
631        // let start = std::time::Instant::now();
632        let butterfly = |a: &mut _, b: &mut _, twiddle: &_| {
633            if INV {
634                if !IS_PREPARED {
635                    <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_inv_fft(self.ring().get_ring(), a);
636                    <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_inv_fft(self.ring().get_ring(), b);
637                }
638                <R_main as CooleyTuckeyButterfly<R_twiddle>>::inv_butterfly_new(&self.hom, a, b, twiddle);
639            } else {
640                if !IS_PREPARED {
641                    <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_fft(self.ring().get_ring(), a);
642                    <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_fft(self.ring().get_ring(), b);
643                }
644                <R_main as CooleyTuckeyButterfly<R_twiddle>>::butterfly_new(&self.hom, a, b, twiddle);
645            }
646        };
647        butterfly_loop(
648            self.log2_n,
649            data,
650            butterfly_range,
651            stride_range,
652            log2_step,
653            twiddles,
654            butterfly,
655        );
656        // let end = std::time::Instant::now();
657        // BUTTERFLY_TIMES[log2_step].fetch_add((end - start).as_micros() as usize,
658        // std::sync::atomic::Ordering::Relaxed);
659    }
660
661    /// The definitions are
662    /// ```text
663    ///   u = a/2 + twiddle * b/2,
664    ///   v = a/2 - twiddle * b/2
665    /// ```
666    #[inline(never)]
667    fn butterfly_ub_from_ab(
668        &self,
669        data: &mut [R_main::Element],
670        butterfly_range: Range<usize>,
671        stride_range: Range<usize>,
672        log2_step: usize,
673    ) {
674        butterfly_loop(
675            self.log2_n,
676            data,
677            butterfly_range,
678            stride_range,
679            log2_step,
680            &self.root_of_unity_list[log2_step],
681            |a, b, twiddle| {
682                *a = self.hom.mul_ref_snd_map(
683                    self.ring().add_ref_fst(a, self.hom.mul_ref_map(b, twiddle)),
684                    &self.two_inv,
685                );
686            },
687        );
688    }
689
690    /// The definitions are
691    /// ```text
692    ///   u = a/2 + twiddle * b/2,
693    ///   v = a/2 - twiddle * b/2
694    /// ```
695    #[inline(never)]
696    fn butterfly_uv_from_ub(
697        &self,
698        data: &mut [R_main::Element],
699        butterfly_range: Range<usize>,
700        stride_range: Range<usize>,
701        log2_step: usize,
702    ) {
703        butterfly_loop(
704            self.log2_n,
705            data,
706            butterfly_range,
707            stride_range,
708            log2_step,
709            &self.root_of_unity_list[log2_step],
710            |a, b, twiddle| {
711                *b = self.ring().sub_ref_fst(a, self.hom.mul_ref_map(b, twiddle));
712            },
713        );
714    }
715
716    /// The definitions are
717    /// ```text
718    ///   u = a/2 + twiddle * b/2,
719    ///   v = a/2 - twiddle * b/2
720    /// ```
721    #[inline(never)]
722    fn butterfly_ab_from_ub(
723        &self,
724        data: &mut [R_main::Element],
725        butterfly_range: Range<usize>,
726        stride_range: Range<usize>,
727        log2_step: usize,
728    ) {
729        butterfly_loop(
730            self.log2_n,
731            data,
732            butterfly_range,
733            stride_range,
734            log2_step,
735            &self.root_of_unity_list[log2_step],
736            |a, b, twiddle| {
737                *a = self.ring().add_ref(a, a);
738                self.ring().sub_assign(a, self.hom.mul_ref_map(b, twiddle));
739            },
740        );
741    }
742
743    /// Returns a reference to the allocator currently used for temporary allocations by this FFT.
744    #[stability::unstable(feature = "enable")]
745    pub fn allocator(&self) -> &A { &self.allocator }
746
747    /// Replaces the allocator used for temporary allocations by this FFT.
748    #[stability::unstable(feature = "enable")]
749    pub fn with_allocator<A_new: Allocator>(self, allocator: A_new) -> CooleyTuckeyFFT<R_main, R_twiddle, H, A_new> {
750        CooleyTuckeyFFT {
751            root_of_unity_list: self.root_of_unity_list,
752            inv_root_of_unity_list: self.inv_root_of_unity_list,
753            two_inv: self.two_inv,
754            n_inv: self.n_inv,
755            root_of_unity: self.root_of_unity,
756            hom: self.hom,
757            log2_n: self.log2_n,
758            allocator,
759        }
760    }
761
762    /// Returns a reference to the homomorphism that is used to map the stored twiddle
763    /// factors into main ring, over which FFTs are computed.
764    #[stability::unstable(feature = "enable")]
765    pub fn hom(&self) -> &H { &self.hom }
766
767    /// Computes the unordered, truncated FFT.
768    ///
769    /// The truncated FFT is the standard DFT, applied to a list for which only the first
770    /// `nonzero_entries` entries are nonzero, and the (bitreversed) result truncated to
771    /// length `nonzero_entries`.
772    ///
773    /// Therefore, this function is equivalent to the following pseudocode
774    /// ```text
775    /// data[nonzero_entries..] = 0;
776    /// unordered_fft(data);
777    /// data[nonzero_entries] = unspecified;
778    /// ```
779    ///
780    /// It can be inverted using [`CooleyTuckey::unordered_truncated_fft_inv()`].
781    #[stability::unstable(feature = "enable")]
782    pub fn unordered_truncated_fft(&self, data: &mut [R_main::Element], nonzero_entries: usize) {
783        assert_eq!(self.len(), data.len());
784        assert!(nonzero_entries > self.len() / 2);
785        assert!(nonzero_entries <= self.len());
786        for i in nonzero_entries..self.len() {
787            debug_assert!(self.ring().get_ring().is_approximate() || self.ring().is_zero(&data[i]));
788        }
789
790        for i in 0..data.len() {
791            <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_fft(self.ring().get_ring(), &mut data[i]);
792        }
793        for log2_step in 0..self.log2_n {
794            let stride = 1 << (self.log2_n - log2_step - 1);
795            let butterfly_count = nonzero_entries.div_ceil(2 * stride);
796            self.butterfly_step_main::<false, true>(data, 0..butterfly_count, 0..stride, log2_step);
797        }
798    }
799
800    /// Computes the inverse of the unordered, truncated FFT.
801    ///
802    /// The truncated FFT is the standard DFT, applied to a list for which only the first
803    /// `nonzero_entries` entries are nonzero, and the (bitreversed) result truncated to
804    /// length `nonzero_entries`. Therefore, this function computes a list of `nonzero_entries`
805    /// many values, followed by zeros, whose DFT agrees with the input on the first
806    /// `nonzero_entries` many elements.
807    #[stability::unstable(feature = "enable")]
808    pub fn unordered_truncated_fft_inv(&self, data: &mut [R_main::Element], nonzero_entries: usize) {
809        assert_eq!(self.len(), data.len());
810        assert!(nonzero_entries > self.len() / 2);
811        assert!(nonzero_entries <= self.len());
812
813        for i in 0..data.len() {
814            <R_main as CooleyTuckeyButterfly<R_twiddle>>::prepare_for_inv_fft(self.ring().get_ring(), &mut data[i]);
815        }
816        for log2_step in (0..self.log2_n).rev() {
817            let stride = 1 << (self.log2_n - log2_step - 1);
818            let current_block = nonzero_entries / (2 * stride);
819            self.butterfly_step_main::<true, true>(data, 0..current_block, 0..stride, log2_step);
820        }
821        if nonzero_entries < (1 << self.log2_n) {
822            for i in nonzero_entries..(1 << self.log2_n) {
823                data[i] = self.ring().zero();
824            }
825            for log2_step in 0..self.log2_n {
826                let stride = 1 << (self.log2_n - log2_step - 1);
827                let current_block = nonzero_entries / (2 * stride);
828                let known_area = nonzero_entries % (2 * stride);
829                if known_area >= stride {
830                    self.butterfly_uv_from_ub(
831                        data,
832                        current_block..(current_block + 1),
833                        (known_area - stride)..stride,
834                        log2_step,
835                    );
836                } else {
837                    self.butterfly_ub_from_ab(data, current_block..(current_block + 1), known_area..stride, log2_step);
838                }
839            }
840            for log2_step in (0..self.log2_n).rev() {
841                let stride = 1 << (self.log2_n - log2_step - 1);
842                let current_block = nonzero_entries / (2 * stride);
843                let known_area = nonzero_entries % (2 * stride);
844                if known_area >= stride {
845                    self.butterfly_step_main::<true, false>(
846                        data,
847                        current_block..(current_block + 1),
848                        0..stride,
849                        log2_step,
850                    );
851                } else {
852                    self.butterfly_ab_from_ub(data, current_block..(current_block + 1), 0..stride, log2_step);
853                }
854            }
855        }
856        for i in 0..(1 << self.log2_n) {
857            self.hom.mul_assign_ref_map(&mut data[i], &self.n_inv);
858        }
859    }
860
861    /// Permutes the given list of length `n` according to `values[bitreverse(i, log2(n))] =
862    /// values[i]`. This is exactly the permutation that is implicitly applied by
863    /// [`CooleyTuckeyFFT::unordered_fft()`].
864    pub fn bitreverse_permute_inplace<V, T>(&self, mut values: V)
865    where
866        V: SwappableVectorViewMut<T>,
867    {
868        assert!(values.len() == 1 << self.log2_n);
869        for i in 0..(1 << self.log2_n) {
870            if bitreverse(i, self.log2_n) < i {
871                values.swap(i, bitreverse(i, self.log2_n));
872            }
873        }
874    }
875}
876
877impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for CooleyTuckeyFFT<R_main, R_twiddle, H, A>
878where
879    R_main: ?Sized + RingBase,
880    R_twiddle: ?Sized + RingBase + DivisibilityRing,
881    H: Homomorphism<R_twiddle, R_main>,
882    A: Allocator,
883{
884    fn len(&self) -> usize { 1 << self.log2_n }
885
886    fn root_of_unity<S: Copy + RingStore<Type = R_main>>(&self, ring: S) -> &R_main::Element {
887        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
888        &self.root_of_unity
889    }
890
891    fn unordered_fft_permutation(&self, i: usize) -> usize { bitreverse(i, self.log2_n) }
892
893    fn unordered_fft_permutation_inv(&self, i: usize) -> usize { bitreverse(i, self.log2_n) }
894
895    fn fft<V, S>(&self, mut values: V, ring: S)
896    where
897        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
898        S: RingStore<Type = R_main> + Copy,
899    {
900        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
901        assert_eq!(self.len(), values.len());
902        self.unordered_fft(&mut values, ring);
903        self.bitreverse_permute_inplace(&mut values);
904    }
905
906    fn inv_fft<V, S>(&self, mut values: V, ring: S)
907    where
908        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
909        S: RingStore<Type = R_main> + Copy,
910    {
911        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
912        assert_eq!(self.len(), values.len());
913        self.bitreverse_permute_inplace(&mut values);
914        self.unordered_inv_fft(&mut values, ring);
915    }
916
917    fn unordered_fft<V, S>(&self, mut values: V, ring: S)
918    where
919        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
920        S: RingStore<Type = R_main> + Copy,
921    {
922        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
923        assert_eq!(self.len(), values.len());
924        if let Some(data) = values.as_slice_mut() {
925            self.unordered_truncated_fft(data, 1 << self.log2_n);
926        } else {
927            let mut data = Vec::with_capacity_in(1 << self.log2_n, &self.allocator);
928            data.extend(values.clone_ring_els(ring).iter());
929            self.unordered_truncated_fft(&mut data, 1 << self.log2_n);
930            for (i, x) in data.into_iter().enumerate() {
931                *values.at_mut(i) = x;
932            }
933        }
934    }
935
936    fn unordered_inv_fft<V, S>(&self, mut values: V, ring: S)
937    where
938        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
939        S: RingStore<Type = R_main> + Copy,
940    {
941        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
942        assert_eq!(self.len(), values.len());
943        if let Some(data) = values.as_slice_mut() {
944            self.unordered_truncated_fft_inv(data, 1 << self.log2_n);
945        } else {
946            let mut data = Vec::with_capacity_in(1 << self.log2_n, &self.allocator);
947            data.extend(values.clone_ring_els(ring).iter());
948            self.unordered_truncated_fft_inv(&mut data, 1 << self.log2_n);
949            for (i, x) in data.into_iter().enumerate() {
950                *values.at_mut(i) = x;
951            }
952        }
953    }
954}
955
956impl<H, A> FFTErrorEstimate for CooleyTuckeyFFT<Complex64Base, Complex64Base, H, A>
957where
958    H: Homomorphism<Complex64Base, Complex64Base>,
959    A: Allocator,
960{
961    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
962        // the butterfly performs a multiplication with a root of unity, and an addition
963        let multiply_absolute_error = input_bound * root_of_unity_error() + input_bound * f64::EPSILON;
964        let addition_absolute_error = input_bound * f64::EPSILON;
965        let butterfly_absolute_error = multiply_absolute_error + addition_absolute_error;
966        // the operator inf-norm of the FFT is its length
967        return 2.0 * self.len() as f64 * butterfly_absolute_error + self.len() as f64 * input_error;
968    }
969}
970
971#[cfg(test)]
972use crate::field::*;
973#[cfg(test)]
974use crate::primitive_int::*;
975#[cfg(test)]
976use crate::rings::finite::FiniteRingStore;
977#[cfg(test)]
978use crate::rings::zn::zn_big;
979#[cfg(test)]
980use crate::rings::zn::zn_static;
981#[cfg(test)]
982use crate::rings::zn::zn_static::Fp;
983
984#[test]
985fn test_bitreverse_fft_inplace_basic() {
986    let ring = Fp::<5>::RING;
987    let z = ring.int_hom().map(2);
988    let fft = CooleyTuckeyFFT::new(ring, ring.div(&1, &z), 2);
989    let mut values = [1, 0, 0, 1];
990    let expected = [2, 4, 0, 3];
991    let mut bitreverse_expected = [0; 4];
992    for i in 0..4 {
993        bitreverse_expected[i] = expected[bitreverse(i, 2)];
994    }
995
996    fft.unordered_fft(&mut values, ring);
997    assert_eq!(values, bitreverse_expected);
998}
999
1000#[test]
1001fn test_bitreverse_fft_inplace_advanced() {
1002    let ring = Fp::<17>::RING;
1003    let z = ring.int_hom().map(3);
1004    let fft = CooleyTuckeyFFT::new(ring, z, 4);
1005    let mut values = [1, 0, 0, 0, 1, 0, 0, 0, 4, 3, 2, 1, 4, 3, 2, 1];
1006    let expected = [5, 2, 0, 11, 5, 4, 0, 6, 6, 13, 0, 1, 7, 6, 0, 1];
1007    let mut bitreverse_expected = [0; 16];
1008    for i in 0..16 {
1009        bitreverse_expected[i] = expected[bitreverse(i, 4)];
1010    }
1011
1012    fft.unordered_fft(&mut values, ring);
1013    assert_eq!(values, bitreverse_expected);
1014}
1015
1016#[test]
1017fn test_unordered_fft_permutation() {
1018    let ring = Fp::<17>::RING;
1019    let fft = CooleyTuckeyFFT::for_zn(&ring, 4).unwrap();
1020    let mut values = [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
1021    let mut expected = [0; 16];
1022    for i in 0..16 {
1023        let power_of_zeta = ring.pow(*fft.root_of_unity(&ring), 16 - fft.unordered_fft_permutation(i));
1024        expected[i] = ring.add(power_of_zeta, ring.pow(power_of_zeta, 4));
1025    }
1026    fft.unordered_fft(&mut values, ring);
1027    assert_eq!(expected, values);
1028}
1029
1030#[test]
1031fn test_bitreverse_inv_fft_inplace() {
1032    let ring = Fp::<17>::RING;
1033    let fft = CooleyTuckeyFFT::for_zn(&ring, 4).unwrap();
1034    let values: [u64; 16] = [1, 2, 3, 2, 1, 0, 17 - 1, 17 - 2, 17 - 1, 0, 1, 2, 3, 4, 5, 6];
1035    let mut work = values;
1036    fft.unordered_fft(&mut work, ring);
1037    fft.unordered_inv_fft(&mut work, ring);
1038    assert_eq!(&work, &values);
1039}
1040
1041#[test]
1042fn test_truncated_fft() {
1043    let ring = Fp::<17>::RING;
1044    let fft = CooleyTuckeyFFT::new(ring, 2, 3);
1045
1046    let data = [2, 3, 0, 1, 1, 0, 0, 0];
1047    let mut complete_fft = data;
1048    fft.unordered_fft(&mut complete_fft, ring);
1049    for k in 5..=8 {
1050        println!("{}", k);
1051        let mut truncated_fft = data;
1052        fft.unordered_truncated_fft(&mut truncated_fft, k);
1053        assert_eq!(&complete_fft[..k], &truncated_fft[..k]);
1054
1055        fft.unordered_truncated_fft_inv(&mut truncated_fft, k);
1056        assert_eq!(data, truncated_fft);
1057    }
1058}
1059
1060#[test]
1061fn test_for_zn() {
1062    let ring = Fp::<17>::RING;
1063    let fft = CooleyTuckeyFFT::for_zn(ring, 4).unwrap();
1064    assert!(ring.is_neg_one(&ring.pow(fft.root_of_unity, 8)));
1065
1066    let ring = Fp::<97>::RING;
1067    let fft = CooleyTuckeyFFT::for_zn(ring, 4).unwrap();
1068    assert!(ring.is_neg_one(&ring.pow(fft.root_of_unity, 8)));
1069}
1070
1071#[cfg(test)]
1072fn run_fft_bench_round<R, S, H>(fft: &CooleyTuckeyFFT<S, R, H>, data: &Vec<S::Element>, copy: &mut Vec<S::Element>)
1073where
1074    R: ZnRing,
1075    S: ZnRing,
1076    H: Homomorphism<R, S>,
1077{
1078    copy.clear();
1079    copy.extend(data.iter().map(|x| fft.ring().clone_el(x)));
1080    fft.unordered_fft(&mut copy[..], &fft.ring());
1081    fft.unordered_inv_fft(&mut copy[..], &fft.ring());
1082    assert_el_eq!(fft.ring(), copy[0], data[0]);
1083}
1084
1085#[cfg(test)]
1086const BENCH_SIZE_LOG2: usize = 13;
1087
1088#[bench]
1089fn bench_fft_zn_big(bencher: &mut test::Bencher) {
1090    let ring = zn_big::Zn::new(StaticRing::<i128>::RING, 1073872897);
1091    let fft = CooleyTuckeyFFT::for_zn(&ring, BENCH_SIZE_LOG2).unwrap();
1092    let data = (0..(1 << BENCH_SIZE_LOG2))
1093        .map(|i| ring.int_hom().map(i))
1094        .collect::<Vec<_>>();
1095    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
1096    bencher.iter(|| run_fft_bench_round(&fft, &data, &mut copy));
1097}
1098
1099#[bench]
1100fn bench_fft_zn_64(bencher: &mut test::Bencher) {
1101    let ring = zn_64::Zn::new(1073872897);
1102    let fft = CooleyTuckeyFFT::for_zn(&ring, BENCH_SIZE_LOG2).unwrap();
1103    let data = (0..(1 << BENCH_SIZE_LOG2))
1104        .map(|i| ring.int_hom().map(i))
1105        .collect::<Vec<_>>();
1106    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
1107    bencher.iter(|| run_fft_bench_round(&fft, &data, &mut copy));
1108}
1109
1110#[bench]
1111fn bench_fft_zn_64_fastmul(bencher: &mut test::Bencher) {
1112    let ring = zn_64::Zn::new(1073872897);
1113    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
1114    let fft = CooleyTuckeyFFT::for_zn_with_hom(ring.into_can_hom(fastmul_ring).ok().unwrap(), BENCH_SIZE_LOG2).unwrap();
1115    let data = (0..(1 << BENCH_SIZE_LOG2))
1116        .map(|i| ring.int_hom().map(i))
1117        .collect::<Vec<_>>();
1118    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
1119    bencher.iter(|| run_fft_bench_round(&fft, &data, &mut copy));
1120}
1121
1122#[test]
1123fn test_approximate_fft() {
1124    let CC = Complex64::RING;
1125    for log2_n in [4, 7, 11, 15] {
1126        let fft = CooleyTuckeyFFT::new_with_pows(CC, |x| CC.root_of_unity(x, 1 << log2_n), log2_n);
1127        let mut array = (0..(1 << log2_n))
1128            .map(|i| CC.root_of_unity(i.try_into().unwrap(), 1 << log2_n))
1129            .collect::<Vec<_>>();
1130        fft.fft(&mut array, CC);
1131        let err = fft.expected_absolute_error(1.0, 0.0);
1132        assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
1133        assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
1134        for i in 2..fft.len() {
1135            assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
1136        }
1137    }
1138}
1139
1140#[test]
1141fn test_size_1_fft() {
1142    let ring = Fp::<17>::RING;
1143    let fft = CooleyTuckeyFFT::for_zn(&ring, 0)
1144        .unwrap()
1145        .change_ring(ring.identity())
1146        .0;
1147    let values: [u64; 1] = [3];
1148    let mut work = values;
1149    fft.unordered_fft(&mut work, ring);
1150    assert_eq!(&work, &values);
1151    fft.unordered_inv_fft(&mut work, ring);
1152    assert_eq!(&work, &values);
1153    assert_eq!(0, fft.unordered_fft_permutation(0));
1154    assert_eq!(0, fft.unordered_fft_permutation_inv(0));
1155}
1156
1157#[cfg(any(test, feature = "generic_tests"))]
1158pub fn generic_test_cooley_tuckey_butterfly<R: RingStore, S: RingStore, I: Iterator<Item = El<R>>>(
1159    ring: R,
1160    base: S,
1161    edge_case_elements: I,
1162    test_twiddle: &El<S>,
1163) where
1164    R::Type: CanHomFrom<S::Type>,
1165    S::Type: DivisibilityRing,
1166{
1167    let test_inv_twiddle = base.invert(&test_twiddle).unwrap();
1168    let elements = edge_case_elements.collect::<Vec<_>>();
1169    let hom = ring.can_hom(&base).unwrap();
1170
1171    for a in &elements {
1172        for b in &elements {
1173            let [mut x, mut y] = [ring.clone_el(a), ring.clone_el(b)];
1174            <R::Type as CooleyTuckeyButterfly<S::Type>>::butterfly_new(&hom, &mut x, &mut y, &test_twiddle);
1175            assert_el_eq!(
1176                ring,
1177                ring.add_ref_fst(a, ring.mul_ref_fst(b, hom.map_ref(test_twiddle))),
1178                &x
1179            );
1180            assert_el_eq!(
1181                ring,
1182                ring.sub_ref_fst(a, ring.mul_ref_fst(b, hom.map_ref(test_twiddle))),
1183                &y
1184            );
1185
1186            <R::Type as CooleyTuckeyButterfly<S::Type>>::inv_butterfly_new(&hom, &mut x, &mut y, &test_inv_twiddle);
1187            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(a, 2), &x);
1188            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(b, 2), &y);
1189
1190            let [mut x, mut y] = [ring.clone_el(a), ring.clone_el(b)];
1191            <R::Type as CooleyTuckeyButterfly<S::Type>>::inv_butterfly_new(&hom, &mut x, &mut y, &test_twiddle);
1192            assert_el_eq!(ring, ring.add_ref(a, b), &x);
1193            assert_el_eq!(ring, ring.mul(ring.sub_ref(a, b), hom.map_ref(test_twiddle)), &y);
1194
1195            <R::Type as CooleyTuckeyButterfly<S::Type>>::butterfly_new(&hom, &mut x, &mut y, &test_inv_twiddle);
1196            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(a, 2), &x);
1197            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(b, 2), &y);
1198        }
1199    }
1200}
1201
1202#[test]
1203fn test_butterfly() {
1204    generic_test_cooley_tuckey_butterfly(
1205        zn_static::F17,
1206        zn_static::F17,
1207        zn_static::F17.elements(),
1208        &get_prim_root_of_unity_pow2(zn_static::F17, 4).unwrap(),
1209    );
1210}