Skip to main content

feanor_math/algorithms/fft/
radix3.rs

1use std::alloc::{Allocator, Global};
2
3use crate::algorithms::fft::FFTAlgorithm;
4use crate::algorithms::fft::complex_fft::*;
5use crate::algorithms::unity_root::{get_prim_root_of_unity_zn, is_prim_root_of_unity};
6use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
7use crate::homomorphism::*;
8use crate::primitive_int::StaticRing;
9use crate::ring::*;
10use crate::rings::float_complex::Complex64Base;
11use crate::rings::zn::*;
12use crate::seq::{SwappableVectorViewMut, VectorFn};
13
14/// Implementation of the Cooley-Tukey FFT algorithm for power-of-three lengths.
15#[stability::unstable(feature = "enable")]
16pub struct CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A = Global>
17where
18    R_main: ?Sized + RingBase,
19    R_twiddle: ?Sized + RingBase + DivisibilityRing,
20    H: Homomorphism<R_twiddle, R_main>,
21    A: Allocator,
22{
23    log3_n: usize,
24    hom: H,
25    twiddles: Vec<Vec<R_twiddle::Element>>,
26    inv_twiddles: Vec<Vec<R_twiddle::Element>>,
27    third_root_of_unity: R_twiddle::Element,
28    main_root_of_unity: R_main::Element,
29    allocator: A,
30}
31
32const ZZ: StaticRing<i64> = StaticRing::RING;
33
34#[inline(never)]
35fn butterfly_loop<T, S, F>(log3_n: usize, data: &mut [T], step: usize, twiddles: &[S], butterfly: F)
36where
37    F: Fn(&mut T, &mut T, &mut T, &S, &S) + Clone,
38{
39    assert_eq!(ZZ.pow(3, log3_n) as usize, data.len());
40    assert!(step < log3_n);
41    assert_eq!(2 * ZZ.pow(3, step) as usize, twiddles.len());
42
43    let stride = ZZ.pow(3, log3_n - step - 1) as usize;
44    assert!(data.len().is_multiple_of(3 * stride));
45    assert_eq!(twiddles.as_chunks::<2>().0.len(), data.chunks_mut(3 * stride).len());
46
47    if stride == 1 {
48        for ([twiddle1, twiddle2], butterfly_data) in twiddles
49            .as_chunks::<2>()
50            .0
51            .iter()
52            .zip(data.as_chunks_mut::<3>().0.iter_mut())
53        {
54            let [a, b, c] = butterfly_data.each_mut();
55            butterfly(a, b, c, twiddle1, twiddle2);
56        }
57    } else {
58        for ([twiddle1, twiddle2], butterfly_data) in
59            twiddles.as_chunks::<2>().0.iter().zip(data.chunks_mut(3 * stride))
60        {
61            let (first, rest) = butterfly_data.split_at_mut(stride);
62            let (second, third) = rest.split_at_mut(stride);
63            for ((a, b), c) in first.iter_mut().zip(second.iter_mut()).zip(third.iter_mut()) {
64                butterfly(a, b, c, twiddle1, twiddle2);
65            }
66        }
67    }
68}
69
70fn threeadic_reverse(mut number: usize, log3_n: usize) -> usize {
71    debug_assert!((number as i64) < ZZ.pow(3, log3_n));
72    let mut result = 0;
73    for _ in 0..log3_n {
74        let (quo, rem) = (number / 3, number % 3);
75        result = 3 * result + rem;
76        number = quo;
77    }
78    assert_eq!(0, number);
79    return result;
80}
81
82impl<R> CooleyTukeyRadix3FFT<R::Type, R::Type, Identity<R>>
83where
84    R: RingStore,
85    R::Type: DivisibilityRing,
86{
87    /// Creates an [`CooleyTukeyRadix3FFT`] for a prime field, assuming it has a characteristic
88    /// congruent to 1 modulo `3^lo32_n`.
89    #[stability::unstable(feature = "enable")]
90    pub fn for_zn(ring: R, log3_n: usize) -> Option<Self>
91    where
92        R::Type: ZnRing,
93    {
94        let n = ZZ.pow(3, log3_n);
95        let root_of_unity = get_prim_root_of_unity_zn(&ring, n as usize)?;
96        return Some(Self::new_with_hom(ring.into_identity(), root_of_unity, log3_n));
97    }
98}
99
100impl<R_main, R_twiddle, H> CooleyTukeyRadix3FFT<R_main, R_twiddle, H>
101where
102    R_main: ?Sized + RingBase,
103    R_twiddle: ?Sized + RingBase + DivisibilityRing,
104    H: Homomorphism<R_twiddle, R_main>,
105{
106    /// Creates an [`CooleyTukeyRadix3FFT`] for the given rings, using the given root of unity.
107    ///
108    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
109    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
110    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
111    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
112    ///
113    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
114    /// will incur avoidable precision loss.
115    #[stability::unstable(feature = "enable")]
116    pub fn new_with_hom(hom: H, zeta: R_twiddle::Element, log3_n: usize) -> Self {
117        let ring = hom.domain();
118        let pow_zeta = |i: i64| {
119            if i < 0 {
120                ring.invert(&ring.pow(ring.clone_el(&zeta), (-i).try_into().unwrap()))
121                    .unwrap()
122            } else {
123                ring.pow(ring.clone_el(&zeta), i.try_into().unwrap())
124            }
125        };
126        let result = CooleyTukeyRadix3FFT::create(&hom, pow_zeta, log3_n, Global);
127        return Self {
128            allocator: result.allocator,
129            inv_twiddles: result.inv_twiddles,
130            log3_n: result.log3_n,
131            main_root_of_unity: result.main_root_of_unity,
132            third_root_of_unity: result.third_root_of_unity,
133            twiddles: result.twiddles,
134            hom,
135        };
136    }
137}
138
139impl<R_main, R_twiddle, H, A> CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A>
140where
141    R_main: ?Sized + RingBase,
142    R_twiddle: ?Sized + RingBase + DivisibilityRing,
143    H: Homomorphism<R_twiddle, R_main>,
144    A: Allocator,
145{
146    /// Most general way to create a [`CooleyTukeyRadix3FFT`].
147    ///
148    /// The given closure should, on input `i`, return `z^i` for a primitive
149    /// `3^log3_n`-th root of unity. The given allocator is used to copy the input
150    /// data in cases where the input data layout is not optimal for the algorithm
151    #[stability::unstable(feature = "enable")]
152    pub fn create<F>(hom: H, mut root_of_unity_pow: F, log3_n: usize, allocator: A) -> Self
153    where
154        F: FnMut(i64) -> R_twiddle::Element,
155    {
156        let n = ZZ.pow(3, log3_n) as usize;
157        assert!(
158            hom.codomain().get_ring().is_approximate()
159                || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_pow(1)), n)
160        );
161
162        return Self {
163            main_root_of_unity: hom.map(root_of_unity_pow(1)),
164            log3_n,
165            twiddles: Self::create_twiddle_list(hom.domain(), log3_n, &mut root_of_unity_pow),
166            inv_twiddles: Self::create_inv_twiddle_list(hom.domain(), log3_n, &mut root_of_unity_pow),
167            third_root_of_unity: root_of_unity_pow(2 * n as i64 / 3),
168            hom,
169            allocator,
170        };
171    }
172
173    /// Replaces the ring that this object can compute FFTs over, assuming that the current
174    /// twiddle factors can be mapped into the new ring with the given homomorphism.
175    ///
176    /// In particular, this function does not recompute twiddles, but uses a different
177    /// homomorphism to map the current twiddles into a new ring. Hence, it is extremely
178    /// cheap.
179    #[stability::unstable(feature = "enable")]
180    pub fn change_ring<R_new: ?Sized + RingBase, H_new: Homomorphism<R_twiddle, R_new>>(
181        self,
182        new_hom: H_new,
183    ) -> (CooleyTukeyRadix3FFT<R_new, R_twiddle, H_new, A>, H) {
184        let ring = new_hom.codomain();
185        let root_of_unity = if self.log3_n == 0 {
186            new_hom.codomain().one()
187        } else if self.log3_n == 1 {
188            let root_of_unity = self
189                .hom
190                .domain()
191                .pow(self.hom.domain().clone_el(&self.third_root_of_unity), 2);
192            debug_assert!(self.ring().eq_el(
193                &self.hom.map_ref(&root_of_unity),
194                self.root_of_unity(self.hom.codomain())
195            ));
196            new_hom.map(root_of_unity)
197        } else {
198            let root_of_unity = &self.inv_twiddles[self.log3_n - 1][threeadic_reverse(1, self.log3_n - 1)];
199            debug_assert!(self.ring().eq_el(
200                &self.hom.map_ref(root_of_unity),
201                self.root_of_unity(self.hom.codomain())
202            ));
203            new_hom.map_ref(root_of_unity)
204        };
205        assert!(ring.is_commutative());
206        assert!(ring.get_ring().is_approximate() || is_prim_root_of_unity(&ring, &root_of_unity, self.len()));
207
208        return (
209            CooleyTukeyRadix3FFT {
210                twiddles: self.twiddles,
211                inv_twiddles: self.inv_twiddles,
212                main_root_of_unity: root_of_unity,
213                third_root_of_unity: self.third_root_of_unity,
214                hom: new_hom,
215                log3_n: self.log3_n,
216                allocator: self.allocator,
217            },
218            self.hom,
219        );
220    }
221
222    /// Returns the ring over which this object can compute FFTs.
223    #[stability::unstable(feature = "enable")]
224    pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore { self.hom.codomain() }
225
226    /// Returns a reference to the allocator currently used for temporary allocations by this FFT.
227    #[stability::unstable(feature = "enable")]
228    pub fn allocator(&self) -> &A { &self.allocator }
229
230    /// Replaces the allocator used for temporary allocations by this FFT.
231    #[stability::unstable(feature = "enable")]
232    pub fn with_allocator<A_new: Allocator>(
233        self,
234        allocator: A_new,
235    ) -> CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A_new> {
236        CooleyTukeyRadix3FFT {
237            twiddles: self.twiddles,
238            inv_twiddles: self.inv_twiddles,
239            main_root_of_unity: self.main_root_of_unity,
240            third_root_of_unity: self.third_root_of_unity,
241            hom: self.hom,
242            log3_n: self.log3_n,
243            allocator,
244        }
245    }
246
247    /// Returns a reference to the homomorphism that is used to map the stored twiddle
248    /// factors into main ring, over which FFTs are computed.
249    #[stability::unstable(feature = "enable")]
250    pub fn hom<'a>(&'a self) -> &'a H { &self.hom }
251
252    fn create_twiddle_list<F>(ring: &H::DomainStore, log3_n: usize, mut pow_zeta: F) -> Vec<Vec<R_twiddle::Element>>
253    where
254        F: FnMut(i64) -> R_twiddle::Element,
255    {
256        let n = ZZ.pow(3, log3_n);
257        let third_root_of_unity = pow_zeta(-(n / 3));
258        let mut result: Vec<_> = (0..log3_n).map(|_| Vec::new()).collect();
259        for i in 0..log3_n {
260            let current = &mut result[i];
261            for j in 0..ZZ.pow(3, i) {
262                let base_twiddle = pow_zeta(-(threeadic_reverse(j as usize, log3_n - 1) as i64));
263                current.push(ring.clone_el(&base_twiddle));
264                current.push(ring.pow(ring.mul_ref_snd(base_twiddle, &third_root_of_unity), 2));
265            }
266        }
267        return result;
268    }
269
270    fn create_inv_twiddle_list<F>(ring: &H::DomainStore, log3_n: usize, mut pow_zeta: F) -> Vec<Vec<R_twiddle::Element>>
271    where
272        F: FnMut(i64) -> R_twiddle::Element,
273    {
274        let mut result: Vec<_> = (0..log3_n).map(|_| Vec::new()).collect();
275        for i in 0..log3_n {
276            let current = &mut result[i];
277            for j in 0..ZZ.pow(3, i) {
278                let base_twiddle = pow_zeta(threeadic_reverse(j as usize, log3_n - 1) as i64);
279                current.push(ring.clone_el(&base_twiddle));
280                current.push(ring.pow(base_twiddle, 2));
281            }
282        }
283        return result;
284    }
285
286    fn butterfly_step_main<const INV: bool>(&self, data: &mut [R_main::Element], step: usize) {
287        let twiddles = if INV {
288            &self.inv_twiddles[step]
289        } else {
290            &self.twiddles[step]
291        };
292        let third_root_of_unity = &self.third_root_of_unity;
293        // let start = std::time::Instant::now();
294        butterfly_loop(self.log3_n, data, step, twiddles, |x, y, z, twiddle1, twiddle2| {
295            if INV {
296                <R_main as CooleyTukeyRadix3Butterfly<R_twiddle>>::inv_butterfly(
297                    &self.hom,
298                    x,
299                    y,
300                    z,
301                    third_root_of_unity,
302                    twiddle1,
303                    twiddle2,
304                )
305            } else {
306                <R_main as CooleyTukeyRadix3Butterfly<R_twiddle>>::butterfly(
307                    &self.hom,
308                    x,
309                    y,
310                    z,
311                    third_root_of_unity,
312                    twiddle1,
313                    twiddle2,
314                )
315            }
316        });
317        // let end = std::time::Instant::now();
318        // BUTTERFLY_RADIX3_TIMES[step].fetch_add((end - start).as_micros() as usize,
319        // std::sync::atomic::Ordering::Relaxed);
320    }
321
322    fn fft_impl(&self, data: &mut [R_main::Element]) {
323        for i in 0..data.len() {
324            <R_main as CooleyTukeyRadix3Butterfly<R_twiddle>>::prepare_for_fft(
325                self.hom.codomain().get_ring(),
326                &mut data[i],
327            );
328        }
329        for step in 0..self.log3_n {
330            self.butterfly_step_main::<false>(data, step);
331        }
332    }
333
334    fn inv_fft_impl(&self, data: &mut [R_main::Element]) {
335        for i in 0..data.len() {
336            <R_main as CooleyTukeyRadix3Butterfly<R_twiddle>>::prepare_for_inv_fft(
337                self.hom.codomain().get_ring(),
338                &mut data[i],
339            );
340        }
341        for step in (0..self.log3_n).rev() {
342            self.butterfly_step_main::<true>(data, step);
343        }
344        let n_inv = self
345            .hom
346            .domain()
347            .invert(&self.hom.domain().int_hom().map(self.len() as i32))
348            .unwrap();
349        for i in 0..data.len() {
350            self.hom.mul_assign_ref_map(&mut data[i], &n_inv);
351        }
352    }
353}
354
355impl<R_main, R_twiddle, H, A> Clone for CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A>
356where
357    R_main: ?Sized + RingBase,
358    R_twiddle: ?Sized + RingBase + DivisibilityRing,
359    H: Homomorphism<R_twiddle, R_main> + Clone,
360    A: Allocator + Clone,
361{
362    fn clone(&self) -> Self {
363        Self {
364            hom: self.hom.clone(),
365            inv_twiddles: self
366                .inv_twiddles
367                .iter()
368                .map(|list| list.iter().map(|x| self.hom.domain().clone_el(x)).collect())
369                .collect(),
370            twiddles: self
371                .twiddles
372                .iter()
373                .map(|list| list.iter().map(|x| self.hom.domain().clone_el(x)).collect())
374                .collect(),
375            main_root_of_unity: self.hom.codomain().clone_el(&self.main_root_of_unity),
376            third_root_of_unity: self.hom.domain().clone_el(&self.third_root_of_unity),
377            log3_n: self.log3_n,
378            allocator: self.allocator.clone(),
379        }
380    }
381}
382
383impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A>
384where
385    R_main: ?Sized + RingBase,
386    R_twiddle: ?Sized + RingBase + DivisibilityRing,
387    H: Homomorphism<R_twiddle, R_main>,
388    A: Allocator,
389{
390    fn len(&self) -> usize {
391        if self.log3_n == 0 {
392            return 1;
393        }
394        let result = self.twiddles[self.log3_n - 1].len() / 2 * 3;
395        debug_assert_eq!(ZZ.pow(3, self.log3_n) as usize, result);
396        return result;
397    }
398
399    fn unordered_fft<V, S>(&self, mut values: V, ring: S)
400    where
401        V: SwappableVectorViewMut<R_main::Element>,
402        S: RingStore<Type = R_main> + Copy,
403    {
404        assert!(ring.get_ring() == self.hom.codomain().get_ring(), "unsupported ring");
405        assert_eq!(self.len(), values.len());
406        if let Some(data) = values.as_slice_mut() {
407            self.fft_impl(data);
408        } else {
409            let mut data = Vec::with_capacity_in(self.len(), &self.allocator);
410            data.extend(values.clone_ring_els(ring).iter());
411            self.fft_impl(&mut data);
412            for (i, x) in data.into_iter().enumerate() {
413                *values.at_mut(i) = x;
414            }
415        }
416    }
417
418    fn unordered_inv_fft<V, S>(&self, mut values: V, ring: S)
419    where
420        V: SwappableVectorViewMut<R_main::Element>,
421        S: RingStore<Type = R_main> + Copy,
422    {
423        assert!(ring.get_ring() == self.hom.codomain().get_ring(), "unsupported ring");
424        assert_eq!(self.len(), values.len());
425        if let Some(data) = values.as_slice_mut() {
426            self.inv_fft_impl(data);
427        } else {
428            let mut data = Vec::with_capacity_in(self.len(), &self.allocator);
429            data.extend(values.clone_ring_els(ring).iter());
430            self.inv_fft_impl(&mut data);
431            for (i, x) in data.into_iter().enumerate() {
432                *values.at_mut(i) = x;
433            }
434        }
435    }
436
437    fn root_of_unity<S>(&self, ring: S) -> &R_main::Element
438    where
439        S: RingStore<Type = R_main> + Copy,
440    {
441        assert!(ring.get_ring() == self.hom.codomain().get_ring(), "unsupported ring");
442        &self.main_root_of_unity
443    }
444
445    fn unordered_fft_permutation(&self, i: usize) -> usize { threeadic_reverse(i, self.log3_n) }
446
447    fn unordered_fft_permutation_inv(&self, i: usize) -> usize { threeadic_reverse(i, self.log3_n) }
448}
449
450#[stability::unstable(feature = "enable")]
451pub trait CooleyTukeyRadix3Butterfly<S: ?Sized + RingBase>: RingBase {
452    /// Should compute `(a, b, c) := (a + t b + t^2 c, a + t z b + t^2 z^2 c, a + t z^2 b + t^2 z
453    /// c)`.
454    ///
455    /// Here `z` is a third root of unity (i.e. `z^2 + z + 1 = 0`) and `t` is the twiddle factor.
456    /// The function should be given `z, t, t^2 z^2`.
457    ///
458    /// It is guaranteed that the input elements are either outputs of
459    /// [`CooleyTukeyRadix3Butterfly::butterfly()`] or of
460    /// [`CooleyTukeyRadix3Butterfly::prepare_for_fft()`].
461    fn butterfly<H: Homomorphism<S, Self>>(
462        hom: H,
463        a: &mut Self::Element,
464        b: &mut Self::Element,
465        c: &mut Self::Element,
466        z: &S::Element,
467        t: &S::Element,
468        t_sqr_z_sqr: &S::Element,
469    );
470
471    /// Should compute `(a, b, c) := (a + b + c, t (a + z^2 b + z c), t^2 (a + z b + z^2 c))`.
472    ///
473    /// Here `z` is a third root of unity (i.e. `z^2 + z + 1 = 0`) and `t` is the twiddle factor.
474    /// The function should be given `z, t, t^2`.
475    ///
476    /// It is guaranteed that the input elements are either outputs of
477    /// [`CooleyTukeyRadix3Butterfly::inv_butterfly()`] or of
478    /// [`CooleyTukeyRadix3Butterfly::prepare_for_inv_fft()`].
479    fn inv_butterfly<H: Homomorphism<S, Self>>(
480        hom: H,
481        a: &mut Self::Element,
482        b: &mut Self::Element,
483        c: &mut Self::Element,
484        z: &S::Element,
485        t: &S::Element,
486        t_sqr: &S::Element,
487    );
488
489    /// Possibly pre-processes elements before the FFT starts. Here you can
490    /// bring ring element into a certain form, and assume during
491    /// [`CooleyTukeyRadix3Butterfly::butterfly()`] that the inputs are in this form.
492    #[inline(always)]
493    fn prepare_for_fft(&self, _value: &mut Self::Element) {}
494
495    /// Possibly pre-processes elements before the inverse FFT starts. Here you can
496    /// bring ring element into a certain form, and assume during
497    /// [`CooleyTukeyRadix3Butterfly::inv_butterfly()`] that the inputs are in this form.
498    #[inline(always)]
499    fn prepare_for_inv_fft(&self, _value: &mut Self::Element) {}
500}
501
502impl<R: ?Sized + RingBase, S: ?Sized + RingBase> CooleyTukeyRadix3Butterfly<S> for R {
503    default fn butterfly<H: Homomorphism<S, Self>>(
504        hom: H,
505        a: &mut Self::Element,
506        b: &mut Self::Element,
507        c: &mut Self::Element,
508        z: &S::Element,
509        t: &S::Element,
510        t_sqr_z_sqr: &S::Element,
511    ) {
512        let ring = hom.codomain();
513        hom.mul_assign_ref_map(b, t); // this is now `t b`
514        hom.mul_assign_ref_map(c, t_sqr_z_sqr); // this is now `t^2 z^2 c`
515        let b_ = hom.mul_ref_map(b, z); // this is now `t z b`
516        let c_ = hom.mul_ref_map(c, z); // this is now `t^2 c z`
517        let s1 = ring.add_ref(b, &c_); // this is now `t b + t^2 c`
518        let s2 = ring.add_ref(&b_, c); // this is now `t z b + t^2 z^2 c`
519        let s3 = ring.add_ref(&s1, &s2); // this is now `-(t z^2 b + t^2 z c)`
520        *b = ring.add_ref_fst(a, s2); // this is now `a + t z b + t^2 z^2 c`
521        *c = ring.sub_ref_fst(a, s3); // this is now `a + t z^2 b + t^2 z c`
522        ring.add_assign(a, s1); // this is now `a + t b + t^2 c`
523    }
524
525    default fn inv_butterfly<H: Homomorphism<S, Self>>(
526        hom: H,
527        a: &mut Self::Element,
528        b: &mut Self::Element,
529        c: &mut Self::Element,
530        z: &S::Element,
531        t: &S::Element,
532        t_sqr: &S::Element,
533    ) {
534        let ring = hom.codomain();
535        let b_ = hom.mul_ref_map(b, z); // this is now `z b`
536        let s1 = ring.add_ref(b, c); // this is now `b + c`
537        let s2 = ring.add_ref(&b_, c); // this is now `z b + c`
538        let s2_ = hom.mul_ref_snd_map(s2, z); // this is now `z^2 b + z c`
539        let s3 = ring.add_ref(&s1, &s2_); // this is now `-(z b + z^2 c)`
540        *b = ring.add_ref(a, &s2_); // this is now `a + z^2 b + z c`
541        *c = ring.sub_ref(a, &s3); // this is now `a + z b + z^2 c`
542        ring.add_assign(a, s1); // this is now `a + b + c`
543        hom.mul_assign_ref_map(b, t); // this is now `t (a + z^2 b + z c)`
544        hom.mul_assign_ref_map(c, t_sqr); // this is now `t^2 (a + z b + z^2 c`
545    }
546
547    /// Possibly pre-processes elements before the FFT starts. Here you can bring ring element
548    /// into a certain form, and assume during [`CooleyTukeyRadix3Butterfly::butterfly()`]
549    /// that the inputs are in this form.
550    #[inline(always)]
551    default fn prepare_for_fft(&self, _value: &mut Self::Element) {}
552
553    /// Possibly pre-processes elements before the inverse FFT starts. Here you can bring ring
554    /// element into a certain form, and assume during
555    /// [`CooleyTukeyRadix3Butterfly::inv_butterfly()`] that the inputs are in this form.
556    #[inline(always)]
557    default fn prepare_for_inv_fft(&self, _value: &mut Self::Element) {}
558}
559
560impl<H, A> FFTErrorEstimate for CooleyTukeyRadix3FFT<Complex64Base, Complex64Base, H, A>
561where
562    H: Homomorphism<Complex64Base, Complex64Base>,
563    A: Allocator,
564{
565    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
566        // the butterfly performs two multiplications with roots of unity, and then two additions
567        let multiply_absolute_error = 2.0 * input_bound * root_of_unity_error() + input_bound * f64::EPSILON;
568        let addition_absolute_error = 2.0 * input_bound * f64::EPSILON;
569        let butterfly_absolute_error = multiply_absolute_error + addition_absolute_error;
570        // the operator inf-norm of the FFT is its length
571        return 2.0 * self.len() as f64 * butterfly_absolute_error + self.len() as f64 * input_error;
572    }
573}
574
575#[cfg(test)]
576use std::array::from_fn;
577
578#[cfg(test)]
579use crate::rings::finite::FiniteRingStore;
580#[cfg(test)]
581use crate::rings::zn::zn_64::*;
582#[cfg(test)]
583use crate::rings::zn::zn_static::Fp;
584
585#[test]
586fn test_radix3_butterflies() {
587    let log3_n = 3;
588    let ring = Zn::new(109);
589    let ring_fastmul = ZnFastmul::new(ring).unwrap();
590    let int_hom = ring.int_hom();
591    let i = |x| int_hom.map(x);
592    let zeta = i(97);
593    let zeta_inv = ring.invert(&zeta).unwrap();
594    let fft = CooleyTukeyRadix3FFT::new_with_hom(
595        ring.into_can_hom(ring_fastmul).ok().unwrap(),
596        ring_fastmul.coerce(&ring, zeta),
597        log3_n,
598    );
599
600    const LEN: usize = 27;
601    let data: [_; LEN] = from_fn(|j| i(j as i32));
602    let expected_std_order = |step: usize, group_idx: usize, value_idx: usize| {
603        ring.sum((0..ZZ.pow(3, step)).map(|k| {
604            ring.mul(
605                ring.pow(zeta_inv, value_idx * (k * ZZ.pow(3, log3_n - step)) as usize),
606                data[group_idx + (k * ZZ.pow(3, log3_n - step)) as usize],
607            )
608        }))
609    };
610    let expected_threeadic_reverse = |step: usize| {
611        from_fn(|i| {
612            expected_std_order(
613                step,
614                i % ZZ.pow(3, log3_n - step) as usize,
615                threeadic_reverse(i / ZZ.pow(3, log3_n - step) as usize, step),
616            )
617        })
618    };
619    let begin = expected_threeadic_reverse(0);
620    for (a, e) in data.iter().zip(begin.iter()) {
621        assert_el_eq!(ring, a, e);
622    }
623
624    let mut actual = data;
625    for i in 0..log3_n {
626        fft.butterfly_step_main::<false>(&mut actual, i);
627        let expected: [ZnEl; LEN] = expected_threeadic_reverse(i + 1);
628        for (a, e) in actual.iter().zip(expected.iter()) {
629            assert_el_eq!(ring, a, e);
630        }
631    }
632}
633
634#[test]
635fn test_radix3_inv_fft() {
636    let log3_n = 3;
637    let ring = Zn::new(109);
638    let ring_fastmul = ZnFastmul::new(ring).unwrap();
639    let zeta = ring.int_hom().map(97);
640    let fft = CooleyTukeyRadix3FFT::new_with_hom(
641        ring.into_can_hom(ring_fastmul).ok().unwrap(),
642        ring_fastmul.coerce(&ring, zeta),
643        log3_n,
644    );
645
646    let data = (0..ZZ.pow(3, log3_n))
647        .map(|x| ring.int_hom().map(x as i32))
648        .collect::<Vec<_>>();
649    let mut actual = data.clone();
650    fft.unordered_fft(&mut actual, &ring);
651    fft.unordered_inv_fft(&mut actual, &ring);
652
653    for i in 0..data.len() {
654        assert_el_eq!(&ring, &data[i], &actual[i]);
655    }
656}
657
658#[test]
659fn test_size_1_fft() {
660    let ring = Fp::<17>::RING;
661    let fft = CooleyTukeyRadix3FFT::for_zn(&ring, 0)
662        .unwrap()
663        .change_ring(ring.identity())
664        .0;
665    let values: [u64; 1] = [3];
666    let mut work = values;
667    fft.unordered_fft(&mut work, ring);
668    assert_eq!(&work, &values);
669    fft.unordered_inv_fft(&mut work, ring);
670    assert_eq!(&work, &values);
671    assert_eq!(0, fft.unordered_fft_permutation(0));
672    assert_eq!(0, fft.unordered_fft_permutation_inv(0));
673}
674
675#[cfg(any(test, feature = "generic_tests"))]
676pub mod generic_tests {
677    use super::*;
678
679    pub fn test_cooley_tuckey_radix3_butterfly<R: RingStore, S: RingStore, I: Iterator<Item = El<R>>>(
680        ring: R,
681        base: S,
682        edge_case_elements: I,
683        test_zeta: &El<S>,
684        test_twiddle: &El<S>,
685    ) where
686        R::Type: CanHomFrom<S::Type>,
687        S::Type: DivisibilityRing,
688    {
689        assert!(base.is_zero(&base.sum([
690            base.one(),
691            base.clone_el(&test_zeta),
692            base.pow(base.clone_el(&test_zeta), 2)
693        ])));
694        let test_inv_twiddle = base.invert(&test_twiddle).unwrap();
695        let elements = edge_case_elements.collect::<Vec<_>>();
696        let hom = ring.can_hom(&base).unwrap();
697
698        for a in &elements {
699            for b in &elements {
700                for c in &elements {
701                    let [mut x, mut y, mut z] = [ring.clone_el(a), ring.clone_el(b), ring.clone_el(c)];
702                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_fft(ring.get_ring(), &mut x);
703                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_fft(ring.get_ring(), &mut y);
704                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_fft(ring.get_ring(), &mut z);
705                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::butterfly(
706                        &hom,
707                        &mut x,
708                        &mut y,
709                        &mut z,
710                        &test_zeta,
711                        &test_twiddle,
712                        &base.pow(base.mul_ref(&test_twiddle, &test_zeta), 2),
713                    );
714                    let mut t = hom.map_ref(&test_twiddle);
715                    assert_el_eq!(
716                        ring,
717                        ring.add_ref_fst(a, ring.mul_ref_snd(ring.add_ref_fst(b, ring.mul_ref(c, &t)), &t)),
718                        &x
719                    );
720                    ring.mul_assign(&mut t, hom.map_ref(&test_zeta));
721                    assert_el_eq!(
722                        ring,
723                        ring.add_ref_fst(a, ring.mul_ref_snd(ring.add_ref_fst(b, ring.mul_ref(c, &t)), &t)),
724                        &y
725                    );
726                    ring.mul_assign(&mut t, hom.map_ref(&test_zeta));
727                    assert_el_eq!(
728                        ring,
729                        ring.add_ref_fst(a, ring.mul_ref_snd(ring.add_ref_fst(b, ring.mul_ref(c, &t)), &t)),
730                        &z
731                    );
732
733                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_inv_fft(ring.get_ring(), &mut x);
734                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_inv_fft(ring.get_ring(), &mut y);
735                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::prepare_for_inv_fft(ring.get_ring(), &mut z);
736                    <R::Type as CooleyTukeyRadix3Butterfly<S::Type>>::inv_butterfly(
737                        &hom,
738                        &mut x,
739                        &mut y,
740                        &mut z,
741                        &test_zeta,
742                        &test_inv_twiddle,
743                        &base.pow(base.clone_el(&test_inv_twiddle), 2),
744                    );
745                    assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(a, 3), &x);
746                    assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(b, 3), &y);
747                    assert_el_eq!(ring, ring.int_hom().mul_ref_fst_map(c, 3), &z);
748                }
749            }
750        }
751    }
752}
753
754#[test]
755fn test_butterfly() {
756    let ring = Fp::<109>::RING;
757    generic_tests::test_cooley_tuckey_radix3_butterfly(ring, ring, ring.elements().step_by(10), &63, &97);
758}