1use std::alloc::{Allocator, Global};
2use std::marker::PhantomData;
3
4use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
5use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
6use crate::algorithms::fft::FFTAlgorithm;
7use crate::lazy::LazyVec;
8use crate::primitive_int::StaticRingBase;
9use crate::integer::*;
10use crate::ring::*;
11use crate::seq::*;
12use crate::primitive_int::*;
13use crate::homomorphism::*;
14use crate::rings::float_complex::*;
15use crate::rings::zn::*;
16
17use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
18
19const ZZ: StaticRing<i64> = StaticRing::RING;
20const CC: Complex64 = Complex64::RING;
21
22#[stability::unstable(feature = "enable")]
23pub struct FFTConvolution<A = Global> {
24 allocator: A,
25 fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
26}
27
28#[stability::unstable(feature = "enable")]
29pub struct PreparedConvolutionOperand<R, A = Global>
30 where R: ?Sized + RingBase,
31 A: Allocator + Clone
32{
33 ring: PhantomData<Box<R>>,
34 original_data: Vec<f64, A>,
35 fft_data: Vec<El<Complex64>, A>
36}
37
38impl<A> FFTConvolution<A>
39 where A: Allocator + Clone
40{
41 #[stability::unstable(feature = "enable")]
42 pub fn new_with(allocator: A) -> Self {
43 Self {
44 allocator: allocator,
45 fft_tables: LazyVec::new()
46 }
47 }
48
49 fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
50 return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
51 }
52
53 #[stability::unstable(feature = "enable")]
54 pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
55 let fft_table = self.get_fft_table(log2_len);
56 let input_size = 2f64.powi(log2_input_size as i32);
57 fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.)) < 0.5
58 }
59
60 fn compute_convolution_impl(&self, mut lhs: Vec<El<Complex64>, A>, rhs: &[El<Complex64>], target_len: usize) -> impl Iterator<Item = i64> {
61 let log2_n = ZZ.abs_log2_ceil(&(lhs.len() as i64)).unwrap();
62 assert_eq!(lhs.len(), 1 << log2_n);
63 assert_eq!(rhs.len(), 1 << log2_n);
64
65 for i in 0..(1 << log2_n) {
66 CC.mul_assign(&mut lhs[i], rhs[i]);
67 }
68 self.get_fft_table(log2_n).unordered_inv_fft(&mut lhs[..], CC);
69 (0..target_len).map(move |i| {
70 let x = CC.closest_gaussian_int(lhs[i]);
71 debug_assert!(x.1 == 0);
72 return x.0;
73 })
74 }
75
76 fn prepare_convolution_impl<V>(&self, data: V, log2_n: usize, log2_data_size: Option<usize>) -> (usize, Vec<El<Complex64>, A>)
77 where V: VectorFn<f64>
78 {
79 assert!(data.len() <= 1 << log2_n);
80 let log2_data_size = if let Some(log2_data_size) = log2_data_size {
81 log2_data_size
82 } else {
83 data.iter().map(|x| x.abs()).max_by(f64::total_cmp).unwrap().log2().ceil() as usize
84 };
85 assert!(self.has_sufficient_precision(log2_n, log2_data_size));
86
87 let mut fft_data = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
88 fft_data.extend(data.iter().map(|x| CC.from_f64(x)));
89 fft_data.resize(1 << log2_n, CC.zero());
90 let fft = self.get_fft_table(log2_n);
91 fft.unordered_fft(&mut fft_data[..], CC);
92 return (log2_data_size, fft_data);
93 }
94}
95
96impl<A> Clone for FFTConvolution<A>
97 where A: Allocator + Clone
98{
99 fn clone(&self) -> Self {
100 Self {
101 allocator: self.allocator.clone(),
102 fft_tables: self.fft_tables.clone()
103 }
104 }
105}
106
107impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
108 where A: Allocator
109{
110 fn from(value: FFTConvolutionZn<A>) -> Self {
111 value.base
112 }
113}
114
115impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
116 where A: Allocator
117{
118 fn from(value: &'a FFTConvolutionZn<A>) -> Self {
119 &value.base
120 }
121}
122
123impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
124 where A: Allocator
125{
126 fn from(value: FFTConvolution<A>) -> Self {
127 FFTConvolutionZn { base: value }
128 }
129}
130
131impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
132 where A: Allocator
133{
134 fn from(value: &'a FFTConvolution<A>) -> Self {
135 unsafe { std::mem::transmute(value) }
136 }
137}
138
139#[stability::unstable(feature = "enable")]
140#[repr(transparent)]
141pub struct FFTConvolutionZn<A = Global> {
142 base: FFTConvolution<A>
143}
144
145impl<A> Clone for FFTConvolutionZn<A>
146 where A: Allocator + Clone
147{
148 fn clone(&self) -> Self {
149 Self { base: self.base.clone() }
150 }
151}
152
153impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
154 where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
155 A: Allocator + Clone
156{
157 fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
158 if lhs.len() == 0 || rhs.len() == 0 {
159 return;
160 }
161 let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
162 let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
163 let lhs_prep = self.base.prepare_convolution_impl(lhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_n, Some(log2_data_size)).1;
164 let rhs_prep = self.base.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_n, Some(log2_data_size)).1;
165 let hom = ring.can_hom(&ZZ).unwrap();
166 for (i, x) in self.base.compute_convolution_impl(lhs_prep, &rhs_prep, lhs.len() + rhs.len() - 1).enumerate() {
167 ring.add_assign(&mut dst[i], hom.map(x));
168 }
169 }
170
171 fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
172 true
173 }
174
175 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
176 where F: PreparedConvolutionOperation<Self, R>
177 {
178 Ok(function.execute())
179 }
180}
181
182impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
183 where I: ?Sized + IntegerRing,
184 A: Allocator + Clone
185{
186 fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
187 if lhs.len() == 0 || rhs.len() == 0 {
188 return;
189 }
190 let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
191 let lhs_prep = self.prepare_convolution_impl(lhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_n, None).1;
192 let rhs_prep = self.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_n, None).1;
193 for (i, x) in self.compute_convolution_impl(lhs_prep, &rhs_prep, lhs.len() + rhs.len() - 1).enumerate() {
194 ring.add_assign(&mut dst[i], int_cast(x, &ring, ZZ));
195 }
196 }
197
198 fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
199 true
200 }
201
202 fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
203 where F: PreparedConvolutionOperation<Self, I>
204 {
205 Ok(function.execute())
206 }
207}
208
209impl<I, A> PreparedConvolutionAlgorithm<I> for FFTConvolution<A>
210 where I: ?Sized + IntegerRing,
211 A: Allocator + Clone
212{
213 type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
214
215 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
216 where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
217 {
218 let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
219 let log2_n_out = log2_n_in + 1;
220 let mut original_data = Vec::new_in(self.allocator.clone());
221 original_data.extend(val.clone_ring_els(&ring).iter().map(|x| int_cast(x, ZZ, &ring) as f64));
222 let (_log2_data_size, fft_data) = self.prepare_convolution_impl(original_data.copy_els(), log2_n_out, None);
223 return PreparedConvolutionOperand {
224 fft_data: fft_data,
225 original_data: original_data,
226 ring: PhantomData
227 };
228 }
229
230 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [I::Element], ring: S)
231 where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
232 {
233 assert!(ring.is_commutative());
234 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
235 assert_eq!(lhs.fft_data.len(), 1 << log2_lhs);
236 let target_len = lhs.original_data.len() + rhs.len() - 1;
237 let log2_target_len = ZZ.abs_log2_ceil(&(target_len as i64)).unwrap().max(log2_lhs);
238 let els = if log2_target_len > log2_lhs {
239 assert!(target_len <= 1 << log2_target_len);
240 let lhs_prep = self.prepare_convolution_impl(lhs.original_data.copy_els(), log2_target_len, None).1;
241 let rhs_prep = self.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_target_len, None).1;
242 self.compute_convolution_impl(lhs_prep, &rhs_prep, target_len)
243 } else {
244 assert!(log2_lhs == log2_target_len || log2_lhs == log2_target_len + 1);
245 assert!(target_len <= 1 << log2_lhs);
246 self.compute_convolution_impl(
247 self.prepare_convolution_impl(rhs.clone_ring_els(ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_lhs, None).1,
248 &lhs.fft_data,
249 target_len
250 )
251 };
252 for (i, x) in els.enumerate() {
253 ring.add_assign(&mut dst[i], int_cast(x, ring, ZZ));
254 }
255 }
256
257 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [I::Element], ring: S)
258 where S: RingStore<Type = I> + Copy
259 {
260 assert!(ring.is_commutative());
261 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
262 assert_eq!(1 << log2_lhs, lhs.fft_data.len());
263 let log2_rhs = ZZ.abs_log2_ceil(&(rhs.fft_data.len() as i64)).unwrap();
264 assert_eq!(1 << log2_rhs, rhs.fft_data.len());
265 let target_len = lhs.original_data.len() + rhs.original_data.len() - 1;
266 assert!(target_len <= 1 << log2_lhs || target_len <= 1 << log2_rhs);
267 let els = match log2_lhs.cmp(&log2_rhs) {
268 std::cmp::Ordering::Equal => self.compute_convolution_impl(lhs.fft_data.clone(), &rhs.fft_data, target_len),
269 std::cmp::Ordering::Greater => self.compute_convolution_impl(self.prepare_convolution_impl(rhs.original_data.copy_els(), log2_lhs, None).1, &lhs.fft_data, target_len),
270 std::cmp::Ordering::Less => self.compute_convolution_impl(self.prepare_convolution_impl(lhs.original_data.copy_els(), log2_rhs, None).1, &rhs.fft_data, target_len)
271 };
272 for (i, x) in els.enumerate() {
273 ring.add_assign(&mut dst[i], int_cast(x, ring, ZZ));
274 }
275 }
276}
277
278impl<R, A> PreparedConvolutionAlgorithm<R> for FFTConvolutionZn<A>
279 where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
280 A: Allocator + Clone
281{
282 type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
283
284 fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
285 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
286 {
287 let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
288 let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
289 let log2_n_out = log2_n_in + 1;
290 let mut original_data = Vec::new_in(self.base.allocator.clone());
291 original_data.extend(val.clone_ring_els(&ring).iter().map(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64));
292 let (_log2_data_size, fft_data) = self.base.prepare_convolution_impl(original_data.copy_els(), log2_n_out, Some(log2_data_size));
293 return PreparedConvolutionOperand {
294 fft_data: fft_data,
295 original_data: original_data,
296 ring: PhantomData
297 };
298 }
299
300 fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
301 where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
302 {
303 assert!(ring.is_commutative());
304 let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
305 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
306 assert_eq!(lhs.fft_data.len(), 1 << log2_lhs);
307 let target_len = lhs.original_data.len() + rhs.len() - 1;
308 let log2_target_len = ZZ.abs_log2_ceil(&(target_len as i64)).unwrap().max(log2_lhs);
309 let els = if log2_target_len > log2_lhs {
310 assert!(target_len <= 1 << log2_target_len);
311 let lhs_prep = self.base.prepare_convolution_impl(lhs.original_data.copy_els(), log2_target_len, Some(log2_data_size)).1;
312 let rhs_prep = self.base.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_target_len, Some(log2_data_size)).1;
313 self.base.compute_convolution_impl(lhs_prep, &rhs_prep, target_len)
314 } else {
315 assert!(log2_lhs == log2_target_len || log2_lhs == log2_target_len + 1);
316 assert!(target_len <= 1 << log2_lhs);
317 self.base.compute_convolution_impl(
318 self.base.prepare_convolution_impl(rhs.clone_ring_els(ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_lhs, Some(log2_data_size)).1,
319 &lhs.fft_data,
320 target_len
321 )
322 };
323 let hom = ring.can_hom(&ZZ).unwrap();
324 for (i, x) in els.enumerate() {
325 ring.add_assign(&mut dst[i], hom.map(x));
326 }
327 }
328
329 fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
330 where S: RingStore<Type = R> + Copy
331 {
332 assert!(ring.is_commutative());
333 let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
334 let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
335 assert_eq!(1 << log2_lhs, lhs.fft_data.len());
336 let log2_rhs = ZZ.abs_log2_ceil(&(rhs.fft_data.len() as i64)).unwrap();
337 assert_eq!(1 << log2_rhs, rhs.fft_data.len());
338 let target_len = lhs.original_data.len() + rhs.original_data.len() - 1;
339 assert!(target_len <= 1 << log2_lhs || target_len <= 1 << log2_rhs);
340 let els = match log2_lhs.cmp(&log2_rhs) {
341 std::cmp::Ordering::Equal => self.base.compute_convolution_impl(lhs.fft_data.clone(), &rhs.fft_data, target_len),
342 std::cmp::Ordering::Greater => self.base.compute_convolution_impl(self.base.prepare_convolution_impl(rhs.original_data.copy_els(), log2_lhs, Some(log2_data_size)).1, &lhs.fft_data, target_len),
343 std::cmp::Ordering::Less => self.base.compute_convolution_impl(self.base.prepare_convolution_impl(lhs.original_data.copy_els(), log2_rhs, Some(log2_data_size)).1, &rhs.fft_data, target_len)
344 };
345 let hom = ring.can_hom(&ZZ).unwrap();
346 for (i, x) in els.enumerate() {
347 ring.add_assign(&mut dst[i], hom.map(x));
348 }
349 }
350}
351
352#[cfg(test)]
353use crate::rings::finite::FiniteRingStore;
354#[cfg(test)]
355use crate::rings::zn::zn_64::Zn;
356
357#[test]
358fn test_convolution_zn() {
359 let convolution: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
360 let ring = Zn::new(17 * 257);
361
362 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
363 super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
364}
365
366#[test]
367fn test_convolution_int() {
368 let convolution: FFTConvolution = FFTConvolution::new_with(Global);
369 let ring = StaticRing::<i64>::RING;
370
371 super::generic_tests::test_convolution(&convolution, &ring, ring.one());
372 super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
373}
374
375#[test]
376#[should_panic(expected = "precision")]
377fn test_fft_convolution_not_enough_precision() {
378 let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
379
380 let ring = Zn::new(1099511627791);
381 let lhs = ring.elements().take(1024).collect::<Vec<_>>();
382 let rhs = ring.elements().take(1024).collect::<Vec<_>>();
383 let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
384
385 convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
386}