1
2use strassen::*;
3
4use crate::matrix::*;
5use crate::ring::*;
6
7use std::alloc::Allocator;
8use std::alloc::Global;
9use std::ops::Deref;
10
11pub mod strassen;
12
13#[stability::unstable(feature = "enable")]
18pub trait ComputeInnerProduct: RingBase {
19
20 fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
24 where Self::Element: 'a,
25 Self: 'a;
26
27 fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
31 where Self::Element: 'a,
32 Self: 'a;
33
34 fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
38}
39
40impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
41
42 default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
43 where Self::Element: 'a
44 {
45 self.inner_product(els.map(|(l, r)| (self.clone_el(l), r)))
46 }
47
48 default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
49 where Self::Element: 'a
50 {
51 self.inner_product_ref_fst(els.map(|(l, r)| (l, self.clone_el(r))))
52 }
53
54 default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
55 self.sum(els.map(|(l, r)| self.mul(l, r)))
56 }
57}
58
59#[stability::unstable(feature = "enable")]
63pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
64
65 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
73 where V1: AsPointerToSlice<R::Element>,
74 V2: AsPointerToSlice<R::Element>,
75 V3: AsPointerToSlice<R::Element>,
76 S: RingStore<Type = R> + Copy;
77
78 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, mut dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
86 where V1: AsPointerToSlice<R::Element>,
87 V2: AsPointerToSlice<R::Element>,
88 V3: AsPointerToSlice<R::Element>,
89 S: RingStore<Type = R> + Copy
90 {
91 for i in 0..dst.row_count() {
92 for j in 0..dst.col_count() {
93 *dst.at_mut(i, j) = ring.zero();
94 }
95 }
96 self.add_matmul(lhs, rhs, dst, ring);
97 }
98}
99
100impl<R, T> MatmulAlgorithm<R> for T
101 where R: ?Sized + RingBase,
102 T: Deref,
103 T::Target: MatmulAlgorithm<R>
104{
105 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
106 where V1: AsPointerToSlice<R::Element>,
107 V2: AsPointerToSlice<R::Element>,
108 V3: AsPointerToSlice<R::Element>,
109 S: RingStore<Type = R> + Copy
110 {
111 (**self).add_matmul(lhs, rhs, dst, ring)
112 }
113
114 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
115 where V1: AsPointerToSlice<R::Element>,
116 V2: AsPointerToSlice<R::Element>,
117 V3: AsPointerToSlice<R::Element>,
118 S: RingStore<Type = R> + Copy
119 {
120 (**self).matmul(lhs, rhs, dst, ring)
121 }
122}
123
124#[stability::unstable(feature = "enable")]
129pub trait StrassenHint: RingBase {
130
131 fn strassen_threshold(&self) -> usize;
142}
143
144impl<R: RingBase + ?Sized> StrassenHint for R {
145
146 default fn strassen_threshold(&self) -> usize {
147 0
148 }
149}
150
151#[stability::unstable(feature = "enable")]
152pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
153
154#[stability::unstable(feature = "enable")]
155#[derive(Clone, Copy)]
156pub struct StrassenAlgorithm<A: Allocator = Global> {
157 allocator: A
158}
159
160impl<A: Allocator> StrassenAlgorithm<A> {
161
162 #[stability::unstable(feature = "enable")]
163 pub const fn new(allocator: A) -> Self {
164 Self { allocator }
165 }
166}
167
168impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
169
170 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
171 &self,
172 lhs: TransposableSubmatrix<V1, R::Element, T1>,
173 rhs: TransposableSubmatrix<V2, R::Element, T2>,
174 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
175 ring: S
176 )
177 where V1: AsPointerToSlice<R::Element>,
178 V2: AsPointerToSlice<R::Element>,
179 V3: AsPointerToSlice<R::Element>,
180 S: RingStore<Type = R> + Copy
181 {
182 strassen::<_, _, _, _, _, T1, T2, T3>(true, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
183 }
184
185 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
186 &self,
187 lhs: TransposableSubmatrix<V1, R::Element, T1>,
188 rhs: TransposableSubmatrix<V2, R::Element, T2>,
189 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
190 ring: S
191 )
192 where V1: AsPointerToSlice<R::Element>,
193 V2: AsPointerToSlice<R::Element>,
194 V3: AsPointerToSlice<R::Element>,
195 S: RingStore<Type = R> + Copy
196 {
197 strassen::<_, _, _, _, _, T1, T2, T3>(false, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
198 }
199}
200
201#[cfg(test)]
202use test;
203#[cfg(test)]
204use crate::primitive_int::*;
205
206#[cfg(test)]
207const BENCH_SIZE: usize = 128;
208#[cfg(test)]
209type BenchInt = i64;
210
211#[bench]
212fn bench_naive_matmul(bencher: &mut test::Bencher) {
213 let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
214 let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
215 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
216 bencher.iter(|| {
217 strassen::<_, _, _, _, _, false, false, false>(
218 false,
219 100,
220 TransposableSubmatrix::from(lhs.data()),
221 TransposableSubmatrix::from(rhs.data()),
222 TransposableSubmatrixMut::from(result.data_mut()),
223 StaticRing::<BenchInt>::RING,
224 &Global
225 );
226 assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
227 });
228}
229
230#[bench]
231fn bench_strassen_matmul(bencher: &mut test::Bencher) {
232 let threshold_log_2 = 4;
233 let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
234 let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
235 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
236 bencher.iter(|| {
237 strassen::<_, _, _, _, _, false, false, false>(
238 false,
239 threshold_log_2,
240 TransposableSubmatrix::from(lhs.data()),
241 TransposableSubmatrix::from(rhs.data()),
242 TransposableSubmatrixMut::from(result.data_mut()),
243 StaticRing::<BenchInt>::RING,
244 &Global
245 );
246 assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
247 });
248}