use std::alloc::{Allocator, Global};
use crate::ring::*;
use crate::seq::subvector::SubvectorView;
use crate::seq::VectorView;
use karatsuba::*;
pub mod karatsuba;
#[stability::unstable(feature = "enable")]
pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
fn compute_convolution<V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: &R);
}
impl<'a, R: ?Sized + RingBase, C: ConvolutionAlgorithm<R>> ConvolutionAlgorithm<R> for &'a C {
fn compute_convolution<V1:VectorView<R::Element>, V2:VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: &R) {
(**self).compute_convolution(lhs, rhs, dst, ring)
}
}
#[stability::unstable(feature = "enable")]
#[derive(Clone, Copy)]
pub struct KaratsubaAlgorithm<A: Allocator = Global> {
allocator: A
}
#[stability::unstable(feature = "enable")]
pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
impl<A: Allocator> KaratsubaAlgorithm<A> {
#[stability::unstable(feature = "enable")]
pub const fn new(allocator: A) -> Self {
Self { allocator }
}
}
impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
fn compute_convolution<V1:VectorView<<R as RingBase>::Element>, V2:VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: &R) {
karatsuba(ring.karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), RingRef::new(ring), &self.allocator)
}
}
#[stability::unstable(feature = "enable")]
pub trait KaratsubaHint: RingBase {
fn karatsuba_threshold(&self) -> usize;
}
impl<R: RingBase + ?Sized> KaratsubaHint for R {
default fn karatsuba_threshold(&self) -> usize {
0
}
}
#[cfg(test)]
use test;
#[cfg(test)]
use crate::primitive_int::*;
#[bench]
fn bench_naive_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}
#[bench]
fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}