Skip to main content

feanor_math/algorithms/matmul/
mod.rs

1use std::alloc::{Allocator, Global};
2use std::ops::Deref;
3
4use strassen::*;
5
6use crate::matrix::*;
7use crate::ring::*;
8
9/// Contains [`strassen::strassen()`], an implementation of Strassen's algorithm
10/// for matrix multiplication.
11pub mod strassen;
12
13/// Trait to allow rings to provide specialized implementations for inner products, i.e.
14/// the sums `sum_i a[i] * b[i]`.
15#[stability::unstable(feature = "enable")]
16pub trait ComputeInnerProduct: RingBase {
17    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
18    fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
19        &self,
20        els: I,
21    ) -> Self::Element
22    where
23        Self::Element: 'a,
24        Self: 'a;
25
26    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
27    fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
28        &self,
29        els: I,
30    ) -> Self::Element
31    where
32        Self::Element: 'a,
33        Self: 'a;
34
35    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
36    fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
37}
38
39impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
40    default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(
41        &self,
42        els: I,
43    ) -> Self::Element
44    where
45        Self::Element: 'a,
46    {
47        let mut result = self.zero();
48        for (l, r) in els {
49            result = self.fma(l, &r, result);
50        }
51        return result;
52    }
53
54    default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(
55        &self,
56        els: I,
57    ) -> Self::Element
58    where
59        Self::Element: 'a,
60    {
61        let mut result = self.zero();
62        for (l, r) in els {
63            result = self.fma(l, r, result);
64        }
65        return result;
66    }
67
68    default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
69        let mut result = self.zero();
70        for (l, r) in els {
71            result = self.fma(&l, &r, result);
72        }
73        return result;
74    }
75}
76
77/// Trait for objects that can compute a matrix multiplications over some ring.
78#[stability::unstable(feature = "enable")]
79pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
80    /// Computes the matrix product of `lhs` and `rhs`, and adds the result to `dst`.
81    ///
82    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm`
83    /// matrix. In this case, the function concretely computes `dst[i, j] += sum_l lhs[i, l] *
84    /// rhs[l, j]` where `l` runs from `0` to `k - 1`.
85    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
86        &self,
87        lhs: TransposableSubmatrix<V1, R::Element, T1>,
88        rhs: TransposableSubmatrix<V2, R::Element, T2>,
89        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
90        ring: S,
91    ) where
92        V1: AsPointerToSlice<R::Element>,
93        V2: AsPointerToSlice<R::Element>,
94        V3: AsPointerToSlice<R::Element>,
95        S: RingStore<Type = R> + Copy;
96
97    /// Computes the matrix product of `lhs` and `rhs`, and stores the result in `dst`.
98    ///
99    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm`
100    /// matrix. In this case, the function concretely computes `dst[i, j] = sum_l lhs[i, l] *
101    /// rhs[l, j]` where `l` runs from `0` to `k - 1`.
102    ///    
103    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
104        &self,
105        lhs: TransposableSubmatrix<V1, R::Element, T1>,
106        rhs: TransposableSubmatrix<V2, R::Element, T2>,
107        mut dst: TransposableSubmatrixMut<V3, R::Element, T3>,
108        ring: S,
109    ) where
110        V1: AsPointerToSlice<R::Element>,
111        V2: AsPointerToSlice<R::Element>,
112        V3: AsPointerToSlice<R::Element>,
113        S: RingStore<Type = R> + Copy,
114    {
115        for i in 0..dst.row_count() {
116            for j in 0..dst.col_count() {
117                *dst.at_mut(i, j) = ring.zero();
118            }
119        }
120        self.add_matmul(lhs, rhs, dst, ring);
121    }
122}
123
124impl<R, T> MatmulAlgorithm<R> for T
125where
126    R: ?Sized + RingBase,
127    T: Deref,
128    T::Target: MatmulAlgorithm<R>,
129{
130    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
131        &self,
132        lhs: TransposableSubmatrix<V1, R::Element, T1>,
133        rhs: TransposableSubmatrix<V2, R::Element, T2>,
134        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
135        ring: S,
136    ) where
137        V1: AsPointerToSlice<R::Element>,
138        V2: AsPointerToSlice<R::Element>,
139        V3: AsPointerToSlice<R::Element>,
140        S: RingStore<Type = R> + Copy,
141    {
142        (**self).add_matmul(lhs, rhs, dst, ring)
143    }
144
145    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
146        &self,
147        lhs: TransposableSubmatrix<V1, R::Element, T1>,
148        rhs: TransposableSubmatrix<V2, R::Element, T2>,
149        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
150        ring: S,
151    ) where
152        V1: AsPointerToSlice<R::Element>,
153        V2: AsPointerToSlice<R::Element>,
154        V3: AsPointerToSlice<R::Element>,
155        S: RingStore<Type = R> + Copy,
156    {
157        (**self).matmul(lhs, rhs, dst, ring)
158    }
159}
160
161/// Trait to allow rings to customize the parameters with which [`StrassenAlgorithm`] will
162/// compute matrix multiplications.
163#[stability::unstable(feature = "enable")]
164pub trait StrassenHint: RingBase {
165    /// Define a threshold from which on [`StrassenAlgorithm`] will use the Strassen algorithm.
166    ///
167    /// Concretely, when this returns `k`, [`StrassenAlgorithm`] will reduce the
168    /// matrix multipliction down to `2^k x 2^k` matrices using Strassen's algorithm,
169    /// and then use naive matmul for the rest.
170    ///
171    /// The value is `0`, but if the considered rings have fast multiplication (compared to
172    /// addition), then setting this higher may result in a performance gain.
173    fn strassen_threshold(&self) -> usize;
174}
175
176impl<R: RingBase + ?Sized> StrassenHint for R {
177    default fn strassen_threshold(&self) -> usize { 0 }
178}
179
180#[stability::unstable(feature = "enable")]
181pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
182
183#[stability::unstable(feature = "enable")]
184#[derive(Clone, Copy)]
185pub struct StrassenAlgorithm<A: Allocator = Global> {
186    allocator: A,
187}
188
189impl<A: Allocator> StrassenAlgorithm<A> {
190    #[stability::unstable(feature = "enable")]
191    pub const fn new(allocator: A) -> Self { Self { allocator } }
192}
193
194impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
195    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
196        &self,
197        lhs: TransposableSubmatrix<V1, R::Element, T1>,
198        rhs: TransposableSubmatrix<V2, R::Element, T2>,
199        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
200        ring: S,
201    ) where
202        V1: AsPointerToSlice<R::Element>,
203        V2: AsPointerToSlice<R::Element>,
204        V3: AsPointerToSlice<R::Element>,
205        S: RingStore<Type = R> + Copy,
206    {
207        strassen::<_, _, _, _, _, T1, T2, T3>(
208            true,
209            <_ as StrassenHint>::strassen_threshold(ring.get_ring()),
210            lhs,
211            rhs,
212            dst,
213            ring,
214            &self.allocator,
215        )
216    }
217
218    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
219        &self,
220        lhs: TransposableSubmatrix<V1, R::Element, T1>,
221        rhs: TransposableSubmatrix<V2, R::Element, T2>,
222        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
223        ring: S,
224    ) where
225        V1: AsPointerToSlice<R::Element>,
226        V2: AsPointerToSlice<R::Element>,
227        V3: AsPointerToSlice<R::Element>,
228        S: RingStore<Type = R> + Copy,
229    {
230        strassen::<_, _, _, _, _, T1, T2, T3>(
231            false,
232            <_ as StrassenHint>::strassen_threshold(ring.get_ring()),
233            lhs,
234            rhs,
235            dst,
236            ring,
237            &self.allocator,
238        )
239    }
240}
241
242/// Computes `dst = lhs * rhs` if `ADD_ASSIGN = false` and `dst += lhs * rhs` if `ADD_ASSIGN =
243/// true`, using the standard cubic formula for matrix multiplication.
244///
245/// This implementation is very simple and not very optimized. Usually it is used as a fallback
246/// for more sophisticated implementations.
247#[stability::unstable(feature = "enable")]
248pub fn naive_matmul<R, V1, V2, V3, const ADD_ASSIGN: bool, const T1: bool, const T2: bool, const T3: bool>(
249    lhs: TransposableSubmatrix<V1, El<R>, T1>,
250    rhs: TransposableSubmatrix<V2, El<R>, T2>,
251    mut dst: TransposableSubmatrixMut<V3, El<R>, T3>,
252    ring: R,
253) where
254    R: RingStore + Copy,
255    V1: AsPointerToSlice<El<R>>,
256    V2: AsPointerToSlice<El<R>>,
257    V3: AsPointerToSlice<El<R>>,
258{
259    assert_eq!(lhs.row_count(), dst.row_count());
260    assert_eq!(rhs.col_count(), dst.col_count());
261    assert_eq!(lhs.col_count(), rhs.row_count());
262    for i in 0..lhs.row_count() {
263        for j in 0..rhs.col_count() {
264            let inner_prod = <_ as ComputeInnerProduct>::inner_product_ref(
265                ring.get_ring(),
266                (0..lhs.col_count()).map(|k| (lhs.at(i, k), rhs.at(k, j))),
267            );
268            if ADD_ASSIGN {
269                ring.add_assign(dst.at_mut(i, j), inner_prod);
270            } else {
271                *dst.at_mut(i, j) = inner_prod;
272            }
273        }
274    }
275}
276
277#[cfg(test)]
278use test;
279
280#[cfg(test)]
281use crate::primitive_int::*;
282
283#[cfg(test)]
284const BENCH_SIZE: usize = 128;
285#[cfg(test)]
286type BenchInt = i64;
287
288#[bench]
289fn bench_naive_matmul(bencher: &mut test::Bencher) {
290    let lhs = OwnedMatrix::from_fn_in(
291        BENCH_SIZE,
292        BENCH_SIZE,
293        |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
294        Global,
295    );
296    let rhs = OwnedMatrix::from_fn_in(
297        BENCH_SIZE,
298        BENCH_SIZE,
299        |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
300        Global,
301    );
302    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
303    bencher.iter(|| {
304        strassen::<_, _, _, _, _, false, false, false>(
305            false,
306            100,
307            std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
308            std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
309            std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
310            StaticRing::<BenchInt>::RING,
311            &Global,
312        );
313        assert_eq!(
314            (BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
315            *result.at(0, 0)
316        );
317    });
318}
319
320#[bench]
321fn bench_strassen_matmul(bencher: &mut test::Bencher) {
322    let threshold_log_2 = 4;
323    let lhs = OwnedMatrix::from_fn_in(
324        BENCH_SIZE,
325        BENCH_SIZE,
326        |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
327        Global,
328    );
329    let rhs = OwnedMatrix::from_fn_in(
330        BENCH_SIZE,
331        BENCH_SIZE,
332        |i, j| std::hint::black_box(i as BenchInt + j as BenchInt),
333        Global,
334    );
335    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
336    bencher.iter(|| {
337        strassen::<_, _, _, _, _, false, false, false>(
338            false,
339            threshold_log_2,
340            std::hint::black_box(TransposableSubmatrix::from(lhs.data())),
341            std::hint::black_box(TransposableSubmatrix::from(rhs.data())),
342            std::hint::black_box(TransposableSubmatrixMut::from(result.data_mut())),
343            StaticRing::<BenchInt>::RING,
344            &Global,
345        );
346        assert_eq!(
347            (BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt,
348            *result.at(0, 0)
349        );
350    });
351}