1
2use strassen::*;
3
4use crate::matrix::*;
5use crate::ring::*;
6
7use std::alloc::Allocator;
8use std::alloc::Global;
9use std::ops::Deref;
10
11pub mod strassen;
16
17#[stability::unstable(feature = "enable")]
22pub trait ComputeInnerProduct: RingBase {
23
24 fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
28 where Self::Element: 'a,
29 Self: 'a;
30
31 fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
35 where Self::Element: 'a,
36 Self: 'a;
37
38 fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
42}
43
44impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
45
46 default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
47 where Self::Element: 'a
48 {
49 let mut result = self.zero();
50 for (l, r) in els {
51 result = self.fma(l, &r, result);
52 }
53 return result;
54 }
55
56 default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
57 where Self::Element: 'a
58 {
59 let mut result = self.zero();
60 for (l, r) in els {
61 result = self.fma(l, r, result);
62 }
63 return result;
64 }
65
66 default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
67 let mut result = self.zero();
68 for (l, r) in els {
69 result = self.fma(&l, &r, result);
70 }
71 return result;
72 }
73}
74
75#[stability::unstable(feature = "enable")]
79pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
80
81 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
89 where V1: AsPointerToSlice<R::Element>,
90 V2: AsPointerToSlice<R::Element>,
91 V3: AsPointerToSlice<R::Element>,
92 S: RingStore<Type = R> + Copy;
93
94 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, mut dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
102 where V1: AsPointerToSlice<R::Element>,
103 V2: AsPointerToSlice<R::Element>,
104 V3: AsPointerToSlice<R::Element>,
105 S: RingStore<Type = R> + Copy
106 {
107 for i in 0..dst.row_count() {
108 for j in 0..dst.col_count() {
109 *dst.at_mut(i, j) = ring.zero();
110 }
111 }
112 self.add_matmul(lhs, rhs, dst, ring);
113 }
114}
115
116impl<R, T> MatmulAlgorithm<R> for T
117 where R: ?Sized + RingBase,
118 T: Deref,
119 T::Target: MatmulAlgorithm<R>
120{
121 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
122 where V1: AsPointerToSlice<R::Element>,
123 V2: AsPointerToSlice<R::Element>,
124 V3: AsPointerToSlice<R::Element>,
125 S: RingStore<Type = R> + Copy
126 {
127 (**self).add_matmul(lhs, rhs, dst, ring)
128 }
129
130 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
131 where V1: AsPointerToSlice<R::Element>,
132 V2: AsPointerToSlice<R::Element>,
133 V3: AsPointerToSlice<R::Element>,
134 S: RingStore<Type = R> + Copy
135 {
136 (**self).matmul(lhs, rhs, dst, ring)
137 }
138}
139
140#[stability::unstable(feature = "enable")]
145pub trait StrassenHint: RingBase {
146
147 fn strassen_threshold(&self) -> usize;
158}
159
160impl<R: RingBase + ?Sized> StrassenHint for R {
161
162 default fn strassen_threshold(&self) -> usize {
163 0
164 }
165}
166
167#[stability::unstable(feature = "enable")]
168pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
169
170#[stability::unstable(feature = "enable")]
171#[derive(Clone, Copy)]
172pub struct StrassenAlgorithm<A: Allocator = Global> {
173 allocator: A
174}
175
176impl<A: Allocator> StrassenAlgorithm<A> {
177
178 #[stability::unstable(feature = "enable")]
179 pub const fn new(allocator: A) -> Self {
180 Self { allocator }
181 }
182}
183
184impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
185
186 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
187 &self,
188 lhs: TransposableSubmatrix<V1, R::Element, T1>,
189 rhs: TransposableSubmatrix<V2, R::Element, T2>,
190 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
191 ring: S
192 )
193 where V1: AsPointerToSlice<R::Element>,
194 V2: AsPointerToSlice<R::Element>,
195 V3: AsPointerToSlice<R::Element>,
196 S: RingStore<Type = R> + Copy
197 {
198 strassen::<_, _, _, _, _, T1, T2, T3>(true, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
199 }
200
201 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
202 &self,
203 lhs: TransposableSubmatrix<V1, R::Element, T1>,
204 rhs: TransposableSubmatrix<V2, R::Element, T2>,
205 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
206 ring: S
207 )
208 where V1: AsPointerToSlice<R::Element>,
209 V2: AsPointerToSlice<R::Element>,
210 V3: AsPointerToSlice<R::Element>,
211 S: RingStore<Type = R> + Copy
212 {
213 strassen::<_, _, _, _, _, T1, T2, T3>(false, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
214 }
215}
216
217#[stability::unstable(feature = "enable")]
225pub fn naive_matmul<R, V1, V2, V3, const ADD_ASSIGN: bool, const T1: bool, const T2: bool, const T3: bool>(
226 lhs: TransposableSubmatrix<V1, El<R>, T1>,
227 rhs: TransposableSubmatrix<V2, El<R>, T2>,
228 mut dst: TransposableSubmatrixMut<V3, El<R>, T3>,
229 ring: R
230)
231 where R: RingStore + Copy,
232 V1: AsPointerToSlice<El<R>>,
233 V2: AsPointerToSlice<El<R>>,
234 V3: AsPointerToSlice<El<R>>
235{
236 assert_eq!(lhs.row_count(), dst.row_count());
237 assert_eq!(rhs.col_count(), dst.col_count());
238 assert_eq!(lhs.col_count(), rhs.row_count());
239 for i in 0..lhs.row_count() {
240 for j in 0..rhs.col_count() {
241 let inner_prod = <_ as ComputeInnerProduct>::inner_product_ref(ring.get_ring(), (0..lhs.col_count()).map(|k| (lhs.at(i, k), rhs.at(k, j))));
242 if ADD_ASSIGN {
243 ring.add_assign(dst.at_mut(i, j), inner_prod);
244 } else {
245 *dst.at_mut(i, j) = inner_prod;
246 }
247 }
248 }
249}
250
251#[cfg(test)]
252use test;
253#[cfg(test)]
254use crate::primitive_int::*;
255
256#[cfg(test)]
257const BENCH_SIZE: usize = 128;
258#[cfg(test)]
259type BenchInt = i64;
260
261#[bench]
262fn bench_naive_matmul(bencher: &mut test::Bencher) {
263 let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
264 let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
265 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
266 bencher.iter(|| {
267 strassen::<_, _, _, _, _, false, false, false>(
268 false,
269 100,
270 std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
271 std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
272 std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
273 StaticRing::<BenchInt>::RING,
274 &Global
275 );
276 assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
277 });
278}
279
280#[bench]
281fn bench_strassen_matmul(bencher: &mut test::Bencher) {
282 let threshold_log_2 = 4;
283 let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
284 let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
285 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
286 bencher.iter(|| {
287 strassen::<_, _, _, _, _, false, false, false>(
288 false,
289 threshold_log_2,
290 std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
291 std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
292 std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
293 StaticRing::<BenchInt>::RING,
294 &Global
295 );
296 assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
297 });
298}