1use crate::error::{CudaError, CudaResult};
4use crate::ffi::cublasHandle_t;
5use crate::stream::CudaStream;
6use crate::tensor::CudaTensor;
7
8pub struct CuBlas {
10 #[allow(dead_code)]
11 handle: cublasHandle_t,
12}
13
14impl CuBlas {
15 pub fn new() -> CudaResult<Self> {
17 #[cfg(feature = "cuda")]
18 {
19 let mut handle: cublasHandle_t = std::ptr::null_mut();
20
21 unsafe {
22 let status = ffi::cublasCreate_v2(&mut handle);
23 if status != 0 {
24 return Err(CudaError::CublasError(status));
25 }
26 }
27
28 Ok(CuBlas { handle })
29 }
30
31 #[cfg(not(feature = "cuda"))]
32 {
33 Ok(CuBlas {
34 handle: std::ptr::null_mut(),
35 })
36 }
37 }
38
39 #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
41 pub fn set_stream(&self, stream: &CudaStream) -> CudaResult<()> {
42 #[cfg(feature = "cuda")]
43 unsafe {
44 let status = ffi::cublasSetStream_v2(self.handle, stream.handle());
45 if status != 0 {
46 return Err(CudaError::CublasError(status));
47 }
48 }
49 Ok(())
50 }
51
52 pub fn sgemm(
71 &self,
72 trans_a: bool,
73 trans_b: bool,
74 m: i32,
75 n: i32,
76 k: i32,
77 alpha: f32,
78 a: *const f32,
79 lda: i32,
80 b: *const f32,
81 ldb: i32,
82 beta: f32,
83 c: *mut f32,
84 ldc: i32,
85 ) -> CudaResult<()> {
86 #[cfg(feature = "cuda")]
87 unsafe {
88 let op_a = if trans_a { CUBLAS_OP_T } else { CUBLAS_OP_N };
89 let op_b = if trans_b { CUBLAS_OP_T } else { CUBLAS_OP_N };
90
91 let status = ffi::cublasSgemm_v2(
92 self.handle,
93 op_a,
94 op_b,
95 m, n, k,
96 &alpha,
97 a, lda,
98 b, ldb,
99 &beta,
100 c, ldc,
101 );
102
103 if status != 0 {
104 return Err(CudaError::CublasError(status));
105 }
106 }
107
108 #[cfg(not(feature = "cuda"))]
109 {
110 unsafe {
112 for i in 0..m as usize {
113 for j in 0..n as usize {
114 let mut sum = if beta != 0.0 {
115 beta * *c.add(i + j * ldc as usize)
116 } else {
117 0.0
118 };
119
120 for l in 0..k as usize {
121 let a_idx = if trans_a { l + i * lda as usize } else { i + l * lda as usize };
122 let b_idx = if trans_b { j + l * ldb as usize } else { l + j * ldb as usize };
123 sum += alpha * *a.add(a_idx) * *b.add(b_idx);
124 }
125
126 *c.add(i + j * ldc as usize) = sum;
127 }
128 }
129 }
130 }
131
132 Ok(())
133 }
134
135 pub fn saxpy(
137 &self,
138 n: i32,
139 alpha: f32,
140 x: *const f32,
141 incx: i32,
142 y: *mut f32,
143 incy: i32,
144 ) -> CudaResult<()> {
145 #[cfg(feature = "cuda")]
146 unsafe {
147 let status = ffi::cublasSaxpy_v2(
148 self.handle,
149 n,
150 &alpha,
151 x, incx,
152 y, incy,
153 );
154
155 if status != 0 {
156 return Err(CudaError::CublasError(status));
157 }
158 }
159
160 #[cfg(not(feature = "cuda"))]
161 unsafe {
162 for i in 0..n as usize {
163 let xi = *x.add(i * incx as usize);
164 let yi = y.add(i * incy as usize);
165 *yi = alpha * xi + *yi;
166 }
167 }
168
169 Ok(())
170 }
171
172 pub fn sdot(
174 &self,
175 n: i32,
176 x: *const f32,
177 incx: i32,
178 y: *const f32,
179 incy: i32,
180 ) -> CudaResult<f32> {
181 let mut result: f32 = 0.0;
182
183 #[cfg(feature = "cuda")]
184 unsafe {
185 let status = ffi::cublasSdot_v2(
186 self.handle,
187 n,
188 x, incx,
189 y, incy,
190 &mut result,
191 );
192
193 if status != 0 {
194 return Err(CudaError::CublasError(status));
195 }
196 }
197
198 #[cfg(not(feature = "cuda"))]
199 unsafe {
200 for i in 0..n as usize {
201 result += *x.add(i * incx as usize) * *y.add(i * incy as usize);
202 }
203 }
204
205 Ok(result)
206 }
207
208 pub fn snrm2(
210 &self,
211 n: i32,
212 x: *const f32,
213 incx: i32,
214 ) -> CudaResult<f32> {
215 #[cfg(feature = "cuda")]
216 {
217 let mut result: f32 = 0.0;
218 unsafe {
219 let status = ffi::cublasSnrm2_v2(
220 self.handle,
221 n,
222 x, incx,
223 &mut result,
224 );
225
226 if status != 0 {
227 return Err(CudaError::CublasError(status));
228 }
229 }
230 Ok(result)
231 }
232
233 #[cfg(not(feature = "cuda"))]
234 unsafe {
235 let mut sum = 0.0f32;
236 for i in 0..n as usize {
237 let xi = *x.add(i * incx as usize);
238 sum += xi * xi;
239 }
240 Ok(sum.sqrt())
241 }
242 }
243
244 pub fn sscal(
246 &self,
247 n: i32,
248 alpha: f32,
249 x: *mut f32,
250 incx: i32,
251 ) -> CudaResult<()> {
252 #[cfg(feature = "cuda")]
253 unsafe {
254 let status = ffi::cublasSscal_v2(
255 self.handle,
256 n,
257 &alpha,
258 x, incx,
259 );
260
261 if status != 0 {
262 return Err(CudaError::CublasError(status));
263 }
264 }
265
266 #[cfg(not(feature = "cuda"))]
267 unsafe {
268 for i in 0..n as usize {
269 let xi = x.add(i * incx as usize);
270 *xi = alpha * *xi;
271 }
272 }
273
274 Ok(())
275 }
276
277 pub fn matmul(&self, a: &CudaTensor, b: &CudaTensor) -> CudaResult<CudaTensor> {
279 let a_dims = a.dims();
280 let b_dims = b.dims();
281
282 if a_dims.len() != 2 || b_dims.len() != 2 {
283 return Err(CudaError::InvalidValue("matmul requires 2D tensors".into()));
284 }
285
286 let m = a_dims[0] as i32;
287 let k = a_dims[1] as i32;
288 let n = b_dims[1] as i32;
289
290 if k != b_dims[0] as i32 {
291 return Err(CudaError::InvalidValue(format!(
292 "Matrix dimensions don't match: [{}, {}] x [{}, {}]",
293 m, k, b_dims[0], n
294 )));
295 }
296
297 let mut c = CudaTensor::zeros(&[m as usize, n as usize], a.device_id())?;
299
300 self.sgemm(
303 false, false,
304 n, m, k,
305 1.0,
306 b.as_ptr() as *const f32, n,
307 a.as_ptr() as *const f32, k,
308 0.0,
309 c.as_mut_ptr() as *mut f32, n,
310 )?;
311
312 Ok(c)
313 }
314}
315
316impl Drop for CuBlas {
317 fn drop(&mut self) {
318 #[cfg(feature = "cuda")]
319 if !self.handle.is_null() {
320 unsafe {
321 let _ = ffi::cublasDestroy_v2(self.handle);
322 }
323 }
324 }
325}
326
327unsafe impl Send for CuBlas {}
328unsafe impl Sync for CuBlas {}