use std::alloc::{Allocator, Global};
use std::ops::Deref;
use karatsuba::*;
use crate::ring::*;
use crate::seq::subvector::SubvectorView;
use crate::seq::*;
pub mod karatsuba;
pub mod fft;
pub mod ntt;
pub mod rns;
pub trait ConvolutionAlgorithm<R: ?Sized + RingBase> {
#[stability::unstable(feature = "enable")]
type PreparedConvolutionOperand = ();
fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
&self,
lhs: V1,
rhs: V2,
dst: &mut [R::Element],
ring: S,
);
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool;
#[stability::unstable(feature = "enable")]
fn prepare_convolution_operand<S, V>(
&self,
_val: V,
_length_hint: Option<usize>,
_ring: S,
) -> Self::PreparedConvolutionOperand
where
S: RingStore<Type = R> + Copy,
V: VectorView<R::Element>,
{
struct ProduceUnitType;
trait ProduceValue<T> {
fn produce() -> T;
}
impl<T> ProduceValue<T> for ProduceUnitType {
default fn produce() -> T {
panic!(
"if you specialize ConvolutionAlgorithm::PreparedConvolutionOperand, you must also specialize ConvolutionAlgorithm::prepare_convolution_operand()"
)
}
}
impl ProduceValue<()> for ProduceUnitType {
fn produce() {}
}
return <ProduceUnitType as ProduceValue<Self::PreparedConvolutionOperand>>::produce();
}
#[stability::unstable(feature = "enable")]
fn compute_convolution_prepared<S, V1, V2>(
&self,
lhs: V1,
_lhs_prep: Option<&Self::PreparedConvolutionOperand>,
rhs: V2,
_rhs_prep: Option<&Self::PreparedConvolutionOperand>,
dst: &mut [R::Element],
ring: S,
) where
S: RingStore<Type = R> + Copy,
V1: VectorView<R::Element>,
V2: VectorView<R::Element>,
{
self.compute_convolution(lhs, rhs, dst, ring)
}
#[stability::unstable(feature = "enable")]
fn compute_convolution_sum<'a, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
where
S: RingStore<Type = R> + Copy,
I: ExactSizeIterator<
Item = (
V1,
Option<&'a Self::PreparedConvolutionOperand>,
V2,
Option<&'a Self::PreparedConvolutionOperand>,
),
>,
V1: VectorView<R::Element>,
V2: VectorView<R::Element>,
Self: 'a,
R: 'a,
{
for (lhs, lhs_prep, rhs, rhs_prep) in values {
self.compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring)
}
}
}
impl<'a, R, C> ConvolutionAlgorithm<R> for C
where
R: ?Sized + RingBase,
C: Deref,
C::Target: ConvolutionAlgorithm<R>,
{
type PreparedConvolutionOperand = <C::Target as ConvolutionAlgorithm<R>>::PreparedConvolutionOperand;
fn compute_convolution<S: RingStore<Type = R> + Copy, V1: VectorView<R::Element>, V2: VectorView<R::Element>>(
&self,
lhs: V1,
rhs: V2,
dst: &mut [R::Element],
ring: S,
) {
(**self).compute_convolution(lhs, rhs, dst, ring)
}
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, ring: S) -> bool { (**self).supports_ring(ring) }
fn prepare_convolution_operand<S, V>(
&self,
val: V,
len_hint: Option<usize>,
ring: S,
) -> Self::PreparedConvolutionOperand
where
S: RingStore<Type = R> + Copy,
V: VectorView<R::Element>,
{
(**self).prepare_convolution_operand(val, len_hint, ring)
}
fn compute_convolution_prepared<S, V1, V2>(
&self,
lhs: V1,
lhs_prep: Option<&Self::PreparedConvolutionOperand>,
rhs: V2,
rhs_prep: Option<&Self::PreparedConvolutionOperand>,
dst: &mut [R::Element],
ring: S,
) where
S: RingStore<Type = R> + Copy,
V1: VectorView<R::Element>,
V2: VectorView<R::Element>,
{
(**self).compute_convolution_prepared(lhs, lhs_prep, rhs, rhs_prep, dst, ring);
}
fn compute_convolution_sum<'b, S, I, V1, V2>(&self, values: I, dst: &mut [R::Element], ring: S)
where
S: RingStore<Type = R> + Copy,
I: ExactSizeIterator<
Item = (
V1,
Option<&'b Self::PreparedConvolutionOperand>,
V2,
Option<&'b Self::PreparedConvolutionOperand>,
),
>,
V1: VectorView<R::Element>,
V2: VectorView<R::Element>,
Self: 'b,
R: 'b,
{
(**self).compute_convolution_sum(values, dst, ring);
}
}
#[derive(Clone, Copy, Debug)]
pub struct KaratsubaAlgorithm<A: Allocator = Global> {
allocator: A,
}
pub const STANDARD_CONVOLUTION: KaratsubaAlgorithm = KaratsubaAlgorithm::new(Global);
impl<A: Allocator> KaratsubaAlgorithm<A> {
#[stability::unstable(feature = "enable")]
pub const fn new(allocator: A) -> Self { Self { allocator } }
}
impl<R: ?Sized + RingBase, A: Allocator> ConvolutionAlgorithm<R> for KaratsubaAlgorithm<A> {
fn compute_convolution<
S: RingStore<Type = R>,
V1: VectorView<<R as RingBase>::Element>,
V2: VectorView<<R as RingBase>::Element>,
>(
&self,
lhs: V1,
rhs: V2,
dst: &mut [<R as RingBase>::Element],
ring: S,
) {
karatsuba(
ring.get_ring().karatsuba_threshold(),
dst,
SubvectorView::new(&lhs),
SubvectorView::new(&rhs),
&ring,
&self.allocator,
)
}
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
}
pub struct SchoolbookConvolution;
impl<R: ?Sized + RingBase> ConvolutionAlgorithm<R> for SchoolbookConvolution {
fn supports_ring<S: RingStore<Type = R> + Copy>(&self, _ring: S) -> bool { true }
fn compute_convolution<S, V1, V2>(&self, lhs: V1, rhs: V2, dst: &mut [<R as RingBase>::Element], ring: S)
where
S: RingStore<Type = R> + Copy,
V1: VectorView<<R as RingBase>::Element>,
V2: VectorView<<R as RingBase>::Element>,
{
naive_assign_mul::<_, _, _, _, true>(dst, lhs, rhs, ring)
}
}
#[stability::unstable(feature = "enable")]
pub trait KaratsubaHint: RingBase {
fn karatsuba_threshold(&self) -> usize;
}
impl<R: RingBase + ?Sized> KaratsubaHint for R {
default fn karatsuba_threshold(&self) -> usize { 0 }
}
#[cfg(test)]
use test;
#[cfg(test)]
use crate::primitive_int::*;
#[bench]
fn bench_naive_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(10, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}
#[bench]
fn bench_karatsuba_mul(bencher: &mut test::Bencher) {
let a: Vec<i32> = (0..32).collect();
let b: Vec<i32> = (0..32).collect();
let mut c: Vec<i32> = (0..64).collect();
bencher.iter(|| {
c.clear();
c.resize(64, 0);
karatsuba(4, &mut c[..], &a[..], &b[..], StaticRing::<i32>::RING, &Global);
assert_eq!(c[31], 31 * 31 * 32 / 2 - 31 * (31 + 1) * (31 * 2 + 1) / 6);
assert_eq!(c[62], 31 * 31);
});
}
#[test]
fn test_schoolbook_convolution() { generic_tests::test_convolution(SchoolbookConvolution, StaticRing::<i64>::RING, 1); }
#[allow(missing_docs)]
#[cfg(any(test, feature = "generic_tests"))]
pub mod generic_tests {
use std::cmp::min;
use super::*;
use crate::homomorphism::*;
pub fn test_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
where
C: ConvolutionAlgorithm<R::Type>,
R: RingStore,
{
for lhs_len in [2, 3, 4, 15] {
for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
let lhs = (0..lhs_len)
.map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
.collect::<Vec<_>>();
let rhs = (0..rhs_len)
.map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
.collect::<Vec<_>>();
let expected = (0..(lhs_len + rhs_len))
.map(|i| {
if i < lhs_len + rhs_len {
min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6
- (i - 1 - min(i, rhs_len - 1))
* (i - min(i, rhs_len - 1))
* (i + 2 * min(i, rhs_len - 1) + 1)
/ 6
} else {
0
}
})
.map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
.collect::<Vec<_>>();
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
}
let expected = (0..(lhs_len + rhs_len))
.map(|i| {
if i < lhs_len + rhs_len {
i * i
+ min(i, lhs_len - 1)
* (min(i, lhs_len - 1) + 1)
* (3 * i - 2 * min(i, lhs_len - 1) - 1)
/ 6
- (i - 1 - min(i, rhs_len - 1))
* (i - min(i, rhs_len - 1))
* (i + 2 * min(i, rhs_len - 1) + 1)
/ 6
} else {
0
}
})
.map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
.collect::<Vec<_>>();
let mut actual = Vec::new();
actual.extend(
(0..(lhs_len + rhs_len))
.map(|i| ring.mul(ring.int_hom().map(i * i), ring.pow(ring.clone_el(&scale), 2))),
);
convolution.compute_convolution(&lhs, &rhs, &mut actual, &ring);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
}
}
}
test_prepared_convolution(convolution, ring, scale);
}
fn test_prepared_convolution<C, R>(convolution: C, ring: R, scale: El<R>)
where
C: ConvolutionAlgorithm<R::Type>,
R: RingStore,
{
for lhs_len in [2, 3, 4, 14, 15] {
for rhs_len in [2, 3, 4, 15, 31, 32, 33] {
let lhs = (0..lhs_len)
.map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
.collect::<Vec<_>>();
let rhs = (0..rhs_len)
.map(|i| ring.mul_ref_snd(ring.int_hom().map(i), &scale))
.collect::<Vec<_>>();
let expected = (0..(lhs_len + rhs_len))
.map(|i| {
if i < lhs_len + rhs_len {
min(i, lhs_len - 1) * (min(i, lhs_len - 1) + 1) * (3 * i - 2 * min(i, lhs_len - 1) - 1) / 6
- (i - 1 - min(i, rhs_len - 1))
* (i - min(i, rhs_len - 1))
* (i + 2 * min(i, rhs_len - 1) + 1)
/ 6
} else {
0
}
})
.map(|x| ring.mul(ring.int_hom().map(x), ring.pow(ring.clone_el(&scale), 2)))
.collect::<Vec<_>>();
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
convolution.compute_convolution_prepared(
&lhs,
Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
&rhs,
Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
&mut actual,
&ring,
);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
}
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
convolution.compute_convolution_prepared(
&lhs,
Some(&convolution.prepare_convolution_operand(&lhs, None, &ring)),
&rhs,
None,
&mut actual,
&ring,
);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
}
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
convolution.compute_convolution_prepared(
&lhs,
None,
&rhs,
Some(&convolution.prepare_convolution_operand(&rhs, None, &ring)),
&mut actual,
&ring,
);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(&ring, &expected[i as usize], &actual[i as usize]);
}
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
let data = [
(
&lhs[..],
Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
&rhs[..],
Some(convolution.prepare_convolution_operand(&rhs, None, &ring)),
),
(&rhs[..], None, &lhs[..], None),
];
convolution.compute_convolution_sum(
data.as_fn()
.map_fn(|(l, l_prep, r, r_prep): &(_, _, _, _)| (l, l_prep.as_ref(), r, r_prep.as_ref()))
.iter(),
&mut actual,
&ring,
);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(
&ring,
&ring.add_ref(&expected[i as usize], &expected[i as usize]),
&actual[i as usize]
);
}
let mut actual = Vec::new();
actual.resize_with((lhs_len + rhs_len) as usize, || ring.zero());
let data = [
(
&lhs[..],
Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
&rhs[..],
None,
),
(
&rhs[..],
None,
&lhs[..],
Some(convolution.prepare_convolution_operand(&lhs, None, &ring)),
),
];
convolution.compute_convolution_sum(
data.as_fn()
.map_fn(|(l, l_prep, r, r_prep)| (l, l_prep.as_ref(), r, r_prep.as_ref()))
.iter(),
&mut actual,
&ring,
);
for i in 0..(lhs_len + rhs_len) {
assert_el_eq!(
&ring,
&ring.add_ref(&expected[i as usize], &expected[i as usize]),
&actual[i as usize]
);
}
}
}
}
}