1use std::alloc::Allocator;
2use std::alloc::Global;
3use std::fmt::Debug;
4
5use crate::algorithms::fft::FFTAlgorithm;
6use crate::algorithms::unity_root::*;
7use crate::divisibility::{DivisibilityRing, DivisibilityRingStore};
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::integer::IntegerRingStore;
10use crate::primitive_int::*;
11use crate::ring::*;
12use crate::homomorphism::*;
13use crate::rings::zn::*;
14use crate::rings::float_complex::*;
15use crate::algorithms::fft::complex_fft::*;
16use crate::seq::SwappableVectorViewMut;
17
18type BaseFFT<R_main, R_twiddle, H, A> = CooleyTuckeyFFT<R_main, R_twiddle, H, A>;
19
20pub struct BluesteinFFT<R_main, R_twiddle, H, A = Global>
35 where R_main: ?Sized + RingBase,
36 R_twiddle: ?Sized + RingBase + DivisibilityRing,
37 H: Homomorphism<R_twiddle, R_main> + Clone,
38 A: Allocator + Clone
39{
40 m_fft_table: BaseFFT<R_main, R_twiddle, H, A>,
41 b_unordered_fft: Vec<R_twiddle::Element>,
42 twiddles: Vec<R_twiddle::Element>,
43 root_of_unity_n: R_main::Element,
44 n: usize
45}
46
47impl<H, A> BluesteinFFT<Complex64Base, Complex64Base, H, A>
48 where H: Homomorphism<Complex64Base, Complex64Base> + Clone,
49 A: Allocator + Clone
50{
51 pub fn for_complex_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Self{
60 let ZZ = StaticRing::<i64>::RING;
61 let CC = Complex64::RING;
62 let n_i64: i64 = n.try_into().unwrap();
63 let log2_m = ZZ.abs_log2_ceil(&(2 * n_i64 + 1)).unwrap();
64 Self::new_with_pows_with_hom(hom, |x| CC.root_of_unity(x, 2 * n_i64), |x| CC.root_of_unity(x, 1 << log2_m), n, log2_m, tmp_mem_allocator)
65 }
66}
67
68impl<R, A> BluesteinFFT<Complex64Base, Complex64Base, Identity<R>, A>
69 where R: RingStore<Type = Complex64Base> + Clone,
70 A: Allocator + Clone
71{
72 pub fn for_complex(ring: R, n: usize, tmp_mem_allocator: A) -> Self{
76 Self::for_complex_with_hom(ring.into_identity(), n, tmp_mem_allocator)
77 }
78}
79
80impl<R, A> BluesteinFFT<R::Type, R::Type, Identity<R>, A>
81 where R: RingStore + Clone,
82 R::Type: DivisibilityRing,
83 A: Allocator + Clone
84{
85 pub fn new(ring: R, root_of_unity_2n: El<R>, root_of_unity_m: El<R>, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
95 Self::new_with_hom(ring.into_identity(), root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator)
96 }
97
98 pub fn new_with_pows<F, G>(ring: R, root_of_unity_2n_pows: F, root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
107 where F: FnMut(i64) -> El<R>,
108 G: FnMut(i64) -> El<R>
109 {
110 Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_2n_pows, root_of_unity_m_pows, n, log2_m, tmp_mem_allocator)
111 }
112
113 pub fn for_zn(ring: R, n: usize, tmp_mem_allocator: A) -> Option<Self>
121 where R::Type: ZnRing
122 {
123 Self::for_zn_with_hom(ring.into_identity(), n, tmp_mem_allocator)
124 }
125}
126
127impl<R_main, R_twiddle, H, A> BluesteinFFT<R_main, R_twiddle, H, A>
128 where R_main: ?Sized + RingBase,
129 R_twiddle: ?Sized + RingBase + DivisibilityRing,
130 H: Homomorphism<R_twiddle, R_main> + Clone,
131 A: Allocator + Clone
132{
133 pub fn new_with_hom(hom: H, root_of_unity_2n: R_twiddle::Element, root_of_unity_m: R_twiddle::Element, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self {
148 let hom_copy = hom.clone();
149 let twiddle_ring = hom_copy.domain();
150 return Self::new_with_pows_with_hom(
151 hom,
152 |i: i64| if i >= 0 {
153 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), i as usize % (2 * n))
154 } else {
155 twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_2n), (-i) as usize % (2 * n))).unwrap()
156 },
157 |i: i64| if i >= 0 {
158 twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), i as usize)
159 } else {
160 twiddle_ring.invert(&twiddle_ring.pow(twiddle_ring.clone_el(&root_of_unity_m), (-i) as usize)).unwrap()
161 },
162 n,
163 log2_m,
164 tmp_mem_allocator
165 );
166 }
167
168 pub fn new_with_pows_with_hom<F, G>(hom: H, mut root_of_unity_2n_pows: F, mut root_of_unity_m_pows: G, n: usize, log2_m: usize, tmp_mem_allocator: A) -> Self
182 where F: FnMut(i64) -> R_twiddle::Element,
183 G: FnMut(i64) -> R_twiddle::Element
184 {
185 let m_fft_table = CooleyTuckeyFFT::create(
186 hom,
187 &mut root_of_unity_m_pows,
188 log2_m,
189 tmp_mem_allocator
190 );
191 return Self::create(m_fft_table, |i| root_of_unity_2n_pows(2 * i), n);
192 }
193
194 pub fn for_zn_with_hom(hom: H, n: usize, tmp_mem_allocator: A) -> Option<Self>
207 where R_twiddle: ZnRing
208 {
209 let root_of_unity_2n = get_prim_root_of_unity_zn(hom.domain(), 2 * n)?;
210 let log2_m = StaticRing::<i64>::RING.abs_log2_ceil(&(n * 2).try_into().unwrap()).unwrap();
211 let root_of_unity_m = get_prim_root_of_unity_zn(hom.domain(), 1 << log2_m)?;
212 return Some(Self::new_with_hom(hom, root_of_unity_2n, root_of_unity_m, n, log2_m, tmp_mem_allocator));
213 }
214
215 #[stability::unstable(feature = "enable")]
222 pub fn create<F>(m_fft_table: BaseFFT<R_main, R_twiddle, H, A>, mut root_of_unity_n_pows: F, n: usize) -> Self
223 where F: FnMut(i64) -> R_twiddle::Element
224 {
225 let hom = m_fft_table.hom().clone();
226 let m = m_fft_table.len();
227 assert!(m >= 2 * n);
228 assert!(n % 2 == 1);
229 assert!(hom.codomain().is_commutative());
230 assert!(hom.domain().get_ring().is_approximate() || is_prim_root_of_unity(hom.domain(), &root_of_unity_n_pows(1), n));
231 assert!(hom.codomain().get_ring().is_approximate() || is_prim_root_of_unity(hom.codomain(), &hom.map(root_of_unity_n_pows(1)), n));
232
233 let (twiddle_fft, old_hom) = m_fft_table.change_ring(hom.domain().identity());
234
235 let half_mod_n = (n + 1) / 2;
236 let mut b: Vec<_> = (0..n).map(|i| root_of_unity_n_pows(TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect();
237 b.resize_with(m, || hom.domain().zero());
238
239 twiddle_fft.unordered_fft(&mut b, hom.domain());
240
241 let twiddles = (0..n).map(|i| root_of_unity_n_pows(-TryInto::<i64>::try_into(i * i * half_mod_n).unwrap())).collect::<Vec<_>>();
242 let root_of_unity_n = hom.map(root_of_unity_n_pows(1));
243
244 return BluesteinFFT {
245 m_fft_table: twiddle_fft.change_ring(old_hom).0,
246 b_unordered_fft: b,
247 twiddles: twiddles,
248 root_of_unity_n: root_of_unity_n,
249 n: n
250 };
251 }
252
253 pub fn fft_base<V, W, const INV: bool>(&self, values: V, _buffer: W)
266 where V: SwappableVectorViewMut<R_main::Element>,
267 W: SwappableVectorViewMut<R_main::Element>
268 {
269 if INV {
270 self.unordered_inv_fft(values, self.ring());
271 } else {
272 self.unordered_fft(values, self.ring());
273 }
274 }
275
276 fn fft_base_impl<V, A2, const INV: bool>(&self, mut values: V, mut buffer: Vec<R_main::Element, A2>)
277 where V: SwappableVectorViewMut<R_main::Element>,
278 A2: Allocator
279 {
280 assert_eq!(values.len(), self.n);
281 assert_eq!(buffer.len(), self.m_fft_table.len());
282
283 let ring = self.m_fft_table.hom().codomain();
284
285 for i in 0..self.n {
287 let value = if INV {
288 values.at((self.n - i) % self.n)
289 } else {
290 values.at(i)
291 };
292 buffer[i] = self.hom().mul_ref_map(value, &self.twiddles[i]);
293 }
294 for i in self.n..self.m_fft_table.len() {
295 buffer[i] = ring.zero();
296 }
297
298 self.m_fft_table.unordered_truncated_fft(&mut buffer, self.n * 2);
299 for i in 0..self.m_fft_table.len() {
300 self.hom().mul_assign_ref_map(&mut buffer[i], &self.b_unordered_fft[i]);
301 }
302 self.m_fft_table.unordered_truncated_fft_inv(&mut buffer, self.n * 2);
303
304 let (buffer1, buffer2) = buffer[..(2 * self.n)].split_at_mut(self.n);
306 for (a, b) in buffer1.iter_mut().zip(buffer2.iter_mut()) {
307 ring.add_assign_ref(a, b);
308 }
309
310 for (i, x) in buffer.into_iter().enumerate().take(self.n) {
312 *values.at_mut(i) = self.hom().mul_ref_snd_map(x, &self.twiddles[i]);
313 }
314
315 if INV {
316 let scale = self.hom().map(self.hom().domain().checked_div(&self.hom().domain().one(), &self.hom().domain().int_hom().map(self.n.try_into().unwrap())).unwrap());
318 for i in 0..values.len() {
319 ring.mul_assign_ref(values.at_mut(i), &scale);
320 }
321 }
322 }
323
324 #[stability::unstable(feature = "enable")]
328 pub fn allocator(&self) -> &A {
329 self.m_fft_table.allocator()
330 }
331
332 #[stability::unstable(feature = "enable")]
336 pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
337 self.hom().codomain()
338 }
339
340 #[stability::unstable(feature = "enable")]
345 pub fn hom(&self) -> &H {
346 self.m_fft_table.hom()
347 }
348}
349
350impl<R_main, R_twiddle, H, A> PartialEq for BluesteinFFT<R_main, R_twiddle, H, A>
351 where R_main: ?Sized + RingBase,
352 R_twiddle: ?Sized + RingBase + DivisibilityRing,
353 H: Homomorphism<R_twiddle, R_main> + Clone,
354 A: Allocator + Clone
355{
356 fn eq(&self, other: &Self) -> bool {
357 self.ring().get_ring() == other.ring().get_ring() &&
358 self.n == other.n &&
359 self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
360 }
361}
362
363impl<R_main, R_twiddle, H, A> Debug for BluesteinFFT<R_main, R_twiddle, H, A>
364 where R_main: ?Sized + RingBase + Debug,
365 R_twiddle: ?Sized + RingBase + DivisibilityRing,
366 H: Homomorphism<R_twiddle, R_main> + Clone,
367 A: Allocator + Clone
368{
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 f.debug_struct("BluesteinFFT")
371 .field("ring", &self.ring().get_ring())
372 .field("n", &self.n)
373 .field("root_of_unity_n", &self.ring().format(&self.root_of_unity(self.ring())))
374 .finish()
375 }
376}
377
378impl<R_main, R_twiddle, H, A> FFTAlgorithm<R_main> for BluesteinFFT<R_main, R_twiddle, H, A>
379 where R_main: ?Sized + RingBase,
380 R_twiddle: ?Sized + RingBase + DivisibilityRing,
381 H: Homomorphism<R_twiddle, R_main> + Clone,
382 A: Allocator + Clone
383{
384 fn len(&self) -> usize {
385 self.n
386 }
387
388 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
389 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
390 &self.root_of_unity_n
391 }
392
393 fn unordered_fft_permutation(&self, i: usize) -> usize {
394 i
395 }
396
397 fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
398 i
399 }
400
401 fn unordered_fft<V, S>(&self, values: V, ring: S)
402 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
403 S: RingStore<Type = R_main> + Copy
404 {
405 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
406 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
407 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
408 self.fft_base_impl::<_, _, false>(values, buffer);
409 }
410
411 fn unordered_inv_fft<V, S>(&self, values: V, ring: S)
412 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
413 S: RingStore<Type = R_main> + Copy
414 {
415 assert!(ring.get_ring() == self.ring().get_ring(), "unsupported ring");
416 let mut buffer = Vec::with_capacity_in(self.m_fft_table.len(), self.allocator().clone());
417 buffer.extend((0..self.m_fft_table.len()).map(|_| self.ring().zero()));
418 self.fft_base_impl::<_, _, true>(values, buffer);
419 }
420
421 fn fft<V, S>(&self, values: V, ring: S)
422 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
423 S: RingStore<Type = R_main> + Copy
424 {
425 self.unordered_fft(values, ring);
426 }
427
428 fn inv_fft<V, S>(&self, values: V, ring: S)
429 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
430 S: RingStore<Type = R_main> + Copy
431 {
432 self.unordered_inv_fft(values, ring);
433 }
434}
435
436impl<H, A> FFTErrorEstimate for BluesteinFFT<Complex64Base, Complex64Base, H, A>
437 where H: Homomorphism<Complex64Base, Complex64Base> + Clone,
438 A: Allocator + Clone
439{
440 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
441 let error_after_twiddling = input_error + input_bound * (root_of_unity_error() + f64::EPSILON);
442 let error_after_fft = self.m_fft_table.expected_absolute_error(input_bound, error_after_twiddling);
443 let b_bitreverse_fft_error = self.m_fft_table.expected_absolute_error(1., root_of_unity_error());
444 let new_input_bound = input_bound * self.m_fft_table.len() as f64;
446 let b_bitreverse_fft_bound = self.m_fft_table.len() as f64;
447 let error_after_mul = new_input_bound * b_bitreverse_fft_error + b_bitreverse_fft_bound * error_after_fft + f64::EPSILON * new_input_bound * b_bitreverse_fft_bound;
448 let error_after_inv_fft = self.m_fft_table.expected_absolute_error(new_input_bound * b_bitreverse_fft_bound, error_after_mul) / self.m_fft_table.len() as f64;
449 let error_end = error_after_inv_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
450 return error_end;
451 }
452}
453
454#[cfg(test)]
455use crate::rings::zn::zn_static::*;
456
457#[test]
458fn test_fft_base() {
459 let ring = Zn::<241>::RING;
460 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
462 let mut values = [1, 3, 2, 0, 7];
463 let mut buffer = [0; 16];
464 fft.fft_base::<_, _, false>(&mut values, &mut buffer);
465 let expected = [13, 137, 202, 206, 170];
466 assert_eq!(expected, values);
467}
468
469#[test]
470fn test_fft_fastmul() {
471 let ring = zn_64::Zn::new(241);
472 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
473 let fft = BluesteinFFT::new_with_hom(ring.can_hom(&fastmul_ring).unwrap(), fastmul_ring.int_hom().map(36), fastmul_ring.int_hom().map(111), 5, 4, Global);
474 let mut values: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([1, 3, 2, 0, 7][i]));
475 fft.fft(&mut values, ring);
476 let expected: [_; 5] = std::array::from_fn(|i| ring.int_hom().map([13, 137, 202, 206, 170][i]));
477 for i in 0..values.len() {
478 assert_el_eq!(ring, expected[i], values[i]);
479 }
480}
481
482#[test]
483fn test_inv_fft_base() {
484 let ring = Zn::<241>::RING;
485 let fft = BluesteinFFT::new(ring, ring.int_hom().map(36), ring.int_hom().map(111), 5, 4, Global);
487 let values = [1, 3, 2, 0, 7];
488 let mut work = values;
489 let mut buffer = [0; 16];
490 fft.fft_base::<_, _, false>(&mut work, &mut buffer);
491 fft.fft_base::<_, _, true>(&mut work, &mut buffer);
492 assert_eq!(values, work);
493}
494
495#[test]
496fn test_approximate_fft() {
497 let CC = Complex64::RING;
498 for (p, _log2_m) in [(5, 4), (53, 7), (1009, 11)] {
499 let fft = BluesteinFFT::for_complex(&CC, p, Global);
500 let mut array = (0..p).map(|i| CC.root_of_unity(i.try_into().unwrap(), p.try_into().unwrap())).collect::<Vec<_>>();
501 fft.fft(&mut array, CC);
502 let err = fft.expected_absolute_error(1., 0.);
503 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
504 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
505 for i in 2..fft.len() {
506 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
507 }
508 }
509}
510
511#[cfg(test)]
512const BENCH_SIZE: usize = 1009;
513
514#[bench]
515fn bench_bluestein(bencher: &mut test::Bencher) {
516 let ring = zn_64::Zn::new(18597889);
517 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
518 let embedding = ring.can_hom(&fastmul_ring).unwrap();
519 let root_of_unity = fastmul_ring.coerce(&ring, get_prim_root_of_unity_zn(&ring, 2 * BENCH_SIZE).unwrap());
520 let fft = BluesteinFFT::new_with_hom(
521 embedding.clone(),
522 root_of_unity,
523 get_prim_root_of_unity_zn(&fastmul_ring, 1 << 11).unwrap(),
524 BENCH_SIZE,
525 11,
526 Global
527 );
528 let data = (0..BENCH_SIZE).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
529 let mut copy = Vec::with_capacity(BENCH_SIZE);
530 bencher.iter(|| {
531 copy.clear();
532 copy.extend(data.iter().map(|x| ring.clone_el(x)));
533 fft.unordered_fft(std::hint::black_box(&mut copy[..]), &ring);
534 fft.unordered_inv_fft(std::hint::black_box(&mut copy[..]), &ring);
535 assert_el_eq!(ring, copy[0], data[0]);
536 });
537}