mdarray_linalg_blas/matmul/
context.rs

1use num_traits::{One, Zero};
2use std::mem::MaybeUninit;
3
4use cblas_sys::{CBLAS_SIDE, CBLAS_UPLO};
5use mdarray::{DSlice, DTensor, Dense, DynRank, Layout, Slice, Tensor, tensor};
6use num_complex::ComplexFloat;
7
8use mdarray_linalg::matmul::{
9    Axes, MatMul, MatMulBuilder, Side, ContractBuilder, Triangle, Type, _contract,
10};
11
12use super::scalar::BlasScalar;
13use super::simple::{gemm, gemm_uninit, hemm_uninit, symm_uninit, trmm};
14
15use crate::Blas;
16
17struct BlasMatMulBuilder<'a, T, La, Lb>
18where
19    La: Layout,
20    Lb: Layout,
21{
22    alpha: T,
23    a: &'a DSlice<T, 2, La>,
24    b: &'a DSlice<T, 2, Lb>,
25}
26
27struct BlasContractBuilder<'a, T, La, Lb>
28where
29    La: Layout,
30    Lb: Layout,
31{
32    alpha: T,
33    a: &'a Slice<T, DynRank, La>,
34    b: &'a Slice<T, DynRank, Lb>,
35    axes: Axes,
36}
37
38impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for BlasMatMulBuilder<'a, T, La, Lb>
39where
40    La: Layout,
41    Lb: Layout,
42    T: BlasScalar + ComplexFloat + Zero + One,
43    // i8: Into<T::Real>,
44    // T::Real: Into<T>,
45{
46    fn parallelize(self) -> Self {
47        self
48    }
49
50    fn scale(mut self, factor: T) -> Self {
51        self.alpha = self.alpha * factor;
52        self
53    }
54
55    fn eval(self) -> DTensor<T, 2> {
56        let (m, _) = *self.a.shape();
57        let (_, n) = *self.b.shape();
58        let c = tensor![[MaybeUninit::<T>::uninit(); n]; m];
59        gemm_uninit::<T, La, Lb, Dense>(self.alpha, self.a, self.b, T::zero(), c)
60        // formerly 0.into().into() instead of T::zero() but
61        // propagating the associated bounds was causing a lot of
62        // trouble
63    }
64
65    fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
66        gemm(self.alpha, self.a, self.b, T::zero(), c);
67    }
68
69    fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
70        gemm(self.alpha, self.a, self.b, T::one(), c);
71    }
72
73    fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T) {
74        gemm(self.alpha, self.a, self.b, beta, c);
75    }
76
77    fn special(self, lr: Side, type_of_matrix: Type, tr: Triangle) -> DTensor<T, 2> {
78        let (m, _) = *self.a.shape();
79        let (_, n) = *self.b.shape();
80        let c = tensor![[MaybeUninit::<T>::uninit(); n]; m];
81        let cblas_side = match lr {
82            Side::Left => CBLAS_SIDE::CblasLeft,
83            Side::Right => CBLAS_SIDE::CblasRight,
84        };
85        let cblas_triangle = match tr {
86            Triangle::Lower => CBLAS_UPLO::CblasLower,
87            Triangle::Upper => CBLAS_UPLO::CblasUpper,
88        };
89        match type_of_matrix {
90            Type::Her => hemm_uninit::<T, La, Lb, Dense>(
91                self.alpha,
92                self.a,
93                self.b,
94                T::zero(),
95                c,
96                cblas_side,
97                cblas_triangle,
98            ),
99            Type::Sym => symm_uninit::<T, La, Lb, Dense>(
100                self.alpha,
101                self.a,
102                self.b,
103                T::zero(),
104                c,
105                cblas_side,
106                cblas_triangle,
107            ),
108            Type::Tri => {
109                let mut b_copy = DTensor::<T, 2>::from_elem(*self.b.shape(), T::zero());
110                b_copy.assign(self.b);
111                trmm(self.alpha, self.a, &mut b_copy, cblas_side, cblas_triangle);
112                b_copy
113            }
114        }
115    }
116}
117
118impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for BlasContractBuilder<'a, T, La, Lb>
119where
120    La: Layout,
121    Lb: Layout,
122    T: BlasScalar + ComplexFloat + Zero + One,
123{
124    fn scale(mut self, factor: T) -> Self {
125        self.alpha = self.alpha * factor;
126        self
127    }
128
129    fn eval(self) -> Tensor<T> {
130        _contract(Blas, self.a, self.b, self.axes, self.alpha)
131    }
132
133    fn overwrite(self, _c: &mut Slice<T>) {
134        todo!()
135    }
136}
137
138impl<T> MatMul<T> for Blas
139where
140    T: BlasScalar + ComplexFloat,
141    // i8: Into<T::Real>,
142    // T::Real: Into<T>,
143{
144    fn matmul<'a, La, Lb>(
145        &self,
146        a: &'a DSlice<T, 2, La>,
147        b: &'a DSlice<T, 2, Lb>,
148    ) -> impl MatMulBuilder<'a, T, La, Lb>
149    where
150        La: Layout,
151        Lb: Layout,
152    {
153        BlasMatMulBuilder {
154            alpha: T::one(),
155            a,
156            b,
157        }
158    }
159
160    /// Contracts all axes of the first tensor with all axes of the second tensor.
161    fn contract_all<'a, La, Lb>(
162        &self,
163        a: &'a Slice<T, DynRank, La>,
164        b: &'a Slice<T, DynRank, Lb>,
165    ) -> impl ContractBuilder<'a, T, La, Lb>
166    where
167        T: 'a,
168        La: Layout,
169        Lb: Layout,
170    {
171        BlasContractBuilder {
172            alpha: T::one(),
173            a,
174            b,
175            axes: Axes::All,
176        }
177    }
178
179    /// Contracts the last `n` axes of the first tensor with the first `n` axes of the second tensor.
180    /// # Example
181    /// For two matrices (2D tensors), `contract_n(1)` performs standard matrix multiplication.
182    fn contract_n<'a, La: Layout, Lb: Layout>(
183        &self,
184        a: &'a Slice<T, DynRank, La>,
185        b: &'a Slice<T, DynRank, Lb>,
186        n: usize,
187    ) -> impl ContractBuilder<'a, T, La, Lb>
188    where
189        T: 'a,
190    {
191        BlasContractBuilder {
192            alpha: T::one(),
193            a,
194            b,
195            axes: Axes::LastFirst { k: (n) },
196        }
197    }
198
199    /// Specifies exactly which axes to contract_all.
200    /// # Example
201    /// `specific([1, 2], [3, 4])` contracts axis 1 and 2 of `a`
202    /// with axes 3 and 4 of `b`.
203    fn contract<'a, La: Layout, Lb: Layout>(
204        &self,
205        a: &'a Slice<T, DynRank, La>,
206        b: &'a Slice<T, DynRank, Lb>,
207        axes_a: impl Into<Box<[usize]>>,
208        axes_b: impl Into<Box<[usize]>>,
209    ) -> impl ContractBuilder<'a, T, La, Lb>
210    where
211        T: 'a,
212    {
213        BlasContractBuilder {
214            alpha: T::one(),
215            a,
216            b,
217            axes: Axes::Specific(axes_a.into(), axes_b.into()),
218        }
219    }
220}