feanor_math/algorithms/fft/
cooley_tuckey.rs

1use crate::algorithms::unity_root::*;
2use crate::divisibility::{DivisibilityRingStore, DivisibilityRing};
3use crate::rings::zn::*;
4use crate::seq::SwappableVectorViewMut;
5use crate::ring::*;
6use crate::seq::VectorViewMut;
7use crate::homomorphism::*;
8use crate::algorithms::fft::*;
9use crate::rings::float_complex::*;
10use super::complex_fft::*;
11
12///
13/// An optimized implementation of the Cooley-Tuckey FFT algorithm, to compute
14/// the Fourier transform of an array with power-of-two length.
15/// 
16/// # Example
17/// ```
18/// # use feanor_math::assert_el_eq;
19/// # use feanor_math::ring::*;
20/// # use feanor_math::algorithms::fft::*;
21/// # use feanor_math::rings::zn::*;
22/// # use feanor_math::algorithms::fft::cooley_tuckey::*;
23/// // this ring has a 256-th primitive root of unity
24/// let ring = zn_64::Zn::new(257);
25/// let fft_table = CooleyTuckeyFFT::for_zn(ring, 8).unwrap();
26/// let mut data = [ring.one()].into_iter().chain((0..255).map(|_| ring.zero())).collect::<Vec<_>>();
27/// fft_table.unordered_fft(&mut data, &ring);
28/// assert_el_eq!(ring, ring.one(), data[0]);
29/// assert_el_eq!(ring, ring.one(), data[1]);
30/// ```
31/// 
32#[derive(Debug)]
33pub struct CooleyTuckeyFFT<R_main, R_twiddle, H> 
34    where R_main: ?Sized + RingBase,
35        R_twiddle: ?Sized + RingBase,
36        H: Homomorphism<R_twiddle, R_main>
37{
38    hom: H,
39    root_of_unity: R_main::Element,
40    log2_n: usize,
41    // stores the powers of root_of_unity in special bitreversed order
42    root_of_unity_list: Vec<R_twiddle::Element>,
43    // stores the powers of inv_root_of_unity in special bitreversed order
44    inv_root_of_unity_list: Vec<R_twiddle::Element>
45}
46
47///
48/// Assumes that `index` has only the least significant `bits` bits set.
49/// Then computes the value that results from reversing the least significant `bits`
50/// bits.
51/// 
52pub fn bitreverse(index: usize, bits: usize) -> usize {
53    index.reverse_bits().checked_shr(usize::BITS - bits as u32).unwrap_or(0)
54}
55
56impl<R_main, H> CooleyTuckeyFFT<R_main, Complex64Base, H> 
57    where R_main: ?Sized + RingBase,
58        H: Homomorphism<Complex64Base, R_main>
59{
60    ///
61    /// Creates an [`CooleyTuckeyFFT`] for the complex field, using the given homomorphism
62    /// to connect the ring implementation for twiddles with the main ring implementation.
63    /// 
64    /// This function is mainly provided for parity with other rings, since in the complex case
65    /// it currently does not make much sense to use a different homomorphism than the identity.
66    /// Hence, it is simpler to use [`CooleyTuckeyFFT::for_complex()`].
67    /// 
68    pub fn for_complex_with_hom(hom: H, log2_n: usize) -> Self {
69        let CC = *hom.domain().get_ring();
70        Self::new_with_pows_with_hom(hom, |i| CC.root_of_unity(i, 1 << log2_n), log2_n)
71    }
72}
73
74impl<R> CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<R>> 
75    where R: RingStore<Type = Complex64Base>
76{
77    ///
78    /// Creates an [`CooleyTuckeyFFT`] for the complex field.
79    /// 
80    pub fn for_complex(ring: R, log2_n: usize) -> Self {
81        Self::for_complex_with_hom(ring.into_identity(), log2_n)
82    }
83}
84
85impl<R> CooleyTuckeyFFT<R::Type, R::Type, Identity<R>> 
86    where R: RingStore,
87        R::Type: DivisibilityRing
88{
89    ///
90    /// Creates an [`CooleyTuckeyFFT`] for the given ring, using the given root of unity
91    /// as base. Do not use this for approximate rings, as computing the powers of `root_of_unity`
92    /// will incur avoidable precision loss.
93    /// 
94    pub fn new(ring: R, root_of_unity: El<R>, log2_n: usize) -> Self {
95        Self::new_with_hom(ring.into_identity(), root_of_unity, log2_n)
96    }
97
98    ///
99    /// Creates an [`CooleyTuckeyFFT`] for the given ring, using the passed function to
100    /// provide the necessary roots of unity.
101    /// 
102    /// Concretely, `root_of_unity_pow(i)` should return `z^i`, where `z` is a `2^log2_n`-th
103    /// primitive root of unity.
104    /// 
105    pub fn new_with_pows<F>(ring: R, root_of_unity_pow: F, log2_n: usize) -> Self 
106        where F: FnMut(i64) -> El<R>
107    {
108        Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_pow, log2_n)
109    }
110
111    ///
112    /// Creates an [`CooleyTuckeyFFT`] for a prime field, assuming it has a characteristic
113    /// congruent to 1 modulo `2^log2_n`.
114    /// 
115    pub fn for_zn(ring: R, log2_n: usize) -> Option<Self>
116        where R::Type: ZnRing
117    {
118        Self::for_zn_with_hom(ring.into_identity(), log2_n)
119    }
120}
121
122impl<R_main, R_twiddle, H> CooleyTuckeyFFT<R_main, R_twiddle, H> 
123    where R_main: ?Sized + RingBase,
124        R_twiddle: ?Sized + RingBase + DivisibilityRing,
125        H: Homomorphism<R_twiddle, R_main>
126{
127    ///
128    /// Creates an [`CooleyTuckeyFFT`] for the given rings, using the given root of unity.
129    /// 
130    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
131    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
132    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
133    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
134    /// 
135    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
136    /// will incur avoidable precision loss.
137    /// 
138    pub fn new_with_hom(hom: H, root_of_unity: R_twiddle::Element, log2_n: usize) -> Self {
139        let ring = hom.domain();
140        let mut root_of_unity_pow = |i: i64| if i >= 0 {
141            ring.pow(ring.clone_el(&root_of_unity), i as usize)
142        } else {
143            ring.invert(&ring.pow(ring.clone_el(&root_of_unity), (-i) as usize)).unwrap()
144        };
145
146        // cannot call new_with_mem_and_pows() because of borrowing conflict
147        assert!(ring.is_commutative());
148        assert!(!hom.domain().get_ring().is_approximate());
149        assert!(is_prim_root_of_unity_pow2(&ring, &root_of_unity_pow(1), log2_n));
150        assert!(is_prim_root_of_unity_pow2(&hom.codomain(), &hom.map(root_of_unity_pow(1)), log2_n));
151
152        let root_of_unity_list = Self::create_root_of_unity_list(ring.get_ring(), &mut root_of_unity_pow, log2_n);
153        let inv_root_of_unity_list = Self::create_root_of_unity_list(ring.get_ring(), |i| root_of_unity_pow(-i), log2_n);
154        let root_of_unity = root_of_unity_pow(1);
155
156        CooleyTuckeyFFT {
157            root_of_unity: hom.map(root_of_unity), 
158            hom, 
159            log2_n, 
160            root_of_unity_list, 
161            inv_root_of_unity_list
162        }
163    }
164
165    ///
166    /// Creates an [`CooleyTuckeyFFT`] for the given rings, using the given function to create
167    /// the necessary powers of roots of unity. This is the most generic way to create [`CooleyTuckeyFFT`].
168    /// 
169    /// Concretely, `root_of_unity_pow(i)` should return `z^i`, where `z` is a `2^log2_n`-th
170    /// primitive root of unity.
171    /// 
172    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
173    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
174    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
175    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
176    /// 
177    pub fn new_with_pows_with_hom<F>(hom: H, mut root_of_unity_pow: F, log2_n: usize) -> Self 
178        where F: FnMut(i64) -> R_twiddle::Element
179    {
180        let ring = hom.domain();
181        assert!(ring.is_commutative());
182        assert!(log2_n > 0);
183        assert!(ring.get_ring().is_approximate() || is_prim_root_of_unity_pow2(&ring, &root_of_unity_pow(1), log2_n));
184        assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity_pow2(&hom.codomain(), &hom.map(root_of_unity_pow(1)), log2_n));
185
186        let root_of_unity_list = Self::create_root_of_unity_list(ring.get_ring(), &mut root_of_unity_pow, log2_n);
187        let inv_root_of_unity_list = Self::create_root_of_unity_list(ring.get_ring(), |i| root_of_unity_pow(-i), log2_n);
188        let root_of_unity = root_of_unity_pow(1);
189        
190        CooleyTuckeyFFT {
191            root_of_unity: hom.map(root_of_unity), 
192            hom, 
193            log2_n, 
194            root_of_unity_list, 
195            inv_root_of_unity_list
196        }
197    }
198
199    fn create_root_of_unity_list<F>(ring: &R_twiddle, mut root_of_unity_pow: F, log2_n: usize) -> Vec<R_twiddle::Element>
200        where F: FnMut(i64) -> R_twiddle::Element
201    {
202        // in fact, we could choose this to have only length `(1 << log2_n) - 1`, but a power of two length is probably faster
203        let mut root_of_unity_list = (0..(1 << log2_n)).map(|_| ring.zero()).collect::<Vec<_>>();
204        let mut index = 0;
205        for s in 0..log2_n {
206            let m = 1 << s;
207            let log2_group_size = log2_n - s;
208            for i_bitreverse in (0..(1 << log2_group_size)).step_by(2) {
209                let current_twiddle = root_of_unity_pow(m * bitreverse(i_bitreverse, log2_group_size) as i64);
210                root_of_unity_list[index] = current_twiddle;
211                index += 1;
212            }
213        }
214        assert_eq!(index, (1 << log2_n) - 1);
215        return root_of_unity_list;
216    }
217
218    ///
219    /// Creates an [`CooleyTuckeyFFT`] for the given prime fields, assuming they have
220    /// a characteristic congruent to 1 modulo `2^log2_n`.
221    /// 
222    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
223    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
224    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
225    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
226    /// 
227    pub fn for_zn_with_hom(hom: H, log2_n: usize) -> Option<Self>
228        where R_twiddle: ZnRing
229    {
230        let ring_as_field = hom.domain().as_field().ok().unwrap();
231        let root_of_unity = ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity_pow2(&ring_as_field, log2_n)?);
232        drop(ring_as_field);
233        Some(Self::new_with_hom(hom, root_of_unity, log2_n))
234    }
235
236    pub fn bitreverse_permute_inplace<V, T>(&self, mut values: V) 
237        where V: SwappableVectorViewMut<T>
238    {
239        assert!(values.len() == 1 << self.log2_n);
240        for i in 0..(1 << self.log2_n) {
241            if bitreverse(i, self.log2_n) < i {
242                values.swap(i, bitreverse(i, self.log2_n));
243            }
244        }
245    }
246}
247
248impl<R_main, R_twiddle, H> PartialEq for CooleyTuckeyFFT<R_main, R_twiddle, H> 
249    where R_main: ?Sized + RingBase,
250        R_twiddle: ?Sized + RingBase + DivisibilityRing,
251        H: Homomorphism<R_twiddle, R_main>
252{
253    fn eq(&self, other: &Self) -> bool {
254        self.ring().get_ring() == other.ring().get_ring() &&
255            self.log2_n == other.log2_n &&
256            self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
257    }
258}
259
260impl<R_main, R_twiddle, H> Clone for CooleyTuckeyFFT<R_main, R_twiddle, H> 
261    where R_main: ?Sized + RingBase,
262        R_twiddle: ?Sized + RingBase + DivisibilityRing,
263        H: Homomorphism<R_twiddle, R_main> + Clone
264{
265    fn clone(&self) -> Self {
266        Self {
267            hom: self.hom.clone(),
268            inv_root_of_unity_list: self.inv_root_of_unity_list.iter().map(|x| self.hom.domain().clone_el(x)).collect(),
269            root_of_unity: self.hom.codomain().clone_el(&self.root_of_unity),
270            log2_n: self.log2_n,
271            root_of_unity_list: self.root_of_unity_list.iter().map(|x| self.hom.domain().clone_el(x)).collect()
272        }
273    }
274}
275
276///
277/// A helper trait that defines the Cooley-Tuckey butterfly operation.
278/// It is default-implemented for all rings, but for increase FFT performance, some rings
279/// might wish to provide a specialization.
280/// 
281pub trait CooleyTuckeyButterfly<S>: RingBase
282    where S: ?Sized + RingBase
283{
284    ///
285    /// Should compute `(values[i1], values[i2]) := (values[i1] + twiddle * values[i2], values[i1] - twiddle * values[i2])`
286    /// 
287    fn butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(&self, hom: H, values: &mut V, twiddle: &S::Element, i1: usize, i2: usize);
288
289    ///
290    /// Should compute `(values[i1], values[i2]) := (values[i1] + values[i2], (values[i1] - values[i2]) * twiddle)`
291    /// 
292    fn inv_butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(&self, hom: H, values: &mut V, twiddle: &S::Element, i1: usize, i2: usize);
293}
294
295impl<R, S> CooleyTuckeyButterfly<S> for R
296    where S: ?Sized + RingBase, R: ?Sized + RingBase
297{
298    #[inline(always)]
299    default fn butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(&self, hom: H, values: &mut V, twiddle: &<S as RingBase>::Element, i1: usize, i2: usize) {
300        hom.mul_assign_ref_map(values.at_mut(i2), twiddle);
301        let new_a = self.add_ref(values.at(i1), values.at(i2));
302        let a = std::mem::replace(values.at_mut(i1), new_a);
303        self.sub_self_assign(values.at_mut(i2), a);
304    }
305
306    #[inline(always)]
307    default fn inv_butterfly<V: VectorViewMut<Self::Element>, H: Homomorphism<S, Self>>(&self, hom: H, values: &mut V, twiddle: &<S as RingBase>::Element, i1: usize, i2: usize) {
308        let new_a = self.add_ref(values.at(i1), values.at(i2));
309        let a = std::mem::replace(values.at_mut(i1), new_a);
310        self.sub_self_assign(values.at_mut(i2), a);
311        hom.mul_assign_ref_map(values.at_mut(i2), twiddle);
312    }
313}
314
315impl<R_main, R_twiddle, H> CooleyTuckeyFFT<R_main, R_twiddle, H> 
316    where R_main: ?Sized + RingBase,
317        R_twiddle: ?Sized + RingBase + DivisibilityRing,
318        H: Homomorphism<R_twiddle, R_main>
319{
320    fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
321        self.hom.codomain()
322    }
323
324    ///
325    /// Almost computes the FFT, but skips the first butterfly level.
326    /// 
327    /// This is designed for use by Bluestein's FFT, since it knows that
328    /// half the input is zero, which means that the first butterfly is trivial.
329    /// 
330    /// It also does not include division by `n` in the inverse case.
331    /// 
332    /// Note however that I had some weird performance results when actually
333    /// using this during the bluestein transform. More concretely, local
334    /// microbenchmarks were faster, but there was a significant slowdown
335    /// when using it "in practice" in my HE library.
336    /// 
337    #[allow(unused)]
338    pub(super) fn unordered_fft_skip_first_butterfly<V, const INV: bool>(&self, values: &mut V)
339        where V: VectorViewMut<R_main::Element> 
340    {
341        assert!(values.len() == (1 << self.log2_n));
342
343        let hom = &self.hom;
344        let R = hom.codomain();
345
346        for step in 1..self.log2_n {
347
348            let (log2_m, log2_group_size_half) = if !INV {
349                (self.log2_n - step - 1, step)  
350            } else {
351                (step - 1, self.log2_n - step)
352            };
353            let group_size_half = 1 << log2_group_size_half;
354            let m = 1 << log2_m;
355            let two_m = 2 << log2_m;
356            const UNROLL_COUNT: usize = 4;
357
358
359            if group_size_half < UNROLL_COUNT {
360
361                for k in 0..m {
362                    let mut root_of_unity_index = (1 << self.log2_n) - 2 * group_size_half;
363                    let mut index1 = k;
364        
365                    for _ in 0..group_size_half {
366        
367                        if !INV {
368                            let current_twiddle = &self.inv_root_of_unity_list[root_of_unity_index];
369                            R.get_ring().butterfly(hom, values, current_twiddle, index1, index1 + m);
370                        } else {
371                            let current_twiddle = &self.root_of_unity_list[root_of_unity_index];
372                            R.get_ring().inv_butterfly(hom, values, current_twiddle, index1, index1 + m);
373                        }
374                        root_of_unity_index += 1;
375                        index1 += two_m;
376        
377                    }
378                }
379
380            } else {
381
382                for k in 0..m {
383                    let mut root_of_unity_index = (1 << self.log2_n) - 2 * group_size_half;
384                    let mut index1 = k;
385        
386                    for _ in (0..group_size_half).step_by(UNROLL_COUNT) {
387                        for _ in 0..UNROLL_COUNT {
388        
389                            if !INV {
390                                let current_twiddle = &self.inv_root_of_unity_list[root_of_unity_index];
391                                R.get_ring().butterfly(hom, values, current_twiddle, index1, index1 + m);
392                            } else {
393                                let current_twiddle = &self.root_of_unity_list[root_of_unity_index];
394                                R.get_ring().inv_butterfly(hom, values, current_twiddle, index1, index1 + m);
395                            }
396                            root_of_unity_index += 1;
397                            index1 += two_m;
398        
399                        }
400                    }
401                }
402            }
403        }
404    }
405
406    ///
407    /// Optimized implementation of the inplace Cooley-Tuckey FFT algorithm.
408    /// Note that setting `INV = true` will perform an inverse fourier transform,
409    /// except that the division by `n` is not included.
410    /// 
411    /// I added #[inline(never)] to make profiling this easy, it does not have
412    /// any noticable impact on performance.
413    /// 
414    #[inline(never)]
415    fn unordered_fft_dispatch<V, const INV: bool>(&self, values: &mut V)
416        where V: VectorViewMut<R_main::Element> 
417    {
418        assert!(values.len() == (1 << self.log2_n));
419
420        let hom = &self.hom;
421        let R = hom.codomain();
422
423        for step in 0..self.log2_n {
424
425            let (log2_m, log2_group_size_half) = if !INV {
426                (self.log2_n - step - 1, step)  
427            } else {
428                (step, self.log2_n - step - 1)
429            };
430            let group_size_half = 1 << log2_group_size_half;
431            let m = 1 << log2_m;
432            let two_m = 2 << log2_m;
433            const UNROLL_COUNT: usize = 4;
434
435            if group_size_half < UNROLL_COUNT {
436
437                for k in 0..(1 << log2_m) {
438
439                    let mut root_of_unity_index = (1 << self.log2_n) - 2 * group_size_half;
440
441                    // 
442                    // we want to compute a bitreverse_fft_inplace for `v_k, v_(k + m), v_(k + 2m), ..., v_(k + n - m)`;
443                    // call this sequence a1
444                    //
445                    // we already have a bitreverse fft of `v_k, v_(k + 2m), v_(k + 4m), ..., v_(k + n - 2m) `
446                    // and `v_(k + m), v_(k + 3m), v_(k + 5m), ..., v_(k + n - m)` in the corresponding entries;
447                    // call these sequences a1 and a2
448                    //
449                    // Note that a1_i is stored in `(k + 2m * bitrev(i, n/m))` and a2_i in `(k + m + 2m * bitrev(i, n/m))`;
450                    // We want to store a_i in `(k + m + m * bitrev(i, 2n/m))`
451                    //
452                    for i_bitreverse in 0..group_size_half {
453                        //
454                        // we want to compute `(a_i, a_(i + group_size/2)) = (a1_i + z^i a2_i, a1_i - z^i a2_i)`
455                        //
456                        // in bitreverse order, have
457                        // `i_bitreverse     = bitrev(i, group_size) = 2 bitrev(i, group_size/2)` and
458                        // `i_bitreverse + 1 = bitrev(i + group_size/2, group_size) = 2 bitrev(i, group_size/2) + 1`
459                        //
460                        let index1 = i_bitreverse * two_m + k;
461                        let index2 = index1 + m;
462    
463                        if !INV {
464                            let current_twiddle = &self.inv_root_of_unity_list[root_of_unity_index];
465                            R.get_ring().butterfly(hom, values, current_twiddle, index1, index2);
466                        } else {
467                            let current_twiddle = &self.root_of_unity_list[root_of_unity_index];
468                            R.get_ring().inv_butterfly(hom, values, current_twiddle, index1, index2);
469                        }
470                        root_of_unity_index += 1;
471                    }
472                }
473
474            } else {
475            
476                // same but loop is unrolled
477
478                for k in 0..m {
479
480                    let mut root_of_unity_index = (1 << self.log2_n) - 2 * group_size_half;
481                    let mut index1 = k;
482
483                    for _ in (0..group_size_half).step_by(UNROLL_COUNT) {
484                        for _ in 0..UNROLL_COUNT {
485
486                            if !INV {
487                                let current_twiddle = &self.inv_root_of_unity_list[root_of_unity_index];
488                                R.get_ring().butterfly(hom, values, current_twiddle, index1, index1 + m);
489                            } else {
490                                let current_twiddle = &self.root_of_unity_list[root_of_unity_index];
491                                R.get_ring().inv_butterfly(hom, values, current_twiddle, index1, index1 + m);
492                            }
493                            root_of_unity_index += 1;
494                            index1 += two_m;
495
496                        }
497                    }
498                }
499            }
500        }
501    }
502}
503
504impl<R_main, R_twiddle, H> FFTAlgorithm<R_main> for CooleyTuckeyFFT<R_main, R_twiddle, H> 
505    where R_main: ?Sized + RingBase,
506        R_twiddle: ?Sized + RingBase + DivisibilityRing,
507        H: Homomorphism<R_twiddle, R_main>
508{
509    fn len(&self) -> usize {
510        1 << self.log2_n
511    }
512
513    fn root_of_unity<S: Copy + RingStore<Type = R_main>>(&self, ring: S) -> &R_main::Element {
514        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
515        &self.root_of_unity
516    }
517
518    fn unordered_fft_permutation(&self, i: usize) -> usize {
519        bitreverse(i, self.log2_n)
520    }
521
522    fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
523        bitreverse(i, self.log2_n)
524    }
525
526    fn fft<V, S>(&self, mut values: V, ring: S)
527        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
528            S: RingStore<Type = R_main> + Copy 
529    {
530        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
531        self.unordered_fft(&mut values, ring);
532        self.bitreverse_permute_inplace(&mut values);
533    }
534
535    fn inv_fft<V, S>(&self, mut values: V, ring: S)
536        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
537            S: RingStore<Type = R_main> + Copy 
538    {
539        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
540        self.bitreverse_permute_inplace(&mut values);
541        self.unordered_inv_fft(&mut values, ring);
542    }
543
544    fn unordered_fft<V, S>(&self, mut values: V, ring: S)
545        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
546            S: RingStore<Type = R_main> + Copy 
547    {
548        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
549        self.unordered_fft_dispatch::<V, false>(&mut values);
550    }
551    
552    fn unordered_inv_fft<V, S>(&self, mut values: V, ring: S)
553        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
554            S: RingStore<Type = R_main> + Copy 
555    {
556        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
557        self.unordered_fft_dispatch::<V, true>(&mut values);
558        let inv = self.hom.domain().invert(&self.hom.domain().int_hom().map(1 << self.log2_n)).unwrap();
559        for i in 0..values.len() {
560            self.hom.mul_assign_ref_map(values.at_mut(i), &inv);
561        }
562    }
563}
564
565impl<H> FFTErrorEstimate for CooleyTuckeyFFT<Complex64Base, Complex64Base, H> 
566    where H: Homomorphism<Complex64Base, Complex64Base>
567{
568    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
569        // each butterfly doubles the error, and then adds up to 
570        let butterfly_absolute_error = input_bound * (root_of_unity_error() + f64::EPSILON);
571        // the operator inf-norm of the FFT is its length
572        return 2. * self.len() as f64 * butterfly_absolute_error + self.len() as f64 * input_error;
573    }
574}
575
576#[cfg(test)]
577use crate::primitive_int::*;
578#[cfg(test)]
579use crate::rings::zn::zn_static::Fp;
580#[cfg(test)]
581use crate::rings::zn::zn_big;
582#[cfg(test)]
583use crate::rings::zn::zn_static;
584#[cfg(test)]
585use crate::field::*;
586#[cfg(test)]
587use crate::rings::finite::FiniteRingStore;
588
589#[test]
590fn test_bitreverse_fft_inplace_basic() {
591    let ring = Fp::<5>::RING;
592    let z = ring.int_hom().map(2);
593    let fft = CooleyTuckeyFFT::new(ring, ring.div(&1, &z), 2);
594    let mut values = [1, 0, 0, 1];
595    let expected = [2, 4, 0, 3];
596    let mut bitreverse_expected = [0; 4];
597    for i in 0..4 {
598        bitreverse_expected[i] = expected[bitreverse(i, 2)];
599    }
600
601    fft.unordered_fft(&mut values, ring);
602    assert_eq!(values, bitreverse_expected);
603}
604
605#[test]
606fn test_bitreverse_fft_inplace_advanced() {
607    let ring = Fp::<17>::RING;
608    let z = ring.int_hom().map(3);
609    let fft = CooleyTuckeyFFT::new(ring, z, 4);
610    let mut values = [1, 0, 0, 0, 1, 0, 0, 0, 4, 3, 2, 1, 4, 3, 2, 1];
611    let expected = [5, 2, 0, 11, 5, 4, 0, 6, 6, 13, 0, 1, 7, 6, 0, 1];
612    let mut bitreverse_expected = [0; 16];
613    for i in 0..16 {
614        bitreverse_expected[i] = expected[bitreverse(i, 4)];
615    }
616
617    fft.unordered_fft(&mut values, ring);
618    assert_eq!(values, bitreverse_expected);
619}
620
621#[test]
622fn test_unordered_fft_permutation() {
623    let ring = Fp::<17>::RING;
624    let fft = CooleyTuckeyFFT::for_zn(&ring, 4).unwrap();
625    let mut values = [0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0];
626    let mut expected = [0; 16];
627    for i in 0..16 {
628        let power_of_zeta = ring.pow(*fft.root_of_unity(&ring), 16 - fft.unordered_fft_permutation(i));
629        expected[i] = ring.add(power_of_zeta, ring.pow(power_of_zeta, 4));
630    }
631    fft.unordered_fft(&mut values, ring);
632    assert_eq!(expected, values);
633}
634
635#[test]
636fn test_bitreverse_inv_fft_inplace() {
637    let ring = Fp::<17>::RING;
638    let fft = CooleyTuckeyFFT::for_zn(&ring, 4).unwrap();
639    let values: [u64; 16] = [1, 2, 3, 2, 1, 0, 17 - 1, 17 - 2, 17 - 1, 0, 1, 2, 3, 4, 5, 6];
640    let mut work = values;
641    fft.unordered_fft(&mut work, ring);
642    fft.unordered_inv_fft(&mut work, ring);
643    assert_eq!(&work, &values);
644}
645
646#[test]
647fn test_for_zn() {
648    let ring = Fp::<17>::RING;
649    let fft = CooleyTuckeyFFT::for_zn(ring, 4).unwrap();
650    assert!(ring.is_neg_one(&ring.pow(fft.root_of_unity, 8)));
651
652    let ring = Fp::<97>::RING;
653    let fft = CooleyTuckeyFFT::for_zn(ring, 4).unwrap();
654    assert!(ring.is_neg_one(&ring.pow(fft.root_of_unity, 8)));
655}
656
657#[cfg(test)]
658fn run_fft_bench_round<R, S, H>(fft: &CooleyTuckeyFFT<S, R, H>, data: &Vec<S::Element>, copy: &mut Vec<S::Element>)
659    where R: ZnRing, S: ZnRing, H: Homomorphism<R, S>
660{
661    copy.clear();
662    copy.extend(data.iter().map(|x| fft.ring().clone_el(x)));
663    fft.unordered_fft(&mut copy[..], &fft.ring());
664    fft.unordered_inv_fft(&mut copy[..], &fft.ring());
665    assert_el_eq!(fft.ring(), copy[0], data[0]);
666}
667
668#[cfg(test)]
669const BENCH_SIZE_LOG2: usize = 13;
670
671#[bench]
672fn bench_fft_zn_big(bencher: &mut test::Bencher) {
673    let ring = zn_big::Zn::new(StaticRing::<i128>::RING, 1073872897);
674    let fft = CooleyTuckeyFFT::for_zn(&ring, BENCH_SIZE_LOG2).unwrap();
675    let data = (0..(1 << BENCH_SIZE_LOG2)).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
676    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
677    bencher.iter(|| {
678        run_fft_bench_round(&fft, &data, &mut copy)
679    });
680}
681
682#[bench]
683fn bench_fft_zn_64(bencher: &mut test::Bencher) {
684    let ring = zn_64::Zn::new(1073872897);
685    let fft = CooleyTuckeyFFT::for_zn(&ring, BENCH_SIZE_LOG2).unwrap();
686    let data = (0..(1 << BENCH_SIZE_LOG2)).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
687    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
688    bencher.iter(|| {
689        run_fft_bench_round(&fft, &data, &mut copy)
690    });
691}
692
693#[bench]
694fn bench_fft_zn_64_fastmul(bencher: &mut test::Bencher) {
695    let ring = zn_64::Zn::new(1073872897);
696    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
697    let fft = CooleyTuckeyFFT::for_zn_with_hom(ring.into_can_hom(fastmul_ring).ok().unwrap(), BENCH_SIZE_LOG2).unwrap();
698    let data = (0..(1 << BENCH_SIZE_LOG2)).map(|i| ring.int_hom().map(i)).collect::<Vec<_>>();
699    let mut copy = Vec::with_capacity(1 << BENCH_SIZE_LOG2);
700    bencher.iter(|| {
701        run_fft_bench_round(&fft, &data, &mut copy)
702    });
703}
704
705#[test]
706fn test_approximate_fft() {
707    let CC = Complex64::RING;
708    for log2_n in [4, 7, 11, 15] {
709        let fft = CooleyTuckeyFFT::new_with_pows(CC, |x| CC.root_of_unity(x, 1 << log2_n), log2_n);
710        let mut array = (0..(1 << log2_n)).map(|i|  CC.root_of_unity(i as i64, 1 << log2_n)).collect::<Vec<_>>();
711        fft.fft(&mut array, CC);
712        let err = fft.expected_absolute_error(1., 0.);
713        assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
714        assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
715        for i in 2..fft.len() {
716            assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
717        }
718    }
719}
720
721#[test]
722fn test_size_1_fft() {
723    let ring = Fp::<17>::RING;
724    let fft = CooleyTuckeyFFT::for_zn(&ring, 0).unwrap();
725    let values: [u64; 1] = [3];
726    let mut work = values;
727    fft.unordered_fft(&mut work, ring);
728    assert_eq!(&work, &values);
729    fft.unordered_inv_fft(&mut work, ring);
730    assert_eq!(&work, &values);
731    assert_eq!(0, fft.unordered_fft_permutation(0));
732    assert_eq!(0, fft.unordered_fft_permutation_inv(0));
733}
734
735#[cfg(any(test, feature = "generic_tests"))]
736pub fn generic_test_cooley_tuckey_butterfly<R: RingStore, S: RingStore, I: Iterator<Item = El<R>>>(ring: R, base: S, edge_case_elements: I, test_twiddle: &El<S>)
737    where R::Type: CanHomFrom<S::Type>,
738        S::Type: DivisibilityRing
739{
740    let test_inv_twiddle = base.invert(&test_twiddle).unwrap();
741    let elements = edge_case_elements.collect::<Vec<_>>();
742    let hom = ring.can_hom(&base).unwrap();
743
744    for a in &elements {
745        for b in &elements {
746
747            let mut vector = [ring.clone_el(a), ring.clone_el(b)];
748            ring.get_ring().butterfly(&hom, &mut vector, &test_twiddle, 0, 1);
749            assert_el_eq!(ring, ring.add_ref_fst(a, ring.mul_ref_fst(b, hom.map_ref(test_twiddle))), &vector[0]);
750            assert_el_eq!(ring, ring.sub_ref_fst(a, ring.mul_ref_fst(b, hom.map_ref(test_twiddle))), &vector[1]);
751
752            ring.get_ring().inv_butterfly(&hom, &mut vector, &test_inv_twiddle, 0, 1);
753            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(a, 2), &vector[0]);
754            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(b, 2), &vector[1]);
755
756            let mut vector = [ring.clone_el(a), ring.clone_el(b)];
757            ring.get_ring().butterfly(&hom, &mut vector, &test_twiddle, 1, 0);
758            assert_el_eq!(ring, ring.add_ref_fst(b, ring.mul_ref_fst(a, hom.map_ref(test_twiddle))), &vector[1]);
759            assert_el_eq!(ring, ring.sub_ref_fst(b, ring.mul_ref_fst(a, hom.map_ref(test_twiddle))), &vector[0]);
760
761            ring.get_ring().inv_butterfly(&hom, &mut vector, &test_inv_twiddle, 1, 0);
762            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(a, 2), &vector[0]);
763            assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(b, 2), &vector[1]);
764        }
765    }
766}
767
768#[test]
769fn test_butterfly() {
770    generic_test_cooley_tuckey_butterfly(zn_static::F17, zn_static::F17, zn_static::F17.elements(), &get_prim_root_of_unity_pow2(zn_static::F17, 4).unwrap());
771}