1use std::alloc::{Allocator, Global};
2use std::fmt::Debug;
3
4use crate::algorithms::fft::FFTAlgorithm;
5use crate::algorithms::fft::complex_fft::*;
6use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
7use crate::algorithms::unity_root::*;
8use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
9use crate::homomorphism::*;
10use crate::integer::IntegerRingStore;
11use crate::primitive_int::*;
12use crate::ring::*;
13use crate::rings::float_complex::*;
14use crate::rings::zn::*;
15use crate::seq::SwappableVectorViewMut;
16
17type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
18
19pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
32where
33 R_main: ?Sized + RingBase,
34 R_twiddle: ?Sized + RingBase + DivisibilityRing,
35 H: Homomorphism<R_twiddle, R_main> + Clone,
36 A: Allocator + Clone,
37{
38 m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
39 b_unordered_fft: Vec<R_twiddle::Element>,
40 twiddles: Vec<R_twiddle::Element>,
41 root_of_unity_n: R_main::Element,
42 n: usize,
43}
44
45impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
46where
47 H: Homomorphism<Complex64Base, Complex64Base> + Clone,
48 A: Allocator + Clone,
49{
50 pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self {
57 let ZZ = StaticRing::<i64>::RING;
58 let CC = Complex64::RING;
59 let n_i64: i64 = n.try_into().unwrap();
60 let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
61 Self::new_with_pows_with_hom(
62 hom,
63 |x| CC.root_of_unity(x, 2 * n_i64),
64 |x| CC.root_of_unity(x, 1 << log2_m),
65 n,
66 log2_m,
67 tmp_mem_allocator,
68 )
69 }
70}
71
72impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
73where
74 R: RingStore<Type = Complex64Base> + Clone,
75 A: Allocator + Clone,
76{
77 pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self {
79 Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
80 }
81}
82
83impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
84where
85 R: RingStore + Clone,
86 R::Type: DivisibilityRing,
87 A: Allocator + Clone,
88{
89 pub fn new(
97 ring: R,
98 root_of_unity_2n: El<R>,
99 root_of_unity_m: El<R>,
100 n: usize,
101 log2_m: usize,
102 tmp_mem_allocator: A,
103 ) -> Self {
104 Self::new_with_hom(
105 ring.into_identity(),
106 root_of_unity_2n,
107 root_of_unity_m,
108 n,
109 log2_m,
110 tmp_mem_allocator,
111 )
112 }
113
114 pub fn new_with_pows<F, G>(
121 ring: R,
122 root_of_unity_2n_pows: F,
123 root_of_unity_m_pows: G,
124 n: usize,
125 log2_m: usize,
126 tmp_mem_allocator: A,
127 ) -> Self
128 where
129 F: FnMut(i64) -> El<R>,
130 G: FnMut(i64) -> El<R>,
131 {
132 Self::new_with_pows_with_hom(
133 ring.into_identity(),
134 root_of_unity_2n_pows,
135 root_of_unity_m_pows,
136 n,
137 log2_m,
138 tmp_mem_allocator,
139 )
140 }
141
142 pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
148 where
149 R::Type: ZnRing,
150 {
151 Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
152 }
153}
154
155impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
156where
157 R_main: ?Sized + RingBase,
158 R_twiddle: ?Sized + RingBase + DivisibilityRing,
159 H: Homomorphism<R_twiddle, R_main> + Clone,
160 A: Allocator + Clone,
161{
162 pub fn new_with_hom(
175 hom: H,
176 root_of_unity_2n: R_twiddle::Element,
177 root_of_unity_m: R_twiddle::Element,
178 n: usize,
179 log2_m: usize,
180 tmp_mem_allocator: A,
181 ) -> Self {
182 let hom_copy = hom.clone();
183 let twiddle_ring = hom_copy.domain();
184 return Self::new_with_pows_with_hom(
185 hom,
186 |i: i64| {
187 if i >= 0 {
188 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
189 } else {
190 twiddle_ring
191 .invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n)))
192 .unwrap()
193 }
194 },
195 |i: i64| {
196 if i >= 0 {
197 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
198 } else {
199 twiddle_ring
200 .invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize))
201 .unwrap()
202 }
203 },
204 n,
205 log2_m,
206 tmp_mem_allocator,
207 );
208 }
209
210 pub fn new_with_pows_with_hom<F, G>(
222 hom: H,
223 mut root_of_unity_2n_pows: F,
224 mut root_of_unity_m_pows: G,
225 n: usize,
226 log2_m: usize,
227 tmp_mem_allocator: A,
228 ) -> Self
229 where
230 F: FnMut(i64) -> R_twiddle::Element,
231 G: FnMut(i64) -> R_twiddle::Element,
232 {
233 let m_fft_table = CooleyTuckeyFFT::create(hom, &mut root_of_unity_m_pows, log2_m, tmp_mem_allocator);
234 return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
235 }
236
237 pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
248 where
249 R_twiddle: ZnRing,
250 {
251 let root_of_unity_2n = get_prim_root_of_unity_zn(hom.domain(), 2 * n)?;
252 let log2_m = StaticRing::<i64>::RING
253 .abs_log2_ceil(&(n * 2).try_into().unwrap())
254 .unwrap();
255 let root_of_unity_m = get_prim_root_of_unity_zn(hom.domain(), 1 << log2_m)?;
256 return Some(Self::new_with_hom(
257 hom,
258 root_of_unity_2n,
259 root_of_unity_m,
260 n,
261 log2_m,
262 tmp_mem_allocator,
263 ));
264 }
265
266 #[stability::unstable(feature = "enable")]
272 pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
273 where
274 F: FnMut(i64) -> R_twiddle::Element,
275 {
276 let hom = m_fft_table.hom().clone();
277 let m = m_fft_table.len();
278 assert!(m >= 2 * n);
279 assert!(n % 2 == 1);
280 assert!(hom.codomain().is_commutative());
281 assert!(
282 hom.domain().get_ring().is_approximate()
283 || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n)
284 );
285 assert!(
286 hom.codomain().get_ring().is_approximate()
287 || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n)
288 );
289
290 let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
291
292 let half_mod_n = n.div_ceil(2);
293 let mut b: Vec<_> = (0..n)
294 .map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap()))
295 .collect();
296 b.resize_with(m, || hom.domain().zero());
297
298 twiddle_fft.unordered_fft(&mut b, hom.domain());
299
300 let twiddles = (0..n)
301 .map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap()))
302 .collect::<Vec<_>>();
303 let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
304
305 return BluesteinFFT {
306 m_fft_table: twiddle_fft.change_ring(old_hom).0,
307 b_unordered_fft: b,
308 twiddles,
309 root_of_unity_n,
310 n,
311 };
312 }
313
314 pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
325 where
326 V: SwappableVectorViewMut<R_main::Element>,
327 W: SwappableVectorViewMut<R_main::Element>,
328 {
329 if INV {
330 self.unordered_inv_fft(values, self.ring());
331 } else {
332 self.unordered_fft(values, self.ring());
333 }
334 }
335
336 fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
337 where
338 V: SwappableVectorViewMut<R_main::Element>,
339 A2: Allocator,
340 {
341 assert_eq!(values.len(), self.n);
342 assert_eq!(buffer.len(), self.m_fft_table.len());
343
344 let ring = self.m_fft_table.hom().codomain();
345
346 for i in 0..self.n {
348 let value = if INV {
349 values.at((self.n - i) % self.n)
350 } else {
351 values.at(i)
352 };
353 buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
354 }
355 for i in self.n..self.m_fft_table.len() {
356 buffer[i] = ring.zero();
357 }
358
359 self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
360 for i in 0..self.m_fft_table.len() {
361 self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
362 }
363 self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
364
365 let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
368 for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
369 ring.add_assign_ref(a, b);
370 }
371
372 for (i, x) in buffer.into_iter().enumerate().take(self.n) {
374 *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
375 }
376
377 if INV {
378 let scale = self.hom().map(
380 self.hom()
381 .domain()
382 .checked_div(
383 &self.hom().domain().one(),
384 &self.hom().domain().int_hom().map(self.n.try_into().unwrap()),
385 )
386 .unwrap(),
387 );
388 for i in 0..values.len() {
389 ring.mul_assign_ref(values.at_mut(i), &scale);
390 }
391 }
392 }
393
394 #[stability::unstable(feature = "enable")]
396 pub fn allocator(&self) -> &A { self.m_fft_table.allocator() }
397
398 #[stability::unstable(feature = "enable")]
400 pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore { self.hom().codomain() }
401
402 #[stability::unstable(feature = "enable")]
405 pub fn hom(&self) -> &H { self.m_fft_table.hom() }
406}
407
408impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
409where
410 R_main: ?Sized + RingBase,
411 R_twiddle: ?Sized + RingBase + DivisibilityRing,
412 H: Homomorphism<R_twiddle, R_main> + Clone,
413 A: Allocator + Clone,
414{
415 fn eq(&self, other: &Self) -> bool {
416 self.ring().get_ring() == other.ring().get_ring()
417 && self.n == other.n
418 && self
419 .ring()
420 .eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
421 }
422}
423
424impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
425where
426 R_main: ?Sized + RingBase + Debug,
427 R_twiddle: ?Sized + RingBase + DivisibilityRing,
428 H: Homomorphism<R_twiddle, R_main> + Clone,
429 A: Allocator + Clone,
430{
431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
432 f.debug_struct("BluesteinFFT")
433 .field("ring", &self.ring().get_ring())
434 .field("n", &self.n)
435 .field("root_of_unity_n", &self.ring().format(self.root_of_unity(self.ring())))
436 .finish()
437 }
438}
439
440impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
441where
442 R_main: ?Sized + RingBase,
443 R_twiddle: ?Sized + RingBase + DivisibilityRing,
444 H: Homomorphism<R_twiddle, R_main> + Clone,
445 A: Allocator + Clone,
446{
447 fn len(&self) -> usize { self.n }
448
449 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
450 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
451 &self.root_of_unity_n
452 }
453
454 fn unordered_fft_permutation(&self, i: usize) -> usize { i }
455
456 fn unordered_fft_permutation_inv(&self, i: usize) -> usize { i }
457
458 fn unordered_fft<V, S>(&self, values: V, ring: S)
459 where
460 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
461 S: RingStore<Type = R_main> + Copy,
462 {
463 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
464 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
465 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
466 self.fft_base_impl::<_, _, false>(values, buffer);
467 }
468
469 fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
470 where
471 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
472 S: RingStore<Type = R_main> + Copy,
473 {
474 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
475 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
476 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
477 self.fft_base_impl::<_, _, true>(values, buffer);
478 }
479
480 fn fft<V, S>(&self, values: V, ring: S)
481 where
482 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
483 S: RingStore<Type = R_main> + Copy,
484 {
485 self.unordered_fft(values, ring);
486 }
487
488 fn inv_fft<V, S>(&self, values: V, ring: S)
489 where
490 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
491 S: RingStore<Type = R_main> + Copy,
492 {
493 self.unordered_inv_fft(values, ring);
494 }
495}
496
497impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
498where
499 H: Homomorphism<Complex64Base, Complex64Base> + Clone,
500 A: Allocator + Clone,
501{
502 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
503 let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
504 let error_after_fft = self
505 .m_fft_table
506 .expected_absolute_error(input_bound, error_after_twiddling);
507 let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1.0, root_of_unity_error());
508 let new_input_bound = input_bound * self.m_fft_table.len() as f64;
510 let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
511 let error_after_mul = new_input_bound * b_bitreverse_fft_error
512 + b_bitreverse_fft_bound * error_after_fft
513 + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
514 let error_after_inv_fft = self
515 .m_fft_table
516 .expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul)
517 / self.m_fft_table.len() as f64;
518 let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
519 return error_end;
520 }
521}
522
523#[cfg(test)]
524use crate::rings::zn::zn_static::*;
525
526#[test]
527fn test_fft_base() {
528 let ring = Zn::<241>::RING;
529 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
531 let mut values = [1, 3, 2, 0, 7];
532 let mut buffer = [0; 16];
533 fft.fft_base::<_, _, false>(&mut values, &mut buffer);
534 let expected = [13, 137, 202, 206, 170];
535 assert_eq!(expected, values);
536}
537
538#[test]
539fn test_fft_fastmul() {
540 let ring = zn_64::Zn::new(241);
541 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
542 let fft = BluesteinFFT::new_with_hom(
543 ring.can_hom(&fastmul_ring).unwrap(),
544 fastmul_ring.int_hom().map(36),
545 fastmul_ring.int_hom().map(111),
546 5,
547 4,
548 Global,
549 );
550 let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
551 fft.fft(&mut values, ring);
552 let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
553 for i in 0..values.len() {
554 assert_el_eq!(ring, expected[i], values[i]);
555 }
556}
557
558#[test]
559fn test_inv_fft_base() {
560 let ring = Zn::<241>::RING;
561 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
563 let values = [1, 3, 2, 0, 7];
564 let mut work = values;
565 let mut buffer = [0; 16];
566 fft.fft_base::<_, _, false>(&mut work, &mut buffer);
567 fft.fft_base::<_, _, true>(&mut work, &mut buffer);
568 assert_eq!(values, work);
569}
570
571#[test]
572fn test_approximate_fft() {
573 let CC = Complex64::RING;
574 for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
575 let fft = BluesteinFFT::for_complex(&CC, p, Global);
576 let mut array = (0..p)
577 .map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap()))
578 .collect::<Vec<_>>();
579 fft.fft(&mut array, CC);
580 let err = fft.expected_absolute_error(1.0, 0.0);
581 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
582 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
583 for i in 2..fft.len() {
584 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
585 }
586 }
587}
588
589#[cfg(test)]
590const BENCH_SIZE: usize = 1009;
591
592#[bench]
593fn bench_bluestein(bencher: &mut test::Bencher) {
594 let ring = zn_64::Zn::new(18597889);
595 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
596 let embedding = ring.can_hom(&fastmul_ring).unwrap();
597 let root_of_unity = fastmul_ring.coerce(&ring, get_prim_root_of_unity_zn(&ring, 2 * BENCH_SIZE).unwrap());
598 let fft = BluesteinFFT::new_with_hom(
599 embedding.clone(),
600 root_of_unity,
601 get_prim_root_of_unity_zn(&fastmul_ring, 1 << 11).unwrap(),
602 BENCH_SIZE,
603 11,
604 Global,
605 );
606 let data = (0..BENCH_SIZE)
607 .map(|i| ring.int_hom().map(i as i32))
608 .collect::<Vec<_>>();
609 let mut copy = Vec::with_capacity(BENCH_SIZE);
610 bencher.iter(|| {
611 copy.clear();
612 copy.extend(data.iter().map(|x| ring.clone_el(x)));
613 fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
614 fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
615 assert_el_eq!(ring, copy[0], data[0]);
616 });
617}