1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use strassen::*;
5
6use crate::matrix::*;
7use crate::ring::*;
8
9pub mod strassen;
12
13#[stability::unstable(feature = "enable")]
16pub trait ComputeInnerProduct: RingBase {
17 fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
19 &self,
20 els: I,
21 ) -> Self::Element
22 where
23 Self::Element: 'a,
24 Self: 'a;
25
26 fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
28 &self,
29 els: I,
30 ) -> Self::Element
31 where
32 Self::Element: 'a,
33 Self: 'a;
34
35 fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
37}
38
39impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
40 default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
41 &self,
42 els: I,
43 ) -> Self::Element
44 where
45 Self::Element: 'a,
46 {
47 let mut result = self.zero();
48 for (l, r) in els {
49 result = self.fma(l, &r, result);
50 }
51 return result;
52 }
53
54 default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
55 &self,
56 els: I,
57 ) -> Self::Element
58 where
59 Self::Element: 'a,
60 {
61 let mut result = self.zero();
62 for (l, r) in els {
63 result = self.fma(l, r, result);
64 }
65 return result;
66 }
67
68 default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
69 let mut result = self.zero();
70 for (l, r) in els {
71 result = self.fma(&l, &r, result);
72 }
73 return result;
74 }
75}
76
77#[stability::unstable(feature = "enable")]
79pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
80 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
86 &self,
87 lhs: TransposableSubmatrix<V1, R::Element, T1>,
88 rhs: TransposableSubmatrix<V2, R::Element, T2>,
89 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
90 ring: S,
91 ) where
92 V1: AsPointerToSlice<R::Element>,
93 V2: AsPointerToSlice<R::Element>,
94 V3: AsPointerToSlice<R::Element>,
95 S: RingStore<Type = R> + Copy;
96
97 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
104 &self,
105 lhs: TransposableSubmatrix<V1, R::Element, T1>,
106 rhs: TransposableSubmatrix<V2, R::Element, T2>,
107 mut dst: TransposableSubmatrixMut<V3, R::Element, T3>,
108 ring: S,
109 ) where
110 V1: AsPointerToSlice<R::Element>,
111 V2: AsPointerToSlice<R::Element>,
112 V3: AsPointerToSlice<R::Element>,
113 S: RingStore<Type = R> + Copy,
114 {
115 for i in 0..dst.row_count() {
116 for j in 0..dst.col_count() {
117 *dst.at_mut(i, j) = ring.zero();
118 }
119 }
120 self.add_matmul(lhs, rhs, dst, ring);
121 }
122}
123
124impl<R, T> MatmulAlgorithm<R> for T
125where
126 R: ?Sized + RingBase,
127 T: Deref,
128 T::Target: MatmulAlgorithm<R>,
129{
130 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
131 &self,
132 lhs: TransposableSubmatrix<V1, R::Element, T1>,
133 rhs: TransposableSubmatrix<V2, R::Element, T2>,
134 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
135 ring: S,
136 ) where
137 V1: AsPointerToSlice<R::Element>,
138 V2: AsPointerToSlice<R::Element>,
139 V3: AsPointerToSlice<R::Element>,
140 S: RingStore<Type = R> + Copy,
141 {
142 (**self).add_matmul(lhs, rhs, dst, ring)
143 }
144
145 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
146 &self,
147 lhs: TransposableSubmatrix<V1, R::Element, T1>,
148 rhs: TransposableSubmatrix<V2, R::Element, T2>,
149 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
150 ring: S,
151 ) where
152 V1: AsPointerToSlice<R::Element>,
153 V2: AsPointerToSlice<R::Element>,
154 V3: AsPointerToSlice<R::Element>,
155 S: RingStore<Type = R> + Copy,
156 {
157 (**self).matmul(lhs, rhs, dst, ring)
158 }
159}
160
161#[stability::unstable(feature = "enable")]
164pub trait StrassenHint: RingBase {
165 fn strassen_threshold(&self) -> usize;
174}
175
176impl<R: RingBase + ?Sized> StrassenHint for R {
177 default fn strassen_threshold(&self) -> usize { 0 }
178}
179
180#[stability::unstable(feature = "enable")]
181pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
182
183#[stability::unstable(feature = "enable")]
184#[derive(Clone, Copy)]
185pub struct StrassenAlgorithm<A: Allocator = Global> {
186 allocator: A,
187}
188
189impl<A: Allocator> StrassenAlgorithm<A> {
190 #[stability::unstable(feature = "enable")]
191 pub const fn new(allocator: A) -> Self { Self { allocator } }
192}
193
194impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
195 fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
196 &self,
197 lhs: TransposableSubmatrix<V1, R::Element, T1>,
198 rhs: TransposableSubmatrix<V2, R::Element, T2>,
199 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
200 ring: S,
201 ) where
202 V1: AsPointerToSlice<R::Element>,
203 V2: AsPointerToSlice<R::Element>,
204 V3: AsPointerToSlice<R::Element>,
205 S: RingStore<Type = R> + Copy,
206 {
207 strassen::<_, _, _, _, _, T1, T2, T3>(
208 true,
209 <_ as StrassenHint>::strassen_threshold(ring.get_ring()),
210 lhs,
211 rhs,
212 dst,
213 ring,
214 &self.allocator,
215 )
216 }
217
218 fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
219 &self,
220 lhs: TransposableSubmatrix<V1, R::Element, T1>,
221 rhs: TransposableSubmatrix<V2, R::Element, T2>,
222 dst: TransposableSubmatrixMut<V3, R::Element, T3>,
223 ring: S,
224 ) where
225 V1: AsPointerToSlice<R::Element>,
226 V2: AsPointerToSlice<R::Element>,
227 V3: AsPointerToSlice<R::Element>,
228 S: RingStore<Type = R> + Copy,
229 {
230 strassen::<_, _, _, _, _, T1, T2, T3>(
231 false,
232 <_ as StrassenHint>::strassen_threshold(ring.get_ring()),
233 lhs,
234 rhs,
235 dst,
236 ring,
237 &self.allocator,
238 )
239 }
240}
241
242#[stability::unstable(feature = "enable")]
248pub fn naive_matmul<R, V1, V2, V3, const ADD_ASSIGN: bool, const T1: bool, const T2: bool, const T3: bool>(
249 lhs: TransposableSubmatrix<V1, El<R>, T1>,
250 rhs: TransposableSubmatrix<V2, El<R>, T2>,
251 mut dst: TransposableSubmatrixMut<V3, El<R>, T3>,
252 ring: R,
253) where
254 R: RingStore + Copy,
255 V1: AsPointerToSlice<El<R>>,
256 V2: AsPointerToSlice<El<R>>,
257 V3: AsPointerToSlice<El<R>>,
258{
259 assert_eq!(lhs.row_count(), dst.row_count());
260 assert_eq!(rhs.col_count(), dst.col_count());
261 assert_eq!(lhs.col_count(), rhs.row_count());
262 for i in 0..lhs.row_count() {
263 for j in 0..rhs.col_count() {
264 let inner_prod = <_ as ComputeInnerProduct>::inner_product_ref(
265 ring.get_ring(),
266 (0..lhs.col_count()).map(|k| (lhs.at(i, k), rhs.at(k, j))),
267 );
268 if ADD_ASSIGN {
269 ring.add_assign(dst.at_mut(i, j), inner_prod);
270 } else {
271 *dst.at_mut(i, j) = inner_prod;
272 }
273 }
274 }
275}
276
277#[cfg(test)]
278use test;
279
280#[cfg(test)]
281use crate::primitive_int::*;
282
283#[cfg(test)]
284const BENCH_SIZE: usize = 128;
285#[cfg(test)]
286type BenchInt = i64;
287
288#[bench]
289fn bench_naive_matmul(bencher: &mut test::Bencher) {
290 let lhs = OwnedMatrix::from_fn_in(
291 BENCH_SIZE,
292 BENCH_SIZE,
293 |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
294 Global,
295 );
296 let rhs = OwnedMatrix::from_fn_in(
297 BENCH_SIZE,
298 BENCH_SIZE,
299 |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
300 Global,
301 );
302 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
303 bencher.iter(|| {
304 strassen::<_, _, _, _, _, false, false, false>(
305 false,
306 100,
307 std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
308 std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
309 std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
310 StaticRing::<BenchInt>::RING,
311 &Global,
312 );
313 assert_eq!(
314 (BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
315 *result.at(0, 0)
316 );
317 });
318}
319
320#[bench]
321fn bench_strassen_matmul(bencher: &mut test::Bencher) {
322 let threshold_log_2 = 4;
323 let lhs = OwnedMatrix::from_fn_in(
324 BENCH_SIZE,
325 BENCH_SIZE,
326 |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
327 Global,
328 );
329 let rhs = OwnedMatrix::from_fn_in(
330 BENCH_SIZE,
331 BENCH_SIZE,
332 |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
333 Global,
334 );
335 let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
336 bencher.iter(|| {
337 strassen::<_, _, _, _, _, false, false, false>(
338 false,
339 threshold_log_2,
340 std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
341 std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
342 std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
343 StaticRing::<BenchInt>::RING,
344 &Global,
345 );
346 assert_eq!(
347 (BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
348 *result.at(0, 0)
349 );
350 });
351}