1use subvector::SubvectorView;
2
3use crate::ring::*;
4use crate::homomorphism::*;
5use crate::algorithms::fft::*;
6use crate::algorithms::fft::complex_fft::*;
7use crate::rings::float_complex::*;
8
9#[stability::unstable(feature = "enable")]
14pub struct CoprimeCooleyTuckeyFFT<R_main, R_twiddle, H, T1, T2>
15 where R_main: ?Sized + RingBase,
16 R_twiddle: ?Sized + RingBase,
17 H: Homomorphism<R_twiddle, R_main>,
18 T1: FFTAlgorithm<R_main>,
19 T2: FFTAlgorithm<R_main>
20{
21 twiddle_factors: Vec<R_twiddle::Element>,
22 inv_twiddle_factors: Vec<R_twiddle::Element>,
23 left_table: T1,
24 right_table: T2,
25 root_of_unity: R_main::Element,
26 hom: H
27}
28
29impl<R, T1, T2> CoprimeCooleyTuckeyFFT<R::Type, R::Type, Identity<R>, T1, T2>
30 where R: RingStore,
31 T1: FFTAlgorithm<R::Type>,
32 T2: FFTAlgorithm<R::Type>
33{
34 #[stability::unstable(feature = "enable")]
35 pub fn new_with_pows<F>(ring: R, root_of_unity_pows: F, left_table: T1, right_table: T2) -> Self
36 where F: FnMut(i64) -> El<R>
37 {
38 Self::new_with_pows_with_hom(ring.into_identity(), root_of_unity_pows, left_table, right_table)
39 }
40
41 #[stability::unstable(feature = "enable")]
42 pub fn new(ring: R, root_of_unity: El<R>, left_table: T1, right_table: T2) -> Self {
43 Self::new_with_hom(ring.into_identity(), root_of_unity, left_table, right_table)
44 }
45}
46
47impl<R_main, R_twiddle, H, T1, T2> CoprimeCooleyTuckeyFFT<R_main, R_twiddle, H, T1, T2>
48 where R_main: ?Sized + RingBase,
49 R_twiddle: ?Sized + RingBase,
50 H: Homomorphism<R_twiddle, R_main>,
51 T1: FFTAlgorithm<R_main>,
52 T2: FFTAlgorithm<R_main>
53{
54 #[stability::unstable(feature = "enable")]
55 pub fn new_with_pows_with_hom<F>(hom: H, mut root_of_unity_pows: F, left_table: T1, right_table: T2) -> Self
56 where F: FnMut(i64) -> R_twiddle::Element
57 {
58 let ring = hom.codomain();
59
60 assert!(ring.get_ring().is_approximate() || ring.eq_el(&hom.map(root_of_unity_pows(right_table.len() as i64)), left_table.root_of_unity(ring)));
61 assert!(ring.get_ring().is_approximate() || ring.eq_el(&hom.map(root_of_unity_pows(left_table.len() as i64)), right_table.root_of_unity(ring)));
62
63 let root_of_unity = root_of_unity_pows(1);
64 let inv_twiddle_factors = Self::create_twiddle_factors(|i| root_of_unity_pows(-i), &left_table, &right_table);
65 let twiddle_factors = Self::create_twiddle_factors(root_of_unity_pows, &left_table, &right_table);
66
67 CoprimeCooleyTuckeyFFT {
68 twiddle_factors: twiddle_factors,
69 inv_twiddle_factors: inv_twiddle_factors,
70 left_table: left_table,
71 right_table: right_table,
72 root_of_unity: hom.map(root_of_unity),
73 hom: hom
74 }
75 }
76
77 #[stability::unstable(feature = "enable")]
78 pub fn left_fft_table(&self) -> &T1 {
79 &self.left_table
80 }
81
82 #[stability::unstable(feature = "enable")]
83 pub fn right_fft_table(&self) -> &T2 {
84 &self.right_table
85 }
86
87 #[stability::unstable(feature = "enable")]
88 pub fn new_with_hom(hom: H, root_of_unity: R_twiddle::Element, left_table: T1, right_table: T2) -> Self {
89 assert!(!hom.domain().get_ring().is_approximate());
90 let ring = hom.codomain();
91
92 assert!(!ring.get_ring().is_approximate());
93
94 let len = left_table.len() * right_table.len();
95 let root_of_unity_pows = |i: i64| if i >= 0 {
96 hom.domain().pow(hom.domain().clone_el(&root_of_unity), i as usize)
97 } else {
98 hom.domain().pow(hom.domain().clone_el(&root_of_unity), (len as i64 + (i % len as i64)) as usize)
99 };
100
101 assert!(ring.eq_el(&hom.map(root_of_unity_pows(right_table.len() as i64)), left_table.root_of_unity(ring)));
102 assert!(ring.eq_el(&hom.map(root_of_unity_pows(left_table.len() as i64)), right_table.root_of_unity(ring)));
103
104 let inv_twiddle_factors = Self::create_twiddle_factors(|i| root_of_unity_pows(-i), &left_table, &right_table);
105 let twiddle_factors = Self::create_twiddle_factors(root_of_unity_pows, &left_table, &right_table);
106
107 CoprimeCooleyTuckeyFFT {
108 twiddle_factors: twiddle_factors,
109 inv_twiddle_factors: inv_twiddle_factors,
110 left_table: left_table,
111 right_table: right_table,
112 root_of_unity: hom.map(root_of_unity),
113 hom: hom
114 }
115 }
116
117 fn create_twiddle_factors<F>(mut root_of_unity_pows: F, left_table: &T1, right_table: &T2) -> Vec<R_twiddle::Element>
118 where F: FnMut(i64) -> R_twiddle::Element
119 {
120 (0..(left_table.len() * right_table.len())).map(|i| {
121 let ri = i % right_table.len();
122 let li = i / right_table.len();
123 return root_of_unity_pows(left_table.unordered_fft_permutation(li) as i64 * ri as i64);
124 }).collect::<Vec<_>>()
125 }
126
127 fn ring<'a>(&'a self) -> &'a <H as Homomorphism<R_twiddle, R_main>>::CodomainStore {
128 self.hom.codomain()
129 }
130}
131
132impl<R_main, R_twiddle, H, T1, T2> PartialEq for CoprimeCooleyTuckeyFFT<R_main, R_twiddle, H, T1, T2>
133 where R_main: ?Sized + RingBase,
134 R_twiddle: ?Sized + RingBase,
135 H: Homomorphism<R_twiddle, R_main>,
136 T1: FFTAlgorithm<R_main> + PartialEq,
137 T2: FFTAlgorithm<R_main> + PartialEq
138{
139 fn eq(&self, other: &Self) -> bool {
140 self.ring().get_ring() == other.ring().get_ring() &&
141 self.left_table == other.left_table &&
142 self.right_table == other.right_table &&
143 self.ring().eq_el(self.root_of_unity(self.ring()), other.root_of_unity(self.ring()))
144 }
145}
146
147impl<R_main, R_twiddle, H, T1, T2> FFTAlgorithm<R_main> for CoprimeCooleyTuckeyFFT<R_main, R_twiddle, H, T1, T2>
148 where R_main: ?Sized + RingBase,
149 R_twiddle: ?Sized + RingBase,
150 H: Homomorphism<R_twiddle, R_main>,
151 T1: FFTAlgorithm<R_main>,
152 T2: FFTAlgorithm<R_main>
153{
154 fn len(&self) -> usize {
155 self.left_table.len() * self.right_table.len()
156 }
157
158 fn root_of_unity<S: RingStore<Type = R_main> + Copy>(&self, ring: S) -> &R_main::Element {
159 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
160 &self.root_of_unity
161 }
162
163 fn unordered_fft<V, S>(&self, mut values: V, ring: S)
164 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
165 S: RingStore<Type = R_main> + Copy
166 {
167 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
168 for i in 0..self.right_table.len() {
169 let mut v = SubvectorView::new(&mut values).restrict(i..).step_by_view(self.right_table.len());
170 self.left_table.unordered_fft(&mut v, ring);
171 }
172 for i in 0..self.len() {
173 self.hom.mul_assign_ref_map(values.at_mut(i), self.inv_twiddle_factors.at(i));
174 }
175 for i in 0..self.left_table.len() {
176 let mut v = SubvectorView::new(&mut values).restrict((i * self.right_table.len())..((i + 1) * self.right_table.len()));
177 self.right_table.unordered_fft(&mut v, ring);
178 }
179 }
180
181 fn unordered_inv_fft<V, S>(&self, mut values: V, ring: S)
182 where V: SwappableVectorViewMut<<R_main as RingBase>::Element>,
183 S: RingStore<Type = R_main> + Copy
184 {
185 assert!(self.ring().get_ring() == ring.get_ring(), "unsupported ring");
186 for i in 0..self.left_table.len() {
187 let mut v = SubvectorView::new(&mut values).restrict((i * self.right_table.len())..((i + 1) * self.right_table.len()));
188 self.right_table.unordered_inv_fft(&mut v, ring);
189 }
190 for i in 0..self.len() {
191 self.hom.mul_assign_ref_map(values.at_mut(i), self.twiddle_factors.at(i));
192 debug_assert!(self.ring().get_ring().is_approximate() || self.hom.domain().is_one(&self.hom.domain().mul_ref(self.twiddle_factors.at(i), self.inv_twiddle_factors.at(i))));
193 }
194 for i in 0..self.right_table.len() {
195 let mut v = SubvectorView::new(&mut values).restrict(i..).step_by_view(self.right_table.len());
196 self.left_table.unordered_inv_fft(&mut v, ring);
197 }
198 }
199
200 fn unordered_fft_permutation(&self, i: usize) -> usize {
201 assert!(i < self.len());
202 self.left_table.unordered_fft_permutation(i / self.right_table.len()) + self.left_table.len() * self.right_table.unordered_fft_permutation(i % self.right_table.len())
203 }
204
205 fn unordered_fft_permutation_inv(&self, i: usize) -> usize {
206 assert!(i < self.len());
207 self.left_table.unordered_fft_permutation_inv(i % self.left_table.len()) * self.right_table.len() + self.right_table.unordered_fft_permutation_inv(i / self.left_table.len())
208 }
209}
210
211impl<H, T1, T2> FFTErrorEstimate for CoprimeCooleyTuckeyFFT<Complex64Base, Complex64Base, H, T1, T2>
212 where H: Homomorphism<Complex64Base, Complex64Base>,
213 T1: FFTAlgorithm<Complex64Base> + FFTErrorEstimate,
214 T2: FFTAlgorithm<Complex64Base> + FFTErrorEstimate
215{
216 fn expected_absolute_error(&self, input_bound: f64, input_error: f64) -> f64 {
217 let error_after_first_fft = self.left_table.expected_absolute_error(input_bound, input_error);
218 let new_input_bound = self.left_table.len() as f64 * input_bound;
219 let error_after_twiddling = error_after_first_fft + new_input_bound * (root_of_unity_error() + f64::EPSILON);
220 return self.right_table.expected_absolute_error(new_input_bound, error_after_twiddling);
221 }
222}
223
224#[cfg(test)]
225use crate::rings::zn::zn_static::{Zn, Fp};
226#[cfg(test)]
227use crate::algorithms;
228#[cfg(test)]
229use crate::rings::zn::zn_64;
230#[cfg(test)]
231use crate::rings::zn::ZnRingStore;
232#[cfg(test)]
233use std::alloc::Global;
234
235#[test]
236fn test_fft_basic() {
237 let ring = Zn::<97>::RING;
238 let z = ring.int_hom().map(39);
239 let fft = CoprimeCooleyTuckeyFFT::new(ring, ring.pow(z, 16),
240 bluestein::BluesteinFFT::new(ring, ring.pow(z, 24), ring.pow(z, 12), 2, 3, Global),
241 bluestein::BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
242 );
243 let mut values = [1, 0, 0, 1, 0, 1];
244 let expected = [3, 62, 63, 96, 37, 36];
245 let mut permuted_expected = [0; 6];
246 for i in 0..6 {
247 permuted_expected[i] = expected[fft.unordered_fft_permutation(i)];
248 }
249
250 fft.unordered_fft(&mut values, ring);
251 assert_eq!(values, permuted_expected);
252}
253
254#[test]
255fn test_fft_long() {
256 let ring = Fp::<97>::RING;
257 let z = ring.int_hom().map(39);
258 let fft = CoprimeCooleyTuckeyFFT::new(ring, ring.pow(z, 4),
259 bluestein::BluesteinFFT::new(ring, ring.pow(z, 6), ring.pow(z, 3), 8, 5, Global),
260 bluestein::BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
261 );
262 let mut values = [1, 0, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 2, 2, 0, 2, 0, 1, 2, 3, 4];
263 let expected = [26, 0, 75, 47, 41, 31, 28, 62, 39, 93, 53, 27, 0, 54, 74, 61, 65, 81, 63, 38, 53, 94, 89, 91];
264 let mut permuted_expected = [0; 24];
265 for i in 0..24 {
266 permuted_expected[i] = expected[fft.unordered_fft_permutation(i)];
267 }
268
269 fft.unordered_fft(&mut values, ring);
270 assert_eq!(values, permuted_expected);
271}
272
273#[test]
274fn test_fft_unordered() {
275 let ring = Fp::<1409>::RING;
276 let z = algorithms::unity_root::get_prim_root_of_unity(ring, 64 * 11).unwrap();
277 let fft = CoprimeCooleyTuckeyFFT::new(
278 ring,
279 ring.pow(z, 4),
280 cooley_tuckey::CooleyTuckeyFFT::new(ring, ring.pow(z, 44), 4),
281 bluestein::BluesteinFFT::new(ring, ring.pow(z, 32), ring.pow(z, 22), 11, 5, Global),
282 );
283 const LEN: usize = 16 * 11;
284 let mut values = [0; LEN];
285 for i in 0..LEN {
286 values[i] = ring.int_hom().map(i as i32);
287 }
288 let original = values;
289
290 fft.unordered_fft(&mut values, ring);
291
292 let mut ordered_fft = [0; LEN];
293 for i in 0..LEN {
294 ordered_fft[fft.unordered_fft_permutation(i)] = values[i];
295 }
296
297 fft.unordered_inv_fft(&mut values, ring);
298 assert_eq!(values, original);
299
300 fft.inv_fft(&mut ordered_fft, ring);
301 assert_eq!(ordered_fft, original);
302}
303
304
305#[test]
306fn test_unordered_fft_permutation_inv() {
307 let ring = Fp::<1409>::RING;
308 let z = algorithms::unity_root::get_prim_root_of_unity(ring, 64 * 11).unwrap();
309 let fft = CoprimeCooleyTuckeyFFT::new(
310 ring,
311 ring.pow(z, 4),
312 cooley_tuckey::CooleyTuckeyFFT::new(ring, ring.pow(z, 44), 4),
313 bluestein::BluesteinFFT::new(ring, ring.pow(z, 32), ring.pow(z, 22), 11, 5, Global),
314 );
315 for i in 0..(16 * 11) {
316 assert_eq!(fft.unordered_fft_permutation_inv(fft.unordered_fft_permutation(i)), i);
317 assert_eq!(fft.unordered_fft_permutation(fft.unordered_fft_permutation_inv(i)), i);
318 }
319}
320
321#[test]
322fn test_inv_fft() {
323 let ring = Fp::<97>::RING;
324 let z = ring.int_hom().map(39);
325 let fft = CoprimeCooleyTuckeyFFT::new(ring, ring.pow(z, 16),
326 bluestein::BluesteinFFT::new(ring, ring.pow(z, 24), ring.pow(z, 12), 2, 3, Global),
327 bluestein::BluesteinFFT::new(ring, ring.pow(z, 16), ring.pow(z, 12), 3, 3, Global),
328 );
329 let mut values = [3, 62, 63, 96, 37, 36];
330 let expected = [1, 0, 0, 1, 0, 1];
331
332 fft.inv_fft(&mut values, ring);
333 assert_eq!(values, expected);
334}
335
336#[test]
337fn test_approximate_fft() {
338 let CC = Complex64::RING;
339 for (p, log2_n) in [(5, 3), (53, 5), (101, 8), (503, 10)] {
340 let fft = CoprimeCooleyTuckeyFFT::new_with_pows(
341 CC,
342 |i| CC.root_of_unity(i, (p as i64) << log2_n),
343 bluestein::BluesteinFFT::for_complex(CC, p, Global),
344 cooley_tuckey::CooleyTuckeyFFT::for_complex(CC, log2_n)
345 );
346 let mut array = (0..(p << log2_n)).map(|i| CC.root_of_unity(i as i64, (p as i64) << log2_n)).collect::<Vec<_>>();
347 fft.fft(&mut array, CC);
348 let err = fft.expected_absolute_error(1., 0.);
349 assert!(CC.is_absolute_approx_eq(array[0], CC.zero(), err));
350 assert!(CC.is_absolute_approx_eq(array[1], CC.from_f64(fft.len() as f64), err));
351 for i in 2..fft.len() {
352 assert!(CC.is_absolute_approx_eq(array[i], CC.zero(), err));
353 }
354 }
355}
356
357#[cfg(test)]
358const BENCH_N1: usize = 31;
359#[cfg(test)]
360const BENCH_N2: usize = 601;
361
362#[bench]
363fn bench_factor_fft(bencher: &mut test::Bencher) {
364 let ring = zn_64::Zn::new(1602564097);
365 let fastmul_ring = zn_64::ZnFastmul::new(ring).unwrap();
366 let embedding = ring.can_hom(&fastmul_ring).unwrap();
367 let ring_as_field = ring.as_field().ok().unwrap();
368 let root_of_unity = fastmul_ring.coerce(&ring, ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity(&ring_as_field, 2 * 31 * 601).unwrap()));
369 let fastmul_ring_as_field = fastmul_ring.as_field().ok().unwrap();
370 let fft = CoprimeCooleyTuckeyFFT::new_with_hom(
371 embedding.clone(),
372 fastmul_ring.pow(root_of_unity, 2),
373 bluestein::BluesteinFFT::new_with_hom(embedding.clone(), fastmul_ring.pow(root_of_unity, BENCH_N1), fastmul_ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity_pow2(&fastmul_ring_as_field, 11).unwrap()), BENCH_N2, 11, Global),
374 bluestein::BluesteinFFT::new_with_hom(embedding, fastmul_ring.pow(root_of_unity, BENCH_N2), fastmul_ring_as_field.get_ring().unwrap_element(algorithms::unity_root::get_prim_root_of_unity_pow2(&fastmul_ring_as_field, 6).unwrap()), BENCH_N1, 6, Global),
375 );
376 let data = (0..(BENCH_N1 * BENCH_N2)).map(|i| ring.int_hom().map(i as i32)).collect::<Vec<_>>();
377 let mut copy = Vec::with_capacity(BENCH_N1 * BENCH_N2);
378 bencher.iter(|| {
379 copy.clear();
380 copy.extend(data.iter().map(|x| ring.clone_el(x)));
381 fft.unordered_fft(&mut copy[..], &ring);
382 fft.unordered_inv_fft(&mut copy[..], &ring);
383 assert_el_eq!(ring, copy[0], data[0]);
384 });
385}