Skip to main content

feanor_math/algorithms/convolution/
fft.rs

1use std::alloc::{Allocator, Global};
2use std::cmp::max;
3use std::marker::PhantomData;
4
5use super::ConvolutionAlgorithm;
6use crate::algorithms::fft::FFTAlgorithm;
7use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
8use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
9use crate::cow::*;
10use crate::homomorphism::*;
11use crate::integer::*;
12use crate::lazy::LazyVec;
13use crate::primitive_int::{StaticRingBase, *};
14use crate::ring::*;
15use crate::rings::float_complex::*;
16use crate::rings::zn::*;
17use crate::seq::*;
18
19const CC: Complex64 = Complex64::RING;
20
21#[stability::unstable(feature = "enable")]
22pub struct FFTConvolution<A = Global> {
23    allocator: A,
24    fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>,
25}
26
27#[stability::unstable(feature = "enable")]
28pub struct PreparedConvolutionOperand<R, A = Global>
29where
30    R: ?Sized + RingBase,
31    A: Allocator + Clone,
32{
33    ring: PhantomData<Box<R>>,
34    fft_data: LazyVec<Vec<El<Complex64>, A>>,
35    log2_data_size: usize,
36}
37
38impl FFTConvolution<Global> {
39    #[stability::unstable(feature = "enable")]
40    pub fn new() -> Self { Self::new_with_alloc(Global) }
41}
42
43impl<A> FFTConvolution<A>
44where
45    A: Allocator + Clone,
46{
47    #[stability::unstable(feature = "enable")]
48    pub fn new_with_alloc(allocator: A) -> Self {
49        Self {
50            allocator,
51            fft_tables: LazyVec::new(),
52        }
53    }
54
55    fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
56        return self
57            .fft_tables
58            .get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
59    }
60
61    fn get_fft_data<'a, R, V, ToInt>(
62        &self,
63        data: V,
64        data_prep: Option<&'a PreparedConvolutionOperand<R, A>>,
65        _ring: &R,
66        log2_len: usize,
67        mut to_int: ToInt,
68        log2_el_size: Option<usize>,
69    ) -> MyCow<'a, Vec<El<Complex64>, A>>
70    where
71        R: ?Sized + RingBase,
72        V: VectorView<R::Element>,
73        ToInt: FnMut(&R::Element) -> i64,
74    {
75        let log2_data_size = if let Some(log2_data_size) = log2_el_size {
76            if let Some(data_prep) = data_prep {
77                assert_eq!(log2_data_size, data_prep.log2_data_size);
78            }
79            log2_data_size
80        } else {
81            data.as_iter()
82                .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
83                .max()
84                .unwrap()
85        };
86        assert!(data.len() <= (1 << log2_len));
87        assert!(self.has_sufficient_precision(log2_len, log2_data_size));
88
89        let mut compute_result = || {
90            let mut result = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
91            result.extend(data.as_iter().map(|x| Complex64::RING.from_f64(to_int(x) as f64)));
92            result.resize(1 << log2_len, Complex64::RING.zero());
93            self.get_fft_table(log2_len).unordered_fft(&mut result, Complex64::RING);
94            return result;
95        };
96
97        return if let Some(data_prep) = data_prep {
98            MyCow::Borrowed(data_prep.fft_data.get_or_init(log2_len, compute_result))
99        } else {
100            MyCow::Owned(compute_result())
101        };
102    }
103
104    #[stability::unstable(feature = "enable")]
105    pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
106        self.max_sum_len(log2_len, log2_input_size) > 0
107    }
108
109    fn max_sum_len(&self, log2_len: usize, log2_input_size: usize) -> usize {
110        let fft_table = self.get_fft_table(log2_len);
111        let input_size = 2.0f64.powi(log2_input_size.try_into().unwrap());
112        (0.5 / fft_table.expected_absolute_error(
113            input_size * input_size,
114            input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.0),
115        ))
116        .floor() as usize
117    }
118
119    fn prepare_convolution_impl<R, V, ToInt>(
120        &self,
121        data: V,
122        ring: &R,
123        length_hint: Option<usize>,
124        mut to_int: ToInt,
125        ring_log2_el_size: Option<usize>,
126    ) -> PreparedConvolutionOperand<R, A>
127    where
128        R: ?Sized + RingBase,
129        V: VectorView<R::Element>,
130        ToInt: FnMut(&R::Element) -> i64,
131    {
132        let log2_data_size = if let Some(log2_data_size) = ring_log2_el_size {
133            log2_data_size
134        } else {
135            data.as_iter()
136                .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
137                .max()
138                .unwrap()
139        };
140        let result = PreparedConvolutionOperand {
141            fft_data: LazyVec::new(),
142            ring: PhantomData,
143            log2_data_size,
144        };
145        // if a length-hint is given, initialize the corresponding length entry;
146        // this might avoid confusing performance characteristics when the user does
147        // not expect lazy behavior
148        if let Some(len) = length_hint {
149            let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
150            _ = self.get_fft_data(data, Some(&result), ring, log2_len, to_int, ring_log2_el_size);
151        }
152        return result;
153    }
154
155    fn compute_convolution_impl<R, V1, V2, ToInt, FromInt>(
156        &self,
157        lhs: V1,
158        lhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
159        rhs: V2,
160        rhs_prep: Option<&PreparedConvolutionOperand<R, A>>,
161        dst: &mut [R::Element],
162        ring: &R,
163        mut to_int: ToInt,
164        mut from_int: FromInt,
165        ring_log2_el_size: Option<usize>,
166    ) where
167        R: ?Sized + RingBase,
168        V1: VectorView<R::Element>,
169        V2: VectorView<R::Element>,
170        ToInt: FnMut(&R::Element) -> i64,
171        FromInt: FnMut(i64) -> R::Element,
172    {
173        if lhs.len() == 0 || rhs.len() == 0 {
174            return;
175        }
176        let len = lhs.len() + rhs.len() - 1;
177        assert!(dst.len() >= len);
178        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
179
180        let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
181        let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
182        if rhs_fft.is_owned() {
183            std::mem::swap(&mut lhs_fft, &mut rhs_fft);
184        }
185        let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
186
187        for i in 0..(1 << log2_len) {
188            CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
189        }
190
191        self.get_fft_table(log2_len).unordered_inv_fft(&mut *lhs_fft, CC);
192
193        for i in 0..len {
194            let result = CC.closest_gaussian_int(lhs_fft[i]);
195            debug_assert_eq!(0, result.1);
196            ring.add_assign(&mut dst[i], from_int(result.0));
197        }
198    }
199
200    fn compute_convolution_sum_impl<'a, R, I, V1, V2, ToInt, FromInt>(
201        &self,
202        data: I,
203        dst: &mut [R::Element],
204        ring: &R,
205        mut to_int: ToInt,
206        mut from_int: FromInt,
207        ring_log2_el_size: Option<usize>,
208    ) where
209        R: ?Sized + RingBase,
210        I: Iterator<
211            Item = (
212                V1,
213                Option<&'a PreparedConvolutionOperand<R, A>>,
214                V2,
215                Option<&'a PreparedConvolutionOperand<R, A>>,
216            ),
217        >,
218        V1: VectorView<R::Element>,
219        V2: VectorView<R::Element>,
220        ToInt: FnMut(&R::Element) -> i64,
221        FromInt: FnMut(i64) -> R::Element,
222        Self: 'a,
223        R: 'a,
224    {
225        let len = dst.len();
226        let log2_len = StaticRing::<i64>::RING.abs_log2_ceil(&len.try_into().unwrap()).unwrap();
227        let mut buffer = Vec::with_capacity_in(1 << log2_len, self.allocator.clone());
228        buffer.resize(1 << log2_len, CC.zero());
229
230        let mut count_since_last_reduction = 0;
231        let mut current_max_sum_len = usize::MAX;
232        let mut current_log2_data_size = ring_log2_el_size.unwrap_or_default();
233        for (lhs, lhs_prep, rhs, rhs_prep) in data {
234            if lhs.len() == 0 || rhs.len() == 0 {
235                continue;
236            }
237            assert!(lhs.len() + rhs.len() - 1 <= dst.len());
238
239            if ring_log2_el_size.is_none() {
240                current_log2_data_size = max(
241                    current_log2_data_size,
242                    lhs.as_iter()
243                        .chain(rhs.as_iter())
244                        .map(|x| StaticRing::<i64>::RING.abs_log2_ceil(&to_int(x)).unwrap_or(0))
245                        .max()
246                        .unwrap(),
247                );
248                current_max_sum_len = self.max_sum_len(log2_len, current_log2_data_size);
249            }
250            assert!(current_max_sum_len > 0);
251
252            if count_since_last_reduction > current_max_sum_len {
253                count_since_last_reduction = 0;
254                self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
255                for i in 0..len {
256                    let result = CC.closest_gaussian_int(buffer[i]);
257                    debug_assert_eq!(0, result.1);
258                    ring.add_assign(&mut dst[i], from_int(result.0));
259                }
260                for i in 0..(1 << log2_len) {
261                    buffer[i] = CC.zero();
262                }
263            }
264
265            let mut lhs_fft = self.get_fft_data(lhs, lhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
266            let mut rhs_fft = self.get_fft_data(rhs, rhs_prep, ring, log2_len, &mut to_int, ring_log2_el_size);
267            if rhs_fft.is_owned() {
268                std::mem::swap(&mut lhs_fft, &mut rhs_fft);
269            }
270            let lhs_fft: &mut Vec<El<Complex64>, A> = lhs_fft.to_mut();
271            for i in 0..(1 << log2_len) {
272                CC.mul_assign(&mut lhs_fft[i], rhs_fft[i]);
273                CC.add_assign(&mut buffer[i], lhs_fft[i]);
274            }
275            count_since_last_reduction += 1;
276        }
277        self.get_fft_table(log2_len).unordered_inv_fft(&mut *buffer, CC);
278        for i in 0..len {
279            let result = CC.closest_gaussian_int(buffer[i]);
280            debug_assert_eq!(0, result.1);
281            ring.add_assign(&mut dst[i], from_int(result.0));
282        }
283    }
284}
285
286fn to_int_int<I>(ring: I) -> impl use<I> + Fn(&El<I>) -> i64
287where
288    I: RingStore,
289    I::Type: IntegerRing,
290{
291    move |x| int_cast(ring.clone_el(x), StaticRing::<i64>::RING, &ring)
292}
293
294fn from_int_int<I>(ring: I) -> impl use<I> + Fn(i64) -> El<I>
295where
296    I: RingStore,
297    I::Type: IntegerRing,
298{
299    move |x| int_cast(x, &ring, StaticRing::<i64>::RING)
300}
301
302fn to_int_zn<R>(ring: R) -> impl use<R> + Fn(&El<R>) -> i64
303where
304    R: RingStore,
305    R::Type: ZnRing,
306{
307    move |x| {
308        int_cast(
309            ring.smallest_lift(ring.clone_el(x)),
310            StaticRing::<i64>::RING,
311            ring.integer_ring(),
312        )
313    }
314}
315
316fn from_int_zn<R>(ring: R) -> impl use<R> + Fn(i64) -> El<R>
317where
318    R: RingStore,
319    R::Type: ZnRing,
320{
321    let hom = ring.can_hom(ring.integer_ring()).unwrap().into_raw_hom();
322    move |x| {
323        ring.get_ring().map_in(
324            ring.integer_ring().get_ring(),
325            int_cast(x, ring.integer_ring(), StaticRing::<i64>::RING),
326            &hom,
327        )
328    }
329}
330
331impl<A> Clone for FFTConvolution<A>
332where
333    A: Allocator + Clone,
334{
335    fn clone(&self) -> Self {
336        Self {
337            allocator: self.allocator.clone(),
338            fft_tables: self.fft_tables.clone(),
339        }
340    }
341}
342
343impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
344where
345    A: Allocator,
346{
347    fn from(value: FFTConvolutionZn<A>) -> Self { value.base }
348}
349
350impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
351where
352    A: Allocator,
353{
354    fn from(value: &'a FFTConvolutionZn<A>) -> Self { &value.base }
355}
356
357impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
358where
359    A: Allocator,
360{
361    fn from(value: FFTConvolution<A>) -> Self { FFTConvolutionZn { base: value } }
362}
363
364impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
365where
366    A: Allocator,
367{
368    fn from(value: &'a FFTConvolution<A>) -> Self { unsafe { std::mem::transmute(value) } }
369}
370
371#[stability::unstable(feature = "enable")]
372#[repr(transparent)]
373pub struct FFTConvolutionZn<A = Global> {
374    base: FFTConvolution<A>,
375}
376
377impl<A> Clone for FFTConvolutionZn<A>
378where
379    A: Allocator + Clone,
380{
381    fn clone(&self) -> Self {
382        Self {
383            base: self.base.clone(),
384        }
385    }
386}
387
388impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
389where
390    R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
391    A: Allocator + Clone,
392{
393    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
394
395    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
396        &self,
397        lhs: V1,
398        rhs: V2,
399        dst: &mut [R::Element],
400        ring: S,
401    ) {
402        self.base.compute_convolution_impl(
403            lhs,
404            None,
405            rhs,
406            None,
407            dst,
408            ring.get_ring(),
409            to_int_zn(&ring),
410            from_int_zn(&ring),
411            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
412        )
413    }
414
415    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
416
417    fn prepare_convolution_operand<S, V>(
418        &self,
419        val: V,
420        len_hint: Option<usize>,
421        ring: S,
422    ) -> Self::PreparedConvolutionOperand
423    where
424        S: RingStore<Type = R> + Copy,
425        V: VectorView<R::Element>,
426    {
427        self.base.prepare_convolution_impl(
428            val,
429            ring.get_ring(),
430            len_hint,
431            to_int_zn(&ring),
432            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
433        )
434    }
435
436    fn compute_convolution_prepared<S, V1, V2>(
437        &self,
438        lhs: V1,
439        lhs_prep: Option<&Self::PreparedConvolutionOperand>,
440        rhs: V2,
441        rhs_prep: Option<&Self::PreparedConvolutionOperand>,
442        dst: &mut [R::Element],
443        ring: S,
444    ) where
445        S: RingStore<Type = R> + Copy,
446        V1: VectorView<R::Element>,
447        V2: VectorView<R::Element>,
448    {
449        self.base.compute_convolution_impl(
450            lhs,
451            lhs_prep,
452            rhs,
453            rhs_prep,
454            dst,
455            ring.get_ring(),
456            to_int_zn(&ring),
457            from_int_zn(&ring),
458            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
459        )
460    }
461
462    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
463    where
464        S: RingStore<Type = R> + Copy,
465        I: Iterator<
466            Item = (
467                V1,
468                Option<&'a Self::PreparedConvolutionOperand>,
469                V2,
470                Option<&'a Self::PreparedConvolutionOperand>,
471            ),
472        >,
473        V1: VectorView<R::Element>,
474        V2: VectorView<R::Element>,
475        Self: 'a,
476        R: 'a,
477        Self::PreparedConvolutionOperand: 'a,
478    {
479        self.base.compute_convolution_sum_impl(
480            values,
481            dst,
482            ring.get_ring(),
483            to_int_zn(&ring),
484            from_int_zn(&ring),
485            Some(ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap()),
486        )
487    }
488}
489
490impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
491where
492    I: ?Sized + IntegerRing,
493    A: Allocator + Clone,
494{
495    type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
496
497    fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(
498        &self,
499        lhs: V1,
500        rhs: V2,
501        dst: &mut [I::Element],
502        ring: S,
503    ) {
504        self.compute_convolution_impl(
505            lhs,
506            None,
507            rhs,
508            None,
509            dst,
510            ring.get_ring(),
511            to_int_int(&ring),
512            from_int_int(&ring),
513            None,
514        )
515    }
516
517    fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool { true }
518
519    fn prepare_convolution_operand<S, V>(
520        &self,
521        val: V,
522        len_hint: Option<usize>,
523        ring: S,
524    ) -> Self::PreparedConvolutionOperand
525    where
526        S: RingStore<Type = I> + Copy,
527        V: VectorView<I::Element>,
528    {
529        self.prepare_convolution_impl(val, ring.get_ring(), len_hint, to_int_int(&ring), None)
530    }
531
532    fn compute_convolution_prepared<S, V1, V2>(
533        &self,
534        lhs: V1,
535        lhs_prep: Option<&Self::PreparedConvolutionOperand>,
536        rhs: V2,
537        rhs_prep: Option<&Self::PreparedConvolutionOperand>,
538        dst: &mut [I::Element],
539        ring: S,
540    ) where
541        S: RingStore<Type = I> + Copy,
542        V1: VectorView<I::Element>,
543        V2: VectorView<I::Element>,
544    {
545        self.compute_convolution_impl(
546            lhs,
547            lhs_prep,
548            rhs,
549            rhs_prep,
550            dst,
551            ring.get_ring(),
552            to_int_int(&ring),
553            from_int_int(&ring),
554            None,
555        )
556    }
557
558    fn compute_convolution_sum<'a, S, J, V1, V2>(&self, values: J, dst: &mut [I::Element], ring: S)
559    where
560        S: RingStore<Type = I> + Copy,
561        J: Iterator<
562            Item = (
563                V1,
564                Option<&'a Self::PreparedConvolutionOperand>,
565                V2,
566                Option<&'a Self::PreparedConvolutionOperand>,
567            ),
568        >,
569        V1: VectorView<I::Element>,
570        V2: VectorView<I::Element>,
571        Self: 'a,
572        I: 'a,
573        Self::PreparedConvolutionOperand: 'a,
574    {
575        self.compute_convolution_sum_impl(
576            values,
577            dst,
578            ring.get_ring(),
579            to_int_int(&ring),
580            from_int_int(&ring),
581            None,
582        )
583    }
584}
585
586#[cfg(test)]
587use crate::rings::finite::FiniteRingStore;
588#[cfg(test)]
589use crate::rings::zn::zn_64::Zn;
590
591#[test]
592fn test_convolution_zn() {
593    let convolution: FFTConvolutionZn = FFTConvolution::new().into();
594    let ring = Zn::new(17 * 257);
595
596    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
597}
598
599#[test]
600fn test_convolution_int() {
601    let convolution: FFTConvolution = FFTConvolution::new();
602    let ring = StaticRing::<i64>::RING;
603
604    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
605}
606
607#[test]
608#[should_panic(expected = "precision")]
609fn test_fft_convolution_not_enough_precision() {
610    let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new().into();
611
612    let ring = Zn::new(1099511627791);
613    let lhs = ring.elements().take(1024).collect::<Vec<_>>();
614    let rhs = ring.elements().take(1024).collect::<Vec<_>>();
615    let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
616
617    convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
618}