Skip to main content

feanor_math/algorithms/fft/
bluestein.rs

1use std::alloc::Allocator;
2use std::alloc::Global;
3use std::fmt::Debug;
4
5use crate::algorithms::fft::FFTAlgorithm;
6use crate::algorithms::unity_root::*;
7use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::integer::IntegerRingStore;
10use crate::primitive_int::*;
11use crate::ring::*;
12use crate::homomorphism::*;
13use crate::rings::zn::*;
14use crate::rings::float_complex::*;
15use crate::algorithms::fft::complex_fft::*;
16use crate::seq::SwappableVectorViewMut;
17
18type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
19
20///
21/// Bluestein's FFT algorithm (also known as Chirp-Z-transform) to compute the Fourier
22/// transform of arbitrary length (including prime numbers).
23/// 
24/// # Convention
25/// 
26/// This implementation does not follows the standard convention for the mathematical
27/// DFT, by performing the standard/forward FFT with the inverse root of unity `z^-1`.
28/// In other words, the forward FFT computes
29/// ```text
30///   (a_0, ..., a_(N - 1)) -> (sum_j a_j z^(-ij))_i
31/// ```
32/// where `z` is the primitive `n`-th root of unity returned by [`FFTAlgorithm::root_of_unity()`].
33/// 
34pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
35    where R_main: ?Sized + RingBase,
36        R_twiddle: ?Sized + RingBase + DivisibilityRing,
37        H: Homomorphism<R_twiddle, R_main> + Clone, 
38        A: Allocator + Clone
39{
40    m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
41    b_unordered_fft: Vec<R_twiddle::Element>,
42    twiddles: Vec<R_twiddle::Element>,
43    root_of_unity_n: R_main::Element,
44    n: usize
45}
46
47impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
48    where H: Homomorphism<Complex64Base, Complex64Base> + Clone, 
49        A: Allocator + Clone
50{
51    ///
52    /// Creates an [`BluesteinFFT`] for the complex field, using the given homomorphism
53    /// to connect the ring implementation for twiddles with the main ring implementation.
54    /// 
55    /// This function is mainly provided for parity with other rings, since in the complex case
56    /// it currently does not make much sense to use a different homomorphism than the identity.
57    /// Hence, it is simpler to use [`BluesteinFFT::for_complex()`].
58    /// 
59    pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self{
60        let ZZ = StaticRing::<i64>::RING;
61        let CC = Complex64::RING;
62        let n_i64: i64 = n.try_into().unwrap();
63        let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
64        Self::new_with_pows_with_hom(hom, |x| CC.root_of_unity(x, 2 * n_i64), |x| CC.root_of_unity(x, 1 << log2_m), n, log2_m, tmp_mem_allocator)
65    }
66}
67
68impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
69    where R: RingStore<Type = Complex64Base> + Clone, 
70        A: Allocator + Clone
71{
72    ///
73    /// Creates an [`BluesteinFFT`] for the complex field.
74    /// 
75    pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self{
76        Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
77    }
78}
79
80impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
81    where R: RingStore + Clone,
82        R::Type: DivisibilityRing,
83        A: Allocator + Clone
84{
85    ///
86    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
87    /// 
88    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
89    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
90    ///  
91    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
92    /// will incur avoidable precision loss.
93    /// 
94    pub fn new(ring: R, root_of_unity_2n: El<R>, root_of_unity_m: El<R>, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
95        Self::new_with_hom(ring.into_identity(), root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator)
96    }
97
98    ///
99    /// Creates an [`BluesteinFFT`] for the given ring, using the passed function to
100    /// provide the necessary roots of unity.
101    /// 
102    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
103    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
104    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m > 2n`.
105    /// 
106    pub fn new_with_pows<F, G>(ring: R, root_of_unity_2n_pows: F, root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self 
107        where F: FnMut(i64) -> El<R>,
108            G: FnMut(i64) -> El<R>
109    {
110        Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_2n_pows, root_of_unity_m_pows, n, log2_m, tmp_mem_allocator)
111    }
112
113    ///
114    /// Creates an [`BluesteinFFT`] for a prime field, assuming it has suitable roots of
115    /// unity.
116    /// 
117    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
118    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
119    /// 
120    pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
121        where R::Type: ZnRing
122    {
123        Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
124    }
125}
126
127impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
128    where R_main: ?Sized + RingBase,
129        R_twiddle: ?Sized + RingBase + DivisibilityRing,
130        H: Homomorphism<R_twiddle, R_main> + Clone, 
131        A: Allocator + Clone
132{
133    ///
134    /// Creates an [`BluesteinFFT`] for the given ring, using the given roots of unity.
135    /// 
136    /// It is necessary that `root_of_unity_2n` is a primitive `2n`-th root of unity, and
137    /// `root_of_unity_m` is a `2^log2_m`-th root of unity, where `2^log2_m >= 2n`.
138    /// 
139    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
140    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
141    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
142    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
143    /// 
144    /// Do not use this for approximate rings, as computing the powers of `root_of_unity`
145    /// will incur avoidable precision loss.
146    /// 
147    pub fn new_with_hom(hom: H, root_of_unity_2n: R_twiddle::Element, root_of_unity_m: R_twiddle::Element, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
148        let hom_copy = hom.clone();
149        let twiddle_ring = hom_copy.domain();
150        return Self::new_with_pows_with_hom(
151            hom, 
152            |i: i64| if i >= 0 {
153                twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
154            } else {
155                twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n))).unwrap()
156            }, 
157            |i: i64| if i >= 0 {
158                twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
159            } else {
160                twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize)).unwrap()
161            }, 
162            n, 
163            log2_m, 
164            tmp_mem_allocator
165        );
166    }
167
168    ///
169    /// Creates an [`BluesteinFFT`] for the given rings, using the given function to create
170    /// the necessary powers of roots of unity.
171    /// 
172    /// Concretely, `root_of_unity_2n_pows(i)` should return `z^i`, where `z` is a `2n`-th
173    /// primitive root of unity, and `root_of_unity_m_pows(i)` should return `w^i` where `w`
174    /// is a `2^log2_m`-th primitive root of unity, where `2^log2_m >= 2n`.
175    /// 
176    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
177    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
178    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
179    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
180    /// 
181    pub fn new_with_pows_with_hom<F, G>(hom: H, mut root_of_unity_2n_pows: F, mut root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
182        where F: FnMut(i64) -> R_twiddle::Element,
183            G: FnMut(i64) -> R_twiddle::Element
184    {
185        let m_fft_table = CooleyTuckeyFFT::create(
186            hom, 
187            &mut root_of_unity_m_pows, 
188            log2_m, 
189            tmp_mem_allocator
190        );
191        return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
192    }
193
194    ///
195    /// Creates an [`BluesteinFFT`] for the given prime fields, assuming they have suitable
196    /// roots of unity.
197    /// 
198    /// Concretely, this requires that the characteristic `p` is congruent to 1 modulo
199    /// `2^log2_m n`, where `2^log2_m` is the smallest power of two that is `>= 2n`.
200    /// 
201    /// Instead of a ring, this function takes a homomorphism `R -> S`. Twiddle factors that are
202    /// precomputed will be stored as elements of `R`, while the main FFT computations will be 
203    /// performed in `S`. This allows both implicit ring conversions, and using patterns like 
204    /// [`zn_64::ZnFastmul`] to precompute some data for better performance.
205    /// 
206    pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
207        where R_twiddle: ZnRing
208    {
209        let root_of_unity_2n = get_prim_root_of_unity_zn(hom.domain(), 2 * n)?;
210        let log2_m = StaticRing::<i64>::RING.abs_log2_ceil(&(n * 2).try_into().unwrap()).unwrap();
211        let root_of_unity_m = get_prim_root_of_unity_zn(hom.domain(), 1 << log2_m)?;
212        return Some(Self::new_with_hom(hom, root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator));
213    }
214    
215    ///
216    /// Most general way to construct a [`BluesteinFFT`].
217    /// 
218    /// This function takes a length-`m` base FFT, where `m >= 2m`, and a function `root_of_unity_pows`,
219    /// on input `i`, should return `z^i` for an `n`-th primitive root of unity `z`.
220    /// 
221    #[stability::unstable(feature = "enable")]
222    pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
223        where F: FnMut(i64) -> R_twiddle::Element
224    {
225        let hom = m_fft_table.hom().clone();
226        let m = m_fft_table.len();
227        assert!(m >= 2 * n);
228        assert!(n % 2 == 1);
229        assert!(hom.codomain().is_commutative());
230        assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n));
231        assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n));
232
233        let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
234
235        let half_mod_n = (n + 1) / 2;
236        let mut b: Vec<_> = (0..n).map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect();
237        b.resize_with(m, || hom.domain().zero());
238        
239        twiddle_fft.unordered_fft(&mut b, hom.domain());
240
241        let twiddles = (0..n).map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect::<Vec<_>>();
242        let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
243
244        return BluesteinFFT { 
245            m_fft_table: twiddle_fft.change_ring(old_hom).0, 
246            b_unordered_fft: b, 
247            twiddles: twiddles, 
248            root_of_unity_n: root_of_unity_n,
249            n: n
250        };
251    }
252
253    ///
254    /// Computes the FFT of the given values using Bluestein's algorithm, using only the passed
255    /// buffer as temporary storage.
256    /// 
257    /// This will not allocate additional memory, as opposed to [`BluesteinFFT::fft()`] etc.
258    /// 
259    /// Basically, the idea is to write an FFT of any length (e.g. prime length) as a convolution,
260    /// and compute the convolution efficiently using a power-of-two FFT (e.g. with the Cooley-Tukey 
261    /// algorithm).
262    /// 
263    /// TODO: At next breaking release, make this private
264    /// 
265    pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
266        where V: SwappableVectorViewMut<R_main::Element>, 
267            W: SwappableVectorViewMut<R_main::Element>
268    {
269        if INV {
270            self.unordered_inv_fft(values, self.ring());
271        } else {
272            self.unordered_fft(values, self.ring());
273        }
274    }
275
276    fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
277        where V: SwappableVectorViewMut<R_main::Element>,
278            A2: Allocator
279    {
280        assert_eq!(values.len(), self.n);
281        assert_eq!(buffer.len(), self.m_fft_table.len());
282
283        let ring = self.m_fft_table.hom().codomain();
284
285        // set buffer to the zero-padded sequence values_i * z^(-i^2/2)
286        for i in 0..self.n {
287            let value = if INV {
288                values.at((self.n - i) % self.n)
289            } else {
290                values.at(i)
291            };
292            buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
293        }
294        for i in self.n..self.m_fft_table.len() {
295            buffer[i] = ring.zero();
296        }
297 
298        self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
299        for i in 0..self.m_fft_table.len() {
300            self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
301        }
302        self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
303        
304        // make the normal convolution into a cyclic convolution of length n by taking it modulo `x^n - 1`
305        let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
306        for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
307            ring.add_assign_ref(a, b);
308        }
309
310        // write values back, and multiply them with a twiddle factor
311        for (i, x) in buffer.into_iter().enumerate().take(self.n) {
312            *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
313        }
314
315        if INV {
316            // finally, scale by 1/n
317            let scale = self.hom().map(self.hom().domain().checked_div(&self.hom().domain().one(), &self.hom().domain().int_hom().map(self.n.try_into().unwrap())).unwrap());
318            for i in 0..values.len() {
319                ring.mul_assign_ref(values.at_mut(i), &scale);
320            }
321        }
322    }
323
324    ///
325    /// Returns a reference to the allocator currently used for temporary allocations by this FFT.
326    /// 
327    #[stability::unstable(feature = "enable")]
328    pub fn allocator(&self) -> &A {
329        self.m_fft_table.allocator()
330    }
331    
332    ///
333    /// Returns the ring over which this object can compute FFTs.
334    /// 
335    #[stability::unstable(feature = "enable")]
336    pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
337        self.hom().codomain()
338    }
339
340    ///
341    /// Returns a reference to the homomorphism that is used to map the stored twiddle
342    /// factors into main ring, over which FFTs are computed.
343    /// 
344    #[stability::unstable(feature = "enable")]
345    pub fn hom(&self) -> &H {
346        self.m_fft_table.hom()
347    }
348}
349
350impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
351    where R_main: ?Sized + RingBase,
352        R_twiddle: ?Sized + RingBase + DivisibilityRing,
353        H: Homomorphism<R_twiddle, R_main> + Clone, 
354        A: Allocator + Clone
355{
356    fn eq(&self, other: &Self) -> bool {
357        self.ring().get_ring() == other.ring().get_ring() &&
358            self.n == other.n &&
359            self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
360    }
361}
362
363impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
364    where R_main: ?Sized + RingBase + Debug,
365        R_twiddle: ?Sized + RingBase + DivisibilityRing,
366        H: Homomorphism<R_twiddle, R_main> + Clone, 
367        A: Allocator + Clone
368{
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        f.debug_struct("BluesteinFFT")
371            .field("ring", &self.ring().get_ring())
372            .field("n", &self.n)
373            .field("root_of_unity_n", &self.ring().format(&self.root_of_unity(self.ring())))
374            .finish()
375    }
376}
377
378impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
379    where R_main: ?Sized + RingBase,
380        R_twiddle: ?Sized + RingBase + DivisibilityRing,
381        H: Homomorphism<R_twiddle, R_main> + Clone, 
382        A: Allocator + Clone
383{
384    fn len(&self) -> usize {
385        self.n
386    }
387
388    fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
389        assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
390        &self.root_of_unity_n
391    }
392
393    fn unordered_fft_permutation(&self, i: usize) -> usize {
394        i
395    }
396
397    fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
398        i
399    }
400
401    fn unordered_fft<V, S>(&self, values: V, ring: S)
402        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
403            S: RingStore<Type = R_main> + Copy 
404    {
405        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
406        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
407        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
408        self.fft_base_impl::<_, _, false>(values, buffer);
409    }
410
411    fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
412        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
413            S: RingStore<Type = R_main> + Copy 
414    {
415        assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
416        let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
417        buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
418        self.fft_base_impl::<_, _, true>(values, buffer);
419    }
420
421    fn fft<V, S>(&self, values: V, ring: S)
422        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
423            S: RingStore<Type = R_main> + Copy 
424    {
425        self.unordered_fft(values, ring);
426    }
427
428    fn inv_fft<V, S>(&self, values: V, ring: S)
429        where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
430            S: RingStore<Type = R_main> + Copy 
431    {
432        self.unordered_inv_fft(values, ring);
433    }
434}
435
436impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
437    where H: Homomorphism<Complex64Base, Complex64Base> + Clone, 
438        A: Allocator + Clone
439{
440    fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
441        let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
442        let error_after_fft = self.m_fft_table.expected_absolute_error(input_bound, error_after_twiddling);
443        let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1., root_of_unity_error());
444        // now the values are increased by up to a factor of m, so use `input_bound * m` instead
445        let new_input_bound = input_bound * self.m_fft_table.len() as f64;
446        let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
447        let error_after_mul = new_input_bound * b_bitreverse_fft_error + b_bitreverse_fft_bound * error_after_fft + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
448        let error_after_inv_fft = self.m_fft_table.expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul) / self.m_fft_table.len() as f64;
449        let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
450        return error_end;
451    }
452}
453
454#[cfg(test)]
455use crate::rings::zn::zn_static::*;
456
457#[test]
458fn test_fft_base() {
459    let ring = Zn::<241>::RING;
460    // a 5-th root of unity is 91 
461    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
462    let mut values = [1, 3, 2, 0, 7];
463    let mut buffer = [0; 16];
464    fft.fft_base::<_, _, false>(&mut values, &mut buffer);
465    let expected = [13, 137, 202, 206, 170];
466    assert_eq!(expected, values);
467}
468
469#[test]
470fn test_fft_fastmul() {
471    let ring = zn_64::Zn::new(241);
472    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
473    let fft = BluesteinFFT::new_with_hom(ring.can_hom(&fastmul_ring).unwrap(), fastmul_ring.int_hom().map(36), fastmul_ring.int_hom().map(111), 5, 4, Global);
474    let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
475    fft.fft(&mut values, ring);
476    let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
477    for i in 0..values.len() {
478        assert_el_eq!(ring, expected[i], values[i]);
479    }
480}
481
482#[test]
483fn test_inv_fft_base() {
484    let ring = Zn::<241>::RING;
485    // a 5-th root of unity is 91 
486    let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
487    let values = [1, 3, 2, 0, 7];
488    let mut work = values;
489    let mut buffer = [0; 16];
490    fft.fft_base::<_, _, false>(&mut work, &mut buffer);
491    fft.fft_base::<_, _, true>(&mut work, &mut buffer);
492    assert_eq!(values, work);
493}
494
495#[test]
496fn test_approximate_fft() {
497    let CC = Complex64::RING;
498    for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
499        let fft = BluesteinFFT::for_complex(&CC, p, Global);
500        let mut array = (0..p).map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap())).collect::<Vec<_>>();
501        fft.fft(&mut array, CC);
502        let err = fft.expected_absolute_error(1., 0.);
503        assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
504        assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
505        for i in 2..fft.len() {
506            assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
507        }
508    }
509}
510
511#[cfg(test)]
512const BENCH_SIZE: usize = 1009;
513
514#[bench]
515fn bench_bluestein(bencher: &mut test::Bencher) {
516    let ring = zn_64::Zn::new(18597889);
517    let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
518    let embedding = ring.can_hom(&fastmul_ring).unwrap();
519    let root_of_unity = fastmul_ring.coerce(&ring, get_prim_root_of_unity_zn(&ring, 2 * BENCH_SIZE).unwrap());
520    let fft = BluesteinFFT::new_with_hom(
521        embedding.clone(), 
522        root_of_unity, 
523        get_prim_root_of_unity_zn(&fastmul_ring, 1 << 11).unwrap(), 
524        BENCH_SIZE, 
525        11, 
526        Global
527    );
528    let data = (0..BENCH_SIZE).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
529    let mut copy = Vec::with_capacity(BENCH_SIZE);
530    bencher.iter(|| {
531        copy.clear();
532        copy.extend(data.iter().map(|x| ring.clone_el(x)));
533        fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
534        fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
535        assert_el_eq!(ring, copy[0], data[0]);
536    });
537}