feanor_math/algorithms/convolution/
fft.rs

1use std::cmp::max;
2use std::alloc::{Allocator, Global};
3use std::marker::PhantomData;
4
5use crate::cow::*;
6use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
7use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
8use crate::algorithms::fft::FFTAlgorithm;
9use crate::lazy::LazyVec;
10use crate::primitive_int::StaticRingBase;
11use crate::integer::*;
12use crate::ring::*;
13use crate::seq::*;
14use crate::primitive_int::*;
15use crate::homomorphism::*;
16use crate::rings::float_complex::*;
17use crate::rings::zn::*;
18
19use super::ConvolutionAlgorithm;
20
21const CC: Complex64 = Complex64::RING;
22
23#[stability::unstable(feature = "enable")]
24pub struct FFTConvolution<A = Global> {
25    allocator: A,
26    fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
27}
28
29#[stability::unstable(feature = "enable")]
30pub struct PreparedConvolutionOperand<R, A = Global>
31    where R: ?Sized + RingBase,
32        A: Allocator + Clone
33{
34    ring: PhantomData<Box<R>>,
35    fft_data: LazyVec<Vec<El<Complex64>, A>>,
36    log2_data_size: usize
37}
38
39impl<A> FFTConvolution<A>
40    where A: Allocator + Clone
41{
42    #[stability::unstable(feature = "enable")]
43    pub fn new_with(allocator: A) -> Self {
44        Self {
45            allocator: allocator,
46            fft_tables: LazyVec::new()
47        }
48    }
49
50    fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
51        return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
52    }
53
54    fn get_fft_data<'a, R, V, ToInt>(
55        &self,
56        data: V,
57        data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
58        _ring: &R,
59        log2_len: usize,
60        mut to_int: ToInt,
61        log2_el_size: Option<usize>
62    ) -> MyCow<'a, Vec<El<Complex64>, A>>
63        where R: ?Sized + RingBase,
64            V: VectorView<R::Element>,
65            ToInt: FnMut(&R::Element) -> i64
66    {
67        let log2_data_size = if let Some(log2_data_size) = log2_el_size {
68            if let Some(data_prep) = data_prep {
69                assert_eq!(log2_data_size, data_prep.log2_data_size);
70            }
71            log2_data_size 
72        } else {
73            data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
74        };
75        assert!(data.len() <= (1 << log2_len));
76        assert!(self.has_sufficient_precision(log2_len, log2_data_size));
77
78        let mut compute_result = || {
79            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
80            result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
81            result.resize(1 << log2_len, Complex64::RING.zero());
82            self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
83            return result;
84        };
85
86        return if let Some(data_prep) = data_prep {
87            MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
88        } else {
89            MyCow::Owned(compute_result())
90        }
91    }
92
93    #[stability::unstable(feature = "enable")]
94    pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
95        self.max_sum_len(log2_len, log2_input_size) > 0
96    }
97    
98    fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
99        let fft_table = self.get_fft_table(log2_len);
100        let input_size = 2f64.powi(log2_input_size.try_into().unwrap());
101        (0.5 / fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.))).floor() as usize
102    }
103
104    fn prepare_convolution_impl<R, V, ToInt>(
105        &self,
106        data: V,
107        ring: &R,
108        length_hint: Option<usize>,
109        mut to_int: ToInt,
110        ring_log2_el_size: Option<usize>
111    ) -> PreparedConvolutionOperand<R, A>
112        where R: ?Sized + RingBase,
113            V: VectorView<R::Element>,
114            ToInt: FnMut(&R::Element) -> i64
115    {
116        let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
117            log2_data_size 
118        } else {
119            data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
120        };
121        let result = PreparedConvolutionOperand {
122            fft_data: LazyVec::new(),
123            ring: PhantomData,
124            log2_data_size: log2_data_size
125        };
126        // if a length-hint is given, initialize the corresponding length entry;
127        // this might avoid confusing performance characteristics when the user does
128        // not expect lazy behavior
129        if let Some(len) = length_hint {
130            let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
131            _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
132        }
133        return result;
134    }
135
136    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
137        &self,
138        lhs: V1,
139        lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
140        rhs: V2,
141        rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
142        dst: &mut [R::Element],
143        ring: &R,
144        mut to_int: ToInt,
145        mut from_int: FromInt,
146        ring_log2_el_size: Option<usize>
147    )
148        where R: ?Sized + RingBase,
149            V1: VectorView<R::Element>,
150            V2: VectorView<R::Element>,
151            ToInt: FnMut(&R::Element) -> i64,
152            FromInt: FnMut(i64) -> R::Element
153    {
154        if lhs.len() == 0 || rhs.len() == 0 {
155            return;
156        }
157        let len = lhs.len() + rhs.len() - 1;
158        assert!(dst.len() >= len);
159        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
160
161        let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
162        let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
163        if rhs_fft.is_owned() {
164            std::mem::swap(&mut lhs_fft, &mut rhs_fft);
165        }
166        let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
167
168        for i in 0..(1 << log2_len) {
169            CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
170        }
171
172        self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
173
174        for i in 0..len {
175            let result = CC.closest_gaussian_int(lhs_fft[i]);
176            debug_assert_eq!(0, result.1);
177            ring.add_assign(&mut dst[i], from_int(result.0));
178        }
179    }
180
181    fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
182        &self,
183        data: I,
184        dst: &mut [R::Element],
185        ring: &R,
186        mut to_int: ToInt,
187        mut from_int: FromInt,
188        ring_log2_el_size: Option<usize>
189    )
190        where R: ?Sized + RingBase,
191            I: Iterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, A>>, V2, Option<&'a PreparedConvolutionOperand<R, A>>)>,
192            V1: VectorView<R::Element>,
193            V2: VectorView<R::Element>,
194            ToInt: FnMut(&R::Element) -> i64,
195            FromInt: FnMut(i64) -> R::Element,
196            Self: 'a,
197            R: 'a
198    {
199        let len = dst.len();
200        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
201        let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
202        buffer.resize(1 << log2_len, CC.zero());
203
204        let mut count_since_last_reduction = 0;
205        let mut current_max_sum_len = usize::MAX;
206        let mut current_log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
207            log2_data_size
208        } else {
209            0
210        };
211        for (lhs, lhs_prep, rhs, rhs_prep) in data {
212            if lhs.len() == 0 || rhs.len() == 0 {
213                continue;
214            }
215            assert!(lhs.len() + rhs.len() - 1 <= dst.len());
216
217            if ring_log2_el_size.is_none() {
218                current_log2_data_size = max(
219                    current_log2_data_size,
220                    lhs.as_iter().chain(rhs.as_iter()).map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
221                );
222                current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
223            }
224            assert!(current_max_sum_len > 0);
225            
226            if count_since_last_reduction > current_max_sum_len {
227                count_since_last_reduction = 0;
228                self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
229                for i in 0..len {
230                    let result = CC.closest_gaussian_int(buffer[i]);
231                    debug_assert_eq!(0, result.1);
232                    ring.add_assign(&mut dst[i], from_int(result.0));
233                }
234                for i in 0..(1 << log2_len) {
235                    buffer[i] = CC.zero();
236                }
237            }
238            
239            let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
240            let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
241            if rhs_fft.is_owned() {
242                std::mem::swap(&mut lhs_fft, &mut rhs_fft);
243            }
244            let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
245            for i in 0..(1 << log2_len) {
246                CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
247                CC.add_assign(&mut buffer[i], lhs_fft[i]);
248            }
249            count_since_last_reduction += 1;
250        }
251        self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
252        for i in 0..len {
253            let result = CC.closest_gaussian_int(buffer[i]);
254            debug_assert_eq!(0, result.1);
255            ring.add_assign(&mut dst[i], from_int(result.0));
256        }
257    }
258}
259
260fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
261    where I: RingStore, I::Type: IntegerRing
262{
263    move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
264}
265
266fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
267    where I: RingStore, I::Type: IntegerRing
268{
269    move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
270}
271
272fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
273    where R: RingStore, R::Type: ZnRing
274{
275    move |x| int_cast(ring.smallest_lift(ring.clone_el(x)), StaticRing::<i64>::RING, ring.integer_ring())
276}
277
278fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
279    where R: RingStore, R::Type: ZnRing
280{
281    let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
282    move |x| ring.get_ring().map_in(ring.integer_ring().get_ring(), int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING), &hom)
283}
284
285impl<A> Clone for FFTConvolution<A>
286    where A: Allocator + Clone
287{
288    fn clone(&self) -> Self {
289        Self {
290            allocator: self.allocator.clone(),
291            fft_tables: self.fft_tables.clone()
292        }
293    }
294}
295
296impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
297    where A: Allocator
298{
299    fn from(value: FFTConvolutionZn<A>) -> Self {
300        value.base
301    }
302}
303
304impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
305    where A: Allocator
306{
307    fn from(value: &'a FFTConvolutionZn<A>) -> Self {
308        &value.base
309    }
310}
311
312impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
313    where A: Allocator
314{
315    fn from(value: FFTConvolution<A>) -> Self {
316        FFTConvolutionZn { base: value }
317    }
318}
319
320impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
321    where A: Allocator
322{
323    fn from(value: &'a FFTConvolution<A>) -> Self {
324        unsafe { std::mem::transmute(value) }
325    }
326}
327
328#[stability::unstable(feature = "enable")]
329#[repr(transparent)]
330pub struct FFTConvolutionZn<A = Global> {
331    base: FFTConvolution<A>
332}
333
334impl<A> Clone for FFTConvolutionZn<A>
335    where A: Allocator + Clone
336{
337    fn clone(&self) -> Self {
338        Self { base: self.base.clone() }
339    }
340}
341
342impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
343    where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
344        A: Allocator + Clone
345{
346    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
347
348    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
349        self.base.compute_convolution_impl(
350            lhs,
351            None,
352            rhs,
353            None,
354            dst,
355            ring.get_ring(),
356            to_int_zn(&ring),
357            from_int_zn(&ring),
358            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
359        )
360    }
361
362    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
363        true
364    }
365
366    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
367        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
368    {
369        self.base.prepare_convolution_impl(
370            val,
371            ring.get_ring(),
372            len_hint,
373            to_int_zn(&ring),
374            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
375        )
376    }
377
378    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
379        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
380    {
381        self.base.compute_convolution_impl(
382            lhs,
383            lhs_prep,
384            rhs,
385            rhs_prep,
386            dst,
387            ring.get_ring(),
388            to_int_zn(&ring),
389            from_int_zn(&ring),
390            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
391        )
392    }
393
394    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
395        where S: RingStore<Type = R> + Copy, 
396            I: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
397            V1: VectorView<R::Element>,
398            V2: VectorView<R::Element>,
399            Self: 'a,
400            R: 'a,
401            Self::PreparedConvolutionOperand: 'a
402    {
403        self.base.compute_convolution_sum_impl(
404            values,
405            dst,
406            ring.get_ring(),
407            to_int_zn(&ring),
408            from_int_zn(&ring),
409            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
410        )
411    }
412}
413
414impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
415    where I: ?Sized + IntegerRing,
416        A: Allocator + Clone
417{
418    type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
419
420    fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
421        self.compute_convolution_impl(
422            lhs,
423            None,
424            rhs,
425            None,
426            dst,
427            ring.get_ring(),
428            to_int_int(&ring),
429            from_int_int(&ring),
430            None
431        )
432    }
433
434    fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
435        true
436    }
437
438    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
439        where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
440    {
441        self.prepare_convolution_impl(
442            val,
443            ring.get_ring(),
444            len_hint,
445            to_int_int(&ring),
446            None
447        )
448    }
449
450    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [I::Element], ring: S)
451        where S: RingStore<Type = I> + Copy, V1: VectorView<I::Element>, V2: VectorView<I::Element>
452    {
453        self.compute_convolution_impl(
454            lhs,
455            lhs_prep,
456            rhs,
457            rhs_prep,
458            dst,
459            ring.get_ring(),
460            to_int_int(&ring),
461            from_int_int(&ring),
462            None
463        )
464    }
465
466    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S) 
467        where S: RingStore<Type = I> + Copy, 
468            J: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
469            V1: VectorView<I::Element>,
470            V2: VectorView<I::Element>,
471            Self: 'a,
472            I: 'a,
473            Self::PreparedConvolutionOperand: 'a
474    {
475        self.compute_convolution_sum_impl(
476            values,
477            dst,
478            ring.get_ring(),
479            to_int_int(&ring),
480            from_int_int(&ring),
481            None
482        )
483    }
484}
485
486#[cfg(test)]
487use crate::rings::finite::FiniteRingStore;
488#[cfg(test)]
489use crate::rings::zn::zn_64::Zn;
490
491#[test]
492fn test_convolution_zn() {
493    let convolution: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
494    let ring = Zn::new(17 * 257);
495
496    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
497}
498
499#[test]
500fn test_convolution_int() {
501    let convolution: FFTConvolution = FFTConvolution::new_with(Global);
502    let ring = StaticRing::<i64>::RING;
503
504    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
505}
506
507#[test]
508#[should_panic(expected = "precision")]
509fn test_fft_convolution_not_enough_precision() {
510    let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
511
512    let ring = Zn::new(1099511627791);
513    let lhs = ring.elements().take(1024).collect::<Vec<_>>();
514    let rhs = ring.elements().take(1024).collect::<Vec<_>>();
515    let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
516
517    convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
518}