Skip to main content

feanor_math/algorithms/fft/
bluestein.rs

1use std::alloc::{Allocator, Global};
2use std::fmt::Debug;
3
4use crate::algorithms::fft::FFTAlgorithm;
5use crate::algorithms::fft::complex_fft::*;
6use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
7use crate::algorithms::unity_root::*;
8use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
9use crate::homomorphism::*;
10use crate::integer::IntegerRingStore;
11use crate::primitive_int::*;
12use crate::ring::*;
13use crate::rings::float_complex::*;
14use crate::rings::zn::*;
15use crate::seq::SwappableVectorViewMut;
16
17type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
18
19/// Bluestein's FFT algorithm (also known as Chirp-Z-transform) to compute the Fourier
20/// transform of arbitrary length (including prime numbers).
21///
22/// # Convention
23///
24/// This implementation does not follows the standard convention for the mathematical
25/// DFT, by performing the standard/forward FFT with the inverse root of unity `z^-1`.
26/// In other words, the forward FFT computes
27/// ```text
28///   (a_0, ..., a_(N - 1)) -> (sum_j a_j z^(-ij))_i
29/// ```
30/// where `z` is the primitive `n`-th root of unity returned by [`FFTAlgorithm::root_of_unity()`].
31pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
32where
33    R_main: ?Sized + RingBase,
34    R_twiddle: ?Sized + RingBase + DivisibilityRing,
35    H: Homomorphism<R_twiddle, R_main> + Clone,
36    A: Allocator + Clone,
37{
38    m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
39    b_unordered_fft: Vec<R_twiddle::Element>,
40    twiddles: Vec<R_twiddle::Element>,
41    root_of_unity_n: R_main::Element,
42    n: usize,
43}
44
45impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
46where
47    H: Homomorphism<Complex64Base, Complex64Base> + Clone,
48    A: Allocator + Clone,
49{
50    /// Creates an [`BluesteinFFT`] for the complex field, using the given homomorphism
51    /// to connect the ring implementation for twiddles with the main ring implementation.
52    ///
53    /// This function is mainly provided for parity with other rings, since in the complex case
54    /// it currently does not make much sense to use a different homomorphism than the identity.
55    /// Hence, it is simpler to use [`BluesteinFFT::for_complex()`].
56    pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self {
57        let ZZ = StaticRing::<i64>::RING;
58        let CC = Complex64::RING;
59        let n_i64: i64 = n.try_into().unwrap();
60        let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
61        Self::new_with_pows_with_hom(
62            hom,
63            |x| CC.root_of_unity(x, 2 * n_i64),
64            |x| CC.root_of_unity(x, 1 << log2_m),
65            n,
66            log2_m,
67            tmp_mem_allocator,
68        )
69    }
70}
71
72impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
73where
74    R: RingStore<Type = Complex64Base> + Clone,
75    A: Allocator + Clone,
76{
77    /// Creates an [`BluesteinFFT`] for the complex field.
78    pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self {
79        Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
80    }
81}
82
83impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
84where
85    R: RingStore + Clone,
86    R::Type: DivisibilityRing,
87    A: Allocator + Clone,
88{
89    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
90    ///
91    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
92    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
93    ///  
94    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
95    /// will incur avoidable precision loss.
96    pub fn new(
97        ring: R,
98        root_of_unity_2n: El<R>,
99        root_of_unity_m: El<R>,
100        n: usize,
101        log2_m: usize,
102        tmp_mem_allocator: A,
103    ) -> Self {
104        Self::new_with_hom(
105            ring.into_identity(),
106            root_of_unity_2n,
107            root_of_unity_m,
108            n,
109            log2_m,
110            tmp_mem_allocator,
111        )
112    }
113
114    /// Creates an [`BluesteinFFT`] for the given ring, using the passed function to
115    /// provide the necessary roots of unity.
116    ///
117    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
118    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
119    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m > 2n`.
120    pub fn new_with_pows<F, G>(
121        ring: R,
122        root_of_unity_2n_pows: F,
123        root_of_unity_m_pows: G,
124        n: usize,
125        log2_m: usize,
126        tmp_mem_allocator: A,
127    ) -> Self
128    where
129        F: FnMut(i64) -> El<R>,
130        G: FnMut(i64) -> El<R>,
131    {
132        Self::new_with_pows_with_hom(
133            ring.into_identity(),
134            root_of_unity_2n_pows,
135            root_of_unity_m_pows,
136            n,
137            log2_m,
138            tmp_mem_allocator,
139        )
140    }
141
142    /// Creates an [`BluesteinFFT`] for a prime field, assuming it has suitable roots of
143    /// unity.
144    ///
145    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
146    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
147    pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
148    where
149        R::Type: ZnRing,
150    {
151        Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
152    }
153}
154
155impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
156where
157    R_main: ?Sized + RingBase,
158    R_twiddle: ?Sized + RingBase + DivisibilityRing,
159    H: Homomorphism<R_twiddle, R_main> + Clone,
160    A: Allocator + Clone,
161{
162    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
163    ///
164    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
165    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
166    ///
167    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
168    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
169    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
170    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
171    ///
172    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
173    /// will incur avoidable precision loss.
174    pub fn new_with_hom(
175        hom: H,
176        root_of_unity_2n: R_twiddle::Element,
177        root_of_unity_m: R_twiddle::Element,
178        n: usize,
179        log2_m: usize,
180        tmp_mem_allocator: A,
181    ) -> Self {
182        let hom_copy = hom.clone();
183        let twiddle_ring = hom_copy.domain();
184        return Self::new_with_pows_with_hom(
185            hom,
186            |i: i64| {
187                if i >= 0 {
188                    twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
189                } else {
190                    twiddle_ring
191                        .invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n)))
192                        .unwrap()
193                }
194            },
195            |i: i64| {
196                if i >= 0 {
197                    twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
198                } else {
199                    twiddle_ring
200                        .invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize))
201                        .unwrap()
202                }
203            },
204            n,
205            log2_m,
206            tmp_mem_allocator,
207        );
208    }
209
210    /// Creates an [`BluesteinFFT`] for the given rings, using the given function to create
211    /// the necessary powers of roots of unity.
212    ///
213    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
214    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
215    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m >= 2n`.
216    ///
217    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
218    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
219    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
220    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
221    pub fn new_with_pows_with_hom<F, G>(
222        hom: H,
223        mut root_of_unity_2n_pows: F,
224        mut root_of_unity_m_pows: G,
225        n: usize,
226        log2_m: usize,
227        tmp_mem_allocator: A,
228    ) -> Self
229    where
230        F: FnMut(i64) -> R_twiddle::Element,
231        G: FnMut(i64) -> R_twiddle::Element,
232    {
233        let m_fft_table = CooleyTuckeyFFT::create(hom, &mut root_of_unity_m_pows, log2_m, tmp_mem_allocator);
234        return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
235    }
236
237    /// Creates an [`BluesteinFFT`] for the given prime fields, assuming they have suitable
238    /// roots of unity.
239    ///
240    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
241    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
242    ///
243    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
244    /// precomputed will be stored as elements of `R`, while the main FFT computations will be
245    /// performed in `S`. This allows both implicit ring conversions, and using patterns like
246    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
247    pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
248    where
249        R_twiddle: ZnRing,
250    {
251        let root_of_unity_2n = get_prim_root_of_unity_zn(hom.domain(), 2 * n)?;
252        let log2_m = StaticRing::<i64>::RING
253            .abs_log2_ceil(&(n * 2).try_into().unwrap())
254            .unwrap();
255        let root_of_unity_m = get_prim_root_of_unity_zn(hom.domain(), 1 << log2_m)?;
256        return Some(Self::new_with_hom(
257            hom,
258            root_of_unity_2n,
259            root_of_unity_m,
260            n,
261            log2_m,
262            tmp_mem_allocator,
263        ));
264    }
265
266    /// Most general way to construct a [`BluesteinFFT`].
267    ///
268    /// This function takes a length-`m` base FFT, where `m >= 2m`, and a function
269    /// `root_of_unity_pows`, on input `i`, should return `z^i` for an `n`-th primitive root of
270    /// unity `z`.
271    #[stability::unstable(feature = "enable")]
272    pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
273    where
274        F: FnMut(i64) -> R_twiddle::Element,
275    {
276        let hom = m_fft_table.hom().clone();
277        let m = m_fft_table.len();
278        assert!(m >= 2 * n);
279        assert!(n % 2 == 1);
280        assert!(hom.codomain().is_commutative());
281        assert!(
282            hom.domain().get_ring().is_approximate()
283                || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n)
284        );
285        assert!(
286            hom.codomain().get_ring().is_approximate()
287                || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n)
288        );
289
290        let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
291
292        let half_mod_n = n.div_ceil(2);
293        let mut b: Vec<_> = (0..n)
294            .map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap()))
295            .collect();
296        b.resize_with(m, || hom.domain().zero());
297
298        twiddle_fft.unordered_fft(&mut b, hom.domain());
299
300        let twiddles = (0..n)
301            .map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap()))
302            .collect::<Vec<_>>();
303        let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
304
305        return BluesteinFFT {
306            m_fft_table: twiddle_fft.change_ring(old_hom).0,
307            b_unordered_fft: b,
308            twiddles,
309            root_of_unity_n,
310            n,
311        };
312    }
313
314    /// Computes the FFT of the given values using Bluestein's algorithm, using only the passed
315    /// buffer as temporary storage.
316    ///
317    /// This will not allocate additional memory, as opposed to [`BluesteinFFT::fft()`] etc.
318    ///
319    /// Basically, the idea is to write an FFT of any length (e.g. prime length) as a convolution,
320    /// and compute the convolution efficiently using a power-of-two FFT (e.g. with the Cooley-Tukey
321    /// algorithm).
322    ///
323    /// TODO: At next breaking release, make this private
324    pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
325    where
326        V: SwappableVectorViewMut<R_main::Element>,
327        W: SwappableVectorViewMut<R_main::Element>,
328    {
329        if INV {
330            self.unordered_inv_fft(values, self.ring());
331        } else {
332            self.unordered_fft(values, self.ring());
333        }
334    }
335
336    fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
337    where
338        V: SwappableVectorViewMut<R_main::Element>,
339        A2: Allocator,
340    {
341        assert_eq!(values.len(), self.n);
342        assert_eq!(buffer.len(), self.m_fft_table.len());
343
344        let ring = self.m_fft_table.hom().codomain();
345
346        // set buffer to the zero-padded sequence values_i * z^(-i^2/2)
347        for i in 0..self.n {
348            let value = if INV {
349                values.at((self.n - i) % self.n)
350            } else {
351                values.at(i)
352            };
353            buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
354        }
355        for i in self.n..self.m_fft_table.len() {
356            buffer[i] = ring.zero();
357        }
358
359        self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
360        for i in 0..self.m_fft_table.len() {
361            self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
362        }
363        self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
364
365        // make the normal convolution into a cyclic convolution of length n by taking it modulo
366        // `x^n - 1`
367        let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
368        for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
369            ring.add_assign_ref(a, b);
370        }
371
372        // write values back, and multiply them with a twiddle factor
373        for (i, x) in buffer.into_iter().enumerate().take(self.n) {
374            *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
375        }
376
377        if INV {
378            // finally, scale by 1/n
379            let scale = self.hom().map(
380                self.hom()
381                    .domain()
382                    .checked_div(
383                        &self.hom().domain().one(),
384                        &self.hom().domain().int_hom().map(self.n.try_into().unwrap()),
385                    )
386                    .unwrap(),
387            );
388            for i in 0..values.len() {
389                ring.mul_assign_ref(values.at_mut(i), &scale);
390            }
391        }
392    }
393
394    /// Returns a reference to the allocator currently used for temporary allocations by this FFT.
395    #[stability::unstable(feature = "enable")]
396    pub fn allocator(&self) -> &A { self.m_fft_table.allocator() }
397
398    /// Returns the ring over which this object can compute FFTs.
399    #[stability::unstable(feature = "enable")]
400    pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore { self.hom().codomain() }
401
402    /// Returns a reference to the homomorphism that is used to map the stored twiddle
403    /// factors into main ring, over which FFTs are computed.
404    #[stability::unstable(feature = "enable")]
405    pub fn hom(&self) -> &H { self.m_fft_table.hom() }
406}
407
408impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
409where
410    R_main: ?Sized + RingBase,
411    R_twiddle: ?Sized + RingBase + DivisibilityRing,
412    H: Homomorphism<R_twiddle, R_main> + Clone,
413    A: Allocator + Clone,
414{
415    fn eq(&self, other: &Self) -> bool {
416        self.ring().get_ring() == other.ring().get_ring()
417            && self.n == other.n
418            && self
419                .ring()
420                .eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
421    }
422}
423
424impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
425where
426    R_main: ?Sized + RingBase + Debug,
427    R_twiddle: ?Sized + RingBase + DivisibilityRing,
428    H: Homomorphism<R_twiddle, R_main> + Clone,
429    A: Allocator + Clone,
430{
431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432        f.debug_struct("BluesteinFFT")
433            .field("ring", &self.ring().get_ring())
434            .field("n", &self.n)
435            .field("root_of_unity_n", &self.ring().format(self.root_of_unity(self.ring())))
436            .finish()
437    }
438}
439
440impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
441where
442    R_main: ?Sized + RingBase,
443    R_twiddle: ?Sized + RingBase + DivisibilityRing,
444    H: Homomorphism<R_twiddle, R_main> + Clone,
445    A: Allocator + Clone,
446{
447    fn len(&self) -> usize { self.n }
448
449    fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
450        assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
451        &self.root_of_unity_n
452    }
453
454    fn unordered_fft_permutation(&self, i: usize) -> usize { i }
455
456    fn unordered_fft_permutation_inv(&self, i: usize) -> usize { i }
457
458    fn unordered_fft<V, S>(&self, values: V, ring: S)
459    where
460        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
461        S: RingStore<Type = R_main> + Copy,
462    {
463        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
464        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
465        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
466        self.fft_base_impl::<_, _, false>(values, buffer);
467    }
468
469    fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
470    where
471        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
472        S: RingStore<Type = R_main> + Copy,
473    {
474        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
475        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
476        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
477        self.fft_base_impl::<_, _, true>(values, buffer);
478    }
479
480    fn fft<V, S>(&self, values: V, ring: S)
481    where
482        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
483        S: RingStore<Type = R_main> + Copy,
484    {
485        self.unordered_fft(values, ring);
486    }
487
488    fn inv_fft<V, S>(&self, values: V, ring: S)
489    where
490        V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
491        S: RingStore<Type = R_main> + Copy,
492    {
493        self.unordered_inv_fft(values, ring);
494    }
495}
496
497impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
498where
499    H: Homomorphism<Complex64Base, Complex64Base> + Clone,
500    A: Allocator + Clone,
501{
502    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
503        let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
504        let error_after_fft = self
505            .m_fft_table
506            .expected_absolute_error(input_bound, error_after_twiddling);
507        let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1.0, root_of_unity_error());
508        // now the values are increased by up to a factor of m, so use `input_bound * m` instead
509        let new_input_bound = input_bound * self.m_fft_table.len() as f64;
510        let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
511        let error_after_mul = new_input_bound * b_bitreverse_fft_error
512            + b_bitreverse_fft_bound * error_after_fft
513            + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
514        let error_after_inv_fft = self
515            .m_fft_table
516            .expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul)
517            / self.m_fft_table.len() as f64;
518        let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
519        return error_end;
520    }
521}
522
523#[cfg(test)]
524use crate::rings::zn::zn_static::*;
525
526#[test]
527fn test_fft_base() {
528    let ring = Zn::<241>::RING;
529    // a 5-th root of unity is 91
530    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
531    let mut values = [1, 3, 2, 0, 7];
532    let mut buffer = [0; 16];
533    fft.fft_base::<_, _, false>(&mut values, &mut buffer);
534    let expected = [13, 137, 202, 206, 170];
535    assert_eq!(expected, values);
536}
537
538#[test]
539fn test_fft_fastmul() {
540    let ring = zn_64::Zn::new(241);
541    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
542    let fft = BluesteinFFT::new_with_hom(
543        ring.can_hom(&fastmul_ring).unwrap(),
544        fastmul_ring.int_hom().map(36),
545        fastmul_ring.int_hom().map(111),
546        5,
547        4,
548        Global,
549    );
550    let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
551    fft.fft(&mut values, ring);
552    let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
553    for i in 0..values.len() {
554        assert_el_eq!(ring, expected[i], values[i]);
555    }
556}
557
558#[test]
559fn test_inv_fft_base() {
560    let ring = Zn::<241>::RING;
561    // a 5-th root of unity is 91
562    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
563    let values = [1, 3, 2, 0, 7];
564    let mut work = values;
565    let mut buffer = [0; 16];
566    fft.fft_base::<_, _, false>(&mut work, &mut buffer);
567    fft.fft_base::<_, _, true>(&mut work, &mut buffer);
568    assert_eq!(values, work);
569}
570
571#[test]
572fn test_approximate_fft() {
573    let CC = Complex64::RING;
574    for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
575        let fft = BluesteinFFT::for_complex(&CC, p, Global);
576        let mut array = (0..p)
577            .map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap()))
578            .collect::<Vec<_>>();
579        fft.fft(&mut array, CC);
580        let err = fft.expected_absolute_error(1.0, 0.0);
581        assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
582        assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
583        for i in 2..fft.len() {
584            assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
585        }
586    }
587}
588
589#[cfg(test)]
590const BENCH_SIZE: usize = 1009;
591
592#[bench]
593fn bench_bluestein(bencher: &mut test::Bencher) {
594    let ring = zn_64::Zn::new(18597889);
595    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
596    let embedding = ring.can_hom(&fastmul_ring).unwrap();
597    let root_of_unity = fastmul_ring.coerce(&ring, get_prim_root_of_unity_zn(&ring, 2 * BENCH_SIZE).unwrap());
598    let fft = BluesteinFFT::new_with_hom(
599        embedding.clone(),
600        root_of_unity,
601        get_prim_root_of_unity_zn(&fastmul_ring, 1 << 11).unwrap(),
602        BENCH_SIZE,
603        11,
604        Global,
605    );
606    let data = (0..BENCH_SIZE)
607        .map(|i| ring.int_hom().map(i as i32))
608        .collect::<Vec<_>>();
609    let mut copy = Vec::with_capacity(BENCH_SIZE);
610    bencher.iter(|| {
611        copy.clear();
612        copy.extend(data.iter().map(|x| ring.clone_el(x)));
613        fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
614        fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
615        assert_el_eq!(ring, copy[0], data[0]);
616    });
617}