feanor_math/algorithms/convolution/
fft.rs

1use std::alloc::{Allocator, Global};
2use std::marker::PhantomData;
3
4use crate::algorithms::fft::complex_fft::FFTErrorEstimate;
5use crate::algorithms::fft::cooley_tuckey::CooleyTuckeyFFT;
6use crate::algorithms::fft::FFTAlgorithm;
7use crate::lazy::LazyVec;
8use crate::primitive_int::StaticRingBase;
9use crate::integer::*;
10use crate::ring::*;
11use crate::seq::*;
12use crate::primitive_int::*;
13use crate::homomorphism::*;
14use crate::rings::float_complex::*;
15use crate::rings::zn::*;
16
17use super::{ConvolutionAlgorithm, PreparedConvolutionAlgorithm, PreparedConvolutionOperation};
18
19const ZZ: StaticRing<i64> = StaticRing::RING;
20const CC: Complex64 = Complex64::RING;
21
22#[stability::unstable(feature = "enable")]
23pub struct FFTConvolution<A = Global> {
24    allocator: A,
25    fft_tables: LazyVec<CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>>>
26}
27
28#[stability::unstable(feature = "enable")]
29pub struct PreparedConvolutionOperand<R, A = Global>
30    where R: ?Sized + RingBase,
31        A: Allocator + Clone
32{
33    ring: PhantomData<Box<R>>,
34    original_data: Vec<f64, A>,
35    fft_data: Vec<El<Complex64>, A>
36}
37
38impl<A> FFTConvolution<A>
39    where A: Allocator + Clone
40{
41    #[stability::unstable(feature = "enable")]
42    pub fn new_with(allocator: A) -> Self {
43        Self {
44            allocator: allocator,
45            fft_tables: LazyVec::new()
46        }
47    }
48
49    fn get_fft_table(&self, log2_len: usize) -> &CooleyTuckeyFFT<Complex64Base, Complex64Base, Identity<Complex64>> {
50        return self.fft_tables.get_or_init(log2_len, || CooleyTuckeyFFT::for_complex(CC, log2_len));
51    }
52
53    #[stability::unstable(feature = "enable")]
54    pub fn has_sufficient_precision(&self, log2_len: usize, log2_input_size: usize) -> bool {
55        let fft_table = self.get_fft_table(log2_len);
56        let input_size = 2f64.powi(log2_input_size as i32);
57        fft_table.expected_absolute_error(input_size * input_size, input_size * input_size * f64::EPSILON + fft_table.expected_absolute_error(input_size, 0.)) < 0.5
58    }
59
60    fn compute_convolution_impl(&self, mut lhs: Vec<El<Complex64>, A>, rhs: &[El<Complex64>], target_len: usize) -> impl Iterator<Item = i64> {
61        let log2_n = ZZ.abs_log2_ceil(&(lhs.len() as i64)).unwrap();
62        assert_eq!(lhs.len(), 1 << log2_n);
63        assert_eq!(rhs.len(), 1 << log2_n);
64
65        for i in 0..(1 << log2_n) {
66            CC.mul_assign(&mut lhs[i], rhs[i]);
67        }
68        self.get_fft_table(log2_n).unordered_inv_fft(&mut lhs[..], CC);
69        (0..target_len).map(move |i| {
70            let x = CC.closest_gaussian_int(lhs[i]);
71            debug_assert!(x.1 == 0);
72            return x.0;
73        })
74    }
75
76    fn prepare_convolution_impl<V>(&self, data: V, log2_n: usize, log2_data_size: Option<usize>) -> (usize, Vec<El<Complex64>, A>) 
77        where V: VectorFn<f64>
78    {
79        assert!(data.len() <= 1 << log2_n);
80        let log2_data_size = if let Some(log2_data_size) = log2_data_size {
81            log2_data_size 
82        } else {
83            data.iter().map(|x| x.abs()).max_by(f64::total_cmp).unwrap().log2().ceil() as usize
84        };
85        assert!(self.has_sufficient_precision(log2_n, log2_data_size));
86
87        let mut fft_data = Vec::with_capacity_in(1 << log2_n, self.allocator.clone());
88        fft_data.extend(data.iter().map(|x| CC.from_f64(x)));
89        fft_data.resize(1 << log2_n, CC.zero());
90        let fft = self.get_fft_table(log2_n);
91        fft.unordered_fft(&mut fft_data[..], CC);
92        return (log2_data_size, fft_data);
93    }
94}
95
96impl<A> Clone for FFTConvolution<A>
97    where A: Allocator + Clone
98{
99    fn clone(&self) -> Self {
100        Self {
101            allocator: self.allocator.clone(),
102            fft_tables: self.fft_tables.clone()
103        }
104    }
105}
106
107impl<A> From<FFTConvolutionZn<A>> for FFTConvolution<A>
108    where A: Allocator
109{
110    fn from(value: FFTConvolutionZn<A>) -> Self {
111        value.base
112    }
113}
114
115impl<'a, A> From<&'a FFTConvolutionZn<A>> for &'a FFTConvolution<A>
116    where A: Allocator
117{
118    fn from(value: &'a FFTConvolutionZn<A>) -> Self {
119        &value.base
120    }
121}
122
123impl<A> From<FFTConvolution<A>> for FFTConvolutionZn<A>
124    where A: Allocator
125{
126    fn from(value: FFTConvolution<A>) -> Self {
127        FFTConvolutionZn { base: value }
128    }
129}
130
131impl<'a, A> From<&'a FFTConvolution<A>> for &'a FFTConvolutionZn<A>
132    where A: Allocator
133{
134    fn from(value: &'a FFTConvolution<A>) -> Self {
135        unsafe { std::mem::transmute(value) }
136    }
137}
138
139#[stability::unstable(feature = "enable")]
140#[repr(transparent)]
141pub struct FFTConvolutionZn<A = Global> {
142    base: FFTConvolution<A>
143}
144
145impl<A> Clone for FFTConvolutionZn<A>
146    where A: Allocator + Clone
147{
148    fn clone(&self) -> Self {
149        Self { base: self.base.clone() }
150    }
151}
152
153impl<R, A> ConvolutionAlgorithm<R> for FFTConvolutionZn<A>
154    where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
155        A: Allocator + Clone
156{
157    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
158        if lhs.len() == 0 || rhs.len() == 0 {
159            return;
160        }
161        let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
162        let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
163        let lhs_prep = self.base.prepare_convolution_impl(lhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_n, Some(log2_data_size)).1;
164        let rhs_prep = self.base.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_n, Some(log2_data_size)).1;
165        let hom = ring.can_hom(&ZZ).unwrap();
166        for (i, x) in self.base.compute_convolution_impl(lhs_prep, &rhs_prep, lhs.len() + rhs.len() - 1).enumerate() {
167            ring.add_assign(&mut dst[i], hom.map(x));
168        }
169    }
170
171    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
172        true
173    }
174
175    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
176        where F: PreparedConvolutionOperation<Self, R>
177    {
178        Ok(function.execute())
179    }
180}
181
182impl<I, A> ConvolutionAlgorithm<I> for FFTConvolution<A>
183    where I: ?Sized + IntegerRing,
184        A: Allocator + Clone
185{
186    fn compute_convolution<S: RingStore<Type = I>, V1: VectorView<I::Element>, V2: VectorView<I::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [I::Element], ring: S) {
187        if lhs.len() == 0 || rhs.len() == 0 {
188            return;
189        }
190        let log2_n = ZZ.abs_log2_ceil(&((lhs.len() + rhs.len()) as i64)).unwrap();
191        let lhs_prep = self.prepare_convolution_impl(lhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_n, None).1;
192        let rhs_prep = self.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_n, None).1;
193        for (i, x) in self.compute_convolution_impl(lhs_prep, &rhs_prep, lhs.len() + rhs.len() - 1).enumerate() {
194            ring.add_assign(&mut dst[i], int_cast(x, &ring, ZZ));
195        }
196    }
197
198    fn supports_ring<S: RingStore<Type = I> + Copy>(&self, _ring: S) -> bool {
199        true
200    }
201
202    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
203        where F: PreparedConvolutionOperation<Self, I>
204    {
205        Ok(function.execute())
206    }
207}
208
209impl<I, A> PreparedConvolutionAlgorithm<I> for FFTConvolution<A>
210    where I: ?Sized + IntegerRing,
211        A: Allocator + Clone
212{
213    type PreparedConvolutionOperand = PreparedConvolutionOperand<I, A>;
214
215    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
216        where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
217    {
218        let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
219        let log2_n_out = log2_n_in + 1;
220        let mut original_data = Vec::new_in(self.allocator.clone());
221        original_data.extend(val.clone_ring_els(&ring).iter().map(|x| int_cast(x, ZZ, &ring) as f64));
222        let (_log2_data_size, fft_data) = self.prepare_convolution_impl(original_data.copy_els(), log2_n_out, None);
223        return PreparedConvolutionOperand {
224            fft_data: fft_data,
225            original_data: original_data,
226            ring: PhantomData
227        };
228    }
229
230    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [I::Element], ring: S)
231        where S: RingStore<Type = I> + Copy, V: VectorView<I::Element>
232    {
233        assert!(ring.is_commutative());
234        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
235        assert_eq!(lhs.fft_data.len(), 1 << log2_lhs);
236        let target_len = lhs.original_data.len() + rhs.len() - 1;
237        let log2_target_len = ZZ.abs_log2_ceil(&(target_len as i64)).unwrap().max(log2_lhs);
238        let els = if log2_target_len > log2_lhs {
239            assert!(target_len <= 1 << log2_target_len);
240            let lhs_prep = self.prepare_convolution_impl(lhs.original_data.copy_els(), log2_target_len, None).1;
241            let rhs_prep = self.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_target_len, None).1;
242            self.compute_convolution_impl(lhs_prep, &rhs_prep, target_len)
243        } else {
244            assert!(log2_lhs == log2_target_len || log2_lhs == log2_target_len + 1);
245            assert!(target_len <= 1 << log2_lhs);
246            self.compute_convolution_impl(
247                self.prepare_convolution_impl(rhs.clone_ring_els(ring).map_fn(|x| int_cast(x, ZZ, &ring) as f64), log2_lhs, None).1,
248                &lhs.fft_data,
249                target_len
250            )
251        };
252        for (i, x) in els.enumerate() {
253            ring.add_assign(&mut dst[i], int_cast(x, ring, ZZ));
254        }
255    }
256
257    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [I::Element], ring: S)
258        where S: RingStore<Type = I> + Copy
259    {
260        assert!(ring.is_commutative());
261        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
262        assert_eq!(1 << log2_lhs, lhs.fft_data.len());
263        let log2_rhs = ZZ.abs_log2_ceil(&(rhs.fft_data.len() as i64)).unwrap();
264        assert_eq!(1 << log2_rhs, rhs.fft_data.len());
265        let target_len = lhs.original_data.len() + rhs.original_data.len() - 1;
266        assert!(target_len <= 1 << log2_lhs || target_len <= 1 << log2_rhs);
267        let els = match log2_lhs.cmp(&log2_rhs) {
268            std::cmp::Ordering::Equal => self.compute_convolution_impl(lhs.fft_data.clone(), &rhs.fft_data, target_len),
269            std::cmp::Ordering::Greater => self.compute_convolution_impl(self.prepare_convolution_impl(rhs.original_data.copy_els(), log2_lhs, None).1, &lhs.fft_data, target_len),
270            std::cmp::Ordering::Less => self.compute_convolution_impl(self.prepare_convolution_impl(lhs.original_data.copy_els(), log2_rhs, None).1, &rhs.fft_data, target_len)
271        };
272        for (i, x) in els.enumerate() {
273            ring.add_assign(&mut dst[i], int_cast(x, ring, ZZ));
274        }
275    }
276}
277
278impl<R, A> PreparedConvolutionAlgorithm<R> for FFTConvolutionZn<A>
279    where R: ?Sized + ZnRing + CanHomFrom<StaticRingBase<i64>>,
280        A: Allocator + Clone
281{
282    type PreparedConvolutionOperand = PreparedConvolutionOperand<R, A>;
283
284    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
285        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
286    {
287        let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
288        let log2_n_in = ZZ.abs_log2_ceil(&(val.len() as i64)).unwrap();
289        let log2_n_out = log2_n_in + 1;
290        let mut original_data = Vec::new_in(self.base.allocator.clone());
291        original_data.extend(val.clone_ring_els(&ring).iter().map(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64));
292        let (_log2_data_size, fft_data) = self.base.prepare_convolution_impl(original_data.copy_els(), log2_n_out, Some(log2_data_size));
293        return PreparedConvolutionOperand {
294            fft_data: fft_data,
295            original_data: original_data,
296            ring: PhantomData
297        };
298    }
299
300    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
301        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
302    {
303        assert!(ring.is_commutative());
304        let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
305        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
306        assert_eq!(lhs.fft_data.len(), 1 << log2_lhs);
307        let target_len = lhs.original_data.len() + rhs.len() - 1;
308        let log2_target_len = ZZ.abs_log2_ceil(&(target_len as i64)).unwrap().max(log2_lhs);
309        let els = if log2_target_len > log2_lhs {
310            assert!(target_len <= 1 << log2_target_len);
311            let lhs_prep = self.base.prepare_convolution_impl(lhs.original_data.copy_els(), log2_target_len, Some(log2_data_size)).1;
312            let rhs_prep = self.base.prepare_convolution_impl(rhs.clone_ring_els(&ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_target_len, Some(log2_data_size)).1;
313            self.base.compute_convolution_impl(lhs_prep, &rhs_prep, target_len)
314        } else {
315            assert!(log2_lhs == log2_target_len || log2_lhs == log2_target_len + 1);
316            assert!(target_len <= 1 << log2_lhs);
317            self.base.compute_convolution_impl(
318                self.base.prepare_convolution_impl(rhs.clone_ring_els(ring).map_fn(|x| int_cast(ring.smallest_lift(x), ZZ, ring.integer_ring()) as f64), log2_lhs, Some(log2_data_size)).1,
319                &lhs.fft_data,
320                target_len
321            )
322        };
323        let hom = ring.can_hom(&ZZ).unwrap();
324        for (i, x) in els.enumerate() {
325            ring.add_assign(&mut dst[i], hom.map(x));
326        }
327    }
328
329    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
330        where S: RingStore<Type = R> + Copy
331    {
332        assert!(ring.is_commutative());
333        let log2_data_size = ring.integer_ring().abs_log2_ceil(ring.modulus()).unwrap();
334        let log2_lhs = ZZ.abs_log2_ceil(&(lhs.fft_data.len() as i64)).unwrap();
335        assert_eq!(1 << log2_lhs, lhs.fft_data.len());
336        let log2_rhs = ZZ.abs_log2_ceil(&(rhs.fft_data.len() as i64)).unwrap();
337        assert_eq!(1 << log2_rhs, rhs.fft_data.len());
338        let target_len = lhs.original_data.len() + rhs.original_data.len() - 1;
339        assert!(target_len <= 1 << log2_lhs || target_len <= 1 << log2_rhs);
340        let els = match log2_lhs.cmp(&log2_rhs) {
341            std::cmp::Ordering::Equal => self.base.compute_convolution_impl(lhs.fft_data.clone(), &rhs.fft_data, target_len),
342            std::cmp::Ordering::Greater => self.base.compute_convolution_impl(self.base.prepare_convolution_impl(rhs.original_data.copy_els(), log2_lhs, Some(log2_data_size)).1, &lhs.fft_data, target_len),
343            std::cmp::Ordering::Less => self.base.compute_convolution_impl(self.base.prepare_convolution_impl(lhs.original_data.copy_els(), log2_rhs, Some(log2_data_size)).1, &rhs.fft_data, target_len)
344        };
345        let hom = ring.can_hom(&ZZ).unwrap();
346        for (i, x) in els.enumerate() {
347            ring.add_assign(&mut dst[i], hom.map(x));
348        }
349    }
350}
351
352#[cfg(test)]
353use crate::rings::finite::FiniteRingStore;
354#[cfg(test)]
355use crate::rings::zn::zn_64::Zn;
356
357#[test]
358fn test_convolution_zn() {
359    let convolution: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
360    let ring = Zn::new(17 * 257);
361
362    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
363    super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
364}
365
366#[test]
367fn test_convolution_int() {
368    let convolution: FFTConvolution = FFTConvolution::new_with(Global);
369    let ring = StaticRing::<i64>::RING;
370
371    super::generic_tests::test_convolution(&convolution, &ring, ring.one());
372    super::generic_tests::test_prepared_convolution(&convolution, &ring, ring.one());
373}
374
375#[test]
376#[should_panic(expected = "precision")]
377fn test_fft_convolution_not_enough_precision() {
378    let convolution_algorithm: FFTConvolutionZn = FFTConvolution::new_with(Global).into();
379
380    let ring = Zn::new(1099511627791);
381    let lhs = ring.elements().take(1024).collect::<Vec<_>>();
382    let rhs = ring.elements().take(1024).collect::<Vec<_>>();
383    let mut actual = (0..(lhs.len() + rhs.len())).map(|_| ring.zero()).collect::<Vec<_>>();
384
385    convolution_algorithm.compute_convolution(&lhs, &rhs, &mut actual, &ring);
386}