rstsr_openblas/driver_impl/cblas/blas3/
gemm.rs

1use crate::lapack_ffi;
2use crate::DeviceBLAS;
3use duplicate::duplicate_item;
4use num::Complex;
5use rstsr_blas_traits::blas3::gemm::*;
6use rstsr_common::prelude::*;
7
8#[duplicate_item(
9    T     cblas_func  ;
10   [f32] [cblas_sgemm];
11   [f64] [cblas_dgemm];
12)]
13impl GEMMDriverAPI<T> for DeviceBLAS {
14    unsafe fn driver_gemm(
15        order: FlagOrder,
16        transa: FlagTrans,
17        transb: FlagTrans,
18        m: usize,
19        n: usize,
20        k: usize,
21        alpha: T,
22        a: *const T,
23        lda: usize,
24        b: *const T,
25        ldb: usize,
26        beta: T,
27        c: *mut T,
28        ldc: usize,
29    ) {
30        lapack_ffi::cblas::cblas_func(
31            order.into(),
32            transa.into(),
33            transb.into(),
34            m as _,
35            n as _,
36            k as _,
37            alpha,
38            a,
39            lda as _,
40            b,
41            ldb as _,
42            beta,
43            c,
44            ldc as _,
45        );
46    }
47}
48
49#[duplicate_item(
50    T              cblas_func  ;
51   [Complex<f32>] [cblas_cgemm];
52   [Complex<f64>] [cblas_zgemm];
53)]
54impl GEMMDriverAPI<T> for DeviceBLAS {
55    unsafe fn driver_gemm(
56        order: FlagOrder,
57        transa: FlagTrans,
58        transb: FlagTrans,
59        m: usize,
60        n: usize,
61        k: usize,
62        alpha: T,
63        a: *const T,
64        lda: usize,
65        b: *const T,
66        ldb: usize,
67        beta: T,
68        c: *mut T,
69        ldc: usize,
70    ) {
71        lapack_ffi::cblas::cblas_func(
72            order.into(),
73            transa.into(),
74            transb.into(),
75            m as _,
76            n as _,
77            k as _,
78            &alpha as *const _ as *const _,
79            a as *const _,
80            lda as _,
81            b as *const _,
82            ldb as _,
83            &beta as *const _ as *const _,
84            c as *mut _,
85            ldc as _,
86        );
87    }
88}