1#![allow(non_camel_case_types)]
7#![allow(non_snake_case)]
8
9use std::os::raw::{c_int, c_void, c_char};
10
11pub type cudaError_t = c_int;
13
14pub type cudaStream_t = *mut c_void;
16
17pub type cudaEvent_t = *mut c_void;
19
20pub type cublasHandle_t = *mut c_void;
22
23pub type cublasStatus_t = c_int;
25
26pub type cublasOperation_t = c_int;
28
29pub const CUBLAS_OP_N: cublasOperation_t = 0;
30pub const CUBLAS_OP_T: cublasOperation_t = 1;
31
32#[repr(C)]
34pub enum cudaMemcpyKind {
35 cudaMemcpyHostToHost = 0,
36 cudaMemcpyHostToDevice = 1,
37 cudaMemcpyDeviceToHost = 2,
38 cudaMemcpyDeviceToDevice = 3,
39 cudaMemcpyDefault = 4,
40}
41
42#[repr(C)]
44#[derive(Debug, Clone)]
45pub struct cudaDeviceProp {
46 pub name: [c_char; 256],
47 pub totalGlobalMem: usize,
48 pub sharedMemPerBlock: usize,
49 pub regsPerBlock: c_int,
50 pub warpSize: c_int,
51 pub memPitch: usize,
52 pub maxThreadsPerBlock: c_int,
53 pub maxThreadsDim: [c_int; 3],
54 pub maxGridSize: [c_int; 3],
55 pub clockRate: c_int,
56 pub totalConstMem: usize,
57 pub major: c_int,
58 pub minor: c_int,
59 pub textureAlignment: usize,
60 pub texturePitchAlignment: usize,
61 pub deviceOverlap: c_int,
62 pub multiProcessorCount: c_int,
63 pub kernelExecTimeoutEnabled: c_int,
64 pub integrated: c_int,
65 pub canMapHostMemory: c_int,
66 pub computeMode: c_int,
67 pub maxTexture1D: c_int,
68 pub maxTexture1DMipmap: c_int,
69 pub maxTexture1DLinear: c_int,
70 pub maxTexture2D: [c_int; 2],
71 pub maxTexture2DMipmap: [c_int; 2],
72 pub maxTexture2DLinear: [c_int; 3],
73 pub maxTexture2DGather: [c_int; 2],
74 pub maxTexture3D: [c_int; 3],
75 pub maxTexture3DAlt: [c_int; 3],
76 pub maxTextureCubemap: c_int,
77 pub maxTexture1DLayered: [c_int; 2],
78 pub maxTexture2DLayered: [c_int; 3],
79 pub maxTextureCubemapLayered: [c_int; 2],
80 pub maxSurface1D: c_int,
81 pub maxSurface2D: [c_int; 2],
82 pub maxSurface3D: [c_int; 3],
83 pub maxSurface1DLayered: [c_int; 2],
84 pub maxSurface2DLayered: [c_int; 3],
85 pub maxSurfaceCubemap: c_int,
86 pub maxSurfaceCubemapLayered: [c_int; 2],
87 pub surfaceAlignment: usize,
88 pub concurrentKernels: c_int,
89 pub ECCEnabled: c_int,
90 pub pciBusID: c_int,
91 pub pciDeviceID: c_int,
92 pub pciDomainID: c_int,
93 pub tccDriver: c_int,
94 pub asyncEngineCount: c_int,
95 pub unifiedAddressing: c_int,
96 pub memoryClockRate: c_int,
97 pub memoryBusWidth: c_int,
98 pub l2CacheSize: c_int,
99 pub persistingL2CacheMaxSize: c_int,
100 pub maxThreadsPerMultiProcessor: c_int,
101 pub streamPrioritiesSupported: c_int,
102 pub globalL1CacheSupported: c_int,
103 pub localL1CacheSupported: c_int,
104 pub sharedMemPerMultiprocessor: usize,
105 pub regsPerMultiprocessor: c_int,
106 pub managedMemory: c_int,
107 pub isMultiGpuBoard: c_int,
108 pub multiGpuBoardGroupID: c_int,
109 pub hostNativeAtomicSupported: c_int,
110 pub singleToDoublePrecisionPerfRatio: c_int,
111 pub pageableMemoryAccess: c_int,
112 pub concurrentManagedAccess: c_int,
113 pub computePreemptionSupported: c_int,
114 pub canUseHostPointerForRegisteredMem: c_int,
115 pub cooperativeLaunch: c_int,
116 pub cooperativeMultiDeviceLaunch: c_int,
117 pub sharedMemPerBlockOptin: usize,
118 pub pageableMemoryAccessUsesHostPageTables: c_int,
119 pub directManagedMemAccessFromHost: c_int,
120 pub maxBlocksPerMultiProcessor: c_int,
121 pub accessPolicyMaxWindowSize: c_int,
122 pub reservedSharedMemPerBlock: usize,
123}
124
125impl Default for cudaDeviceProp {
126 fn default() -> Self {
127 unsafe { std::mem::zeroed() }
128 }
129}
130
131#[cfg(feature = "cuda")]
133#[link(name = "cudart")]
134extern "C" {
135 pub fn cudaGetDeviceCount(count: *mut c_int) -> cudaError_t;
136 pub fn cudaSetDevice(device: c_int) -> cudaError_t;
137 pub fn cudaGetDevice(device: *mut c_int) -> cudaError_t;
138 pub fn cudaGetDeviceProperties(prop: *mut cudaDeviceProp, device: c_int) -> cudaError_t;
139 pub fn cudaDeviceSynchronize() -> cudaError_t;
140 pub fn cudaDeviceReset() -> cudaError_t;
141
142 pub fn cudaMalloc(devPtr: *mut *mut c_void, size: usize) -> cudaError_t;
143 pub fn cudaFree(devPtr: *mut c_void) -> cudaError_t;
144 pub fn cudaMemcpy(dst: *mut c_void, src: *const c_void, count: usize, kind: cudaMemcpyKind) -> cudaError_t;
145 pub fn cudaMemcpyAsync(dst: *mut c_void, src: *const c_void, count: usize, kind: cudaMemcpyKind, stream: cudaStream_t) -> cudaError_t;
146 pub fn cudaMemset(devPtr: *mut c_void, value: c_int, count: usize) -> cudaError_t;
147 pub fn cudaMemsetAsync(devPtr: *mut c_void, value: c_int, count: usize, stream: cudaStream_t) -> cudaError_t;
148 pub fn cudaMemGetInfo(free: *mut usize, total: *mut usize) -> cudaError_t;
149
150 pub fn cudaStreamCreate(pStream: *mut cudaStream_t) -> cudaError_t;
151 pub fn cudaStreamDestroy(stream: cudaStream_t) -> cudaError_t;
152 pub fn cudaStreamSynchronize(stream: cudaStream_t) -> cudaError_t;
153 pub fn cudaStreamQuery(stream: cudaStream_t) -> cudaError_t;
154
155 pub fn cudaEventCreate(event: *mut cudaEvent_t) -> cudaError_t;
156 pub fn cudaEventDestroy(event: cudaEvent_t) -> cudaError_t;
157 pub fn cudaEventRecord(event: cudaEvent_t, stream: cudaStream_t) -> cudaError_t;
158 pub fn cudaEventSynchronize(event: cudaEvent_t) -> cudaError_t;
159 pub fn cudaEventElapsedTime(ms: *mut f32, start: cudaEvent_t, end: cudaEvent_t) -> cudaError_t;
160
161 pub fn cudaGetErrorString(error: cudaError_t) -> *const c_char;
162 pub fn cudaGetLastError() -> cudaError_t;
163 pub fn cudaPeekAtLastError() -> cudaError_t;
164}
165
166#[cfg(feature = "cuda")]
168extern "C" {
169 pub fn launch_optimized_sgemm(
170 A: *const f32, B: *const f32, C: *mut f32,
171 M: c_int, N: c_int, K: c_int,
172 alpha: f32, beta: f32,
173 stream: cudaStream_t
174 );
175
176 pub fn launch_fused_conv_bn_relu(
177 input: *const f32, weight: *const f32,
178 bn_weight: *const f32, bn_bias: *const f32,
179 bn_mean: *const f32, bn_var: *const f32,
180 output: *mut f32,
181 batch: c_int, in_channels: c_int, out_channels: c_int,
182 in_h: c_int, in_w: c_int, out_h: c_int, out_w: c_int,
183 kernel_h: c_int, kernel_w: c_int,
184 stride_h: c_int, stride_w: c_int,
185 pad_h: c_int, pad_w: c_int,
186 eps: f32,
187 stream: cudaStream_t
188 );
189
190 pub fn launch_fused_attention(
191 Q: *const f32, K: *const f32, V: *const f32,
192 output: *mut f32,
193 batch: c_int, heads: c_int, seq_len: c_int, head_dim: c_int,
194 scale: f32,
195 stream: cudaStream_t
196 );
197}
198
199#[cfg(feature = "cuda")]
201#[link(name = "cublas")]
202extern "C" {
203 pub fn cublasCreate_v2(handle: *mut cublasHandle_t) -> cublasStatus_t;
204 pub fn cublasDestroy_v2(handle: cublasHandle_t) -> cublasStatus_t;
205 pub fn cublasSetStream_v2(handle: cublasHandle_t, streamId: cudaStream_t) -> cublasStatus_t;
206
207 pub fn cublasSgemm_v2(
209 handle: cublasHandle_t,
210 transa: cublasOperation_t,
211 transb: cublasOperation_t,
212 m: c_int,
213 n: c_int,
214 k: c_int,
215 alpha: *const f32,
216 A: *const f32,
217 lda: c_int,
218 B: *const f32,
219 ldb: c_int,
220 beta: *const f32,
221 C: *mut f32,
222 ldc: c_int,
223 ) -> cublasStatus_t;
224
225 pub fn cublasSgemmBatched(
227 handle: cublasHandle_t,
228 transa: cublasOperation_t,
229 transb: cublasOperation_t,
230 m: c_int,
231 n: c_int,
232 k: c_int,
233 alpha: *const f32,
234 Aarray: *const *const f32,
235 lda: c_int,
236 Barray: *const *const f32,
237 ldb: c_int,
238 beta: *const f32,
239 Carray: *mut *mut f32,
240 ldc: c_int,
241 batchCount: c_int,
242 ) -> cublasStatus_t;
243
244 pub fn cublasSaxpy_v2(
246 handle: cublasHandle_t,
247 n: c_int,
248 alpha: *const f32,
249 x: *const f32,
250 incx: c_int,
251 y: *mut f32,
252 incy: c_int,
253 ) -> cublasStatus_t;
254
255 pub fn cublasSdot_v2(
257 handle: cublasHandle_t,
258 n: c_int,
259 x: *const f32,
260 incx: c_int,
261 y: *const f32,
262 incy: c_int,
263 result: *mut f32,
264 ) -> cublasStatus_t;
265
266 pub fn cublasSnrm2_v2(
268 handle: cublasHandle_t,
269 n: c_int,
270 x: *const f32,
271 incx: c_int,
272 result: *mut f32,
273 ) -> cublasStatus_t;
274
275 pub fn cublasSscal_v2(
277 handle: cublasHandle_t,
278 n: c_int,
279 alpha: *const f32,
280 x: *mut f32,
281 incx: c_int,
282 ) -> cublasStatus_t;
283}
284
285#[cfg(not(feature = "cuda"))]
287pub mod stubs {
288 use super::*;
289
290 pub unsafe fn cudaGetDeviceCount(count: *mut c_int) -> cudaError_t {
291 *count = 0;
292 0 }
294
295 pub unsafe fn cudaSetDevice(_device: c_int) -> cudaError_t {
296 1 }
298
299 pub unsafe fn cudaGetDevice(device: *mut c_int) -> cudaError_t {
300 *device = -1;
301 1
302 }
303
304 pub unsafe fn cudaDeviceSynchronize() -> cudaError_t {
305 0
306 }
307
308 pub unsafe fn cudaMalloc(_devPtr: *mut *mut c_void, _size: usize) -> cudaError_t {
309 2 }
311
312 pub unsafe fn cudaFree(_devPtr: *mut c_void) -> cudaError_t {
313 0
314 }
315
316 pub unsafe fn cudaMemcpy(_dst: *mut c_void, _src: *const c_void, _count: usize, _kind: cudaMemcpyKind) -> cudaError_t {
317 0
318 }
319
320 pub unsafe fn cudaMemset(_devPtr: *mut c_void, _value: c_int, _count: usize) -> cudaError_t {
321 0
322 }
323
324 pub unsafe fn cudaStreamCreate(_pStream: *mut cudaStream_t) -> cudaError_t {
325 0
326 }
327
328 pub unsafe fn cudaStreamDestroy(_stream: cudaStream_t) -> cudaError_t {
329 0
330 }
331
332 pub unsafe fn cudaStreamSynchronize(_stream: cudaStream_t) -> cudaError_t {
333 0
334 }
335}
336
337#[cfg(not(feature = "cuda"))]
338pub use stubs::*;
339
340pub fn check_cuda(err: cudaError_t) -> Result<(), cudaError_t> {
342 if err == 0 {
343 Ok(())
344 } else {
345 Err(err)
346 }
347}
348
349pub fn check_cublas(status: cublasStatus_t) -> Result<(), cublasStatus_t> {
351 if status == 0 {
352 Ok(())
353 } else {
354 Err(status)
355 }
356}