feanor_math/algorithms/convolution/
fft.rs

1use std::cmp::max;
2use std::alloc::{Allocator, Global};
3use std::marker::PhantomData;
4
5use crate::cow::*;
6use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
7use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
8use crate::algorithms::fft::FFTAlgorithm;
9use crate::lazy::LazyVec;
10use crate::primitive_int::StaticRingBase;
11use crate::integer::*;
12use crate::ring::*;
13use crate::seq::*;
14use crate::primitive_int::*;
15use crate::homomorphism::*;
16use crate::rings::float_complex::*;
17use crate::rings::zn::*;
18
19use super::ConvolutionAlgorithm;
20
21const CC: Complex64 = Complex64::RING;
22
23#[stability::unstable(feature = "enable")]
24pub struct FFTConvolution<A = Global> {
25    allocator: A,
26    fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
27}
28
29#[stability::unstable(feature = "enable")]
30pub struct PreparedConvolutionOperand<R, A = Global>
31    where R: ?Sized + RingBase,
32        A: Allocator + Clone
33{
34    ring: PhantomData<Box<R>>,
35    fft_data: LazyVec<Vec<El<Complex64>, A>>,
36    log2_data_size: usize
37}
38
39impl FFTConvolution<Global> {
40    
41    #[stability::unstable(feature = "enable")]
42    pub fn new() -> Self {
43        Self::new_with_alloc(Global)
44    }
45}
46
47impl<A> FFTConvolution<A>
48    where A: Allocator + Clone
49{
50    #[stability::unstable(feature = "enable")]
51    pub fn new_with_alloc(allocator: A) -> Self {
52        Self {
53            allocator: allocator,
54            fft_tables: LazyVec::new()
55        }
56    }
57
58    fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
59        return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
60    }
61
62    fn get_fft_data<'a, R, V, ToInt>(
63        &self,
64        data: V,
65        data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
66        _ring: &R,
67        log2_len: usize,
68        mut to_int: ToInt,
69        log2_el_size: Option<usize>
70    ) -> MyCow<'a, Vec<El<Complex64>, A>>
71        where R: ?Sized + RingBase,
72            V: VectorView<R::Element>,
73            ToInt: FnMut(&R::Element) -> i64
74    {
75        let log2_data_size = if let Some(log2_data_size) = log2_el_size {
76            if let Some(data_prep) = data_prep {
77                assert_eq!(log2_data_size, data_prep.log2_data_size);
78            }
79            log2_data_size 
80        } else {
81            data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
82        };
83        assert!(data.len() <= (1 << log2_len));
84        assert!(self.has_sufficient_precision(log2_len, log2_data_size));
85
86        let mut compute_result = || {
87            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
88            result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
89            result.resize(1 << log2_len, Complex64::RING.zero());
90            self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
91            return result;
92        };
93
94        return if let Some(data_prep) = data_prep {
95            MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
96        } else {
97            MyCow::Owned(compute_result())
98        }
99    }
100
101    #[stability::unstable(feature = "enable")]
102    pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
103        self.max_sum_len(log2_len, log2_input_size) > 0
104    }
105    
106    fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
107        let fft_table = self.get_fft_table(log2_len);
108        let input_size = 2f64.powi(log2_input_size.try_into().unwrap());
109        (0.5 / fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.))).floor() as usize
110    }
111
112    fn prepare_convolution_impl<R, V, ToInt>(
113        &self,
114        data: V,
115        ring: &R,
116        length_hint: Option<usize>,
117        mut to_int: ToInt,
118        ring_log2_el_size: Option<usize>
119    ) -> PreparedConvolutionOperand<R, A>
120        where R: ?Sized + RingBase,
121            V: VectorView<R::Element>,
122            ToInt: FnMut(&R::Element) -> i64
123    {
124        let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
125            log2_data_size 
126        } else {
127            data.as_iter().map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
128        };
129        let result = PreparedConvolutionOperand {
130            fft_data: LazyVec::new(),
131            ring: PhantomData,
132            log2_data_size: log2_data_size
133        };
134        // if a length-hint is given, initialize the corresponding length entry;
135        // this might avoid confusing performance characteristics when the user does
136        // not expect lazy behavior
137        if let Some(len) = length_hint {
138            let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
139            _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
140        }
141        return result;
142    }
143
144    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
145        &self,
146        lhs: V1,
147        lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
148        rhs: V2,
149        rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
150        dst: &mut [R::Element],
151        ring: &R,
152        mut to_int: ToInt,
153        mut from_int: FromInt,
154        ring_log2_el_size: Option<usize>
155    )
156        where R: ?Sized + RingBase,
157            V1: VectorView<R::Element>,
158            V2: VectorView<R::Element>,
159            ToInt: FnMut(&R::Element) -> i64,
160            FromInt: FnMut(i64) -> R::Element
161    {
162        if lhs.len() == 0 || rhs.len() == 0 {
163            return;
164        }
165        let len = lhs.len() + rhs.len() - 1;
166        assert!(dst.len() >= len);
167        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
168
169        let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
170        let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
171        if rhs_fft.is_owned() {
172            std::mem::swap(&mut lhs_fft, &mut rhs_fft);
173        }
174        let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
175
176        for i in 0..(1 << log2_len) {
177            CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
178        }
179
180        self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
181
182        for i in 0..len {
183            let result = CC.closest_gaussian_int(lhs_fft[i]);
184            debug_assert_eq!(0, result.1);
185            ring.add_assign(&mut dst[i], from_int(result.0));
186        }
187    }
188
189    fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
190        &self,
191        data: I,
192        dst: &mut [R::Element],
193        ring: &R,
194        mut to_int: ToInt,
195        mut from_int: FromInt,
196        ring_log2_el_size: Option<usize>
197    )
198        where R: ?Sized + RingBase,
199            I: Iterator<Item = (V1, Option<&'a PreparedConvolutionOperand<R, A>>, V2, Option<&'a PreparedConvolutionOperand<R, A>>)>,
200            V1: VectorView<R::Element>,
201            V2: VectorView<R::Element>,
202            ToInt: FnMut(&R::Element) -> i64,
203            FromInt: FnMut(i64) -> R::Element,
204            Self: 'a,
205            R: 'a
206    {
207        let len = dst.len();
208        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
209        let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
210        buffer.resize(1 << log2_len, CC.zero());
211
212        let mut count_since_last_reduction = 0;
213        let mut current_max_sum_len = usize::MAX;
214        let mut current_log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
215            log2_data_size
216        } else {
217            0
218        };
219        for (lhs, lhs_prep, rhs, rhs_prep) in data {
220            if lhs.len() == 0 || rhs.len() == 0 {
221                continue;
222            }
223            assert!(lhs.len() + rhs.len() - 1 <= dst.len());
224
225            if ring_log2_el_size.is_none() {
226                current_log2_data_size = max(
227                    current_log2_data_size,
228                    lhs.as_iter().chain(rhs.as_iter()).map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0)).max().unwrap()
229                );
230                current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
231            }
232            assert!(current_max_sum_len > 0);
233            
234            if count_since_last_reduction > current_max_sum_len {
235                count_since_last_reduction = 0;
236                self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
237                for i in 0..len {
238                    let result = CC.closest_gaussian_int(buffer[i]);
239                    debug_assert_eq!(0, result.1);
240                    ring.add_assign(&mut dst[i], from_int(result.0));
241                }
242                for i in 0..(1 << log2_len) {
243                    buffer[i] = CC.zero();
244                }
245            }
246            
247            let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
248            let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
249            if rhs_fft.is_owned() {
250                std::mem::swap(&mut lhs_fft, &mut rhs_fft);
251            }
252            let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
253            for i in 0..(1 << log2_len) {
254                CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
255                CC.add_assign(&mut buffer[i], lhs_fft[i]);
256            }
257            count_since_last_reduction += 1;
258        }
259        self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
260        for i in 0..len {
261            let result = CC.closest_gaussian_int(buffer[i]);
262            debug_assert_eq!(0, result.1);
263            ring.add_assign(&mut dst[i], from_int(result.0));
264        }
265    }
266}
267
268fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
269    where I: RingStore, I::Type: IntegerRing
270{
271    move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
272}
273
274fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
275    where I: RingStore, I::Type: IntegerRing
276{
277    move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
278}
279
280fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
281    where R: RingStore, R::Type: ZnRing
282{
283    move |x| int_cast(ring.smallest_lift(ring.clone_el(x)), StaticRing::<i64>::RING, ring.integer_ring())
284}
285
286fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
287    where R: RingStore, R::Type: ZnRing
288{
289    let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
290    move |x| ring.get_ring().map_in(ring.integer_ring().get_ring(), int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING), &hom)
291}
292
293impl<A> Clone for FFTConvolution<A>
294    where A: Allocator + Clone
295{
296    fn clone(&self) -> Self {
297        Self {
298            allocator: self.allocator.clone(),
299            fft_tables: self.fft_tables.clone()
300        }
301    }
302}
303
304impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
305    where A: Allocator
306{
307    fn from(value: FFTConvolutionZn<A>) -> Self {
308        value.base
309    }
310}
311
312impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
313    where A: Allocator
314{
315    fn from(value: &'a FFTConvolutionZn<A>) -> Self {
316        &value.base
317    }
318}
319
320impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
321    where A: Allocator
322{
323    fn from(value: FFTConvolution<A>) -> Self {
324        FFTConvolutionZn { base: value }
325    }
326}
327
328impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
329    where A: Allocator
330{
331    fn from(value: &'a FFTConvolution<A>) -> Self {
332        unsafe { std::mem::transmute(value) }
333    }
334}
335
336#[stability::unstable(feature = "enable")]
337#[repr(transparent)]
338pub struct FFTConvolutionZn<A = Global> {
339    base: FFTConvolution<A>
340}
341
342impl<A> Clone for FFTConvolutionZn<A>
343    where A: Allocator + Clone
344{
345    fn clone(&self) -> Self {
346        Self { base: self.base.clone() }
347    }
348}
349
350impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
351    where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
352        A: Allocator + Clone
353{
354    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
355
356    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
357        self.base.compute_convolution_impl(
358            lhs,
359            None,
360            rhs,
361            None,
362            dst,
363            ring.get_ring(),
364            to_int_zn(&ring),
365            from_int_zn(&ring),
366            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
367        )
368    }
369
370    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
371        true
372    }
373
374    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
375        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
376    {
377        self.base.prepare_convolution_impl(
378            val,
379            ring.get_ring(),
380            len_hint,
381            to_int_zn(&ring),
382            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
383        )
384    }
385
386    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
387        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
388    {
389        self.base.compute_convolution_impl(
390            lhs,
391            lhs_prep,
392            rhs,
393            rhs_prep,
394            dst,
395            ring.get_ring(),
396            to_int_zn(&ring),
397            from_int_zn(&ring),
398            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
399        )
400    }
401
402    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
403        where S: RingStore<Type = R> + Copy, 
404            I: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
405            V1: VectorView<R::Element>,
406            V2: VectorView<R::Element>,
407            Self: 'a,
408            R: 'a,
409            Self::PreparedConvolutionOperand: 'a
410    {
411        self.base.compute_convolution_sum_impl(
412            values,
413            dst,
414            ring.get_ring(),
415            to_int_zn(&ring),
416            from_int_zn(&ring),
417            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap())
418        )
419    }
420}
421
422impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
423    where I: ?Sized + IntegerRing,
424        A: Allocator + Clone
425{
426    type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
427
428    fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
429        self.compute_convolution_impl(
430            lhs,
431            None,
432            rhs,
433            None,
434            dst,
435            ring.get_ring(),
436            to_int_int(&ring),
437            from_int_int(&ring),
438            None
439        )
440    }
441
442    fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
443        true
444    }
445
446    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
447        where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
448    {
449        self.prepare_convolution_impl(
450            val,
451            ring.get_ring(),
452            len_hint,
453            to_int_int(&ring),
454            None
455        )
456    }
457
458    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [I::Element], ring: S)
459        where S: RingStore<Type = I> + Copy, V1: VectorView<I::Element>, V2: VectorView<I::Element>
460    {
461        self.compute_convolution_impl(
462            lhs,
463            lhs_prep,
464            rhs,
465            rhs_prep,
466            dst,
467            ring.get_ring(),
468            to_int_int(&ring),
469            from_int_int(&ring),
470            None
471        )
472    }
473
474    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S) 
475        where S: RingStore<Type = I> + Copy, 
476            J: Iterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
477            V1: VectorView<I::Element>,
478            V2: VectorView<I::Element>,
479            Self: 'a,
480            I: 'a,
481            Self::PreparedConvolutionOperand: 'a
482    {
483        self.compute_convolution_sum_impl(
484            values,
485            dst,
486            ring.get_ring(),
487            to_int_int(&ring),
488            from_int_int(&ring),
489            None
490        )
491    }
492}
493
494#[cfg(test)]
495use crate::rings::finite::FiniteRingStore;
496#[cfg(test)]
497use crate::rings::zn::zn_64::Zn;
498
499#[test]
500fn test_convolution_zn() {
501    let convolution: FFTConvolutionZn = FFTConvolution::new().into();
502    let ring = Zn::new(17 * 257);
503
504    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
505}
506
507#[test]
508fn test_convolution_int() {
509    let convolution: FFTConvolution = FFTConvolution::new();
510    let ring = StaticRing::<i64>::RING;
511
512    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
513}
514
515#[test]
516#[should_panic(expected = "precision")]
517fn test_fft_convolution_not_enough_precision() {
518    let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new().into();
519
520    let ring = Zn::new(1099511627791);
521    let lhs = ring.elements().take(1024).collect::<Vec<_>>();
522    let rhs = ring.elements().take(1024).collect::<Vec<_>>();
523    let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
524
525    convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
526}