Skip to main content

baracuda_cuda_sys/
driver.rs

1//! The [`Driver`] singleton: a lazily-loaded handle to `libcuda` with a
2//! cached, version-aware function-pointer table.
3//!
4//! Typical use from a safe crate:
5//!
6//! ```no_run
7//! use baracuda_cuda_sys::driver;
8//! let d = driver()?;
9//! let cu_init = d.cu_init()?;
10//! // SAFETY: we just resolved the symbol; calling it matches the CUDA ABI.
11//! unsafe { cu_init(0) };
12//! # Ok::<(), baracuda_core::LoaderError>(())
13//! ```
14//!
15//! Function names are normalized from the C `cuInit` / `cuDeviceGet` form
16//! to `snake_case` (`cu_init`, `cu_device_get`) so they don't clash with
17//! Rust's naming lints.
18
19use core::ffi::c_char;
20use std::ptr;
21use std::sync::OnceLock;
22
23use baracuda_core::{platform, stream_mode, Library, LoaderError};
24use baracuda_types::StreamMode;
25
26use crate::functions::*;
27use crate::status::CUresult;
28
29/// Flag passed to `cuGetProcAddress` in [`StreamMode::Legacy`].
30const CU_GET_PROC_ADDRESS_LEGACY_STREAM: u64 = 1;
31/// Flag passed to `cuGetProcAddress` in [`StreamMode::PerThread`].
32const CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM: u64 = 2;
33
34/// `CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT` from `cuda.h`. Returned by
35/// `cuGetProcAddress` when the symbol exists in newer CUDA versions than the
36/// one we asked for.
37#[allow(dead_code)]
38const CU_GET_PROC_ADDRESS_VERSION_NOT_SUFFICIENT: i32 = 222;
39
40macro_rules! driver_fns {
41    ($(
42        $(#[$attr:meta])*
43        fn $name:ident as $sym:literal : $pfn:ty;
44    )*) => {
45        /// Lazily-resolved CUDA Driver API function-pointer table.
46        #[allow(non_snake_case)]
47        pub struct Driver {
48            lib: Library,
49            get_proc_address: OnceLock<PFN_cuGetProcAddress>,
50            $(
51                $name: OnceLock<$pfn>,
52            )*
53        }
54
55        impl core::fmt::Debug for Driver {
56            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
57                f.debug_struct("Driver")
58                    .field("lib", &self.lib)
59                    .finish_non_exhaustive()
60            }
61        }
62
63        impl Driver {
64            fn empty(lib: Library) -> Self {
65                Self {
66                    lib,
67                    get_proc_address: OnceLock::new(),
68                    $(
69                        $name: OnceLock::new(),
70                    )*
71                }
72            }
73
74            $(
75                $(#[$attr])*
76                #[allow(non_snake_case)]
77                #[doc = concat!("Resolve `", $sym, "` and return the cached function pointer.")]
78                pub fn $name(&self) -> Result<$pfn, LoaderError> {
79                    if let Some(&p) = self.$name.get() { return Ok(p); }
80                    let p: $pfn = unsafe { self.resolve($sym)? };
81                    let _ = self.$name.set(p);
82                    Ok(p)
83                }
84            )*
85        }
86    };
87}
88
89// Naming convention: symbols we've pinned to a specific ABI (our PFN_*
90// signature matches a particular `_vN` variant) use the explicit versioned
91// name so the resolver goes through dlsym and the driver cannot upgrade us
92// to a newer, incompatible ABI. Symbols we haven't pinned use the base
93// name so cuGetProcAddress can pick the best variant for the installed
94// driver (and the _ptsz variant under per-thread-default-stream mode).
95driver_fns! {
96    // Initialization & version
97    fn cu_init as "cuInit": PFN_cuInit;
98    fn cu_driver_get_version as "cuDriverGetVersion": PFN_cuDriverGetVersion;
99
100    // Errors (note: the strings returned are owned by the driver; do not free)
101    fn cu_get_error_name as "cuGetErrorName": PFN_cuGetErrorName;
102    fn cu_get_error_string as "cuGetErrorString": PFN_cuGetErrorString;
103
104    // Device
105    fn cu_device_get_count as "cuDeviceGetCount": PFN_cuDeviceGetCount;
106    fn cu_device_get as "cuDeviceGet": PFN_cuDeviceGet;
107    fn cu_device_get_name as "cuDeviceGetName": PFN_cuDeviceGetName;
108    fn cu_device_get_attribute as "cuDeviceGetAttribute": PFN_cuDeviceGetAttribute;
109    fn cu_device_total_mem as "cuDeviceTotalMem_v2": PFN_cuDeviceTotalMem;
110
111    // Context — pinned to _v2 because cuCtxCreate_v3 (CUDA 11.4) / _v4 (12.5)
112    // have extra parameters and would be returned by cuGetProcAddress.
113    fn cu_ctx_create as "cuCtxCreate_v2": PFN_cuCtxCreate;
114    fn cu_ctx_destroy as "cuCtxDestroy_v2": PFN_cuCtxDestroy;
115    fn cu_ctx_get_current as "cuCtxGetCurrent": PFN_cuCtxGetCurrent;
116    fn cu_ctx_set_current as "cuCtxSetCurrent": PFN_cuCtxSetCurrent;
117    fn cu_ctx_push_current as "cuCtxPushCurrent_v2": PFN_cuCtxPushCurrent;
118    fn cu_ctx_pop_current as "cuCtxPopCurrent_v2": PFN_cuCtxPopCurrent;
119    fn cu_ctx_synchronize as "cuCtxSynchronize": PFN_cuCtxSynchronize;
120
121    // Primary context
122    fn cu_device_primary_ctx_retain as "cuDevicePrimaryCtxRetain": PFN_cuDevicePrimaryCtxRetain;
123    fn cu_device_primary_ctx_release as "cuDevicePrimaryCtxRelease_v2": PFN_cuDevicePrimaryCtxRelease;
124    fn cu_device_primary_ctx_reset as "cuDevicePrimaryCtxReset_v2": PFN_cuDevicePrimaryCtxReset;
125
126    // Memory — pinned to _v2 (64-bit addresses, stable since CUDA 3.2).
127    fn cu_mem_alloc as "cuMemAlloc_v2": PFN_cuMemAlloc;
128    fn cu_mem_free as "cuMemFree_v2": PFN_cuMemFree;
129    fn cu_memcpy_htod as "cuMemcpyHtoD_v2": PFN_cuMemcpyHtoD;
130    fn cu_memcpy_dtoh as "cuMemcpyDtoH_v2": PFN_cuMemcpyDtoH;
131    fn cu_memcpy_dtod as "cuMemcpyDtoD_v2": PFN_cuMemcpyDtoD;
132    fn cu_memcpy_htod_async as "cuMemcpyHtoDAsync_v2": PFN_cuMemcpyHtoDAsync;
133    fn cu_memcpy_dtoh_async as "cuMemcpyDtoHAsync_v2": PFN_cuMemcpyDtoHAsync;
134    fn cu_memset_d8 as "cuMemsetD8_v2": PFN_cuMemsetD8;
135    fn cu_memset_d32 as "cuMemsetD32_v2": PFN_cuMemsetD32;
136
137    // Stream
138    fn cu_stream_create as "cuStreamCreate": PFN_cuStreamCreate;
139    fn cu_stream_destroy as "cuStreamDestroy_v2": PFN_cuStreamDestroy;
140    fn cu_stream_synchronize as "cuStreamSynchronize": PFN_cuStreamSynchronize;
141    fn cu_stream_query as "cuStreamQuery": PFN_cuStreamQuery;
142    fn cu_stream_wait_event as "cuStreamWaitEvent": PFN_cuStreamWaitEvent;
143
144    // Event
145    fn cu_event_create as "cuEventCreate": PFN_cuEventCreate;
146    fn cu_event_destroy as "cuEventDestroy_v2": PFN_cuEventDestroy;
147    fn cu_event_record as "cuEventRecord": PFN_cuEventRecord;
148    fn cu_event_synchronize as "cuEventSynchronize": PFN_cuEventSynchronize;
149    fn cu_event_query as "cuEventQuery": PFN_cuEventQuery;
150    fn cu_event_elapsed_time as "cuEventElapsedTime": PFN_cuEventElapsedTime;
151
152    // Module / kernel
153    fn cu_module_load_data as "cuModuleLoadData": PFN_cuModuleLoadData;
154    fn cu_module_unload as "cuModuleUnload": PFN_cuModuleUnload;
155    fn cu_module_get_function as "cuModuleGetFunction": PFN_cuModuleGetFunction;
156    fn cu_launch_kernel as "cuLaunchKernel": PFN_cuLaunchKernel;
157
158    // Stream capture / graphs — _v2 of cuStreamBeginCapture pins the 2-arg signature
159    // (which is what CUDA 10.1+ shipped; the older 1-arg form is deprecated).
160    fn cu_stream_begin_capture as "cuStreamBeginCapture_v2": PFN_cuStreamBeginCapture;
161    fn cu_stream_end_capture as "cuStreamEndCapture": PFN_cuStreamEndCapture;
162    fn cu_stream_is_capturing as "cuStreamIsCapturing": PFN_cuStreamIsCapturing;
163    fn cu_graph_create as "cuGraphCreate": PFN_cuGraphCreate;
164    fn cu_graph_destroy as "cuGraphDestroy": PFN_cuGraphDestroy;
165    fn cu_graph_instantiate_with_flags as "cuGraphInstantiateWithFlags": PFN_cuGraphInstantiateWithFlags;
166    fn cu_graph_launch as "cuGraphLaunch": PFN_cuGraphLaunch;
167    fn cu_graph_exec_destroy as "cuGraphExecDestroy": PFN_cuGraphExecDestroy;
168    fn cu_graph_get_nodes as "cuGraphGetNodes": PFN_cuGraphGetNodes;
169
170    // Stream-ordered memory allocation (CUDA 11.2+, available at our 11.4 floor).
171    fn cu_mem_alloc_async as "cuMemAllocAsync": PFN_cuMemAllocAsync;
172    fn cu_mem_free_async as "cuMemFreeAsync": PFN_cuMemFreeAsync;
173
174    // ---- Wave 1 additions ----
175
176    // Occupancy
177    fn cu_occupancy_max_active_blocks_per_multiprocessor as "cuOccupancyMaxActiveBlocksPerMultiprocessor": PFN_cuOccupancyMaxActiveBlocksPerMultiprocessor;
178    fn cu_occupancy_max_active_blocks_per_multiprocessor_with_flags as "cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags": PFN_cuOccupancyMaxActiveBlocksPerMultiprocessorWithFlags;
179    fn cu_occupancy_max_potential_block_size as "cuOccupancyMaxPotentialBlockSize": PFN_cuOccupancyMaxPotentialBlockSize;
180    fn cu_occupancy_available_dynamic_smem_per_block as "cuOccupancyAvailableDynamicSMemPerBlock": PFN_cuOccupancyAvailableDynamicSMemPerBlock;
181
182    // Unified memory
183    fn cu_mem_alloc_managed as "cuMemAllocManaged": PFN_cuMemAllocManaged;
184    fn cu_mem_advise as "cuMemAdvise": PFN_cuMemAdvise;
185    fn cu_mem_prefetch_async as "cuMemPrefetchAsync": PFN_cuMemPrefetchAsync;
186    fn cu_mem_get_info as "cuMemGetInfo_v2": PFN_cuMemGetInfo;
187
188    // Context queries/config
189    fn cu_ctx_get_device as "cuCtxGetDevice": PFN_cuCtxGetDevice;
190    fn cu_ctx_get_api_version as "cuCtxGetApiVersion": PFN_cuCtxGetApiVersion;
191    fn cu_ctx_get_flags as "cuCtxGetFlags": PFN_cuCtxGetFlags;
192    fn cu_ctx_get_limit as "cuCtxGetLimit": PFN_cuCtxGetLimit;
193    fn cu_ctx_set_limit as "cuCtxSetLimit": PFN_cuCtxSetLimit;
194    fn cu_ctx_get_cache_config as "cuCtxGetCacheConfig": PFN_cuCtxGetCacheConfig;
195    fn cu_ctx_set_cache_config as "cuCtxSetCacheConfig": PFN_cuCtxSetCacheConfig;
196    fn cu_ctx_get_stream_priority_range as "cuCtxGetStreamPriorityRange": PFN_cuCtxGetStreamPriorityRange;
197
198    // Peer
199    fn cu_device_can_access_peer as "cuDeviceCanAccessPeer": PFN_cuDeviceCanAccessPeer;
200    fn cu_ctx_enable_peer_access as "cuCtxEnablePeerAccess": PFN_cuCtxEnablePeerAccess;
201    fn cu_ctx_disable_peer_access as "cuCtxDisablePeerAccess": PFN_cuCtxDisablePeerAccess;
202
203    // Pointer attributes
204    fn cu_pointer_get_attribute as "cuPointerGetAttribute": PFN_cuPointerGetAttribute;
205
206    // Stream priority + host func
207    fn cu_stream_create_with_priority as "cuStreamCreateWithPriority": PFN_cuStreamCreateWithPriority;
208    fn cu_stream_get_priority as "cuStreamGetPriority": PFN_cuStreamGetPriority;
209    fn cu_stream_get_flags as "cuStreamGetFlags": PFN_cuStreamGetFlags;
210    fn cu_stream_get_ctx as "cuStreamGetCtx": PFN_cuStreamGetCtx;
211    fn cu_launch_host_func as "cuLaunchHostFunc": PFN_cuLaunchHostFunc;
212
213    // Event flags
214    fn cu_event_record_with_flags as "cuEventRecordWithFlags": PFN_cuEventRecordWithFlags;
215
216    // Primary-context state
217    fn cu_device_primary_ctx_get_state as "cuDevicePrimaryCtxGetState": PFN_cuDevicePrimaryCtxGetState;
218    fn cu_device_primary_ctx_set_flags as "cuDevicePrimaryCtxSetFlags_v2": PFN_cuDevicePrimaryCtxSetFlags;
219
220    // ---- Wave 2 ----
221    fn cu_func_get_attribute as "cuFuncGetAttribute": PFN_cuFuncGetAttribute;
222    fn cu_func_set_attribute as "cuFuncSetAttribute": PFN_cuFuncSetAttribute;
223    fn cu_module_get_global as "cuModuleGetGlobal_v2": PFN_cuModuleGetGlobal;
224    fn cu_module_load_data_ex as "cuModuleLoadDataEx": PFN_cuModuleLoadDataEx;
225
226    // ---- Wave 3: extensible launch + library management (CUDA 12.0+) ----
227    fn cu_launch_kernel_ex as "cuLaunchKernelEx": PFN_cuLaunchKernelEx;
228    fn cu_library_load_data as "cuLibraryLoadData": PFN_cuLibraryLoadData;
229    fn cu_library_unload as "cuLibraryUnload": PFN_cuLibraryUnload;
230    fn cu_library_get_kernel as "cuLibraryGetKernel": PFN_cuLibraryGetKernel;
231    fn cu_library_get_global as "cuLibraryGetGlobal": PFN_cuLibraryGetGlobal;
232    fn cu_kernel_get_function as "cuKernelGetFunction": PFN_cuKernelGetFunction;
233
234    // ---- Wave 4: 2D alloc + memcpy ----
235    fn cu_mem_alloc_pitch as "cuMemAllocPitch_v2": PFN_cuMemAllocPitch;
236    fn cu_memcpy_2d as "cuMemcpy2D_v2": PFN_cuMemcpy2D;
237    fn cu_memcpy_2d_async as "cuMemcpy2DAsync_v2": PFN_cuMemcpy2DAsync;
238
239    // ---- Wave 5: explicit graph node construction ----
240    // `cuGraphAddKernelNode_v2` is pinned because our CUDA_KERNEL_NODE_PARAMS
241    // matches the v2 shape (kern + ctx fields). The _v2 suffix routes this
242    // through dlsym so the driver can't upgrade us to a future ABI.
243    fn cu_graph_add_kernel_node as "cuGraphAddKernelNode_v2": PFN_cuGraphAddKernelNode;
244    fn cu_graph_add_empty_node as "cuGraphAddEmptyNode": PFN_cuGraphAddEmptyNode;
245    fn cu_graph_add_memset_node as "cuGraphAddMemsetNode": PFN_cuGraphAddMemsetNode;
246    fn cu_graph_destroy_node as "cuGraphDestroyNode": PFN_cuGraphDestroyNode;
247    fn cu_graph_clone as "cuGraphClone": PFN_cuGraphClone;
248
249    // ---- Wave 6: arrays, textures, surfaces ----
250    fn cu_array_create as "cuArrayCreate_v2": PFN_cuArrayCreate;
251    fn cu_array_destroy as "cuArrayDestroy": PFN_cuArrayDestroy;
252    fn cu_tex_object_create as "cuTexObjectCreate": PFN_cuTexObjectCreate;
253    fn cu_tex_object_destroy as "cuTexObjectDestroy": PFN_cuTexObjectDestroy;
254    fn cu_surf_object_create as "cuSurfObjectCreate": PFN_cuSurfObjectCreate;
255    fn cu_surf_object_destroy as "cuSurfObjectDestroy": PFN_cuSurfObjectDestroy;
256
257    // ---- Wave 7: virtual memory management (VMM) ----
258    fn cu_mem_address_reserve as "cuMemAddressReserve": PFN_cuMemAddressReserve;
259    fn cu_mem_address_free as "cuMemAddressFree": PFN_cuMemAddressFree;
260    fn cu_mem_create as "cuMemCreate": PFN_cuMemCreate;
261    fn cu_mem_release as "cuMemRelease": PFN_cuMemRelease;
262    fn cu_mem_map as "cuMemMap": PFN_cuMemMap;
263    fn cu_mem_unmap as "cuMemUnmap": PFN_cuMemUnmap;
264    fn cu_mem_set_access as "cuMemSetAccess": PFN_cuMemSetAccess;
265    fn cu_mem_get_allocation_granularity as "cuMemGetAllocationGranularity": PFN_cuMemGetAllocationGranularity;
266
267    // ---- Wave 8: memory pools ----
268    fn cu_mem_pool_create as "cuMemPoolCreate": PFN_cuMemPoolCreate;
269    fn cu_mem_pool_destroy as "cuMemPoolDestroy": PFN_cuMemPoolDestroy;
270    fn cu_mem_pool_set_attribute as "cuMemPoolSetAttribute": PFN_cuMemPoolSetAttribute;
271    fn cu_mem_pool_get_attribute as "cuMemPoolGetAttribute": PFN_cuMemPoolGetAttribute;
272    fn cu_mem_pool_trim_to as "cuMemPoolTrimTo": PFN_cuMemPoolTrimTo;
273    fn cu_mem_pool_set_access as "cuMemPoolSetAccess": PFN_cuMemPoolSetAccess;
274    fn cu_mem_pool_get_access as "cuMemPoolGetAccess": PFN_cuMemPoolGetAccess;
275    fn cu_mem_alloc_from_pool_async as "cuMemAllocFromPoolAsync": PFN_cuMemAllocFromPoolAsync;
276    fn cu_device_get_default_mem_pool as "cuDeviceGetDefaultMemPool": PFN_cuDeviceGetDefaultMemPool;
277    fn cu_device_get_mem_pool as "cuDeviceGetMemPool": PFN_cuDeviceGetMemPool;
278    fn cu_device_set_mem_pool as "cuDeviceSetMemPool": PFN_cuDeviceSetMemPool;
279    fn cu_mem_pool_export_to_shareable_handle as "cuMemPoolExportToShareableHandle": PFN_cuMemPoolExportToShareableHandle;
280    fn cu_mem_pool_import_from_shareable_handle as "cuMemPoolImportFromShareableHandle": PFN_cuMemPoolImportFromShareableHandle;
281    fn cu_mem_pool_export_pointer as "cuMemPoolExportPointer": PFN_cuMemPoolExportPointer;
282    fn cu_mem_pool_import_pointer as "cuMemPoolImportPointer": PFN_cuMemPoolImportPointer;
283
284    // ---- Wave 9: external memory / semaphore interop ----
285    fn cu_import_external_memory as "cuImportExternalMemory": PFN_cuImportExternalMemory;
286    fn cu_destroy_external_memory as "cuDestroyExternalMemory": PFN_cuDestroyExternalMemory;
287    fn cu_external_memory_get_mapped_buffer as "cuExternalMemoryGetMappedBuffer": PFN_cuExternalMemoryGetMappedBuffer;
288    fn cu_external_memory_get_mapped_mipmapped_array as "cuExternalMemoryGetMappedMipmappedArray": PFN_cuExternalMemoryGetMappedMipmappedArray;
289    fn cu_import_external_semaphore as "cuImportExternalSemaphore": PFN_cuImportExternalSemaphore;
290    fn cu_destroy_external_semaphore as "cuDestroyExternalSemaphore": PFN_cuDestroyExternalSemaphore;
291    fn cu_signal_external_semaphores_async as "cuSignalExternalSemaphoresAsync": PFN_cuSignalExternalSemaphoresAsync;
292    fn cu_wait_external_semaphores_async as "cuWaitExternalSemaphoresAsync": PFN_cuWaitExternalSemaphoresAsync;
293
294    // ---- Wave 10: 3D memcpy + 3D arrays + mipmapped arrays ----
295    fn cu_array_3d_create as "cuArray3DCreate_v2": PFN_cuArray3DCreate;
296    fn cu_array_3d_get_descriptor as "cuArray3DGetDescriptor_v2": PFN_cuArray3DGetDescriptor;
297    fn cu_memcpy_3d as "cuMemcpy3D_v2": PFN_cuMemcpy3D;
298    fn cu_memcpy_3d_async as "cuMemcpy3DAsync_v2": PFN_cuMemcpy3DAsync;
299    fn cu_mipmapped_array_create as "cuMipmappedArrayCreate": PFN_cuMipmappedArrayCreate;
300    fn cu_mipmapped_array_destroy as "cuMipmappedArrayDestroy": PFN_cuMipmappedArrayDestroy;
301    fn cu_mipmapped_array_get_level as "cuMipmappedArrayGetLevel": PFN_cuMipmappedArrayGetLevel;
302
303    // ---- Wave 11: pinned host memory ----
304    fn cu_mem_alloc_host as "cuMemAllocHost_v2": PFN_cuMemAllocHost;
305    fn cu_mem_free_host as "cuMemFreeHost": PFN_cuMemFreeHost;
306    fn cu_mem_host_alloc as "cuMemHostAlloc": PFN_cuMemHostAlloc;
307    fn cu_mem_host_register as "cuMemHostRegister_v2": PFN_cuMemHostRegister;
308    fn cu_mem_host_unregister as "cuMemHostUnregister": PFN_cuMemHostUnregister;
309    fn cu_mem_host_get_device_pointer as "cuMemHostGetDevicePointer_v2": PFN_cuMemHostGetDevicePointer;
310    fn cu_mem_host_get_flags as "cuMemHostGetFlags": PFN_cuMemHostGetFlags;
311
312    // ---- Wave 12: full graph node builders + edit ----
313    fn cu_graph_add_memcpy_node as "cuGraphAddMemcpyNode": PFN_cuGraphAddMemcpyNode;
314    fn cu_graph_add_host_node as "cuGraphAddHostNode": PFN_cuGraphAddHostNode;
315    fn cu_graph_add_child_graph_node as "cuGraphAddChildGraphNode": PFN_cuGraphAddChildGraphNode;
316    fn cu_graph_add_event_record_node as "cuGraphAddEventRecordNode": PFN_cuGraphAddEventRecordNode;
317    fn cu_graph_add_event_wait_node as "cuGraphAddEventWaitNode": PFN_cuGraphAddEventWaitNode;
318    fn cu_graph_add_external_semaphores_signal_node as "cuGraphAddExternalSemaphoresSignalNode": PFN_cuGraphAddExternalSemaphoresSignalNode;
319    fn cu_graph_add_external_semaphores_wait_node as "cuGraphAddExternalSemaphoresWaitNode": PFN_cuGraphAddExternalSemaphoresWaitNode;
320    // Node-param get/set pinned to v2 variants (match our v2 struct shape).
321    fn cu_graph_kernel_node_get_params as "cuGraphKernelNodeGetParams_v2": PFN_cuGraphKernelNodeGetParams;
322    fn cu_graph_kernel_node_set_params as "cuGraphKernelNodeSetParams_v2": PFN_cuGraphKernelNodeSetParams;
323    fn cu_graph_memcpy_node_get_params as "cuGraphMemcpyNodeGetParams": PFN_cuGraphMemcpyNodeGetParams;
324    fn cu_graph_memcpy_node_set_params as "cuGraphMemcpyNodeSetParams": PFN_cuGraphMemcpyNodeSetParams;
325    fn cu_graph_memset_node_get_params as "cuGraphMemsetNodeGetParams": PFN_cuGraphMemsetNodeGetParams;
326    fn cu_graph_memset_node_set_params as "cuGraphMemsetNodeSetParams": PFN_cuGraphMemsetNodeSetParams;
327    fn cu_graph_node_get_type as "cuGraphNodeGetType": PFN_cuGraphNodeGetType;
328    fn cu_graph_node_get_dependencies as "cuGraphNodeGetDependencies": PFN_cuGraphNodeGetDependencies;
329    fn cu_graph_node_get_dependent_nodes as "cuGraphNodeGetDependentNodes": PFN_cuGraphNodeGetDependentNodes;
330    fn cu_graph_get_edges as "cuGraphGetEdges": PFN_cuGraphGetEdges;
331    fn cu_graph_add_dependencies as "cuGraphAddDependencies": PFN_cuGraphAddDependencies;
332    fn cu_graph_remove_dependencies as "cuGraphRemoveDependencies": PFN_cuGraphRemoveDependencies;
333    fn cu_graph_exec_kernel_node_set_params as "cuGraphExecKernelNodeSetParams_v2": PFN_cuGraphExecKernelNodeSetParams;
334    fn cu_graph_exec_memcpy_node_set_params as "cuGraphExecMemcpyNodeSetParams": PFN_cuGraphExecMemcpyNodeSetParams;
335    fn cu_graph_exec_memset_node_set_params as "cuGraphExecMemsetNodeSetParams": PFN_cuGraphExecMemsetNodeSetParams;
336    fn cu_graph_exec_host_node_set_params as "cuGraphExecHostNodeSetParams": PFN_cuGraphExecHostNodeSetParams;
337
338    // ---- Wave 13: stream extras ----
339    fn cu_stream_get_id as "cuStreamGetId": PFN_cuStreamGetId;
340    fn cu_stream_copy_attributes as "cuStreamCopyAttributes": PFN_cuStreamCopyAttributes;
341    fn cu_stream_get_attribute as "cuStreamGetAttribute": PFN_cuStreamGetAttribute;
342    fn cu_stream_set_attribute as "cuStreamSetAttribute": PFN_cuStreamSetAttribute;
343    fn cu_stream_attach_mem_async as "cuStreamAttachMemAsync": PFN_cuStreamAttachMemAsync;
344    fn cu_stream_get_capture_info as "cuStreamGetCaptureInfo_v2": PFN_cuStreamGetCaptureInfo;
345    fn cu_stream_update_capture_dependencies as "cuStreamUpdateCaptureDependencies": PFN_cuStreamUpdateCaptureDependencies;
346
347    // ---- Wave 14: misc memcpy variants ----
348    fn cu_memcpy_dtod_async as "cuMemcpyDtoDAsync_v2": PFN_cuMemcpyDtoDAsync;
349    fn cu_memcpy_peer as "cuMemcpyPeer": PFN_cuMemcpyPeer;
350    fn cu_memcpy_peer_async as "cuMemcpyPeerAsync": PFN_cuMemcpyPeerAsync;
351    fn cu_memcpy as "cuMemcpy": PFN_cuMemcpy;
352    fn cu_memcpy_async as "cuMemcpyAsync": PFN_cuMemcpyAsync;
353    fn cu_memcpy_atoh as "cuMemcpyAtoH_v2": PFN_cuMemcpyAtoH;
354    fn cu_memcpy_htoa as "cuMemcpyHtoA_v2": PFN_cuMemcpyHtoA;
355    fn cu_memcpy_atod as "cuMemcpyAtoD_v2": PFN_cuMemcpyAtoD;
356    fn cu_memcpy_dtoa as "cuMemcpyDtoA_v2": PFN_cuMemcpyDtoA;
357    fn cu_memcpy_atoa as "cuMemcpyAtoA_v2": PFN_cuMemcpyAtoA;
358    fn cu_memset_d16 as "cuMemsetD16_v2": PFN_cuMemsetD16;
359    fn cu_memset_d8_async as "cuMemsetD8Async": PFN_cuMemsetD8Async;
360    fn cu_memset_d16_async as "cuMemsetD16Async": PFN_cuMemsetD16Async;
361    fn cu_memset_d32_async as "cuMemsetD32Async": PFN_cuMemsetD32Async;
362    fn cu_memset_d2d8 as "cuMemsetD2D8_v2": PFN_cuMemsetD2D8;
363    fn cu_memset_d2d16 as "cuMemsetD2D16_v2": PFN_cuMemsetD2D16;
364    fn cu_memset_d2d32 as "cuMemsetD2D32_v2": PFN_cuMemsetD2D32;
365
366    // ---- Wave 15: range + pointer attrs ----
367    fn cu_mem_range_get_attribute as "cuMemRangeGetAttribute": PFN_cuMemRangeGetAttribute;
368    fn cu_mem_range_get_attributes as "cuMemRangeGetAttributes": PFN_cuMemRangeGetAttributes;
369    fn cu_pointer_get_attributes as "cuPointerGetAttributes": PFN_cuPointerGetAttributes;
370    fn cu_pointer_set_attribute as "cuPointerSetAttribute": PFN_cuPointerSetAttribute;
371
372    // ---- Wave 16: tensor maps (Hopper TMA) ----
373    fn cu_tensor_map_encode_tiled as "cuTensorMapEncodeTiled": PFN_cuTensorMapEncodeTiled;
374    fn cu_tensor_map_encode_im2col as "cuTensorMapEncodeIm2col": PFN_cuTensorMapEncodeIm2col;
375    fn cu_tensor_map_replace_address as "cuTensorMapReplaceAddress": PFN_cuTensorMapReplaceAddress;
376
377    // ---- Wave 17: green contexts (CUDA 12.4+) ----
378    fn cu_device_get_dev_resource as "cuDeviceGetDevResource": PFN_cuDeviceGetDevResource;
379    fn cu_dev_sm_resource_split_by_count as "cuDevSmResourceSplitByCount": PFN_cuDevSmResourceSplitByCount;
380    fn cu_dev_resource_generate_desc as "cuDevResourceGenerateDesc": PFN_cuDevResourceGenerateDesc;
381    fn cu_green_ctx_create as "cuGreenCtxCreate": PFN_cuGreenCtxCreate;
382    fn cu_green_ctx_destroy as "cuGreenCtxDestroy": PFN_cuGreenCtxDestroy;
383    fn cu_ctx_from_green_ctx as "cuCtxFromGreenCtx": PFN_cuCtxFromGreenCtx;
384    fn cu_green_ctx_get_dev_resource as "cuGreenCtxGetDevResource": PFN_cuGreenCtxGetDevResource;
385    fn cu_green_ctx_stream_create as "cuGreenCtxStreamCreate": PFN_cuGreenCtxStreamCreate;
386
387    // ---- Wave 18: multicast objects ----
388    fn cu_multicast_create as "cuMulticastCreate": PFN_cuMulticastCreate;
389    fn cu_multicast_add_device as "cuMulticastAddDevice": PFN_cuMulticastAddDevice;
390    fn cu_multicast_bind_mem as "cuMulticastBindMem": PFN_cuMulticastBindMem;
391    fn cu_multicast_bind_addr as "cuMulticastBindAddr": PFN_cuMulticastBindAddr;
392    fn cu_multicast_unbind as "cuMulticastUnbind": PFN_cuMulticastUnbind;
393    fn cu_multicast_get_granularity as "cuMulticastGetGranularity": PFN_cuMulticastGetGranularity;
394
395    // ---- Wave 19: conditional + switch graph nodes ----
396    fn cu_graph_add_node as "cuGraphAddNode_v2": PFN_cuGraphAddNode;
397    fn cu_graph_node_set_params as "cuGraphNodeSetParams": PFN_cuGraphNodeSetParams;
398    fn cu_graph_conditional_handle_create as "cuGraphConditionalHandleCreate": PFN_cuGraphConditionalHandleCreate;
399
400    // ---- Wave 20: IPC ----
401    fn cu_ipc_get_event_handle as "cuIpcGetEventHandle": PFN_cuIpcGetEventHandle;
402    fn cu_ipc_open_event_handle as "cuIpcOpenEventHandle": PFN_cuIpcOpenEventHandle;
403    fn cu_ipc_get_mem_handle as "cuIpcGetMemHandle": PFN_cuIpcGetMemHandle;
404    fn cu_ipc_open_mem_handle as "cuIpcOpenMemHandle_v2": PFN_cuIpcOpenMemHandle;
405    fn cu_ipc_close_mem_handle as "cuIpcCloseMemHandle": PFN_cuIpcCloseMemHandle;
406
407    // ---- Wave 21: kernel attrs extension (CUDA 12+) ----
408    fn cu_kernel_get_attribute as "cuKernelGetAttribute": PFN_cuKernelGetAttribute;
409    fn cu_kernel_set_attribute as "cuKernelSetAttribute": PFN_cuKernelSetAttribute;
410    fn cu_kernel_get_name as "cuKernelGetName": PFN_cuKernelGetName;
411    fn cu_kernel_set_cache_config as "cuKernelSetCacheConfig": PFN_cuKernelSetCacheConfig;
412    fn cu_kernel_get_library as "cuKernelGetLibrary": PFN_cuKernelGetLibrary;
413    fn cu_kernel_get_param_info as "cuKernelGetParamInfo": PFN_cuKernelGetParamInfo;
414
415    // ---- Wave 22: user objects ----
416    fn cu_user_object_create as "cuUserObjectCreate": PFN_cuUserObjectCreate;
417    fn cu_user_object_retain as "cuUserObjectRetain": PFN_cuUserObjectRetain;
418    fn cu_user_object_release as "cuUserObjectRelease": PFN_cuUserObjectRelease;
419    fn cu_graph_retain_user_object as "cuGraphRetainUserObject": PFN_cuGraphRetainUserObject;
420    fn cu_graph_release_user_object as "cuGraphReleaseUserObject": PFN_cuGraphReleaseUserObject;
421
422    // ---- Wave 23: misc extras ----
423    fn cu_profiler_start as "cuProfilerStart": PFN_cuProfilerStart;
424    fn cu_profiler_stop as "cuProfilerStop": PFN_cuProfilerStop;
425    fn cu_func_get_module as "cuFuncGetModule": PFN_cuFuncGetModule;
426    fn cu_func_get_name as "cuFuncGetName": PFN_cuFuncGetName;
427    fn cu_func_get_param_info as "cuFuncGetParamInfo": PFN_cuFuncGetParamInfo;
428    fn cu_graph_debug_dot_print as "cuGraphDebugDotPrint": PFN_cuGraphDebugDotPrint;
429    fn cu_ctx_get_id as "cuCtxGetId": PFN_cuCtxGetId;
430    fn cu_module_get_loading_mode as "cuModuleGetLoadingMode": PFN_cuModuleGetLoadingMode;
431    fn cu_device_get_uuid as "cuDeviceGetUuid_v2": PFN_cuDeviceGetUuid;
432    fn cu_device_get_luid as "cuDeviceGetLuid": PFN_cuDeviceGetLuid;
433    fn cu_logs_register_callback as "cuLogsRegisterCallback": PFN_cuLogsRegisterCallback;
434    fn cu_logs_unregister_callback as "cuLogsUnregisterCallback": PFN_cuLogsUnregisterCallback;
435    fn cu_logs_current as "cuLogsCurrent": PFN_cuLogsCurrent;
436    fn cu_logs_dump_to_file as "cuLogsDumpToFile": PFN_cuLogsDumpToFile;
437    fn cu_logs_dump_to_memory as "cuLogsDumpToMemory": PFN_cuLogsDumpToMemory;
438
439    // ---- Wave 24: graph memory nodes + graph-exec update ----
440    fn cu_graph_add_mem_alloc_node as "cuGraphAddMemAllocNode": PFN_cuGraphAddMemAllocNode;
441    fn cu_graph_mem_alloc_node_get_params as "cuGraphMemAllocNodeGetParams": PFN_cuGraphMemAllocNodeGetParams;
442    fn cu_graph_add_mem_free_node as "cuGraphAddMemFreeNode": PFN_cuGraphAddMemFreeNode;
443    fn cu_graph_mem_free_node_get_params as "cuGraphMemFreeNodeGetParams": PFN_cuGraphMemFreeNodeGetParams;
444    fn cu_device_graph_mem_trim as "cuDeviceGraphMemTrim": PFN_cuDeviceGraphMemTrim;
445    fn cu_device_get_graph_mem_attribute as "cuDeviceGetGraphMemAttribute": PFN_cuDeviceGetGraphMemAttribute;
446    fn cu_device_set_graph_mem_attribute as "cuDeviceSetGraphMemAttribute": PFN_cuDeviceSetGraphMemAttribute;
447    fn cu_graph_add_batch_mem_op_node as "cuGraphAddBatchMemOpNode": PFN_cuGraphAddBatchMemOpNode;
448    fn cu_graph_batch_mem_op_node_get_params as "cuGraphBatchMemOpNodeGetParams": PFN_cuGraphBatchMemOpNodeGetParams;
449    fn cu_graph_batch_mem_op_node_set_params as "cuGraphBatchMemOpNodeSetParams": PFN_cuGraphBatchMemOpNodeSetParams;
450    fn cu_graph_exec_batch_mem_op_node_set_params as "cuGraphExecBatchMemOpNodeSetParams": PFN_cuGraphExecBatchMemOpNodeSetParams;
451    fn cu_graph_exec_update as "cuGraphExecUpdate_v2": PFN_cuGraphExecUpdate;
452
453    // ---- Wave 25: stream memory ops ----
454    fn cu_stream_write_value_32 as "cuStreamWriteValue32_v2": PFN_cuStreamWriteValue32;
455    fn cu_stream_write_value_64 as "cuStreamWriteValue64_v2": PFN_cuStreamWriteValue64;
456    fn cu_stream_wait_value_32 as "cuStreamWaitValue32_v2": PFN_cuStreamWaitValue32;
457    fn cu_stream_wait_value_64 as "cuStreamWaitValue64_v2": PFN_cuStreamWaitValue64;
458    fn cu_stream_batch_mem_op as "cuStreamBatchMemOp_v2": PFN_cuStreamBatchMemOp;
459
460    // ---- Wave 27: v2 advise/prefetch + VMM reverse lookups ----
461    fn cu_mem_prefetch_async_v2 as "cuMemPrefetchAsync_v2": PFN_cuMemPrefetchAsyncV2;
462    fn cu_mem_advise_v2 as "cuMemAdvise_v2": PFN_cuMemAdviseV2;
463    fn cu_mem_map_array_async as "cuMemMapArrayAsync": PFN_cuMemMapArrayAsync;
464    fn cu_mem_get_handle_for_address_range as "cuMemGetHandleForAddressRange": PFN_cuMemGetHandleForAddressRange;
465    fn cu_mem_retain_allocation_handle as "cuMemRetainAllocationHandle": PFN_cuMemRetainAllocationHandle;
466    fn cu_mem_get_allocation_properties_from_handle as "cuMemGetAllocationPropertiesFromHandle": PFN_cuMemGetAllocationPropertiesFromHandle;
467    fn cu_mem_export_to_shareable_handle as "cuMemExportToShareableHandle": PFN_cuMemExportToShareableHandle;
468    fn cu_mem_import_from_shareable_handle as "cuMemImportFromShareableHandle": PFN_cuMemImportFromShareableHandle;
469    fn cu_mem_get_access as "cuMemGetAccess": PFN_cuMemGetAccess;
470
471    // ---- Wave 28: medium-value consolidated ----
472    fn cu_array_get_descriptor as "cuArrayGetDescriptor_v2": PFN_cuArrayGetDescriptor;
473    fn cu_array_get_sparse_properties as "cuArrayGetSparseProperties": PFN_cuArrayGetSparseProperties;
474    fn cu_mipmapped_array_get_sparse_properties as "cuMipmappedArrayGetSparseProperties": PFN_cuMipmappedArrayGetSparseProperties;
475    fn cu_array_get_memory_requirements as "cuArrayGetMemoryRequirements": PFN_cuArrayGetMemoryRequirements;
476    fn cu_mipmapped_array_get_memory_requirements as "cuMipmappedArrayGetMemoryRequirements": PFN_cuMipmappedArrayGetMemoryRequirements;
477    fn cu_array_get_plane as "cuArrayGetPlane": PFN_cuArrayGetPlane;
478    fn cu_ctx_record_event as "cuCtxRecordEvent": PFN_cuCtxRecordEvent;
479    fn cu_ctx_wait_event as "cuCtxWaitEvent": PFN_cuCtxWaitEvent;
480    fn cu_device_get_p2p_attribute as "cuDeviceGetP2PAttribute": PFN_cuDeviceGetP2PAttribute;
481    fn cu_device_get_exec_affinity_support as "cuDeviceGetExecAffinitySupport": PFN_cuDeviceGetExecAffinitySupport;
482    fn cu_flush_gpudirect_rdma_writes as "cuFlushGPUDirectRDMAWrites": PFN_cuFlushGPUDirectRDMAWrites;
483    fn cu_coredump_get_attribute as "cuCoredumpGetAttribute": PFN_cuCoredumpGetAttribute;
484    fn cu_coredump_get_attribute_global as "cuCoredumpGetAttributeGlobal": PFN_cuCoredumpGetAttributeGlobal;
485    fn cu_coredump_set_attribute as "cuCoredumpSetAttribute": PFN_cuCoredumpSetAttribute;
486    fn cu_coredump_set_attribute_global as "cuCoredumpSetAttributeGlobal": PFN_cuCoredumpSetAttributeGlobal;
487    fn cu_library_get_unified_function as "cuLibraryGetUnifiedFunction": PFN_cuLibraryGetUnifiedFunction;
488    fn cu_library_get_module as "cuLibraryGetModule": PFN_cuLibraryGetModule;
489    fn cu_library_get_kernel_count as "cuLibraryGetKernelCount": PFN_cuLibraryGetKernelCount;
490    fn cu_library_enumerate_kernels as "cuLibraryEnumerateKernels": PFN_cuLibraryEnumerateKernels;
491    fn cu_library_get_managed as "cuLibraryGetManaged": PFN_cuLibraryGetManaged;
492
493    // ---- Wave 29: graphics core + OpenGL ----
494    fn cu_graphics_unregister_resource as "cuGraphicsUnregisterResource": PFN_cuGraphicsUnregisterResource;
495    fn cu_graphics_map_resources as "cuGraphicsMapResources": PFN_cuGraphicsMapResources;
496    fn cu_graphics_unmap_resources as "cuGraphicsUnmapResources": PFN_cuGraphicsUnmapResources;
497    fn cu_graphics_resource_get_mapped_pointer as "cuGraphicsResourceGetMappedPointer_v2": PFN_cuGraphicsResourceGetMappedPointer;
498    fn cu_graphics_resource_get_mapped_mipmapped_array as "cuGraphicsResourceGetMappedMipmappedArray": PFN_cuGraphicsResourceGetMappedMipmappedArray;
499    fn cu_graphics_sub_resource_get_mapped_array as "cuGraphicsSubResourceGetMappedArray": PFN_cuGraphicsSubResourceGetMappedArray;
500    fn cu_graphics_resource_set_map_flags as "cuGraphicsResourceSetMapFlags_v2": PFN_cuGraphicsResourceSetMapFlags;
501    fn cu_gl_get_devices as "cuGLGetDevices_v2": PFN_cuGLGetDevices;
502    fn cu_graphics_gl_register_buffer as "cuGraphicsGLRegisterBuffer": PFN_cuGraphicsGLRegisterBuffer;
503    fn cu_graphics_gl_register_image as "cuGraphicsGLRegisterImage": PFN_cuGraphicsGLRegisterImage;
504    fn cu_gl_ctx_create as "cuGLCtxCreate_v2": PFN_cuGLCtxCreate;
505    fn cu_gl_init as "cuGLInit": PFN_cuGLInit;
506
507    // ---- Wave 30: Direct3D 9 / 10 / 11 ----
508    fn cu_d3d9_get_device as "cuD3D9GetDevice": PFN_cuD3D9GetDevice;
509    fn cu_d3d9_get_devices as "cuD3D9GetDevices": PFN_cuD3D9GetDevices;
510    fn cu_graphics_d3d9_register_resource as "cuGraphicsD3D9RegisterResource": PFN_cuGraphicsD3D9RegisterResource;
511    fn cu_d3d10_get_device as "cuD3D10GetDevice": PFN_cuD3D10GetDevice;
512    fn cu_d3d10_get_devices as "cuD3D10GetDevices": PFN_cuD3D10GetDevices;
513    fn cu_graphics_d3d10_register_resource as "cuGraphicsD3D10RegisterResource": PFN_cuGraphicsD3D10RegisterResource;
514    fn cu_d3d11_get_device as "cuD3D11GetDevice": PFN_cuD3D11GetDevice;
515    fn cu_d3d11_get_devices as "cuD3D11GetDevices": PFN_cuD3D11GetDevices;
516    fn cu_graphics_d3d11_register_resource as "cuGraphicsD3D11RegisterResource": PFN_cuGraphicsD3D11RegisterResource;
517
518    // ---- Wave 31: VDPAU + EGL + NvSci (Jetson) ----
519    // VDPAU (Linux)
520    fn cu_vdpau_get_device as "cuVDPAUGetDevice": PFN_cuVDPAUGetDevice;
521    fn cu_vdpau_ctx_create as "cuVDPAUCtxCreate_v2": PFN_cuVDPAUCtxCreate;
522    fn cu_graphics_vdpau_register_video_surface as "cuGraphicsVDPAURegisterVideoSurface": PFN_cuGraphicsVDPAURegisterVideoSurface;
523    fn cu_graphics_vdpau_register_output_surface as "cuGraphicsVDPAURegisterOutputSurface": PFN_cuGraphicsVDPAURegisterOutputSurface;
524
525    // EGL (Jetson / cross-platform)
526    fn cu_graphics_egl_register_image as "cuGraphicsEGLRegisterImage": PFN_cuGraphicsEGLRegisterImage;
527    fn cu_graphics_resource_get_mapped_egl_frame as "cuGraphicsResourceGetMappedEglFrame": PFN_cuGraphicsResourceGetMappedEglFrame;
528    fn cu_event_create_from_egl_sync as "cuEventCreateFromEGLSync": PFN_cuEventCreateFromEGLSync;
529    fn cu_egl_stream_consumer_connect as "cuEGLStreamConsumerConnect": PFN_cuEGLStreamConsumerConnect;
530    fn cu_egl_stream_consumer_disconnect as "cuEGLStreamConsumerDisconnect": PFN_cuEGLStreamConsumerDisconnect;
531    fn cu_egl_stream_consumer_acquire_frame as "cuEGLStreamConsumerAcquireFrame": PFN_cuEGLStreamConsumerAcquireFrame;
532    fn cu_egl_stream_consumer_release_frame as "cuEGLStreamConsumerReleaseFrame": PFN_cuEGLStreamConsumerReleaseFrame;
533    fn cu_egl_stream_producer_connect as "cuEGLStreamProducerConnect": PFN_cuEGLStreamProducerConnect;
534    fn cu_egl_stream_producer_disconnect as "cuEGLStreamProducerDisconnect": PFN_cuEGLStreamProducerDisconnect;
535    fn cu_egl_stream_producer_present_frame as "cuEGLStreamProducerPresentFrame": PFN_cuEGLStreamProducerPresentFrame;
536    fn cu_egl_stream_producer_return_frame as "cuEGLStreamProducerReturnFrame": PFN_cuEGLStreamProducerReturnFrame;
537
538    // NvSci (Jetson / DRIVE)
539    fn cu_device_get_nv_sci_sync_attributes as "cuDeviceGetNvSciSyncAttributes": PFN_cuDeviceGetNvSciSyncAttributes;
540}
541
542impl Driver {
543    /// Resolve `cuGetProcAddress` via `dlsym` (the only symbol we cannot
544    /// resolve through itself).
545    fn cu_get_proc_address(&self) -> Result<PFN_cuGetProcAddress, LoaderError> {
546        if let Some(&p) = self.get_proc_address.get() {
547            return Ok(p);
548        }
549        let p: PFN_cuGetProcAddress = unsafe { self.resolve_via_dlsym("cuGetProcAddress")? };
550        let _ = self.get_proc_address.set(p);
551        Ok(p)
552    }
553
554    /// Resolve `symbol` and transmute into the caller-specified fn-pointer
555    /// type. `cuGetProcAddress` is resolved via `dlsym` (it's our bootstrap).
556    /// Symbol names with an explicit version suffix like `_v2` / `_v3` are
557    /// also resolved via `dlsym` — because `cuGetProcAddress` with a
558    /// version-suffixed base name would _still_ do version-dispatch and
559    /// might return a newer ABI than our `PFN_*` signature expects.
560    /// Everything else goes through `cuGetProcAddress`, which transparently
561    /// picks `_ptsz` / `_ptds` variants based on the configured stream mode.
562    ///
563    /// # Safety
564    ///
565    /// `T` must be a function-pointer type whose signature matches the C
566    /// declaration of `symbol` in `cuda.h`. The symbol names we pass come
567    /// from the macro above and are checked against the NVIDIA docs.
568    unsafe fn resolve<T: Copy>(&self, symbol: &'static str) -> Result<T, LoaderError> { unsafe {
569        if symbol == "cuGetProcAddress" || has_version_suffix(symbol) {
570            return self.resolve_via_dlsym(symbol);
571        }
572        self.resolve_via_get_proc_address(symbol)
573    }}
574
575    /// Direct `dlsym` / `GetProcAddress`; used for `cuGetProcAddress` itself.
576    unsafe fn resolve_via_dlsym<T: Copy>(&self, symbol: &'static str) -> Result<T, LoaderError> { unsafe {
577        debug_assert_eq!(
578            core::mem::size_of::<T>(),
579            core::mem::size_of::<*mut ()>(),
580            "Driver::resolve_via_dlsym<T>: T must be a function-pointer type",
581        );
582        let raw: *mut () = self.lib.raw_symbol(symbol)?;
583        Ok(core::mem::transmute_copy::<*mut (), T>(&raw))
584    }}
585
586    /// Cached driver-reported CUDA version. Probed once via a direct
587    /// `dlsym` on `cuDriverGetVersion`. Falls back to the baracuda floor
588    /// (CUDA 11.4) if the probe fails, which is safe — `cuGetProcAddress`
589    /// treats the version as a minimum.
590    fn detected_cuda_version(&self) -> core::ffi::c_int {
591        use std::sync::OnceLock;
592        static CACHED: OnceLock<core::ffi::c_int> = OnceLock::new();
593        *CACHED.get_or_init(|| {
594            let raw: *mut () = match unsafe { self.lib.raw_symbol("cuDriverGetVersion") } {
595                Ok(p) => p,
596                Err(_) => return baracuda_types::CudaVersion::FLOOR.raw() as core::ffi::c_int,
597            };
598            type Fn = unsafe extern "C" fn(*mut core::ffi::c_int) -> CUresult;
599            // SAFETY: `cuDriverGetVersion` has a stable signature.
600            let f: Fn = unsafe { core::mem::transmute_copy::<*mut (), Fn>(&raw) };
601            let mut v: core::ffi::c_int = 0;
602            match unsafe { f(&mut v) } {
603                CUresult::SUCCESS if v > 0 => v,
604                _ => baracuda_types::CudaVersion::FLOOR.raw() as core::ffi::c_int,
605            }
606        })
607    }
608
609    /// Resolve `symbol` through `cuGetProcAddress`, passing the process's
610    /// configured stream mode as the flags argument.
611    unsafe fn resolve_via_get_proc_address<T: Copy>(
612        &self,
613        symbol: &'static str,
614    ) -> Result<T, LoaderError> { unsafe {
615        debug_assert_eq!(
616            core::mem::size_of::<T>(),
617            core::mem::size_of::<*mut ()>(),
618            "Driver::resolve_via_get_proc_address<T>: T must be a function-pointer type",
619        );
620        let gpa = self.cu_get_proc_address()?;
621        let flags = match stream_mode::get() {
622            StreamMode::Legacy => CU_GET_PROC_ADDRESS_LEGACY_STREAM,
623            StreamMode::PerThread => CU_GET_PROC_ADDRESS_PER_THREAD_DEFAULT_STREAM,
624        };
625        let c_sym: Vec<u8> = symbol.bytes().chain(std::iter::once(0)).collect();
626        let mut pfn: *mut core::ffi::c_void = ptr::null_mut();
627
628        // Two-stage resolution:
629        //   1. Query at the baracuda floor (CUDA 11.4). This pins the ABI
630        //      for symbols that have `_v2`/`_v3` variants — we want the
631        //      11.4-era shape because that's what our `PFN_*` type matches.
632        //   2. If the driver returns NOT_FOUND / VERSION_NOT_SUFFICIENT,
633        //      retry at the driver's own reported version. That catches
634        //      CUDA-12+ additions (cuLibraryLoadData, cuLaunchKernelEx)
635        //      without upgrading older symbols to their newer ABIs.
636        let floor = baracuda_types::CudaVersion::FLOOR.raw() as core::ffi::c_int;
637        let mut res = gpa(c_sym.as_ptr() as *const c_char, &mut pfn, floor, flags);
638        if res != CUresult::SUCCESS || pfn.is_null() {
639            let installed = self.detected_cuda_version();
640            if installed > floor {
641                pfn = ptr::null_mut();
642                res = gpa(c_sym.as_ptr() as *const c_char, &mut pfn, installed, flags);
643            }
644        }
645        if res != CUresult::SUCCESS || pfn.is_null() {
646            return Err(LoaderError::SymbolNotFound {
647                library: "cuda-driver",
648                symbol,
649            });
650        }
651        Ok(core::mem::transmute_copy::<*mut core::ffi::c_void, T>(&pfn))
652    }}
653}
654
655/// `true` if `sym` ends with `_v<N>` (version pin) or `_ptsz` / `_ptds`
656/// (stream-mode pin). Such names are resolved via `dlsym` so the driver
657/// doesn't silently upgrade us to a newer ABI.
658fn has_version_suffix(sym: &str) -> bool {
659    if sym.ends_with("_ptsz") || sym.ends_with("_ptds") {
660        return true;
661    }
662    if let Some(idx) = sym.rfind("_v") {
663        let tail = &sym[idx + 2..];
664        !tail.is_empty() && tail.chars().all(|c| c.is_ascii_digit())
665    } else {
666        false
667    }
668}
669
670/// Lazily-initialized process-wide Driver singleton.
671///
672/// The first successful call to [`driver`] caches the [`Driver`] in a
673/// `OnceLock`; subsequent calls return the same `&'static Driver`. If the
674/// first call fails (no libcuda, unsupported platform), the error is
675/// **not** memoized — a later call with a different environment has a
676/// chance of succeeding.
677pub fn driver() -> Result<&'static Driver, LoaderError> {
678    static DRIVER: OnceLock<Driver> = OnceLock::new();
679    if let Some(d) = DRIVER.get() {
680        return Ok(d);
681    }
682    let lib = Library::open("cuda-driver", platform::driver_library_candidates())?;
683    let d = Driver::empty(lib);
684    // If another thread raced us, our `d` is dropped; `DRIVER.get().unwrap()`
685    // still returns the thread-that-won's instance.
686    let _ = DRIVER.set(d);
687    Ok(DRIVER.get().expect("OnceLock set or lost race"))
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693
694    #[test]
695    fn driver_singleton_returns_loader_error_without_cuda() {
696        // On a machine without CUDA, `driver()` returns LoaderError. On a
697        // machine with CUDA, we'd expect Ok, but we can't tell from CI. We
698        // just verify we get *some* Result out — no panic.
699        let _ = driver();
700    }
701}