1use std::alloc::Allocator;
2
3use crate::algorithms::fft::complex_fft::*;
4use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
5use crate::algorithms::fft::radix3::CooleyTukeyRadix3FFT;
6use crate::algorithms::fft::*;
7use crate::algorithms::unity_root::is_prim_root_of_unity;
8use crate::divisibility::DivisibilityRing;
9use crate::homomorphism::*;
10use crate::rings::float_complex::*;
11use crate::seq::subvector::SubvectorView;
12
13#[stability::unstable(feature = "enable")]
16pub struct GeneralCooleyTukeyFFT<R_main, R_twiddle, H, T1, T2>
17where
18 R_main: ?Sized + RingBase,
19 R_twiddle: ?Sized + RingBase,
20 H: Homomorphism<R_twiddle, R_main>,
21 T1: FFTAlgorithm<R_main>,
22 T2: FFTAlgorithm<R_main>,
23{
24 twiddle_factors: Vec<R_twiddle::Element>,
25 inv_twiddle_factors: Vec<R_twiddle::Element>,
26 left_table: T1,
27 right_table: T2,
28 root_of_unity: R_main::Element,
29 root_of_unity_twiddle: R_twiddle::Element,
30 hom: H,
31}
32
33impl<R, T1, T2> GeneralCooleyTukeyFFT<R::Type, R::Type, Identity<R>, T1, T2>
34where
35 R: RingStore,
36 T1: FFTAlgorithm<R::Type>,
37 T2: FFTAlgorithm<R::Type>,
38{
39 #[stability::unstable(feature = "enable")]
47 pub fn new_with_pows<F>(ring: R, root_of_unity_pows: F, left_table: T1, right_table: T2) -> Self
48 where
49 F: FnMut(i64) -> El<R>,
50 {
51 Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_pows, left_table, right_table)
52 }
53
54 #[stability::unstable(feature = "enable")]
65 pub fn new(ring: R, root_of_unity: El<R>, left_table: T1, right_table: T2) -> Self {
66 Self::new_with_hom(ring.into_identity(), root_of_unity, left_table, right_table)
67 }
68}
69
70impl<R_main, R_twiddle, H, A1, A2>
71 GeneralCooleyTukeyFFT<
72 R_main,
73 R_twiddle,
74 H,
75 CooleyTukeyRadix3FFT<R_main, R_twiddle, H, A1>,
76 CooleyTuckeyFFT<R_main, R_twiddle, H, A2>,
77 >
78where
79 R_main: ?Sized + RingBase,
80 R_twiddle: ?Sized + RingBase + DivisibilityRing,
81 H: Homomorphism<R_twiddle, R_main>,
82 A1: Allocator + Clone,
83 A2: Allocator + Clone,
84{
85 #[stability::unstable(feature = "enable")]
92 pub fn change_ring<R_new: ?Sized + RingBase, H_new: Clone + Homomorphism<R_twiddle, R_new>>(
93 self,
94 new_hom: H_new,
95 ) -> (
96 GeneralCooleyTukeyFFT<
97 R_new,
98 R_twiddle,
99 H_new,
100 CooleyTukeyRadix3FFT<R_new, R_twiddle, H_new, A1>,
101 CooleyTuckeyFFT<R_new, R_twiddle, H_new, A2>,
102 >,
103 H,
104 ) {
105 let ring = new_hom.codomain();
106 let root_of_unity = new_hom.map_ref(&self.root_of_unity_twiddle);
107 assert!(ring.is_commutative());
108 assert!(ring.get_ring().is_approximate() || is_prim_root_of_unity(&ring, &root_of_unity, self.len()));
109
110 return (
111 GeneralCooleyTukeyFFT {
112 twiddle_factors: self.twiddle_factors,
113 left_table: self.left_table.change_ring(new_hom.clone()).0,
114 right_table: self.right_table.change_ring(new_hom.clone()).0,
115 inv_twiddle_factors: self.inv_twiddle_factors,
116 root_of_unity,
117 root_of_unity_twiddle: self.root_of_unity_twiddle,
118 hom: new_hom,
119 },
120 self.hom,
121 );
122 }
123}
124
125impl<R_main, R_twiddle, H, T1, T2> GeneralCooleyTukeyFFT<R_main, R_twiddle, H, T1, T2>
126where
127 R_main: ?Sized + RingBase,
128 R_twiddle: ?Sized + RingBase,
129 H: Homomorphism<R_twiddle, R_main>,
130 T1: FFTAlgorithm<R_main>,
131 T2: FFTAlgorithm<R_main>,
132{
133 #[stability::unstable(feature = "enable")]
146 pub fn new_with_pows_with_hom<F>(hom: H, root_of_unity_pows: F, left_table: T1, right_table: T2) -> Self
147 where
148 F: FnMut(i64) -> R_twiddle::Element,
149 {
150 Self::create(hom, root_of_unity_pows, left_table, right_table)
151 }
152
153 #[stability::unstable(feature = "enable")]
156 pub fn create<F>(hom: H, mut root_of_unity_pows: F, left_table: T1, right_table: T2) -> Self
157 where
158 F: FnMut(i64) -> R_twiddle::Element,
159 {
160 let ring = hom.codomain();
161
162 assert!(ring.is_commutative());
163 assert!(
164 ring.get_ring().is_approximate()
165 || is_prim_root_of_unity(
166 ring,
167 &hom.map(root_of_unity_pows(1)),
168 left_table.len() * right_table.len()
169 )
170 );
171 assert!(
172 ring.get_ring().is_approximate()
173 || ring.eq_el(
174 &hom.map(root_of_unity_pows(right_table.len().try_into().unwrap())),
175 left_table.root_of_unity(ring)
176 )
177 );
178 assert!(
179 ring.get_ring().is_approximate()
180 || ring.eq_el(
181 &hom.map(root_of_unity_pows(left_table.len().try_into().unwrap())),
182 right_table.root_of_unity(ring)
183 )
184 );
185
186 let root_of_unity = root_of_unity_pows(1);
187 let inv_twiddle_factors = Self::create_twiddle_factors(|i| root_of_unity_pows(-i), &left_table, &right_table);
188 let twiddle_factors = Self::create_twiddle_factors(root_of_unity_pows, &left_table, &right_table);
189
190 GeneralCooleyTukeyFFT {
191 twiddle_factors,
192 inv_twiddle_factors,
193 left_table,
194 right_table,
195 root_of_unity: hom.map_ref(&root_of_unity),
196 root_of_unity_twiddle: root_of_unity,
197 hom,
198 }
199 }
200
201 #[stability::unstable(feature = "enable")]
203 pub fn left_fft_table(&self) -> &T1 { &self.left_table }
204
205 #[stability::unstable(feature = "enable")]
207 pub fn right_fft_table(&self) -> &T2 { &self.right_table }
208
209 #[stability::unstable(feature = "enable")]
212 pub fn hom<'a>(&'a self) -> &'a H { &self.hom }
213
214 #[stability::unstable(feature = "enable")]
230 pub fn new_with_hom(hom: H, root_of_unity: R_twiddle::Element, left_table: T1, right_table: T2) -> Self {
231 let len = left_table.len() * right_table.len();
232 let root_of_unity_pows = |i: i64| {
233 if i >= 0 {
234 hom.domain()
235 .pow(hom.domain().clone_el(&root_of_unity), i.try_into().unwrap())
236 } else {
237 let len_i64: i64 = len.try_into().unwrap();
238 hom.domain().pow(
239 hom.domain().clone_el(&root_of_unity),
240 (len_i64 + (i % len_i64)).try_into().unwrap(),
241 )
242 }
243 };
244 let result = GeneralCooleyTukeyFFT::create(&hom, root_of_unity_pows, left_table, right_table);
245 GeneralCooleyTukeyFFT {
246 twiddle_factors: result.twiddle_factors,
247 inv_twiddle_factors: result.inv_twiddle_factors,
248 left_table: result.left_table,
249 right_table: result.right_table,
250 root_of_unity: result.root_of_unity,
251 root_of_unity_twiddle: result.root_of_unity_twiddle,
252 hom,
253 }
254 }
255
256 fn create_twiddle_factors<F>(
257 mut root_of_unity_pows: F,
258 left_table: &T1,
259 right_table: &T2,
260 ) -> Vec<R_twiddle::Element>
261 where
262 F: FnMut(i64) -> R_twiddle::Element,
263 {
264 (0..(left_table.len() * right_table.len()))
265 .map(|i| {
266 let ri: i64 = (i % right_table.len()).try_into().unwrap();
267 let li = i / right_table.len();
268 return root_of_unity_pows(
269 TryInto::<i64>::try_into(left_table.unordered_fft_permutation(li)).unwrap() * ri,
270 );
271 })
272 .collect::<Vec<_>>()
273 }
274
275 #[stability::unstable(feature = "enable")]
277 pub fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore { self.hom.codomain() }
278}
279
280impl<R_main, R_twiddle, H, T1, T2> PartialEq for GeneralCooleyTukeyFFT<R_main, R_twiddle, H, T1, T2>
281where
282 R_main: ?Sized + RingBase,
283 R_twiddle: ?Sized + RingBase,
284 H: Homomorphism<R_twiddle, R_main>,
285 T1: FFTAlgorithm<R_main> + PartialEq,
286 T2: FFTAlgorithm<R_main> + PartialEq,
287{
288 fn eq(&self, other: &Self) -> bool {
289 self.ring().get_ring() == other.ring().get_ring()
290 && self.left_table == other.left_table
291 && self.right_table == other.right_table
292 && self
293 .ring()
294 .eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
295 }
296}
297
298impl<R_main, R_twiddle, H, T1, T2> FFTAlgorithm<R_main> for GeneralCooleyTukeyFFT<R_main, R_twiddle, H, T1, T2>
299where
300 R_main: ?Sized + RingBase,
301 R_twiddle: ?Sized + RingBase,
302 H: Homomorphism<R_twiddle, R_main>,
303 T1: FFTAlgorithm<R_main>,
304 T2: FFTAlgorithm<R_main>,
305{
306 fn len(&self) -> usize { self.left_table.len() * self.right_table.len() }
307
308 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
309 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
310 &self.root_of_unity
311 }
312
313 fn unordered_fft<V, S>(&self, mut values: V, ring: S)
314 where
315 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
316 S: RingStore<Type = R_main> + Copy,
317 {
318 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
319 if self.left_table.len() > 1 {
320 for i in 0..self.right_table.len() {
321 let mut v = SubvectorView::new(&mut values)
322 .restrict(i..)
323 .step_by_view(self.right_table.len());
324 self.left_table.unordered_fft(&mut v, ring);
325 }
326 for i in 0..self.len() {
327 self.hom
328 .mul_assign_ref_map(values.at_mut(i), self.inv_twiddle_factors.at(i));
329 }
330 }
331 for i in 0..self.left_table.len() {
332 let mut v = SubvectorView::new(&mut values)
333 .restrict((i * self.right_table.len())..((i + 1) * self.right_table.len()));
334 self.right_table.unordered_fft(&mut v, ring);
335 }
336 }
337
338 fn unordered_inv_fft<V, S>(&self, mut values: V, ring: S)
339 where
340 V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
341 S: RingStore<Type = R_main> + Copy,
342 {
343 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
344 for i in 0..self.left_table.len() {
345 let mut v = SubvectorView::new(&mut values)
346 .restrict((i * self.right_table.len())..((i + 1) * self.right_table.len()));
347 self.right_table.unordered_inv_fft(&mut v, ring);
348 }
349 if self.left_table.len() > 1 {
350 for i in 0..self.len() {
351 self.hom
352 .mul_assign_ref_map(values.at_mut(i), self.twiddle_factors.at(i));
353 debug_assert!(
354 self.ring().get_ring().is_approximate()
355 || self.hom.domain().is_one(
356 &self
357 .hom
358 .domain()
359 .mul_ref(self.twiddle_factors.at(i), self.inv_twiddle_factors.at(i))
360 )
361 );
362 }
363 for i in 0..self.right_table.len() {
364 let mut v = SubvectorView::new(&mut values)
365 .restrict(i..)
366 .step_by_view(self.right_table.len());
367 self.left_table.unordered_inv_fft(&mut v, ring);
368 }
369 }
370 }
371
372 fn unordered_fft_permutation(&self, i: usize) -> usize {
373 assert!(i < self.len());
374 self.left_table.unordered_fft_permutation(i / self.right_table.len())
375 + self.left_table.len() * self.right_table.unordered_fft_permutation(i % self.right_table.len())
376 }
377
378 fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
379 assert!(i < self.len());
380 self.left_table.unordered_fft_permutation_inv(i % self.left_table.len()) * self.right_table.len()
381 + self
382 .right_table
383 .unordered_fft_permutation_inv(i / self.left_table.len())
384 }
385}
386
387impl<H, T1, T2> FFTErrorEstimate for GeneralCooleyTukeyFFT<Complex64Base, Complex64Base, H, T1, T2>
388where
389 H: Homomorphism<Complex64Base, Complex64Base>,
390 T1: FFTAlgorithm<Complex64Base> + FFTErrorEstimate,
391 T2: FFTAlgorithm<Complex64Base> + FFTErrorEstimate,
392{
393 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
394 let error_after_first_fft = self.left_table.expected_absolute_error(input_bound, input_error);
395 let new_input_bound = self.left_table.len() as f64 * input_bound;
396 let error_after_twiddling = error_after_first_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
397 return self
398 .right_table
399 .expected_absolute_error(new_input_bound, error_after_twiddling);
400 }
401}
402
403#[cfg(test)]
404use std::alloc::Global;
405
406#[cfg(test)]
407use crate::algorithms::fft::bluestein::BluesteinFFT;
408#[cfg(test)]
409use crate::algorithms::unity_root::*;
410#[cfg(test)]
411use crate::rings::zn::zn_64;
412#[cfg(test)]
413use crate::rings::zn::zn_static::{Fp, Zn};
414
415#[test]
416fn test_fft_basic() {
417 let ring = Zn::<97>::RING;
418 let z = ring.int_hom().map(39);
419 let fft = GeneralCooleyTukeyFFT::new(
420 ring,
421 ring.pow(z, 16),
422 CooleyTuckeyFFT::new(ring, ring.pow(z, 48), 1),
423 BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
424 );
425 let mut values = [1, 0, 0, 1, 0, 1];
426 let expected = [3, 62, 63, 96, 37, 36];
427 let mut permuted_expected = [0; 6];
428 for i in 0..6 {
429 permuted_expected[i] = expected[fft.unordered_fft_permutation(i)];
430 }
431
432 fft.unordered_fft(&mut values, ring);
433 assert_eq!(values, permuted_expected);
434}
435
436#[test]
437fn test_fft_long() {
438 let ring = Fp::<97>::RING;
439 let z = ring.int_hom().map(39);
440 let fft = GeneralCooleyTukeyFFT::new(
441 ring,
442 ring.pow(z, 4),
443 CooleyTuckeyFFT::new(ring, ring.pow(z, 12), 3),
444 BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
445 );
446 let mut values = [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 2, 2, 0, 2, 0, 1, 2, 3, 4];
447 let expected = [
448 26, 0, 75, 47, 41, 31, 28, 62, 39, 93, 53, 27, 0, 54, 74, 61, 65, 81, 63, 38, 53, 94, 89, 91,
449 ];
450 let mut permuted_expected = [0; 24];
451 for i in 0..24 {
452 permuted_expected[i] = expected[fft.unordered_fft_permutation(i)];
453 }
454
455 fft.unordered_fft(&mut values, ring);
456 assert_eq!(values, permuted_expected);
457}
458
459#[test]
460fn test_fft_unordered() {
461 let ring = Fp::<1409>::RING;
462 let z = get_prim_root_of_unity(ring, 64 * 11).unwrap();
463 let fft = GeneralCooleyTukeyFFT::new(
464 ring,
465 ring.pow(z, 4),
466 CooleyTuckeyFFT::new(ring, ring.pow(z, 44), 4),
467 BluesteinFFT::new(ring, ring.pow(z, 32), ring.pow(z, 22), 11, 5, Global),
468 );
469 const LEN: usize = 16 * 11;
470 let mut values = [0; LEN];
471 for i in 0..LEN {
472 values[i] = ring.int_hom().map(i as i32);
473 }
474 let original = values;
475
476 fft.unordered_fft(&mut values, ring);
477
478 let mut ordered_fft = [0; LEN];
479 for i in 0..LEN {
480 ordered_fft[fft.unordered_fft_permutation(i)] = values[i];
481 }
482
483 fft.unordered_inv_fft(&mut values, ring);
484 assert_eq!(values, original);
485
486 fft.inv_fft(&mut ordered_fft, ring);
487 assert_eq!(ordered_fft, original);
488}
489
490#[test]
491fn test_unordered_fft_permutation_inv() {
492 let ring = Fp::<1409>::RING;
493 let z = get_prim_root_of_unity(ring, 64 * 11).unwrap();
494 let fft = GeneralCooleyTukeyFFT::new(
495 ring,
496 ring.pow(z, 4),
497 CooleyTuckeyFFT::new(ring, ring.pow(z, 44), 4),
498 BluesteinFFT::new(ring, ring.pow(z, 32), ring.pow(z, 22), 11, 5, Global),
499 );
500 for i in 0..(16 * 11) {
501 assert_eq!(fft.unordered_fft_permutation_inv(fft.unordered_fft_permutation(i)), i);
502 assert_eq!(fft.unordered_fft_permutation(fft.unordered_fft_permutation_inv(i)), i);
503 }
504
505 let fft = GeneralCooleyTukeyFFT::new(
506 ring,
507 ring.pow(z, 4),
508 BluesteinFFT::new(ring, ring.pow(z, 32), ring.pow(z, 22), 11, 5, Global),
509 CooleyTuckeyFFT::new(ring, ring.pow(z, 44), 4),
510 );
511 for i in 0..(16 * 11) {
512 assert_eq!(fft.unordered_fft_permutation_inv(fft.unordered_fft_permutation(i)), i);
513 assert_eq!(fft.unordered_fft_permutation(fft.unordered_fft_permutation_inv(i)), i);
514 }
515}
516
517#[test]
518fn test_inv_fft() {
519 let ring = Fp::<97>::RING;
520 let z = ring.int_hom().map(39);
521 let fft = GeneralCooleyTukeyFFT::new(
522 ring,
523 ring.pow(z, 16),
524 CooleyTuckeyFFT::new(ring, ring.pow(z, 16 * 3), 1),
525 BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
526 );
527 let mut values = [3, 62, 63, 96, 37, 36];
528 let expected = [1, 0, 0, 1, 0, 1];
529
530 fft.inv_fft(&mut values, ring);
531 assert_eq!(values, expected);
532}
533
534#[test]
535fn test_approximate_fft() {
536 let CC = Complex64::RING;
537 for (p, log2_n) in [(5, 3), (53, 5), (101, 8), (503, 10)] {
538 let fft = GeneralCooleyTukeyFFT::new_with_pows(
539 CC,
540 |i| CC.root_of_unity(i, TryInto::<i64>::try_into(p).unwrap() << log2_n),
541 BluesteinFFT::for_complex(CC, p, Global),
542 CooleyTuckeyFFT::for_complex(CC, log2_n),
543 );
544 let mut array = (0..(p << log2_n))
545 .map(|i| CC.root_of_unity(i.try_into().unwrap(), TryInto::<i64>::try_into(p).unwrap() << log2_n))
546 .collect::<Vec<_>>();
547 fft.fft(&mut array, CC);
548 let err = fft.expected_absolute_error(1.0, 0.0);
549 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
550 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
551 for i in 2..fft.len() {
552 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
553 }
554 }
555}
556
557#[cfg(test)]
558const BENCH_N1: usize = 31;
559#[cfg(test)]
560const BENCH_N2: usize = 601;
561
562#[bench]
563fn bench_factor_fft(bencher: &mut test::Bencher) {
564 let ring = zn_64::Zn::new(1602564097);
565 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
566 let embedding = ring.can_hom(&fastmul_ring).unwrap();
567 let root_of_unity = fastmul_ring.coerce(&ring, get_prim_root_of_unity_zn(&ring, 2 * 31 * 601).unwrap());
568 let fft = GeneralCooleyTukeyFFT::new_with_hom(
569 embedding.clone(),
570 fastmul_ring.pow(root_of_unity, 2),
571 BluesteinFFT::new_with_hom(
572 embedding.clone(),
573 fastmul_ring.pow(root_of_unity, BENCH_N1),
574 get_prim_root_of_unity_zn(&fastmul_ring, 1 << 11).unwrap(),
575 BENCH_N2,
576 11,
577 Global,
578 ),
579 BluesteinFFT::new_with_hom(
580 embedding,
581 fastmul_ring.pow(root_of_unity, BENCH_N2),
582 get_prim_root_of_unity_zn(&fastmul_ring, 1 << 6).unwrap(),
583 BENCH_N1,
584 6,
585 Global,
586 ),
587 );
588 let data = (0..(BENCH_N1 * BENCH_N2))
589 .map(|i| ring.int_hom().map(i as i32))
590 .collect::<Vec<_>>();
591 let mut copy = Vec::with_capacity(BENCH_N1 * BENCH_N2);
592 bencher.iter(|| {
593 copy.clear();
594 copy.extend(data.iter().map(|x| ring.clone_el(x)));
595 fft.unordered_fft(&mut copy[..], &ring);
596 fft.unordered_inv_fft(&mut copy[..], &ring);
597 assert_el_eq!(ring, copy[0], data[0]);
598 });
599}