blas_array2/blas3/
gemm.rs

1use crate::ffi::{self, blas_int, c_char};
2use crate::util::*;
3use derive_builder::Builder;
4use ndarray::prelude::*;
5
6/* #region BLAS func */
7
8pub trait GEMMNum: BLASFloat {
9    unsafe fn gemm(
10        transa: *const c_char,
11        transb: *const c_char,
12        m: *const blas_int,
13        n: *const blas_int,
14        k: *const blas_int,
15        alpha: *const Self,
16        a: *const Self,
17        lda: *const blas_int,
18        b: *const Self,
19        ldb: *const blas_int,
20        beta: *const Self,
21        c: *mut Self,
22        ldc: *const blas_int,
23    );
24}
25
26macro_rules! impl_func {
27    ($type: ty, $func: ident) => {
28        impl GEMMNum for $type {
29            unsafe fn gemm(
30                transa: *const c_char,
31                transb: *const c_char,
32                m: *const blas_int,
33                n: *const blas_int,
34                k: *const blas_int,
35                alpha: *const Self,
36                a: *const Self,
37                lda: *const blas_int,
38                b: *const Self,
39                ldb: *const blas_int,
40                beta: *const Self,
41                c: *mut Self,
42                ldc: *const blas_int,
43            ) {
44                ffi::$func(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc);
45            }
46        }
47    };
48}
49
50impl_func!(f32, sgemm_);
51impl_func!(f64, dgemm_);
52impl_func!(c32, cgemm_);
53impl_func!(c64, zgemm_);
54
55/* #endregion */
56
57/* #region BLAS driver */
58
59pub struct GEMM_Driver<'a, 'b, 'c, F>
60where
61    F: GEMMNum,
62{
63    transa: c_char,
64    transb: c_char,
65    m: blas_int,
66    n: blas_int,
67    k: blas_int,
68    alpha: F,
69    a: ArrayView2<'a, F>,
70    lda: blas_int,
71    b: ArrayView2<'b, F>,
72    ldb: blas_int,
73    beta: F,
74    c: ArrayOut2<'c, F>,
75    ldc: blas_int,
76}
77
78impl<'a, 'b, 'c, F> BLASDriver<'c, F, Ix2> for GEMM_Driver<'a, 'b, 'c, F>
79where
80    F: GEMMNum,
81{
82    fn run_blas(self) -> Result<ArrayOut2<'c, F>, BLASError> {
83        let Self { transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, mut c, ldc } = self;
84        let a_ptr = a.as_ptr();
85        let b_ptr = b.as_ptr();
86        let c_ptr = c.get_data_mut_ptr();
87
88        // assuming dimension checks has been performed
89        // unconditionally return Ok if output does not contain anything
90        if m == 0 || n == 0 {
91            return Ok(c.clone_to_view_mut());
92        } else if k == 0 {
93            if beta == F::zero() {
94                c.view_mut().fill(F::zero());
95            } else if beta != F::one() {
96                c.view_mut().mapv_inplace(|v| v * beta);
97            }
98            return Ok(c.clone_to_view_mut());
99        }
100
101        unsafe {
102            F::gemm(&transa, &transb, &m, &n, &k, &alpha, a_ptr, &lda, b_ptr, &ldb, &beta, c_ptr, &ldc);
103        }
104        return Ok(c.clone_to_view_mut());
105    }
106}
107
108/* #endregion */
109
110/* #region BLAS builder */
111
112#[derive(Builder)]
113#[builder(pattern = "owned", build_fn(error = "BLASError"), no_std)]
114pub struct GEMM_<'a, 'b, 'c, F>
115where
116    F: GEMMNum,
117{
118    pub a: ArrayView2<'a, F>,
119    pub b: ArrayView2<'b, F>,
120
121    #[builder(setter(into, strip_option), default = "None")]
122    pub c: Option<ArrayViewMut2<'c, F>>,
123    #[builder(setter(into), default = "F::one()")]
124    pub alpha: F,
125    #[builder(setter(into), default = "F::zero()")]
126    pub beta: F,
127    #[builder(setter(into), default = "BLASNoTrans")]
128    pub transa: BLASTranspose,
129    #[builder(setter(into), default = "BLASNoTrans")]
130    pub transb: BLASTranspose,
131    #[builder(setter(into, strip_option), default = "None")]
132    pub layout: Option<BLASLayout>,
133}
134
135impl<'a, 'b, 'c, F> BLASBuilder_<'c, F, Ix2> for GEMM_<'a, 'b, 'c, F>
136where
137    F: GEMMNum,
138{
139    fn driver(self) -> Result<GEMM_Driver<'a, 'b, 'c, F>, BLASError> {
140        let Self { a, b, c, alpha, beta, transa, transb, layout } = self;
141
142        // only fortran-preferred (col-major) is accepted in inner wrapper
143        assert_eq!(layout, Some(BLASColMajor));
144        assert!(a.is_fpref() && b.is_fpref());
145
146        // initialize intent(hide)
147        let (m, k) = match transa {
148            BLASNoTrans => (a.len_of(Axis(0)), a.len_of(Axis(1))),
149            BLASTrans | BLASConjTrans => (a.len_of(Axis(1)), a.len_of(Axis(0))),
150            _ => blas_invalid!(transa)?,
151        };
152        let n = match transb {
153            BLASNoTrans => b.len_of(Axis(1)),
154            BLASTrans | BLASConjTrans => b.len_of(Axis(0)),
155            _ => blas_invalid!(transb)?,
156        };
157        let lda = a.stride_of(Axis(1));
158        let ldb = b.stride_of(Axis(1));
159
160        // perform check
161        match transb {
162            BLASNoTrans => blas_assert_eq!(b.len_of(Axis(0)), k, InvalidDim)?,
163            BLASTrans | BLASConjTrans => blas_assert_eq!(b.len_of(Axis(1)), k, InvalidDim)?,
164            _ => blas_invalid!(transb)?,
165        }
166
167        // optional intent(out)
168        let c = match c {
169            Some(c) => {
170                blas_assert_eq!(c.dim(), (m, n), InvalidDim)?;
171                if c.view().is_fpref() {
172                    ArrayOut2::ViewMut(c)
173                } else {
174                    let c_buffer = c.view().to_col_layout()?.into_owned();
175                    ArrayOut2::ToBeCloned(c, c_buffer)
176                }
177            },
178            None => ArrayOut2::Owned(Array2::zeros((m, n).f())),
179        };
180        let ldc = c.view().stride_of(Axis(1));
181
182        // finalize
183        let driver = GEMM_Driver {
184            transa: transa.try_into()?,
185            transb: transb.try_into()?,
186            m: m.try_into()?,
187            n: n.try_into()?,
188            k: k.try_into()?,
189            alpha,
190            a,
191            lda: lda.try_into()?,
192            b,
193            ldb: ldb.try_into()?,
194            beta,
195            c,
196            ldc: ldc.try_into()?,
197        };
198        return Ok(driver);
199    }
200}
201
202/* #endregion */
203
204/* #region BLAS wrapper */
205
206pub type GEMM<'a, 'b, 'c, F> = GEMM_Builder<'a, 'b, 'c, F>;
207pub type SGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, f32>;
208pub type DGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, f64>;
209pub type CGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, c32>;
210pub type ZGEMM<'a, 'b, 'c> = GEMM<'a, 'b, 'c, c64>;
211
212impl<'a, 'b, 'c, F> BLASBuilder<'c, F, Ix2> for GEMM_Builder<'a, 'b, 'c, F>
213where
214    F: GEMMNum,
215{
216    fn run(self) -> Result<ArrayOut2<'c, F>, BLASError> {
217        // initialize
218        let GEMM_ { a, b, c, alpha, beta, transa, transb, layout } = self.build()?;
219        let at = a.t();
220        let bt = b.t();
221
222        let layout_a = get_layout_array2(&a);
223        let layout_b = get_layout_array2(&b);
224        let layout_c = c.as_ref().map(|c| get_layout_array2(&c.view()));
225
226        let layout = get_layout_row_preferred(&[layout, layout_c], &[layout_a, layout_b]);
227        if layout == BLASColMajor {
228            // F-contiguous: C = op(A) op(B)
229            let (transa, a_cow) = flip_trans_fpref(transa, &a, &at, false)?;
230            let (transb, b_cow) = flip_trans_fpref(transb, &b, &bt, false)?;
231            let obj = GEMM_ {
232                a: a_cow.view(),
233                b: b_cow.view(),
234                c,
235                alpha,
236                beta,
237                transa,
238                transb,
239                layout: Some(BLASColMajor),
240            };
241            return obj.driver()?.run_blas();
242        } else if layout == BLASRowMajor {
243            // C-contiguous: C' = op(B') op(A')
244            let (transa, a_cow) = flip_trans_cpref(transa, &a, &at, false)?;
245            let (transb, b_cow) = flip_trans_cpref(transb, &b, &bt, false)?;
246            let obj = GEMM_ {
247                a: b_cow.t(),
248                b: a_cow.t(),
249                c: c.map(|c| c.reversed_axes()),
250                alpha,
251                beta,
252                transa: transb,
253                transb: transa,
254                layout: Some(BLASColMajor),
255            };
256            return Ok(obj.driver()?.run_blas()?.reversed_axes());
257        } else {
258            return blas_raise!(RuntimeError, "This is designed not to execuate this line.");
259        }
260    }
261}
262
263/* #endregion */