1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6pub trait GEMMNum: BLASFloat {
9 unsafe fn gemm(
10 transa: *const c_char,
11 transb: *const c_char,
12 m: *const blas_int,
13 n: *const blas_int,
14 k: *const blas_int,
15 alpha: *const Self,
16 a: *const Self,
17 lda: *const blas_int,
18 b: *const Self,
19 ldb: *const blas_int,
20 beta: *const Self,
21 c: *mut Self,
22 ldc: *const blas_int,
23 );
24}
25
26macro_rules! impl_func {
27 ($type: ty, $func: ident) => {
28 impl GEMMNum for $type {
29 unsafe fn gemm(
30 transa: *const c_char,
31 transb: *const c_char,
32 m: *const blas_int,
33 n: *const blas_int,
34 k: *const blas_int,
35 alpha: *const Self,
36 a: *const Self,
37 lda: *const blas_int,
38 b: *const Self,
39 ldb: *const blas_int,
40 beta: *const Self,
41 c: *mut Self,
42 ldc: *const blas_int,
43 ) {
44 ffi::$func(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
45 }
46 }
47 };
48}
49
50impl_func!(f32, sgemm_);
51impl_func!(f64, dgemm_);
52impl_func!(c32, cgemm_);
53impl_func!(c64, zgemm_);
54
55pub struct GEMM_Driver<'a, 'b, 'c, F>
60where
61 F: GEMMNum,
62{
63 transa: c_char,
64 transb: c_char,
65 m: blas_int,
66 n: blas_int,
67 k: blas_int,
68 alpha: F,
69 a: ArrayView2<'a, F>,
70 lda: blas_int,
71 b: ArrayView2<'b, F>,
72 ldb: blas_int,
73 beta: F,
74 c: ArrayOut2<'c, F>,
75 ldc: blas_int,
76}
77
78impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for GEMM_Driver<'a, 'b, 'c, F>
79where
80 F: GEMMNum,
81{
82 fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
83 let Self { transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
84 let a_ptr = a.as_ptr();
85 let b_ptr = b.as_ptr();
86 let c_ptr = c.get_data_mut_ptr();
87
88 if m == 0 || n == 0 {
91 return Ok(c.clone_to_view_mut());
92 } else if k == 0 {
93 if beta == F::zero() {
94 c.view_mut().fill(F::zero());
95 } else if beta != F::one() {
96 c.view_mut().mapv_inplace(|v| v * beta);
97 }
98 return Ok(c.clone_to_view_mut());
99 }
100
101 unsafe {
102 F::gemm(&transa, &transb, &m, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
103 }
104 return Ok(c.clone_to_view_mut());
105 }
106}
107
108#[derive(Builder)]
113#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
114pub struct GEMM_<'a, 'b, 'c, F>
115where
116 F: GEMMNum,
117{
118 pub a: ArrayView2<'a, F>,
119 pub b: ArrayView2<'b, F>,
120
121 #[builder(setter(into, strip_option), default = "None")]
122 pub c: Option<ArrayViewMut2<'c, F>>,
123 #[builder(setter(into), default = "F::one()")]
124 pub alpha: F,
125 #[builder(setter(into), default = "F::zero()")]
126 pub beta: F,
127 #[builder(setter(into), default = "BLASNoTrans")]
128 pub transa: BLASTranspose,
129 #[builder(setter(into), default = "BLASNoTrans")]
130 pub transb: BLASTranspose,
131 #[builder(setter(into, strip_option), default = "None")]
132 pub layout: Option<BLASLayout>,
133}
134
135impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for GEMM_<'a, 'b, 'c, F>
136where
137 F: GEMMNum,
138{
139 fn driver(self) -> Result<GEMM_Driver<'a, 'b, 'c, F>, BLASError> {
140 let Self { a, b, c, alpha, beta, transa, transb, layout } = self;
141
142 assert_eq!(layout, Some(BLASColMajor));
144 assert!(a.is_fpref() && b.is_fpref());
145
146 let (m, k) = match transa {
148 BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
149 BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
150 _ => blas_invalid!(transa)?,
151 };
152 let n = match transb {
153 BLASNoTrans => b.len_of(Axis(1)),
154 BLASTrans | BLASConjTrans => b.len_of(Axis(0)),
155 _ => blas_invalid!(transb)?,
156 };
157 let lda = a.stride_of(Axis(1));
158 let ldb = b.stride_of(Axis(1));
159
160 match transb {
162 BLASNoTrans => blas_assert_eq!(b.len_of(Axis(0)), k, InvalidDim)?,
163 BLASTrans | BLASConjTrans => blas_assert_eq!(b.len_of(Axis(1)), k, InvalidDim)?,
164 _ => blas_invalid!(transb)?,
165 }
166
167 let c = match c {
169 Some(c) => {
170 blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
171 if c.view().is_fpref() {
172 ArrayOut2::ViewMut(c)
173 } else {
174 let c_buffer = c.view().to_col_layout()?.into_owned();
175 ArrayOut2::ToBeCloned(c, c_buffer)
176 }
177 },
178 None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
179 };
180 let ldc = c.view().stride_of(Axis(1));
181
182 let driver = GEMM_Driver {
184 transa: transa.try_into()?,
185 transb: transb.try_into()?,
186 m: m.try_into()?,
187 n: n.try_into()?,
188 k: k.try_into()?,
189 alpha,
190 a,
191 lda: lda.try_into()?,
192 b,
193 ldb: ldb.try_into()?,
194 beta,
195 c,
196 ldc: ldc.try_into()?,
197 };
198 return Ok(driver);
199 }
200}
201
202pub type GEMM<'a, 'b, 'c, F> = GEMM_Builder<'a, 'b, 'c, F>;
207pub type SGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, f32>;
208pub type DGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, f64>;
209pub type CGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, c32>;
210pub type ZGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, c64>;
211
212impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for GEMM_Builder<'a, 'b, 'c, F>
213where
214 F: GEMMNum,
215{
216 fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
217 let GEMM_ { a, b, c, alpha, beta, transa, transb, layout } = self.build()?;
219 let at = a.t();
220 let bt = b.t();
221
222 let layout_a = get_layout_array2(&a);
223 let layout_b = get_layout_array2(&b);
224 let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
225
226 let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
227 if layout == BLASColMajor {
228 let (transa, a_cow) = flip_trans_fpref(transa, &a, &at, false)?;
230 let (transb, b_cow) = flip_trans_fpref(transb, &b, &bt, false)?;
231 let obj = GEMM_ {
232 a: a_cow.view(),
233 b: b_cow.view(),
234 c,
235 alpha,
236 beta,
237 transa,
238 transb,
239 layout: Some(BLASColMajor),
240 };
241 return obj.driver()?.run_blas();
242 } else if layout == BLASRowMajor {
243 let (transa, a_cow) = flip_trans_cpref(transa, &a, &at, false)?;
245 let (transb, b_cow) = flip_trans_cpref(transb, &b, &bt, false)?;
246 let obj = GEMM_ {
247 a: b_cow.t(),
248 b: a_cow.t(),
249 c: c.map(|c| c.reversed_axes()),
250 alpha,
251 beta,
252 transa: transb,
253 transb: transa,
254 layout: Some(BLASColMajor),
255 };
256 return Ok(obj.driver()?.run_blas()?.reversed_axes());
257 } else {
258 return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
259 }
260 }
261}
262
263