feanor_math/algorithms/matmul/
mod.rs

1
2use strassen::*;
3
4use crate::matrix::*;
5use crate::ring::*;
6
7use std::alloc::Allocator;
8use std::alloc::Global;
9use std::ops::Deref;
10
11pub mod strassen;
12
13///
14/// Trait to allow rings to provide specialized implementations for inner products, i.e.
15/// the sums `sum_i a[i] * b[i]`.
16/// 
17#[stability::unstable(feature = "enable")]
18pub trait ComputeInnerProduct: RingBase {
19
20    ///
21    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
22    /// 
23    fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
24        where Self::Element: 'a,
25            Self: 'a;
26
27    ///
28    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
29    /// 
30    fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
31        where Self::Element: 'a,
32            Self: 'a;
33
34    ///
35    /// Computes the inner product `sum_i lhs[i] * rhs[i]`.
36    /// 
37    fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element;
38}
39
40impl<R: ?Sized + RingBase> ComputeInnerProduct for R {
41
42    default fn inner_product_ref_fst<'a, I: Iterator<Item = (&'a Self::Element, Self::Element)>>(&self, els: I) -> Self::Element
43        where Self::Element: 'a
44    {
45        self.inner_product(els.map(|(l, r)| (self.clone_el(l), r)))
46    }
47
48    default fn inner_product_ref<'a, I: Iterator<Item = (&'a Self::Element, &'a Self::Element)>>(&self, els: I) -> Self::Element
49        where Self::Element: 'a
50    {
51        self.inner_product_ref_fst(els.map(|(l, r)| (l, self.clone_el(r))))
52    }
53
54    default fn inner_product<I: Iterator<Item = (Self::Element, Self::Element)>>(&self, els: I) -> Self::Element {
55        self.sum(els.map(|(l, r)| self.mul(l, r)))
56    }
57}
58
59///
60/// Trait for objects that can compute a matrix multiplications over some ring.
61/// 
62#[stability::unstable(feature = "enable")]
63pub trait MatmulAlgorithm<R: ?Sized + RingBase> {
64
65    ///
66    /// Computes the matrix product of `lhs` and `rhs`, and adds the result to `dst`.
67    /// 
68    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm` matrix.
69    /// In this case, the function concretely computes `dst[i, j] += sum_l lhs[i, l] * rhs[l, j]` where
70    /// `l` runs from `0` to `k - 1`.
71    /// 
72    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
73        where V1: AsPointerToSlice<R::Element>,
74            V2: AsPointerToSlice<R::Element>,
75            V3: AsPointerToSlice<R::Element>,
76            S: RingStore<Type = R> + Copy;
77         
78    ///
79    /// Computes the matrix product of `lhs` and `rhs`, and stores the result in `dst`.
80    /// 
81    /// This requires that `lhs` is a `nxk` matrix, `rhs` is a `kxm` matrix and `dst` is a `nxm` matrix.
82    /// In this case, the function concretely computes `dst[i, j] = sum_l lhs[i, l] * rhs[l, j]` where
83    /// `l` runs from `0` to `k - 1`.
84    ///    
85    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, mut dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
86        where V1: AsPointerToSlice<R::Element>,
87            V2: AsPointerToSlice<R::Element>,
88            V3: AsPointerToSlice<R::Element>,
89            S: RingStore<Type = R> + Copy
90    {
91        for i in 0..dst.row_count() {
92            for j in 0..dst.col_count() {
93                *dst.at_mut(i, j) = ring.zero();
94            }
95        }
96        self.add_matmul(lhs, rhs, dst, ring);
97    }
98}
99
100impl<R, T> MatmulAlgorithm<R> for T
101    where R: ?Sized + RingBase,
102        T: Deref,
103        T::Target: MatmulAlgorithm<R>
104{
105    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
106        where V1: AsPointerToSlice<R::Element>,
107            V2: AsPointerToSlice<R::Element>,
108            V3: AsPointerToSlice<R::Element>,
109            S: RingStore<Type = R> + Copy
110    {
111        (**self).add_matmul(lhs, rhs, dst, ring)
112    }
113
114    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(&self, lhs: TransposableSubmatrix<V1, R::Element, T1>, rhs: TransposableSubmatrix<V2, R::Element, T2>, dst: TransposableSubmatrixMut<V3, R::Element, T3>, ring: S)
115        where V1: AsPointerToSlice<R::Element>,
116            V2: AsPointerToSlice<R::Element>,
117            V3: AsPointerToSlice<R::Element>,
118            S: RingStore<Type = R> + Copy
119    {
120        (**self).matmul(lhs, rhs, dst, ring)
121    }
122}
123
124///
125/// Trait to allow rings to customize the parameters with which [`StrassenAlgorithm`] will
126/// compute matrix multiplications.
127/// 
128#[stability::unstable(feature = "enable")]
129pub trait StrassenHint: RingBase {
130
131    ///
132    /// Define a threshold from which on [`StrassenAlgorithm`] will use the Strassen algorithm.
133    /// 
134    /// Concretely, when this returns `k`, [`StrassenAlgorithm`] will reduce the 
135    /// matrix multipliction down to `2^k x 2^k` matrices using Strassen's algorithm,
136    /// and then use naive matmul for the rest.
137    /// 
138    /// The value is `0`, but if the considered rings have fast multiplication (compared to addition), 
139    /// then setting this higher may result in a performance gain.
140    /// 
141    fn strassen_threshold(&self) -> usize;
142}
143
144impl<R: RingBase + ?Sized> StrassenHint for R {
145
146    default fn strassen_threshold(&self) -> usize {
147        0
148    }
149}
150
151#[stability::unstable(feature = "enable")]
152pub const STANDARD_MATMUL: StrassenAlgorithm = StrassenAlgorithm::new(Global);
153
154#[stability::unstable(feature = "enable")]
155#[derive(Clone, Copy)]
156pub struct StrassenAlgorithm<A: Allocator = Global> {
157    allocator: A
158}
159
160impl<A: Allocator> StrassenAlgorithm<A> {
161
162    #[stability::unstable(feature = "enable")]
163    pub const fn new(allocator: A) -> Self {
164        Self { allocator }
165    }
166}
167
168impl<R: ?Sized + RingBase, A: Allocator> MatmulAlgorithm<R> for StrassenAlgorithm<A> {
169
170    fn add_matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
171        &self,
172        lhs: TransposableSubmatrix<V1, R::Element, T1>,
173        rhs: TransposableSubmatrix<V2, R::Element, T2>,
174        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
175        ring: S
176    )
177        where V1: AsPointerToSlice<R::Element>,
178            V2: AsPointerToSlice<R::Element>,
179            V3: AsPointerToSlice<R::Element>,
180            S: RingStore<Type = R> + Copy
181    {
182        strassen::<_, _, _, _, _, T1, T2, T3>(true, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
183    }
184
185    fn matmul<S, V1, V2, V3, const T1: bool, const T2: bool, const T3: bool>(
186        &self,
187        lhs: TransposableSubmatrix<V1, R::Element, T1>,
188        rhs: TransposableSubmatrix<V2, R::Element, T2>,
189        dst: TransposableSubmatrixMut<V3, R::Element, T3>,
190        ring: S
191    )
192        where V1: AsPointerToSlice<R::Element>,
193            V2: AsPointerToSlice<R::Element>,
194            V3: AsPointerToSlice<R::Element>,
195            S: RingStore<Type = R> + Copy
196    {
197        strassen::<_, _, _, _, _, T1, T2, T3>(false, <_ as StrassenHint>::strassen_threshold(ring.get_ring()), lhs, rhs, dst, ring, &self.allocator)
198    }
199}
200
201#[cfg(test)]
202use test;
203#[cfg(test)]
204use crate::primitive_int::*;
205
206#[cfg(test)]
207const BENCH_SIZE: usize = 128;
208#[cfg(test)]
209type BenchInt = i64;
210
211#[bench]
212fn bench_naive_matmul(bencher: &mut test::Bencher) {
213    let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
214    let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
215    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
216    bencher.iter(|| {
217        strassen::<_, _, _, _, _, false, false, false>(
218            false, 
219            100, 
220            TransposableSubmatrix::from(lhs.data()), 
221            TransposableSubmatrix::from(rhs.data()), 
222            TransposableSubmatrixMut::from(result.data_mut()), 
223            StaticRing::<BenchInt>::RING, 
224            &Global
225        );
226        assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
227    });
228}
229
230#[bench]
231fn bench_strassen_matmul(bencher: &mut test::Bencher) {
232    let threshold_log_2 = 4;
233    let lhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
234    let rhs = OwnedMatrix::from_fn_in(BENCH_SIZE, BENCH_SIZE, |i, j| std::hint::black_box(i as BenchInt + j as BenchInt), Global);
235    let mut result: OwnedMatrix<BenchInt> = OwnedMatrix::zero(BENCH_SIZE, BENCH_SIZE, StaticRing::<BenchInt>::RING);
236    bencher.iter(|| {
237        strassen::<_, _, _, _, _, false, false, false>(
238            false, 
239            threshold_log_2, 
240            TransposableSubmatrix::from(lhs.data()), 
241            TransposableSubmatrix::from(rhs.data()), 
242            TransposableSubmatrixMut::from(result.data_mut()), 
243            StaticRing::<BenchInt>::RING, 
244            &Global
245        );
246        assert_eq!((BENCH_SIZE * (BENCH_SIZE + 1) * (BENCH_SIZE * 2 + 1) / 6 - BENCH_SIZE * BENCH_SIZE) as BenchInt, *result.at(0, 0));
247    });
248}