1use std::alloc::{Allocator, Global};
2use std::cmp::max;
3use std::marker::PhantomData;
4
5use super::ConvolutionAlgorithm;
6use crate::algorithms::fft::FFTAlgorithm;
7use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::cow::*;
10use crate::homomorphism::*;
11use crate::integer::*;
12use crate::lazy::LazyVec;
13use crate::primitive_int::{StaticRingBase, *};
14use crate::ring::*;
15use crate::rings::float_complex::*;
16use crate::rings::zn::*;
17use crate::seq::*;
18
19const CC: Complex64 = Complex64::RING;
20
21#[stability::unstable(feature = "enable")]
22pub struct FFTConvolution<A = Global> {
23 allocator: A,
24 fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>,
25}
26
27#[stability::unstable(feature = "enable")]
28pub struct PreparedConvolutionOperand<R, A = Global>
29where
30 R: ?Sized + RingBase,
31 A: Allocator + Clone,
32{
33 ring: PhantomData<Box<R>>,
34 fft_data: LazyVec<Vec<El<Complex64>, A>>,
35 log2_data_size: usize,
36}
37
38impl FFTConvolution<Global> {
39 #[stability::unstable(feature = "enable")]
40 pub fn new() -> Self { Self::new_with_alloc(Global) }
41}
42
43impl<A> FFTConvolution<A>
44where
45 A: Allocator + Clone,
46{
47 #[stability::unstable(feature = "enable")]
48 pub fn new_with_alloc(allocator: A) -> Self {
49 Self {
50 allocator,
51 fft_tables: LazyVec::new(),
52 }
53 }
54
55 fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
56 return self
57 .fft_tables
58 .get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
59 }
60
61 fn get_fft_data<'a, R, V, ToInt>(
62 &self,
63 data: V,
64 data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
65 _ring: &R,
66 log2_len: usize,
67 mut to_int: ToInt,
68 log2_el_size: Option<usize>,
69 ) -> MyCow<'a, Vec<El<Complex64>, A>>
70 where
71 R: ?Sized + RingBase,
72 V: VectorView<R::Element>,
73 ToInt: FnMut(&R::Element) -> i64,
74 {
75 let log2_data_size = if let Some(log2_data_size) = log2_el_size {
76 if let Some(data_prep) = data_prep {
77 assert_eq!(log2_data_size, data_prep.log2_data_size);
78 }
79 log2_data_size
80 } else {
81 data.as_iter()
82 .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
83 .max()
84 .unwrap()
85 };
86 assert!(data.len() <= (1 << log2_len));
87 assert!(self.has_sufficient_precision(log2_len, log2_data_size));
88
89 let mut compute_result = || {
90 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
91 result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
92 result.resize(1 << log2_len, Complex64::RING.zero());
93 self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
94 return result;
95 };
96
97 return if let Some(data_prep) = data_prep {
98 MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
99 } else {
100 MyCow::Owned(compute_result())
101 };
102 }
103
104 #[stability::unstable(feature = "enable")]
105 pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
106 self.max_sum_len(log2_len, log2_input_size) > 0
107 }
108
109 fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
110 let fft_table = self.get_fft_table(log2_len);
111 let input_size = 2.0f64.powi(log2_input_size.try_into().unwrap());
112 (0.5 / fft_table.expected_absolute_error(
113 input_size * input_size,
114 input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.0),
115 ))
116 .floor() as usize
117 }
118
119 fn prepare_convolution_impl<R, V, ToInt>(
120 &self,
121 data: V,
122 ring: &R,
123 length_hint: Option<usize>,
124 mut to_int: ToInt,
125 ring_log2_el_size: Option<usize>,
126 ) -> PreparedConvolutionOperand<R, A>
127 where
128 R: ?Sized + RingBase,
129 V: VectorView<R::Element>,
130 ToInt: FnMut(&R::Element) -> i64,
131 {
132 let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
133 log2_data_size
134 } else {
135 data.as_iter()
136 .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
137 .max()
138 .unwrap()
139 };
140 let result = PreparedConvolutionOperand {
141 fft_data: LazyVec::new(),
142 ring: PhantomData,
143 log2_data_size,
144 };
145 if let Some(len) = length_hint {
149 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
150 _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
151 }
152 return result;
153 }
154
155 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
156 &self,
157 lhs: V1,
158 lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
159 rhs: V2,
160 rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
161 dst: &mut [R::Element],
162 ring: &R,
163 mut to_int: ToInt,
164 mut from_int: FromInt,
165 ring_log2_el_size: Option<usize>,
166 ) where
167 R: ?Sized + RingBase,
168 V1: VectorView<R::Element>,
169 V2: VectorView<R::Element>,
170 ToInt: FnMut(&R::Element) -> i64,
171 FromInt: FnMut(i64) -> R::Element,
172 {
173 if lhs.len() == 0 || rhs.len() == 0 {
174 return;
175 }
176 let len = lhs.len() + rhs.len() - 1;
177 assert!(dst.len() >= len);
178 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
179
180 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
181 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
182 if rhs_fft.is_owned() {
183 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
184 }
185 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
186
187 for i in 0..(1 << log2_len) {
188 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
189 }
190
191 self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
192
193 for i in 0..len {
194 let result = CC.closest_gaussian_int(lhs_fft[i]);
195 debug_assert_eq!(0, result.1);
196 ring.add_assign(&mut dst[i], from_int(result.0));
197 }
198 }
199
200 fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
201 &self,
202 data: I,
203 dst: &mut [R::Element],
204 ring: &R,
205 mut to_int: ToInt,
206 mut from_int: FromInt,
207 ring_log2_el_size: Option<usize>,
208 ) where
209 R: ?Sized + RingBase,
210 I: Iterator<
211 Item = (
212 V1,
213 Option<&'a PreparedConvolutionOperand<R, A>>,
214 V2,
215 Option<&'a PreparedConvolutionOperand<R, A>>,
216 ),
217 >,
218 V1: VectorView<R::Element>,
219 V2: VectorView<R::Element>,
220 ToInt: FnMut(&R::Element) -> i64,
221 FromInt: FnMut(i64) -> R::Element,
222 Self: 'a,
223 R: 'a,
224 {
225 let len = dst.len();
226 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
227 let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
228 buffer.resize(1 << log2_len, CC.zero());
229
230 let mut count_since_last_reduction = 0;
231 let mut current_max_sum_len = usize::MAX;
232 let mut current_log2_data_size = ring_log2_el_size.unwrap_or_default();
233 for (lhs, lhs_prep, rhs, rhs_prep) in data {
234 if lhs.len() == 0 || rhs.len() == 0 {
235 continue;
236 }
237 assert!(lhs.len() + rhs.len() - 1 <= dst.len());
238
239 if ring_log2_el_size.is_none() {
240 current_log2_data_size = max(
241 current_log2_data_size,
242 lhs.as_iter()
243 .chain(rhs.as_iter())
244 .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
245 .max()
246 .unwrap(),
247 );
248 current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
249 }
250 assert!(current_max_sum_len > 0);
251
252 if count_since_last_reduction > current_max_sum_len {
253 count_since_last_reduction = 0;
254 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
255 for i in 0..len {
256 let result = CC.closest_gaussian_int(buffer[i]);
257 debug_assert_eq!(0, result.1);
258 ring.add_assign(&mut dst[i], from_int(result.0));
259 }
260 for i in 0..(1 << log2_len) {
261 buffer[i] = CC.zero();
262 }
263 }
264
265 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
266 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
267 if rhs_fft.is_owned() {
268 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
269 }
270 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
271 for i in 0..(1 << log2_len) {
272 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
273 CC.add_assign(&mut buffer[i], lhs_fft[i]);
274 }
275 count_since_last_reduction += 1;
276 }
277 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
278 for i in 0..len {
279 let result = CC.closest_gaussian_int(buffer[i]);
280 debug_assert_eq!(0, result.1);
281 ring.add_assign(&mut dst[i], from_int(result.0));
282 }
283 }
284}
285
286fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
287where
288 I: RingStore,
289 I::Type: IntegerRing,
290{
291 move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
292}
293
294fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
295where
296 I: RingStore,
297 I::Type: IntegerRing,
298{
299 move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
300}
301
302fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
303where
304 R: RingStore,
305 R::Type: ZnRing,
306{
307 move |x| {
308 int_cast(
309 ring.smallest_lift(ring.clone_el(x)),
310 StaticRing::<i64>::RING,
311 ring.integer_ring(),
312 )
313 }
314}
315
316fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
317where
318 R: RingStore,
319 R::Type: ZnRing,
320{
321 let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
322 move |x| {
323 ring.get_ring().map_in(
324 ring.integer_ring().get_ring(),
325 int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING),
326 &hom,
327 )
328 }
329}
330
331impl<A> Clone for FFTConvolution<A>
332where
333 A: Allocator + Clone,
334{
335 fn clone(&self) -> Self {
336 Self {
337 allocator: self.allocator.clone(),
338 fft_tables: self.fft_tables.clone(),
339 }
340 }
341}
342
343impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
344where
345 A: Allocator,
346{
347 fn from(value: FFTConvolutionZn<A>) -> Self { value.base }
348}
349
350impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
351where
352 A: Allocator,
353{
354 fn from(value: &'a FFTConvolutionZn<A>) -> Self { &value.base }
355}
356
357impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
358where
359 A: Allocator,
360{
361 fn from(value: FFTConvolution<A>) -> Self { FFTConvolutionZn { base: value } }
362}
363
364impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
365where
366 A: Allocator,
367{
368 fn from(value: &'a FFTConvolution<A>) -> Self { unsafe { std::mem::transmute(value) } }
369}
370
371#[stability::unstable(feature = "enable")]
372#[repr(transparent)]
373pub struct FFTConvolutionZn<A = Global> {
374 base: FFTConvolution<A>,
375}
376
377impl<A> Clone for FFTConvolutionZn<A>
378where
379 A: Allocator + Clone,
380{
381 fn clone(&self) -> Self {
382 Self {
383 base: self.base.clone(),
384 }
385 }
386}
387
388impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
389where
390 R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
391 A: Allocator + Clone,
392{
393 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
394
395 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
396 &self,
397 lhs: V1,
398 rhs: V2,
399 dst: &mut [R::Element],
400 ring: S,
401 ) {
402 self.base.compute_convolution_impl(
403 lhs,
404 None,
405 rhs,
406 None,
407 dst,
408 ring.get_ring(),
409 to_int_zn(&ring),
410 from_int_zn(&ring),
411 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
412 )
413 }
414
415 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
416
417 fn prepare_convolution_operand<S, V>(
418 &self,
419 val: V,
420 len_hint: Option<usize>,
421 ring: S,
422 ) -> Self::PreparedConvolutionOperand
423 where
424 S: RingStore<Type = R> + Copy,
425 V: VectorView<R::Element>,
426 {
427 self.base.prepare_convolution_impl(
428 val,
429 ring.get_ring(),
430 len_hint,
431 to_int_zn(&ring),
432 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
433 )
434 }
435
436 fn compute_convolution_prepared<S, V1, V2>(
437 &self,
438 lhs: V1,
439 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
440 rhs: V2,
441 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
442 dst: &mut [R::Element],
443 ring: S,
444 ) where
445 S: RingStore<Type = R> + Copy,
446 V1: VectorView<R::Element>,
447 V2: VectorView<R::Element>,
448 {
449 self.base.compute_convolution_impl(
450 lhs,
451 lhs_prep,
452 rhs,
453 rhs_prep,
454 dst,
455 ring.get_ring(),
456 to_int_zn(&ring),
457 from_int_zn(&ring),
458 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
459 )
460 }
461
462 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
463 where
464 S: RingStore<Type = R> + Copy,
465 I: Iterator<
466 Item = (
467 V1,
468 Option<&'a Self::PreparedConvolutionOperand>,
469 V2,
470 Option<&'a Self::PreparedConvolutionOperand>,
471 ),
472 >,
473 V1: VectorView<R::Element>,
474 V2: VectorView<R::Element>,
475 Self: 'a,
476 R: 'a,
477 Self::PreparedConvolutionOperand: 'a,
478 {
479 self.base.compute_convolution_sum_impl(
480 values,
481 dst,
482 ring.get_ring(),
483 to_int_zn(&ring),
484 from_int_zn(&ring),
485 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
486 )
487 }
488}
489
490impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
491where
492 I: ?Sized + IntegerRing,
493 A: Allocator + Clone,
494{
495 type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
496
497 fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(
498 &self,
499 lhs: V1,
500 rhs: V2,
501 dst: &mut [I::Element],
502 ring: S,
503 ) {
504 self.compute_convolution_impl(
505 lhs,
506 None,
507 rhs,
508 None,
509 dst,
510 ring.get_ring(),
511 to_int_int(&ring),
512 from_int_int(&ring),
513 None,
514 )
515 }
516
517 fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool { true }
518
519 fn prepare_convolution_operand<S, V>(
520 &self,
521 val: V,
522 len_hint: Option<usize>,
523 ring: S,
524 ) -> Self::PreparedConvolutionOperand
525 where
526 S: RingStore<Type = I> + Copy,
527 V: VectorView<I::Element>,
528 {
529 self.prepare_convolution_impl(val, ring.get_ring(), len_hint, to_int_int(&ring), None)
530 }
531
532 fn compute_convolution_prepared<S, V1, V2>(
533 &self,
534 lhs: V1,
535 lhs_prep: Option<&Self::PreparedConvolutionOperand>,
536 rhs: V2,
537 rhs_prep: Option<&Self::PreparedConvolutionOperand>,
538 dst: &mut [I::Element],
539 ring: S,
540 ) where
541 S: RingStore<Type = I> + Copy,
542 V1: VectorView<I::Element>,
543 V2: VectorView<I::Element>,
544 {
545 self.compute_convolution_impl(
546 lhs,
547 lhs_prep,
548 rhs,
549 rhs_prep,
550 dst,
551 ring.get_ring(),
552 to_int_int(&ring),
553 from_int_int(&ring),
554 None,
555 )
556 }
557
558 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S)
559 where
560 S: RingStore<Type = I> + Copy,
561 J: Iterator<
562 Item = (
563 V1,
564 Option<&'a Self::PreparedConvolutionOperand>,
565 V2,
566 Option<&'a Self::PreparedConvolutionOperand>,
567 ),
568 >,
569 V1: VectorView<I::Element>,
570 V2: VectorView<I::Element>,
571 Self: 'a,
572 I: 'a,
573 Self::PreparedConvolutionOperand: 'a,
574 {
575 self.compute_convolution_sum_impl(
576 values,
577 dst,
578 ring.get_ring(),
579 to_int_int(&ring),
580 from_int_int(&ring),
581 None,
582 )
583 }
584}
585
586#[cfg(test)]
587use crate::rings::finite::FiniteRingStore;
588#[cfg(test)]
589use crate::rings::zn::zn_64::Zn;
590
591#[test]
592fn test_convolution_zn() {
593 let convolution: FFTConvolutionZn = FFTConvolution::new().into();
594 let ring = Zn::new(17 * 257);
595
596 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
597}
598
599#[test]
600fn test_convolution_int() {
601 let convolution: FFTConvolution = FFTConvolution::new();
602 let ring = StaticRing::<i64>::RING;
603
604 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
605}
606
607#[test]
608#[should_panic(expected = "precision")]
609fn test_fft_convolution_not_enough_precision() {
610 let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new().into();
611
612 let ring = Zn::new(1099511627791);
613 let lhs = ring.elements().take(1024).collect::<Vec<_>>();
614 let rhs = ring.elements().take(1024).collect::<Vec<_>>();
615 let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
616
617 convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
618}