oxicuda_driver/loader.rs
1//! Dynamic CUDA driver library loader.
2//!
3//! This module is the architectural foundation of `oxicuda-driver`. It locates
4//! and loads the CUDA driver shared library (`libcuda.so` on Linux,
5//! `nvcuda.dll` on Windows) **at runtime** via [`libloading`], so that no CUDA
6//! SDK is required at build time.
7//!
8//! # Platform support
9//!
10//! | Platform | Library names tried | Notes |
11//! |----------|----------------------------------|----------------------------------|
12//! | Linux | `libcuda.so.1`, `libcuda.so` | Installed by NVIDIA driver |
13//! | Windows | `nvcuda.dll` | Ships with the display driver |
14//! | macOS | — | Returns `UnsupportedPlatform` |
15//!
16//! # Usage
17//!
18//! Application code should **not** interact with [`DriverApi`] directly.
19//! Instead, call [`try_driver`] to obtain a reference to the lazily-
20//! initialised global singleton:
21//!
22//! ```rust,no_run
23//! # use oxicuda_driver::loader::try_driver;
24//! let api = try_driver()?;
25//! // api.cu_init, api.cu_device_get, …
26//! # Ok::<(), oxicuda_driver::error::CudaError>(())
27//! ```
28//!
29//! The singleton is stored in a [`OnceLock`] so that the (relatively
30//! expensive) `dlopen` + symbol resolution only happens once, and all
31//! subsequent accesses are a single atomic load.
32
33use std::ffi::{c_char, c_int, c_void};
34use std::sync::OnceLock;
35
36use libloading::Library;
37
38use crate::error::{CudaError, CudaResult, DriverLoadError};
39use crate::ffi::*;
40
41// ---------------------------------------------------------------------------
42// Global singleton
43// ---------------------------------------------------------------------------
44
45/// Global singleton for the driver API function table.
46///
47/// Initialised lazily on the first call to [`try_driver`].
48static DRIVER: OnceLock<Result<DriverApi, DriverLoadError>> = OnceLock::new();
49
50// ---------------------------------------------------------------------------
51// load_sym! helper macro
52// ---------------------------------------------------------------------------
53
54/// Load a single symbol from the shared library and transmute it to the
55/// requested function-pointer type.
56///
57/// # Safety
58///
59/// The caller must ensure that the symbol name matches the actual ABI of the
60/// function pointer type expected at the call site.
61#[cfg(not(target_os = "macos"))]
62macro_rules! load_sym {
63 ($lib:expr, $name:literal) => {{
64 // `Library::get` requires the name as a byte slice. We request the
65 // most general function-pointer type and then transmute to the
66 // concrete signature stored in DriverApi.
67 let sym = unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) }.map_err(|e| {
68 DriverLoadError::SymbolNotFound {
69 symbol: $name,
70 reason: e.to_string(),
71 }
72 })?;
73 // SAFETY: we trust that the CUDA driver exports the symbol with the
74 // ABI described by the target field type. The type is inferred from
75 // the DriverApi field this expression is assigned to, so explicit
76 // transmute annotations would require repeating the function-pointer
77 // type at every call site inside a macro — we suppress that lint here.
78 #[allow(clippy::missing_transmute_annotations)]
79 let result = unsafe { std::mem::transmute(*sym) };
80 result
81 }};
82}
83
84/// Load a symbol from the shared library, returning `Some(fn_ptr)` on success
85/// or `None` if the symbol is not found. Used for optional API entry points
86/// that may not be present in older driver versions.
87///
88/// # Safety
89///
90/// Same safety requirements as [`load_sym!`].
91#[cfg(not(target_os = "macos"))]
92macro_rules! load_sym_optional {
93 ($lib:expr, $name:literal) => {{
94 match unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) } {
95 Ok(sym) => {
96 // SAFETY: the target type is inferred from the DriverApi field
97 // this value is assigned to. Suppressing the lint here avoids
98 // repeating the function-pointer type at every call site.
99 #[allow(clippy::missing_transmute_annotations)]
100 let fp = unsafe { std::mem::transmute(*sym) };
101 Some(fp)
102 }
103 Err(_) => {
104 tracing::debug!(concat!("optional symbol not found: ", $name));
105 None
106 }
107 }
108 }};
109}
110
111// ---------------------------------------------------------------------------
112// DriverApi
113// ---------------------------------------------------------------------------
114
115/// Complete function-pointer table for the CUDA Driver API.
116///
117/// An instance of this struct is produced by [`DriverApi::load`] and kept
118/// alive for the lifetime of the process inside the `DRIVER` singleton.
119/// The embedded [`Library`] handle ensures the shared object is not unloaded.
120///
121/// # Function pointer groups
122///
123/// The fields are organised into logical groups mirroring the CUDA Driver API
124/// documentation:
125///
126/// * **Initialisation** — [`cu_init`](Self::cu_init)
127/// * **Device management** — `cu_device_*`
128/// * **Context management** — `cu_ctx_*`
129/// * **Module management** — `cu_module_*`
130/// * **Memory management** — `cu_mem_*`, `cu_memcpy_*`, `cu_memset_*`
131/// * **Stream management** — `cu_stream_*`
132/// * **Event management** — `cu_event_*`
133/// * **Kernel launch** — [`cu_launch_kernel`](Self::cu_launch_kernel)
134/// * **Occupancy queries** — `cu_occupancy_*`
135pub struct DriverApi {
136 // Keep the shared library handle alive.
137 _lib: Library,
138
139 // -- Initialisation ----------------------------------------------------
140 /// `cuInit(flags) -> CUresult`
141 ///
142 /// Initialises the CUDA driver API. Must be called before any other
143 /// driver function. Passing `0` for *flags* is the only documented
144 /// value.
145 pub cu_init: unsafe extern "C" fn(flags: u32) -> CUresult,
146
147 // -- Version query -------------------------------------------------------
148 /// `cuDriverGetVersion(driverVersion*) -> CUresult`
149 ///
150 /// Returns the CUDA driver version as `major*1000 + minor*10`.
151 pub cu_driver_get_version: unsafe extern "C" fn(version: *mut c_int) -> CUresult,
152
153 // -- Device management -------------------------------------------------
154 /// `cuDeviceGet(device*, ordinal) -> CUresult`
155 ///
156 /// Returns a handle to a compute device.
157 pub cu_device_get: unsafe extern "C" fn(device: *mut CUdevice, ordinal: c_int) -> CUresult,
158
159 /// `cuDeviceGetCount(count*) -> CUresult`
160 ///
161 /// Returns the number of compute-capable devices.
162 pub cu_device_get_count: unsafe extern "C" fn(count: *mut c_int) -> CUresult,
163
164 /// `cuDeviceGetName(name*, len, dev) -> CUresult`
165 ///
166 /// Returns an ASCII string identifying the device.
167 pub cu_device_get_name:
168 unsafe extern "C" fn(name: *mut c_char, len: c_int, dev: CUdevice) -> CUresult,
169
170 /// `cuDeviceGetAttribute(pi*, attrib, dev) -> CUresult`
171 ///
172 /// Returns information about the device.
173 pub cu_device_get_attribute:
174 unsafe extern "C" fn(pi: *mut c_int, attrib: CUdevice_attribute, dev: CUdevice) -> CUresult,
175
176 /// `cuDeviceTotalMem_v2(bytes*, dev) -> CUresult`
177 ///
178 /// Returns the total amount of memory on the device.
179 pub cu_device_total_mem_v2: unsafe extern "C" fn(bytes: *mut usize, dev: CUdevice) -> CUresult,
180
181 /// `cuDeviceCanAccessPeer(canAccessPeer*, dev, peerDev) -> CUresult`
182 ///
183 /// Queries if a device may directly access a peer device's memory.
184 pub cu_device_can_access_peer:
185 unsafe extern "C" fn(can_access: *mut c_int, dev: CUdevice, peer_dev: CUdevice) -> CUresult,
186
187 // -- Primary context management ----------------------------------------
188 /// `cuDevicePrimaryCtxRetain(pctx*, dev) -> CUresult`
189 ///
190 /// Retains the primary context on the device, creating it if necessary.
191 pub cu_device_primary_ctx_retain:
192 unsafe extern "C" fn(pctx: *mut CUcontext, dev: CUdevice) -> CUresult,
193
194 /// `cuDevicePrimaryCtxRelease_v2(dev) -> CUresult`
195 ///
196 /// Releases the primary context on the device.
197 pub cu_device_primary_ctx_release_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
198
199 /// `cuDevicePrimaryCtxSetFlags_v2(dev, flags) -> CUresult`
200 ///
201 /// Sets flags for the primary context.
202 pub cu_device_primary_ctx_set_flags_v2:
203 unsafe extern "C" fn(dev: CUdevice, flags: u32) -> CUresult,
204
205 /// `cuDevicePrimaryCtxGetState(dev, flags*, active*) -> CUresult`
206 ///
207 /// Returns the state (flags and active status) of the primary context.
208 pub cu_device_primary_ctx_get_state:
209 unsafe extern "C" fn(dev: CUdevice, flags: *mut u32, active: *mut c_int) -> CUresult,
210
211 /// `cuDevicePrimaryCtxReset_v2(dev) -> CUresult`
212 ///
213 /// Resets the primary context on the device.
214 pub cu_device_primary_ctx_reset_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
215
216 // -- Context management ------------------------------------------------
217 /// `cuCtxCreate_v2(pctx*, flags, dev) -> CUresult`
218 ///
219 /// Creates a new CUDA context and associates it with the calling thread.
220 pub cu_ctx_create_v2:
221 unsafe extern "C" fn(pctx: *mut CUcontext, flags: u32, dev: CUdevice) -> CUresult,
222
223 /// `cuCtxDestroy_v2(ctx) -> CUresult`
224 ///
225 /// Destroys a CUDA context.
226 pub cu_ctx_destroy_v2: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
227
228 /// `cuCtxSetCurrent(ctx) -> CUresult`
229 ///
230 /// Binds the specified CUDA context to the calling CPU thread.
231 pub cu_ctx_set_current: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
232
233 /// `cuCtxGetCurrent(pctx*) -> CUresult`
234 ///
235 /// Returns the CUDA context bound to the calling CPU thread.
236 pub cu_ctx_get_current: unsafe extern "C" fn(pctx: *mut CUcontext) -> CUresult,
237
238 /// `cuCtxSynchronize() -> CUresult`
239 ///
240 /// Blocks until the device has completed all preceding requested tasks.
241 pub cu_ctx_synchronize: unsafe extern "C" fn() -> CUresult,
242
243 // -- Module management -------------------------------------------------
244 /// `cuModuleLoadData(module*, image*) -> CUresult`
245 ///
246 /// Loads a module from a PTX or cubin image in host memory.
247 pub cu_module_load_data:
248 unsafe extern "C" fn(module: *mut CUmodule, image: *const c_void) -> CUresult,
249
250 /// `cuModuleLoadDataEx(module*, image*, numOptions, options*, optionValues*) -> CUresult`
251 ///
252 /// Loads a module with JIT compiler options.
253 pub cu_module_load_data_ex: unsafe extern "C" fn(
254 module: *mut CUmodule,
255 image: *const c_void,
256 num_options: u32,
257 options: *mut CUjit_option,
258 option_values: *mut *mut c_void,
259 ) -> CUresult,
260
261 /// `cuModuleGetFunction(hfunc*, hmod, name*) -> CUresult`
262 ///
263 /// Returns a handle to a function within a module.
264 pub cu_module_get_function: unsafe extern "C" fn(
265 hfunc: *mut CUfunction,
266 hmod: CUmodule,
267 name: *const c_char,
268 ) -> CUresult,
269
270 /// `cuModuleUnload(hmod) -> CUresult`
271 ///
272 /// Unloads a module from the current context.
273 pub cu_module_unload: unsafe extern "C" fn(hmod: CUmodule) -> CUresult,
274
275 // -- Memory management -------------------------------------------------
276 /// `cuMemAlloc_v2(dptr*, bytesize) -> CUresult`
277 ///
278 /// Allocates device memory.
279 pub cu_mem_alloc_v2: unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult,
280
281 /// `cuMemFree_v2(dptr) -> CUresult`
282 ///
283 /// Frees device memory.
284 pub cu_mem_free_v2: unsafe extern "C" fn(dptr: CUdeviceptr) -> CUresult,
285
286 /// `cuMemcpyHtoD_v2(dst, src*, bytesize) -> CUresult`
287 ///
288 /// Copies data from host memory to device memory.
289 pub cu_memcpy_htod_v2:
290 unsafe extern "C" fn(dst: CUdeviceptr, src: *const c_void, bytesize: usize) -> CUresult,
291
292 /// `cuMemcpyDtoH_v2(dst*, src, bytesize) -> CUresult`
293 ///
294 /// Copies data from device memory to host memory.
295 pub cu_memcpy_dtoh_v2:
296 unsafe extern "C" fn(dst: *mut c_void, src: CUdeviceptr, bytesize: usize) -> CUresult,
297
298 /// `cuMemcpyDtoD_v2(dst, src, bytesize) -> CUresult`
299 ///
300 /// Copies data from device memory to device memory.
301 pub cu_memcpy_dtod_v2:
302 unsafe extern "C" fn(dst: CUdeviceptr, src: CUdeviceptr, bytesize: usize) -> CUresult,
303
304 /// `cuMemcpyHtoDAsync_v2(dst, src*, bytesize, stream) -> CUresult`
305 ///
306 /// Asynchronously copies data from host to device memory.
307 pub cu_memcpy_htod_async_v2: unsafe extern "C" fn(
308 dst: CUdeviceptr,
309 src: *const c_void,
310 bytesize: usize,
311 stream: CUstream,
312 ) -> CUresult,
313
314 /// `cuMemcpyDtoHAsync_v2(dst*, src, bytesize, stream) -> CUresult`
315 ///
316 /// Asynchronously copies data from device to host memory.
317 pub cu_memcpy_dtoh_async_v2: unsafe extern "C" fn(
318 dst: *mut c_void,
319 src: CUdeviceptr,
320 bytesize: usize,
321 stream: CUstream,
322 ) -> CUresult,
323
324 /// `cuMemAllocHost_v2(pp*, bytesize) -> CUresult`
325 ///
326 /// Allocates page-locked (pinned) host memory.
327 pub cu_mem_alloc_host_v2:
328 unsafe extern "C" fn(pp: *mut *mut c_void, bytesize: usize) -> CUresult,
329
330 /// `cuMemFreeHost(p*) -> CUresult`
331 ///
332 /// Frees page-locked host memory.
333 pub cu_mem_free_host: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
334
335 /// `cuMemAllocManaged(dptr*, bytesize, flags) -> CUresult`
336 ///
337 /// Allocates unified memory accessible from both host and device.
338 pub cu_mem_alloc_managed:
339 unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize, flags: u32) -> CUresult,
340
341 /// `cuMemsetD8_v2(dst, value, count) -> CUresult`
342 ///
343 /// Sets device memory to a value (byte granularity).
344 pub cu_memset_d8_v2:
345 unsafe extern "C" fn(dst: CUdeviceptr, value: u8, count: usize) -> CUresult,
346
347 /// `cuMemsetD32_v2(dst, value, count) -> CUresult`
348 ///
349 /// Sets device memory to a value (32-bit granularity).
350 pub cu_memset_d32_v2:
351 unsafe extern "C" fn(dst: CUdeviceptr, value: u32, count: usize) -> CUresult,
352
353 /// `cuMemGetInfo_v2(free*, total*) -> CUresult`
354 ///
355 /// Returns free and total memory for the current context's device.
356 pub cu_mem_get_info_v2: unsafe extern "C" fn(free: *mut usize, total: *mut usize) -> CUresult,
357
358 /// `cuMemHostRegister_v2(p*, bytesize, flags) -> CUresult`
359 ///
360 /// Registers an existing host memory range for use by CUDA.
361 pub cu_mem_host_register_v2:
362 unsafe extern "C" fn(p: *mut c_void, bytesize: usize, flags: u32) -> CUresult,
363
364 /// `cuMemHostUnregister(p*) -> CUresult`
365 ///
366 /// Unregisters a memory range that was registered with cuMemHostRegister.
367 pub cu_mem_host_unregister: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
368
369 /// `cuMemHostGetDevicePointer_v2(pdptr*, p*, flags) -> CUresult`
370 ///
371 /// Returns the device pointer mapped to a registered host pointer.
372 pub cu_mem_host_get_device_pointer_v2:
373 unsafe extern "C" fn(pdptr: *mut CUdeviceptr, p: *mut c_void, flags: u32) -> CUresult,
374
375 /// `cuPointerGetAttribute(data*, attribute, ptr) -> CUresult`
376 ///
377 /// Returns information about a pointer.
378 pub cu_pointer_get_attribute:
379 unsafe extern "C" fn(data: *mut c_void, attribute: u32, ptr: CUdeviceptr) -> CUresult,
380
381 /// `cuMemAdvise(devPtr, count, advice, device) -> CUresult`
382 ///
383 /// Advises the unified memory subsystem about usage patterns.
384 pub cu_mem_advise: unsafe extern "C" fn(
385 dev_ptr: CUdeviceptr,
386 count: usize,
387 advice: u32,
388 device: CUdevice,
389 ) -> CUresult,
390
391 /// `cuMemPrefetchAsync(devPtr, count, dstDevice, hStream) -> CUresult`
392 ///
393 /// Prefetches unified memory to the specified device.
394 pub cu_mem_prefetch_async: unsafe extern "C" fn(
395 dev_ptr: CUdeviceptr,
396 count: usize,
397 dst_device: CUdevice,
398 hstream: CUstream,
399 ) -> CUresult,
400
401 // -- Stream management -------------------------------------------------
402 /// `cuStreamCreate(phStream*, flags) -> CUresult`
403 ///
404 /// Creates a stream.
405 pub cu_stream_create: unsafe extern "C" fn(phstream: *mut CUstream, flags: u32) -> CUresult,
406
407 /// `cuStreamCreateWithPriority(phStream*, flags, priority) -> CUresult`
408 ///
409 /// Creates a stream with the given priority.
410 pub cu_stream_create_with_priority:
411 unsafe extern "C" fn(phstream: *mut CUstream, flags: u32, priority: c_int) -> CUresult,
412
413 /// `cuStreamDestroy_v2(hStream) -> CUresult`
414 ///
415 /// Destroys a stream.
416 pub cu_stream_destroy_v2: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
417
418 /// `cuStreamSynchronize(hStream) -> CUresult`
419 ///
420 /// Waits until a stream's tasks are completed.
421 pub cu_stream_synchronize: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
422
423 /// `cuStreamWaitEvent(hStream, hEvent, flags) -> CUresult`
424 ///
425 /// Makes all future work submitted to the stream wait for the event.
426 pub cu_stream_wait_event:
427 unsafe extern "C" fn(hstream: CUstream, hevent: CUevent, flags: u32) -> CUresult,
428
429 /// `cuStreamQuery(hStream) -> CUresult`
430 ///
431 /// Returns `CUDA_SUCCESS` if all operations in the stream have completed,
432 /// `CUDA_ERROR_NOT_READY` if still pending.
433 pub cu_stream_query: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
434
435 /// `cuStreamGetPriority(hStream, priority*) -> CUresult`
436 ///
437 /// Query the priority of `hStream`.
438 pub cu_stream_get_priority:
439 unsafe extern "C" fn(hstream: CUstream, priority: *mut std::ffi::c_int) -> CUresult,
440
441 /// `cuStreamGetFlags(hStream, flags*) -> CUresult`
442 ///
443 /// Query the flags of `hStream`.
444 pub cu_stream_get_flags: unsafe extern "C" fn(hstream: CUstream, flags: *mut u32) -> CUresult,
445
446 // -- Event management --------------------------------------------------
447 /// `cuEventCreate(phEvent*, flags) -> CUresult`
448 ///
449 /// Creates an event.
450 pub cu_event_create: unsafe extern "C" fn(phevent: *mut CUevent, flags: u32) -> CUresult,
451
452 /// `cuEventDestroy_v2(hEvent) -> CUresult`
453 ///
454 /// Destroys an event.
455 pub cu_event_destroy_v2: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
456
457 /// `cuEventRecord(hEvent, hStream) -> CUresult`
458 ///
459 /// Records an event in a stream.
460 pub cu_event_record: unsafe extern "C" fn(hevent: CUevent, hstream: CUstream) -> CUresult,
461
462 /// `cuEventQuery(hEvent) -> CUresult`
463 ///
464 /// Queries the status of an event. Returns `CUDA_SUCCESS` if complete,
465 /// `CUDA_ERROR_NOT_READY` if still pending.
466 pub cu_event_query: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
467
468 /// `cuEventSynchronize(hEvent) -> CUresult`
469 ///
470 /// Waits until an event completes.
471 pub cu_event_synchronize: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
472
473 /// `cuEventElapsedTime(pMilliseconds*, hStart, hEnd) -> CUresult`
474 ///
475 /// Computes the elapsed time between two events.
476 pub cu_event_elapsed_time:
477 unsafe extern "C" fn(pmilliseconds: *mut f32, hstart: CUevent, hend: CUevent) -> CUresult,
478
479 // -- Kernel launch -----------------------------------------------------
480
481 // -- Peer memory access ------------------------------------------------
482 /// `cuMemcpyPeer(dstDevice, dstContext, srcDevice, srcContext, count) -> CUresult`
483 ///
484 /// Copies device memory between two primary contexts.
485 pub cu_memcpy_peer: unsafe extern "C" fn(
486 dst_device: u64,
487 dst_ctx: CUcontext,
488 src_device: u64,
489 src_ctx: CUcontext,
490 count: usize,
491 ) -> CUresult,
492
493 /// `cuMemcpyPeerAsync(..., hStream) -> CUresult`
494 ///
495 /// Asynchronous cross-device copy.
496 pub cu_memcpy_peer_async: unsafe extern "C" fn(
497 dst_device: u64,
498 dst_ctx: CUcontext,
499 src_device: u64,
500 src_ctx: CUcontext,
501 count: usize,
502 stream: CUstream,
503 ) -> CUresult,
504
505 /// `cuCtxEnablePeerAccess(peerContext, flags) -> CUresult`
506 ///
507 /// Enables peer access between two contexts.
508 pub cu_ctx_enable_peer_access:
509 unsafe extern "C" fn(peer_context: CUcontext, flags: u32) -> CUresult,
510
511 /// `cuCtxDisablePeerAccess(peerContext) -> CUresult`
512 ///
513 /// Disables peer access to a context.
514 pub cu_ctx_disable_peer_access: unsafe extern "C" fn(peer_context: CUcontext) -> CUresult,
515 /// `cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
516 /// blockDimZ, sharedMemBytes, hStream, kernelParams**, extra**) -> CUresult`
517 ///
518 /// Launches a CUDA kernel.
519 #[allow(clippy::type_complexity)]
520 pub cu_launch_kernel: unsafe extern "C" fn(
521 f: CUfunction,
522 grid_dim_x: u32,
523 grid_dim_y: u32,
524 grid_dim_z: u32,
525 block_dim_x: u32,
526 block_dim_y: u32,
527 block_dim_z: u32,
528 shared_mem_bytes: u32,
529 hstream: CUstream,
530 kernel_params: *mut *mut c_void,
531 extra: *mut *mut c_void,
532 ) -> CUresult,
533
534 /// `cuLaunchCooperativeKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX,
535 /// blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams**) -> CUresult`
536 ///
537 /// Launches a cooperative CUDA kernel (CUDA 9.0+).
538 #[allow(clippy::type_complexity)]
539 pub cu_launch_cooperative_kernel: unsafe extern "C" fn(
540 f: CUfunction,
541 grid_dim_x: u32,
542 grid_dim_y: u32,
543 grid_dim_z: u32,
544 block_dim_x: u32,
545 block_dim_y: u32,
546 block_dim_z: u32,
547 shared_mem_bytes: u32,
548 hstream: CUstream,
549 kernel_params: *mut *mut c_void,
550 ) -> CUresult,
551
552 /// `cuLaunchCooperativeKernelMultiDevice(launchParamsList*, numDevices,
553 /// flags) -> CUresult`
554 ///
555 /// Launches a cooperative kernel across multiple devices (CUDA 9.0+).
556 pub cu_launch_cooperative_kernel_multi_device: unsafe extern "C" fn(
557 launch_params_list: *mut c_void,
558 num_devices: u32,
559 flags: u32,
560 ) -> CUresult,
561
562 // -- Occupancy ---------------------------------------------------------
563 /// `cuOccupancyMaxActiveBlocksPerMultiprocessor(numBlocks*, func, blockSize,
564 /// dynamicSMemSize) -> CUresult`
565 ///
566 /// Returns the number of the maximum active blocks per streaming
567 /// multiprocessor.
568 pub cu_occupancy_max_active_blocks_per_multiprocessor: unsafe extern "C" fn(
569 num_blocks: *mut c_int,
570 func: CUfunction,
571 block_size: c_int,
572 dynamic_smem_size: usize,
573 ) -> CUresult,
574
575 /// `cuOccupancyMaxPotentialBlockSize(minGridSize*, blockSize*, func,
576 /// blockSizeToDynamicSMemSize, dynamicSMemSize, blockSizeLimit) -> CUresult`
577 ///
578 /// Suggests a launch configuration with reasonable occupancy.
579 #[allow(clippy::type_complexity)]
580 pub cu_occupancy_max_potential_block_size: unsafe extern "C" fn(
581 min_grid_size: *mut c_int,
582 block_size: *mut c_int,
583 func: CUfunction,
584 block_size_to_dynamic_smem_size: Option<unsafe extern "C" fn(c_int) -> usize>,
585 dynamic_smem_size: usize,
586 block_size_limit: c_int,
587 ) -> CUresult,
588
589 /// `cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(numBlocks*, func,
590 /// blockSize, dynamicSMemSize, flags) -> CUresult`
591 ///
592 /// Like `cuOccupancyMaxActiveBlocksPerMultiprocessor` but with flags
593 /// to control caching behaviour (CUDA 9.0+).
594 pub cu_occupancy_max_active_blocks_per_multiprocessor_with_flags:
595 unsafe extern "C" fn(
596 num_blocks: *mut c_int,
597 func: CUfunction,
598 block_size: c_int,
599 dynamic_smem_size: usize,
600 flags: u32,
601 ) -> CUresult,
602
603 // -- Memory management (optional) -----------------------------------------
604 /// `cuMemcpyDtoDAsync_v2(dst, src, bytesize, stream) -> CUresult`
605 ///
606 /// Asynchronously copies data from device memory to device memory.
607 pub cu_memcpy_dtod_async_v2: Option<
608 unsafe extern "C" fn(
609 dst: CUdeviceptr,
610 src: CUdeviceptr,
611 bytesize: usize,
612 stream: CUstream,
613 ) -> CUresult,
614 >,
615
616 /// `cuMemsetD16_v2(dst, value, count) -> CUresult`
617 ///
618 /// Sets device memory to a value (16-bit granularity).
619 pub cu_memset_d16_v2:
620 Option<unsafe extern "C" fn(dst: CUdeviceptr, value: u16, count: usize) -> CUresult>,
621
622 /// `cuMemsetD32Async(dst, value, count, stream) -> CUresult`
623 ///
624 /// Asynchronously sets device memory to a value (32-bit granularity).
625 pub cu_memset_d32_async: Option<
626 unsafe extern "C" fn(
627 dst: CUdeviceptr,
628 value: u32,
629 count: usize,
630 stream: CUstream,
631 ) -> CUresult,
632 >,
633
634 // -- Context management (optional) ----------------------------------------
635 /// `cuCtxGetLimit(value*, limit) -> CUresult`
636 ///
637 /// Returns the value of a context limit.
638 pub cu_ctx_get_limit: Option<unsafe extern "C" fn(value: *mut usize, limit: u32) -> CUresult>,
639
640 /// `cuCtxSetLimit(limit, value) -> CUresult`
641 ///
642 /// Sets a context limit.
643 pub cu_ctx_set_limit: Option<unsafe extern "C" fn(limit: u32, value: usize) -> CUresult>,
644
645 /// `cuCtxGetCacheConfig(config*) -> CUresult`
646 ///
647 /// Returns the current cache configuration for the context.
648 pub cu_ctx_get_cache_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
649
650 /// `cuCtxSetCacheConfig(config) -> CUresult`
651 ///
652 /// Sets the cache configuration for the current context.
653 pub cu_ctx_set_cache_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
654
655 /// `cuCtxGetSharedMemConfig(config*) -> CUresult`
656 ///
657 /// Returns the shared memory configuration for the context.
658 pub cu_ctx_get_shared_mem_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
659
660 /// `cuCtxSetSharedMemConfig(config) -> CUresult`
661 ///
662 /// Sets the shared memory configuration for the current context.
663 pub cu_ctx_set_shared_mem_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
664
665 // -- Event with flags (optional, CUDA 11.1+) ------------------------------
666 /// `cuEventRecordWithFlags(hEvent, hStream, flags) -> CUresult`
667 ///
668 /// Records an event in a stream with additional flags (CUDA 11.1+).
669 /// Falls back to `cu_event_record` when `None`.
670 pub cu_event_record_with_flags:
671 Option<unsafe extern "C" fn(hevent: CUevent, hstream: CUstream, flags: u32) -> CUresult>,
672
673 // -- Function attributes (optional) ---------------------------------------
674 /// `cuFuncGetAttribute(value*, attrib, func) -> CUresult`
675 ///
676 /// Returns information about a function.
677 pub cu_func_get_attribute: Option<
678 unsafe extern "C" fn(value: *mut c_int, attrib: c_int, func: CUfunction) -> CUresult,
679 >,
680
681 /// `cuFuncSetCacheConfig(func, config) -> CUresult`
682 ///
683 /// Sets the cache configuration for a device function.
684 pub cu_func_set_cache_config:
685 Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
686
687 /// `cuFuncSetSharedMemConfig(func, config) -> CUresult`
688 ///
689 /// Sets the shared memory configuration for a device function.
690 pub cu_func_set_shared_mem_config:
691 Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
692
693 /// `cuFuncSetAttribute(func, attrib, value) -> CUresult`
694 ///
695 /// Sets an attribute value for a device function.
696 pub cu_func_set_attribute:
697 Option<unsafe extern "C" fn(func: CUfunction, attrib: c_int, value: c_int) -> CUresult>,
698
699 // -- Profiler (optional) --------------------------------------------------
700 /// `cuProfilerStart() -> CUresult`
701 ///
702 /// Starts the CUDA profiler.
703 pub cu_profiler_start: Option<unsafe extern "C" fn() -> CUresult>,
704
705 /// `cuProfilerStop() -> CUresult`
706 ///
707 /// Stops the CUDA profiler.
708 pub cu_profiler_stop: Option<unsafe extern "C" fn() -> CUresult>,
709
710 // -- CUDA 12.x extended launch (optional) ---------------------------------
711 /// `cuLaunchKernelEx(config*, f, kernelParams**, extra**) -> CUresult`
712 ///
713 /// Extended kernel launch with cluster dimensions and other CUDA 12.0+
714 /// attributes. Available only when the driver is CUDA 12.0 or newer.
715 ///
716 /// When `None`, fall back to [`cu_launch_kernel`](Self::cu_launch_kernel).
717 #[allow(clippy::type_complexity)]
718 pub cu_launch_kernel_ex: Option<
719 unsafe extern "C" fn(
720 config: *const CuLaunchConfig,
721 f: CUfunction,
722 kernel_params: *mut *mut std::ffi::c_void,
723 extra: *mut *mut std::ffi::c_void,
724 ) -> CUresult,
725 >,
726
727 /// `cuTensorMapEncodeTiled(tensorMap*, ...) -> CUresult`
728 ///
729 /// Creates a TMA tensor map descriptor for tiled access patterns.
730 /// Available on CUDA 12.0+ with sm_90+ (Hopper/Blackwell).
731 ///
732 /// When `None`, TMA is not supported by the loaded driver.
733 #[allow(clippy::type_complexity)]
734 pub cu_tensor_map_encode_tiled: Option<
735 unsafe extern "C" fn(
736 tensor_map: *mut std::ffi::c_void,
737 tensor_data_type: u32,
738 tensor_rank: u32,
739 global_address: *mut std::ffi::c_void,
740 global_dim: *const u64,
741 global_strides: *const u64,
742 box_dim: *const u32,
743 element_strides: *const u32,
744 interleave: u32,
745 swizzle: u32,
746 l2_promotion: u32,
747 oob_fill: u32,
748 ) -> CUresult,
749 >,
750
751 // -- CUDA 12.8+ extended API (optional) -----------------------------------
752 /// `cuTensorMapEncodeTiledMemref(tensorMap*, ...) -> CUresult`
753 ///
754 /// Extended TMA encoding using memref descriptors (CUDA 12.8+,
755 /// Blackwell sm_100/sm_120). When `None`, fall back to
756 /// [`cu_tensor_map_encode_tiled`](Self::cu_tensor_map_encode_tiled).
757 #[allow(clippy::type_complexity)]
758 pub cu_tensor_map_encode_tiled_memref: Option<
759 unsafe extern "C" fn(
760 tensor_map: *mut c_void,
761 tensor_data_type: u32,
762 tensor_rank: u32,
763 global_address: *mut c_void,
764 global_dim: *const u64,
765 global_strides: *const u64,
766 box_dim: *const u32,
767 element_strides: *const u32,
768 interleave: u32,
769 swizzle: u32,
770 l2_promotion: u32,
771 oob_fill: u32,
772 flags: u64,
773 ) -> CUresult,
774 >,
775
776 /// `cuKernelGetLibrary(pLib*, kernel) -> CUresult`
777 ///
778 /// Returns the library handle that owns a given kernel handle
779 /// (CUDA 12.8+). When `None`, the driver does not support the JIT
780 /// library API.
781 pub cu_kernel_get_library:
782 Option<unsafe extern "C" fn(p_lib: *mut CUlibrary, kernel: CUkernel) -> CUresult>,
783
784 /// `cuMulticastGetGranularity(granularity*, desc*, option) -> CUresult`
785 ///
786 /// Queries the recommended memory granularity for an NVLink multicast
787 /// object (CUDA 12.8+). When `None`, multicast memory is not supported.
788 pub cu_multicast_get_granularity: Option<
789 unsafe extern "C" fn(granularity: *mut usize, desc: *const c_void, option: u32) -> CUresult,
790 >,
791
792 /// `cuMulticastCreate(mcHandle*, desc*) -> CUresult`
793 ///
794 /// Creates an NVLink multicast object for cross-GPU broadcast memory
795 /// (CUDA 12.8+). When `None`, multicast memory is not supported.
796 pub cu_multicast_create: Option<
797 unsafe extern "C" fn(mc_handle: *mut CUmulticastObject, desc: *const c_void) -> CUresult,
798 >,
799
800 /// `cuMulticastAddDevice(mcHandle, dev) -> CUresult`
801 ///
802 /// Adds a device to an NVLink multicast group (CUDA 12.8+). When
803 /// `None`, multicast memory is not supported.
804 pub cu_multicast_add_device:
805 Option<unsafe extern "C" fn(mc_handle: CUmulticastObject, dev: CUdevice) -> CUresult>,
806
807 /// `cuMemcpyBatchAsync(dsts*, srcs*, sizes*, count, flags, stream) -> CUresult`
808 ///
809 /// Issues *count* asynchronous memory copies (H2D, D2H, or D2D) in a
810 /// single driver call (CUDA 12.8+). When `None`, issue individual
811 /// `cuMemcpyAsync` calls as a fallback.
812 #[allow(clippy::type_complexity)]
813 pub cu_memcpy_batch_async: Option<
814 unsafe extern "C" fn(
815 dsts: *const *mut c_void,
816 srcs: *const *const c_void,
817 sizes: *const usize,
818 count: u64,
819 flags: u64,
820 stream: CUstream,
821 ) -> CUresult,
822 >,
823
824 // -- Texture / Surface memory (optional) ----------------------------------
825 /// `cuArrayCreate_v2(pHandle*, pAllocateArray*) -> CUresult`
826 ///
827 /// Allocates a 1-D or 2-D CUDA array. When `None`, CUDA array allocation
828 /// is not supported by the loaded driver.
829 pub cu_array_create_v2: Option<
830 unsafe extern "C" fn(
831 p_handle: *mut CUarray,
832 p_allocate_array: *const CUDA_ARRAY_DESCRIPTOR,
833 ) -> CUresult,
834 >,
835
836 /// `cuArrayDestroy(hArray) -> CUresult`
837 ///
838 /// Frees a CUDA array previously allocated by `cuArrayCreate_v2`.
839 pub cu_array_destroy: Option<unsafe extern "C" fn(h_array: CUarray) -> CUresult>,
840
841 /// `cuArrayGetDescriptor_v2(pArrayDescriptor*, hArray) -> CUresult`
842 ///
843 /// Returns the descriptor of a 1-D or 2-D CUDA array.
844 pub cu_array_get_descriptor_v2: Option<
845 unsafe extern "C" fn(
846 p_array_descriptor: *mut CUDA_ARRAY_DESCRIPTOR,
847 h_array: CUarray,
848 ) -> CUresult,
849 >,
850
851 /// `cuArray3DCreate_v2(pHandle*, pAllocateArray*) -> CUresult`
852 ///
853 /// Allocates a 3-D CUDA array (also supports layered and cubemap arrays).
854 pub cu_array3d_create_v2: Option<
855 unsafe extern "C" fn(
856 p_handle: *mut CUarray,
857 p_allocate_array: *const CUDA_ARRAY3D_DESCRIPTOR,
858 ) -> CUresult,
859 >,
860
861 /// `cuArray3DGetDescriptor_v2(pArrayDescriptor*, hArray) -> CUresult`
862 ///
863 /// Returns the descriptor of a 3-D CUDA array.
864 pub cu_array3d_get_descriptor_v2: Option<
865 unsafe extern "C" fn(
866 p_array_descriptor: *mut CUDA_ARRAY3D_DESCRIPTOR,
867 h_array: CUarray,
868 ) -> CUresult,
869 >,
870
871 /// `cuMemcpyHtoA_v2(dstArray, dstOffset, srcHost*, ByteCount) -> CUresult`
872 ///
873 /// Synchronously copies host memory into a CUDA array.
874 pub cu_memcpy_htoa_v2: Option<
875 unsafe extern "C" fn(
876 dst_array: CUarray,
877 dst_offset: usize,
878 src_host: *const c_void,
879 byte_count: usize,
880 ) -> CUresult,
881 >,
882
883 /// `cuMemcpyAtoH_v2(dstHost*, srcArray, srcOffset, ByteCount) -> CUresult`
884 ///
885 /// Synchronously copies data from a CUDA array into host memory.
886 pub cu_memcpy_atoh_v2: Option<
887 unsafe extern "C" fn(
888 dst_host: *mut c_void,
889 src_array: CUarray,
890 src_offset: usize,
891 byte_count: usize,
892 ) -> CUresult,
893 >,
894
895 /// `cuMemcpyHtoAAsync_v2(dstArray, dstOffset, srcHost*, byteCount, stream) -> CUresult`
896 ///
897 /// Asynchronously copies host memory into a CUDA array on a stream.
898 pub cu_memcpy_htoa_async_v2: Option<
899 unsafe extern "C" fn(
900 dst_array: CUarray,
901 dst_offset: usize,
902 src_host: *const c_void,
903 byte_count: usize,
904 stream: CUstream,
905 ) -> CUresult,
906 >,
907
908 /// `cuMemcpyAtoHAsync_v2(dstHost*, srcArray, srcOffset, byteCount, stream) -> CUresult`
909 ///
910 /// Asynchronously copies data from a CUDA array into host memory on a stream.
911 pub cu_memcpy_atoh_async_v2: Option<
912 unsafe extern "C" fn(
913 dst_host: *mut c_void,
914 src_array: CUarray,
915 src_offset: usize,
916 byte_count: usize,
917 stream: CUstream,
918 ) -> CUresult,
919 >,
920
921 /// `cuTexObjectCreate(pTexObject*, pResDesc*, pTexDesc*, pResViewDesc*) -> CUresult`
922 ///
923 /// Creates a texture object from a resource descriptor, texture descriptor,
924 /// and optional resource-view descriptor (CUDA 5.0+).
925 pub cu_tex_object_create: Option<
926 unsafe extern "C" fn(
927 p_tex_object: *mut CUtexObject,
928 p_res_desc: *const CUDA_RESOURCE_DESC,
929 p_tex_desc: *const CUDA_TEXTURE_DESC,
930 p_res_view_desc: *const CUDA_RESOURCE_VIEW_DESC,
931 ) -> CUresult,
932 >,
933
934 /// `cuTexObjectDestroy(texObject) -> CUresult`
935 ///
936 /// Destroys a texture object created by `cuTexObjectCreate`.
937 pub cu_tex_object_destroy: Option<unsafe extern "C" fn(tex_object: CUtexObject) -> CUresult>,
938
939 /// `cuTexObjectGetResourceDesc(pResDesc*, texObject) -> CUresult`
940 ///
941 /// Returns the resource descriptor of a texture object.
942 pub cu_tex_object_get_resource_desc: Option<
943 unsafe extern "C" fn(
944 p_res_desc: *mut CUDA_RESOURCE_DESC,
945 tex_object: CUtexObject,
946 ) -> CUresult,
947 >,
948
949 /// `cuSurfObjectCreate(pSurfObject*, pResDesc*) -> CUresult`
950 ///
951 /// Creates a surface object from a resource descriptor (CUDA 5.0+).
952 /// The resource type must be `Array` (surface-capable CUDA arrays only).
953 pub cu_surf_object_create: Option<
954 unsafe extern "C" fn(
955 p_surf_object: *mut CUsurfObject,
956 p_res_desc: *const CUDA_RESOURCE_DESC,
957 ) -> CUresult,
958 >,
959
960 /// `cuSurfObjectDestroy(surfObject) -> CUresult`
961 ///
962 /// Destroys a surface object created by `cuSurfObjectCreate`.
963 pub cu_surf_object_destroy: Option<unsafe extern "C" fn(surf_object: CUsurfObject) -> CUresult>,
964
965 // -- JIT linker (optional) ----------------------------------------------
966 /// `cuLinkCreate_v2(numOptions, options*, optionValues**, stateOut*) -> CUresult`
967 ///
968 /// Creates a pending JIT linker invocation. When `None`, the driver does
969 /// not expose the linker API.
970 pub cu_link_create: Option<
971 unsafe extern "C" fn(
972 num_options: u32,
973 options: *mut CUjit_option,
974 option_values: *mut *mut c_void,
975 state_out: *mut CUlinkState,
976 ) -> CUresult,
977 >,
978
979 /// `cuLinkAddData_v2(state, type, data*, size, name*, numOptions, options*, optionValues**) -> CUresult`
980 ///
981 /// Adds an input PTX/cubin/fatbin to a pending linker invocation. When
982 /// `None`, the driver does not expose the linker API.
983 #[allow(clippy::type_complexity)]
984 pub cu_link_add_data: Option<
985 unsafe extern "C" fn(
986 state: CUlinkState,
987 input_type: CUjitInputType,
988 data: *mut c_void,
989 size: usize,
990 name: *const c_char,
991 num_options: u32,
992 options: *mut CUjit_option,
993 option_values: *mut *mut c_void,
994 ) -> CUresult,
995 >,
996
997 /// `cuLinkComplete(state, cubinOut**, sizeOut*) -> CUresult`
998 ///
999 /// Finalises a JIT linker invocation and returns the resulting cubin
1000 /// pointer / size. When `None`, the driver does not expose the linker
1001 /// API.
1002 pub cu_link_complete: Option<
1003 unsafe extern "C" fn(
1004 state: CUlinkState,
1005 cubin_out: *mut *mut c_void,
1006 size_out: *mut usize,
1007 ) -> CUresult,
1008 >,
1009
1010 /// `cuLinkDestroy(state) -> CUresult`
1011 ///
1012 /// Destroys a linker state previously created by `cuLinkCreate`. When
1013 /// `None`, the driver does not expose the linker API.
1014 pub cu_link_destroy: Option<unsafe extern "C" fn(state: CUlinkState) -> CUresult>,
1015
1016 // -- 2-D memory copy (optional) -----------------------------------------
1017 /// `cuMemcpy2D_v2(pCopy*) -> CUresult`
1018 ///
1019 /// Performs a 2-D memory copy described by [`CUDA_MEMCPY2D`]. When
1020 /// `None`, fall back to issuing per-row 1-D `cuMemcpyXXX_v2` calls.
1021 pub cu_memcpy_2d: Option<unsafe extern "C" fn(p_copy: *const CUDA_MEMCPY2D) -> CUresult>,
1022
1023 // -- Virtual memory management (optional, CUDA 11.2+) -------------------
1024 /// `cuMemAddressReserve(ptr*, size, alignment, addr, flags) -> CUresult`
1025 ///
1026 /// Reserves a contiguous range of virtual addresses on the device for
1027 /// later mapping by `cuMemMap`. When `None`, the VMM API is not
1028 /// supported.
1029 pub cu_mem_address_reserve: Option<
1030 unsafe extern "C" fn(
1031 ptr: *mut CUdeviceptr,
1032 size: usize,
1033 alignment: usize,
1034 addr: CUdeviceptr,
1035 flags: u64,
1036 ) -> CUresult,
1037 >,
1038
1039 /// `cuMemAddressFree(ptr, size) -> CUresult`
1040 ///
1041 /// Releases a virtual-address range previously obtained from
1042 /// `cuMemAddressReserve`. When `None`, the VMM API is not supported.
1043 pub cu_mem_address_free:
1044 Option<unsafe extern "C" fn(ptr: CUdeviceptr, size: usize) -> CUresult>,
1045
1046 /// `cuMemCreate(handle*, size, prop*, flags) -> CUresult`
1047 ///
1048 /// Creates a new generic VMM allocation handle. When `None`, the VMM
1049 /// API is not supported.
1050 pub cu_mem_create: Option<
1051 unsafe extern "C" fn(
1052 handle: *mut CUmemGenericAllocationHandle,
1053 size: usize,
1054 prop: *const CUmemAllocationProp,
1055 flags: u64,
1056 ) -> CUresult,
1057 >,
1058
1059 /// `cuMemRelease(handle) -> CUresult`
1060 ///
1061 /// Releases a generic VMM allocation handle. When `None`, the VMM
1062 /// API is not supported.
1063 pub cu_mem_release:
1064 Option<unsafe extern "C" fn(handle: CUmemGenericAllocationHandle) -> CUresult>,
1065
1066 /// `cuMemMap(ptr, size, offset, handle, flags) -> CUresult`
1067 ///
1068 /// Maps a VMM allocation onto a previously reserved virtual address
1069 /// range. When `None`, the VMM API is not supported.
1070 pub cu_mem_map: Option<
1071 unsafe extern "C" fn(
1072 ptr: CUdeviceptr,
1073 size: usize,
1074 offset: usize,
1075 handle: CUmemGenericAllocationHandle,
1076 flags: u64,
1077 ) -> CUresult,
1078 >,
1079
1080 /// `cuMemUnmap(ptr, size) -> CUresult`
1081 ///
1082 /// Unmaps a VMM allocation from a virtual address range. When `None`,
1083 /// the VMM API is not supported.
1084 pub cu_mem_unmap: Option<unsafe extern "C" fn(ptr: CUdeviceptr, size: usize) -> CUresult>,
1085
1086 /// `cuMemSetAccess(ptr, size, desc*, count) -> CUresult`
1087 ///
1088 /// Sets per-location access permissions for a VMM mapping. When `None`,
1089 /// the VMM API is not supported.
1090 pub cu_mem_set_access: Option<
1091 unsafe extern "C" fn(
1092 ptr: CUdeviceptr,
1093 size: usize,
1094 desc: *const CUmemAccessDesc,
1095 count: usize,
1096 ) -> CUresult,
1097 >,
1098
1099 // -- Stream-ordered memory pools (optional, CUDA 11.2+) -----------------
1100 /// `cuMemPoolCreate(pool*, poolProps*) -> CUresult`
1101 ///
1102 /// Creates a stream-ordered memory pool. When `None`, the memory pool
1103 /// API is not supported.
1104 pub cu_mem_pool_create: Option<
1105 unsafe extern "C" fn(
1106 pool: *mut CUmemoryPool,
1107 pool_props: *const CUmemPoolProps,
1108 ) -> CUresult,
1109 >,
1110
1111 /// `cuMemPoolDestroy(pool) -> CUresult`
1112 ///
1113 /// Destroys a stream-ordered memory pool. When `None`, the memory pool
1114 /// API is not supported.
1115 pub cu_mem_pool_destroy: Option<unsafe extern "C" fn(pool: CUmemoryPool) -> CUresult>,
1116
1117 /// `cuMemAllocFromPoolAsync(dptr*, bytesize, pool, hStream) -> CUresult`
1118 ///
1119 /// Asynchronously allocates memory from a pool on a stream. When `None`,
1120 /// the memory pool API is not supported.
1121 pub cu_mem_alloc_from_pool_async: Option<
1122 unsafe extern "C" fn(
1123 dptr: *mut CUdeviceptr,
1124 bytesize: usize,
1125 pool: CUmemoryPool,
1126 hstream: CUstream,
1127 ) -> CUresult,
1128 >,
1129
1130 /// `cuMemFreeAsync(dptr, hStream) -> CUresult`
1131 ///
1132 /// Asynchronously frees memory on a stream. When `None`, the
1133 /// stream-ordered memory API is not supported.
1134 pub cu_mem_free_async:
1135 Option<unsafe extern "C" fn(dptr: CUdeviceptr, hstream: CUstream) -> CUresult>,
1136
1137 /// `cuMemAllocAsync(dptr*, bytesize, hStream) -> CUresult`
1138 ///
1139 /// Asynchronously allocates memory from the current context's default
1140 /// pool on a stream. When `None`, the stream-ordered memory API is not
1141 /// supported.
1142 pub cu_mem_alloc_async: Option<
1143 unsafe extern "C" fn(
1144 dptr: *mut CUdeviceptr,
1145 bytesize: usize,
1146 hstream: CUstream,
1147 ) -> CUresult,
1148 >,
1149
1150 /// `cuMemPoolTrimTo(pool, minBytesToKeep) -> CUresult`
1151 ///
1152 /// Releases freed memory back to the OS, keeping at least
1153 /// `minBytesToKeep` bytes reserved. When `None`, the memory pool API
1154 /// is not supported.
1155 pub cu_mem_pool_trim_to:
1156 Option<unsafe extern "C" fn(pool: CUmemoryPool, min_bytes_to_keep: usize) -> CUresult>,
1157
1158 /// `cuMemPoolSetAttribute(pool, attr, value*) -> CUresult`
1159 ///
1160 /// Sets a writable attribute on a memory pool. When `None`, the memory
1161 /// pool API is not supported.
1162 pub cu_mem_pool_set_attribute: Option<
1163 unsafe extern "C" fn(
1164 pool: CUmemoryPool,
1165 attr: CUmemPoolAttribute,
1166 value: *mut c_void,
1167 ) -> CUresult,
1168 >,
1169
1170 /// `cuMemPoolGetAttribute(pool, attr, value*) -> CUresult`
1171 ///
1172 /// Reads an attribute from a memory pool. When `None`, the memory pool
1173 /// API is not supported.
1174 pub cu_mem_pool_get_attribute: Option<
1175 unsafe extern "C" fn(
1176 pool: CUmemoryPool,
1177 attr: CUmemPoolAttribute,
1178 value: *mut c_void,
1179 ) -> CUresult,
1180 >,
1181
1182 /// `cuMemPoolSetAccess(pool, map*, count) -> CUresult`
1183 ///
1184 /// Controls the per-device visibility of allocations from a memory pool.
1185 /// When `None`, the memory pool API is not supported.
1186 pub cu_mem_pool_set_access: Option<
1187 unsafe extern "C" fn(
1188 pool: CUmemoryPool,
1189 map: *const CUmemAccessDesc,
1190 count: usize,
1191 ) -> CUresult,
1192 >,
1193
1194 /// `cuDeviceGetDefaultMemPool(pool*, dev) -> CUresult`
1195 ///
1196 /// Returns the default stream-ordered memory pool of a device. When
1197 /// `None`, the memory pool API is not supported.
1198 pub cu_device_get_default_mem_pool:
1199 Option<unsafe extern "C" fn(pool: *mut CUmemoryPool, dev: CUdevice) -> CUresult>,
1200
1201 // -- CUDA Graph API (optional, CUDA 10.0+) ------------------------------
1202 /// `cuGraphCreate(phGraph*, flags) -> CUresult`
1203 ///
1204 /// Creates an empty CUDA graph. When `None`, the graph API is not
1205 /// supported by the loaded driver.
1206 pub cu_graph_create:
1207 Option<unsafe extern "C" fn(ph_graph: *mut CUgraph, flags: u32) -> CUresult>,
1208
1209 /// `cuGraphDestroy(hGraph) -> CUresult`
1210 ///
1211 /// Destroys a CUDA graph. When `None`, the graph API is not supported.
1212 pub cu_graph_destroy: Option<unsafe extern "C" fn(h_graph: CUgraph) -> CUresult>,
1213
1214 /// `cuGraphAddKernelNode(phGraphNode*, hGraph, dependencies*, numDependencies, nodeParams*) -> CUresult`
1215 ///
1216 /// Adds a kernel-launch node to a graph. When `None`, the graph API is
1217 /// not supported.
1218 pub cu_graph_add_kernel_node: Option<
1219 unsafe extern "C" fn(
1220 ph_graph_node: *mut CUgraphNode,
1221 h_graph: CUgraph,
1222 dependencies: *const CUgraphNode,
1223 num_dependencies: usize,
1224 node_params: *const CUDA_KERNEL_NODE_PARAMS,
1225 ) -> CUresult,
1226 >,
1227
1228 /// `cuGraphAddMemcpyNode(phGraphNode*, hGraph, dependencies*, numDependencies, copyParams*, ctx) -> CUresult`
1229 ///
1230 /// Adds a memory-copy node to a graph. When `None`, the graph API is
1231 /// not supported.
1232 pub cu_graph_add_memcpy_node: Option<
1233 unsafe extern "C" fn(
1234 ph_graph_node: *mut CUgraphNode,
1235 h_graph: CUgraph,
1236 dependencies: *const CUgraphNode,
1237 num_dependencies: usize,
1238 copy_params: *const CUDA_MEMCPY3D,
1239 ctx: CUcontext,
1240 ) -> CUresult,
1241 >,
1242
1243 /// `cuGraphAddMemsetNode(phGraphNode*, hGraph, dependencies*, numDependencies, memsetParams*, ctx) -> CUresult`
1244 ///
1245 /// Adds a memset node to a graph. When `None`, the graph API is not
1246 /// supported.
1247 pub cu_graph_add_memset_node: Option<
1248 unsafe extern "C" fn(
1249 ph_graph_node: *mut CUgraphNode,
1250 h_graph: CUgraph,
1251 dependencies: *const CUgraphNode,
1252 num_dependencies: usize,
1253 memset_params: *const CUDA_MEMSET_NODE_PARAMS,
1254 ctx: CUcontext,
1255 ) -> CUresult,
1256 >,
1257
1258 /// `cuGraphAddEmptyNode(phGraphNode*, hGraph, dependencies*, numDependencies) -> CUresult`
1259 ///
1260 /// Adds an empty (no-op) node to a graph. When `None`, the graph API
1261 /// is not supported.
1262 pub cu_graph_add_empty_node: Option<
1263 unsafe extern "C" fn(
1264 ph_graph_node: *mut CUgraphNode,
1265 h_graph: CUgraph,
1266 dependencies: *const CUgraphNode,
1267 num_dependencies: usize,
1268 ) -> CUresult,
1269 >,
1270
1271 /// `cuGraphInstantiateWithFlags(phGraphExec*, hGraph, flags) -> CUresult`
1272 ///
1273 /// Instantiates a graph into an executable form (CUDA 11.4+). When
1274 /// `None`, fall back to [`cu_graph_instantiate`](Self::cu_graph_instantiate).
1275 pub cu_graph_instantiate_with_flags: Option<
1276 unsafe extern "C" fn(
1277 ph_graph_exec: *mut CUgraphExec,
1278 h_graph: CUgraph,
1279 flags: u64,
1280 ) -> CUresult,
1281 >,
1282
1283 /// `cuGraphInstantiate_v2(phGraphExec*, hGraph, phErrorNode*, logBuffer*, bufferSize) -> CUresult`
1284 ///
1285 /// Instantiates a graph into an executable form (legacy signature).
1286 /// When `None`, the graph API is not supported.
1287 pub cu_graph_instantiate: Option<
1288 unsafe extern "C" fn(
1289 ph_graph_exec: *mut CUgraphExec,
1290 h_graph: CUgraph,
1291 ph_error_node: *mut CUgraphNode,
1292 log_buffer: *mut c_char,
1293 buffer_size: usize,
1294 ) -> CUresult,
1295 >,
1296
1297 /// `cuGraphExecDestroy(hGraphExec) -> CUresult`
1298 ///
1299 /// Destroys an executable graph. When `None`, the graph API is not
1300 /// supported.
1301 pub cu_graph_exec_destroy: Option<unsafe extern "C" fn(h_graph_exec: CUgraphExec) -> CUresult>,
1302
1303 /// `cuGraphLaunch(hGraphExec, hStream) -> CUresult`
1304 ///
1305 /// Submits an executable graph to a stream. When `None`, the graph API
1306 /// is not supported.
1307 pub cu_graph_launch:
1308 Option<unsafe extern "C" fn(h_graph_exec: CUgraphExec, h_stream: CUstream) -> CUresult>,
1309}
1310
1311// SAFETY: All fields are plain function pointers (which are Send + Sync) and
1312// the Library handle is kept alive but never mutated.
1313unsafe impl Send for DriverApi {}
1314unsafe impl Sync for DriverApi {}
1315
1316// ---------------------------------------------------------------------------
1317// DriverApi — construction
1318// ---------------------------------------------------------------------------
1319
1320impl DriverApi {
1321 /// Attempt to dynamically load the CUDA driver shared library and resolve
1322 /// every required symbol.
1323 ///
1324 /// # Platform behaviour
1325 ///
1326 /// * **macOS** — immediately returns [`DriverLoadError::UnsupportedPlatform`].
1327 /// * **Linux** — tries `libcuda.so.1` then `libcuda.so`.
1328 /// * **Windows** — tries `nvcuda.dll`.
1329 ///
1330 /// # Errors
1331 ///
1332 /// * [`DriverLoadError::UnsupportedPlatform`] on macOS.
1333 /// * [`DriverLoadError::LibraryNotFound`] if none of the candidate library
1334 /// names could be opened.
1335 /// * [`DriverLoadError::SymbolNotFound`] if a required CUDA entry point is
1336 /// missing from the loaded library.
1337 pub fn load() -> Result<Self, DriverLoadError> {
1338 // macOS: CUDA is not and will not be supported.
1339 #[cfg(target_os = "macos")]
1340 {
1341 Err(DriverLoadError::UnsupportedPlatform)
1342 }
1343
1344 // Linux library search order.
1345 #[cfg(target_os = "linux")]
1346 let lib_names: &[&str] = &["libcuda.so.1", "libcuda.so"];
1347
1348 // Windows library search order.
1349 #[cfg(target_os = "windows")]
1350 let lib_names: &[&str] = &["nvcuda.dll"];
1351
1352 #[cfg(not(target_os = "macos"))]
1353 {
1354 let lib = Self::load_library(lib_names)?;
1355 let api = Self::load_symbols(lib)?;
1356 // `cuInit(0)` must be called before any other CUDA driver API.
1357 // This mirrors what `libcudart` does internally on the first CUDA
1358 // Runtime call. We call it unconditionally here so that all
1359 // `try_driver()` callers get a fully initialised driver without
1360 // each needing to call `cuInit` themselves.
1361 //
1362 // SAFETY: `api.cu_init` was just resolved from the shared library.
1363 // Passing flags=0 is the only documented value.
1364 let rc = unsafe { (api.cu_init)(0) };
1365 if rc != 0 {
1366 // Propagate the error; the OnceLock will store this Err and
1367 // return CudaError::NotInitialized on every subsequent
1368 // try_driver() call — matching behaviour on no-GPU machines.
1369 return Err(DriverLoadError::InitializationFailed { code: rc });
1370 }
1371 Ok(api)
1372 }
1373 }
1374
1375 /// Try each candidate library name in order, returning the first that
1376 /// loads successfully.
1377 ///
1378 /// # Errors
1379 ///
1380 /// Returns [`DriverLoadError::LibraryNotFound`] if **all** candidates
1381 /// fail to load, capturing the last OS-level error message.
1382 #[cfg(not(target_os = "macos"))]
1383 fn load_library(names: &[&str]) -> Result<Library, DriverLoadError> {
1384 let mut last_error = String::new();
1385 for name in names {
1386 // SAFETY: loading a shared library has side-effects (running its
1387 // init routines), but the CUDA driver library is designed for
1388 // this.
1389 match unsafe { Library::new(*name) } {
1390 Ok(lib) => {
1391 tracing::debug!("loaded CUDA driver library: {name}");
1392 return Ok(lib);
1393 }
1394 Err(e) => {
1395 tracing::debug!("failed to load {name}: {e}");
1396 last_error = e.to_string();
1397 }
1398 }
1399 }
1400
1401 Err(DriverLoadError::LibraryNotFound {
1402 candidates: names.iter().map(|s| (*s).to_string()).collect(),
1403 last_error,
1404 })
1405 }
1406
1407 /// Resolve every required CUDA driver symbol from the loaded library and
1408 /// assemble the [`DriverApi`] function table.
1409 ///
1410 /// # Errors
1411 ///
1412 /// Returns [`DriverLoadError::SymbolNotFound`] if any symbol cannot be
1413 /// resolved.
1414 #[cfg(not(target_os = "macos"))]
1415 fn load_symbols(lib: Library) -> Result<Self, DriverLoadError> {
1416 Ok(Self {
1417 // -- Initialisation ------------------------------------------------
1418 cu_init: load_sym!(lib, "cuInit"),
1419
1420 // -- Version query -------------------------------------------------
1421 cu_driver_get_version: load_sym!(lib, "cuDriverGetVersion"),
1422
1423 // -- Device management ---------------------------------------------
1424 cu_device_get: load_sym!(lib, "cuDeviceGet"),
1425 cu_device_get_count: load_sym!(lib, "cuDeviceGetCount"),
1426 cu_device_get_name: load_sym!(lib, "cuDeviceGetName"),
1427 cu_device_get_attribute: load_sym!(lib, "cuDeviceGetAttribute"),
1428 cu_device_total_mem_v2: load_sym!(lib, "cuDeviceTotalMem_v2"),
1429 cu_device_can_access_peer: load_sym!(lib, "cuDeviceCanAccessPeer"),
1430
1431 // -- Primary context management ------------------------------------
1432 cu_device_primary_ctx_retain: load_sym!(lib, "cuDevicePrimaryCtxRetain"),
1433 cu_device_primary_ctx_release_v2: load_sym!(lib, "cuDevicePrimaryCtxRelease_v2"),
1434 cu_device_primary_ctx_set_flags_v2: load_sym!(lib, "cuDevicePrimaryCtxSetFlags_v2"),
1435 cu_device_primary_ctx_get_state: load_sym!(lib, "cuDevicePrimaryCtxGetState"),
1436 cu_device_primary_ctx_reset_v2: load_sym!(lib, "cuDevicePrimaryCtxReset_v2"),
1437
1438 // -- Context management --------------------------------------------
1439 cu_ctx_create_v2: load_sym!(lib, "cuCtxCreate_v2"),
1440 cu_ctx_destroy_v2: load_sym!(lib, "cuCtxDestroy_v2"),
1441 cu_ctx_set_current: load_sym!(lib, "cuCtxSetCurrent"),
1442 cu_ctx_get_current: load_sym!(lib, "cuCtxGetCurrent"),
1443 cu_ctx_synchronize: load_sym!(lib, "cuCtxSynchronize"),
1444
1445 // -- Module management ---------------------------------------------
1446 cu_module_load_data: load_sym!(lib, "cuModuleLoadData"),
1447 cu_module_load_data_ex: load_sym!(lib, "cuModuleLoadDataEx"),
1448 cu_module_get_function: load_sym!(lib, "cuModuleGetFunction"),
1449 cu_module_unload: load_sym!(lib, "cuModuleUnload"),
1450
1451 // -- Memory management ---------------------------------------------
1452 cu_mem_alloc_v2: load_sym!(lib, "cuMemAlloc_v2"),
1453 cu_mem_free_v2: load_sym!(lib, "cuMemFree_v2"),
1454 cu_memcpy_htod_v2: load_sym!(lib, "cuMemcpyHtoD_v2"),
1455 cu_memcpy_dtoh_v2: load_sym!(lib, "cuMemcpyDtoH_v2"),
1456 cu_memcpy_dtod_v2: load_sym!(lib, "cuMemcpyDtoD_v2"),
1457 cu_memcpy_htod_async_v2: load_sym!(lib, "cuMemcpyHtoDAsync_v2"),
1458 cu_memcpy_dtoh_async_v2: load_sym!(lib, "cuMemcpyDtoHAsync_v2"),
1459 cu_mem_alloc_host_v2: load_sym!(lib, "cuMemAllocHost_v2"),
1460 cu_mem_free_host: load_sym!(lib, "cuMemFreeHost"),
1461 cu_mem_alloc_managed: load_sym!(lib, "cuMemAllocManaged"),
1462 cu_memset_d8_v2: load_sym!(lib, "cuMemsetD8_v2"),
1463 cu_memset_d32_v2: load_sym!(lib, "cuMemsetD32_v2"),
1464 cu_mem_get_info_v2: load_sym!(lib, "cuMemGetInfo_v2"),
1465 cu_mem_host_register_v2: load_sym!(lib, "cuMemHostRegister_v2"),
1466 cu_mem_host_unregister: load_sym!(lib, "cuMemHostUnregister"),
1467 cu_mem_host_get_device_pointer_v2: load_sym!(lib, "cuMemHostGetDevicePointer_v2"),
1468 cu_pointer_get_attribute: load_sym!(lib, "cuPointerGetAttribute"),
1469 cu_mem_advise: load_sym!(lib, "cuMemAdvise"),
1470 cu_mem_prefetch_async: load_sym!(lib, "cuMemPrefetchAsync"),
1471
1472 // -- Stream management ---------------------------------------------
1473 cu_stream_create: load_sym!(lib, "cuStreamCreate"),
1474 cu_stream_create_with_priority: load_sym!(lib, "cuStreamCreateWithPriority"),
1475 cu_stream_destroy_v2: load_sym!(lib, "cuStreamDestroy_v2"),
1476 cu_stream_synchronize: load_sym!(lib, "cuStreamSynchronize"),
1477 cu_stream_wait_event: load_sym!(lib, "cuStreamWaitEvent"),
1478 cu_stream_query: load_sym!(lib, "cuStreamQuery"),
1479 cu_stream_get_priority: load_sym!(lib, "cuStreamGetPriority"),
1480 cu_stream_get_flags: load_sym!(lib, "cuStreamGetFlags"),
1481
1482 // -- Event management ----------------------------------------------
1483 cu_event_create: load_sym!(lib, "cuEventCreate"),
1484 cu_event_destroy_v2: load_sym!(lib, "cuEventDestroy_v2"),
1485 cu_event_record: load_sym!(lib, "cuEventRecord"),
1486 cu_event_query: load_sym!(lib, "cuEventQuery"),
1487 cu_event_synchronize: load_sym!(lib, "cuEventSynchronize"),
1488 cu_event_elapsed_time: load_sym!(lib, "cuEventElapsedTime"),
1489 cu_event_record_with_flags: load_sym_optional!(lib, "cuEventRecordWithFlags"),
1490
1491 // -- Peer memory access -------------------------------------------
1492 cu_memcpy_peer: load_sym!(lib, "cuMemcpyPeer"),
1493 cu_memcpy_peer_async: load_sym!(lib, "cuMemcpyPeerAsync"),
1494 cu_ctx_enable_peer_access: load_sym!(lib, "cuCtxEnablePeerAccess"),
1495 cu_ctx_disable_peer_access: load_sym!(lib, "cuCtxDisablePeerAccess"),
1496
1497 // -- Kernel launch -------------------------------------------------
1498 cu_launch_kernel: load_sym!(lib, "cuLaunchKernel"),
1499 cu_launch_cooperative_kernel: load_sym!(lib, "cuLaunchCooperativeKernel"),
1500 cu_launch_cooperative_kernel_multi_device: load_sym!(
1501 lib,
1502 "cuLaunchCooperativeKernelMultiDevice"
1503 ),
1504
1505 // -- Occupancy -----------------------------------------------------
1506 cu_occupancy_max_active_blocks_per_multiprocessor: load_sym!(
1507 lib,
1508 "cuOccupancyMaxActiveBlocksPerMultiprocessor"
1509 ),
1510 cu_occupancy_max_potential_block_size: load_sym!(
1511 lib,
1512 "cuOccupancyMaxPotentialBlockSize"
1513 ),
1514 cu_occupancy_max_active_blocks_per_multiprocessor_with_flags: load_sym!(
1515 lib,
1516 "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"
1517 ),
1518
1519 // -- Memory management (optional) ---------------------------------
1520 cu_memcpy_dtod_async_v2: load_sym_optional!(lib, "cuMemcpyDtoDAsync_v2"),
1521 cu_memset_d16_v2: load_sym_optional!(lib, "cuMemsetD16_v2"),
1522 cu_memset_d32_async: load_sym_optional!(lib, "cuMemsetD32Async"),
1523
1524 // -- Context management (optional) --------------------------------
1525 cu_ctx_get_limit: load_sym_optional!(lib, "cuCtxGetLimit"),
1526 cu_ctx_set_limit: load_sym_optional!(lib, "cuCtxSetLimit"),
1527 cu_ctx_get_cache_config: load_sym_optional!(lib, "cuCtxGetCacheConfig"),
1528 cu_ctx_set_cache_config: load_sym_optional!(lib, "cuCtxSetCacheConfig"),
1529 cu_ctx_get_shared_mem_config: load_sym_optional!(lib, "cuCtxGetSharedMemConfig"),
1530 cu_ctx_set_shared_mem_config: load_sym_optional!(lib, "cuCtxSetSharedMemConfig"),
1531
1532 // -- Function attributes (optional) -------------------------------
1533 cu_func_get_attribute: load_sym_optional!(lib, "cuFuncGetAttribute"),
1534 cu_func_set_cache_config: load_sym_optional!(lib, "cuFuncSetCacheConfig"),
1535 cu_func_set_shared_mem_config: load_sym_optional!(lib, "cuFuncSetSharedMemConfig"),
1536 cu_func_set_attribute: load_sym_optional!(lib, "cuFuncSetAttribute"),
1537
1538 // -- Profiler (optional) ------------------------------------------
1539 cu_profiler_start: load_sym_optional!(lib, "cuProfilerStart"),
1540 cu_profiler_stop: load_sym_optional!(lib, "cuProfilerStop"),
1541
1542 // -- CUDA 12.x extended launch (optional) -------------------------
1543 cu_launch_kernel_ex: load_sym_optional!(lib, "cuLaunchKernelEx"),
1544 cu_tensor_map_encode_tiled: load_sym_optional!(lib, "cuTensorMapEncodeTiled"),
1545
1546 // -- CUDA 12.8+ extended API (optional) ---------------------------
1547 cu_tensor_map_encode_tiled_memref: load_sym_optional!(
1548 lib,
1549 "cuTensorMapEncodeTiledMemref"
1550 ),
1551 cu_kernel_get_library: load_sym_optional!(lib, "cuKernelGetLibrary"),
1552 cu_multicast_get_granularity: load_sym_optional!(lib, "cuMulticastGetGranularity"),
1553 cu_multicast_create: load_sym_optional!(lib, "cuMulticastCreate"),
1554 cu_multicast_add_device: load_sym_optional!(lib, "cuMulticastAddDevice"),
1555 cu_memcpy_batch_async: load_sym_optional!(lib, "cuMemcpyBatchAsync"),
1556
1557 // -- Texture / Surface memory (optional) ---------------------------
1558 cu_array_create_v2: load_sym_optional!(lib, "cuArrayCreate_v2"),
1559 cu_array_destroy: load_sym_optional!(lib, "cuArrayDestroy"),
1560 cu_array_get_descriptor_v2: load_sym_optional!(lib, "cuArrayGetDescriptor_v2"),
1561 cu_array3d_create_v2: load_sym_optional!(lib, "cuArray3DCreate_v2"),
1562 cu_array3d_get_descriptor_v2: load_sym_optional!(lib, "cuArray3DGetDescriptor_v2"),
1563 cu_memcpy_htoa_v2: load_sym_optional!(lib, "cuMemcpyHtoA_v2"),
1564 cu_memcpy_atoh_v2: load_sym_optional!(lib, "cuMemcpyAtoH_v2"),
1565 cu_memcpy_htoa_async_v2: load_sym_optional!(lib, "cuMemcpyHtoAAsync_v2"),
1566 cu_memcpy_atoh_async_v2: load_sym_optional!(lib, "cuMemcpyAtoHAsync_v2"),
1567 cu_tex_object_create: load_sym_optional!(lib, "cuTexObjectCreate"),
1568 cu_tex_object_destroy: load_sym_optional!(lib, "cuTexObjectDestroy"),
1569 cu_tex_object_get_resource_desc: load_sym_optional!(lib, "cuTexObjectGetResourceDesc"),
1570 cu_surf_object_create: load_sym_optional!(lib, "cuSurfObjectCreate"),
1571 cu_surf_object_destroy: load_sym_optional!(lib, "cuSurfObjectDestroy"),
1572
1573 // -- JIT linker (optional) ----------------------------------------
1574 cu_link_create: load_sym_optional!(lib, "cuLinkCreate_v2"),
1575 cu_link_add_data: load_sym_optional!(lib, "cuLinkAddData_v2"),
1576 cu_link_complete: load_sym_optional!(lib, "cuLinkComplete"),
1577 cu_link_destroy: load_sym_optional!(lib, "cuLinkDestroy"),
1578
1579 // -- 2-D memory copy (optional) -----------------------------------
1580 cu_memcpy_2d: load_sym_optional!(lib, "cuMemcpy2D_v2"),
1581
1582 // -- VMM (optional, CUDA 11.2+) -----------------------------------
1583 cu_mem_address_reserve: load_sym_optional!(lib, "cuMemAddressReserve"),
1584 cu_mem_address_free: load_sym_optional!(lib, "cuMemAddressFree"),
1585 cu_mem_create: load_sym_optional!(lib, "cuMemCreate"),
1586 cu_mem_release: load_sym_optional!(lib, "cuMemRelease"),
1587 cu_mem_map: load_sym_optional!(lib, "cuMemMap"),
1588 cu_mem_unmap: load_sym_optional!(lib, "cuMemUnmap"),
1589 cu_mem_set_access: load_sym_optional!(lib, "cuMemSetAccess"),
1590
1591 // -- Stream-ordered memory pools (optional, CUDA 11.2+) -----------
1592 cu_mem_pool_create: load_sym_optional!(lib, "cuMemPoolCreate"),
1593 cu_mem_pool_destroy: load_sym_optional!(lib, "cuMemPoolDestroy"),
1594 cu_mem_alloc_from_pool_async: load_sym_optional!(lib, "cuMemAllocFromPoolAsync"),
1595 cu_mem_free_async: load_sym_optional!(lib, "cuMemFreeAsync"),
1596 cu_mem_alloc_async: load_sym_optional!(lib, "cuMemAllocAsync"),
1597 cu_mem_pool_trim_to: load_sym_optional!(lib, "cuMemPoolTrimTo"),
1598 cu_mem_pool_set_attribute: load_sym_optional!(lib, "cuMemPoolSetAttribute"),
1599 cu_mem_pool_get_attribute: load_sym_optional!(lib, "cuMemPoolGetAttribute"),
1600 cu_mem_pool_set_access: load_sym_optional!(lib, "cuMemPoolSetAccess"),
1601 cu_device_get_default_mem_pool: load_sym_optional!(lib, "cuDeviceGetDefaultMemPool"),
1602
1603 // -- CUDA Graph API (optional, CUDA 10.0+) ------------------------
1604 cu_graph_create: load_sym_optional!(lib, "cuGraphCreate"),
1605 cu_graph_destroy: load_sym_optional!(lib, "cuGraphDestroy"),
1606 cu_graph_add_kernel_node: load_sym_optional!(lib, "cuGraphAddKernelNode"),
1607 cu_graph_add_memcpy_node: load_sym_optional!(lib, "cuGraphAddMemcpyNode"),
1608 cu_graph_add_memset_node: load_sym_optional!(lib, "cuGraphAddMemsetNode"),
1609 cu_graph_add_empty_node: load_sym_optional!(lib, "cuGraphAddEmptyNode"),
1610 cu_graph_instantiate_with_flags: load_sym_optional!(lib, "cuGraphInstantiateWithFlags"),
1611 cu_graph_instantiate: load_sym_optional!(lib, "cuGraphInstantiate_v2"),
1612 cu_graph_exec_destroy: load_sym_optional!(lib, "cuGraphExecDestroy"),
1613 cu_graph_launch: load_sym_optional!(lib, "cuGraphLaunch"),
1614
1615 // Keep the library handle alive.
1616 _lib: lib,
1617 })
1618 }
1619}
1620
1621// ---------------------------------------------------------------------------
1622// Global accessor
1623// ---------------------------------------------------------------------------
1624
1625/// Get a reference to the lazily-loaded CUDA driver API function table.
1626///
1627/// On the first call, this function dynamically loads the CUDA shared library
1628/// and resolves all required symbols. Subsequent calls return the cached
1629/// result with only an atomic load.
1630///
1631/// # Errors
1632///
1633/// Returns [`CudaError::NotInitialized`] if the driver could not be loaded —
1634/// for instance, on macOS, or on a system without an NVIDIA GPU driver
1635/// installed.
1636///
1637/// # Examples
1638///
1639/// ```rust,no_run
1640/// # use oxicuda_driver::loader::try_driver;
1641/// let api = try_driver()?;
1642/// let result = unsafe { (api.cu_init)(0) };
1643/// # Ok::<(), oxicuda_driver::error::CudaError>(())
1644/// ```
1645pub fn try_driver() -> CudaResult<&'static DriverApi> {
1646 let result = DRIVER.get_or_init(DriverApi::load);
1647 match result {
1648 Ok(api) => Ok(api),
1649 Err(_) => Err(CudaError::NotInitialized),
1650 }
1651}
1652
1653// ---------------------------------------------------------------------------
1654// Tests
1655// ---------------------------------------------------------------------------
1656
1657#[cfg(test)]
1658mod tests {
1659 use super::*;
1660
1661 /// On macOS, loading should always fail with `UnsupportedPlatform`.
1662 #[cfg(target_os = "macos")]
1663 #[test]
1664 fn load_returns_unsupported_on_macos() {
1665 let result = DriverApi::load();
1666 assert!(result.is_err(), "expected Err on macOS");
1667 let err = match result {
1668 Err(e) => e,
1669 Ok(_) => panic!("expected Err on macOS"),
1670 };
1671 assert!(
1672 matches!(err, DriverLoadError::UnsupportedPlatform),
1673 "expected UnsupportedPlatform, got {err:?}"
1674 );
1675 }
1676
1677 /// `try_driver` should return `Err(NotInitialized)` on platforms without
1678 /// a CUDA driver (including macOS).
1679 #[cfg(target_os = "macos")]
1680 #[test]
1681 fn try_driver_returns_not_initialized_on_macos() {
1682 let result = try_driver();
1683 assert!(result.is_err(), "expected Err on macOS");
1684 let err = match result {
1685 Err(e) => e,
1686 Ok(_) => panic!("expected Err on macOS"),
1687 };
1688 assert!(
1689 matches!(err, CudaError::NotInitialized),
1690 "expected NotInitialized, got {err:?}"
1691 );
1692 }
1693
1694 // -----------------------------------------------------------------------
1695 // Task 1 — CUDA 12.8+ DriverApi struct layout tests
1696 //
1697 // These tests verify that the DriverApi struct contains the expected
1698 // Option<fn(...)> fields for the new CUDA 12.8+ API entry points.
1699 // They compile and run without a GPU because they only inspect type
1700 // layout and field presence, never calling the function pointers.
1701 // -----------------------------------------------------------------------
1702
1703 /// Verify that the `cu_tensor_map_encode_tiled_memref` field exists and
1704 /// is an `Option` type. The driver will return `None` on older versions.
1705 #[test]
1706 fn driver_v12_8_api_fields_present() {
1707 // The simplest way to prove a field exists at the correct type is to
1708 // construct a value that fits in that position. We use a local
1709 // DriverApi value on macOS (where load() always returns Err) by
1710 // manufacturing a dummy function pointer and verifying the type
1711 // annotation compiles.
1712 //
1713 // On non-macOS platforms we simply verify the field is accessible on
1714 // the type via a None literal assignment (compile-time check).
1715 type TensorMapEncodeTiledFn = unsafe extern "C" fn(
1716 tensor_map: *mut std::ffi::c_void,
1717 tensor_data_type: u32,
1718 tensor_rank: u32,
1719 global_address: *mut std::ffi::c_void,
1720 global_dim: *const u64,
1721 global_strides: *const u64,
1722 box_dim: *const u32,
1723 element_strides: *const u32,
1724 interleave: u32,
1725 swizzle: u32,
1726 l2_promotion: u32,
1727 oob_fill: u32,
1728 flags: u64,
1729 ) -> CUresult;
1730 let _none: Option<TensorMapEncodeTiledFn> = None;
1731 // Field name check: accessing the field compiles only if it exists.
1732 // We use a trait-object-based field-name probe: the macro produces a
1733 // compile error if the identifier does not exist.
1734 let _field_exists = |api: &DriverApi| api.cu_tensor_map_encode_tiled_memref.is_none();
1735 // Suppress unused variable warnings.
1736 let _ = _none;
1737 let _ = _field_exists;
1738 }
1739
1740 /// Verify that `cu_multicast_create` and `cu_multicast_add_device` fields
1741 /// exist with the correct Option<fn(...)> types (CUDA 12.8+ multicast).
1742 #[test]
1743 fn driver_v12_8_multicast_fields_present() {
1744 let _probe_create = |api: &DriverApi| api.cu_multicast_create.is_none();
1745 let _probe_add = |api: &DriverApi| api.cu_multicast_add_device.is_none();
1746 let _probe_gran = |api: &DriverApi| api.cu_multicast_get_granularity.is_none();
1747 let _ = (_probe_create, _probe_add, _probe_gran);
1748 }
1749
1750 /// Verify that `cu_memcpy_batch_async` field exists with the correct
1751 /// Option<fn(...)> type (CUDA 12.8+ batch memcpy).
1752 #[test]
1753 fn driver_v12_8_batch_memcpy_field_present() {
1754 let _probe = |api: &DriverApi| api.cu_memcpy_batch_async.is_none();
1755 let _ = _probe;
1756 }
1757
1758 /// Verify that `cu_kernel_get_library` field exists (CUDA 12.8+ JIT libs).
1759 #[test]
1760 fn driver_v12_8_kernel_get_library_field_present() {
1761 let _probe = |api: &DriverApi| api.cu_kernel_get_library.is_none();
1762 let _ = _probe;
1763 }
1764
1765 // -----------------------------------------------------------------------
1766 // Wave 1 — Extended Driver API field-presence tests
1767 // -----------------------------------------------------------------------
1768
1769 /// All four `cuLink*` JIT linker fields are present and `Option`.
1770 #[test]
1771 fn driver_link_api_fields_present() {
1772 let _probe_create = |api: &DriverApi| api.cu_link_create.is_none();
1773 let _probe_add = |api: &DriverApi| api.cu_link_add_data.is_none();
1774 let _probe_complete = |api: &DriverApi| api.cu_link_complete.is_none();
1775 let _probe_destroy = |api: &DriverApi| api.cu_link_destroy.is_none();
1776 let _ = (_probe_create, _probe_add, _probe_complete, _probe_destroy);
1777 }
1778
1779 /// The `cuMemcpy2D_v2` field is present and `Option`.
1780 #[test]
1781 fn driver_memcpy_2d_field_present() {
1782 let _probe = |api: &DriverApi| api.cu_memcpy_2d.is_none();
1783 let _ = _probe;
1784 }
1785
1786 /// All seven VMM (CUDA 11.2+) fields are present and `Option`.
1787 #[test]
1788 fn driver_vmm_api_fields_present() {
1789 let _probe_reserve = |api: &DriverApi| api.cu_mem_address_reserve.is_none();
1790 let _probe_free = |api: &DriverApi| api.cu_mem_address_free.is_none();
1791 let _probe_create = |api: &DriverApi| api.cu_mem_create.is_none();
1792 let _probe_release = |api: &DriverApi| api.cu_mem_release.is_none();
1793 let _probe_map = |api: &DriverApi| api.cu_mem_map.is_none();
1794 let _probe_unmap = |api: &DriverApi| api.cu_mem_unmap.is_none();
1795 let _probe_set_access = |api: &DriverApi| api.cu_mem_set_access.is_none();
1796 let _ = (
1797 _probe_reserve,
1798 _probe_free,
1799 _probe_create,
1800 _probe_release,
1801 _probe_map,
1802 _probe_unmap,
1803 _probe_set_access,
1804 );
1805 }
1806
1807 /// All four memory-pool (CUDA 11.2+) fields are present and `Option`.
1808 #[test]
1809 fn driver_mem_pool_api_fields_present() {
1810 let _probe_create = |api: &DriverApi| api.cu_mem_pool_create.is_none();
1811 let _probe_destroy = |api: &DriverApi| api.cu_mem_pool_destroy.is_none();
1812 let _probe_alloc = |api: &DriverApi| api.cu_mem_alloc_from_pool_async.is_none();
1813 let _probe_free = |api: &DriverApi| api.cu_mem_free_async.is_none();
1814 let _ = (_probe_create, _probe_destroy, _probe_alloc, _probe_free);
1815 }
1816
1817 /// The extended stream-ordered memory-pool fields are present and `Option`.
1818 #[test]
1819 fn driver_stream_ordered_pool_extended_fields_present() {
1820 let _probe_alloc_async = |api: &DriverApi| api.cu_mem_alloc_async.is_none();
1821 let _probe_trim = |api: &DriverApi| api.cu_mem_pool_trim_to.is_none();
1822 let _probe_set_attr = |api: &DriverApi| api.cu_mem_pool_set_attribute.is_none();
1823 let _probe_get_attr = |api: &DriverApi| api.cu_mem_pool_get_attribute.is_none();
1824 let _probe_set_access = |api: &DriverApi| api.cu_mem_pool_set_access.is_none();
1825 let _probe_default = |api: &DriverApi| api.cu_device_get_default_mem_pool.is_none();
1826 let _ = (
1827 _probe_alloc_async,
1828 _probe_trim,
1829 _probe_set_attr,
1830 _probe_get_attr,
1831 _probe_set_access,
1832 _probe_default,
1833 );
1834 }
1835
1836 /// All CUDA Graph API fields are present and `Option`.
1837 #[test]
1838 fn driver_graph_api_fields_present() {
1839 let _probe_create = |api: &DriverApi| api.cu_graph_create.is_none();
1840 let _probe_destroy = |api: &DriverApi| api.cu_graph_destroy.is_none();
1841 let _probe_kernel = |api: &DriverApi| api.cu_graph_add_kernel_node.is_none();
1842 let _probe_memcpy = |api: &DriverApi| api.cu_graph_add_memcpy_node.is_none();
1843 let _probe_memset = |api: &DriverApi| api.cu_graph_add_memset_node.is_none();
1844 let _probe_empty = |api: &DriverApi| api.cu_graph_add_empty_node.is_none();
1845 let _probe_inst_flags = |api: &DriverApi| api.cu_graph_instantiate_with_flags.is_none();
1846 let _probe_inst = |api: &DriverApi| api.cu_graph_instantiate.is_none();
1847 let _probe_exec_destroy = |api: &DriverApi| api.cu_graph_exec_destroy.is_none();
1848 let _probe_launch = |api: &DriverApi| api.cu_graph_launch.is_none();
1849 let _ = (
1850 _probe_create,
1851 _probe_destroy,
1852 _probe_kernel,
1853 _probe_memcpy,
1854 _probe_memset,
1855 _probe_empty,
1856 _probe_inst_flags,
1857 _probe_inst,
1858 _probe_exec_destroy,
1859 _probe_launch,
1860 );
1861 }
1862}