feanor_math/algorithms/fft/
bluestein.rs

1use std::alloc::Allocator;
2use std::alloc::Global;
3use std::fmt::Debug;
4
5use crate::algorithms::fft::FFTAlgorithm;
6use crate::algorithms::unity_root::*;
7use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::integer::IntegerRingStore;
10use crate::primitive_int::*;
11use crate::ring::*;
12use crate::homomorphism::*;
13use crate::rings::zn::*;
14use crate::rings::float_complex::*;
15use crate::algorithms::fft::complex_fft::*;
16use crate::seq::SwappableVectorViewMut;
17
18type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
19
20///
21/// Bluestein's FFT algorithm (also known as Chirp-Z-transform) to compute the Fourier
22/// transform of arbitrary length (including prime numbers).
23/// 
24/// # Convention
25/// 
26/// This implementation does not follows the standard convention for the mathematical
27/// DFT, by performing the standard/forward FFT with the inverse root of unity `z^-1`.
28/// In other words, the forward FFT computes
29/// ```text
30///   (a_0, ..., a_(N - 1)) -> (sum_j a_j z^(-ij))_i
31/// ```
32/// where `z` is the primitive `n`-th root of unity returned by [`FFTAlgorithm::root_of_unity()`].
33/// 
34pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
35    where R_main: ?Sized + RingBase,
36        R_twiddle: ?Sized + RingBase + DivisibilityRing,
37        H: Homomorphism<R_twiddle, R_main> + Clone, 
38        A: Allocator + Clone
39{
40    m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
41    b_unordered_fft: Vec<R_twiddle::Element>,
42    twiddles: Vec<R_twiddle::Element>,
43    root_of_unity_n: R_main::Element,
44    n: usize
45}
46
47impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
48    where H: Homomorphism<Complex64Base, Complex64Base> + Clone, 
49        A: Allocator + Clone
50{
51    ///
52    /// Creates an [`BluesteinFFT`] for the complex field, using the given homomorphism
53    /// to connect the ring implementation for twiddles with the main ring implementation.
54    /// 
55    /// This function is mainly provided for parity with other rings, since in the complex case
56    /// it currently does not make much sense to use a different homomorphism than the identity.
57    /// Hence, it is simpler to use [`BluesteinFFT::for_complex()`].
58    /// 
59    pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self{
60        let ZZ = StaticRing::<i64>::RING;
61        let CC = Complex64::RING;
62        let n_i64: i64 = n.try_into().unwrap();
63        let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
64        Self::new_with_pows_with_hom(hom, |x| CC.root_of_unity(x, 2 * n_i64), |x| CC.root_of_unity(x, 1 << log2_m), n, log2_m, tmp_mem_allocator)
65    }
66}
67
68impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
69    where R: RingStore<Type = Complex64Base> + Clone, 
70        A: Allocator + Clone
71{
72    ///
73    /// Creates an [`BluesteinFFT`] for the complex field.
74    /// 
75    pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self{
76        Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
77    }
78}
79
80impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
81    where R: RingStore + Clone,
82        R::Type: DivisibilityRing,
83        A: Allocator + Clone
84{
85    ///
86    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
87    /// 
88    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
89    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
90    ///  
91    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
92    /// will incur avoidable precision loss.
93    /// 
94    pub fn new(ring: R, root_of_unity_2n: El<R>, root_of_unity_m: El<R>, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
95        Self::new_with_hom(ring.into_identity(), root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator)
96    }
97
98    ///
99    /// Creates an [`BluesteinFFT`] for the given ring, using the passed function to
100    /// provide the necessary roots of unity.
101    /// 
102    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
103    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
104    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m > 2n`.
105    /// 
106    pub fn new_with_pows<F, G>(ring: R, root_of_unity_2n_pows: F, root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self 
107        where F: FnMut(i64) -> El<R>,
108            G: FnMut(i64) -> El<R>
109    {
110        Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_2n_pows, root_of_unity_m_pows, n, log2_m, tmp_mem_allocator)
111    }
112
113    ///
114    /// Creates an [`BluesteinFFT`] for a prime field, assuming it has suitable roots of
115    /// unity.
116    /// 
117    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
118    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
119    /// 
120    pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
121        where R::Type: ZnRing
122    {
123        Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
124    }
125}
126
127impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
128    where R_main: ?Sized + RingBase,
129        R_twiddle: ?Sized + RingBase + DivisibilityRing,
130        H: Homomorphism<R_twiddle, R_main> + Clone, 
131        A: Allocator + Clone
132{
133    ///
134    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
135    /// 
136    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
137    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
138    /// 
139    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
140    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
141    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
142    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
143    /// 
144    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
145    /// will incur avoidable precision loss.
146    /// 
147    pub fn new_with_hom(hom: H, root_of_unity_2n: R_twiddle::Element, root_of_unity_m: R_twiddle::Element, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
148        let hom_copy = hom.clone();
149        let twiddle_ring = hom_copy.domain();
150        return Self::new_with_pows_with_hom(
151            hom, 
152            |i: i64| if i >= 0 {
153                twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
154            } else {
155                twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n))).unwrap()
156            }, 
157            |i: i64| if i >= 0 {
158                twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
159            } else {
160                twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize)).unwrap()
161            }, 
162            n, 
163            log2_m, 
164            tmp_mem_allocator
165        );
166    }
167
168    ///
169    /// Creates an [`BluesteinFFT`] for the given rings, using the given function to create
170    /// the necessary powers of roots of unity.
171    /// 
172    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
173    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
174    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m >= 2n`.
175    /// 
176    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
177    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
178    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
179    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
180    /// 
181    pub fn new_with_pows_with_hom<F, G>(hom: H, mut root_of_unity_2n_pows: F, mut root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
182        where F: FnMut(i64) -> R_twiddle::Element,
183            G: FnMut(i64) -> R_twiddle::Element
184    {
185        let m_fft_table = CooleyTuckeyFFT::create(
186            hom, 
187            &mut root_of_unity_m_pows, 
188            log2_m, 
189            tmp_mem_allocator
190        );
191        return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
192    }
193
194    ///
195    /// Creates an [`BluesteinFFT`] for the given prime fields, assuming they have suitable
196    /// roots of unity.
197    /// 
198    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
199    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
200    /// 
201    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
202    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
203    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
204    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
205    /// 
206    pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
207        where R_twiddle: ZnRing
208    {
209        let ring_as_field = hom.domain().as_field().ok().unwrap();
210        let root_of_unity_2n = ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity(&ring_as_field, 2 * n)?);
211        let log2_m = StaticRing::<i64>::RING.abs_log2_ceil(&(n * 2).try_into().unwrap()).unwrap();
212        let root_of_unity_m = ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity_pow2(&ring_as_field, log2_m)?);
213        return Some(Self::new_with_hom(hom, root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator));
214    }
215    
216    ///
217    /// Most general way to construct a [`BluesteinFFT`].
218    /// 
219    /// This function takes a length-`m` base FFT, where `m >= 2m`, and a function `root_of_unity_pows`,
220    /// on input `i`, should return `z^i` for an `n`-th primitive root of unity `z`.
221    /// 
222    #[stability::unstable(feature = "enable")]
223    pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
224        where F: FnMut(i64) -> R_twiddle::Element
225    {
226        let hom = m_fft_table.hom().clone();
227        let m = m_fft_table.len();
228        assert!(m >= 2 * n);
229        assert!(n % 2 == 1);
230        assert!(hom.codomain().is_commutative());
231        assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n));
232        assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n));
233
234        let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
235
236        let half_mod_n = (n + 1) / 2;
237        let mut b: Vec<_> = (0..n).map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect();
238        b.resize_with(m, || hom.domain().zero());
239        
240        twiddle_fft.unordered_fft(&mut b, hom.domain());
241
242        let twiddles = (0..n).map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect::<Vec<_>>();
243        let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
244
245        return BluesteinFFT { 
246            m_fft_table: twiddle_fft.change_ring(old_hom).0, 
247            b_unordered_fft: b, 
248            twiddles: twiddles, 
249            root_of_unity_n: root_of_unity_n,
250            n: n
251        };
252    }
253
254    ///
255    /// Computes the FFT of the given values using Bluestein's algorithm, using only the passed
256    /// buffer as temporary storage.
257    /// 
258    /// This will not allocate additional memory, as opposed to [`BluesteinFFT::fft()`] etc.
259    /// 
260    /// Basically, the idea is to write an FFT of any length (e.g. prime length) as a convolution,
261    /// and compute the convolution efficiently using a power-of-two FFT (e.g. with the Cooley-Tukey 
262    /// algorithm).
263    /// 
264    /// TODO: At next breaking release, make this private
265    /// 
266    pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
267        where V: SwappableVectorViewMut<R_main::Element>, 
268            W: SwappableVectorViewMut<R_main::Element>
269    {
270        if INV {
271            self.unordered_inv_fft(values, self.ring());
272        } else {
273            self.unordered_fft(values, self.ring());
274        }
275    }
276
277    fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
278        where V: SwappableVectorViewMut<R_main::Element>,
279            A2: Allocator
280    {
281        assert_eq!(values.len(), self.n);
282        assert_eq!(buffer.len(), self.m_fft_table.len());
283
284        let ring = self.m_fft_table.hom().codomain();
285
286        // set buffer to the zero-padded sequence values_i * z^(-i^2/2)
287        for i in 0..self.n {
288            let value = if INV {
289                values.at((self.n - i) % self.n)
290            } else {
291                values.at(i)
292            };
293            buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
294        }
295        for i in self.n..self.m_fft_table.len() {
296            buffer[i] = ring.zero();
297        }
298 
299        self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
300        for i in 0..self.m_fft_table.len() {
301            self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
302        }
303        self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
304        
305        // make the normal convolution into a cyclic convolution of length n by taking it modulo `x^n - 1`
306        let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
307        for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
308            ring.add_assign_ref(a, b);
309        }
310
311        // write values back, and multiply them with a twiddle factor
312        for (i, x) in buffer.into_iter().enumerate().take(self.n) {
313            *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
314        }
315
316        if INV {
317            // finally, scale by 1/n
318            let scale = self.hom().map(self.hom().domain().checked_div(&self.hom().domain().one(), &self.hom().domain().int_hom().map(self.n.try_into().unwrap())).unwrap());
319            for i in 0..values.len() {
320                ring.mul_assign_ref(values.at_mut(i), &scale);
321            }
322        }
323    }
324
325    ///
326    /// Returns a reference to the allocator currently used for temporary allocations by this FFT.
327    /// 
328    #[stability::unstable(feature = "enable")]
329    pub fn allocator(&self) -> &A {
330        self.m_fft_table.allocator()
331    }
332    
333    ///
334    /// Returns the ring over which this object can compute FFTs.
335    /// 
336    #[stability::unstable(feature = "enable")]
337    pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
338        self.hom().codomain()
339    }
340
341    ///
342    /// Returns a reference to the homomorphism that is used to map the stored twiddle
343    /// factors into main ring, over which FFTs are computed.
344    /// 
345    #[stability::unstable(feature = "enable")]
346    pub fn hom(&self) -> &H {
347        self.m_fft_table.hom()
348    }
349}
350
351impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
352    where R_main: ?Sized + RingBase,
353        R_twiddle: ?Sized + RingBase + DivisibilityRing,
354        H: Homomorphism<R_twiddle, R_main> + Clone, 
355        A: Allocator + Clone
356{
357    fn eq(&self, other: &Self) -> bool {
358        self.ring().get_ring() == other.ring().get_ring() &&
359            self.n == other.n &&
360            self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
361    }
362}
363
364impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
365    where R_main: ?Sized + RingBase + Debug,
366        R_twiddle: ?Sized + RingBase + DivisibilityRing,
367        H: Homomorphism<R_twiddle, R_main> + Clone, 
368        A: Allocator + Clone
369{
370    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371        f.debug_struct("BluesteinFFT")
372            .field("ring", &self.ring().get_ring())
373            .field("n", &self.n)
374            .field("root_of_unity_n", &self.ring().format(&self.root_of_unity(self.ring())))
375            .finish()
376    }
377}
378
379impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
380    where R_main: ?Sized + RingBase,
381        R_twiddle: ?Sized + RingBase + DivisibilityRing,
382        H: Homomorphism<R_twiddle, R_main> + Clone, 
383        A: Allocator + Clone
384{
385    fn len(&self) -> usize {
386        self.n
387    }
388
389    fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
390        assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
391        &self.root_of_unity_n
392    }
393
394    fn unordered_fft_permutation(&self, i: usize) -> usize {
395        i
396    }
397
398    fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
399        i
400    }
401
402    fn unordered_fft<V, S>(&self, values: V, ring: S)
403        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
404            S: RingStore<Type = R_main> + Copy 
405    {
406        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
407        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
408        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
409        self.fft_base_impl::<_, _, false>(values, buffer);
410    }
411
412    fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
413        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
414            S: RingStore<Type = R_main> + Copy 
415    {
416        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
417        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
418        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
419        self.fft_base_impl::<_, _, true>(values, buffer);
420    }
421
422    fn fft<V, S>(&self, values: V, ring: S)
423        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
424            S: RingStore<Type = R_main> + Copy 
425    {
426        self.unordered_fft(values, ring);
427    }
428
429    fn inv_fft<V, S>(&self, values: V, ring: S)
430        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
431            S: RingStore<Type = R_main> + Copy 
432    {
433        self.unordered_inv_fft(values, ring);
434    }
435}
436
437impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
438    where H: Homomorphism<Complex64Base, Complex64Base> + Clone, 
439        A: Allocator + Clone
440{
441    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
442        let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
443        let error_after_fft = self.m_fft_table.expected_absolute_error(input_bound, error_after_twiddling);
444        let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1., root_of_unity_error());
445        // now the values are increased by up to a factor of m, so use `input_bound * m` instead
446        let new_input_bound = input_bound * self.m_fft_table.len() as f64;
447        let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
448        let error_after_mul = new_input_bound * b_bitreverse_fft_error + b_bitreverse_fft_bound * error_after_fft + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
449        let error_after_inv_fft = self.m_fft_table.expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul) / self.m_fft_table.len() as f64;
450        let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
451        return error_end;
452    }
453}
454
455#[cfg(test)]
456use crate::rings::zn::zn_static::*;
457
458#[test]
459fn test_fft_base() {
460    let ring = Zn::<241>::RING;
461    // a 5-th root of unity is 91 
462    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
463    let mut values = [1, 3, 2, 0, 7];
464    let mut buffer = [0; 16];
465    fft.fft_base::<_, _, false>(&mut values, &mut buffer);
466    let expected = [13, 137, 202, 206, 170];
467    assert_eq!(expected, values);
468}
469
470#[test]
471fn test_fft_fastmul() {
472    let ring = zn_64::Zn::new(241);
473    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
474    let fft = BluesteinFFT::new_with_hom(ring.can_hom(&fastmul_ring).unwrap(), fastmul_ring.int_hom().map(36), fastmul_ring.int_hom().map(111), 5, 4, Global);
475    let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
476    fft.fft(&mut values, ring);
477    let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
478    for i in 0..values.len() {
479        assert_el_eq!(ring, expected[i], values[i]);
480    }
481}
482
483#[test]
484fn test_inv_fft_base() {
485    let ring = Zn::<241>::RING;
486    // a 5-th root of unity is 91 
487    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
488    let values = [1, 3, 2, 0, 7];
489    let mut work = values;
490    let mut buffer = [0; 16];
491    fft.fft_base::<_, _, false>(&mut work, &mut buffer);
492    fft.fft_base::<_, _, true>(&mut work, &mut buffer);
493    assert_eq!(values, work);
494}
495
496#[test]
497fn test_approximate_fft() {
498    let CC = Complex64::RING;
499    for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
500        let fft = BluesteinFFT::for_complex(&CC, p, Global);
501        let mut array = (0..p).map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap())).collect::<Vec<_>>();
502        fft.fft(&mut array, CC);
503        let err = fft.expected_absolute_error(1., 0.);
504        assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
505        assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
506        for i in 2..fft.len() {
507            assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
508        }
509    }
510}
511
512#[cfg(test)]
513const BENCH_SIZE: usize = 1009;
514
515#[bench]
516fn bench_bluestein(bencher: &mut test::Bencher) {
517    let ring = zn_64::Zn::new(18597889);
518    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
519    let embedding = ring.can_hom(&fastmul_ring).unwrap();
520    let ring_as_field = ring.as_field().ok().unwrap();
521    let root_of_unity = fastmul_ring.coerce(&ring, ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity(&ring_as_field, 2 * BENCH_SIZE).unwrap()));
522    let fastmul_ring_as_field = fastmul_ring.as_field().ok().unwrap();
523    let fft = BluesteinFFT::new_with_hom(
524        embedding.clone(), 
525        root_of_unity, 
526        fastmul_ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity_pow2(&fastmul_ring_as_field, 11).unwrap()), 
527        BENCH_SIZE, 
528        11, 
529        Global
530    );
531    let data = (0..BENCH_SIZE).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
532    let mut copy = Vec::with_capacity(BENCH_SIZE);
533    bencher.iter(|| {
534        copy.clear();
535        copy.extend(data.iter().map(|x| ring.clone_el(x)));
536        fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
537        fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
538        assert_el_eq!(ring, copy[0], data[0]);
539    });
540}