1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use crate::ring::*;
5use crate::seq::subvector::SubvectorView;
6use crate::seq::*;
7
8use karatsuba::*;
9
10pub mod karatsuba;
14
15pub mod fft;
19
20pub mod ntt;
25
26pub mod rns;
31
32pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
66
67    #[stability::unstable(feature = "enable")]
76    type PreparedConvolutionOperand = ();
77
78    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
95
96    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
105
106    #[stability::unstable(feature = "enable")]
157    fn prepare_convolution_operand<S, V>(&self, _val: V, _length_hint: Option<usize>, _ring: S) -> Self::PreparedConvolutionOperand
158        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
159    {
160        struct ProduceUnitType;
161        trait ProduceValue<T> {
162            fn produce() -> T;
163        }
164        impl<T> ProduceValue<T> for ProduceUnitType {
165            default fn produce() -> T {
166                panic!("if you specialize ConvolutionAlgorithm::PreparedConvolutionOperand, you must also specialize ConvolutionAlgorithm::prepare_convolution_operand()")
167            }
168        }
169        impl ProduceValue<()> for ProduceUnitType {
170            fn produce() -> () {}
171        }
172        return <ProduceUnitType as ProduceValue<Self::PreparedConvolutionOperand>>::produce();
173    }
174    
175    #[stability::unstable(feature = "enable")]
183    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, _lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, _rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
184        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
185    {
186        self.compute_convolution(lhs, rhs, dst, ring)
187    }
188
189    #[stability::unstable(feature = "enable")]
197    fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
198        where S: RingStore<Type = R> + Copy, 
199            I: ExactSizeIterator<Item = (V1, Option<&'a Self::PreparedConvolutionOperand>, V2, Option<&'a Self::PreparedConvolutionOperand>)>,
200            V1: VectorView<R::Element>,
201            V2: VectorView<R::Element>,
202            Self: 'a,
203            R: 'a
204    {
205        for (lhs, lhs_prep, rhs, rhs_prep) in values {
206            self.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring)
207        }
208    }
209}
210
211impl<'a, R, C> ConvolutionAlgorithm<R> for C
212    where R: ?Sized + RingBase,
213        C: Deref,
214        C::Target: ConvolutionAlgorithm<R>
215{
216    type PreparedConvolutionOperand = <C::Target as ConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
217
218    fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
219        (**self).compute_convolution(lhs, rhs, dst, ring)
220    }
221
222    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
223        (**self).supports_ring(ring)
224    }
225
226    fn prepare_convolution_operand<S, V>(&self, val: V, len_hint: Option<usize>, ring: S) -> Self::PreparedConvolutionOperand
227        where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
228    {
229        (**self).prepare_convolution_operand(val, len_hint, ring)
230    }
231
232    fn compute_convolution_prepared<S, V1, V2>(&self, lhs: V1, lhs_prep: Option<&Self::PreparedConvolutionOperand>, rhs: V2, rhs_prep: Option<&Self::PreparedConvolutionOperand>, dst: &mut [R::Element], ring: S)
233        where S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>
234    {
235        (**self).compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
236    }
237
238    fn compute_convolution_sum<'b, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S) 
239        where S: RingStore<Type = R> + Copy, 
240            I: ExactSizeIterator<Item = (V1, Option<&'b Self::PreparedConvolutionOperand>, V2, Option<&'b Self::PreparedConvolutionOperand>)>,
241            V1: VectorView<R::Element>,
242            V2: VectorView<R::Element>,
243            Self: 'b,
244            R: 'b
245    {
246        (**self).compute_convolution_sum(values, dst, ring);
247    }
248}
249
250#[derive(Clone, Copy, Debug)]
255pub struct KaratsubaAlgorithm<A: Allocator = Global> {
256    allocator: A
257}
258
259pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
264
265impl<A: Allocator> KaratsubaAlgorithm<A> {
266    
267    #[stability::unstable(feature = "enable")]
268    pub const fn new(allocator: A) -> Self {
269        Self { allocator }
270    }
271}
272
273impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
274
275    fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
276        karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
277    }
278
279    fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
280        true
281    }
282}
283
284#[stability::unstable(feature = "enable")]
289pub trait KaratsubaHint: RingBase {
290
291    fn karatsuba_threshold(&self) -> usize;
300}
301
302impl<R: RingBase + ?Sized> KaratsubaHint for R {
303
304    default fn karatsuba_threshold(&self) -> usize {
305        0
306    }
307}
308
309#[cfg(test)]
310use test;
311#[cfg(test)]
312use crate::primitive_int::*;
313
314#[bench]
315fn bench_naive_mul(bencher: &mut test::Bencher) {
316    let a: Vec<i32> = (0..32).collect();
317    let b: Vec<i32> = (0..32).collect();
318    let mut c: Vec<i32> = (0..64).collect();
319    bencher.iter(|| {
320        c.clear();
321        c.resize(64, 0);
322        karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
323        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
324        assert_eq!(c[62], 31 * 31);
325    });
326}
327
328#[bench]
329fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
330    let a: Vec<i32> = (0..32).collect();
331    let b: Vec<i32> = (0..32).collect();
332    let mut c: Vec<i32> = (0..64).collect();
333    bencher.iter(|| {
334        c.clear();
335        c.resize(64, 0);
336        karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
337        assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
338        assert_eq!(c[62], 31 * 31);
339    });
340}
341
342
343#[allow(missing_docs)]
344#[cfg(any(test, feature = "generic_tests"))]
345pub mod generic_tests {
346    use std::cmp::min;
347    use crate::homomorphism::*;
348    use crate::ring::*;
349    use super::*;
350
351    pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
352        where C: ConvolutionAlgorithm<R::Type>,
353            R: RingStore
354    {
355        for lhs_len in [2, 3, 4, 15] {
356            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
357                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
358                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
359                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
360                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
361                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
362                } else {
363                    0
364                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
365    
366                let mut actual = Vec::new();
367                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
368                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
369                for i in 0..(lhs_len + rhs_len) {
370                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
371                }
372
373                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
374                    i * i +
375                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
376                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
377                } else {
378                    0
379                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
380    
381                let mut actual = Vec::new();
382                actual.extend((0..(lhs_len + rhs_len)).map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))));
383                convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
384                for i in 0..(lhs_len + rhs_len) {
385                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
386                }
387            }
388        }
389        test_prepared_convolution(convolution, ring, scale);
390    }
391
392    fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
393        where C: ConvolutionAlgorithm<R::Type>,
394            R: RingStore
395    {
396        for lhs_len in [2, 3, 4, 14, 15] {
397            for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
398                let lhs = (0..lhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
399                let rhs = (0..rhs_len).map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale)).collect::<Vec<_>>();
400                let expected = (0..(lhs_len + rhs_len)).map(|i| if i < lhs_len + rhs_len {
401                    min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6 - 
402                    (i - 1 - min(i, rhs_len - 1)) * (i - min(i, rhs_len - 1)) * (i + 2 * min(i, rhs_len - 1) + 1) / 6
403                } else {
404                    0
405                }).map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2))).collect::<Vec<_>>();
406    
407                let mut actual = Vec::new();
408                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
409                convolution.compute_convolution_prepared(
410                    &lhs,
411                    Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
412                    &rhs,
413                    Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
414                    &mut actual, 
415                    &ring
416                );
417                for i in 0..(lhs_len + rhs_len) {
418                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
419                }
420
421                let mut actual = Vec::new();
422                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
423                convolution.compute_convolution_prepared(
424                    &lhs,
425                    Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
426                    &rhs,
427                    None,
428                    &mut actual, 
429                    &ring
430                );
431                for i in 0..(lhs_len + rhs_len) {
432                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
433                }
434                
435                let mut actual = Vec::new();
436                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
437                convolution.compute_convolution_prepared(
438                    &lhs,
439                    None,
440                    &rhs,
441                    Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
442                    &mut actual, 
443                    &ring
444                );
445                for i in 0..(lhs_len + rhs_len) {
446                    assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
447                }
448
449                let mut actual = Vec::new();
450                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
451                let data = [
452                    (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], Some(convolution.prepare_convolution_operand(&rhs, None, &ring))),
453                    (&rhs[..], None, &lhs[..], None)
454                ];
455                convolution.compute_convolution_sum(
456                    data.as_fn().map_fn(|(l, l_prep, r, r_prep): &(_, _, _, _)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
457                    &mut actual, 
458                    &ring
459                );
460                for i in 0..(lhs_len + rhs_len) {
461                    assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
462                }
463
464                let mut actual = Vec::new();
465                actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
466                let data = [
467                    (&lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)), &rhs[..], None),
468                    (&rhs[..], None, &lhs[..], Some(convolution.prepare_convolution_operand(&lhs, None, &ring)))
469                ];
470                convolution.compute_convolution_sum(
471                    data.as_fn().map_fn(|(l, l_prep, r, r_prep)| (l, l_prep.as_ref(), r, r_prep.as_ref())).iter(),
472                    &mut actual, 
473                    &ring
474                );
475                for i in 0..(lhs_len + rhs_len) {
476                    assert_el_eq!(&ring, &ring.add_ref(&expected[i as usize], &expected[i as usize]), &actual[i as usize]);
477                }
478            }
479        }
480    }
481}