1use num_traits::{One, Zero};
2use std::mem::MaybeUninit;
3
4use cblas_sys::{CBLAS_SIDE, CBLAS_UPLO};
5use mdarray::{DSlice, DTensor, Dense, DynRank, Layout, Slice, Tensor, tensor};
6use num_complex::ComplexFloat;
7
8use mdarray_linalg::matmul::{
9 Axes, MatMul, MatMulBuilder, Side, ContractBuilder, Triangle, Type, _contract,
10};
11
12use super::scalar::BlasScalar;
13use super::simple::{gemm, gemm_uninit, hemm_uninit, symm_uninit, trmm};
14
15use crate::Blas;
16
17struct BlasMatMulBuilder<'a, T, La, Lb>
18where
19 La: Layout,
20 Lb: Layout,
21{
22 alpha: T,
23 a: &'a DSlice<T, 2, La>,
24 b: &'a DSlice<T, 2, Lb>,
25}
26
27struct BlasContractBuilder<'a, T, La, Lb>
28where
29 La: Layout,
30 Lb: Layout,
31{
32 alpha: T,
33 a: &'a Slice<T, DynRank, La>,
34 b: &'a Slice<T, DynRank, Lb>,
35 axes: Axes,
36}
37
38impl<'a, T, La, Lb> MatMulBuilder<'a, T, La, Lb> for BlasMatMulBuilder<'a, T, La, Lb>
39where
40 La: Layout,
41 Lb: Layout,
42 T: BlasScalar + ComplexFloat + Zero + One,
43 {
46 fn parallelize(self) -> Self {
47 self
48 }
49
50 fn scale(mut self, factor: T) -> Self {
51 self.alpha = self.alpha * factor;
52 self
53 }
54
55 fn eval(self) -> DTensor<T, 2> {
56 let (m, _) = *self.a.shape();
57 let (_, n) = *self.b.shape();
58 let c = tensor![[MaybeUninit::<T>::uninit(); n]; m];
59 gemm_uninit::<T, La, Lb, Dense>(self.alpha, self.a, self.b, T::zero(), c)
60 }
64
65 fn overwrite<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
66 gemm(self.alpha, self.a, self.b, T::zero(), c);
67 }
68
69 fn add_to<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>) {
70 gemm(self.alpha, self.a, self.b, T::one(), c);
71 }
72
73 fn add_to_scaled<Lc: Layout>(self, c: &mut DSlice<T, 2, Lc>, beta: T) {
74 gemm(self.alpha, self.a, self.b, beta, c);
75 }
76
77 fn special(self, lr: Side, type_of_matrix: Type, tr: Triangle) -> DTensor<T, 2> {
78 let (m, _) = *self.a.shape();
79 let (_, n) = *self.b.shape();
80 let c = tensor![[MaybeUninit::<T>::uninit(); n]; m];
81 let cblas_side = match lr {
82 Side::Left => CBLAS_SIDE::CblasLeft,
83 Side::Right => CBLAS_SIDE::CblasRight,
84 };
85 let cblas_triangle = match tr {
86 Triangle::Lower => CBLAS_UPLO::CblasLower,
87 Triangle::Upper => CBLAS_UPLO::CblasUpper,
88 };
89 match type_of_matrix {
90 Type::Her => hemm_uninit::<T, La, Lb, Dense>(
91 self.alpha,
92 self.a,
93 self.b,
94 T::zero(),
95 c,
96 cblas_side,
97 cblas_triangle,
98 ),
99 Type::Sym => symm_uninit::<T, La, Lb, Dense>(
100 self.alpha,
101 self.a,
102 self.b,
103 T::zero(),
104 c,
105 cblas_side,
106 cblas_triangle,
107 ),
108 Type::Tri => {
109 let mut b_copy = DTensor::<T, 2>::from_elem(*self.b.shape(), T::zero());
110 b_copy.assign(self.b);
111 trmm(self.alpha, self.a, &mut b_copy, cblas_side, cblas_triangle);
112 b_copy
113 }
114 }
115 }
116}
117
118impl<'a, T, La, Lb> ContractBuilder<'a, T, La, Lb> for BlasContractBuilder<'a, T, La, Lb>
119where
120 La: Layout,
121 Lb: Layout,
122 T: BlasScalar + ComplexFloat + Zero + One,
123{
124 fn scale(mut self, factor: T) -> Self {
125 self.alpha = self.alpha * factor;
126 self
127 }
128
129 fn eval(self) -> Tensor<T> {
130 _contract(Blas, self.a, self.b, self.axes, self.alpha)
131 }
132
133 fn overwrite(self, _c: &mut Slice<T>) {
134 todo!()
135 }
136}
137
138impl<T> MatMul<T> for Blas
139where
140 T: BlasScalar + ComplexFloat,
141 {
144 fn matmul<'a, La, Lb>(
145 &self,
146 a: &'a DSlice<T, 2, La>,
147 b: &'a DSlice<T, 2, Lb>,
148 ) -> impl MatMulBuilder<'a, T, La, Lb>
149 where
150 La: Layout,
151 Lb: Layout,
152 {
153 BlasMatMulBuilder {
154 alpha: T::one(),
155 a,
156 b,
157 }
158 }
159
160 fn contract_all<'a, La, Lb>(
162 &self,
163 a: &'a Slice<T, DynRank, La>,
164 b: &'a Slice<T, DynRank, Lb>,
165 ) -> impl ContractBuilder<'a, T, La, Lb>
166 where
167 T: 'a,
168 La: Layout,
169 Lb: Layout,
170 {
171 BlasContractBuilder {
172 alpha: T::one(),
173 a,
174 b,
175 axes: Axes::All,
176 }
177 }
178
179 fn contract_n<'a, La: Layout, Lb: Layout>(
183 &self,
184 a: &'a Slice<T, DynRank, La>,
185 b: &'a Slice<T, DynRank, Lb>,
186 n: usize,
187 ) -> impl ContractBuilder<'a, T, La, Lb>
188 where
189 T: 'a,
190 {
191 BlasContractBuilder {
192 alpha: T::one(),
193 a,
194 b,
195 axes: Axes::LastFirst { k: (n) },
196 }
197 }
198
199 fn contract<'a, La: Layout, Lb: Layout>(
204 &self,
205 a: &'a Slice<T, DynRank, La>,
206 b: &'a Slice<T, DynRank, Lb>,
207 axes_a: impl Into<Box<[usize]>>,
208 axes_b: impl Into<Box<[usize]>>,
209 ) -> impl ContractBuilder<'a, T, La, Lb>
210 where
211 T: 'a,
212 {
213 BlasContractBuilder {
214 alpha: T::one(),
215 a,
216 b,
217 axes: Axes::Specific(axes_a.into(), axes_b.into()),
218 }
219 }
220}