use std::alloc::{Allocator, Global};
use std::ops::Deref;
use strassen::*;
use crate::matrix::*;
use crate::ring::*;
pub mod strassen;
#[stability::unstable(feature = "enable")]
pub trait ComputeInnerProduct: RingBase {
fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
&self,
els: I,
) -> Self::Element
where
Self::Element: 'a,
Self: 'a;
fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
&self,
els: I,
) -> Self::Element
where
Self::Element: 'a,
Self: 'a;
fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
}
impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
&self,
els: I,
) -> Self::Element
where
Self::Element: 'a,
{
let mut result = self.zero();
for (l, r) in els {
result = self.fma(l, &r, result);
}
return result;
}
default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
&self,
els: I,
) -> Self::Element
where
Self::Element: 'a,
{
let mut result = self.zero();
for (l, r) in els {
result = self.fma(l, r, result);
}
return result;
}
default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
let mut result = self.zero();
for (l, r) in els {
result = self.fma(&l, &r, result);
}
return result;
}
}
#[stability::unstable(feature = "enable")]
pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy;
fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
mut dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy,
{
for i in 0..dst.row_count() {
for j in 0..dst.col_count() {
*dst.at_mut(i, j) = ring.zero();
}
}
self.add_matmul(lhs, rhs, dst, ring);
}
}
impl<R, T> MatmulAlgorithm<R> for T
where
R: ?Sized + RingBase,
T: Deref,
T::Target: MatmulAlgorithm<R>,
{
fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy,
{
(**self).add_matmul(lhs, rhs, dst, ring)
}
fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy,
{
(**self).matmul(lhs, rhs, dst, ring)
}
}
#[stability::unstable(feature = "enable")]
pub trait StrassenHint: RingBase {
fn strassen_threshold(&self) -> usize;
}
impl<R: RingBase + ?Sized> StrassenHint for R {
default fn strassen_threshold(&self) -> usize { 0 }
}
#[stability::unstable(feature = "enable")]
pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
#[stability::unstable(feature = "enable")]
#[derive(Clone, Copy)]
pub struct StrassenAlgorithm<A: Allocator = Global> {
allocator: A,
}
impl<A: Allocator> StrassenAlgorithm<A> {
#[stability::unstable(feature = "enable")]
pub const fn new(allocator: A) -> Self { Self { allocator } }
}
impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy,
{
strassen::<_, _, _, _, _, T1, T2, T3>(
true,
<_ as StrassenHint>::strassen_threshold(ring.get_ring()),
lhs,
rhs,
dst,
ring,
&self.allocator,
)
}
fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
&self,
lhs: TransposableSubmatrix<V1, R::Element, T1>,
rhs: TransposableSubmatrix<V2, R::Element, T2>,
dst: TransposableSubmatrixMut<V3, R::Element, T3>,
ring: S,
) where
V1: AsPointerToSlice<R::Element>,
V2: AsPointerToSlice<R::Element>,
V3: AsPointerToSlice<R::Element>,
S: RingStore<Type = R> + Copy,
{
strassen::<_, _, _, _, _, T1, T2, T3>(
false,
<_ as StrassenHint>::strassen_threshold(ring.get_ring()),
lhs,
rhs,
dst,
ring,
&self.allocator,
)
}
}
#[stability::unstable(feature = "enable")]
pub fn naive_matmul<R, V1, V2, V3, const ADD_ASSIGN: bool, const T1: bool, const T2: bool, const T3: bool>(
lhs: TransposableSubmatrix<V1, El<R>, T1>,
rhs: TransposableSubmatrix<V2, El<R>, T2>,
mut dst: TransposableSubmatrixMut<V3, El<R>, T3>,
ring: R,
) where
R: RingStore + Copy,
V1: AsPointerToSlice<El<R>>,
V2: AsPointerToSlice<El<R>>,
V3: AsPointerToSlice<El<R>>,
{
assert_eq!(lhs.row_count(), dst.row_count());
assert_eq!(rhs.col_count(), dst.col_count());
assert_eq!(lhs.col_count(), rhs.row_count());
for i in 0..lhs.row_count() {
for j in 0..rhs.col_count() {
let inner_prod = <_ as ComputeInnerProduct>::inner_product_ref(
ring.get_ring(),
(0..lhs.col_count()).map(|k| (lhs.at(i, k), rhs.at(k, j))),
);
if ADD_ASSIGN {
ring.add_assign(dst.at_mut(i, j), inner_prod);
} else {
*dst.at_mut(i, j) = inner_prod;
}
}
}
}
#[cfg(test)]
use test;
#[cfg(test)]
use crate::primitive_int::*;
#[cfg(test)]
const BENCH_SIZE: usize = 128;
#[cfg(test)]
type BenchInt = i64;
#[bench]
fn bench_naive_matmul(bencher: &mut test::Bencher) {
let lhs = OwnedMatrix::from_fn_in(
BENCH_SIZE,
BENCH_SIZE,
|i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
Global,
);
let rhs = OwnedMatrix::from_fn_in(
BENCH_SIZE,
BENCH_SIZE,
|i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
Global,
);
let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
bencher.iter(|| {
strassen::<_, _, _, _, _, false, false, false>(
false,
100,
std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
StaticRing::<BenchInt>::RING,
&Global,
);
assert_eq!(
(BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
*result.at(0, 0)
);
});
}
#[bench]
fn bench_strassen_matmul(bencher: &mut test::Bencher) {
let threshold_log_2 = 4;
let lhs = OwnedMatrix::from_fn_in(
BENCH_SIZE,
BENCH_SIZE,
|i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
Global,
);
let rhs = OwnedMatrix::from_fn_in(
BENCH_SIZE,
BENCH_SIZE,
|i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
Global,
);
let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
bencher.iter(|| {
strassen::<_, _, _, _, _, false, false, false>(
false,
threshold_log_2,
std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
StaticRing::<BenchInt>::RING,
&Global,
);
assert_eq!(
(BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
*result.at(0, 0)
);
});
}