1use std::alloc::Allocator;
2use std::alloc::Global;
3use std::fmt::Debug;
4
5use crate::algorithms::fft::FFTAlgorithm;
6use crate::algorithms::unity_root::*;
7use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::integer::IntegerRingStore;
10use crate::primitive_int::*;
11use crate::ring::*;
12use crate::homomorphism::*;
13use crate::rings::zn::*;
14use crate::rings::float_complex::*;
15use crate::algorithms::fft::complex_fft::*;
16use crate::seq::SwappableVectorViewMut;
17
18type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
19
20pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
35 where R_main: ?Sized + RingBase,
36 R_twiddle: ?Sized + RingBase + DivisibilityRing,
37 H: Homomorphism<R_twiddle, R_main> + Clone,
38 A: Allocator + Clone
39{
40 m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
41 b_unordered_fft: Vec<R_twiddle::Element>,
42 twiddles: Vec<R_twiddle::Element>,
43 root_of_unity_n: R_main::Element,
44 n: usize
45}
46
47impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
48 where H: Homomorphism<Complex64Base, Complex64Base> + Clone,
49 A: Allocator + Clone
50{
51 pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self{
60 let ZZ = StaticRing::<i64>::RING;
61 let CC = Complex64::RING;
62 let n_i64: i64 = n.try_into().unwrap();
63 let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
64 Self::new_with_pows_with_hom(hom, |x| CC.root_of_unity(x, 2 * n_i64), |x| CC.root_of_unity(x, 1 << log2_m), n, log2_m, tmp_mem_allocator)
65 }
66}
67
68impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
69 where R: RingStore<Type = Complex64Base> + Clone,
70 A: Allocator + Clone
71{
72 pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self{
76 Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
77 }
78}
79
80impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
81 where R: RingStore + Clone,
82 R::Type: DivisibilityRing,
83 A: Allocator + Clone
84{
85 pub fn new(ring: R, root_of_unity_2n: El<R>, root_of_unity_m: El<R>, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
95 Self::new_with_hom(ring.into_identity(), root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator)
96 }
97
98 pub fn new_with_pows<F, G>(ring: R, root_of_unity_2n_pows: F, root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
107 where F: FnMut(i64) -> El<R>,
108 G: FnMut(i64) -> El<R>
109 {
110 Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_2n_pows, root_of_unity_m_pows, n, log2_m, tmp_mem_allocator)
111 }
112
113 pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
121 where R::Type: ZnRing
122 {
123 Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
124 }
125}
126
127impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
128 where R_main: ?Sized + RingBase,
129 R_twiddle: ?Sized + RingBase + DivisibilityRing,
130 H: Homomorphism<R_twiddle, R_main> + Clone,
131 A: Allocator + Clone
132{
133 pub fn new_with_hom(hom: H, root_of_unity_2n: R_twiddle::Element, root_of_unity_m: R_twiddle::Element, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
148 let hom_copy = hom.clone();
149 let twiddle_ring = hom_copy.domain();
150 return Self::new_with_pows_with_hom(
151 hom,
152 |i: i64| if i >= 0 {
153 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
154 } else {
155 twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n))).unwrap()
156 },
157 |i: i64| if i >= 0 {
158 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
159 } else {
160 twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize)).unwrap()
161 },
162 n,
163 log2_m,
164 tmp_mem_allocator
165 );
166 }
167
168 pub fn new_with_pows_with_hom<F, G>(hom: H, mut root_of_unity_2n_pows: F, mut root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
182 where F: FnMut(i64) -> R_twiddle::Element,
183 G: FnMut(i64) -> R_twiddle::Element
184 {
185 let m_fft_table = CooleyTuckeyFFT::create(
186 hom,
187 &mut root_of_unity_m_pows,
188 log2_m,
189 tmp_mem_allocator
190 );
191 return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
192 }
193
194 pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
207 where R_twiddle: ZnRing
208 {
209 let ring_as_field = hom.domain().as_field().ok().unwrap();
210 let root_of_unity_2n = ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity(&ring_as_field, 2 * n)?);
211 let log2_m = StaticRing::<i64>::RING.abs_log2_ceil(&(n * 2).try_into().unwrap()).unwrap();
212 let root_of_unity_m = ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity_pow2(&ring_as_field, log2_m)?);
213 return Some(Self::new_with_hom(hom, root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator));
214 }
215
216 #[stability::unstable(feature = "enable")]
223 pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
224 where F: FnMut(i64) -> R_twiddle::Element
225 {
226 let hom = m_fft_table.hom().clone();
227 let m = m_fft_table.len();
228 assert!(m >= 2 * n);
229 assert!(n % 2 == 1);
230 assert!(hom.codomain().is_commutative());
231 assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n));
232 assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n));
233
234 let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
235
236 let half_mod_n = (n + 1) / 2;
237 let mut b: Vec<_> = (0..n).map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect();
238 b.resize_with(m, || hom.domain().zero());
239
240 twiddle_fft.unordered_fft(&mut b, hom.domain());
241
242 let twiddles = (0..n).map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect::<Vec<_>>();
243 let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
244
245 return BluesteinFFT {
246 m_fft_table: twiddle_fft.change_ring(old_hom).0,
247 b_unordered_fft: b,
248 twiddles: twiddles,
249 root_of_unity_n: root_of_unity_n,
250 n: n
251 };
252 }
253
254 pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
267 where V: SwappableVectorViewMut<R_main::Element>,
268 W: SwappableVectorViewMut<R_main::Element>
269 {
270 if INV {
271 self.unordered_inv_fft(values, self.ring());
272 } else {
273 self.unordered_fft(values, self.ring());
274 }
275 }
276
277 fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
278 where V: SwappableVectorViewMut<R_main::Element>,
279 A2: Allocator
280 {
281 assert_eq!(values.len(), self.n);
282 assert_eq!(buffer.len(), self.m_fft_table.len());
283
284 let ring = self.m_fft_table.hom().codomain();
285
286 for i in 0..self.n {
288 let value = if INV {
289 values.at((self.n - i) % self.n)
290 } else {
291 values.at(i)
292 };
293 buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
294 }
295 for i in self.n..self.m_fft_table.len() {
296 buffer[i] = ring.zero();
297 }
298
299 self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
300 for i in 0..self.m_fft_table.len() {
301 self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
302 }
303 self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
304
305 let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
307 for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
308 ring.add_assign_ref(a, b);
309 }
310
311 for (i, x) in buffer.into_iter().enumerate().take(self.n) {
313 *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
314 }
315
316 if INV {
317 let scale = self.hom().map(self.hom().domain().checked_div(&self.hom().domain().one(), &self.hom().domain().int_hom().map(self.n.try_into().unwrap())).unwrap());
319 for i in 0..values.len() {
320 ring.mul_assign_ref(values.at_mut(i), &scale);
321 }
322 }
323 }
324
325 #[stability::unstable(feature = "enable")]
329 pub fn allocator(&self) -> &A {
330 self.m_fft_table.allocator()
331 }
332
333 #[stability::unstable(feature = "enable")]
337 pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
338 self.hom().codomain()
339 }
340
341 #[stability::unstable(feature = "enable")]
346 pub fn hom(&self) -> &H {
347 self.m_fft_table.hom()
348 }
349}
350
351impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
352 where R_main: ?Sized + RingBase,
353 R_twiddle: ?Sized + RingBase + DivisibilityRing,
354 H: Homomorphism<R_twiddle, R_main> + Clone,
355 A: Allocator + Clone
356{
357 fn eq(&self, other: &Self) -> bool {
358 self.ring().get_ring() == other.ring().get_ring() &&
359 self.n == other.n &&
360 self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
361 }
362}
363
364impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
365 where R_main: ?Sized + RingBase + Debug,
366 R_twiddle: ?Sized + RingBase + DivisibilityRing,
367 H: Homomorphism<R_twiddle, R_main> + Clone,
368 A: Allocator + Clone
369{
370 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
371 f.debug_struct("BluesteinFFT")
372 .field("ring", &self.ring().get_ring())
373 .field("n", &self.n)
374 .field("root_of_unity_n", &self.ring().format(&self.root_of_unity(self.ring())))
375 .finish()
376 }
377}
378
379impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
380 where R_main: ?Sized + RingBase,
381 R_twiddle: ?Sized + RingBase + DivisibilityRing,
382 H: Homomorphism<R_twiddle, R_main> + Clone,
383 A: Allocator + Clone
384{
385 fn len(&self) -> usize {
386 self.n
387 }
388
389 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
390 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
391 &self.root_of_unity_n
392 }
393
394 fn unordered_fft_permutation(&self, i: usize) -> usize {
395 i
396 }
397
398 fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
399 i
400 }
401
402 fn unordered_fft<V, S>(&self, values: V, ring: S)
403 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
404 S: RingStore<Type = R_main> + Copy
405 {
406 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
407 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
408 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
409 self.fft_base_impl::<_, _, false>(values, buffer);
410 }
411
412 fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
413 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
414 S: RingStore<Type = R_main> + Copy
415 {
416 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
417 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
418 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
419 self.fft_base_impl::<_, _, true>(values, buffer);
420 }
421
422 fn fft<V, S>(&self, values: V, ring: S)
423 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
424 S: RingStore<Type = R_main> + Copy
425 {
426 self.unordered_fft(values, ring);
427 }
428
429 fn inv_fft<V, S>(&self, values: V, ring: S)
430 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
431 S: RingStore<Type = R_main> + Copy
432 {
433 self.unordered_inv_fft(values, ring);
434 }
435}
436
437impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
438 where H: Homomorphism<Complex64Base, Complex64Base> + Clone,
439 A: Allocator + Clone
440{
441 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
442 let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
443 let error_after_fft = self.m_fft_table.expected_absolute_error(input_bound, error_after_twiddling);
444 let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1., root_of_unity_error());
445 let new_input_bound = input_bound * self.m_fft_table.len() as f64;
447 let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
448 let error_after_mul = new_input_bound * b_bitreverse_fft_error + b_bitreverse_fft_bound * error_after_fft + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
449 let error_after_inv_fft = self.m_fft_table.expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul) / self.m_fft_table.len() as f64;
450 let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
451 return error_end;
452 }
453}
454
455#[cfg(test)]
456use crate::rings::zn::zn_static::*;
457
458#[test]
459fn test_fft_base() {
460 let ring = Zn::<241>::RING;
461 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
463 let mut values = [1, 3, 2, 0, 7];
464 let mut buffer = [0; 16];
465 fft.fft_base::<_, _, false>(&mut values, &mut buffer);
466 let expected = [13, 137, 202, 206, 170];
467 assert_eq!(expected, values);
468}
469
470#[test]
471fn test_fft_fastmul() {
472 let ring = zn_64::Zn::new(241);
473 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
474 let fft = BluesteinFFT::new_with_hom(ring.can_hom(&fastmul_ring).unwrap(), fastmul_ring.int_hom().map(36), fastmul_ring.int_hom().map(111), 5, 4, Global);
475 let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
476 fft.fft(&mut values, ring);
477 let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
478 for i in 0..values.len() {
479 assert_el_eq!(ring, expected[i], values[i]);
480 }
481}
482
483#[test]
484fn test_inv_fft_base() {
485 let ring = Zn::<241>::RING;
486 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
488 let values = [1, 3, 2, 0, 7];
489 let mut work = values;
490 let mut buffer = [0; 16];
491 fft.fft_base::<_, _, false>(&mut work, &mut buffer);
492 fft.fft_base::<_, _, true>(&mut work, &mut buffer);
493 assert_eq!(values, work);
494}
495
496#[test]
497fn test_approximate_fft() {
498 let CC = Complex64::RING;
499 for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
500 let fft = BluesteinFFT::for_complex(&CC, p, Global);
501 let mut array = (0..p).map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap())).collect::<Vec<_>>();
502 fft.fft(&mut array, CC);
503 let err = fft.expected_absolute_error(1., 0.);
504 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
505 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
506 for i in 2..fft.len() {
507 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
508 }
509 }
510}
511
512#[cfg(test)]
513const BENCH_SIZE: usize = 1009;
514
515#[bench]
516fn bench_bluestein(bencher: &mut test::Bencher) {
517 let ring = zn_64::Zn::new(18597889);
518 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
519 let embedding = ring.can_hom(&fastmul_ring).unwrap();
520 let ring_as_field = ring.as_field().ok().unwrap();
521 let root_of_unity = fastmul_ring.coerce(&ring, ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity(&ring_as_field, 2 * BENCH_SIZE).unwrap()));
522 let fastmul_ring_as_field = fastmul_ring.as_field().ok().unwrap();
523 let fft = BluesteinFFT::new_with_hom(
524 embedding.clone(),
525 root_of_unity,
526 fastmul_ring_as_field.get_ring().unwrap_element(get_prim_root_of_unity_pow2(&fastmul_ring_as_field, 11).unwrap()),
527 BENCH_SIZE,
528 11,
529 Global
530 );
531 let data = (0..BENCH_SIZE).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
532 let mut copy = Vec::with_capacity(BENCH_SIZE);
533 bencher.iter(|| {
534 copy.clear();
535 copy.extend(data.iter().map(|x| ring.clone_el(x)));
536 fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
537 fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
538 assert_el_eq!(ring, copy[0], data[0]);
539 });
540}