feanor_math/algorithms/convolution/
mod.rsuse std::alloc::{Allocator, Global};
use std::ops::Deref;
use crate::ring::*;
use crate::seq::subvector::SubvectorView;
use crate::seq::VectorView;
use karatsuba::*;
pub mod karatsuba;
pub mod fft;
pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S);
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
}
#[stability::unstable(feature = "enable")]
pub trait PreparedConvolutionAlgorithm<R: ?Sized + RingBase>: ConvolutionAlgorithm<R> {
type PreparedConvolutionOperand;
fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>;
fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy;
fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
{
assert!(ring.is_commutative());
self.compute_convolution_lhs_prepared(rhs, lhs, dst, ring);
}
fn compute_convolution_inner_product_lhs_prepared<'a, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy,
I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, V)>,
V: VectorView<R::Element>,
Self::PreparedConvolutionOperand: 'a
{
for (lhs, rhs) in values {
self.compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
}
}
fn compute_convolution_inner_product_prepared<'a, S, I>(&self, values: I, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy,
I: Iterator<Item = (&'a Self::PreparedConvolutionOperand, &'a Self::PreparedConvolutionOperand)>,
Self::PreparedConvolutionOperand: 'a
{
for (lhs, rhs) in values {
self.compute_convolution_prepared(lhs, rhs, dst, ring)
}
}
}
impl<'a, R, C> ConvolutionAlgorithm<R> for C
where R: ?Sized + RingBase,
C: Deref,
C::Target: ConvolutionAlgorithm<R>
{
fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(&self, lhs: V1, rhs: V2, dst: &mut [R::Element], ring: S) {
(**self).compute_convolution(lhs, rhs, dst, ring)
}
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool {
(**self).supports_ring(ring)
}
}
impl<'a, R, C> PreparedConvolutionAlgorithm<R> for C
where R: ?Sized + RingBase,
C: Deref,
C::Target: PreparedConvolutionAlgorithm<R>
{
type PreparedConvolutionOperand = <C::Target as PreparedConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
fn prepare_convolution_operand<S, V>(&self, val: V, ring: S) -> Self::PreparedConvolutionOperand
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
{
(**self).prepare_convolution_operand(val, ring)
}
fn compute_convolution_lhs_prepared<S, V>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: V, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
{
(**self).compute_convolution_lhs_prepared(lhs, rhs, dst, ring)
}
fn compute_convolution_prepared<S>(&self, lhs: &Self::PreparedConvolutionOperand, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy
{
(**self).compute_convolution_prepared(lhs, rhs, dst, ring)
}
fn compute_convolution_rhs_prepared<S, V>(&self, lhs: V, rhs: &Self::PreparedConvolutionOperand, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy, V: VectorView<R::Element>
{
(**self).compute_convolution_rhs_prepared(lhs, rhs, dst, ring)
}
fn compute_convolution_inner_product_lhs_prepared<'b, S, I, V>(&self, values: I, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy,
I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, V)>,
V: VectorView<R::Element>,
Self::PreparedConvolutionOperand: 'b
{
(**self).compute_convolution_inner_product_lhs_prepared(values, dst, ring)
}
fn compute_convolution_inner_product_prepared<'b, S, I>(&self, values: I, dst: &mut [R::Element], ring: S)
where S: RingStore<Type = R> + Copy,
I: Iterator<Item = (&'b Self::PreparedConvolutionOperand, &'b Self::PreparedConvolutionOperand)>,
Self::PreparedConvolutionOperand: 'b
{
(**self).compute_convolution_inner_product_prepared(values, dst, ring)
}
}
#[derive(Clone, Copy, Debug)]
pub struct KaratsubaAlgorithm<A: Allocator = Global> {
allocator: A
}
pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
impl<A: Allocator> KaratsubaAlgorithm<A> {
#[stability::unstable(feature = "enable")]
pub const fn new(allocator: A) -> Self {
Self { allocator }
}
}
impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
fn compute_convolution<S: RingStore<Type = R>, V1: VectorView<<R as RingBase>::Element>, V2: VectorView<<R as RingBase>::Element>>(&self, lhs: V1, rhs: V2, dst: &mut[<R as RingBase>::Element], ring: S) {
karatsuba(ring.get_ring().karatsuba_threshold(), dst, SubvectorView::new(&lhs), SubvectorView::new(&rhs), &ring, &self.allocator)
}
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool {
true
}
}
#[stability::unstable(feature = "enable")]
pub trait KaratsubaHint: RingBase {
fn karatsuba_threshold(&self) -> usize;
}
impl<R: RingBase + ?Sized> KaratsubaHint for R {
default fn karatsuba_threshold(&self) -> usize {
0
}
}
#[cfg(test)]
use test;
#[cfg(test)]
use crate::primitive_int::*;
#[bench]
fn bench_naive_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}
#[bench]
fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}