feanor_math/algorithms/matmul/
mod.rs

1
2use strassen::*;
3
4use crate::matrix::*;
5use crate::ring::*;
6
7use std::alloc::Allocator;
8use std::alloc::Global;
9use std::ops::Deref;
10
11///
12/// Contains [`strassen::strassen()`], an implementation of Strassen's algorithm
13/// for matrix multiplication.
14/// 
15pub mod strassen;
16
17///
18/// Trait to allow rings to provide specialized implementations for inner products, i.e.
19/// the sums `sum_i a[i] * b[i]`.
20/// 
21#[stability::unstable(feature = "enable")]
22pub trait ComputeInnerProduct: RingBase {
23
24    ///
25    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
26    /// 
27    fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
28        where Self::Element: 'a,
29            Self: 'a;
30
31    ///
32    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
33    /// 
34    fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
35        where Self::Element: 'a,
36            Self: 'a;
37
38    ///
39    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
40    /// 
41    fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
42}
43
44impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
45
46    default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
47        where Self::Element: 'a
48    {
49        let mut result = self.zero();
50        for (l, r) in els {
51            result = self.fma(l, &r, result);
52        }
53        return result;
54    }
55
56    default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
57        where Self::Element: 'a
58    {
59        let mut result = self.zero();
60        for (l, r) in els {
61            result = self.fma(l, r, result);
62        }
63        return result;
64    }
65
66    default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
67        let mut result = self.zero();
68        for (l, r) in els {
69            result = self.fma(&l, &r, result);
70        }
71        return result;
72    }
73}
74
75///
76/// Trait for objects that can compute a matrix multiplications over some ring.
77/// 
78#[stability::unstable(feature = "enable")]
79pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
80
81    ///
82    /// Computes the matrix product of `lhs` and `rhs`, and adds the result to `dst`.
83    /// 
84    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm` matrix.
85    /// In this case, the function concretely computes `dst[i, j] += sum_l lhs[i, l] * rhs[l, j]` where
86    /// `l` runs from `0` to `k - 1`.
87    /// 
88    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
89        where V1: AsPointerToSlice<R::Element>,
90            V2: AsPointerToSlice<R::Element>,
91            V3: AsPointerToSlice<R::Element>,
92            S: RingStore<Type = R> + Copy;
93         
94    ///
95    /// Computes the matrix product of `lhs` and `rhs`, and stores the result in `dst`.
96    /// 
97    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm` matrix.
98    /// In this case, the function concretely computes `dst[i, j] = sum_l lhs[i, l] * rhs[l, j]` where
99    /// `l` runs from `0` to `k - 1`.
100    ///    
101    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, mut dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
102        where V1: AsPointerToSlice<R::Element>,
103            V2: AsPointerToSlice<R::Element>,
104            V3: AsPointerToSlice<R::Element>,
105            S: RingStore<Type = R> + Copy
106    {
107        for i in 0..dst.row_count() {
108            for j in 0..dst.col_count() {
109                *dst.at_mut(i, j) = ring.zero();
110            }
111        }
112        self.add_matmul(lhs, rhs, dst, ring);
113    }
114}
115
116impl<R, T> MatmulAlgorithm<R> for T
117    where R: ?Sized + RingBase,
118        T: Deref,
119        T::Target: MatmulAlgorithm<R>
120{
121    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
122        where V1: AsPointerToSlice<R::Element>,
123            V2: AsPointerToSlice<R::Element>,
124            V3: AsPointerToSlice<R::Element>,
125            S: RingStore<Type = R> + Copy
126    {
127        (**self).add_matmul(lhs, rhs, dst, ring)
128    }
129
130    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
131        where V1: AsPointerToSlice<R::Element>,
132            V2: AsPointerToSlice<R::Element>,
133            V3: AsPointerToSlice<R::Element>,
134            S: RingStore<Type = R> + Copy
135    {
136        (**self).matmul(lhs, rhs, dst, ring)
137    }
138}
139
140///
141/// Trait to allow rings to customize the parameters with which [`StrassenAlgorithm`] will
142/// compute matrix multiplications.
143/// 
144#[stability::unstable(feature = "enable")]
145pub trait StrassenHint: RingBase {
146
147    ///
148    /// Define a threshold from which on [`StrassenAlgorithm`] will use the Strassen algorithm.
149    /// 
150    /// Concretely, when this returns `k`, [`StrassenAlgorithm`] will reduce the 
151    /// matrix multipliction down to `2^k x 2^k` matrices using Strassen's algorithm,
152    /// and then use naive matmul for the rest.
153    /// 
154    /// The value is `0`, but if the considered rings have fast multiplication (compared to addition), 
155    /// then setting this higher may result in a performance gain.
156    /// 
157    fn strassen_threshold(&self) -> usize;
158}
159
160impl<R: RingBase + ?Sized> StrassenHint for R {
161
162    default fn strassen_threshold(&self) -> usize {
163        0
164    }
165}
166
167#[stability::unstable(feature = "enable")]
168pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
169
170#[stability::unstable(feature = "enable")]
171#[derive(Clone, Copy)]
172pub struct StrassenAlgorithm<A: Allocator = Global> {
173    allocator: A
174}
175
176impl<A: Allocator> StrassenAlgorithm<A> {
177
178    #[stability::unstable(feature = "enable")]
179    pub const fn new(allocator: A) -> Self {
180        Self { allocator }
181    }
182}
183
184impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
185
186    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
187        &self,
188        lhs: TransposableSubmatrix<V1, R::Element, T1>,
189        rhs: TransposableSubmatrix<V2, R::Element, T2>,
190        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
191        ring: S
192    )
193        where V1: AsPointerToSlice<R::Element>,
194            V2: AsPointerToSlice<R::Element>,
195            V3: AsPointerToSlice<R::Element>,
196            S: RingStore<Type = R> + Copy
197    {
198        strassen::<_, _, _, _, _, T1, T2, T3>(true, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
199    }
200
201    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
202        &self,
203        lhs: TransposableSubmatrix<V1, R::Element, T1>,
204        rhs: TransposableSubmatrix<V2, R::Element, T2>,
205        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
206        ring: S
207    )
208        where V1: AsPointerToSlice<R::Element>,
209            V2: AsPointerToSlice<R::Element>,
210            V3: AsPointerToSlice<R::Element>,
211            S: RingStore<Type = R> + Copy
212    {
213        strassen::<_, _, _, _, _, T1, T2, T3>(false, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
214    }
215}
216
217///
218/// Computes `dst = lhs * rhs` if `ADD_ASSIGN = false` and `dst += lhs * rhs` if `ADD_ASSIGN = true`,
219/// using the standard cubic formula for matrix multiplication. 
220/// 
221/// This implementation is very simple and not very optimized. Usually it is used as a fallback
222/// for more sophisticated implementations. 
223/// 
224#[stability::unstable(feature = "enable")]
225pub fn naive_matmul<R, V1, V2, V3, const ADD_ASSIGN: bool, const T1: bool, const T2: bool, const T3: bool>(
226    lhs: TransposableSubmatrix<V1, El<R>, T1>, 
227    rhs: TransposableSubmatrix<V2, El<R>, T2>, 
228    mut dst: TransposableSubmatrixMut<V3, El<R>, T3>, 
229    ring: R
230)
231    where R: RingStore + Copy, 
232        V1: AsPointerToSlice<El<R>>,
233        V2: AsPointerToSlice<El<R>>,
234        V3: AsPointerToSlice<El<R>>
235{
236    assert_eq!(lhs.row_count(), dst.row_count());
237    assert_eq!(rhs.col_count(), dst.col_count());
238    assert_eq!(lhs.col_count(), rhs.row_count());
239    for i in 0..lhs.row_count() {
240        for j in 0..rhs.col_count() {
241            let inner_prod = <_ as ComputeInnerProduct>::inner_product_ref(ring.get_ring(), (0..lhs.col_count()).map(|k| (lhs.at(i, k), rhs.at(k, j))));
242            if ADD_ASSIGN {
243                ring.add_assign(dst.at_mut(i, j), inner_prod);
244            } else {
245                *dst.at_mut(i, j) = inner_prod;
246            }
247        }
248    }
249}
250
251#[cfg(test)]
252use test;
253#[cfg(test)]
254use crate::primitive_int::*;
255
256#[cfg(test)]
257const BENCH_SIZE: usize = 128;
258#[cfg(test)]
259type BenchInt = i64;
260
261#[bench]
262fn bench_naive_matmul(bencher: &mut test::Bencher) {
263    let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
264    let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
265    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
266    bencher.iter(|| {
267        strassen::<_, _, _, _, _, false, false, false>(
268            false, 
269            100, 
270            std::hint::black_box(TransposableSubmatrix::from(lhs.data())), 
271            std::hint::black_box(TransposableSubmatrix::from(rhs.data())), 
272            std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())), 
273            StaticRing::<BenchInt>::RING, 
274            &Global
275        );
276        assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
277    });
278}
279
280#[bench]
281fn bench_strassen_matmul(bencher: &mut test::Bencher) {
282    let threshold_log_2 = 4;
283    let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
284    let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
285    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
286    bencher.iter(|| {
287        strassen::<_, _, _, _, _, false, false, false>(
288            false, 
289            threshold_log_2, 
290            std::hint::black_box(TransposableSubmatrix::from(lhs.data())), 
291            std::hint::black_box(TransposableSubmatrix::from(rhs.data())), 
292            std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())), 
293            StaticRing::<BenchInt>::RING, 
294            &Global
295        );
296        assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
297    });
298}