ghostflow_cuda/
ffi.rs

1//! CUDA FFI bindings - Real CUDA Runtime API
2//!
3//! These are the actual CUDA C API bindings. When compiled with the `cuda` feature,
4//! these will link against the real CUDA runtime library.
5
6#![allow(non_camel_case_types)]
7#![allow(non_snake_case)]
8
9use std::os::raw::{c_int, c_void, c_char};
10
11/// CUDA error codes
12pub type cudaError_t = c_int;
13
14/// CUDA stream handle
15pub type cudaStream_t = *mut c_void;
16
17/// CUDA event handle  
18pub type cudaEvent_t = *mut c_void;
19
20/// cuBLAS handle
21pub type cublasHandle_t = *mut c_void;
22
23/// cuBLAS status
24pub type cublasStatus_t = c_int;
25
26/// cuBLAS operation type
27pub type cublasOperation_t = c_int;
28
29pub const CUBLAS_OP_N: cublasOperation_t = 0;
30pub const CUBLAS_OP_T: cublasOperation_t = 1;
31
32/// CUDA memory copy kind
33#[repr(C)]
34pub enum cudaMemcpyKind {
35    cudaMemcpyHostToHost = 0,
36    cudaMemcpyHostToDevice = 1,
37    cudaMemcpyDeviceToHost = 2,
38    cudaMemcpyDeviceToDevice = 3,
39    cudaMemcpyDefault = 4,
40}
41
42/// CUDA device properties
43#[repr(C)]
44#[derive(Debug, Clone)]
45pub struct cudaDeviceProp {
46    pub name: [c_char; 256],
47    pub totalGlobalMem: usize,
48    pub sharedMemPerBlock: usize,
49    pub regsPerBlock: c_int,
50    pub warpSize: c_int,
51    pub memPitch: usize,
52    pub maxThreadsPerBlock: c_int,
53    pub maxThreadsDim: [c_int; 3],
54    pub maxGridSize: [c_int; 3],
55    pub clockRate: c_int,
56    pub totalConstMem: usize,
57    pub major: c_int,
58    pub minor: c_int,
59    pub textureAlignment: usize,
60    pub texturePitchAlignment: usize,
61    pub deviceOverlap: c_int,
62    pub multiProcessorCount: c_int,
63    pub kernelExecTimeoutEnabled: c_int,
64    pub integrated: c_int,
65    pub canMapHostMemory: c_int,
66    pub computeMode: c_int,
67    pub maxTexture1D: c_int,
68    pub maxTexture1DMipmap: c_int,
69    pub maxTexture1DLinear: c_int,
70    pub maxTexture2D: [c_int; 2],
71    pub maxTexture2DMipmap: [c_int; 2],
72    pub maxTexture2DLinear: [c_int; 3],
73    pub maxTexture2DGather: [c_int; 2],
74    pub maxTexture3D: [c_int; 3],
75    pub maxTexture3DAlt: [c_int; 3],
76    pub maxTextureCubemap: c_int,
77    pub maxTexture1DLayered: [c_int; 2],
78    pub maxTexture2DLayered: [c_int; 3],
79    pub maxTextureCubemapLayered: [c_int; 2],
80    pub maxSurface1D: c_int,
81    pub maxSurface2D: [c_int; 2],
82    pub maxSurface3D: [c_int; 3],
83    pub maxSurface1DLayered: [c_int; 2],
84    pub maxSurface2DLayered: [c_int; 3],
85    pub maxSurfaceCubemap: c_int,
86    pub maxSurfaceCubemapLayered: [c_int; 2],
87    pub surfaceAlignment: usize,
88    pub concurrentKernels: c_int,
89    pub ECCEnabled: c_int,
90    pub pciBusID: c_int,
91    pub pciDeviceID: c_int,
92    pub pciDomainID: c_int,
93    pub tccDriver: c_int,
94    pub asyncEngineCount: c_int,
95    pub unifiedAddressing: c_int,
96    pub memoryClockRate: c_int,
97    pub memoryBusWidth: c_int,
98    pub l2CacheSize: c_int,
99    pub persistingL2CacheMaxSize: c_int,
100    pub maxThreadsPerMultiProcessor: c_int,
101    pub streamPrioritiesSupported: c_int,
102    pub globalL1CacheSupported: c_int,
103    pub localL1CacheSupported: c_int,
104    pub sharedMemPerMultiprocessor: usize,
105    pub regsPerMultiprocessor: c_int,
106    pub managedMemory: c_int,
107    pub isMultiGpuBoard: c_int,
108    pub multiGpuBoardGroupID: c_int,
109    pub hostNativeAtomicSupported: c_int,
110    pub singleToDoublePrecisionPerfRatio: c_int,
111    pub pageableMemoryAccess: c_int,
112    pub concurrentManagedAccess: c_int,
113    pub computePreemptionSupported: c_int,
114    pub canUseHostPointerForRegisteredMem: c_int,
115    pub cooperativeLaunch: c_int,
116    pub cooperativeMultiDeviceLaunch: c_int,
117    pub sharedMemPerBlockOptin: usize,
118    pub pageableMemoryAccessUsesHostPageTables: c_int,
119    pub directManagedMemAccessFromHost: c_int,
120    pub maxBlocksPerMultiProcessor: c_int,
121    pub accessPolicyMaxWindowSize: c_int,
122    pub reservedSharedMemPerBlock: usize,
123}
124
125impl Default for cudaDeviceProp {
126    fn default() -> Self {
127        unsafe { std::mem::zeroed() }
128    }
129}
130
131// CUDA Runtime API
132#[cfg(feature = "cuda")]
133#[link(name = "cudart")]
134extern "C" {
135    pub fn cudaGetDeviceCount(count: *mut c_int) -> cudaError_t;
136    pub fn cudaSetDevice(device: c_int) -> cudaError_t;
137    pub fn cudaGetDevice(device: *mut c_int) -> cudaError_t;
138    pub fn cudaGetDeviceProperties(prop: *mut cudaDeviceProp, device: c_int) -> cudaError_t;
139    pub fn cudaDeviceSynchronize() -> cudaError_t;
140    pub fn cudaDeviceReset() -> cudaError_t;
141    
142    pub fn cudaMalloc(devPtr: *mut *mut c_void, size: usize) -> cudaError_t;
143    pub fn cudaFree(devPtr: *mut c_void) -> cudaError_t;
144    pub fn cudaMemcpy(dst: *mut c_void, src: *const c_void, count: usize, kind: cudaMemcpyKind) -> cudaError_t;
145    pub fn cudaMemcpyAsync(dst: *mut c_void, src: *const c_void, count: usize, kind: cudaMemcpyKind, stream: cudaStream_t) -> cudaError_t;
146    pub fn cudaMemset(devPtr: *mut c_void, value: c_int, count: usize) -> cudaError_t;
147    pub fn cudaMemsetAsync(devPtr: *mut c_void, value: c_int, count: usize, stream: cudaStream_t) -> cudaError_t;
148    pub fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> cudaError_t;
149    
150    pub fn cudaStreamCreate(pStream: *mut cudaStream_t) -> cudaError_t;
151    pub fn cudaStreamDestroy(stream: cudaStream_t) -> cudaError_t;
152    pub fn cudaStreamSynchronize(stream: cudaStream_t) -> cudaError_t;
153    pub fn cudaStreamQuery(stream: cudaStream_t) -> cudaError_t;
154    
155    pub fn cudaEventCreate(event: *mut cudaEvent_t) -> cudaError_t;
156    pub fn cudaEventDestroy(event: cudaEvent_t) -> cudaError_t;
157    pub fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> cudaError_t;
158    pub fn cudaEventSynchronize(event: cudaEvent_t) -> cudaError_t;
159    pub fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> cudaError_t;
160    
161    pub fn cudaGetErrorString(error: cudaError_t) -> *const c_char;
162    pub fn cudaGetLastError() -> cudaError_t;
163    pub fn cudaPeekAtLastError() -> cudaError_t;
164}
165
166// Our custom optimized kernels from optimized_kernels.cu
167#[cfg(feature = "cuda")]
168extern "C" {
169    pub fn launch_optimized_sgemm(
170        A: *const f32, B: *const f32, C: *mut f32,
171        M: c_int, N: c_int, K: c_int,
172        alpha: f32, beta: f32,
173        stream: cudaStream_t
174    );
175    
176    pub fn launch_fused_conv_bn_relu(
177        input: *const f32, weight: *const f32,
178        bn_weight: *const f32, bn_bias: *const f32,
179        bn_mean: *const f32, bn_var: *const f32,
180        output: *mut f32,
181        batch: c_int, in_channels: c_int, out_channels: c_int,
182        in_h: c_int, in_w: c_int, out_h: c_int, out_w: c_int,
183        kernel_h: c_int, kernel_w: c_int,
184        stride_h: c_int, stride_w: c_int,
185        pad_h: c_int, pad_w: c_int,
186        eps: f32,
187        stream: cudaStream_t
188    );
189    
190    pub fn launch_fused_attention(
191        Q: *const f32, K: *const f32, V: *const f32,
192        output: *mut f32,
193        batch: c_int, heads: c_int, seq_len: c_int, head_dim: c_int,
194        scale: f32,
195        stream: cudaStream_t
196    );
197}
198
199// cuBLAS API
200#[cfg(feature = "cuda")]
201#[link(name = "cublas")]
202extern "C" {
203    pub fn cublasCreate_v2(handle: *mut cublasHandle_t) -> cublasStatus_t;
204    pub fn cublasDestroy_v2(handle: cublasHandle_t) -> cublasStatus_t;
205    pub fn cublasSetStream_v2(handle: cublasHandle_t, streamId: cudaStream_t) -> cublasStatus_t;
206    
207    // SGEMM: C = alpha * op(A) * op(B) + beta * C
208    pub fn cublasSgemm_v2(
209        handle: cublasHandle_t,
210        transa: cublasOperation_t,
211        transb: cublasOperation_t,
212        m: c_int,
213        n: c_int,
214        k: c_int,
215        alpha: *const f32,
216        A: *const f32,
217        lda: c_int,
218        B: *const f32,
219        ldb: c_int,
220        beta: *const f32,
221        C: *mut f32,
222        ldc: c_int,
223    ) -> cublasStatus_t;
224    
225    // Batched SGEMM
226    pub fn cublasSgemmBatched(
227        handle: cublasHandle_t,
228        transa: cublasOperation_t,
229        transb: cublasOperation_t,
230        m: c_int,
231        n: c_int,
232        k: c_int,
233        alpha: *const f32,
234        Aarray: *const *const f32,
235        lda: c_int,
236        Barray: *const *const f32,
237        ldb: c_int,
238        beta: *const f32,
239        Carray: *mut *mut f32,
240        ldc: c_int,
241        batchCount: c_int,
242    ) -> cublasStatus_t;
243    
244    // SAXPY: y = alpha * x + y
245    pub fn cublasSaxpy_v2(
246        handle: cublasHandle_t,
247        n: c_int,
248        alpha: *const f32,
249        x: *const f32,
250        incx: c_int,
251        y: *mut f32,
252        incy: c_int,
253    ) -> cublasStatus_t;
254    
255    // SDOT: result = x . y
256    pub fn cublasSdot_v2(
257        handle: cublasHandle_t,
258        n: c_int,
259        x: *const f32,
260        incx: c_int,
261        y: *const f32,
262        incy: c_int,
263        result: *mut f32,
264    ) -> cublasStatus_t;
265    
266    // SNRM2: result = ||x||_2
267    pub fn cublasSnrm2_v2(
268        handle: cublasHandle_t,
269        n: c_int,
270        x: *const f32,
271        incx: c_int,
272        result: *mut f32,
273    ) -> cublasStatus_t;
274    
275    // SSCAL: x = alpha * x
276    pub fn cublasSscal_v2(
277        handle: cublasHandle_t,
278        n: c_int,
279        alpha: *const f32,
280        x: *mut f32,
281        incx: c_int,
282    ) -> cublasStatus_t;
283}
284
285// Stub implementations when CUDA is not available
286#[cfg(not(feature = "cuda"))]
287pub mod stubs {
288    use super::*;
289    
290    pub unsafe fn cudaGetDeviceCount(count: *mut c_int) -> cudaError_t {
291        *count = 0;
292        0 // cudaSuccess
293    }
294    
295    pub unsafe fn cudaSetDevice(_device: c_int) -> cudaError_t {
296        1 // cudaErrorInvalidDevice
297    }
298    
299    pub unsafe fn cudaGetDevice(device: *mut c_int) -> cudaError_t {
300        *device = -1;
301        1
302    }
303    
304    pub unsafe fn cudaDeviceSynchronize() -> cudaError_t {
305        0
306    }
307    
308    pub unsafe fn cudaMalloc(_devPtr: *mut *mut c_void, _size: usize) -> cudaError_t {
309        2 // cudaErrorMemoryAllocation
310    }
311    
312    pub unsafe fn cudaFree(_devPtr: *mut c_void) -> cudaError_t {
313        0
314    }
315    
316    pub unsafe fn cudaMemcpy(_dst: *mut c_void, _src: *const c_void, _count: usize, _kind: cudaMemcpyKind) -> cudaError_t {
317        0
318    }
319    
320    pub unsafe fn cudaMemset(_devPtr: *mut c_void, _value: c_int, _count: usize) -> cudaError_t {
321        0
322    }
323    
324    pub unsafe fn cudaStreamCreate(_pStream: *mut cudaStream_t) -> cudaError_t {
325        0
326    }
327    
328    pub unsafe fn cudaStreamDestroy(_stream: cudaStream_t) -> cudaError_t {
329        0
330    }
331    
332    pub unsafe fn cudaStreamSynchronize(_stream: cudaStream_t) -> cudaError_t {
333        0
334    }
335}
336
337#[cfg(not(feature = "cuda"))]
338pub use stubs::*;
339
340/// Check CUDA error and convert to Result
341pub fn check_cuda(err: cudaError_t) -> Result<(), cudaError_t> {
342    if err == 0 {
343        Ok(())
344    } else {
345        Err(err)
346    }
347}
348
349/// Check cuBLAS error
350pub fn check_cublas(status: cublasStatus_t) -> Result<(), cublasStatus_t> {
351    if status == 0 {
352        Ok(())
353    } else {
354        Err(status)
355    }
356}