1use std::cmp::max;
2use std::alloc::{Allocator, Global};
3use std::marker::PhantomData;
4
5use crate::cow::*;
6use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
7use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
8use crate::algorithms::fft::FFTAlgorithm;
9use crate::lazy::LazyVec;
10use crate::primitive_int::StaticRingBase;
11use crate::integer::*;
12use crate::ring::*;
13use crate::seq::*;
14use crate::primitive_int::*;
15use crate::homomorphism::*;
16use crate::rings::float_complex::*;
17use crate::rings::zn::*;
18
19use super::ConvolutionAlgorithm;
20
21const CC: Complex64 = Complex64::RING;
22
23#[stability::unstable(feature = "enable")]
24pub struct FFTConvolution<A = Global> {
25 allocator: A,
26 fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
27}
28
29#[stability::unstable(feature = "enable")]
30pub struct PreparedConvolutionOperand<R, A = Global>
31 where R: ?Sized + RingBase,
32 A: Allocator + Clone
33{
34 ring: PhantomData<Box<R>>,
35 fft_data: LazyVec<Vec<El<Complex64>, A>>,
36 log2_data_size: usize
37}
38
39impl FFTConvolution<Global> {
40
41 #[stability::unstable(feature = "enable")]
42 pub fn new() -> Self {
43 Self::new_with_alloc(Global)
44 }
45}
46
47impl<A> FFTConvolution<A>
48 where A: Allocator + Clone
49{
50 #[stability::unstable(feature = "enable")]
51 pub fn new_with_alloc(allocator: A) -> Self {
52 Self {
53 allocator: allocator,
54 fft_tables: LazyVec::new()
55 }
56 }
57
58 fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
59 return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
60 }
61
62 fn get_fft_data<'a, R, V, ToInt>(
63 &self,
64 data: V,
65 data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
66 _ring: &R,
67 log2_len: usize,
68 mut to_int: ToInt,
69 log2_el_size: Option<usize>
70 ) -> MyCow<'a, Vec<El<Complex64>, A>>
71 where R: ?Sized + RingBase,
72 V: VectorView<R::Element>,
73 ToInt: FnMut(&R::Element) -> i64
74 {
75 let log2_data_size = if let Some(log2_data_size) = log2_el_size {
76 if let Some(data_prep) = data_prep {
77 assert_eq!(log2_data_size, data_prep.log2_data_size);
78 }
79 log2_data_size
80 } else {
81 data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
82 };
83 assert!(data.len() <= (1 << log2_len));
84 assert!(self.has_sufficient_precision(log2_len, log2_data_size));
85
86 let mut compute_result = || {
87 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
88 result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
89 result.resize(1 << log2_len, Complex64::RING.zero());
90 self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
91 return result;
92 };
93
94 return if let Some(data_prep) = data_prep {
95 MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
96 } else {
97 MyCow::Owned(compute_result())
98 }
99 }
100
101 #[stability::unstable(feature = "enable")]
102 pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
103 self.max_sum_len(log2_len, log2_input_size) > 0
104 }
105
106 fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
107 let fft_table = self.get_fft_table(log2_len);
108 let input_size = 2f64.powi(log2_input_size.try_into().unwrap());
109 (0.5 / fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.))).floor() as usize
110 }
111
112 fn prepare_convolution_impl<R, V, ToInt>(
113 &self,
114 data: V,
115 ring: &R,
116 length_hint: Option<usize>,
117 mut to_int: ToInt,
118 ring_log2_el_size: Option<usize>
119 ) -> PreparedConvolutionOperand<R, A>
120 where R: ?Sized + RingBase,
121 V: VectorView<R::Element>,
122 ToInt: FnMut(&R::Element) -> i64
123 {
124 let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
125 log2_data_size
126 } else {
127 data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
128 };
129 let result = PreparedConvolutionOperand {
130 fft_data: LazyVec::new(),
131 ring: PhantomData,
132 log2_data_size: log2_data_size
133 };
134 if let Some(len) = length_hint {
138 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
139 _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
140 }
141 return result;
142 }
143
144 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
145 &self,
146 lhs: V1,
147 lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
148 rhs: V2,
149 rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
150 dst: &mut [R::Element],
151 ring: &R,
152 mut to_int: ToInt,
153 mut from_int: FromInt,
154 ring_log2_el_size: Option<usize>
155 )
156 where R: ?Sized + RingBase,
157 V1: VectorView<R::Element>,
158 V2: VectorView<R::Element>,
159 ToInt: FnMut(&R::Element) -> i64,
160 FromInt: FnMut(i64) -> R::Element
161 {
162 if lhs.len() == 0 || rhs.len() == 0 {
163 return;
164 }
165 let len = lhs.len() + rhs.len() - 1;
166 assert!(dst.len() >= len);
167 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
168
169 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
170 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
171 if rhs_fft.is_owned() {
172 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
173 }
174 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
175
176 for i in 0..(1 << log2_len) {
177 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
178 }
179
180 self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
181
182 for i in 0..len {
183 let result = CC.closest_gaussian_int(lhs_fft[i]);
184 debug_assert_eq!(0, result.1);
185 ring.add_assign(&mut dst[i], from_int(result.0));
186 }
187 }
188
189 fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
190 &self,
191 data: I,
192 dst: &mut [R::Element],
193 ring: &R,
194 mut to_int: ToInt,
195 mut from_int: FromInt,
196 ring_log2_el_size: Option<usize>
197 )
198 where R: ?Sized + RingBase,
199 I: Iterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, A>>, V2, Option<&'a PreparedConvolutionOperand<R, A>>)>,
200 V1: VectorView<R::Element>,
201 V2: VectorView<R::Element>,
202 ToInt: FnMut(&R::Element) -> i64,
203 FromInt: FnMut(i64) -> R::Element,
204 Self: 'a,
205 R: 'a
206 {
207 let len = dst.len();
208 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
209 let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
210 buffer.resize(1 << log2_len, CC.zero());
211
212 let mut count_since_last_reduction = 0;
213 let mut current_max_sum_len = usize::MAX;
214 let mut current_log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
215 log2_data_size
216 } else {
217 0
218 };
219 for (lhs, lhs_prep, rhs, rhs_prep) in data {
220 if lhs.len() == 0 || rhs.len() == 0 {
221 continue;
222 }
223 assert!(lhs.len() + rhs.len() - 1 <= dst.len());
224
225 if ring_log2_el_size.is_none() {
226 current_log2_data_size = max(
227 current_log2_data_size,
228 lhs.as_iter().chain(rhs.as_iter()).map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
229 );
230 current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
231 }
232 assert!(current_max_sum_len > 0);
233
234 if count_since_last_reduction > current_max_sum_len {
235 count_since_last_reduction = 0;
236 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
237 for i in 0..len {
238 let result = CC.closest_gaussian_int(buffer[i]);
239 debug_assert_eq!(0, result.1);
240 ring.add_assign(&mut dst[i], from_int(result.0));
241 }
242 for i in 0..(1 << log2_len) {
243 buffer[i] = CC.zero();
244 }
245 }
246
247 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
248 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
249 if rhs_fft.is_owned() {
250 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
251 }
252 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
253 for i in 0..(1 << log2_len) {
254 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
255 CC.add_assign(&mut buffer[i], lhs_fft[i]);
256 }
257 count_since_last_reduction += 1;
258 }
259 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
260 for i in 0..len {
261 let result = CC.closest_gaussian_int(buffer[i]);
262 debug_assert_eq!(0, result.1);
263 ring.add_assign(&mut dst[i], from_int(result.0));
264 }
265 }
266}
267
268fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
269 where I: RingStore, I::Type: IntegerRing
270{
271 move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
272}
273
274fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
275 where I: RingStore, I::Type: IntegerRing
276{
277 move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
278}
279
280fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
281 where R: RingStore, R::Type: ZnRing
282{
283 move |x| int_cast(ring.smallest_lift(ring.clone_el(x)), StaticRing::<i64>::RING, ring.integer_ring())
284}
285
286fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
287 where R: RingStore, R::Type: ZnRing
288{
289 let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
290 move |x| ring.get_ring().map_in(ring.integer_ring().get_ring(), int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING), &hom)
291}
292
293impl<A> Clone for FFTConvolution<A>
294 where A: Allocator + Clone
295{
296 fn clone(&self) -> Self {
297 Self {
298 allocator: self.allocator.clone(),
299 fft_tables: self.fft_tables.clone()
300 }
301 }
302}
303
304impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
305 where A: Allocator
306{
307 fn from(value: FFTConvolutionZn<A>) -> Self {
308 value.base
309 }
310}
311
312impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
313 where A: Allocator
314{
315 fn from(value: &'a FFTConvolutionZn<A>) -> Self {
316 &value.base
317 }
318}
319
320impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
321 where A: Allocator
322{
323 fn from(value: FFTConvolution<A>) -> Self {
324 FFTConvolutionZn { base: value }
325 }
326}
327
328impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
329 where A: Allocator
330{
331 fn from(value: &'a FFTConvolution<A>) -> Self {
332 unsafe { std::mem::transmute(value) }
333 }
334}
335
336#[stability::unstable(feature = "enable")]
337#[repr(transparent)]
338pub struct FFTConvolutionZn<A = Global> {
339 base: FFTConvolution<A>
340}
341
342impl<A> Clone for FFTConvolutionZn<A>
343 where A: Allocator + Clone
344{
345 fn clone(&self) -> Self {
346 Self { base: self.base.clone() }
347 }
348}
349
350impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
351 where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
352 A: Allocator + Clone
353{
354 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
355
356 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
357 self.base.compute_convolution_impl(
358 lhs,
359 None,
360 rhs,
361 None,
362 dst,
363 ring.get_ring(),
364 to_int_zn(&ring),
365 from_int_zn(&ring),
366 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
367 )
368 }
369
370 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
371 true
372 }
373
374 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
375 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
376 {
377 self.base.prepare_convolution_impl(
378 val,
379 ring.get_ring(),
380 len_hint,
381 to_int_zn(&ring),
382 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
383 )
384 }
385
386 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
387 where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
388 {
389 self.base.compute_convolution_impl(
390 lhs,
391 lhs_prep,
392 rhs,
393 rhs_prep,
394 dst,
395 ring.get_ring(),
396 to_int_zn(&ring),
397 from_int_zn(&ring),
398 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
399 )
400 }
401
402 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
403 where S: RingStore<Type = R> + Copy,
404 I: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
405 V1: VectorView<R::Element>,
406 V2: VectorView<R::Element>,
407 Self: 'a,
408 R: 'a,
409 Self::PreparedConvolutionOperand: 'a
410 {
411 self.base.compute_convolution_sum_impl(
412 values,
413 dst,
414 ring.get_ring(),
415 to_int_zn(&ring),
416 from_int_zn(&ring),
417 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
418 )
419 }
420}
421
422impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
423 where I: ?Sized + IntegerRing,
424 A: Allocator + Clone
425{
426 type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
427
428 fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
429 self.compute_convolution_impl(
430 lhs,
431 None,
432 rhs,
433 None,
434 dst,
435 ring.get_ring(),
436 to_int_int(&ring),
437 from_int_int(&ring),
438 None
439 )
440 }
441
442 fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
443 true
444 }
445
446 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
447 where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
448 {
449 self.prepare_convolution_impl(
450 val,
451 ring.get_ring(),
452 len_hint,
453 to_int_int(&ring),
454 None
455 )
456 }
457
458 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [I::Element], ring: S)
459 where S: RingStore<Type = I> + Copy, V1: VectorView<I::Element>, V2: VectorView<I::Element>
460 {
461 self.compute_convolution_impl(
462 lhs,
463 lhs_prep,
464 rhs,
465 rhs_prep,
466 dst,
467 ring.get_ring(),
468 to_int_int(&ring),
469 from_int_int(&ring),
470 None
471 )
472 }
473
474 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S)
475 where S: RingStore<Type = I> + Copy,
476 J: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
477 V1: VectorView<I::Element>,
478 V2: VectorView<I::Element>,
479 Self: 'a,
480 I: 'a,
481 Self::PreparedConvolutionOperand: 'a
482 {
483 self.compute_convolution_sum_impl(
484 values,
485 dst,
486 ring.get_ring(),
487 to_int_int(&ring),
488 from_int_int(&ring),
489 None
490 )
491 }
492}
493
494#[cfg(test)]
495use crate::rings::finite::FiniteRingStore;
496#[cfg(test)]
497use crate::rings::zn::zn_64::Zn;
498
499#[test]
500fn test_convolution_zn() {
501 let convolution: FFTConvolutionZn = FFTConvolution::new().into();
502 let ring = Zn::new(17 * 257);
503
504 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
505}
506
507#[test]
508fn test_convolution_int() {
509 let convolution: FFTConvolution = FFTConvolution::new();
510 let ring = StaticRing::<i64>::RING;
511
512 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
513}
514
515#[test]
516#[should_panic(expected = "precision")]
517fn test_fft_convolution_not_enough_precision() {
518 let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new().into();
519
520 let ring = Zn::new(1099511627791);
521 let lhs = ring.elements().take(1024).collect::<Vec<_>>();
522 let rhs = ring.elements().take(1024).collect::<Vec<_>>();
523 let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
524
525 convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
526}