1#[cfg(feature = "blas")]
4#[cfg(feature = "cpu")]
5use super::cpu::{
6 api::{cblas_dgemm, cblas_sgemm},
7 Order, Transpose,
8};
9
10#[cfg(feature = "cuda")]
11use super::cuda::api::{
12 cublas::{cublasDgemm_v2, cublasOperation_t, cublasSgemm_v2, CublasHandle},
13 CUdeviceptr,
14};
15
16pub trait GenericBlas
18where
19 Self: Sized,
20{
21 #[cfg(feature = "blas")]
23 #[cfg(feature = "cpu")]
24 #[allow(clippy::too_many_arguments)]
25 fn blas_gemm(
26 order: Order,
27 trans_a: Transpose,
28 trans_b: Transpose,
29 m: usize,
30 n: usize,
31 k: usize,
32 a: &[Self],
33 lda: usize,
34 b: &[Self],
35 ldb: usize,
36 c: &mut [Self],
37 ldc: usize,
38 );
39
40 #[cfg(feature = "blas")]
42 #[cfg(feature = "cpu")]
43 #[inline]
44 fn gemm(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
45 Self::blas_gemm(
46 Order::RowMajor,
47 Transpose::NoTrans,
48 Transpose::NoTrans,
49 m,
50 n,
51 k,
52 a,
53 k,
54 b,
55 n,
56 c,
57 n,
58 )
59 }
60
61 #[cfg(feature = "blas")]
64 #[cfg(feature = "cpu")]
65 #[inline]
66 #[allow(non_snake_case)]
67 fn gemmT(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
68 Self::blas_gemm(
69 Order::RowMajor,
70 Transpose::NoTrans,
71 Transpose::Trans,
72 m,
73 n,
74 k,
75 a,
76 k,
77 b,
78 k,
79 c,
80 n,
81 )
82 }
83
84 #[cfg(feature = "blas")]
87 #[cfg(feature = "cpu")]
88 #[inline]
89 #[allow(non_snake_case)]
90 fn Tgemm(m: usize, n: usize, k: usize, a: &[Self], b: &[Self], c: &mut [Self]) {
91 Self::blas_gemm(
92 Order::RowMajor,
93 Transpose::Trans,
94 Transpose::NoTrans,
95 m,
96 n,
97 k,
98 a,
99 m,
100 b,
101 n,
102 c,
103 n,
104 )
105 }
106
107 #[cfg(feature = "cuda")]
109 fn cugemm(
110 handle: &CublasHandle,
111 m: usize,
112 n: usize,
113 k: usize,
114 a: CUdeviceptr,
115 b: CUdeviceptr,
116 c: CUdeviceptr,
117 ) -> crate::Result<()>;
118}
119
120impl GenericBlas for f32 {
121 #[cfg(feature = "blas")]
122 #[cfg(feature = "cpu")]
123 #[inline]
124 fn blas_gemm(
125 order: Order,
126 trans_a: Transpose,
127 trans_b: Transpose,
128 m: usize,
129 n: usize,
130 k: usize,
131 a: &[Self],
132 lda: usize,
133 b: &[Self],
134 ldb: usize,
135 c: &mut [Self],
136 ldc: usize,
137 ) {
138 unsafe {
139 cblas_sgemm(
140 order,
141 trans_a,
142 trans_b,
143 m,
144 n,
145 k,
146 1.0,
147 a.as_ptr(),
148 lda,
149 b.as_ptr(),
150 ldb,
151 0.0,
152 c.as_mut_ptr(),
153 ldc,
154 )
155 };
156 }
157 #[cfg(feature = "cuda")]
158 #[inline]
159 fn cugemm(
160 handle: &CublasHandle,
161 m: usize,
162 n: usize,
163 k: usize,
164 a: CUdeviceptr,
165 b: CUdeviceptr,
166 c: CUdeviceptr,
167 ) -> crate::Result<()> {
168 unsafe {
169 cublasSgemm_v2(
170 handle.0,
171 cublasOperation_t::CUBLAS_OP_N,
172 cublasOperation_t::CUBLAS_OP_N,
173 n as i32,
174 m as i32,
175 k as i32,
176 &1f32 as *const f32,
177 b as *const u64 as *const f32,
178 n as i32,
179 a as *const u64 as *const f32,
180 k as i32,
181 &0f32 as *const f32,
182 c as *mut u64 as *mut f32,
183 n as i32,
184 )
185 }
186 .to_result()?;
187 Ok(())
188 }
189}
190
191impl GenericBlas for f64 {
192 #[cfg(feature = "blas")]
193 #[cfg(feature = "cpu")]
194 #[inline]
195 fn blas_gemm(
196 order: Order,
197 trans_a: Transpose,
198 trans_b: Transpose,
199 m: usize,
200 n: usize,
201 k: usize,
202 a: &[Self],
203 lda: usize,
204 b: &[Self],
205 ldb: usize,
206 c: &mut [Self],
207 ldc: usize,
208 ) {
209 unsafe {
210 cblas_dgemm(
211 order,
212 trans_a,
213 trans_b,
214 m,
215 n,
216 k,
217 1.0,
218 a.as_ptr(),
219 lda,
220 b.as_ptr(),
221 ldb,
222 0.0,
223 c.as_mut_ptr(),
224 ldc,
225 )
226 };
227 }
228 #[cfg(feature = "cuda")]
229 #[inline]
230 fn cugemm(
231 handle: &CublasHandle,
232 m: usize,
233 n: usize,
234 k: usize,
235 a: CUdeviceptr,
236 b: CUdeviceptr,
237 c: CUdeviceptr,
238 ) -> crate::Result<()> {
239 unsafe {
240 cublasDgemm_v2(
241 handle.0,
242 cublasOperation_t::CUBLAS_OP_N,
243 cublasOperation_t::CUBLAS_OP_N,
244 n as i32,
245 m as i32,
246 k as i32,
247 &1f64 as *const f64,
248 b as *const u64 as *const f64,
249 n as i32,
250 a as *const u64 as *const f64,
251 k as i32,
252 &0f64 as *const f64,
253 c as *mut u64 as *mut f64,
254 n as i32,
255 )
256 }
257 .to_result()?;
258 Ok(())
259 }
260}