Skip to main content

oxicuda_runtime/
device.rs

1//! Device management — `cudaGetDeviceCount`, `cudaSetDevice`, `cudaGetDevice`,
2//! `cudaGetDeviceProperties`, `cudaDeviceSynchronize`, `cudaDeviceReset`.
3//!
4//! The CUDA Runtime maintains a *per-thread* current device.  This module
5//! stores that state in a thread-local and exposes the standard Runtime API
6//! surface on top of the underlying driver API.
7
8use std::cell::Cell;
9use std::ffi::c_int;
10
11use oxicuda_driver::loader::try_driver;
12
13use crate::error::{CudaRtError, CudaRtResult};
14
15// ─── Thread-local current device ─────────────────────────────────────────────
16
17thread_local! {
18    /// The device ordinal bound to the calling thread.
19    /// `None` means no device has been selected yet.
20    static CURRENT_DEVICE: Cell<Option<c_int>> = const { Cell::new(None) };
21}
22
23// ─── cudaDeviceProp ───────────────────────────────────────────────────────────
24
25/// Subset of `cudaDeviceProp` exposed by the Runtime API.
26///
27/// Field names intentionally match the CUDA Runtime documentation so that code
28/// written against `cudaGetDeviceProperties` compiles with minimal changes.
29#[derive(Debug, Clone)]
30pub struct CudaDeviceProp {
31    /// ASCII name of the device, e.g. `"NVIDIA GeForce RTX 4090"`.
32    pub name: String,
33    /// Total amount of global memory in bytes.
34    pub total_global_mem: usize,
35    /// Shared memory per block in bytes.
36    pub shared_mem_per_block: usize,
37    /// 32-bit registers per block.
38    pub regs_per_block: u32,
39    /// Warp size (threads per warp).
40    pub warp_size: u32,
41    /// Maximum pitch in bytes for `cudaMallocPitch`.
42    pub mem_pitch: usize,
43    /// Maximum number of threads per block.
44    pub max_threads_per_block: u32,
45    /// Maximum size of each dimension of a block `[x, y, z]`.
46    pub max_threads_dim: [u32; 3],
47    /// Maximum size of each dimension of a grid `[x, y, z]`.
48    pub max_grid_size: [u32; 3],
49    /// Clock frequency in kilohertz.
50    pub clock_rate: u32,
51    /// Total constant memory available on device in bytes.
52    pub total_const_mem: usize,
53    /// Major revision number of device's compute capability.
54    pub major: u32,
55    /// Minor revision number of device's compute capability.
56    pub minor: u32,
57    /// Alignment requirement for textures (in bytes).
58    pub texture_alignment: usize,
59    /// Pitch alignment requirement for texture references (in bytes).
60    pub texture_pitch_alignment: usize,
61    /// `true` if device can concurrently copy and execute a kernel.
62    pub device_overlap: bool,
63    /// Number of multiprocessors on device.
64    pub multi_processor_count: u32,
65    /// `true` if device has ECC support enabled.
66    pub ecc_enabled: bool,
67    /// `true` if device is an integrated (on-chip) GPU.
68    pub integrated: bool,
69    /// `true` if device can map host memory.
70    pub can_map_host_memory: bool,
71    /// `true` if device supports unified virtual addressing.
72    pub unified_addressing: bool,
73    /// Peak memory clock frequency in kilohertz.
74    pub memory_clock_rate: u32,
75    /// Global memory bus width in bits.
76    pub memory_bus_width: u32,
77    /// Size of the L2 cache in bytes (0 if not applicable).
78    pub l2_cache_size: u32,
79    /// Maximum number of resident threads per multiprocessor.
80    pub max_threads_per_multi_processor: u32,
81    /// Device supports stream priorities.
82    pub stream_priorities_supported: bool,
83    /// Shared memory per multiprocessor in bytes.
84    pub shared_mem_per_multiprocessor: usize,
85    /// 32-bit registers per multiprocessor.
86    pub regs_per_multiprocessor: u32,
87    /// Device supports allocating managed memory.
88    pub managed_memory: bool,
89    /// Device is on a multi-GPU board.
90    pub is_multi_gpu_board: bool,
91    /// Unique identifier for a group of devices on the same multi-GPU board.
92    pub multi_gpu_board_group_id: u32,
93    /// Link between the device and the host supports native atomic operations.
94    pub host_native_atomic_supported: bool,
95    /// `true` if the device supports Cooperative Launch.
96    pub cooperative_launch: bool,
97    /// `true` if the device supports Multi-Device Cooperative Launch.
98    pub cooperative_multi_device_launch: bool,
99    /// Maximum number of blocks per multiprocessor.
100    pub max_blocks_per_multi_processor: u32,
101    /// Per-device maximum shared memory per block usable without opt-in.
102    pub shared_mem_per_block_optin: usize,
103    /// `true` if device supports cluster launch.
104    pub cluster_launch: bool,
105}
106
107impl CudaDeviceProp {
108    /// Construct a `CudaDeviceProp` by querying device attributes via the
109    /// CUDA Driver API.
110    ///
111    /// # Errors
112    ///
113    /// Returns an error if the driver is not loaded or any attribute query
114    /// fails.
115    pub fn from_device(ordinal: c_int) -> CudaRtResult<Self> {
116        let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
117
118        // Helper: query one attribute; map driver error to runtime error.
119        let attr = |a: oxicuda_driver::ffi::CUdevice_attribute| -> CudaRtResult<u32> {
120            let mut v: c_int = 0;
121            // SAFETY: FFI; driver validates the attribute enum.
122            let rc = unsafe { (api.cu_device_get_attribute)(&raw mut v, a, ordinal) };
123            if rc != 0 {
124                return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
125            }
126            Ok(v as u32)
127        };
128
129        use oxicuda_driver::ffi::CUdevice_attribute as A;
130
131        // Device name
132        let mut name_buf = [0u8; 256];
133        // SAFETY: FFI; name_buf is valid and len matches.
134        unsafe {
135            (api.cu_device_get_name)(
136                name_buf.as_mut_ptr() as *mut std::ffi::c_char,
137                name_buf.len() as c_int,
138                ordinal,
139            );
140        }
141        let name = {
142            let nul = name_buf
143                .iter()
144                .position(|&b| b == 0)
145                .unwrap_or(name_buf.len());
146            String::from_utf8_lossy(&name_buf[..nul]).into_owned()
147        };
148
149        // Total global memory
150        let mut total_global_mem: usize = 0;
151        // SAFETY: FFI; pointer is valid.
152        unsafe {
153            (api.cu_device_total_mem_v2)(&raw mut total_global_mem, ordinal);
154        }
155
156        Ok(Self {
157            name,
158            total_global_mem,
159            shared_mem_per_block: attr(A::MaxSharedMemoryPerBlock)? as usize,
160            regs_per_block: attr(A::MaxRegistersPerBlock)?,
161            warp_size: attr(A::WarpSize)?,
162            mem_pitch: attr(A::MaxPitch)? as usize,
163            max_threads_per_block: attr(A::MaxThreadsPerBlock)?,
164            max_threads_dim: [
165                attr(A::MaxBlockDimX)?,
166                attr(A::MaxBlockDimY)?,
167                attr(A::MaxBlockDimZ)?,
168            ],
169            max_grid_size: [
170                attr(A::MaxGridDimX)?,
171                attr(A::MaxGridDimY)?,
172                attr(A::MaxGridDimZ)?,
173            ],
174            clock_rate: attr(A::ClockRate)?,
175            total_const_mem: attr(A::TotalConstantMemory)? as usize,
176            major: attr(A::ComputeCapabilityMajor)?,
177            minor: attr(A::ComputeCapabilityMinor)?,
178            texture_alignment: attr(A::TextureAlignment)? as usize,
179            texture_pitch_alignment: attr(A::TexturePitchAlignment)? as usize,
180            device_overlap: attr(A::GpuOverlap)? != 0,
181            multi_processor_count: attr(A::MultiprocessorCount)?,
182            ecc_enabled: attr(A::EccEnabled)? != 0,
183            integrated: attr(A::Integrated)? != 0,
184            can_map_host_memory: attr(A::CanMapHostMemory)? != 0,
185            unified_addressing: attr(A::UnifiedAddressing)? != 0,
186            memory_clock_rate: attr(A::MemoryClockRate)?,
187            memory_bus_width: attr(A::GlobalMemoryBusWidth)?,
188            l2_cache_size: attr(A::L2CacheSize)?,
189            max_threads_per_multi_processor: attr(A::MaxThreadsPerMultiprocessor)?,
190            stream_priorities_supported: attr(A::StreamPrioritiesSupported)? != 0,
191            shared_mem_per_multiprocessor: attr(A::MaxSharedMemoryPerMultiprocessor)? as usize,
192            regs_per_multiprocessor: attr(A::MaxRegistersPerMultiprocessor)?,
193            managed_memory: attr(A::ManagedMemory)? != 0,
194            is_multi_gpu_board: attr(A::IsMultiGpuBoard)? != 0,
195            multi_gpu_board_group_id: attr(A::MultiGpuBoardGroupId)?,
196            host_native_atomic_supported: attr(A::HostNativeAtomicSupported)? != 0,
197            cooperative_launch: attr(A::CooperativeLaunch)? != 0,
198            cooperative_multi_device_launch: attr(A::CooperativeMultiDeviceLaunch)? != 0,
199            max_blocks_per_multi_processor: attr(A::MaxBlocksPerMultiprocessor)?,
200            shared_mem_per_block_optin: attr(A::MaxSharedMemoryPerBlockOptin)? as usize,
201            cluster_launch: attr(A::ClusterLaunch)? != 0,
202        })
203    }
204}
205
206// ─── Public API ───────────────────────────────────────────────────────────────
207
208/// Returns the number of CUDA-capable devices.
209///
210/// Mirrors `cudaGetDeviceCount`.
211///
212/// # Errors
213///
214/// Returns [`CudaRtError::DriverNotAvailable`] if the CUDA driver is not
215/// installed, or [`CudaRtError::NoGpu`] on systems with zero CUDA devices.
216pub fn get_device_count() -> CudaRtResult<u32> {
217    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
218    let mut count: c_int = 0;
219    // SAFETY: FFI; count is a valid stack-allocated i32. cuInit(0) was called
220    // during DriverApi::load(), so the driver is guaranteed to be initialised.
221    let rc = unsafe { (api.cu_device_get_count)(&raw mut count) };
222    if rc != 0 {
223        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::NoGpu));
224    }
225    if count == 0 {
226        return Err(CudaRtError::NoGpu);
227    }
228    Ok(count as u32)
229}
230
231/// Selects `device` as the current CUDA device for the calling thread.
232///
233/// Mirrors `cudaSetDevice`.
234///
235/// # Errors
236///
237/// Returns [`CudaRtError::InvalidDevice`] if `device >= get_device_count()`,
238/// or [`CudaRtError::DriverNotAvailable`] if the driver is absent.
239pub fn set_device(device: u32) -> CudaRtResult<()> {
240    let count = get_device_count()?;
241    if device >= count {
242        return Err(CudaRtError::InvalidDevice);
243    }
244    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
245    // Retain the primary context for this device (creates it if necessary) and
246    // make it the current context on the calling thread.  This is what the real
247    // `cudaSetDevice` does internally; without it, driver API calls that need
248    // a context (cuMemAlloc, cuLaunchKernel, …) fail with DeviceUninitialized.
249    let mut ctx = oxicuda_driver::ffi::CUcontext::default();
250    // SAFETY: FFI; device is a valid ordinal (checked above).
251    let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut ctx, device as c_int) };
252    if rc != 0 {
253        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
254    }
255    // SAFETY: ctx is a valid primary context handle.
256    let rc = unsafe { (api.cu_ctx_set_current)(ctx) };
257    if rc != 0 {
258        return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
259    }
260    CURRENT_DEVICE.with(|cell| cell.set(Some(device as c_int)));
261    Ok(())
262}
263
264/// Returns the ordinal of the current CUDA device for the calling thread.
265///
266/// Mirrors `cudaGetDevice`.
267///
268/// # Errors
269///
270/// Returns [`CudaRtError::DeviceNotSet`] if no device has been selected.
271pub fn get_device() -> CudaRtResult<u32> {
272    CURRENT_DEVICE.with(|cell| {
273        cell.get()
274            .map(|d| d as u32)
275            .ok_or(CudaRtError::DeviceNotSet)
276    })
277}
278
279/// Returns properties of the specified device.
280///
281/// Mirrors `cudaGetDeviceProperties`.
282///
283/// # Errors
284///
285/// Propagates driver errors or returns [`CudaRtError::InvalidDevice`] for
286/// out-of-range ordinals.
287pub fn get_device_properties(device: u32) -> CudaRtResult<CudaDeviceProp> {
288    let count = get_device_count()?;
289    if device >= count {
290        return Err(CudaRtError::InvalidDevice);
291    }
292    CudaDeviceProp::from_device(device as c_int)
293}
294
295/// Blocks until all preceding tasks in the current device's context complete.
296///
297/// Mirrors `cudaDeviceSynchronize`.
298///
299/// # Errors
300///
301/// Returns [`CudaRtError::DeviceNotSet`] if no device is selected, or
302/// a driver error if synchronization fails.
303pub fn device_synchronize() -> CudaRtResult<()> {
304    let _device = get_device()?;
305    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
306    // SAFETY: FFI; driver's current context is valid.
307    unsafe { (api.cu_ctx_synchronize)() };
308    Ok(())
309}
310
311/// Explicitly destroys and cleans up all resources associated with the current
312/// device in the current process.
313///
314/// Mirrors `cudaDeviceReset`.
315///
316/// # Errors
317///
318/// Propagates driver errors.
319pub fn device_reset() -> CudaRtResult<()> {
320    let _device = get_device()?;
321    let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
322    // The Runtime API implements reset by resetting the primary context.
323    // We obtain and then release the primary context to force a reset.
324    let mut _dev: c_int = 0;
325    CURRENT_DEVICE.with(|cell| {
326        if let Some(d) = cell.get() {
327            _dev = d;
328        }
329    });
330    // SAFETY: FFI; _dev is a valid device ordinal.
331    unsafe { (api.cu_device_primary_ctx_reset_v2)(_dev) };
332    // Forget the device binding for this thread.
333    CURRENT_DEVICE.with(|cell| cell.set(None));
334    Ok(())
335}
336
337/// Returns the compute capability as a `(major, minor)` pair.
338///
339/// Convenience helper built on top of [`get_device_properties`].
340///
341/// # Errors
342///
343/// Propagates errors from `get_device_properties`.
344pub fn get_compute_capability(device: u32) -> CudaRtResult<(u32, u32)> {
345    let props = get_device_properties(device)?;
346    Ok((props.major, props.minor))
347}
348
349// ─── Tests ───────────────────────────────────────────────────────────────────
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    #[cfg(not(feature = "gpu-tests"))]
357    fn get_device_without_set_errors() {
358        // Fresh thread has no device set.
359        let result = get_device();
360        assert!(matches!(result, Err(CudaRtError::DeviceNotSet)));
361    }
362
363    #[test]
364    fn set_device_persists_in_thread() {
365        // We can't test against a real GPU here, but we can verify the
366        // thread-local logic is consistent if get_device_count works.
367        // If driver is absent, both calls return the same error class.
368        let count_result = get_device_count();
369        match count_result {
370            Err(CudaRtError::DriverNotAvailable) | Err(CudaRtError::NoGpu) => {
371                // No GPU environment — only verify DeviceNotSet.
372                assert!(get_device().is_err());
373            }
374            Ok(n) => {
375                // Real GPU: set device 0 and verify round-trip.
376                set_device(0).expect("set_device(0) failed");
377                assert_eq!(get_device().unwrap(), 0);
378                // Out-of-range device must fail.
379                assert!(matches!(set_device(n), Err(CudaRtError::InvalidDevice)));
380            }
381            Err(e) => panic!("unexpected error: {e}"),
382        }
383    }
384
385    #[test]
386    fn from_code_round_trip() {
387        // Verify that error::from_code covers the device-specific codes.
388        assert_eq!(CudaRtError::from_code(100), Some(CudaRtError::NoDevice));
389        assert_eq!(
390            CudaRtError::from_code(101),
391            Some(CudaRtError::InvalidDevice)
392        );
393    }
394}