ghostflow_cuda/
blas.rs

1//! cuBLAS wrapper for linear algebra operations
2
3use crate::error::{CudaError, CudaResult};
4use crate::ffi::cublasHandle_t;
5use crate::stream::CudaStream;
6use crate::tensor::CudaTensor;
7
8/// cuBLAS handle wrapper
9pub struct CuBlas {
10    #[allow(dead_code)]
11    handle: cublasHandle_t,
12}
13
14impl CuBlas {
15    /// Create new cuBLAS handle
16    pub fn new() -> CudaResult<Self> {
17        #[cfg(feature = "cuda")]
18        {
19            let mut handle: cublasHandle_t = std::ptr::null_mut();
20            
21            unsafe {
22                let status = ffi::cublasCreate_v2(&mut handle);
23                if status != 0 {
24                    return Err(CudaError::CublasError(status));
25                }
26            }
27            
28            Ok(CuBlas { handle })
29        }
30        
31        #[cfg(not(feature = "cuda"))]
32        {
33            Ok(CuBlas {
34                handle: std::ptr::null_mut(),
35            })
36        }
37    }
38
39    /// Set stream for cuBLAS operations
40    #[cfg_attr(not(feature = "cuda"), allow(unused_variables))]
41    pub fn set_stream(&self, stream: &CudaStream) -> CudaResult<()> {
42        #[cfg(feature = "cuda")]
43        unsafe {
44            let status = ffi::cublasSetStream_v2(self.handle, stream.handle());
45            if status != 0 {
46                return Err(CudaError::CublasError(status));
47            }
48        }
49        Ok(())
50    }
51
52    /// SGEMM: C = alpha * op(A) * op(B) + beta * C
53    /// 
54    /// This is the core matrix multiplication operation.
55    /// 
56    /// # Arguments
57    /// * `trans_a` - Whether to transpose A
58    /// * `trans_b` - Whether to transpose B
59    /// * `m` - Number of rows of op(A) and C
60    /// * `n` - Number of columns of op(B) and C
61    /// * `k` - Number of columns of op(A) and rows of op(B)
62    /// * `alpha` - Scalar multiplier for A*B
63    /// * `a` - Matrix A
64    /// * `lda` - Leading dimension of A
65    /// * `b` - Matrix B
66    /// * `ldb` - Leading dimension of B
67    /// * `beta` - Scalar multiplier for C
68    /// * `c` - Matrix C (output)
69    /// * `ldc` - Leading dimension of C
70    pub fn sgemm(
71        &self,
72        trans_a: bool,
73        trans_b: bool,
74        m: i32,
75        n: i32,
76        k: i32,
77        alpha: f32,
78        a: *const f32,
79        lda: i32,
80        b: *const f32,
81        ldb: i32,
82        beta: f32,
83        c: *mut f32,
84        ldc: i32,
85    ) -> CudaResult<()> {
86        #[cfg(feature = "cuda")]
87        unsafe {
88            let op_a = if trans_a { CUBLAS_OP_T } else { CUBLAS_OP_N };
89            let op_b = if trans_b { CUBLAS_OP_T } else { CUBLAS_OP_N };
90            
91            let status = ffi::cublasSgemm_v2(
92                self.handle,
93                op_a,
94                op_b,
95                m, n, k,
96                &alpha,
97                a, lda,
98                b, ldb,
99                &beta,
100                c, ldc,
101            );
102            
103            if status != 0 {
104                return Err(CudaError::CublasError(status));
105            }
106        }
107        
108        #[cfg(not(feature = "cuda"))]
109        {
110            // CPU fallback - naive implementation
111            unsafe {
112                for i in 0..m as usize {
113                    for j in 0..n as usize {
114                        let mut sum = if beta != 0.0 {
115                            beta * *c.add(i + j * ldc as usize)
116                        } else {
117                            0.0
118                        };
119                        
120                        for l in 0..k as usize {
121                            let a_idx = if trans_a { l + i * lda as usize } else { i + l * lda as usize };
122                            let b_idx = if trans_b { j + l * ldb as usize } else { l + j * ldb as usize };
123                            sum += alpha * *a.add(a_idx) * *b.add(b_idx);
124                        }
125                        
126                        *c.add(i + j * ldc as usize) = sum;
127                    }
128                }
129            }
130        }
131        
132        Ok(())
133    }
134
135    /// SAXPY: y = alpha * x + y
136    pub fn saxpy(
137        &self,
138        n: i32,
139        alpha: f32,
140        x: *const f32,
141        incx: i32,
142        y: *mut f32,
143        incy: i32,
144    ) -> CudaResult<()> {
145        #[cfg(feature = "cuda")]
146        unsafe {
147            let status = ffi::cublasSaxpy_v2(
148                self.handle,
149                n,
150                &alpha,
151                x, incx,
152                y, incy,
153            );
154            
155            if status != 0 {
156                return Err(CudaError::CublasError(status));
157            }
158        }
159        
160        #[cfg(not(feature = "cuda"))]
161        unsafe {
162            for i in 0..n as usize {
163                let xi = *x.add(i * incx as usize);
164                let yi = y.add(i * incy as usize);
165                *yi = alpha * xi + *yi;
166            }
167        }
168        
169        Ok(())
170    }
171
172    /// SDOT: result = x . y
173    pub fn sdot(
174        &self,
175        n: i32,
176        x: *const f32,
177        incx: i32,
178        y: *const f32,
179        incy: i32,
180    ) -> CudaResult<f32> {
181        let mut result: f32 = 0.0;
182        
183        #[cfg(feature = "cuda")]
184        unsafe {
185            let status = ffi::cublasSdot_v2(
186                self.handle,
187                n,
188                x, incx,
189                y, incy,
190                &mut result,
191            );
192            
193            if status != 0 {
194                return Err(CudaError::CublasError(status));
195            }
196        }
197        
198        #[cfg(not(feature = "cuda"))]
199        unsafe {
200            for i in 0..n as usize {
201                result += *x.add(i * incx as usize) * *y.add(i * incy as usize);
202            }
203        }
204        
205        Ok(result)
206    }
207
208    /// SNRM2: result = ||x||_2
209    pub fn snrm2(
210        &self,
211        n: i32,
212        x: *const f32,
213        incx: i32,
214    ) -> CudaResult<f32> {
215        #[cfg(feature = "cuda")]
216        {
217            let mut result: f32 = 0.0;
218            unsafe {
219                let status = ffi::cublasSnrm2_v2(
220                    self.handle,
221                    n,
222                    x, incx,
223                    &mut result,
224                );
225                
226                if status != 0 {
227                    return Err(CudaError::CublasError(status));
228                }
229            }
230            Ok(result)
231        }
232        
233        #[cfg(not(feature = "cuda"))]
234        unsafe {
235            let mut sum = 0.0f32;
236            for i in 0..n as usize {
237                let xi = *x.add(i * incx as usize);
238                sum += xi * xi;
239            }
240            Ok(sum.sqrt())
241        }
242    }
243
244    /// SSCAL: x = alpha * x
245    pub fn sscal(
246        &self,
247        n: i32,
248        alpha: f32,
249        x: *mut f32,
250        incx: i32,
251    ) -> CudaResult<()> {
252        #[cfg(feature = "cuda")]
253        unsafe {
254            let status = ffi::cublasSscal_v2(
255                self.handle,
256                n,
257                &alpha,
258                x, incx,
259            );
260            
261            if status != 0 {
262                return Err(CudaError::CublasError(status));
263            }
264        }
265        
266        #[cfg(not(feature = "cuda"))]
267        unsafe {
268            for i in 0..n as usize {
269                let xi = x.add(i * incx as usize);
270                *xi = alpha * *xi;
271            }
272        }
273        
274        Ok(())
275    }
276
277    /// Matrix multiplication for CudaTensors
278    pub fn matmul(&self, a: &CudaTensor, b: &CudaTensor) -> CudaResult<CudaTensor> {
279        let a_dims = a.dims();
280        let b_dims = b.dims();
281        
282        if a_dims.len() != 2 || b_dims.len() != 2 {
283            return Err(CudaError::InvalidValue("matmul requires 2D tensors".into()));
284        }
285        
286        let m = a_dims[0] as i32;
287        let k = a_dims[1] as i32;
288        let n = b_dims[1] as i32;
289        
290        if k != b_dims[0] as i32 {
291            return Err(CudaError::InvalidValue(format!(
292                "Matrix dimensions don't match: [{}, {}] x [{}, {}]",
293                m, k, b_dims[0], n
294            )));
295        }
296        
297        // Create output tensor
298        let mut c = CudaTensor::zeros(&[m as usize, n as usize], a.device_id())?;
299        
300        // cuBLAS uses column-major, so we compute C^T = B^T * A^T
301        // which gives us C in row-major
302        self.sgemm(
303            false, false,
304            n, m, k,
305            1.0,
306            b.as_ptr() as *const f32, n,
307            a.as_ptr() as *const f32, k,
308            0.0,
309            c.as_mut_ptr() as *mut f32, n,
310        )?;
311        
312        Ok(c)
313    }
314}
315
316impl Drop for CuBlas {
317    fn drop(&mut self) {
318        #[cfg(feature = "cuda")]
319        if !self.handle.is_null() {
320            unsafe {
321                let _ = ffi::cublasDestroy_v2(self.handle);
322            }
323        }
324    }
325}
326
327unsafe impl Send for CuBlas {}
328unsafe impl Sync for CuBlas {}