1use std::alloc::Allocator;
2use std::alloc::Global;
3
4use crate::algorithms::fft::FFTAlgorithm;
5use crate::algorithms::unity_root::is_prim_root_of_unity;
6use crate::divisibility::DivisibilityRing;
7use crate::divisibility::DivisibilityRingStore;
8use crate::integer::IntegerRingStore;
9use crate::primitive_int::*;
10use crate::ring::*;
11use crate::homomorphism::*;
12use crate::algorithms;
13use crate::rings::zn::*;use crate::rings::float_complex::*;
14use crate::algorithms::fft::complex_fft::*;
15use crate::seq::SwappableVectorViewMut;
16
17#[derive(Debug)]
22pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
23 where R_main: ?Sized + RingBase,
24 R_twiddle: ?Sized + RingBase + DivisibilityRing,
25 H: Homomorphism<R_twiddle, R_main>,
26 A: Allocator + Clone
27{
28 hom: H,
29 m_fft_table: algorithms::fft::cooley_tuckey::CooleyTuckeyFFT<R_main, R_twiddle, H>,
30 tmp_mem_allocator: A,
31 b_bitreverse_fft: Vec<R_twiddle::Element>,
39 inv_root_of_unity_2n: Vec<R_twiddle::Element>,
41 root_of_unity_n: R_main::Element,
43 n: usize
44}
45
46impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
47 where H: Homomorphism<Complex64Base, Complex64Base> + Clone,
48 A: Allocator + Clone
49{
50 pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self{
64 let ZZ = StaticRing::<i64>::RING;
65 let CC = Complex64::RING;
66 let log2_m = ZZ.abs_log2_ceil(&(2 * n as i64 + 1)).unwrap();
67 Self::new_with_pows_with_hom(hom, |x| CC.root_of_unity(x, 2 * n as i64), |x| CC.root_of_unity(x, 1 << log2_m), n, log2_m, tmp_mem_allocator)
68 }
69}
70
71impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
72 where R: RingStore<Type = Complex64Base> + Clone,
73 A: Allocator + Clone
74{
75 pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self{
84 Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
85 }
86}
87
88impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
89 where R: RingStore + Clone,
90 R::Type: DivisibilityRing,
91 A: Allocator + Clone
92{
93 pub fn new(ring: R, root_of_unity_2n: El<R>, root_of_unity_m: El<R>, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
109 Self::new_with_hom(ring.into_identity(), root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator)
110 }
111
112 pub fn new_with_pows<F, G>(ring: R, root_of_unity_2n_pows: F, root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
126 where F: FnMut(i64) -> El<R>,
127 G: FnMut(i64) -> El<R>
128 {
129 Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_2n_pows, root_of_unity_m_pows, n, log2_m, tmp_mem_allocator)
130 }
131
132 pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
145 where R::Type: ZnRing
146 {
147 Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
148 }
149}
150
151impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
152 where R_main: ?Sized + RingBase,
153 R_twiddle: ?Sized + RingBase + DivisibilityRing,
154 H: Homomorphism<R_twiddle, R_main> + Clone,
155 A: Allocator + Clone
156{
157 pub fn new_with_hom(hom: H, root_of_unity_2n: R_twiddle::Element, root_of_unity_m: R_twiddle::Element, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
177 assert!((1 << log2_m) >= 2 * n + 1);
181 assert!(!hom.domain().get_ring().is_approximate());
182 assert!(is_prim_root_of_unity(hom.domain(), &root_of_unity_2n, 2 * n));
183 assert!(is_prim_root_of_unity(hom.codomain(), &hom.map_ref(&root_of_unity_2n), 2 * n));
184 assert!(is_prim_root_of_unity(hom.domain(), &root_of_unity_m, 1 << log2_m));
185 assert!(is_prim_root_of_unity(hom.codomain(), &hom.map_ref(&root_of_unity_m), 1 << log2_m));
186
187 let root_of_unity_2n_pows = |x: i64| if x >= 0 {
188 hom.domain().pow(hom.domain().clone_el(&root_of_unity_2n), x as usize % (2 * n))
189 } else {
190 hom.domain().invert(&hom.domain().pow(hom.domain().clone_el(&root_of_unity_2n), (-x) as usize % (2 * n))).unwrap()
191 };
192
193 let mut b = Self::create_b_array(hom.domain().get_ring(), root_of_unity_2n_pows, n, 1 << log2_m);
194 let inv_root_of_unity_2n = (0..n).map(|i| root_of_unity_2n_pows(-((i * i) as i64))).collect::<Vec<_>>();
195 let root_of_unity_n = hom.codomain().pow(hom.map_ref(&root_of_unity_2n), 2);
196
197 let m_fft_table_base = algorithms::fft::cooley_tuckey::CooleyTuckeyFFT::new(hom.domain(), hom.domain().clone_el(&root_of_unity_m), log2_m);
198 m_fft_table_base.unordered_fft(&mut b[..], &hom.domain());
199 drop(m_fft_table_base);
200
201 let m_fft_table = algorithms::fft::cooley_tuckey::CooleyTuckeyFFT::new_with_hom(hom.clone(), root_of_unity_m, log2_m);
203 return BluesteinFFT {
204 m_fft_table: m_fft_table,
205 b_bitreverse_fft: b,
206 inv_root_of_unity_2n: inv_root_of_unity_2n,
207 root_of_unity_n: root_of_unity_n,
208 n: n,
209 tmp_mem_allocator: tmp_mem_allocator,
210 hom: hom
211 };
212 }
213
214 fn create_b_array<F>(ring: &R_twiddle, mut root_of_unity_2n_pows: F, n: usize, m: usize) -> Vec<R_twiddle::Element>
215 where F: FnMut(i64) -> R_twiddle::Element
216 {
217 let mut b = (0..m).map(|_| ring.zero()).collect::<Vec<_>>();
218 b[0] = ring.one();
219 for i in 1..n {
220 b[i] = root_of_unity_2n_pows((i * i) as i64 % (2 * n) as i64);
221 b[m - i] = ring.clone_el(&b[i]);
222 }
223 return b;
224 }
225
226 pub fn new_with_pows_with_hom<F, G>(hom: H, mut root_of_unity_2n_pows: F, mut root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
245 where F: FnMut(i64) -> R_twiddle::Element,
246 G: FnMut(i64) -> R_twiddle::Element
247 {
248 assert!((1 << log2_m) >= 2 * n + 1);
250 assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_2n_pows(1), 2 * n));
251 assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_2n_pows(1)), 2 * n));
252 assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_m_pows(1), 1 << log2_m));
253 assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_m_pows(1)), 1 << log2_m));
254
255 let mut b = Self::create_b_array(hom.domain().get_ring(), &mut root_of_unity_2n_pows, n, 1 << log2_m);
256 let inv_root_of_unity_2n = (0..n).map(|i| root_of_unity_2n_pows(-((i * i) as i64))).collect::<Vec<_>>();
257 let root_of_unity_n = hom.map(root_of_unity_2n_pows(2));
258
259 let m_fft_table_base = algorithms::fft::cooley_tuckey::CooleyTuckeyFFT::new_with_pows(hom.domain(), &mut root_of_unity_m_pows, log2_m);
260 m_fft_table_base.unordered_fft(&mut b[..], &hom.domain());
261 drop(m_fft_table_base);
262
263 let m_fft_table = algorithms::fft::cooley_tuckey::CooleyTuckeyFFT::new_with_pows_with_hom(hom.clone(), root_of_unity_m_pows, log2_m);
265 return BluesteinFFT {
266 m_fft_table: m_fft_table,
267 b_bitreverse_fft: b,
268 inv_root_of_unity_2n: inv_root_of_unity_2n,
269 root_of_unity_n: root_of_unity_n,
270 n: n,
271 tmp_mem_allocator: tmp_mem_allocator,
272 hom: hom
273 };
274 }
275
276 pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
294 where R_twiddle: ZnRing
295 {
296 let ZZ = StaticRing::<i64>::RING;
297 let log2_m = ZZ.abs_log2_ceil(&(2 * n as i64 + 1)).unwrap();
298 let ring_as_field = hom.domain().as_field().ok().unwrap();
299 let root_of_unity_2n = ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity(&ring_as_field, 2 * n)?);
300 let root_of_unity_m = ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity_pow2(&ring_as_field, log2_m)?);
301 drop(ring_as_field);
302 Some(Self::new_with_hom(hom, root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator))
303 }
304}
305
306impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
307 where R_main: ?Sized + RingBase,
308 R_twiddle: ?Sized + RingBase + DivisibilityRing,
309 H: Homomorphism<R_twiddle, R_main>,
310 A: Allocator + Clone
311{
312 pub fn fft_base<V, W, const INV: bool>(&self, mut values: V, mut buffer: W)
323 where V: SwappableVectorViewMut<R_main::Element>,
324 W: SwappableVectorViewMut<R_main::Element>
325 {
326 assert_eq!(values.len(), self.n);
327 assert_eq!(buffer.len(), self.m_fft_table.len());
328
329 let ring = self.hom.codomain();
330
331 for i in 0..self.n {
333 let value = if INV {
334 values.at((self.n - i) % self.n)
335 } else {
336 values.at(i)
337 };
338 *buffer.at_mut(i) = ring.clone_el(value);
339 self.hom.mul_assign_ref_map(buffer.at_mut(i), &self.inv_root_of_unity_2n[i]);
340 }
341 for i in self.n..self.m_fft_table.len() {
342 *buffer.at_mut(i) = ring.zero();
343 }
344
345 self.m_fft_table.unordered_fft(&mut buffer, &self.ring());
347 for i in 0..self.m_fft_table.len() {
348 self.hom.mul_assign_ref_map(buffer.at_mut(i), &self.b_bitreverse_fft[i]);
349 }
350 self.m_fft_table.unordered_inv_fft(&mut buffer, &self.ring());
351
352 for i in 0..self.n {
354 *values.at_mut(i) = std::mem::replace(buffer.at_mut(i), ring.zero());
355 self.hom.mul_assign_ref_map(values.at_mut(i), &self.inv_root_of_unity_2n[i]);
356 }
357
358 if INV {
359 let scale = self.hom.map(self.hom.domain().checked_div(&self.hom.domain().one(), &self.hom.domain().int_hom().map(self.n as i32)).unwrap());
361 for i in 0..values.len() {
362 ring.mul_assign_ref(values.at_mut(i), &scale);
363 }
364 }
365 }
366
367 fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
368 self.hom.codomain()
369 }
370
371}
372
373impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
374 where R_main: ?Sized + RingBase,
375 R_twiddle: ?Sized + RingBase + DivisibilityRing,
376 H: Homomorphism<R_twiddle, R_main>,
377 A: Allocator + Clone
378{
379 fn eq(&self, other: &Self) -> bool {
380 self.ring().get_ring() == other.ring().get_ring() &&
381 self.n == other.n &&
382 self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
383 }
384}
385
386impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
387 where R_main: ?Sized + RingBase,
388 R_twiddle: ?Sized + RingBase + DivisibilityRing,
389 H: Homomorphism<R_twiddle, R_main>,
390 A: Allocator + Clone
391{
392 fn len(&self) -> usize {
393 self.n
394 }
395
396 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
397 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
398 &self.root_of_unity_n
399 }
400
401 fn unordered_fft_permutation(&self, i: usize) -> usize {
402 i
403 }
404
405 fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
406 i
407 }
408
409 fn unordered_fft<V, S>(&self, values: V, ring: S)
410 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
411 S: RingStore<Type = R_main> + Copy
412 {
413 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
414 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.tmp_mem_allocator.clone());
415 buffer.extend((0..self.m_fft_table.len()).map(|_| self.hom.codomain().zero()));
416 self.fft_base::<_, _, false>(values, &mut buffer[..]);
417 }
418
419 fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
420 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
421 S: RingStore<Type = R_main> + Copy
422 {
423 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
424 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.tmp_mem_allocator.clone());
425 buffer.extend((0..self.m_fft_table.len()).map(|_| self.hom.codomain().zero()));
426 self.fft_base::<_, _, true>(values, &mut buffer[..]);
427 }
428
429 fn fft<V, S>(&self, values: V, ring: S)
430 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
431 S: RingStore<Type = R_main> + Copy
432 {
433 self.unordered_fft(values, ring);
434 }
435
436 fn inv_fft<V, S>(&self, values: V, ring: S)
437 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
438 S: RingStore<Type = R_main> + Copy
439 {
440 self.unordered_inv_fft(values, ring);
441 }
442}
443
444impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
445 where H: Homomorphism<Complex64Base, Complex64Base>,
446 A: Allocator + Clone
447{
448 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
449 let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
450 let error_after_fft = self.m_fft_table.expected_absolute_error(input_bound, error_after_twiddling);
451 let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1., root_of_unity_error());
452 let new_input_bound = input_bound * self.m_fft_table.len() as f64;
454 let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
455 let error_after_mul = new_input_bound * b_bitreverse_fft_error + b_bitreverse_fft_bound * error_after_fft + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
456 let error_after_inv_fft = self.m_fft_table.expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul) / self.m_fft_table.len() as f64;
457 let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
458 return error_end;
459 }
460}
461
462#[cfg(test)]
463use crate::rings::zn::zn_static::*;
464
465#[test]
466fn test_fft_base() {
467 let ring = Zn::<241>::RING;
468 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
470 let mut values = [1, 3, 2, 0, 7];
471 let mut buffer = [0; 16];
472 fft.fft_base::<_, _, false>(&mut values, &mut buffer);
473 let expected = [13, 137, 202, 206, 170];
474 assert_eq!(expected, values);
475}
476
477#[test]
478fn test_fft_fastmul() {
479 let ring = zn_64::Zn::new(241);
480 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
481 let fft = BluesteinFFT::new_with_hom(ring.can_hom(&fastmul_ring).unwrap(), fastmul_ring.int_hom().map(36), fastmul_ring.int_hom().map(111), 5, 4, Global);
482 let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
483 fft.fft(&mut values, ring);
484 let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
485 for i in 0..values.len() {
486 assert_el_eq!(ring, expected[i], values[i]);
487 }
488}
489
490#[test]
491fn test_inv_fft_base() {
492 let ring = Zn::<241>::RING;
493 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
495 let values = [1, 3, 2, 0, 7];
496 let mut work = values;
497 let mut buffer = [0; 16];
498 fft.fft_base::<_, _, false>(&mut work, &mut buffer);
499 fft.fft_base::<_, _, true>(&mut work, &mut buffer);
500 assert_eq!(values, work);
501}
502
503#[test]
504fn test_approximate_fft() {
505 let CC = Complex64::RING;
506 for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
507 let fft = BluesteinFFT::for_complex(&CC, p, Global);
508 let mut array = (0..p).map(|i| CC.root_of_unity(i as i64, p as i64)).collect::<Vec<_>>();
509 fft.fft(&mut array, CC);
510 let err = fft.expected_absolute_error(1., 0.);
511 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
512 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
513 for i in 2..fft.len() {
514 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
515 }
516 }
517}
518
519#[cfg(test)]
520const BENCH_SIZE: usize = 1009;
521
522#[bench]
523fn bench_bluestein(bencher: &mut test::Bencher) {
524 let ring = zn_64::Zn::new(18597889);
525 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
526 let embedding = ring.can_hom(&fastmul_ring).unwrap();
527 let ring_as_field = ring.as_field().ok().unwrap();
528 let root_of_unity = fastmul_ring.coerce(&ring, ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity(&ring_as_field, 2 * BENCH_SIZE).unwrap()));
529 let fastmul_ring_as_field = fastmul_ring.as_field().ok().unwrap();
530 let fft = BluesteinFFT::new_with_hom(
531 embedding.clone(),
532 root_of_unity,
533 fastmul_ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity_pow2(&fastmul_ring_as_field, 11).unwrap()),
534 BENCH_SIZE,
535 11,
536 Global
537 );
538 let data = (0..BENCH_SIZE).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
539 let mut copy = Vec::with_capacity(BENCH_SIZE);
540 bencher.iter(|| {
541 copy.clear();
542 copy.extend(data.iter().map(|x| ring.clone_el(x)));
543 fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
544 fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
545 assert_el_eq!(ring, copy[0], data[0]);
546 });
547}