Skip to main content

oxicuda_driver/
loader.rs

1//! Dynamic CUDA driver library loader.
2//!
3//! This module is the architectural foundation of `oxicuda-driver`. It locates
4//! and loads the CUDA driver shared library (`libcuda.so` on Linux,
5//! `nvcuda.dll` on Windows) **at runtime** via [`libloading`], so that no CUDA
6//! SDK is required at build time.
7//!
8//! # Platform support
9//!
10//! | Platform | Library names tried              | Notes                            |
11//! |----------|----------------------------------|----------------------------------|
12//! | Linux    | `libcuda.so.1`, `libcuda.so`     | Installed by NVIDIA driver       |
13//! | Windows  | `nvcuda.dll`                     | Ships with the display driver    |
14//! | macOS    | —                                | Returns `UnsupportedPlatform`    |
15//!
16//! # Usage
17//!
18//! Application code should **not** interact with [`DriverApi`] directly.
19//! Instead, call [`try_driver`] to obtain a reference to the lazily-
20//! initialised global singleton:
21//!
22//! ```rust,no_run
23//! # use oxicuda_driver::loader::try_driver;
24//! let api = try_driver()?;
25//! // api.cu_init, api.cu_device_get, …
26//! # Ok::<(), oxicuda_driver::error::CudaError>(())
27//! ```
28//!
29//! The singleton is stored in a [`OnceLock`] so that the (relatively
30//! expensive) `dlopen` + symbol resolution only happens once, and all
31//! subsequent accesses are a single atomic load.
32
33use std::ffi::{c_char, c_int, c_void};
34use std::sync::OnceLock;
35
36use libloading::Library;
37
38use crate::error::{CudaError, CudaResult, DriverLoadError};
39use crate::ffi::*;
40
41// ---------------------------------------------------------------------------
42// Global singleton
43// ---------------------------------------------------------------------------
44
45/// Global singleton for the driver API function table.
46///
47/// Initialised lazily on the first call to [`try_driver`].
48static DRIVER: OnceLock<Result<DriverApi, DriverLoadError>> = OnceLock::new();
49
50// ---------------------------------------------------------------------------
51// load_sym! helper macro
52// ---------------------------------------------------------------------------
53
54/// Load a single symbol from the shared library and transmute it to the
55/// requested function-pointer type.
56///
57/// # Safety
58///
59/// The caller must ensure that the symbol name matches the actual ABI of the
60/// function pointer type expected at the call site.
61#[cfg(not(target_os = "macos"))]
62macro_rules! load_sym {
63    ($lib:expr, $name:literal) => {{
64        // `Library::get` requires the name as a byte slice.  We request the
65        // most general function-pointer type and then transmute to the
66        // concrete signature stored in DriverApi.
67        let sym = unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) }.map_err(|e| {
68            DriverLoadError::SymbolNotFound {
69                symbol: $name,
70                reason: e.to_string(),
71            }
72        })?;
73        // SAFETY: we trust that the CUDA driver exports the symbol with the
74        // ABI described by the target field type.  The type is inferred from
75        // the DriverApi field this expression is assigned to, so explicit
76        // transmute annotations would require repeating the function-pointer
77        // type at every call site inside a macro — we suppress that lint here.
78        #[allow(clippy::missing_transmute_annotations)]
79        let result = unsafe { std::mem::transmute(*sym) };
80        result
81    }};
82}
83
84/// Load a symbol from the shared library, returning `Some(fn_ptr)` on success
85/// or `None` if the symbol is not found. Used for optional API entry points
86/// that may not be present in older driver versions.
87///
88/// # Safety
89///
90/// Same safety requirements as [`load_sym!`].
91#[cfg(not(target_os = "macos"))]
92macro_rules! load_sym_optional {
93    ($lib:expr, $name:literal) => {{
94        match unsafe { $lib.get::<unsafe extern "C" fn()>($name.as_bytes()) } {
95            Ok(sym) => {
96                // SAFETY: the target type is inferred from the DriverApi field
97                // this value is assigned to.  Suppressing the lint here avoids
98                // repeating the function-pointer type at every call site.
99                #[allow(clippy::missing_transmute_annotations)]
100                let fp = unsafe { std::mem::transmute(*sym) };
101                Some(fp)
102            }
103            Err(_) => {
104                tracing::debug!(concat!("optional symbol not found: ", $name));
105                None
106            }
107        }
108    }};
109}
110
111// ---------------------------------------------------------------------------
112// DriverApi
113// ---------------------------------------------------------------------------
114
115/// Complete function-pointer table for the CUDA Driver API.
116///
117/// An instance of this struct is produced by [`DriverApi::load`] and kept
118/// alive for the lifetime of the process inside the `DRIVER` singleton.
119/// The embedded [`Library`] handle ensures the shared object is not unloaded.
120///
121/// # Function pointer groups
122///
123/// The fields are organised into logical groups mirroring the CUDA Driver API
124/// documentation:
125///
126/// * **Initialisation** — [`cu_init`](Self::cu_init)
127/// * **Device management** — `cu_device_*`
128/// * **Context management** — `cu_ctx_*`
129/// * **Module management** — `cu_module_*`
130/// * **Memory management** — `cu_mem_*`, `cu_memcpy_*`, `cu_memset_*`
131/// * **Stream management** — `cu_stream_*`
132/// * **Event management** — `cu_event_*`
133/// * **Kernel launch** — [`cu_launch_kernel`](Self::cu_launch_kernel)
134/// * **Occupancy queries** — `cu_occupancy_*`
135pub struct DriverApi {
136    // Keep the shared library handle alive.
137    _lib: Library,
138
139    // -- Initialisation ----------------------------------------------------
140    /// `cuInit(flags) -> CUresult`
141    ///
142    /// Initialises the CUDA driver API.  Must be called before any other
143    /// driver function.  Passing `0` for *flags* is the only documented
144    /// value.
145    pub cu_init: unsafe extern "C" fn(flags: u32) -> CUresult,
146
147    // -- Version query -------------------------------------------------------
148    /// `cuDriverGetVersion(driverVersion*) -> CUresult`
149    ///
150    /// Returns the CUDA driver version as `major*1000 + minor*10`.
151    pub cu_driver_get_version: unsafe extern "C" fn(version: *mut c_int) -> CUresult,
152
153    // -- Device management -------------------------------------------------
154    /// `cuDeviceGet(device*, ordinal) -> CUresult`
155    ///
156    /// Returns a handle to a compute device.
157    pub cu_device_get: unsafe extern "C" fn(device: *mut CUdevice, ordinal: c_int) -> CUresult,
158
159    /// `cuDeviceGetCount(count*) -> CUresult`
160    ///
161    /// Returns the number of compute-capable devices.
162    pub cu_device_get_count: unsafe extern "C" fn(count: *mut c_int) -> CUresult,
163
164    /// `cuDeviceGetName(name*, len, dev) -> CUresult`
165    ///
166    /// Returns an ASCII string identifying the device.
167    pub cu_device_get_name:
168        unsafe extern "C" fn(name: *mut c_char, len: c_int, dev: CUdevice) -> CUresult,
169
170    /// `cuDeviceGetAttribute(pi*, attrib, dev) -> CUresult`
171    ///
172    /// Returns information about the device.
173    pub cu_device_get_attribute:
174        unsafe extern "C" fn(pi: *mut c_int, attrib: CUdevice_attribute, dev: CUdevice) -> CUresult,
175
176    /// `cuDeviceTotalMem_v2(bytes*, dev) -> CUresult`
177    ///
178    /// Returns the total amount of memory on the device.
179    pub cu_device_total_mem_v2: unsafe extern "C" fn(bytes: *mut usize, dev: CUdevice) -> CUresult,
180
181    /// `cuDeviceCanAccessPeer(canAccessPeer*, dev, peerDev) -> CUresult`
182    ///
183    /// Queries if a device may directly access a peer device's memory.
184    pub cu_device_can_access_peer:
185        unsafe extern "C" fn(can_access: *mut c_int, dev: CUdevice, peer_dev: CUdevice) -> CUresult,
186
187    // -- Primary context management ----------------------------------------
188    /// `cuDevicePrimaryCtxRetain(pctx*, dev) -> CUresult`
189    ///
190    /// Retains the primary context on the device, creating it if necessary.
191    pub cu_device_primary_ctx_retain:
192        unsafe extern "C" fn(pctx: *mut CUcontext, dev: CUdevice) -> CUresult,
193
194    /// `cuDevicePrimaryCtxRelease_v2(dev) -> CUresult`
195    ///
196    /// Releases the primary context on the device.
197    pub cu_device_primary_ctx_release_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
198
199    /// `cuDevicePrimaryCtxSetFlags_v2(dev, flags) -> CUresult`
200    ///
201    /// Sets flags for the primary context.
202    pub cu_device_primary_ctx_set_flags_v2:
203        unsafe extern "C" fn(dev: CUdevice, flags: u32) -> CUresult,
204
205    /// `cuDevicePrimaryCtxGetState(dev, flags*, active*) -> CUresult`
206    ///
207    /// Returns the state (flags and active status) of the primary context.
208    pub cu_device_primary_ctx_get_state:
209        unsafe extern "C" fn(dev: CUdevice, flags: *mut u32, active: *mut c_int) -> CUresult,
210
211    /// `cuDevicePrimaryCtxReset_v2(dev) -> CUresult`
212    ///
213    /// Resets the primary context on the device.
214    pub cu_device_primary_ctx_reset_v2: unsafe extern "C" fn(dev: CUdevice) -> CUresult,
215
216    // -- Context management ------------------------------------------------
217    /// `cuCtxCreate_v2(pctx*, flags, dev) -> CUresult`
218    ///
219    /// Creates a new CUDA context and associates it with the calling thread.
220    pub cu_ctx_create_v2:
221        unsafe extern "C" fn(pctx: *mut CUcontext, flags: u32, dev: CUdevice) -> CUresult,
222
223    /// `cuCtxDestroy_v2(ctx) -> CUresult`
224    ///
225    /// Destroys a CUDA context.
226    pub cu_ctx_destroy_v2: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
227
228    /// `cuCtxSetCurrent(ctx) -> CUresult`
229    ///
230    /// Binds the specified CUDA context to the calling CPU thread.
231    pub cu_ctx_set_current: unsafe extern "C" fn(ctx: CUcontext) -> CUresult,
232
233    /// `cuCtxGetCurrent(pctx*) -> CUresult`
234    ///
235    /// Returns the CUDA context bound to the calling CPU thread.
236    pub cu_ctx_get_current: unsafe extern "C" fn(pctx: *mut CUcontext) -> CUresult,
237
238    /// `cuCtxSynchronize() -> CUresult`
239    ///
240    /// Blocks until the device has completed all preceding requested tasks.
241    pub cu_ctx_synchronize: unsafe extern "C" fn() -> CUresult,
242
243    // -- Module management -------------------------------------------------
244    /// `cuModuleLoadData(module*, image*) -> CUresult`
245    ///
246    /// Loads a module from a PTX or cubin image in host memory.
247    pub cu_module_load_data:
248        unsafe extern "C" fn(module: *mut CUmodule, image: *const c_void) -> CUresult,
249
250    /// `cuModuleLoadDataEx(module*, image*, numOptions, options*, optionValues*) -> CUresult`
251    ///
252    /// Loads a module with JIT compiler options.
253    pub cu_module_load_data_ex: unsafe extern "C" fn(
254        module: *mut CUmodule,
255        image: *const c_void,
256        num_options: u32,
257        options: *mut CUjit_option,
258        option_values: *mut *mut c_void,
259    ) -> CUresult,
260
261    /// `cuModuleGetFunction(hfunc*, hmod, name*) -> CUresult`
262    ///
263    /// Returns a handle to a function within a module.
264    pub cu_module_get_function: unsafe extern "C" fn(
265        hfunc: *mut CUfunction,
266        hmod: CUmodule,
267        name: *const c_char,
268    ) -> CUresult,
269
270    /// `cuModuleUnload(hmod) -> CUresult`
271    ///
272    /// Unloads a module from the current context.
273    pub cu_module_unload: unsafe extern "C" fn(hmod: CUmodule) -> CUresult,
274
275    // -- Memory management -------------------------------------------------
276    /// `cuMemAlloc_v2(dptr*, bytesize) -> CUresult`
277    ///
278    /// Allocates device memory.
279    pub cu_mem_alloc_v2: unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize) -> CUresult,
280
281    /// `cuMemFree_v2(dptr) -> CUresult`
282    ///
283    /// Frees device memory.
284    pub cu_mem_free_v2: unsafe extern "C" fn(dptr: CUdeviceptr) -> CUresult,
285
286    /// `cuMemcpyHtoD_v2(dst, src*, bytesize) -> CUresult`
287    ///
288    /// Copies data from host memory to device memory.
289    pub cu_memcpy_htod_v2:
290        unsafe extern "C" fn(dst: CUdeviceptr, src: *const c_void, bytesize: usize) -> CUresult,
291
292    /// `cuMemcpyDtoH_v2(dst*, src, bytesize) -> CUresult`
293    ///
294    /// Copies data from device memory to host memory.
295    pub cu_memcpy_dtoh_v2:
296        unsafe extern "C" fn(dst: *mut c_void, src: CUdeviceptr, bytesize: usize) -> CUresult,
297
298    /// `cuMemcpyDtoD_v2(dst, src, bytesize) -> CUresult`
299    ///
300    /// Copies data from device memory to device memory.
301    pub cu_memcpy_dtod_v2:
302        unsafe extern "C" fn(dst: CUdeviceptr, src: CUdeviceptr, bytesize: usize) -> CUresult,
303
304    /// `cuMemcpyHtoDAsync_v2(dst, src*, bytesize, stream) -> CUresult`
305    ///
306    /// Asynchronously copies data from host to device memory.
307    pub cu_memcpy_htod_async_v2: unsafe extern "C" fn(
308        dst: CUdeviceptr,
309        src: *const c_void,
310        bytesize: usize,
311        stream: CUstream,
312    ) -> CUresult,
313
314    /// `cuMemcpyDtoHAsync_v2(dst*, src, bytesize, stream) -> CUresult`
315    ///
316    /// Asynchronously copies data from device to host memory.
317    pub cu_memcpy_dtoh_async_v2: unsafe extern "C" fn(
318        dst: *mut c_void,
319        src: CUdeviceptr,
320        bytesize: usize,
321        stream: CUstream,
322    ) -> CUresult,
323
324    /// `cuMemAllocHost_v2(pp*, bytesize) -> CUresult`
325    ///
326    /// Allocates page-locked (pinned) host memory.
327    pub cu_mem_alloc_host_v2:
328        unsafe extern "C" fn(pp: *mut *mut c_void, bytesize: usize) -> CUresult,
329
330    /// `cuMemFreeHost(p*) -> CUresult`
331    ///
332    /// Frees page-locked host memory.
333    pub cu_mem_free_host: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
334
335    /// `cuMemAllocManaged(dptr*, bytesize, flags) -> CUresult`
336    ///
337    /// Allocates unified memory accessible from both host and device.
338    pub cu_mem_alloc_managed:
339        unsafe extern "C" fn(dptr: *mut CUdeviceptr, bytesize: usize, flags: u32) -> CUresult,
340
341    /// `cuMemsetD8_v2(dst, value, count) -> CUresult`
342    ///
343    /// Sets device memory to a value (byte granularity).
344    pub cu_memset_d8_v2:
345        unsafe extern "C" fn(dst: CUdeviceptr, value: u8, count: usize) -> CUresult,
346
347    /// `cuMemsetD32_v2(dst, value, count) -> CUresult`
348    ///
349    /// Sets device memory to a value (32-bit granularity).
350    pub cu_memset_d32_v2:
351        unsafe extern "C" fn(dst: CUdeviceptr, value: u32, count: usize) -> CUresult,
352
353    /// `cuMemGetInfo_v2(free*, total*) -> CUresult`
354    ///
355    /// Returns free and total memory for the current context's device.
356    pub cu_mem_get_info_v2: unsafe extern "C" fn(free: *mut usize, total: *mut usize) -> CUresult,
357
358    /// `cuMemHostRegister_v2(p*, bytesize, flags) -> CUresult`
359    ///
360    /// Registers an existing host memory range for use by CUDA.
361    pub cu_mem_host_register_v2:
362        unsafe extern "C" fn(p: *mut c_void, bytesize: usize, flags: u32) -> CUresult,
363
364    /// `cuMemHostUnregister(p*) -> CUresult`
365    ///
366    /// Unregisters a memory range that was registered with cuMemHostRegister.
367    pub cu_mem_host_unregister: unsafe extern "C" fn(p: *mut c_void) -> CUresult,
368
369    /// `cuMemHostGetDevicePointer_v2(pdptr*, p*, flags) -> CUresult`
370    ///
371    /// Returns the device pointer mapped to a registered host pointer.
372    pub cu_mem_host_get_device_pointer_v2:
373        unsafe extern "C" fn(pdptr: *mut CUdeviceptr, p: *mut c_void, flags: u32) -> CUresult,
374
375    /// `cuPointerGetAttribute(data*, attribute, ptr) -> CUresult`
376    ///
377    /// Returns information about a pointer.
378    pub cu_pointer_get_attribute:
379        unsafe extern "C" fn(data: *mut c_void, attribute: u32, ptr: CUdeviceptr) -> CUresult,
380
381    /// `cuMemAdvise(devPtr, count, advice, device) -> CUresult`
382    ///
383    /// Advises the unified memory subsystem about usage patterns.
384    pub cu_mem_advise: unsafe extern "C" fn(
385        dev_ptr: CUdeviceptr,
386        count: usize,
387        advice: u32,
388        device: CUdevice,
389    ) -> CUresult,
390
391    /// `cuMemPrefetchAsync(devPtr, count, dstDevice, hStream) -> CUresult`
392    ///
393    /// Prefetches unified memory to the specified device.
394    pub cu_mem_prefetch_async: unsafe extern "C" fn(
395        dev_ptr: CUdeviceptr,
396        count: usize,
397        dst_device: CUdevice,
398        hstream: CUstream,
399    ) -> CUresult,
400
401    // -- Stream management -------------------------------------------------
402    /// `cuStreamCreate(phStream*, flags) -> CUresult`
403    ///
404    /// Creates a stream.
405    pub cu_stream_create: unsafe extern "C" fn(phstream: *mut CUstream, flags: u32) -> CUresult,
406
407    /// `cuStreamCreateWithPriority(phStream*, flags, priority) -> CUresult`
408    ///
409    /// Creates a stream with the given priority.
410    pub cu_stream_create_with_priority:
411        unsafe extern "C" fn(phstream: *mut CUstream, flags: u32, priority: c_int) -> CUresult,
412
413    /// `cuStreamDestroy_v2(hStream) -> CUresult`
414    ///
415    /// Destroys a stream.
416    pub cu_stream_destroy_v2: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
417
418    /// `cuStreamSynchronize(hStream) -> CUresult`
419    ///
420    /// Waits until a stream's tasks are completed.
421    pub cu_stream_synchronize: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
422
423    /// `cuStreamWaitEvent(hStream, hEvent, flags) -> CUresult`
424    ///
425    /// Makes all future work submitted to the stream wait for the event.
426    pub cu_stream_wait_event:
427        unsafe extern "C" fn(hstream: CUstream, hevent: CUevent, flags: u32) -> CUresult,
428
429    /// `cuStreamQuery(hStream) -> CUresult`
430    ///
431    /// Returns `CUDA_SUCCESS` if all operations in the stream have completed,
432    /// `CUDA_ERROR_NOT_READY` if still pending.
433    pub cu_stream_query: unsafe extern "C" fn(hstream: CUstream) -> CUresult,
434
435    /// `cuStreamGetPriority(hStream, priority*) -> CUresult`
436    ///
437    /// Query the priority of `hStream`.
438    pub cu_stream_get_priority:
439        unsafe extern "C" fn(hstream: CUstream, priority: *mut std::ffi::c_int) -> CUresult,
440
441    /// `cuStreamGetFlags(hStream, flags*) -> CUresult`
442    ///
443    /// Query the flags of `hStream`.
444    pub cu_stream_get_flags: unsafe extern "C" fn(hstream: CUstream, flags: *mut u32) -> CUresult,
445
446    // -- Event management --------------------------------------------------
447    /// `cuEventCreate(phEvent*, flags) -> CUresult`
448    ///
449    /// Creates an event.
450    pub cu_event_create: unsafe extern "C" fn(phevent: *mut CUevent, flags: u32) -> CUresult,
451
452    /// `cuEventDestroy_v2(hEvent) -> CUresult`
453    ///
454    /// Destroys an event.
455    pub cu_event_destroy_v2: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
456
457    /// `cuEventRecord(hEvent, hStream) -> CUresult`
458    ///
459    /// Records an event in a stream.
460    pub cu_event_record: unsafe extern "C" fn(hevent: CUevent, hstream: CUstream) -> CUresult,
461
462    /// `cuEventQuery(hEvent) -> CUresult`
463    ///
464    /// Queries the status of an event. Returns `CUDA_SUCCESS` if complete,
465    /// `CUDA_ERROR_NOT_READY` if still pending.
466    pub cu_event_query: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
467
468    /// `cuEventSynchronize(hEvent) -> CUresult`
469    ///
470    /// Waits until an event completes.
471    pub cu_event_synchronize: unsafe extern "C" fn(hevent: CUevent) -> CUresult,
472
473    /// `cuEventElapsedTime(pMilliseconds*, hStart, hEnd) -> CUresult`
474    ///
475    /// Computes the elapsed time between two events.
476    pub cu_event_elapsed_time:
477        unsafe extern "C" fn(pmilliseconds: *mut f32, hstart: CUevent, hend: CUevent) -> CUresult,
478
479    // -- Kernel launch -----------------------------------------------------
480
481    // -- Peer memory access ------------------------------------------------
482    /// `cuMemcpyPeer(dstDevice, dstContext, srcDevice, srcContext, count) -> CUresult`
483    ///
484    /// Copies device memory between two primary contexts.
485    pub cu_memcpy_peer: unsafe extern "C" fn(
486        dst_device: u64,
487        dst_ctx: CUcontext,
488        src_device: u64,
489        src_ctx: CUcontext,
490        count: usize,
491    ) -> CUresult,
492
493    /// `cuMemcpyPeerAsync(..., hStream) -> CUresult`
494    ///
495    /// Asynchronous cross-device copy.
496    pub cu_memcpy_peer_async: unsafe extern "C" fn(
497        dst_device: u64,
498        dst_ctx: CUcontext,
499        src_device: u64,
500        src_ctx: CUcontext,
501        count: usize,
502        stream: CUstream,
503    ) -> CUresult,
504
505    /// `cuCtxEnablePeerAccess(peerContext, flags) -> CUresult`
506    ///
507    /// Enables peer access between two contexts.
508    pub cu_ctx_enable_peer_access:
509        unsafe extern "C" fn(peer_context: CUcontext, flags: u32) -> CUresult,
510
511    /// `cuCtxDisablePeerAccess(peerContext) -> CUresult`
512    ///
513    /// Disables peer access to a context.
514    pub cu_ctx_disable_peer_access: unsafe extern "C" fn(peer_context: CUcontext) -> CUresult,
515    /// `cuLaunchKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
516    ///   blockDimZ, sharedMemBytes, hStream, kernelParams**, extra**) -> CUresult`
517    ///
518    /// Launches a CUDA kernel.
519    #[allow(clippy::type_complexity)]
520    pub cu_launch_kernel: unsafe extern "C" fn(
521        f: CUfunction,
522        grid_dim_x: u32,
523        grid_dim_y: u32,
524        grid_dim_z: u32,
525        block_dim_x: u32,
526        block_dim_y: u32,
527        block_dim_z: u32,
528        shared_mem_bytes: u32,
529        hstream: CUstream,
530        kernel_params: *mut *mut c_void,
531        extra: *mut *mut c_void,
532    ) -> CUresult,
533
534    /// `cuLaunchCooperativeKernel(f, gridDimX, gridDimY, gridDimZ, blockDimX,
535    ///   blockDimY, blockDimZ, sharedMemBytes, hStream, kernelParams**) -> CUresult`
536    ///
537    /// Launches a cooperative CUDA kernel (CUDA 9.0+).
538    #[allow(clippy::type_complexity)]
539    pub cu_launch_cooperative_kernel: unsafe extern "C" fn(
540        f: CUfunction,
541        grid_dim_x: u32,
542        grid_dim_y: u32,
543        grid_dim_z: u32,
544        block_dim_x: u32,
545        block_dim_y: u32,
546        block_dim_z: u32,
547        shared_mem_bytes: u32,
548        hstream: CUstream,
549        kernel_params: *mut *mut c_void,
550    ) -> CUresult,
551
552    /// `cuLaunchCooperativeKernelMultiDevice(launchParamsList*, numDevices,
553    ///   flags) -> CUresult`
554    ///
555    /// Launches a cooperative kernel across multiple devices (CUDA 9.0+).
556    pub cu_launch_cooperative_kernel_multi_device: unsafe extern "C" fn(
557        launch_params_list: *mut c_void,
558        num_devices: u32,
559        flags: u32,
560    ) -> CUresult,
561
562    // -- Occupancy ---------------------------------------------------------
563    /// `cuOccupancyMaxActiveBlocksPerMultiprocessor(numBlocks*, func, blockSize,
564    ///   dynamicSMemSize) -> CUresult`
565    ///
566    /// Returns the number of the maximum active blocks per streaming
567    /// multiprocessor.
568    pub cu_occupancy_max_active_blocks_per_multiprocessor: unsafe extern "C" fn(
569        num_blocks: *mut c_int,
570        func: CUfunction,
571        block_size: c_int,
572        dynamic_smem_size: usize,
573    ) -> CUresult,
574
575    /// `cuOccupancyMaxPotentialBlockSize(minGridSize*, blockSize*, func,
576    ///   blockSizeToDynamicSMemSize, dynamicSMemSize, blockSizeLimit) -> CUresult`
577    ///
578    /// Suggests a launch configuration with reasonable occupancy.
579    #[allow(clippy::type_complexity)]
580    pub cu_occupancy_max_potential_block_size: unsafe extern "C" fn(
581        min_grid_size: *mut c_int,
582        block_size: *mut c_int,
583        func: CUfunction,
584        block_size_to_dynamic_smem_size: Option<unsafe extern "C" fn(c_int) -> usize>,
585        dynamic_smem_size: usize,
586        block_size_limit: c_int,
587    ) -> CUresult,
588
589    /// `cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags(numBlocks*, func,
590    ///   blockSize, dynamicSMemSize, flags) -> CUresult`
591    ///
592    /// Like `cuOccupancyMaxActiveBlocksPerMultiprocessor` but with flags
593    /// to control caching behaviour (CUDA 9.0+).
594    pub cu_occupancy_max_active_blocks_per_multiprocessor_with_flags:
595        unsafe extern "C" fn(
596            num_blocks: *mut c_int,
597            func: CUfunction,
598            block_size: c_int,
599            dynamic_smem_size: usize,
600            flags: u32,
601        ) -> CUresult,
602
603    // -- Memory management (optional) -----------------------------------------
604    /// `cuMemcpyDtoDAsync_v2(dst, src, bytesize, stream) -> CUresult`
605    ///
606    /// Asynchronously copies data from device memory to device memory.
607    pub cu_memcpy_dtod_async_v2: Option<
608        unsafe extern "C" fn(
609            dst: CUdeviceptr,
610            src: CUdeviceptr,
611            bytesize: usize,
612            stream: CUstream,
613        ) -> CUresult,
614    >,
615
616    /// `cuMemsetD16_v2(dst, value, count) -> CUresult`
617    ///
618    /// Sets device memory to a value (16-bit granularity).
619    pub cu_memset_d16_v2:
620        Option<unsafe extern "C" fn(dst: CUdeviceptr, value: u16, count: usize) -> CUresult>,
621
622    /// `cuMemsetD32Async(dst, value, count, stream) -> CUresult`
623    ///
624    /// Asynchronously sets device memory to a value (32-bit granularity).
625    pub cu_memset_d32_async: Option<
626        unsafe extern "C" fn(
627            dst: CUdeviceptr,
628            value: u32,
629            count: usize,
630            stream: CUstream,
631        ) -> CUresult,
632    >,
633
634    // -- Context management (optional) ----------------------------------------
635    /// `cuCtxGetLimit(value*, limit) -> CUresult`
636    ///
637    /// Returns the value of a context limit.
638    pub cu_ctx_get_limit: Option<unsafe extern "C" fn(value: *mut usize, limit: u32) -> CUresult>,
639
640    /// `cuCtxSetLimit(limit, value) -> CUresult`
641    ///
642    /// Sets a context limit.
643    pub cu_ctx_set_limit: Option<unsafe extern "C" fn(limit: u32, value: usize) -> CUresult>,
644
645    /// `cuCtxGetCacheConfig(config*) -> CUresult`
646    ///
647    /// Returns the current cache configuration for the context.
648    pub cu_ctx_get_cache_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
649
650    /// `cuCtxSetCacheConfig(config) -> CUresult`
651    ///
652    /// Sets the cache configuration for the current context.
653    pub cu_ctx_set_cache_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
654
655    /// `cuCtxGetSharedMemConfig(config*) -> CUresult`
656    ///
657    /// Returns the shared memory configuration for the context.
658    pub cu_ctx_get_shared_mem_config: Option<unsafe extern "C" fn(config: *mut u32) -> CUresult>,
659
660    /// `cuCtxSetSharedMemConfig(config) -> CUresult`
661    ///
662    /// Sets the shared memory configuration for the current context.
663    pub cu_ctx_set_shared_mem_config: Option<unsafe extern "C" fn(config: u32) -> CUresult>,
664
665    // -- Event with flags (optional, CUDA 11.1+) ------------------------------
666    /// `cuEventRecordWithFlags(hEvent, hStream, flags) -> CUresult`
667    ///
668    /// Records an event in a stream with additional flags (CUDA 11.1+).
669    /// Falls back to `cu_event_record` when `None`.
670    pub cu_event_record_with_flags:
671        Option<unsafe extern "C" fn(hevent: CUevent, hstream: CUstream, flags: u32) -> CUresult>,
672
673    // -- Function attributes (optional) ---------------------------------------
674    /// `cuFuncGetAttribute(value*, attrib, func) -> CUresult`
675    ///
676    /// Returns information about a function.
677    pub cu_func_get_attribute: Option<
678        unsafe extern "C" fn(value: *mut c_int, attrib: c_int, func: CUfunction) -> CUresult,
679    >,
680
681    /// `cuFuncSetCacheConfig(func, config) -> CUresult`
682    ///
683    /// Sets the cache configuration for a device function.
684    pub cu_func_set_cache_config:
685        Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
686
687    /// `cuFuncSetSharedMemConfig(func, config) -> CUresult`
688    ///
689    /// Sets the shared memory configuration for a device function.
690    pub cu_func_set_shared_mem_config:
691        Option<unsafe extern "C" fn(func: CUfunction, config: u32) -> CUresult>,
692
693    /// `cuFuncSetAttribute(func, attrib, value) -> CUresult`
694    ///
695    /// Sets an attribute value for a device function.
696    pub cu_func_set_attribute:
697        Option<unsafe extern "C" fn(func: CUfunction, attrib: c_int, value: c_int) -> CUresult>,
698
699    // -- Profiler (optional) --------------------------------------------------
700    /// `cuProfilerStart() -> CUresult`
701    ///
702    /// Starts the CUDA profiler.
703    pub cu_profiler_start: Option<unsafe extern "C" fn() -> CUresult>,
704
705    /// `cuProfilerStop() -> CUresult`
706    ///
707    /// Stops the CUDA profiler.
708    pub cu_profiler_stop: Option<unsafe extern "C" fn() -> CUresult>,
709
710    // -- CUDA 12.x extended launch (optional) ---------------------------------
711    /// `cuLaunchKernelEx(config*, f, kernelParams**, extra**) -> CUresult`
712    ///
713    /// Extended kernel launch with cluster dimensions and other CUDA 12.0+
714    /// attributes. Available only when the driver is CUDA 12.0 or newer.
715    ///
716    /// When `None`, fall back to [`cu_launch_kernel`](Self::cu_launch_kernel).
717    #[allow(clippy::type_complexity)]
718    pub cu_launch_kernel_ex: Option<
719        unsafe extern "C" fn(
720            config: *const CuLaunchConfig,
721            f: CUfunction,
722            kernel_params: *mut *mut std::ffi::c_void,
723            extra: *mut *mut std::ffi::c_void,
724        ) -> CUresult,
725    >,
726
727    /// `cuTensorMapEncodeTiled(tensorMap*, ...) -> CUresult`
728    ///
729    /// Creates a TMA tensor map descriptor for tiled access patterns.
730    /// Available on CUDA 12.0+ with sm_90+ (Hopper/Blackwell).
731    ///
732    /// When `None`, TMA is not supported by the loaded driver.
733    #[allow(clippy::type_complexity)]
734    pub cu_tensor_map_encode_tiled: Option<
735        unsafe extern "C" fn(
736            tensor_map: *mut std::ffi::c_void,
737            tensor_data_type: u32,
738            tensor_rank: u32,
739            global_address: *mut std::ffi::c_void,
740            global_dim: *const u64,
741            global_strides: *const u64,
742            box_dim: *const u32,
743            element_strides: *const u32,
744            interleave: u32,
745            swizzle: u32,
746            l2_promotion: u32,
747            oob_fill: u32,
748        ) -> CUresult,
749    >,
750
751    // -- CUDA 12.8+ extended API (optional) -----------------------------------
752    /// `cuTensorMapEncodeTiledMemref(tensorMap*, ...) -> CUresult`
753    ///
754    /// Extended TMA encoding using memref descriptors (CUDA 12.8+,
755    /// Blackwell sm_100/sm_120). When `None`, fall back to
756    /// [`cu_tensor_map_encode_tiled`](Self::cu_tensor_map_encode_tiled).
757    #[allow(clippy::type_complexity)]
758    pub cu_tensor_map_encode_tiled_memref: Option<
759        unsafe extern "C" fn(
760            tensor_map: *mut c_void,
761            tensor_data_type: u32,
762            tensor_rank: u32,
763            global_address: *mut c_void,
764            global_dim: *const u64,
765            global_strides: *const u64,
766            box_dim: *const u32,
767            element_strides: *const u32,
768            interleave: u32,
769            swizzle: u32,
770            l2_promotion: u32,
771            oob_fill: u32,
772            flags: u64,
773        ) -> CUresult,
774    >,
775
776    /// `cuKernelGetLibrary(pLib*, kernel) -> CUresult`
777    ///
778    /// Returns the library handle that owns a given kernel handle
779    /// (CUDA 12.8+). When `None`, the driver does not support the JIT
780    /// library API.
781    pub cu_kernel_get_library:
782        Option<unsafe extern "C" fn(p_lib: *mut CUlibrary, kernel: CUkernel) -> CUresult>,
783
784    /// `cuMulticastGetGranularity(granularity*, desc*, option) -> CUresult`
785    ///
786    /// Queries the recommended memory granularity for an NVLink multicast
787    /// object (CUDA 12.8+). When `None`, multicast memory is not supported.
788    pub cu_multicast_get_granularity: Option<
789        unsafe extern "C" fn(granularity: *mut usize, desc: *const c_void, option: u32) -> CUresult,
790    >,
791
792    /// `cuMulticastCreate(mcHandle*, desc*) -> CUresult`
793    ///
794    /// Creates an NVLink multicast object for cross-GPU broadcast memory
795    /// (CUDA 12.8+). When `None`, multicast memory is not supported.
796    pub cu_multicast_create: Option<
797        unsafe extern "C" fn(mc_handle: *mut CUmulticastObject, desc: *const c_void) -> CUresult,
798    >,
799
800    /// `cuMulticastAddDevice(mcHandle, dev) -> CUresult`
801    ///
802    /// Adds a device to an NVLink multicast group (CUDA 12.8+). When
803    /// `None`, multicast memory is not supported.
804    pub cu_multicast_add_device:
805        Option<unsafe extern "C" fn(mc_handle: CUmulticastObject, dev: CUdevice) -> CUresult>,
806
807    /// `cuMemcpyBatchAsync(dsts*, srcs*, sizes*, count, flags, stream) -> CUresult`
808    ///
809    /// Issues *count* asynchronous memory copies (H2D, D2H, or D2D) in a
810    /// single driver call (CUDA 12.8+). When `None`, issue individual
811    /// `cuMemcpyAsync` calls as a fallback.
812    #[allow(clippy::type_complexity)]
813    pub cu_memcpy_batch_async: Option<
814        unsafe extern "C" fn(
815            dsts: *const *mut c_void,
816            srcs: *const *const c_void,
817            sizes: *const usize,
818            count: u64,
819            flags: u64,
820            stream: CUstream,
821        ) -> CUresult,
822    >,
823
824    // -- Texture / Surface memory (optional) ----------------------------------
825    /// `cuArrayCreate_v2(pHandle*, pAllocateArray*) -> CUresult`
826    ///
827    /// Allocates a 1-D or 2-D CUDA array. When `None`, CUDA array allocation
828    /// is not supported by the loaded driver.
829    pub cu_array_create_v2: Option<
830        unsafe extern "C" fn(
831            p_handle: *mut CUarray,
832            p_allocate_array: *const CUDA_ARRAY_DESCRIPTOR,
833        ) -> CUresult,
834    >,
835
836    /// `cuArrayDestroy(hArray) -> CUresult`
837    ///
838    /// Frees a CUDA array previously allocated by `cuArrayCreate_v2`.
839    pub cu_array_destroy: Option<unsafe extern "C" fn(h_array: CUarray) -> CUresult>,
840
841    /// `cuArrayGetDescriptor_v2(pArrayDescriptor*, hArray) -> CUresult`
842    ///
843    /// Returns the descriptor of a 1-D or 2-D CUDA array.
844    pub cu_array_get_descriptor_v2: Option<
845        unsafe extern "C" fn(
846            p_array_descriptor: *mut CUDA_ARRAY_DESCRIPTOR,
847            h_array: CUarray,
848        ) -> CUresult,
849    >,
850
851    /// `cuArray3DCreate_v2(pHandle*, pAllocateArray*) -> CUresult`
852    ///
853    /// Allocates a 3-D CUDA array (also supports layered and cubemap arrays).
854    pub cu_array3d_create_v2: Option<
855        unsafe extern "C" fn(
856            p_handle: *mut CUarray,
857            p_allocate_array: *const CUDA_ARRAY3D_DESCRIPTOR,
858        ) -> CUresult,
859    >,
860
861    /// `cuArray3DGetDescriptor_v2(pArrayDescriptor*, hArray) -> CUresult`
862    ///
863    /// Returns the descriptor of a 3-D CUDA array.
864    pub cu_array3d_get_descriptor_v2: Option<
865        unsafe extern "C" fn(
866            p_array_descriptor: *mut CUDA_ARRAY3D_DESCRIPTOR,
867            h_array: CUarray,
868        ) -> CUresult,
869    >,
870
871    /// `cuMemcpyHtoA_v2(dstArray, dstOffset, srcHost*, ByteCount) -> CUresult`
872    ///
873    /// Synchronously copies host memory into a CUDA array.
874    pub cu_memcpy_htoa_v2: Option<
875        unsafe extern "C" fn(
876            dst_array: CUarray,
877            dst_offset: usize,
878            src_host: *const c_void,
879            byte_count: usize,
880        ) -> CUresult,
881    >,
882
883    /// `cuMemcpyAtoH_v2(dstHost*, srcArray, srcOffset, ByteCount) -> CUresult`
884    ///
885    /// Synchronously copies data from a CUDA array into host memory.
886    pub cu_memcpy_atoh_v2: Option<
887        unsafe extern "C" fn(
888            dst_host: *mut c_void,
889            src_array: CUarray,
890            src_offset: usize,
891            byte_count: usize,
892        ) -> CUresult,
893    >,
894
895    /// `cuMemcpyHtoAAsync_v2(dstArray, dstOffset, srcHost*, byteCount, stream) -> CUresult`
896    ///
897    /// Asynchronously copies host memory into a CUDA array on a stream.
898    pub cu_memcpy_htoa_async_v2: Option<
899        unsafe extern "C" fn(
900            dst_array: CUarray,
901            dst_offset: usize,
902            src_host: *const c_void,
903            byte_count: usize,
904            stream: CUstream,
905        ) -> CUresult,
906    >,
907
908    /// `cuMemcpyAtoHAsync_v2(dstHost*, srcArray, srcOffset, byteCount, stream) -> CUresult`
909    ///
910    /// Asynchronously copies data from a CUDA array into host memory on a stream.
911    pub cu_memcpy_atoh_async_v2: Option<
912        unsafe extern "C" fn(
913            dst_host: *mut c_void,
914            src_array: CUarray,
915            src_offset: usize,
916            byte_count: usize,
917            stream: CUstream,
918        ) -> CUresult,
919    >,
920
921    /// `cuTexObjectCreate(pTexObject*, pResDesc*, pTexDesc*, pResViewDesc*) -> CUresult`
922    ///
923    /// Creates a texture object from a resource descriptor, texture descriptor,
924    /// and optional resource-view descriptor (CUDA 5.0+).
925    pub cu_tex_object_create: Option<
926        unsafe extern "C" fn(
927            p_tex_object: *mut CUtexObject,
928            p_res_desc: *const CUDA_RESOURCE_DESC,
929            p_tex_desc: *const CUDA_TEXTURE_DESC,
930            p_res_view_desc: *const CUDA_RESOURCE_VIEW_DESC,
931        ) -> CUresult,
932    >,
933
934    /// `cuTexObjectDestroy(texObject) -> CUresult`
935    ///
936    /// Destroys a texture object created by `cuTexObjectCreate`.
937    pub cu_tex_object_destroy: Option<unsafe extern "C" fn(tex_object: CUtexObject) -> CUresult>,
938
939    /// `cuTexObjectGetResourceDesc(pResDesc*, texObject) -> CUresult`
940    ///
941    /// Returns the resource descriptor of a texture object.
942    pub cu_tex_object_get_resource_desc: Option<
943        unsafe extern "C" fn(
944            p_res_desc: *mut CUDA_RESOURCE_DESC,
945            tex_object: CUtexObject,
946        ) -> CUresult,
947    >,
948
949    /// `cuSurfObjectCreate(pSurfObject*, pResDesc*) -> CUresult`
950    ///
951    /// Creates a surface object from a resource descriptor (CUDA 5.0+).
952    /// The resource type must be `Array` (surface-capable CUDA arrays only).
953    pub cu_surf_object_create: Option<
954        unsafe extern "C" fn(
955            p_surf_object: *mut CUsurfObject,
956            p_res_desc: *const CUDA_RESOURCE_DESC,
957        ) -> CUresult,
958    >,
959
960    /// `cuSurfObjectDestroy(surfObject) -> CUresult`
961    ///
962    /// Destroys a surface object created by `cuSurfObjectCreate`.
963    pub cu_surf_object_destroy: Option<unsafe extern "C" fn(surf_object: CUsurfObject) -> CUresult>,
964
965    // -- JIT linker (optional) ----------------------------------------------
966    /// `cuLinkCreate_v2(numOptions, options*, optionValues**, stateOut*) -> CUresult`
967    ///
968    /// Creates a pending JIT linker invocation.  When `None`, the driver does
969    /// not expose the linker API.
970    pub cu_link_create: Option<
971        unsafe extern "C" fn(
972            num_options: u32,
973            options: *mut CUjit_option,
974            option_values: *mut *mut c_void,
975            state_out: *mut CUlinkState,
976        ) -> CUresult,
977    >,
978
979    /// `cuLinkAddData_v2(state, type, data*, size, name*, numOptions, options*, optionValues**) -> CUresult`
980    ///
981    /// Adds an input PTX/cubin/fatbin to a pending linker invocation.  When
982    /// `None`, the driver does not expose the linker API.
983    #[allow(clippy::type_complexity)]
984    pub cu_link_add_data: Option<
985        unsafe extern "C" fn(
986            state: CUlinkState,
987            input_type: CUjitInputType,
988            data: *mut c_void,
989            size: usize,
990            name: *const c_char,
991            num_options: u32,
992            options: *mut CUjit_option,
993            option_values: *mut *mut c_void,
994        ) -> CUresult,
995    >,
996
997    /// `cuLinkComplete(state, cubinOut**, sizeOut*) -> CUresult`
998    ///
999    /// Finalises a JIT linker invocation and returns the resulting cubin
1000    /// pointer / size.  When `None`, the driver does not expose the linker
1001    /// API.
1002    pub cu_link_complete: Option<
1003        unsafe extern "C" fn(
1004            state: CUlinkState,
1005            cubin_out: *mut *mut c_void,
1006            size_out: *mut usize,
1007        ) -> CUresult,
1008    >,
1009
1010    /// `cuLinkDestroy(state) -> CUresult`
1011    ///
1012    /// Destroys a linker state previously created by `cuLinkCreate`.  When
1013    /// `None`, the driver does not expose the linker API.
1014    pub cu_link_destroy: Option<unsafe extern "C" fn(state: CUlinkState) -> CUresult>,
1015
1016    // -- 2-D memory copy (optional) -----------------------------------------
1017    /// `cuMemcpy2D_v2(pCopy*) -> CUresult`
1018    ///
1019    /// Performs a 2-D memory copy described by [`CUDA_MEMCPY2D`].  When
1020    /// `None`, fall back to issuing per-row 1-D `cuMemcpyXXX_v2` calls.
1021    pub cu_memcpy_2d: Option<unsafe extern "C" fn(p_copy: *const CUDA_MEMCPY2D) -> CUresult>,
1022
1023    // -- Virtual memory management (optional, CUDA 11.2+) -------------------
1024    /// `cuMemAddressReserve(ptr*, size, alignment, addr, flags) -> CUresult`
1025    ///
1026    /// Reserves a contiguous range of virtual addresses on the device for
1027    /// later mapping by `cuMemMap`.  When `None`, the VMM API is not
1028    /// supported.
1029    pub cu_mem_address_reserve: Option<
1030        unsafe extern "C" fn(
1031            ptr: *mut CUdeviceptr,
1032            size: usize,
1033            alignment: usize,
1034            addr: CUdeviceptr,
1035            flags: u64,
1036        ) -> CUresult,
1037    >,
1038
1039    /// `cuMemAddressFree(ptr, size) -> CUresult`
1040    ///
1041    /// Releases a virtual-address range previously obtained from
1042    /// `cuMemAddressReserve`.  When `None`, the VMM API is not supported.
1043    pub cu_mem_address_free:
1044        Option<unsafe extern "C" fn(ptr: CUdeviceptr, size: usize) -> CUresult>,
1045
1046    /// `cuMemCreate(handle*, size, prop*, flags) -> CUresult`
1047    ///
1048    /// Creates a new generic VMM allocation handle.  When `None`, the VMM
1049    /// API is not supported.
1050    pub cu_mem_create: Option<
1051        unsafe extern "C" fn(
1052            handle: *mut CUmemGenericAllocationHandle,
1053            size: usize,
1054            prop: *const CUmemAllocationProp,
1055            flags: u64,
1056        ) -> CUresult,
1057    >,
1058
1059    /// `cuMemRelease(handle) -> CUresult`
1060    ///
1061    /// Releases a generic VMM allocation handle.  When `None`, the VMM
1062    /// API is not supported.
1063    pub cu_mem_release:
1064        Option<unsafe extern "C" fn(handle: CUmemGenericAllocationHandle) -> CUresult>,
1065
1066    /// `cuMemMap(ptr, size, offset, handle, flags) -> CUresult`
1067    ///
1068    /// Maps a VMM allocation onto a previously reserved virtual address
1069    /// range.  When `None`, the VMM API is not supported.
1070    pub cu_mem_map: Option<
1071        unsafe extern "C" fn(
1072            ptr: CUdeviceptr,
1073            size: usize,
1074            offset: usize,
1075            handle: CUmemGenericAllocationHandle,
1076            flags: u64,
1077        ) -> CUresult,
1078    >,
1079
1080    /// `cuMemUnmap(ptr, size) -> CUresult`
1081    ///
1082    /// Unmaps a VMM allocation from a virtual address range.  When `None`,
1083    /// the VMM API is not supported.
1084    pub cu_mem_unmap: Option<unsafe extern "C" fn(ptr: CUdeviceptr, size: usize) -> CUresult>,
1085
1086    /// `cuMemSetAccess(ptr, size, desc*, count) -> CUresult`
1087    ///
1088    /// Sets per-location access permissions for a VMM mapping.  When `None`,
1089    /// the VMM API is not supported.
1090    pub cu_mem_set_access: Option<
1091        unsafe extern "C" fn(
1092            ptr: CUdeviceptr,
1093            size: usize,
1094            desc: *const CUmemAccessDesc,
1095            count: usize,
1096        ) -> CUresult,
1097    >,
1098
1099    // -- Stream-ordered memory pools (optional, CUDA 11.2+) -----------------
1100    /// `cuMemPoolCreate(pool*, poolProps*) -> CUresult`
1101    ///
1102    /// Creates a stream-ordered memory pool.  When `None`, the memory pool
1103    /// API is not supported.
1104    pub cu_mem_pool_create: Option<
1105        unsafe extern "C" fn(
1106            pool: *mut CUmemoryPool,
1107            pool_props: *const CUmemPoolProps,
1108        ) -> CUresult,
1109    >,
1110
1111    /// `cuMemPoolDestroy(pool) -> CUresult`
1112    ///
1113    /// Destroys a stream-ordered memory pool.  When `None`, the memory pool
1114    /// API is not supported.
1115    pub cu_mem_pool_destroy: Option<unsafe extern "C" fn(pool: CUmemoryPool) -> CUresult>,
1116
1117    /// `cuMemAllocFromPoolAsync(dptr*, bytesize, pool, hStream) -> CUresult`
1118    ///
1119    /// Asynchronously allocates memory from a pool on a stream.  When `None`,
1120    /// the memory pool API is not supported.
1121    pub cu_mem_alloc_from_pool_async: Option<
1122        unsafe extern "C" fn(
1123            dptr: *mut CUdeviceptr,
1124            bytesize: usize,
1125            pool: CUmemoryPool,
1126            hstream: CUstream,
1127        ) -> CUresult,
1128    >,
1129
1130    /// `cuMemFreeAsync(dptr, hStream) -> CUresult`
1131    ///
1132    /// Asynchronously frees memory on a stream.  When `None`, the
1133    /// stream-ordered memory API is not supported.
1134    pub cu_mem_free_async:
1135        Option<unsafe extern "C" fn(dptr: CUdeviceptr, hstream: CUstream) -> CUresult>,
1136
1137    /// `cuMemAllocAsync(dptr*, bytesize, hStream) -> CUresult`
1138    ///
1139    /// Asynchronously allocates memory from the current context's default
1140    /// pool on a stream.  When `None`, the stream-ordered memory API is not
1141    /// supported.
1142    pub cu_mem_alloc_async: Option<
1143        unsafe extern "C" fn(
1144            dptr: *mut CUdeviceptr,
1145            bytesize: usize,
1146            hstream: CUstream,
1147        ) -> CUresult,
1148    >,
1149
1150    /// `cuMemPoolTrimTo(pool, minBytesToKeep) -> CUresult`
1151    ///
1152    /// Releases freed memory back to the OS, keeping at least
1153    /// `minBytesToKeep` bytes reserved.  When `None`, the memory pool API
1154    /// is not supported.
1155    pub cu_mem_pool_trim_to:
1156        Option<unsafe extern "C" fn(pool: CUmemoryPool, min_bytes_to_keep: usize) -> CUresult>,
1157
1158    /// `cuMemPoolSetAttribute(pool, attr, value*) -> CUresult`
1159    ///
1160    /// Sets a writable attribute on a memory pool.  When `None`, the memory
1161    /// pool API is not supported.
1162    pub cu_mem_pool_set_attribute: Option<
1163        unsafe extern "C" fn(
1164            pool: CUmemoryPool,
1165            attr: CUmemPoolAttribute,
1166            value: *mut c_void,
1167        ) -> CUresult,
1168    >,
1169
1170    /// `cuMemPoolGetAttribute(pool, attr, value*) -> CUresult`
1171    ///
1172    /// Reads an attribute from a memory pool.  When `None`, the memory pool
1173    /// API is not supported.
1174    pub cu_mem_pool_get_attribute: Option<
1175        unsafe extern "C" fn(
1176            pool: CUmemoryPool,
1177            attr: CUmemPoolAttribute,
1178            value: *mut c_void,
1179        ) -> CUresult,
1180    >,
1181
1182    /// `cuMemPoolSetAccess(pool, map*, count) -> CUresult`
1183    ///
1184    /// Controls the per-device visibility of allocations from a memory pool.
1185    /// When `None`, the memory pool API is not supported.
1186    pub cu_mem_pool_set_access: Option<
1187        unsafe extern "C" fn(
1188            pool: CUmemoryPool,
1189            map: *const CUmemAccessDesc,
1190            count: usize,
1191        ) -> CUresult,
1192    >,
1193
1194    /// `cuDeviceGetDefaultMemPool(pool*, dev) -> CUresult`
1195    ///
1196    /// Returns the default stream-ordered memory pool of a device.  When
1197    /// `None`, the memory pool API is not supported.
1198    pub cu_device_get_default_mem_pool:
1199        Option<unsafe extern "C" fn(pool: *mut CUmemoryPool, dev: CUdevice) -> CUresult>,
1200
1201    // -- CUDA Graph API (optional, CUDA 10.0+) ------------------------------
1202    /// `cuGraphCreate(phGraph*, flags) -> CUresult`
1203    ///
1204    /// Creates an empty CUDA graph.  When `None`, the graph API is not
1205    /// supported by the loaded driver.
1206    pub cu_graph_create:
1207        Option<unsafe extern "C" fn(ph_graph: *mut CUgraph, flags: u32) -> CUresult>,
1208
1209    /// `cuGraphDestroy(hGraph) -> CUresult`
1210    ///
1211    /// Destroys a CUDA graph.  When `None`, the graph API is not supported.
1212    pub cu_graph_destroy: Option<unsafe extern "C" fn(h_graph: CUgraph) -> CUresult>,
1213
1214    /// `cuGraphAddKernelNode(phGraphNode*, hGraph, dependencies*, numDependencies, nodeParams*) -> CUresult`
1215    ///
1216    /// Adds a kernel-launch node to a graph.  When `None`, the graph API is
1217    /// not supported.
1218    pub cu_graph_add_kernel_node: Option<
1219        unsafe extern "C" fn(
1220            ph_graph_node: *mut CUgraphNode,
1221            h_graph: CUgraph,
1222            dependencies: *const CUgraphNode,
1223            num_dependencies: usize,
1224            node_params: *const CUDA_KERNEL_NODE_PARAMS,
1225        ) -> CUresult,
1226    >,
1227
1228    /// `cuGraphAddMemcpyNode(phGraphNode*, hGraph, dependencies*, numDependencies, copyParams*, ctx) -> CUresult`
1229    ///
1230    /// Adds a memory-copy node to a graph.  When `None`, the graph API is
1231    /// not supported.
1232    pub cu_graph_add_memcpy_node: Option<
1233        unsafe extern "C" fn(
1234            ph_graph_node: *mut CUgraphNode,
1235            h_graph: CUgraph,
1236            dependencies: *const CUgraphNode,
1237            num_dependencies: usize,
1238            copy_params: *const CUDA_MEMCPY3D,
1239            ctx: CUcontext,
1240        ) -> CUresult,
1241    >,
1242
1243    /// `cuGraphAddMemsetNode(phGraphNode*, hGraph, dependencies*, numDependencies, memsetParams*, ctx) -> CUresult`
1244    ///
1245    /// Adds a memset node to a graph.  When `None`, the graph API is not
1246    /// supported.
1247    pub cu_graph_add_memset_node: Option<
1248        unsafe extern "C" fn(
1249            ph_graph_node: *mut CUgraphNode,
1250            h_graph: CUgraph,
1251            dependencies: *const CUgraphNode,
1252            num_dependencies: usize,
1253            memset_params: *const CUDA_MEMSET_NODE_PARAMS,
1254            ctx: CUcontext,
1255        ) -> CUresult,
1256    >,
1257
1258    /// `cuGraphAddEmptyNode(phGraphNode*, hGraph, dependencies*, numDependencies) -> CUresult`
1259    ///
1260    /// Adds an empty (no-op) node to a graph.  When `None`, the graph API
1261    /// is not supported.
1262    pub cu_graph_add_empty_node: Option<
1263        unsafe extern "C" fn(
1264            ph_graph_node: *mut CUgraphNode,
1265            h_graph: CUgraph,
1266            dependencies: *const CUgraphNode,
1267            num_dependencies: usize,
1268        ) -> CUresult,
1269    >,
1270
1271    /// `cuGraphInstantiateWithFlags(phGraphExec*, hGraph, flags) -> CUresult`
1272    ///
1273    /// Instantiates a graph into an executable form (CUDA 11.4+).  When
1274    /// `None`, fall back to [`cu_graph_instantiate`](Self::cu_graph_instantiate).
1275    pub cu_graph_instantiate_with_flags: Option<
1276        unsafe extern "C" fn(
1277            ph_graph_exec: *mut CUgraphExec,
1278            h_graph: CUgraph,
1279            flags: u64,
1280        ) -> CUresult,
1281    >,
1282
1283    /// `cuGraphInstantiate_v2(phGraphExec*, hGraph, phErrorNode*, logBuffer*, bufferSize) -> CUresult`
1284    ///
1285    /// Instantiates a graph into an executable form (legacy signature).
1286    /// When `None`, the graph API is not supported.
1287    pub cu_graph_instantiate: Option<
1288        unsafe extern "C" fn(
1289            ph_graph_exec: *mut CUgraphExec,
1290            h_graph: CUgraph,
1291            ph_error_node: *mut CUgraphNode,
1292            log_buffer: *mut c_char,
1293            buffer_size: usize,
1294        ) -> CUresult,
1295    >,
1296
1297    /// `cuGraphExecDestroy(hGraphExec) -> CUresult`
1298    ///
1299    /// Destroys an executable graph.  When `None`, the graph API is not
1300    /// supported.
1301    pub cu_graph_exec_destroy: Option<unsafe extern "C" fn(h_graph_exec: CUgraphExec) -> CUresult>,
1302
1303    /// `cuGraphLaunch(hGraphExec, hStream) -> CUresult`
1304    ///
1305    /// Submits an executable graph to a stream.  When `None`, the graph API
1306    /// is not supported.
1307    pub cu_graph_launch:
1308        Option<unsafe extern "C" fn(h_graph_exec: CUgraphExec, h_stream: CUstream) -> CUresult>,
1309}
1310
1311// SAFETY: All fields are plain function pointers (which are Send + Sync) and
1312// the Library handle is kept alive but never mutated.
1313unsafe impl Send for DriverApi {}
1314unsafe impl Sync for DriverApi {}
1315
1316// ---------------------------------------------------------------------------
1317// DriverApi — construction
1318// ---------------------------------------------------------------------------
1319
1320impl DriverApi {
1321    /// Attempt to dynamically load the CUDA driver shared library and resolve
1322    /// every required symbol.
1323    ///
1324    /// # Platform behaviour
1325    ///
1326    /// * **macOS** — immediately returns [`DriverLoadError::UnsupportedPlatform`].
1327    /// * **Linux** — tries `libcuda.so.1` then `libcuda.so`.
1328    /// * **Windows** — tries `nvcuda.dll`.
1329    ///
1330    /// # Errors
1331    ///
1332    /// * [`DriverLoadError::UnsupportedPlatform`] on macOS.
1333    /// * [`DriverLoadError::LibraryNotFound`] if none of the candidate library
1334    ///   names could be opened.
1335    /// * [`DriverLoadError::SymbolNotFound`] if a required CUDA entry point is
1336    ///   missing from the loaded library.
1337    pub fn load() -> Result<Self, DriverLoadError> {
1338        // macOS: CUDA is not and will not be supported.
1339        #[cfg(target_os = "macos")]
1340        {
1341            Err(DriverLoadError::UnsupportedPlatform)
1342        }
1343
1344        // Linux library search order.
1345        #[cfg(target_os = "linux")]
1346        let lib_names: &[&str] = &["libcuda.so.1", "libcuda.so"];
1347
1348        // Windows library search order.
1349        #[cfg(target_os = "windows")]
1350        let lib_names: &[&str] = &["nvcuda.dll"];
1351
1352        #[cfg(not(target_os = "macos"))]
1353        {
1354            let lib = Self::load_library(lib_names)?;
1355            let api = Self::load_symbols(lib)?;
1356            // `cuInit(0)` must be called before any other CUDA driver API.
1357            // This mirrors what `libcudart` does internally on the first CUDA
1358            // Runtime call. We call it unconditionally here so that all
1359            // `try_driver()` callers get a fully initialised driver without
1360            // each needing to call `cuInit` themselves.
1361            //
1362            // SAFETY: `api.cu_init` was just resolved from the shared library.
1363            // Passing flags=0 is the only documented value.
1364            let rc = unsafe { (api.cu_init)(0) };
1365            if rc != 0 {
1366                // Propagate the error; the OnceLock will store this Err and
1367                // return CudaError::NotInitialized on every subsequent
1368                // try_driver() call — matching behaviour on no-GPU machines.
1369                return Err(DriverLoadError::InitializationFailed { code: rc });
1370            }
1371            Ok(api)
1372        }
1373    }
1374
1375    /// Try each candidate library name in order, returning the first that
1376    /// loads successfully.
1377    ///
1378    /// # Errors
1379    ///
1380    /// Returns [`DriverLoadError::LibraryNotFound`] if **all** candidates
1381    /// fail to load, capturing the last OS-level error message.
1382    #[cfg(not(target_os = "macos"))]
1383    fn load_library(names: &[&str]) -> Result<Library, DriverLoadError> {
1384        let mut last_error = String::new();
1385        for name in names {
1386            // SAFETY: loading a shared library has side-effects (running its
1387            // init routines), but the CUDA driver library is designed for
1388            // this.
1389            match unsafe { Library::new(*name) } {
1390                Ok(lib) => {
1391                    tracing::debug!("loaded CUDA driver library: {name}");
1392                    return Ok(lib);
1393                }
1394                Err(e) => {
1395                    tracing::debug!("failed to load {name}: {e}");
1396                    last_error = e.to_string();
1397                }
1398            }
1399        }
1400
1401        Err(DriverLoadError::LibraryNotFound {
1402            candidates: names.iter().map(|s| (*s).to_string()).collect(),
1403            last_error,
1404        })
1405    }
1406
1407    /// Resolve every required CUDA driver symbol from the loaded library and
1408    /// assemble the [`DriverApi`] function table.
1409    ///
1410    /// # Errors
1411    ///
1412    /// Returns [`DriverLoadError::SymbolNotFound`] if any symbol cannot be
1413    /// resolved.
1414    #[cfg(not(target_os = "macos"))]
1415    fn load_symbols(lib: Library) -> Result<Self, DriverLoadError> {
1416        Ok(Self {
1417            // -- Initialisation ------------------------------------------------
1418            cu_init: load_sym!(lib, "cuInit"),
1419
1420            // -- Version query -------------------------------------------------
1421            cu_driver_get_version: load_sym!(lib, "cuDriverGetVersion"),
1422
1423            // -- Device management ---------------------------------------------
1424            cu_device_get: load_sym!(lib, "cuDeviceGet"),
1425            cu_device_get_count: load_sym!(lib, "cuDeviceGetCount"),
1426            cu_device_get_name: load_sym!(lib, "cuDeviceGetName"),
1427            cu_device_get_attribute: load_sym!(lib, "cuDeviceGetAttribute"),
1428            cu_device_total_mem_v2: load_sym!(lib, "cuDeviceTotalMem_v2"),
1429            cu_device_can_access_peer: load_sym!(lib, "cuDeviceCanAccessPeer"),
1430
1431            // -- Primary context management ------------------------------------
1432            cu_device_primary_ctx_retain: load_sym!(lib, "cuDevicePrimaryCtxRetain"),
1433            cu_device_primary_ctx_release_v2: load_sym!(lib, "cuDevicePrimaryCtxRelease_v2"),
1434            cu_device_primary_ctx_set_flags_v2: load_sym!(lib, "cuDevicePrimaryCtxSetFlags_v2"),
1435            cu_device_primary_ctx_get_state: load_sym!(lib, "cuDevicePrimaryCtxGetState"),
1436            cu_device_primary_ctx_reset_v2: load_sym!(lib, "cuDevicePrimaryCtxReset_v2"),
1437
1438            // -- Context management --------------------------------------------
1439            cu_ctx_create_v2: load_sym!(lib, "cuCtxCreate_v2"),
1440            cu_ctx_destroy_v2: load_sym!(lib, "cuCtxDestroy_v2"),
1441            cu_ctx_set_current: load_sym!(lib, "cuCtxSetCurrent"),
1442            cu_ctx_get_current: load_sym!(lib, "cuCtxGetCurrent"),
1443            cu_ctx_synchronize: load_sym!(lib, "cuCtxSynchronize"),
1444
1445            // -- Module management ---------------------------------------------
1446            cu_module_load_data: load_sym!(lib, "cuModuleLoadData"),
1447            cu_module_load_data_ex: load_sym!(lib, "cuModuleLoadDataEx"),
1448            cu_module_get_function: load_sym!(lib, "cuModuleGetFunction"),
1449            cu_module_unload: load_sym!(lib, "cuModuleUnload"),
1450
1451            // -- Memory management ---------------------------------------------
1452            cu_mem_alloc_v2: load_sym!(lib, "cuMemAlloc_v2"),
1453            cu_mem_free_v2: load_sym!(lib, "cuMemFree_v2"),
1454            cu_memcpy_htod_v2: load_sym!(lib, "cuMemcpyHtoD_v2"),
1455            cu_memcpy_dtoh_v2: load_sym!(lib, "cuMemcpyDtoH_v2"),
1456            cu_memcpy_dtod_v2: load_sym!(lib, "cuMemcpyDtoD_v2"),
1457            cu_memcpy_htod_async_v2: load_sym!(lib, "cuMemcpyHtoDAsync_v2"),
1458            cu_memcpy_dtoh_async_v2: load_sym!(lib, "cuMemcpyDtoHAsync_v2"),
1459            cu_mem_alloc_host_v2: load_sym!(lib, "cuMemAllocHost_v2"),
1460            cu_mem_free_host: load_sym!(lib, "cuMemFreeHost"),
1461            cu_mem_alloc_managed: load_sym!(lib, "cuMemAllocManaged"),
1462            cu_memset_d8_v2: load_sym!(lib, "cuMemsetD8_v2"),
1463            cu_memset_d32_v2: load_sym!(lib, "cuMemsetD32_v2"),
1464            cu_mem_get_info_v2: load_sym!(lib, "cuMemGetInfo_v2"),
1465            cu_mem_host_register_v2: load_sym!(lib, "cuMemHostRegister_v2"),
1466            cu_mem_host_unregister: load_sym!(lib, "cuMemHostUnregister"),
1467            cu_mem_host_get_device_pointer_v2: load_sym!(lib, "cuMemHostGetDevicePointer_v2"),
1468            cu_pointer_get_attribute: load_sym!(lib, "cuPointerGetAttribute"),
1469            cu_mem_advise: load_sym!(lib, "cuMemAdvise"),
1470            cu_mem_prefetch_async: load_sym!(lib, "cuMemPrefetchAsync"),
1471
1472            // -- Stream management ---------------------------------------------
1473            cu_stream_create: load_sym!(lib, "cuStreamCreate"),
1474            cu_stream_create_with_priority: load_sym!(lib, "cuStreamCreateWithPriority"),
1475            cu_stream_destroy_v2: load_sym!(lib, "cuStreamDestroy_v2"),
1476            cu_stream_synchronize: load_sym!(lib, "cuStreamSynchronize"),
1477            cu_stream_wait_event: load_sym!(lib, "cuStreamWaitEvent"),
1478            cu_stream_query: load_sym!(lib, "cuStreamQuery"),
1479            cu_stream_get_priority: load_sym!(lib, "cuStreamGetPriority"),
1480            cu_stream_get_flags: load_sym!(lib, "cuStreamGetFlags"),
1481
1482            // -- Event management ----------------------------------------------
1483            cu_event_create: load_sym!(lib, "cuEventCreate"),
1484            cu_event_destroy_v2: load_sym!(lib, "cuEventDestroy_v2"),
1485            cu_event_record: load_sym!(lib, "cuEventRecord"),
1486            cu_event_query: load_sym!(lib, "cuEventQuery"),
1487            cu_event_synchronize: load_sym!(lib, "cuEventSynchronize"),
1488            cu_event_elapsed_time: load_sym!(lib, "cuEventElapsedTime"),
1489            cu_event_record_with_flags: load_sym_optional!(lib, "cuEventRecordWithFlags"),
1490
1491            // -- Peer memory access -------------------------------------------
1492            cu_memcpy_peer: load_sym!(lib, "cuMemcpyPeer"),
1493            cu_memcpy_peer_async: load_sym!(lib, "cuMemcpyPeerAsync"),
1494            cu_ctx_enable_peer_access: load_sym!(lib, "cuCtxEnablePeerAccess"),
1495            cu_ctx_disable_peer_access: load_sym!(lib, "cuCtxDisablePeerAccess"),
1496
1497            // -- Kernel launch -------------------------------------------------
1498            cu_launch_kernel: load_sym!(lib, "cuLaunchKernel"),
1499            cu_launch_cooperative_kernel: load_sym!(lib, "cuLaunchCooperativeKernel"),
1500            cu_launch_cooperative_kernel_multi_device: load_sym!(
1501                lib,
1502                "cuLaunchCooperativeKernelMultiDevice"
1503            ),
1504
1505            // -- Occupancy -----------------------------------------------------
1506            cu_occupancy_max_active_blocks_per_multiprocessor: load_sym!(
1507                lib,
1508                "cuOccupancyMaxActiveBlocksPerMultiprocessor"
1509            ),
1510            cu_occupancy_max_potential_block_size: load_sym!(
1511                lib,
1512                "cuOccupancyMaxPotentialBlockSize"
1513            ),
1514            cu_occupancy_max_active_blocks_per_multiprocessor_with_flags: load_sym!(
1515                lib,
1516                "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags"
1517            ),
1518
1519            // -- Memory management (optional) ---------------------------------
1520            cu_memcpy_dtod_async_v2: load_sym_optional!(lib, "cuMemcpyDtoDAsync_v2"),
1521            cu_memset_d16_v2: load_sym_optional!(lib, "cuMemsetD16_v2"),
1522            cu_memset_d32_async: load_sym_optional!(lib, "cuMemsetD32Async"),
1523
1524            // -- Context management (optional) --------------------------------
1525            cu_ctx_get_limit: load_sym_optional!(lib, "cuCtxGetLimit"),
1526            cu_ctx_set_limit: load_sym_optional!(lib, "cuCtxSetLimit"),
1527            cu_ctx_get_cache_config: load_sym_optional!(lib, "cuCtxGetCacheConfig"),
1528            cu_ctx_set_cache_config: load_sym_optional!(lib, "cuCtxSetCacheConfig"),
1529            cu_ctx_get_shared_mem_config: load_sym_optional!(lib, "cuCtxGetSharedMemConfig"),
1530            cu_ctx_set_shared_mem_config: load_sym_optional!(lib, "cuCtxSetSharedMemConfig"),
1531
1532            // -- Function attributes (optional) -------------------------------
1533            cu_func_get_attribute: load_sym_optional!(lib, "cuFuncGetAttribute"),
1534            cu_func_set_cache_config: load_sym_optional!(lib, "cuFuncSetCacheConfig"),
1535            cu_func_set_shared_mem_config: load_sym_optional!(lib, "cuFuncSetSharedMemConfig"),
1536            cu_func_set_attribute: load_sym_optional!(lib, "cuFuncSetAttribute"),
1537
1538            // -- Profiler (optional) ------------------------------------------
1539            cu_profiler_start: load_sym_optional!(lib, "cuProfilerStart"),
1540            cu_profiler_stop: load_sym_optional!(lib, "cuProfilerStop"),
1541
1542            // -- CUDA 12.x extended launch (optional) -------------------------
1543            cu_launch_kernel_ex: load_sym_optional!(lib, "cuLaunchKernelEx"),
1544            cu_tensor_map_encode_tiled: load_sym_optional!(lib, "cuTensorMapEncodeTiled"),
1545
1546            // -- CUDA 12.8+ extended API (optional) ---------------------------
1547            cu_tensor_map_encode_tiled_memref: load_sym_optional!(
1548                lib,
1549                "cuTensorMapEncodeTiledMemref"
1550            ),
1551            cu_kernel_get_library: load_sym_optional!(lib, "cuKernelGetLibrary"),
1552            cu_multicast_get_granularity: load_sym_optional!(lib, "cuMulticastGetGranularity"),
1553            cu_multicast_create: load_sym_optional!(lib, "cuMulticastCreate"),
1554            cu_multicast_add_device: load_sym_optional!(lib, "cuMulticastAddDevice"),
1555            cu_memcpy_batch_async: load_sym_optional!(lib, "cuMemcpyBatchAsync"),
1556
1557            // -- Texture / Surface memory (optional) ---------------------------
1558            cu_array_create_v2: load_sym_optional!(lib, "cuArrayCreate_v2"),
1559            cu_array_destroy: load_sym_optional!(lib, "cuArrayDestroy"),
1560            cu_array_get_descriptor_v2: load_sym_optional!(lib, "cuArrayGetDescriptor_v2"),
1561            cu_array3d_create_v2: load_sym_optional!(lib, "cuArray3DCreate_v2"),
1562            cu_array3d_get_descriptor_v2: load_sym_optional!(lib, "cuArray3DGetDescriptor_v2"),
1563            cu_memcpy_htoa_v2: load_sym_optional!(lib, "cuMemcpyHtoA_v2"),
1564            cu_memcpy_atoh_v2: load_sym_optional!(lib, "cuMemcpyAtoH_v2"),
1565            cu_memcpy_htoa_async_v2: load_sym_optional!(lib, "cuMemcpyHtoAAsync_v2"),
1566            cu_memcpy_atoh_async_v2: load_sym_optional!(lib, "cuMemcpyAtoHAsync_v2"),
1567            cu_tex_object_create: load_sym_optional!(lib, "cuTexObjectCreate"),
1568            cu_tex_object_destroy: load_sym_optional!(lib, "cuTexObjectDestroy"),
1569            cu_tex_object_get_resource_desc: load_sym_optional!(lib, "cuTexObjectGetResourceDesc"),
1570            cu_surf_object_create: load_sym_optional!(lib, "cuSurfObjectCreate"),
1571            cu_surf_object_destroy: load_sym_optional!(lib, "cuSurfObjectDestroy"),
1572
1573            // -- JIT linker (optional) ----------------------------------------
1574            cu_link_create: load_sym_optional!(lib, "cuLinkCreate_v2"),
1575            cu_link_add_data: load_sym_optional!(lib, "cuLinkAddData_v2"),
1576            cu_link_complete: load_sym_optional!(lib, "cuLinkComplete"),
1577            cu_link_destroy: load_sym_optional!(lib, "cuLinkDestroy"),
1578
1579            // -- 2-D memory copy (optional) -----------------------------------
1580            cu_memcpy_2d: load_sym_optional!(lib, "cuMemcpy2D_v2"),
1581
1582            // -- VMM (optional, CUDA 11.2+) -----------------------------------
1583            cu_mem_address_reserve: load_sym_optional!(lib, "cuMemAddressReserve"),
1584            cu_mem_address_free: load_sym_optional!(lib, "cuMemAddressFree"),
1585            cu_mem_create: load_sym_optional!(lib, "cuMemCreate"),
1586            cu_mem_release: load_sym_optional!(lib, "cuMemRelease"),
1587            cu_mem_map: load_sym_optional!(lib, "cuMemMap"),
1588            cu_mem_unmap: load_sym_optional!(lib, "cuMemUnmap"),
1589            cu_mem_set_access: load_sym_optional!(lib, "cuMemSetAccess"),
1590
1591            // -- Stream-ordered memory pools (optional, CUDA 11.2+) -----------
1592            cu_mem_pool_create: load_sym_optional!(lib, "cuMemPoolCreate"),
1593            cu_mem_pool_destroy: load_sym_optional!(lib, "cuMemPoolDestroy"),
1594            cu_mem_alloc_from_pool_async: load_sym_optional!(lib, "cuMemAllocFromPoolAsync"),
1595            cu_mem_free_async: load_sym_optional!(lib, "cuMemFreeAsync"),
1596            cu_mem_alloc_async: load_sym_optional!(lib, "cuMemAllocAsync"),
1597            cu_mem_pool_trim_to: load_sym_optional!(lib, "cuMemPoolTrimTo"),
1598            cu_mem_pool_set_attribute: load_sym_optional!(lib, "cuMemPoolSetAttribute"),
1599            cu_mem_pool_get_attribute: load_sym_optional!(lib, "cuMemPoolGetAttribute"),
1600            cu_mem_pool_set_access: load_sym_optional!(lib, "cuMemPoolSetAccess"),
1601            cu_device_get_default_mem_pool: load_sym_optional!(lib, "cuDeviceGetDefaultMemPool"),
1602
1603            // -- CUDA Graph API (optional, CUDA 10.0+) ------------------------
1604            cu_graph_create: load_sym_optional!(lib, "cuGraphCreate"),
1605            cu_graph_destroy: load_sym_optional!(lib, "cuGraphDestroy"),
1606            cu_graph_add_kernel_node: load_sym_optional!(lib, "cuGraphAddKernelNode"),
1607            cu_graph_add_memcpy_node: load_sym_optional!(lib, "cuGraphAddMemcpyNode"),
1608            cu_graph_add_memset_node: load_sym_optional!(lib, "cuGraphAddMemsetNode"),
1609            cu_graph_add_empty_node: load_sym_optional!(lib, "cuGraphAddEmptyNode"),
1610            cu_graph_instantiate_with_flags: load_sym_optional!(lib, "cuGraphInstantiateWithFlags"),
1611            cu_graph_instantiate: load_sym_optional!(lib, "cuGraphInstantiate_v2"),
1612            cu_graph_exec_destroy: load_sym_optional!(lib, "cuGraphExecDestroy"),
1613            cu_graph_launch: load_sym_optional!(lib, "cuGraphLaunch"),
1614
1615            // Keep the library handle alive.
1616            _lib: lib,
1617        })
1618    }
1619}
1620
1621// ---------------------------------------------------------------------------
1622// Global accessor
1623// ---------------------------------------------------------------------------
1624
1625/// Get a reference to the lazily-loaded CUDA driver API function table.
1626///
1627/// On the first call, this function dynamically loads the CUDA shared library
1628/// and resolves all required symbols.  Subsequent calls return the cached
1629/// result with only an atomic load.
1630///
1631/// # Errors
1632///
1633/// Returns [`CudaError::NotInitialized`] if the driver could not be loaded —
1634/// for instance, on macOS, or on a system without an NVIDIA GPU driver
1635/// installed.
1636///
1637/// # Examples
1638///
1639/// ```rust,no_run
1640/// # use oxicuda_driver::loader::try_driver;
1641/// let api = try_driver()?;
1642/// let result = unsafe { (api.cu_init)(0) };
1643/// # Ok::<(), oxicuda_driver::error::CudaError>(())
1644/// ```
1645pub fn try_driver() -> CudaResult<&'static DriverApi> {
1646    let result = DRIVER.get_or_init(DriverApi::load);
1647    match result {
1648        Ok(api) => Ok(api),
1649        Err(_) => Err(CudaError::NotInitialized),
1650    }
1651}
1652
1653// ---------------------------------------------------------------------------
1654// Tests
1655// ---------------------------------------------------------------------------
1656
1657#[cfg(test)]
1658mod tests {
1659    use super::*;
1660
1661    /// On macOS, loading should always fail with `UnsupportedPlatform`.
1662    #[cfg(target_os = "macos")]
1663    #[test]
1664    fn load_returns_unsupported_on_macos() {
1665        let result = DriverApi::load();
1666        assert!(result.is_err(), "expected Err on macOS");
1667        let err = match result {
1668            Err(e) => e,
1669            Ok(_) => panic!("expected Err on macOS"),
1670        };
1671        assert!(
1672            matches!(err, DriverLoadError::UnsupportedPlatform),
1673            "expected UnsupportedPlatform, got {err:?}"
1674        );
1675    }
1676
1677    /// `try_driver` should return `Err(NotInitialized)` on platforms without
1678    /// a CUDA driver (including macOS).
1679    #[cfg(target_os = "macos")]
1680    #[test]
1681    fn try_driver_returns_not_initialized_on_macos() {
1682        let result = try_driver();
1683        assert!(result.is_err(), "expected Err on macOS");
1684        let err = match result {
1685            Err(e) => e,
1686            Ok(_) => panic!("expected Err on macOS"),
1687        };
1688        assert!(
1689            matches!(err, CudaError::NotInitialized),
1690            "expected NotInitialized, got {err:?}"
1691        );
1692    }
1693
1694    // -----------------------------------------------------------------------
1695    // Task 1 — CUDA 12.8+ DriverApi struct layout tests
1696    //
1697    // These tests verify that the DriverApi struct contains the expected
1698    // Option<fn(...)> fields for the new CUDA 12.8+ API entry points.
1699    // They compile and run without a GPU because they only inspect type
1700    // layout and field presence, never calling the function pointers.
1701    // -----------------------------------------------------------------------
1702
1703    /// Verify that the `cu_tensor_map_encode_tiled_memref` field exists and
1704    /// is an `Option` type.  The driver will return `None` on older versions.
1705    #[test]
1706    fn driver_v12_8_api_fields_present() {
1707        // The simplest way to prove a field exists at the correct type is to
1708        // construct a value that fits in that position.  We use a local
1709        // DriverApi value on macOS (where load() always returns Err) by
1710        // manufacturing a dummy function pointer and verifying the type
1711        // annotation compiles.
1712        //
1713        // On non-macOS platforms we simply verify the field is accessible on
1714        // the type via a None literal assignment (compile-time check).
1715        type TensorMapEncodeTiledFn = unsafe extern "C" fn(
1716            tensor_map: *mut std::ffi::c_void,
1717            tensor_data_type: u32,
1718            tensor_rank: u32,
1719            global_address: *mut std::ffi::c_void,
1720            global_dim: *const u64,
1721            global_strides: *const u64,
1722            box_dim: *const u32,
1723            element_strides: *const u32,
1724            interleave: u32,
1725            swizzle: u32,
1726            l2_promotion: u32,
1727            oob_fill: u32,
1728            flags: u64,
1729        ) -> CUresult;
1730        let _none: Option<TensorMapEncodeTiledFn> = None;
1731        // Field name check: accessing the field compiles only if it exists.
1732        // We use a trait-object-based field-name probe: the macro produces a
1733        // compile error if the identifier does not exist.
1734        let _field_exists = |api: &DriverApi| api.cu_tensor_map_encode_tiled_memref.is_none();
1735        // Suppress unused variable warnings.
1736        let _ = _none;
1737        let _ = _field_exists;
1738    }
1739
1740    /// Verify that `cu_multicast_create` and `cu_multicast_add_device` fields
1741    /// exist with the correct Option<fn(...)> types (CUDA 12.8+ multicast).
1742    #[test]
1743    fn driver_v12_8_multicast_fields_present() {
1744        let _probe_create = |api: &DriverApi| api.cu_multicast_create.is_none();
1745        let _probe_add = |api: &DriverApi| api.cu_multicast_add_device.is_none();
1746        let _probe_gran = |api: &DriverApi| api.cu_multicast_get_granularity.is_none();
1747        let _ = (_probe_create, _probe_add, _probe_gran);
1748    }
1749
1750    /// Verify that `cu_memcpy_batch_async` field exists with the correct
1751    /// Option<fn(...)> type (CUDA 12.8+ batch memcpy).
1752    #[test]
1753    fn driver_v12_8_batch_memcpy_field_present() {
1754        let _probe = |api: &DriverApi| api.cu_memcpy_batch_async.is_none();
1755        let _ = _probe;
1756    }
1757
1758    /// Verify that `cu_kernel_get_library` field exists (CUDA 12.8+ JIT libs).
1759    #[test]
1760    fn driver_v12_8_kernel_get_library_field_present() {
1761        let _probe = |api: &DriverApi| api.cu_kernel_get_library.is_none();
1762        let _ = _probe;
1763    }
1764
1765    // -----------------------------------------------------------------------
1766    // Wave 1 — Extended Driver API field-presence tests
1767    // -----------------------------------------------------------------------
1768
1769    /// All four `cuLink*` JIT linker fields are present and `Option`.
1770    #[test]
1771    fn driver_link_api_fields_present() {
1772        let _probe_create = |api: &DriverApi| api.cu_link_create.is_none();
1773        let _probe_add = |api: &DriverApi| api.cu_link_add_data.is_none();
1774        let _probe_complete = |api: &DriverApi| api.cu_link_complete.is_none();
1775        let _probe_destroy = |api: &DriverApi| api.cu_link_destroy.is_none();
1776        let _ = (_probe_create, _probe_add, _probe_complete, _probe_destroy);
1777    }
1778
1779    /// The `cuMemcpy2D_v2` field is present and `Option`.
1780    #[test]
1781    fn driver_memcpy_2d_field_present() {
1782        let _probe = |api: &DriverApi| api.cu_memcpy_2d.is_none();
1783        let _ = _probe;
1784    }
1785
1786    /// All seven VMM (CUDA 11.2+) fields are present and `Option`.
1787    #[test]
1788    fn driver_vmm_api_fields_present() {
1789        let _probe_reserve = |api: &DriverApi| api.cu_mem_address_reserve.is_none();
1790        let _probe_free = |api: &DriverApi| api.cu_mem_address_free.is_none();
1791        let _probe_create = |api: &DriverApi| api.cu_mem_create.is_none();
1792        let _probe_release = |api: &DriverApi| api.cu_mem_release.is_none();
1793        let _probe_map = |api: &DriverApi| api.cu_mem_map.is_none();
1794        let _probe_unmap = |api: &DriverApi| api.cu_mem_unmap.is_none();
1795        let _probe_set_access = |api: &DriverApi| api.cu_mem_set_access.is_none();
1796        let _ = (
1797            _probe_reserve,
1798            _probe_free,
1799            _probe_create,
1800            _probe_release,
1801            _probe_map,
1802            _probe_unmap,
1803            _probe_set_access,
1804        );
1805    }
1806
1807    /// All four memory-pool (CUDA 11.2+) fields are present and `Option`.
1808    #[test]
1809    fn driver_mem_pool_api_fields_present() {
1810        let _probe_create = |api: &DriverApi| api.cu_mem_pool_create.is_none();
1811        let _probe_destroy = |api: &DriverApi| api.cu_mem_pool_destroy.is_none();
1812        let _probe_alloc = |api: &DriverApi| api.cu_mem_alloc_from_pool_async.is_none();
1813        let _probe_free = |api: &DriverApi| api.cu_mem_free_async.is_none();
1814        let _ = (_probe_create, _probe_destroy, _probe_alloc, _probe_free);
1815    }
1816
1817    /// The extended stream-ordered memory-pool fields are present and `Option`.
1818    #[test]
1819    fn driver_stream_ordered_pool_extended_fields_present() {
1820        let _probe_alloc_async = |api: &DriverApi| api.cu_mem_alloc_async.is_none();
1821        let _probe_trim = |api: &DriverApi| api.cu_mem_pool_trim_to.is_none();
1822        let _probe_set_attr = |api: &DriverApi| api.cu_mem_pool_set_attribute.is_none();
1823        let _probe_get_attr = |api: &DriverApi| api.cu_mem_pool_get_attribute.is_none();
1824        let _probe_set_access = |api: &DriverApi| api.cu_mem_pool_set_access.is_none();
1825        let _probe_default = |api: &DriverApi| api.cu_device_get_default_mem_pool.is_none();
1826        let _ = (
1827            _probe_alloc_async,
1828            _probe_trim,
1829            _probe_set_attr,
1830            _probe_get_attr,
1831            _probe_set_access,
1832            _probe_default,
1833        );
1834    }
1835
1836    /// All CUDA Graph API fields are present and `Option`.
1837    #[test]
1838    fn driver_graph_api_fields_present() {
1839        let _probe_create = |api: &DriverApi| api.cu_graph_create.is_none();
1840        let _probe_destroy = |api: &DriverApi| api.cu_graph_destroy.is_none();
1841        let _probe_kernel = |api: &DriverApi| api.cu_graph_add_kernel_node.is_none();
1842        let _probe_memcpy = |api: &DriverApi| api.cu_graph_add_memcpy_node.is_none();
1843        let _probe_memset = |api: &DriverApi| api.cu_graph_add_memset_node.is_none();
1844        let _probe_empty = |api: &DriverApi| api.cu_graph_add_empty_node.is_none();
1845        let _probe_inst_flags = |api: &DriverApi| api.cu_graph_instantiate_with_flags.is_none();
1846        let _probe_inst = |api: &DriverApi| api.cu_graph_instantiate.is_none();
1847        let _probe_exec_destroy = |api: &DriverApi| api.cu_graph_exec_destroy.is_none();
1848        let _probe_launch = |api: &DriverApi| api.cu_graph_launch.is_none();
1849        let _ = (
1850            _probe_create,
1851            _probe_destroy,
1852            _probe_kernel,
1853            _probe_memcpy,
1854            _probe_memset,
1855            _probe_empty,
1856            _probe_inst_flags,
1857            _probe_inst,
1858            _probe_exec_destroy,
1859            _probe_launch,
1860        );
1861    }
1862}