Skip to main content

oxicuda_driver/
lib.rs

1//! # OxiCUDA Driver
2//!
3//! **Dynamic, safe Rust bindings for the NVIDIA CUDA Driver API.**
4//!
5//! `oxicuda-driver` provides a zero-SDK-dependency wrapper around the CUDA
6//! Driver API.  Unlike traditional CUDA crate approaches that require the
7//! CUDA Toolkit (or at least its headers and link stubs) to be present at
8//! **build time**, this crate loads the driver shared library entirely at
9//! **runtime** via [`libloading`](https://crates.io/crates/libloading).
10//!
11//! ## Zero build-time dependency
12//!
13//! No `cuda.h`, no `libcuda.so` symlink, no `nvcc` — the crate compiles on
14//! any Rust toolchain.  The actual GPU driver is discovered and loaded the
15//! first time you call [`try_driver()`] or [`init()`].
16//!
17//! ## Runtime library loading
18//!
19//! | Platform | Library searched             |
20//! |----------|-----------------------------|
21//! | Linux    | `libcuda.so`, `libcuda.so.1` |
22//! | Windows  | `nvcuda.dll`                 |
23//! | macOS    | *(returns `UnsupportedPlatform` — NVIDIA dropped macOS support)* |
24//!
25//! ## Key types
26//!
27//! | Type          | Description                                    |
28//! |---------------|------------------------------------------------|
29//! | [`Device`]    | A CUDA-capable GPU discovered on the system    |
30//! | [`Context`]   | Owns a CUDA context bound to a device          |
31//! | [`Stream`]    | Asynchronous command queue within a context     |
32//! | [`Event`]     | Timing / synchronisation marker on a stream    |
33//! | [`Module`]    | Loaded PTX or cubin containing kernel code     |
34//! | [`Function`]  | A single kernel entry point inside a module    |
35//! | [`CudaError`] | Strongly-typed driver error code               |
36//!
37//! ## Quick start
38//!
39//! ```rust,no_run
40//! use oxicuda_driver::prelude::*;
41//!
42//! // Initialise the CUDA driver (loads libcuda at runtime).
43//! init()?;
44//!
45//! // Pick the best available GPU and create a context.
46//! let dev = Device::get(0)?;
47//! let _ctx = Context::new(&dev)?;
48//!
49//! // Load a PTX module and look up a kernel.
50//! let module = Module::from_ptx("ptx_source")?;
51//! let kernel = module.get_function("vector_add")?;
52//! # Ok::<(), oxicuda_driver::CudaError>(())
53//! ```
54
55#![warn(missing_docs)]
56#![warn(clippy::all)]
57#![allow(clippy::module_name_repetitions)]
58#![allow(clippy::missing_safety_doc)]
59#![allow(clippy::too_many_arguments)]
60#![allow(clippy::macro_metavars_in_unsafe)]
61
62// ---------------------------------------------------------------------------
63// Module declarations
64// ---------------------------------------------------------------------------
65
66pub mod context;
67pub mod context_config;
68pub mod cooperative_launch;
69pub mod debug;
70pub mod device;
71pub mod error;
72pub mod event;
73pub mod ffi;
74pub mod function_attr;
75pub mod graph;
76pub mod link;
77pub mod loader;
78pub mod memory_info;
79pub mod module;
80pub mod multi_gpu;
81pub mod nvlink_topology;
82pub mod occupancy;
83pub mod occupancy_ext;
84pub mod primary_context;
85pub mod profiler;
86pub mod stream;
87pub mod stream_ordered_alloc;
88pub mod tma;
89
90// ---------------------------------------------------------------------------
91// Re-exports — error handling
92// ---------------------------------------------------------------------------
93
94pub use error::{CudaError, CudaResult, DriverLoadError, check};
95
96// ---------------------------------------------------------------------------
97// Re-exports — FFI types and constants
98// ---------------------------------------------------------------------------
99
100pub use ffi::{
101    CU_TRSF_DISABLE_TRILINEAR_OPTIMIZATION, CU_TRSF_NORMALIZED_COORDINATES,
102    CU_TRSF_READ_AS_INTEGER, CU_TRSF_SRGB, CUDA_ARRAY_DESCRIPTOR, CUDA_ARRAY3D_CUBEMAP,
103    CUDA_ARRAY3D_DESCRIPTOR, CUDA_ARRAY3D_LAYERED, CUDA_ARRAY3D_SURFACE_LDST,
104    CUDA_ARRAY3D_TEXTURE_GATHER, CUDA_MEMCPY2D, CUDA_RESOURCE_DESC, CUDA_RESOURCE_VIEW_DESC,
105    CUDA_TEXTURE_DESC, CUaddress_mode, CUarray, CUarray_format, CUcontext, CUdevice,
106    CUdevice_attribute, CUdeviceptr, CUevent, CUfilter_mode, CUfunction, CUfunction_attribute,
107    CUjit_option, CUjitInputType, CUkernel, CUlibrary, CUlimit, CUlinkState, CUmemAccessDesc,
108    CUmemAccessFlags, CUmemAllocationHandleType, CUmemAllocationProp, CUmemAllocationType,
109    CUmemGenericAllocationHandle, CUmemLocation, CUmemLocationType, CUmemPoolProps, CUmemoryPool,
110    CUmemorytype, CUmipmappedArray, CUmodule, CUmulticastObject, CUpointer_attribute,
111    CUresourceViewFormat, CUresourcetype, CUstream, CUsurfObject, CUsurfref, CUtexObject, CUtexref,
112    CuLaunchAttribute, CuLaunchAttributeClusterDim, CuLaunchAttributeId, CuLaunchAttributeValue,
113    CuLaunchConfig, CudaResourceDescArray, CudaResourceDescLinear, CudaResourceDescMipmap,
114    CudaResourceDescPitch2d, CudaResourceDescRes,
115};
116
117// ---------------------------------------------------------------------------
118// Re-exports — high-level safe wrappers
119// ---------------------------------------------------------------------------
120
121pub use context::Context;
122pub use context_config::{CacheConfig, SharedMemConfig};
123pub use cooperative_launch::{
124    CooperativeLaunchConfig, CooperativeLaunchSupport, DeviceLaunchConfig,
125    MultiDeviceCooperativeLaunchConfig, cooperative_launch, cooperative_launch_multi_device,
126};
127pub use debug::{DebugLevel, DebugSession, KernelDebugger, MemoryChecker, NanInfChecker};
128pub use device::{Device, DeviceInfo, best_device, can_access_peer, driver_version, list_devices};
129pub use event::Event;
130pub use graph::{Graph, GraphExec, GraphNode, MemcpyDirection, StreamCapture};
131pub use link::{
132    FallbackStrategy, LinkInputType, LinkedModule, Linker, LinkerOptions, OptimizationLevel,
133};
134pub use loader::try_driver;
135pub use module::{Function, JitDiagnostic, JitLog, JitOptions, JitSeverity, Module};
136pub use multi_gpu::DevicePool;
137pub use nvlink_topology::{GpuTopology, NvLinkVersion, TopologyTree, TopologyType};
138pub use primary_context::PrimaryContext;
139pub use profiler::ProfilerGuard;
140pub use stream::Stream;
141pub use stream_ordered_alloc::{
142    StreamAllocation, StreamMemoryPool, StreamOrderedAllocConfig, stream_alloc, stream_free,
143};
144
145// ---------------------------------------------------------------------------
146// Driver initialisation
147// ---------------------------------------------------------------------------
148
149/// Initialise the CUDA driver API.
150///
151/// This must be called before any other driver function.  It is safe to call
152/// multiple times; subsequent calls are no-ops inside the driver itself.
153///
154/// Internally this loads the shared library (if not already cached) and
155/// invokes `cuInit(0)`.
156///
157/// # Errors
158///
159/// Returns [`CudaError::NotInitialized`] if the CUDA driver library cannot be
160/// loaded, or another [`CudaError`] variant if `cuInit` reports a failure.
161pub fn init() -> CudaResult<()> {
162    let driver = loader::try_driver()?;
163    error::check(unsafe { (driver.cu_init)(0) })
164}
165
166// ---------------------------------------------------------------------------
167// Prelude — convenient glob import
168// ---------------------------------------------------------------------------
169
170/// Convenient glob import for common OxiCUDA Driver types.
171///
172/// ```rust
173/// use oxicuda_driver::prelude::*;
174/// ```
175pub mod prelude {
176    pub use crate::{
177        CacheConfig, Context, CooperativeLaunchConfig, CooperativeLaunchSupport, CudaError,
178        CudaResult, DebugLevel, DebugSession, Device, DeviceLaunchConfig, DevicePool, Event,
179        FallbackStrategy, Function, GpuTopology, Graph, GraphExec, GraphNode, KernelDebugger,
180        LinkInputType, LinkedModule, Linker, LinkerOptions, MemcpyDirection, Module,
181        MultiDeviceCooperativeLaunchConfig, NvLinkVersion, OptimizationLevel, PrimaryContext,
182        ProfilerGuard, SharedMemConfig, Stream, StreamAllocation, StreamCapture, StreamMemoryPool,
183        StreamOrderedAllocConfig, TopologyTree, TopologyType, can_access_peer, cooperative_launch,
184        cooperative_launch_multi_device, driver_version, init, stream_alloc, stream_free,
185        try_driver,
186    };
187}
188
189// ---------------------------------------------------------------------------
190// Compile-time feature flags
191// ---------------------------------------------------------------------------
192
193/// Compile-time feature availability.
194pub mod features {
195    /// Whether GPU tests are enabled (`--features gpu-tests`).
196    pub const HAS_GPU_TESTS: bool = cfg!(feature = "gpu-tests");
197}
198
199// ---------------------------------------------------------------------------
200// CPU-only tests for driver infrastructure
201// ---------------------------------------------------------------------------
202
203#[cfg(test)]
204mod driver_infra_tests {
205    // -----------------------------------------------------------------------
206    // Task 2 — Multi-threaded context migration (F3)
207    //
208    // Verifies the thread-safety of the context-stack data structure model
209    // using pure Rust primitives.  No GPU is required.
210    // -----------------------------------------------------------------------
211
212    /// Simulate 4 threads each pushing and popping a "context ID" to/from a
213    /// thread-local stack, then verifying all results are collected correctly.
214    ///
215    /// This exercises the logical structure of context push/pop across threads
216    /// (corresponding to `cuCtxPushCurrent` / `cuCtxPopCurrent`) without
217    /// needing a real CUDA driver.
218    #[test]
219    fn context_push_pop_thread_safety() {
220        use std::sync::{Arc, Mutex};
221        use std::thread;
222
223        let results: Arc<Mutex<Vec<(u32, u32)>>> = Arc::new(Mutex::new(vec![]));
224        let mut handles = vec![];
225
226        for thread_id in 0..4u32 {
227            let results_clone = Arc::clone(&results);
228            let handle = thread::spawn(move || {
229                // Each thread simulates pushing two context IDs onto its
230                // private stack and then reading the top (most-recently-pushed)
231                // context.
232                let ctx_id = thread_id * 100;
233                let stack: Vec<u32> = vec![ctx_id, ctx_id + 1];
234                // Pop semantics: the top of the stack is the last element.
235                let top = stack.last().copied().unwrap_or(0);
236                let mut r = results_clone.lock().expect("results lock failed");
237                r.push((thread_id, top));
238            });
239            handles.push(handle);
240        }
241
242        for h in handles {
243            h.join().expect("thread panicked");
244        }
245
246        let results = results.lock().expect("final lock failed");
247        assert_eq!(results.len(), 4, "all 4 threads must contribute a result");
248
249        // Every thread should have seen `ctx_id + 1` as the top of its stack.
250        for &(thread_id, top) in results.iter() {
251            let expected_top = thread_id * 100 + 1;
252            assert_eq!(
253                top, expected_top,
254                "thread {thread_id}: expected top {expected_top}, got {top}"
255            );
256        }
257    }
258
259    // -----------------------------------------------------------------------
260    // Task 3 — Scope-exit / Drop resource release under OOM (F10)
261    //
262    // Verifies that Drop impls run correctly even when further allocations
263    // fail (simulated OOM), and that Rust's LIFO drop order is preserved.
264    // -----------------------------------------------------------------------
265
266    /// `Drop` is invoked for every resource that was successfully constructed,
267    /// even when a subsequent allocation would fail (simulated OOM).
268    #[test]
269    fn drop_counter_tracks_resource_release() {
270        use std::sync::Arc;
271        use std::sync::atomic::{AtomicUsize, Ordering};
272
273        struct FakeResource {
274            dropped: Arc<AtomicUsize>,
275        }
276
277        impl Drop for FakeResource {
278            fn drop(&mut self) {
279                self.dropped.fetch_add(1, Ordering::SeqCst);
280            }
281        }
282
283        let counter = Arc::new(AtomicUsize::new(0));
284
285        {
286            let _r1 = FakeResource {
287                dropped: Arc::clone(&counter),
288            };
289            let _r2 = FakeResource {
290                dropped: Arc::clone(&counter),
291            };
292            // Simulate OOM by not creating r3 — neither r1 nor r2 is dropped yet.
293            assert_eq!(
294                counter.load(Ordering::SeqCst),
295                0,
296                "resources must not be dropped before scope exit"
297            );
298        }
299
300        // After the block ends, both r1 and r2 must have been dropped.
301        assert_eq!(
302            counter.load(Ordering::SeqCst),
303            2,
304            "both resources must be dropped at scope exit"
305        );
306    }
307
308    /// Rust drops local variables in **reverse declaration order** (LIFO).
309    /// This test verifies that invariant for RAII guard types.
310    #[test]
311    fn drop_order_is_lifo() {
312        use std::sync::{Arc, Mutex};
313
314        let order: Arc<Mutex<Vec<u32>>> = Arc::new(Mutex::new(vec![]));
315
316        struct Ordered {
317            id: u32,
318            order: Arc<Mutex<Vec<u32>>>,
319        }
320
321        impl Drop for Ordered {
322            fn drop(&mut self) {
323                self.order.lock().expect("order lock failed").push(self.id);
324            }
325        }
326
327        {
328            let _a = Ordered {
329                id: 1,
330                order: Arc::clone(&order),
331            };
332            let _b = Ordered {
333                id: 2,
334                order: Arc::clone(&order),
335            };
336            let _c = Ordered {
337                id: 3,
338                order: Arc::clone(&order),
339            };
340        }
341
342        let observed = order.lock().expect("final order lock failed");
343        assert_eq!(
344            *observed,
345            vec![3, 2, 1],
346            "CUDA RAII guards must be released in LIFO order"
347        );
348    }
349
350    // -----------------------------------------------------------------------
351    // Task 4 — Driver version negotiation (NVIDIA Driver 525 / 535 / 550 / 560)
352    //
353    // `cuDriverGetVersion` returns the CUDA version as `major * 1000 + minor`.
354    // These tests verify the parsing logic and the version-gating conditions
355    // used throughout OxiCUDA without requiring a real driver.
356    // -----------------------------------------------------------------------
357
358    /// NVIDIA Driver 525 ships with CUDA 12.0.  Verify the parse of 12000.
359    #[test]
360    fn driver_version_parsing_cuda_12_0() {
361        // cuDriverGetVersion returns 12000 for CUDA 12.0 (driver 525).
362        let cuda_version: i32 = 12000;
363        let major = cuda_version / 1000;
364        let minor = cuda_version % 1000;
365        assert_eq!(major, 12, "major version mismatch");
366        assert_eq!(minor, 0, "minor version mismatch");
367    }
368
369    /// NVIDIA Driver 535 ships with CUDA 12.2.  Verify the parse of 12020.
370    #[test]
371    fn driver_version_parsing_cuda_12_2() {
372        let cuda_version: i32 = 12020;
373        let major = cuda_version / 1000;
374        let minor = cuda_version % 1000;
375        assert_eq!(major, 12);
376        assert_eq!(minor, 20);
377    }
378
379    /// NVIDIA Driver 550 ships with CUDA 12.4.  Verify the parse of 12040.
380    #[test]
381    fn driver_version_parsing_cuda_12_4() {
382        let cuda_version: i32 = 12040;
383        let major = cuda_version / 1000;
384        let minor = cuda_version % 1000;
385        assert_eq!(major, 12);
386        assert_eq!(minor, 40);
387    }
388
389    /// NVIDIA Driver 560 ships with CUDA 12.6.  Verify the parse of 12060.
390    #[test]
391    fn driver_version_parsing_cuda_12_6() {
392        let cuda_version: i32 = 12060;
393        let major = cuda_version / 1000;
394        let minor = cuda_version % 1000;
395        assert_eq!(major, 12);
396        assert_eq!(minor, 60);
397    }
398
399    /// OxiCUDA requires CUDA 11.2+ (`cuMemAllocAsync` availability).
400    /// Verify that the set of supported versions all meet the minimum and
401    /// that older versions are correctly rejected.
402    #[test]
403    fn driver_version_minimum_requirement() {
404        // cuMemAllocAsync was introduced in CUDA 11.2 (version integer 11020).
405        let min_required: i32 = 11020;
406
407        let supported: [i32; 5] = [11020, 11040, 12000, 12060, 12080];
408        for v in supported {
409            assert!(
410                v >= min_required,
411                "CUDA version {v} should be supported (>= {min_required})"
412            );
413        }
414
415        let too_old: [i32; 2] = [10020, 11010];
416        for v in too_old {
417            assert!(
418                v < min_required,
419                "CUDA version {v} should NOT be supported (< {min_required})"
420            );
421        }
422    }
423
424    /// CUDA 12.8 (version 12080) introduces `cuMemcpyBatchAsync`.
425    /// Verify the feature-gating arithmetic.
426    #[test]
427    fn driver_cuda_12_8_features_available() {
428        // 12.8 → 12080
429        let cuda_128: i32 = 12080;
430        assert!(
431            cuda_128 >= 12080,
432            "CUDA 12.8 must support cuMemcpyBatchAsync"
433        );
434
435        // 12.0 does not have it.
436        let cuda_120: i32 = 12000;
437        assert!(
438            cuda_120 < 12080,
439            "CUDA 12.0 must NOT support cuMemcpyBatchAsync"
440        );
441    }
442
443    /// Verify the complete NVIDIA-driver-version → CUDA-version mapping used
444    /// in OxiCUDA's version negotiation table.
445    #[test]
446    fn driver_nvidia_to_cuda_version_mapping() {
447        // (nvidia_driver, expected_cuda_version_int)
448        let mapping: [(u32, i32); 4] = [
449            (525, 12000), // Driver 525  → CUDA 12.0
450            (535, 12020), // Driver 535  → CUDA 12.2
451            (550, 12040), // Driver 550  → CUDA 12.4
452            (560, 12060), // Driver 560  → CUDA 12.6
453        ];
454
455        for (nvidia_driver, cuda_version) in mapping {
456            let major = cuda_version / 1000;
457            let minor = cuda_version % 1000;
458            // Sanity: all are CUDA 12.x
459            assert_eq!(major, 12, "driver {nvidia_driver}: expected CUDA 12.x");
460            // Minor must be a multiple of 10 (CUDA minor encoding)
461            assert_eq!(
462                minor % 10,
463                0,
464                "driver {nvidia_driver}: minor {minor} is not a multiple of 10"
465            );
466            // CUDA 12.8+ features require version >= 12080
467            let has_12_8_features = cuda_version >= 12080;
468            assert!(
469                !has_12_8_features,
470                "driver {nvidia_driver} (CUDA {major}.{:02}) should NOT have 12.8+ features",
471                minor / 10
472            );
473        }
474    }
475}