1use std::cmp::max;
2use std::alloc::{Allocator, Global};
3use std::marker::PhantomData;
4
5use crate::cow::*;
6use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
7use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
8use crate::algorithms::fft::FFTAlgorithm;
9use crate::lazy::LazyVec;
10use crate::primitive_int::StaticRingBase;
11use crate::integer::*;
12use crate::ring::*;
13use crate::seq::*;
14use crate::primitive_int::*;
15use crate::homomorphism::*;
16use crate::rings::float_complex::*;
17use crate::rings::zn::*;
18
19use super::ConvolutionAlgorithm;
20
21const CC: Complex64 = Complex64::RING;
22
23#[stability::unstable(feature = "enable")]
24pub struct FFTConvolution<A = Global> {
25 allocator: A,
26 fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
27}
28
29#[stability::unstable(feature = "enable")]
30pub struct PreparedConvolutionOperand<R, A = Global>
31 where R: ?Sized + RingBase,
32 A: Allocator + Clone
33{
34 ring: PhantomData<Box<R>>,
35 fft_data: LazyVec<Vec<El<Complex64>, A>>,
36 log2_data_size: usize
37}
38
39impl<A> FFTConvolution<A>
40 where A: Allocator + Clone
41{
42 #[stability::unstable(feature = "enable")]
43 pub fn new_with(allocator: A) -> Self {
44 Self {
45 allocator: allocator,
46 fft_tables: LazyVec::new()
47 }
48 }
49
50 fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
51 return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
52 }
53
54 fn get_fft_data<'a, R, V, ToInt>(
55 &self,
56 data: V,
57 data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
58 _ring: &R,
59 log2_len: usize,
60 mut to_int: ToInt,
61 log2_el_size: Option<usize>
62 ) -> MyCow<'a, Vec<El<Complex64>, A>>
63 where R: ?Sized + RingBase,
64 V: VectorView<R::Element>,
65 ToInt: FnMut(&R::Element) -> i64
66 {
67 let log2_data_size = if let Some(log2_data_size) = log2_el_size {
68 if let Some(data_prep) = data_prep {
69 assert_eq!(log2_data_size, data_prep.log2_data_size);
70 }
71 log2_data_size
72 } else {
73 data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
74 };
75 assert!(data.len() <= (1 << log2_len));
76 assert!(self.has_sufficient_precision(log2_len, log2_data_size));
77
78 let mut compute_result = || {
79 let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
80 result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
81 result.resize(1 << log2_len, Complex64::RING.zero());
82 self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
83 return result;
84 };
85
86 return if let Some(data_prep) = data_prep {
87 MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
88 } else {
89 MyCow::Owned(compute_result())
90 }
91 }
92
93 #[stability::unstable(feature = "enable")]
94 pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
95 self.max_sum_len(log2_len, log2_input_size) > 0
96 }
97
98 fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
99 let fft_table = self.get_fft_table(log2_len);
100 let input_size = 2f64.powi(log2_input_size.try_into().unwrap());
101 (0.5 / fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.))).floor() as usize
102 }
103
104 fn prepare_convolution_impl<R, V, ToInt>(
105 &self,
106 data: V,
107 ring: &R,
108 length_hint: Option<usize>,
109 mut to_int: ToInt,
110 ring_log2_el_size: Option<usize>
111 ) -> PreparedConvolutionOperand<R, A>
112 where R: ?Sized + RingBase,
113 V: VectorView<R::Element>,
114 ToInt: FnMut(&R::Element) -> i64
115 {
116 let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
117 log2_data_size
118 } else {
119 data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
120 };
121 let result = PreparedConvolutionOperand {
122 fft_data: LazyVec::new(),
123 ring: PhantomData,
124 log2_data_size: log2_data_size
125 };
126 if let Some(len) = length_hint {
130 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
131 _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
132 }
133 return result;
134 }
135
136 fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
137 &self,
138 lhs: V1,
139 lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
140 rhs: V2,
141 rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
142 dst: &mut [R::Element],
143 ring: &R,
144 mut to_int: ToInt,
145 mut from_int: FromInt,
146 ring_log2_el_size: Option<usize>
147 )
148 where R: ?Sized + RingBase,
149 V1: VectorView<R::Element>,
150 V2: VectorView<R::Element>,
151 ToInt: FnMut(&R::Element) -> i64,
152 FromInt: FnMut(i64) -> R::Element
153 {
154 if lhs.len() == 0 || rhs.len() == 0 {
155 return;
156 }
157 let len = lhs.len() + rhs.len() - 1;
158 assert!(dst.len() >= len);
159 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
160
161 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
162 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
163 if rhs_fft.is_owned() {
164 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
165 }
166 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
167
168 for i in 0..(1 << log2_len) {
169 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
170 }
171
172 self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
173
174 for i in 0..len {
175 let result = CC.closest_gaussian_int(lhs_fft[i]);
176 debug_assert_eq!(0, result.1);
177 ring.add_assign(&mut dst[i], from_int(result.0));
178 }
179 }
180
181 fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
182 &self,
183 data: I,
184 dst: &mut [R::Element],
185 ring: &R,
186 mut to_int: ToInt,
187 mut from_int: FromInt,
188 ring_log2_el_size: Option<usize>
189 )
190 where R: ?Sized + RingBase,
191 I: Iterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, A>>, V2, Option<&'a PreparedConvolutionOperand<R, A>>)>,
192 V1: VectorView<R::Element>,
193 V2: VectorView<R::Element>,
194 ToInt: FnMut(&R::Element) -> i64,
195 FromInt: FnMut(i64) -> R::Element,
196 Self: 'a,
197 R: 'a
198 {
199 let len = dst.len();
200 let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
201 let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
202 buffer.resize(1 << log2_len, CC.zero());
203
204 let mut count_since_last_reduction = 0;
205 let mut current_max_sum_len = usize::MAX;
206 let mut current_log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
207 log2_data_size
208 } else {
209 0
210 };
211 for (lhs, lhs_prep, rhs, rhs_prep) in data {
212 if lhs.len() == 0 || rhs.len() == 0 {
213 continue;
214 }
215 assert!(lhs.len() + rhs.len() - 1 <= dst.len());
216
217 if ring_log2_el_size.is_none() {
218 current_log2_data_size = max(
219 current_log2_data_size,
220 lhs.as_iter().chain(rhs.as_iter()).map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
221 );
222 current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
223 }
224 assert!(current_max_sum_len > 0);
225
226 if count_since_last_reduction > current_max_sum_len {
227 count_since_last_reduction = 0;
228 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
229 for i in 0..len {
230 let result = CC.closest_gaussian_int(buffer[i]);
231 debug_assert_eq!(0, result.1);
232 ring.add_assign(&mut dst[i], from_int(result.0));
233 }
234 for i in 0..(1 << log2_len) {
235 buffer[i] = CC.zero();
236 }
237 }
238
239 let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
240 let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
241 if rhs_fft.is_owned() {
242 std::mem::swap(&mut lhs_fft, &mut rhs_fft);
243 }
244 let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
245 for i in 0..(1 << log2_len) {
246 CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
247 CC.add_assign(&mut buffer[i], lhs_fft[i]);
248 }
249 count_since_last_reduction += 1;
250 }
251 self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
252 for i in 0..len {
253 let result = CC.closest_gaussian_int(buffer[i]);
254 debug_assert_eq!(0, result.1);
255 ring.add_assign(&mut dst[i], from_int(result.0));
256 }
257 }
258}
259
260fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
261 where I: RingStore, I::Type: IntegerRing
262{
263 move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
264}
265
266fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
267 where I: RingStore, I::Type: IntegerRing
268{
269 move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
270}
271
272fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
273 where R: RingStore, R::Type: ZnRing
274{
275 move |x| int_cast(ring.smallest_lift(ring.clone_el(x)), StaticRing::<i64>::RING, ring.integer_ring())
276}
277
278fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
279 where R: RingStore, R::Type: ZnRing
280{
281 let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
282 move |x| ring.get_ring().map_in(ring.integer_ring().get_ring(), int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING), &hom)
283}
284
285impl<A> Clone for FFTConvolution<A>
286 where A: Allocator + Clone
287{
288 fn clone(&self) -> Self {
289 Self {
290 allocator: self.allocator.clone(),
291 fft_tables: self.fft_tables.clone()
292 }
293 }
294}
295
296impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
297 where A: Allocator
298{
299 fn from(value: FFTConvolutionZn<A>) -> Self {
300 value.base
301 }
302}
303
304impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
305 where A: Allocator
306{
307 fn from(value: &'a FFTConvolutionZn<A>) -> Self {
308 &value.base
309 }
310}
311
312impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
313 where A: Allocator
314{
315 fn from(value: FFTConvolution<A>) -> Self {
316 FFTConvolutionZn { base: value }
317 }
318}
319
320impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
321 where A: Allocator
322{
323 fn from(value: &'a FFTConvolution<A>) -> Self {
324 unsafe { std::mem::transmute(value) }
325 }
326}
327
328#[stability::unstable(feature = "enable")]
329#[repr(transparent)]
330pub struct FFTConvolutionZn<A = Global> {
331 base: FFTConvolution<A>
332}
333
334impl<A> Clone for FFTConvolutionZn<A>
335 where A: Allocator + Clone
336{
337 fn clone(&self) -> Self {
338 Self { base: self.base.clone() }
339 }
340}
341
342impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
343 where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
344 A: Allocator + Clone
345{
346 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
347
348 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
349 self.base.compute_convolution_impl(
350 lhs,
351 None,
352 rhs,
353 None,
354 dst,
355 ring.get_ring(),
356 to_int_zn(&ring),
357 from_int_zn(&ring),
358 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
359 )
360 }
361
362 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
363 true
364 }
365
366 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
367 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
368 {
369 self.base.prepare_convolution_impl(
370 val,
371 ring.get_ring(),
372 len_hint,
373 to_int_zn(&ring),
374 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
375 )
376 }
377
378 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
379 where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
380 {
381 self.base.compute_convolution_impl(
382 lhs,
383 lhs_prep,
384 rhs,
385 rhs_prep,
386 dst,
387 ring.get_ring(),
388 to_int_zn(&ring),
389 from_int_zn(&ring),
390 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
391 )
392 }
393
394 fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
395 where S: RingStore<Type = R> + Copy,
396 I: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
397 V1: VectorView<R::Element>,
398 V2: VectorView<R::Element>,
399 Self: 'a,
400 R: 'a,
401 Self::PreparedConvolutionOperand: 'a
402 {
403 self.base.compute_convolution_sum_impl(
404 values,
405 dst,
406 ring.get_ring(),
407 to_int_zn(&ring),
408 from_int_zn(&ring),
409 Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
410 )
411 }
412}
413
414impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
415 where I: ?Sized + IntegerRing,
416 A: Allocator + Clone
417{
418 type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
419
420 fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
421 self.compute_convolution_impl(
422 lhs,
423 None,
424 rhs,
425 None,
426 dst,
427 ring.get_ring(),
428 to_int_int(&ring),
429 from_int_int(&ring),
430 None
431 )
432 }
433
434 fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
435 true
436 }
437
438 fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
439 where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
440 {
441 self.prepare_convolution_impl(
442 val,
443 ring.get_ring(),
444 len_hint,
445 to_int_int(&ring),
446 None
447 )
448 }
449
450 fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [I::Element], ring: S)
451 where S: RingStore<Type = I> + Copy, V1: VectorView<I::Element>, V2: VectorView<I::Element>
452 {
453 self.compute_convolution_impl(
454 lhs,
455 lhs_prep,
456 rhs,
457 rhs_prep,
458 dst,
459 ring.get_ring(),
460 to_int_int(&ring),
461 from_int_int(&ring),
462 None
463 )
464 }
465
466 fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S)
467 where S: RingStore<Type = I> + Copy,
468 J: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
469 V1: VectorView<I::Element>,
470 V2: VectorView<I::Element>,
471 Self: 'a,
472 I: 'a,
473 Self::PreparedConvolutionOperand: 'a
474 {
475 self.compute_convolution_sum_impl(
476 values,
477 dst,
478 ring.get_ring(),
479 to_int_int(&ring),
480 from_int_int(&ring),
481 None
482 )
483 }
484}
485
486#[cfg(test)]
487use crate::rings::finite::FiniteRingStore;
488#[cfg(test)]
489use crate::rings::zn::zn_64::Zn;
490
491#[test]
492fn test_convolution_zn() {
493 let convolution: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
494 let ring = Zn::new(17 * 257);
495
496 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
497}
498
499#[test]
500fn test_convolution_int() {
501 let convolution: FFTConvolution = FFTConvolution::new_with(Global);
502 let ring = StaticRing::<i64>::RING;
503
504 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
505}
506
507#[test]
508#[should_panic(expected = "precision")]
509fn test_fft_convolution_not_enough_precision() {
510 let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
511
512 let ring = Zn::new(1099511627791);
513 let lhs = ring.elements().take(1024).collect::<Vec<_>>();
514 let rhs = ring.elements().take(1024).collect::<Vec<_>>();
515 let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
516
517 convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
518}