rstsr_openblas/driver_impl/cblas/blas3/
gemm.rs1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use duplicate::duplicate_item;
4use num::Complex;
5use rstsr_blas_traits::blas3::gemm::*;
6use rstsr_common::prelude::*;
7
8#[duplicate_item(
9 T cblas_func ;
10 [f32] [cblas_sgemm];
11 [f64] [cblas_dgemm];
12)]
13impl GEMMDriverAPI<T> for DeviceBLAS {
14 unsafe fn driver_gemm(
15 order: FlagOrder,
16 transa: FlagTrans,
17 transb: FlagTrans,
18 m: usize,
19 n: usize,
20 k: usize,
21 alpha: T,
22 a: *const T,
23 lda: usize,
24 b: *const T,
25 ldb: usize,
26 beta: T,
27 c: *mut T,
28 ldc: usize,
29 ) {
30 lapack_ffi::cblas::cblas_func(
31 order.into(),
32 transa.into(),
33 transb.into(),
34 m as _,
35 n as _,
36 k as _,
37 alpha,
38 a,
39 lda as _,
40 b,
41 ldb as _,
42 beta,
43 c,
44 ldc as _,
45 );
46 }
47}
48
49#[duplicate_item(
50 T cblas_func ;
51 [Complex<f32>] [cblas_cgemm];
52 [Complex<f64>] [cblas_zgemm];
53)]
54impl GEMMDriverAPI<T> for DeviceBLAS {
55 unsafe fn driver_gemm(
56 order: FlagOrder,
57 transa: FlagTrans,
58 transb: FlagTrans,
59 m: usize,
60 n: usize,
61 k: usize,
62 alpha: T,
63 a: *const T,
64 lda: usize,
65 b: *const T,
66 ldb: usize,
67 beta: T,
68 c: *mut T,
69 ldc: usize,
70 ) {
71 lapack_ffi::cblas::cblas_func(
72 order.into(),
73 transa.into(),
74 transb.into(),
75 m as _,
76 n as _,
77 k as _,
78 &alpha as *const _ as *const _,
79 a as *const _,
80 lda as _,
81 b as *const _,
82 ldb as _,
83 &beta as *const _ as *const _,
84 c as *mut _,
85 ldc as _,
86 );
87 }
88}