feanor_math/algorithms/convolution/
mod.rs

1use std::alloc::{Allocator, Global};
2use std::marker::PhantomData;
3use std::ops::Deref;
4
5use crate::ring::*;
6use crate::seq::subvector::SubvectorView;
7use crate::seq::VectorView;
8
9use karatsuba::*;
10
11///
12/// Contains an optimized implementation of Karatsuba's for computing convolutions
13/// 
14pub mod karatsuba;
15
16///
17/// Contains an implementation of computing convolutions using complex FFTs.
18/// 
19pub mod fft;
20
21pub mod ntt;
22
23pub mod rns;
24
25///
26/// Trait for objects that can compute a convolution over some ring.
27/// 
28/// # Example
29/// ```
30/// # use std::cmp::{min, max};
31/// # use feanor_math::ring::*;
32/// # use feanor_math::primitive_int::*;
33/// # use feanor_math::seq::*;
34/// # use feanor_math::algorithms::convolution::*;
35/// struct NaiveConvolution;
36/// // we support all rings!
37/// impl<R: ?Sized + RingBase> ConvolutionAlgorithm<R> for NaiveConvolution {
38///     fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
39///         for i in 0..(lhs.len() + rhs.len() - 1) {
40///             for j in max(0, i as isize - rhs.len() as isize + 1)..min(lhs.len() as isize, i as isize + 1) {
41///                 ring.add_assign(&mut dst[i], ring.mul_ref(lhs.at(j as usize), rhs.at(i - j as usize)));
42///             }
43///         }
44///     }
45///     fn supports_ring<S: RingStore<Type = R>>(&self, _: S) -> bool
46///         where S: Copy
47///     { true }
48/// }
49/// let lhs = [1, 2, 3, 4, 5];
50/// let rhs = [2, 3, 4, 5, 6];
51/// let mut expected = [0; 10];
52/// let mut actual = [0; 10];
53/// STANDARD_CONVOLUTION.compute_convolution(lhs, rhs, &mut expected, StaticRing::<i64>::RING);
54/// NaiveConvolution.compute_convolution(lhs, rhs, &mut actual, StaticRing::<i64>::RING);
55/// assert_eq!(expected, actual);
56/// ```
57/// 
58pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
59
60    ///
61    /// Elementwise adds the convolution of `lhs` and `rhs` to `dst`.
62    /// 
63    /// In other words, computes `dst[i] += sum_j lhs[j] * rhs[i - j]` for all `i`, where
64    /// `j` runs through all positive integers for which `lhs[j]` and `rhs[i - j]` are defined,
65    /// i.e. not out-of-bounds.
66    /// 
67    /// In particular, it is necessary that `dst.len() >= lhs.len() + rhs.len() - 1`. However,
68    /// to allow for more efficient implementations, it is instead required that 
69    /// `dst.len() >= lhs.len() + rhs.len()`.
70    /// 
71    /// # Panic
72    /// 
73    /// Panics if `dst` is shorter than `lhs.len() + rhs.len() - 1`. May panic if `dst` is shorter
74    /// than `lhs.len() + rhs.len()`.
75    /// 
76    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
77
78    ///
79    /// Returns whether this convolution algorithm supports computations of 
80    /// the given ring.
81    /// 
82    /// Note that most algorithms will support all rings of type `R`. However in some cases,
83    /// e.g. for finite fields, required data might only be precomputed for some moduli,
84    /// and thus only these will be supported.
85    /// 
86    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
87
88    ///
89    /// If this convolution implements [`PreparedConvolutionAlgorithm`], then
90    /// the given function is called and its result is returned. Otherwise,
91    /// `Err` is returned.
92    /// 
93    #[stability::unstable(feature = "enable")]
94    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
95        where F: PreparedConvolutionOperation<Self, R>
96    {
97        Err(function)
98    }
99}
100
101///
102/// Operation that only makes sense when `C` implements [`PreparedConvolutionAlgorithm`].
103/// Used together with [`ConvolutionAlgorithm::specialize_prepared_convolution()`] as
104/// a workaround for specialization.
105/// 
106#[stability::unstable(feature = "enable")]
107pub trait PreparedConvolutionOperation<C: ?Sized, R: ?Sized + RingBase> {
108
109    type Output;
110
111    fn execute(self) -> Self::Output
112        where C: PreparedConvolutionAlgorithm<R>;
113}
114
115///
116/// Trait for convolution algorithms that can "prepare" one (or both) operands in advance
117/// by computing additional data, and then use this data to perform the actual convolution
118/// more efficiently.
119/// 
120#[stability::unstable(feature = "enable")]
121pub trait PreparedConvolutionAlgorithm<R: ?Sized + RingBase>: ConvolutionAlgorithm<R> {
122
123    type PreparedConvolutionOperand;
124
125    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
126        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
127
128    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
129        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
130
131    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
132        where S: RingStore<Type = R> + Copy;
133
134    fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
135        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
136    {
137        assert!(ring.is_commutative());
138        self.compute_convolution_lhs_prepared(rhs, lhs, dst, ring);
139    }
140
141    fn compute_convolution_inner_product_lhs_prepared<'a, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S) 
142        where S: RingStore<Type = R> + Copy, 
143            I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
144            V: VectorView<R::Element>,
145            Self: 'a,
146            R: 'a,
147            Self::PreparedConvolutionOperand: 'a
148    {
149        for (lhs, rhs) in values {
150            self.compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
151        }
152    }
153
154    fn compute_convolution_inner_product_prepared<'a, S, I>(&self, values: I, dst: &mut [R::Element], ring: S) 
155        where S: RingStore<Type = R> + Copy, 
156            I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
157            Self::PreparedConvolutionOperand: 'a,
158            Self: 'a,
159            R: 'a,
160    {
161        for (lhs, rhs) in values {
162            self.compute_convolution_prepared(lhs, rhs, dst, ring)
163        }
164    }
165}
166
167impl<'a, R, C> ConvolutionAlgorithm<R> for C
168    where R: ?Sized + RingBase,
169        C: Deref,
170        C::Target: ConvolutionAlgorithm<R>
171{
172    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
173        (**self).compute_convolution(lhs, rhs, dst, ring)
174    }
175
176    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
177        (**self).supports_ring(ring)
178    }
179
180    fn specialize_prepared_convolution<F>(function: F) -> Result<F::Output, F>
181        where F: PreparedConvolutionOperation<Self, R>
182    {
183        struct CallFunction<F, C, R>
184            where R: ?Sized + RingBase,
185                C: Deref,
186                C::Target: ConvolutionAlgorithm<R>,
187                F: PreparedConvolutionOperation<C, R>
188        {
189            convolution: PhantomData<Box<C>>,
190            ring: PhantomData<Box<R>>,
191            function: F
192        }
193        impl<F, C, R> PreparedConvolutionOperation<C::Target, R> for CallFunction<F, C, R>
194            where R: ?Sized + RingBase,
195                C: Deref,
196                C::Target: ConvolutionAlgorithm<R>,
197                F: PreparedConvolutionOperation<C, R>
198        {
199            type Output = F::Output;
200
201            fn execute(self) -> Self::Output
202                where C::Target:PreparedConvolutionAlgorithm<R>
203            {
204                self.function.execute()
205            }
206        }
207        return <C::Target as ConvolutionAlgorithm<R>>::specialize_prepared_convolution::<CallFunction<F, C, R>>(CallFunction {
208            function: function,
209            ring: PhantomData,
210            convolution: PhantomData
211        }).map_err(|f| f.function);
212    }
213}
214
215impl<'a, R, C> PreparedConvolutionAlgorithm<R> for C
216    where R: ?Sized + RingBase,
217        C: Deref,
218        C::Target: PreparedConvolutionAlgorithm<R>
219{
220    type PreparedConvolutionOperand = <C::Target as PreparedConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
221
222    fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
223        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
224    {
225        (**self).prepare_convolution_operand(val, ring)
226    }
227
228    fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
229        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
230    {
231        (**self).compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
232    }
233
234    fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
235        where S: RingStore<Type = R> + Copy
236    {
237        (**self).compute_convolution_prepared(lhs, rhs, dst, ring)
238    }
239
240    fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
241        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
242    {
243        (**self).compute_convolution_rhs_prepared(lhs, rhs, dst, ring)
244    }
245
246    fn compute_convolution_inner_product_lhs_prepared<'b, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S) 
247        where S: RingStore<Type = R> + Copy, 
248            I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, V)>,
249            V: VectorView<R::Element>,
250            Self: 'b,
251            R: 'b,
252            Self::PreparedConvolutionOperand: 'b
253    {
254        (**self).compute_convolution_inner_product_lhs_prepared(values, dst, ring)
255    }
256
257    fn compute_convolution_inner_product_prepared<'b, S, I>(&self, values: I, dst: &mut [R::Element], ring: S) 
258        where S: RingStore<Type = R> + Copy, 
259            I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, &'b Self::PreparedConvolutionOperand)>,
260            Self: 'b,
261            R: 'b,
262            Self::PreparedConvolutionOperand: 'b
263    {
264        (**self).compute_convolution_inner_product_prepared(values, dst, ring)
265    }
266}
267
268///
269/// Implementation of convolutions that uses Karatsuba's algorithm
270/// with a threshold defined by [`KaratsubaHint`].
271/// 
272#[derive(Clone, Copy, Debug)]
273pub struct KaratsubaAlgorithm<A: Allocator = Global> {
274    allocator: A
275}
276
277///
278/// Good default algorithm for computing convolutions, using Karatsuba's algorithm
279/// with a threshold defined by [`KaratsubaHint`].
280/// 
281pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
282
283impl<A: Allocator> KaratsubaAlgorithm<A> {
284    
285    #[stability::unstable(feature = "enable")]
286    pub const fn new(allocator: A) -> Self {
287        Self { allocator }
288    }
289}
290
291impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
292
293    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
294        karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
295    }
296
297    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
298        true
299    }
300}
301
302///
303/// Trait to allow rings to customize the parameters with which [`KaratsubaAlgorithm`] will
304/// compute convolutions over the ring.
305/// 
306#[stability::unstable(feature = "enable")]
307pub trait KaratsubaHint: RingBase {
308
309    ///
310    /// Define a threshold from which on [`KaratsubaAlgorithm`] will use the Karatsuba algorithm.
311    /// 
312    /// Concretely, when this returns `k`, [`KaratsubaAlgorithm`] will reduce the 
313    /// convolution down to ones on slices of size `2^k`, and compute their convolution naively. The default
314    /// value is `0`, but if the considered rings have fast multiplication (compared to addition), then setting
315    /// it higher may result in a performance gain.
316    /// 
317    fn karatsuba_threshold(&self) -> usize;
318}
319
320impl<R: RingBase + ?Sized> KaratsubaHint for R {
321
322    default fn karatsuba_threshold(&self) -> usize {
323        0
324    }
325}
326
327#[cfg(test)]
328use test;
329#[cfg(test)]
330use crate::primitive_int::*;
331
332#[bench]
333fn bench_naive_mul(bencher: &mut test::Bencher) {
334    let a: Vec<i32> = (0..32).collect();
335    let b: Vec<i32> = (0..32).collect();
336    let mut c: Vec<i32> = (0..64).collect();
337    bencher.iter(|| {
338        c.clear();
339        c.resize(64, 0);
340        karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
341        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
342        assert_eq!(c[62], 31 * 31);
343    });
344}
345
346#[bench]
347fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
348    let a: Vec<i32> = (0..32).collect();
349    let b: Vec<i32> = (0..32).collect();
350    let mut c: Vec<i32> = (0..64).collect();
351    bencher.iter(|| {
352        c.clear();
353        c.resize(64, 0);
354        karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
355        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
356        assert_eq!(c[62], 31 * 31);
357    });
358}
359
360
361#[allow(missing_docs)]
362#[cfg(any(test, feature = "generic_tests"))]
363pub mod generic_tests {
364    use std::cmp::min;
365    use crate::homomorphism::*;
366    use crate::ring::*;
367    use super::*;
368
369    pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
370        where C: ConvolutionAlgorithm<R::Type>,
371            R: RingStore
372    {
373        for lhs_len in [2, 3, 4, 15] {
374            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
375                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
376                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
377                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
378                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
379                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
380                } else {
381                    0
382                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
383    
384                let mut actual = Vec::new();
385                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
386                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
387                for i in 0..(lhs_len + rhs_len) {
388                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
389                }
390
391                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
392                    i * i +
393                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
394                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
395                } else {
396                    0
397                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
398    
399                let mut actual = Vec::new();
400                actual.extend((0..(lhs_len + rhs_len)).map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))));
401                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
402                for i in 0..(lhs_len + rhs_len) {
403                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
404                }
405            }
406        }
407    }
408
409    #[stability::unstable(feature = "enable")]
410    pub fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
411        where C: PreparedConvolutionAlgorithm<R::Type>,
412            R: RingStore
413    {
414        for lhs_len in [2, 3, 4, 14, 15] {
415            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
416                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
417                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
418                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
419                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
420                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
421                } else {
422                    0
423                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
424    
425                let mut actual = Vec::new();
426                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
427                convolution.compute_convolution_prepared(
428                    &convolution.prepare_convolution_operand(&lhs, &ring),
429                    &convolution.prepare_convolution_operand(&rhs, &ring),
430                    &mut actual, 
431                    &ring
432                );
433                for i in 0..(lhs_len + rhs_len) {
434                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
435                }
436
437                let mut actual = Vec::new();
438                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
439                convolution.compute_convolution_lhs_prepared(
440                    &convolution.prepare_convolution_operand(&lhs, &ring),
441                    &rhs,
442                    &mut actual, 
443                    &ring
444                );
445                for i in 0..(lhs_len + rhs_len) {
446                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
447                }
448
449                let mut actual = Vec::new();
450                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
451                let data = [
452                    (convolution.prepare_convolution_operand(&lhs, &ring), convolution.prepare_convolution_operand(&rhs, &ring)),
453                    (convolution.prepare_convolution_operand(&[ring.one()], &ring), convolution.prepare_convolution_operand(&[ring.one()], &ring))
454                ];
455                convolution.compute_convolution_inner_product_prepared(
456                    data.iter().map(|(l, r)| (l, r)),
457                    &mut actual, 
458                    &ring
459                );
460                assert_el_eq!(&ring, ring.add_ref_fst(&expected[0], ring.one()), &actual[0]);
461                for i in 1..(lhs_len + rhs_len) {
462                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
463                }
464
465                let mut actual = Vec::new();
466                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
467                let data = [
468                    (convolution.prepare_convolution_operand(&lhs, &ring), rhs),
469                    (convolution.prepare_convolution_operand(&[ring.one()], &ring), vec![ring.one()])
470                ];
471                convolution.compute_convolution_inner_product_lhs_prepared(
472                    data.iter().map(|(l, r)| (l, r)),
473                    &mut actual, 
474                    &ring
475                );
476                assert_el_eq!(&ring, ring.add_ref_fst(&expected[0], ring.one()), &actual[0]);
477                for i in 1..(lhs_len + rhs_len) {
478                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
479                }
480            }
481        }
482    }
483}