feanor_math/algorithms/convolution/
mod.rs

1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use crate::ring::*;
5use crate::seq::subvector::SubvectorView;
6use crate::seq::*;
7
8use karatsuba::*;
9
10///
11/// Contains an optimized implementation of Karatsuba's for computing convolutions
12/// 
13pub mod karatsuba;
14
15///
16/// Contains an implementation of computing convolutions using complex floating-point FFTs.
17/// 
18pub mod fft;
19
20///
21/// Contains an implementation of computing convolutions using NTTs, i.e. FFTs over
22/// a finite field that has suitable roots of unity.
23/// 
24pub mod ntt;
25
26///
27/// Contains an implementation of computing convolutions by considering them modulo
28/// various primes that are either smaller or allow for suitable roots of unity.
29/// 
30pub mod rns;
31
32///
33/// Trait for objects that can compute a convolution over some ring.
34/// 
35/// # Example
36/// ```rust
37/// # use std::cmp::{min, max};
38/// # use feanor_math::ring::*;
39/// # use feanor_math::primitive_int::*;
40/// # use feanor_math::seq::*;
41/// # use feanor_math::algorithms::convolution::*;
42/// struct NaiveConvolution;
43/// // we support all rings!
44/// impl<R: ?Sized + RingBase> ConvolutionAlgorithm<R> for NaiveConvolution {
45///     fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
46///         for i in 0..(lhs.len() + rhs.len() - 1) {
47///             for j in max(0, i as isize - rhs.len() as isize + 1)..min(lhs.len() as isize, i as isize + 1) {
48///                 ring.add_assign(&mut dst[i], ring.mul_ref(lhs.at(j as usize), rhs.at(i - j as usize)));
49///             }
50///         }
51///     }
52///     fn supports_ring<S: RingStore<Type = R>>(&self, _: S) -> bool
53///         where S: Copy
54///     { true }
55/// }
56/// let lhs = [1, 2, 3, 4, 5];
57/// let rhs = [2, 3, 4, 5, 6];
58/// let mut expected = [0; 10];
59/// let mut actual = [0; 10];
60/// STANDARD_CONVOLUTION.compute_convolution(lhs, rhs, &mut expected, StaticRing::<i64>::RING);
61/// NaiveConvolution.compute_convolution(lhs, rhs, &mut actual, StaticRing::<i64>::RING);
62/// assert_eq!(expected, actual);
63/// ```
64/// 
65pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
66
67    ///
68    /// Additional data associated to a list of ring elements, which can be used to
69    /// compute a convolution where this list is one of the operands faster.
70    ///
71    /// For more details, see [`ConvolutionAlgorithm::prepare_convolution_operand()`].
72    /// Note that a `PreparedConvolutionOperand` can only be used for convolutions
73    /// with the same list of values it was created for.
74    /// 
75    #[stability::unstable(feature = "enable")]
76    type PreparedConvolutionOperand = ();
77
78    ///
79    /// Elementwise adds the convolution of `lhs` and `rhs` to `dst`.
80    /// 
81    /// In other words, computes `dst[i] += sum_j lhs[j] * rhs[i - j]` for all `i`, where
82    /// `j` runs through all positive integers for which `lhs[j]` and `rhs[i - j]` are defined,
83    /// i.e. not out-of-bounds.
84    /// 
85    /// In particular, it is necessary that `dst.len() >= lhs.len() + rhs.len() - 1`. However,
86    /// to allow for more efficient implementations, it is instead required that 
87    /// `dst.len() >= lhs.len() + rhs.len()`.
88    /// 
89    /// # Panic
90    /// 
91    /// Panics if `dst` is shorter than `lhs.len() + rhs.len() - 1`. May panic if `dst` is shorter
92    /// than `lhs.len() + rhs.len()`.
93    /// 
94    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
95
96    ///
97    /// Returns whether this convolution algorithm supports computations of 
98    /// the given ring.
99    /// 
100    /// Note that most algorithms will support all rings of type `R`. However in some cases,
101    /// e.g. for finite fields, required data might only be precomputed for some moduli,
102    /// and thus only these will be supported.
103    /// 
104    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
105
106    ///
107    /// Takes an input list of values and computes an opaque [`ConvolutionAlgorithm::PreparedConvolutionOperand`],
108    /// which can be used to compute future convolutions with this list of values faster.
109    /// 
110    /// Although the [`ConvolutionAlgorithm::PreparedConvolutionOperand`] does not have any explicit reference
111    /// to the list of values it was created for, passing it to [`ConvolutionAlgorithm::compute_convolution_prepared()`]
112    /// with another list of values will give erroneous results.
113    /// 
114    /// # Length-dependence when preparing a convolution
115    /// 
116    /// For some algorithms, different data is required to speed up the convolution with an operand, depending on the 
117    /// length of the other operand. For example, for FFT-based convolutions, the prepared data would consist of the
118    /// Fourier transform of the list of values, zero-padded to a length that can store the complete result of the
119    /// (future) convolution.
120    /// 
121    /// To handle this, implementations can make use of the `length_hint`, which - if given - should be an upper bound
122    /// to the length of the output of any future convolution that uses the given operand. Alternatively, implementations
123    /// are encouraged to not compute any data during [`ConvolutionAlgorithm::prepare_convolution_operand()`],
124    /// but initialize an object with interior mutability, and use it to cache data computed during
125    /// [`ConvolutionAlgorithm::compute_convolution_prepared()`].
126    /// 
127    /// TODO: At next breaking release, remove the default implementation
128    /// 
129    /// # Example
130    /// 
131    /// ```rust
132    /// # use feanor_math::ring::*;
133    /// # use feanor_math::algorithms::convolution::*;
134    /// # use feanor_math::algorithms::convolution::ntt::*;
135    /// # use feanor_math::rings::zn::*;
136    /// # use feanor_math::rings::zn::zn_64::*;
137    /// # use feanor_math::rings::finite::*;
138    /// let ring = Zn::new(65537);
139    /// let convolution = NTTConvolution::new(ring);
140    /// let lhs = ring.elements().take(10).collect::<Vec<_>>();
141    /// let rhs = ring.elements().take(10).collect::<Vec<_>>();
142    /// // "standard" use
143    /// let mut expected = (0..19).map(|_| ring.zero()).collect::<Vec<_>>();
144    /// convolution.compute_convolution(&lhs, &rhs, &mut expected, ring);
145    /// 
146    /// // "prepared" variant
147    /// let lhs_prep = convolution.prepare_convolution_operand(&lhs, None, ring);
148    /// let rhs_prep = convolution.prepare_convolution_operand(&rhs, None, ring);
149    /// let mut actual = (0..19).map(|_| ring.zero()).collect::<Vec<_>>();
150    /// // this will now be faster than `convolution.compute_convolution()`
151    /// convolution.compute_convolution_prepared(&lhs, Some(&lhs_prep), &rhs, Some(&rhs_prep), &mut actual, ring);
152    /// println!("{:?}, {:?}", actual.iter().map(|x| ring.format(x)).collect::<Vec<_>>(), expected.iter().map(|x| ring.format(x)).collect::<Vec<_>>());
153    /// assert!(expected.iter().zip(actual.iter()).all(|(l, r)| ring.eq_el(l, r)));
154    /// ```
155    /// 
156    #[stability::unstable(feature = "enable")]
157    fn prepare_convolution_operand<S, V>(&self, _val: V, _length_hint: Option<usize>, _ring: S) -> Self::PreparedConvolutionOperand
158        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
159    {
160        struct ProduceUnitType;
161        trait ProduceValue<T> {
162            fn produce() -> T;
163        }
164        impl<T> ProduceValue<T> for ProduceUnitType {
165            default fn produce() -> T {
166                panic!("if you specialize ConvolutionAlgorithm::PreparedConvolutionOperand, you must also specialize ConvolutionAlgorithm::prepare_convolution_operand()")
167            }
168        }
169        impl ProduceValue<()> for ProduceUnitType {
170            fn produce() -> () {}
171        }
172        return <ProduceUnitType as ProduceValue<Self::PreparedConvolutionOperand>>::produce();
173    }
174    
175    ///
176    /// Elementwise adds the convolution of `lhs` and `rhs` to `dst`. If provided, the given
177    /// prepared convolution operands are used for a faster computation.
178    /// 
179    /// When called with `None` as both the prepared convolution operands, this is exactly
180    /// equivalent to [`ConvolutionAlgorithm::compute_convolution()`].
181    /// 
182    #[stability::unstable(feature = "enable")]
183    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, _lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, _rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
184        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
185    {
186        self.compute_convolution(lhs, rhs, dst, ring)
187    }
188
189    ///
190    /// Computes a convolution for each tuple in the given sequence, and sums the result of each convolution
191    /// to `dst`.
192    /// 
193    /// In other words, this computes `dst[k] += sum_l sum_(i + j = k) values[l][i] * values[l][k]`.
194    /// It can be faster than calling [`ConvolutionAlgorithm::prepare_convolution_operand()`].
195    /// 
196    #[stability::unstable(feature = "enable")]
197    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
198        where S: RingStore<Type = R> + Copy, 
199            I: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
200            V1: VectorView<R::Element>,
201            V2: VectorView<R::Element>,
202            Self: 'a,
203            R: 'a
204    {
205        for (lhs, lhs_prep, rhs, rhs_prep) in values {
206            self.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring)
207        }
208    }
209}
210
211impl<'a, R, C> ConvolutionAlgorithm<R> for C
212    where R: ?Sized + RingBase,
213        C: Deref,
214        C::Target: ConvolutionAlgorithm<R>
215{
216    type PreparedConvolutionOperand = <C::Target as ConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
217
218    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
219        (**self).compute_convolution(lhs, rhs, dst, ring)
220    }
221
222    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
223        (**self).supports_ring(ring)
224    }
225
226    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
227        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
228    {
229        (**self).prepare_convolution_operand(val, len_hint, ring)
230    }
231
232    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
233        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
234    {
235        (**self).compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
236    }
237
238    fn compute_convolution_sum<'b, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
239        where S: RingStore<Type = R> + Copy, 
240            I: ExactSizeIterator<Item = (V1, Option<&'b Self::PreparedConvolutionOperand>, V2, Option<&'b Self::PreparedConvolutionOperand>)>,
241            V1: VectorView<R::Element>,
242            V2: VectorView<R::Element>,
243            Self: 'b,
244            R: 'b
245    {
246        (**self).compute_convolution_sum(values, dst, ring);
247    }
248}
249
250///
251/// Implementation of convolutions that uses Karatsuba's algorithm
252/// with a threshold defined by [`KaratsubaHint`].
253/// 
254#[derive(Clone, Copy, Debug)]
255pub struct KaratsubaAlgorithm<A: Allocator = Global> {
256    allocator: A
257}
258
259///
260/// Good default algorithm for computing convolutions, using Karatsuba's algorithm
261/// with a threshold defined by [`KaratsubaHint`].
262/// 
263pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
264
265impl<A: Allocator> KaratsubaAlgorithm<A> {
266    
267    #[stability::unstable(feature = "enable")]
268    pub const fn new(allocator: A) -> Self {
269        Self { allocator }
270    }
271}
272
273impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
274
275    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
276        karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
277    }
278
279    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
280        true
281    }
282}
283
284///
285/// Trait to allow rings to customize the parameters with which [`KaratsubaAlgorithm`] will
286/// compute convolutions over the ring.
287/// 
288#[stability::unstable(feature = "enable")]
289pub trait KaratsubaHint: RingBase {
290
291    ///
292    /// Define a threshold from which on [`KaratsubaAlgorithm`] will use the Karatsuba algorithm.
293    /// 
294    /// Concretely, when this returns `k`, [`KaratsubaAlgorithm`] will reduce the 
295    /// convolution down to ones on slices of size `2^k`, and compute their convolution naively. The default
296    /// value is `0`, but if the considered rings have fast multiplication (compared to addition), then setting
297    /// it higher may result in a performance gain.
298    /// 
299    fn karatsuba_threshold(&self) -> usize;
300}
301
302impl<R: RingBase + ?Sized> KaratsubaHint for R {
303
304    default fn karatsuba_threshold(&self) -> usize {
305        0
306    }
307}
308
309#[cfg(test)]
310use test;
311#[cfg(test)]
312use crate::primitive_int::*;
313
314#[bench]
315fn bench_naive_mul(bencher: &mut test::Bencher) {
316    let a: Vec<i32> = (0..32).collect();
317    let b: Vec<i32> = (0..32).collect();
318    let mut c: Vec<i32> = (0..64).collect();
319    bencher.iter(|| {
320        c.clear();
321        c.resize(64, 0);
322        karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
323        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
324        assert_eq!(c[62], 31 * 31);
325    });
326}
327
328#[bench]
329fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
330    let a: Vec<i32> = (0..32).collect();
331    let b: Vec<i32> = (0..32).collect();
332    let mut c: Vec<i32> = (0..64).collect();
333    bencher.iter(|| {
334        c.clear();
335        c.resize(64, 0);
336        karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
337        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
338        assert_eq!(c[62], 31 * 31);
339    });
340}
341
342
343#[allow(missing_docs)]
344#[cfg(any(test, feature = "generic_tests"))]
345pub mod generic_tests {
346    use std::cmp::min;
347    use crate::homomorphism::*;
348    use crate::ring::*;
349    use super::*;
350
351    pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
352        where C: ConvolutionAlgorithm<R::Type>,
353            R: RingStore
354    {
355        for lhs_len in [2, 3, 4, 15] {
356            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
357                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
358                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
359                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
360                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
361                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
362                } else {
363                    0
364                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
365    
366                let mut actual = Vec::new();
367                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
368                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
369                for i in 0..(lhs_len + rhs_len) {
370                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
371                }
372
373                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
374                    i * i +
375                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
376                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
377                } else {
378                    0
379                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
380    
381                let mut actual = Vec::new();
382                actual.extend((0..(lhs_len + rhs_len)).map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))));
383                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
384                for i in 0..(lhs_len + rhs_len) {
385                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
386                }
387            }
388        }
389        test_prepared_convolution(convolution, ring, scale);
390    }
391
392    fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
393        where C: ConvolutionAlgorithm<R::Type>,
394            R: RingStore
395    {
396        for lhs_len in [2, 3, 4, 14, 15] {
397            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
398                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
399                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
400                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
401                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
402                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
403                } else {
404                    0
405                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
406    
407                let mut actual = Vec::new();
408                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
409                convolution.compute_convolution_prepared(
410                    &lhs,
411                    Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
412                    &rhs,
413                    Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
414                    &mut actual, 
415                    &ring
416                );
417                for i in 0..(lhs_len + rhs_len) {
418                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
419                }
420
421                let mut actual = Vec::new();
422                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
423                convolution.compute_convolution_prepared(
424                    &lhs,
425                    Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
426                    &rhs,
427                    None,
428                    &mut actual, 
429                    &ring
430                );
431                for i in 0..(lhs_len + rhs_len) {
432                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
433                }
434                
435                let mut actual = Vec::new();
436                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
437                convolution.compute_convolution_prepared(
438                    &lhs,
439                    None,
440                    &rhs,
441                    Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
442                    &mut actual, 
443                    &ring
444                );
445                for i in 0..(lhs_len + rhs_len) {
446                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
447                }
448
449                let mut actual = Vec::new();
450                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
451                let data = [
452                    (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], Some(convolution.prepare_convolution_operand(&rhs, None, &ring))),
453                    (&rhs[..], None, &lhs[..], None)
454                ];
455                convolution.compute_convolution_sum(
456                    data.as_fn().map_fn(|(l, l_prep, r, r_prep): &(_, _, _, _)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
457                    &mut actual, 
458                    &ring
459                );
460                for i in 0..(lhs_len + rhs_len) {
461                    assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
462                }
463
464                let mut actual = Vec::new();
465                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
466                let data = [
467                    (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], None),
468                    (&rhs[..], None, &lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)))
469                ];
470                convolution.compute_convolution_sum(
471                    data.as_fn().map_fn(|(l, l_prep, r, r_prep)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
472                    &mut actual, 
473                    &ring
474                );
475                for i in 0..(lhs_len + rhs_len) {
476                    assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
477                }
478            }
479        }
480    }
481}