oxicuda_runtime/device.rs
1//! Device management — `cudaGetDeviceCount`, `cudaSetDevice`, `cudaGetDevice`,
2//! `cudaGetDeviceProperties`, `cudaDeviceSynchronize`, `cudaDeviceReset`.
3//!
4//! The CUDA Runtime maintains a *per-thread* current device. This module
5//! stores that state in a thread-local and exposes the standard Runtime API
6//! surface on top of the underlying driver API.
7
8use std::cell::Cell;
9use std::ffi::c_int;
10
11use oxicuda_driver::loader::try_driver;
12
13use crate::error::{CudaRtError, CudaRtResult};
14
15// ─── Thread-local current device ─────────────────────────────────────────────
16
17thread_local! {
18 /// The device ordinal bound to the calling thread.
19 /// `None` means no device has been selected yet.
20 static CURRENT_DEVICE: Cell<Option<c_int>> = const { Cell::new(None) };
21}
22
23// ─── cudaDeviceProp ───────────────────────────────────────────────────────────
24
25/// Subset of `cudaDeviceProp` exposed by the Runtime API.
26///
27/// Field names intentionally match the CUDA Runtime documentation so that code
28/// written against `cudaGetDeviceProperties` compiles with minimal changes.
29#[derive(Debug, Clone)]
30pub struct CudaDeviceProp {
31 /// ASCII name of the device, e.g. `"NVIDIA GeForce RTX 4090"`.
32 pub name: String,
33 /// Total amount of global memory in bytes.
34 pub total_global_mem: usize,
35 /// Shared memory per block in bytes.
36 pub shared_mem_per_block: usize,
37 /// 32-bit registers per block.
38 pub regs_per_block: u32,
39 /// Warp size (threads per warp).
40 pub warp_size: u32,
41 /// Maximum pitch in bytes for `cudaMallocPitch`.
42 pub mem_pitch: usize,
43 /// Maximum number of threads per block.
44 pub max_threads_per_block: u32,
45 /// Maximum size of each dimension of a block `[x, y, z]`.
46 pub max_threads_dim: [u32; 3],
47 /// Maximum size of each dimension of a grid `[x, y, z]`.
48 pub max_grid_size: [u32; 3],
49 /// Clock frequency in kilohertz.
50 pub clock_rate: u32,
51 /// Total constant memory available on device in bytes.
52 pub total_const_mem: usize,
53 /// Major revision number of device's compute capability.
54 pub major: u32,
55 /// Minor revision number of device's compute capability.
56 pub minor: u32,
57 /// Alignment requirement for textures (in bytes).
58 pub texture_alignment: usize,
59 /// Pitch alignment requirement for texture references (in bytes).
60 pub texture_pitch_alignment: usize,
61 /// `true` if device can concurrently copy and execute a kernel.
62 pub device_overlap: bool,
63 /// Number of multiprocessors on device.
64 pub multi_processor_count: u32,
65 /// `true` if device has ECC support enabled.
66 pub ecc_enabled: bool,
67 /// `true` if device is an integrated (on-chip) GPU.
68 pub integrated: bool,
69 /// `true` if device can map host memory.
70 pub can_map_host_memory: bool,
71 /// `true` if device supports unified virtual addressing.
72 pub unified_addressing: bool,
73 /// Peak memory clock frequency in kilohertz.
74 pub memory_clock_rate: u32,
75 /// Global memory bus width in bits.
76 pub memory_bus_width: u32,
77 /// Size of the L2 cache in bytes (0 if not applicable).
78 pub l2_cache_size: u32,
79 /// Maximum number of resident threads per multiprocessor.
80 pub max_threads_per_multi_processor: u32,
81 /// Device supports stream priorities.
82 pub stream_priorities_supported: bool,
83 /// Shared memory per multiprocessor in bytes.
84 pub shared_mem_per_multiprocessor: usize,
85 /// 32-bit registers per multiprocessor.
86 pub regs_per_multiprocessor: u32,
87 /// Device supports allocating managed memory.
88 pub managed_memory: bool,
89 /// Device is on a multi-GPU board.
90 pub is_multi_gpu_board: bool,
91 /// Unique identifier for a group of devices on the same multi-GPU board.
92 pub multi_gpu_board_group_id: u32,
93 /// Link between the device and the host supports native atomic operations.
94 pub host_native_atomic_supported: bool,
95 /// `true` if the device supports Cooperative Launch.
96 pub cooperative_launch: bool,
97 /// `true` if the device supports Multi-Device Cooperative Launch.
98 pub cooperative_multi_device_launch: bool,
99 /// Maximum number of blocks per multiprocessor.
100 pub max_blocks_per_multi_processor: u32,
101 /// Per-device maximum shared memory per block usable without opt-in.
102 pub shared_mem_per_block_optin: usize,
103 /// `true` if device supports cluster launch.
104 pub cluster_launch: bool,
105}
106
107impl CudaDeviceProp {
108 /// Construct a `CudaDeviceProp` by querying device attributes via the
109 /// CUDA Driver API.
110 ///
111 /// # Errors
112 ///
113 /// Returns an error if the driver is not loaded or any attribute query
114 /// fails.
115 pub fn from_device(ordinal: c_int) -> CudaRtResult<Self> {
116 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
117
118 // Helper: query one attribute; map driver error to runtime error.
119 let attr = |a: oxicuda_driver::ffi::CUdevice_attribute| -> CudaRtResult<u32> {
120 let mut v: c_int = 0;
121 // SAFETY: FFI; driver validates the attribute enum.
122 let rc = unsafe { (api.cu_device_get_attribute)(&raw mut v, a, ordinal) };
123 if rc != 0 {
124 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
125 }
126 Ok(v as u32)
127 };
128
129 use oxicuda_driver::ffi::CUdevice_attribute as A;
130
131 // Device name
132 let mut name_buf = [0u8; 256];
133 // SAFETY: FFI; name_buf is valid and len matches.
134 unsafe {
135 (api.cu_device_get_name)(
136 name_buf.as_mut_ptr() as *mut std::ffi::c_char,
137 name_buf.len() as c_int,
138 ordinal,
139 );
140 }
141 let name = {
142 let nul = name_buf
143 .iter()
144 .position(|&b| b == 0)
145 .unwrap_or(name_buf.len());
146 String::from_utf8_lossy(&name_buf[..nul]).into_owned()
147 };
148
149 // Total global memory
150 let mut total_global_mem: usize = 0;
151 // SAFETY: FFI; pointer is valid.
152 unsafe {
153 (api.cu_device_total_mem_v2)(&raw mut total_global_mem, ordinal);
154 }
155
156 Ok(Self {
157 name,
158 total_global_mem,
159 shared_mem_per_block: attr(A::MaxSharedMemoryPerBlock)? as usize,
160 regs_per_block: attr(A::MaxRegistersPerBlock)?,
161 warp_size: attr(A::WarpSize)?,
162 mem_pitch: attr(A::MaxPitch)? as usize,
163 max_threads_per_block: attr(A::MaxThreadsPerBlock)?,
164 max_threads_dim: [
165 attr(A::MaxBlockDimX)?,
166 attr(A::MaxBlockDimY)?,
167 attr(A::MaxBlockDimZ)?,
168 ],
169 max_grid_size: [
170 attr(A::MaxGridDimX)?,
171 attr(A::MaxGridDimY)?,
172 attr(A::MaxGridDimZ)?,
173 ],
174 clock_rate: attr(A::ClockRate)?,
175 total_const_mem: attr(A::TotalConstantMemory)? as usize,
176 major: attr(A::ComputeCapabilityMajor)?,
177 minor: attr(A::ComputeCapabilityMinor)?,
178 texture_alignment: attr(A::TextureAlignment)? as usize,
179 texture_pitch_alignment: attr(A::TexturePitchAlignment)? as usize,
180 device_overlap: attr(A::GpuOverlap)? != 0,
181 multi_processor_count: attr(A::MultiprocessorCount)?,
182 ecc_enabled: attr(A::EccEnabled)? != 0,
183 integrated: attr(A::Integrated)? != 0,
184 can_map_host_memory: attr(A::CanMapHostMemory)? != 0,
185 unified_addressing: attr(A::UnifiedAddressing)? != 0,
186 memory_clock_rate: attr(A::MemoryClockRate)?,
187 memory_bus_width: attr(A::GlobalMemoryBusWidth)?,
188 l2_cache_size: attr(A::L2CacheSize)?,
189 max_threads_per_multi_processor: attr(A::MaxThreadsPerMultiprocessor)?,
190 stream_priorities_supported: attr(A::StreamPrioritiesSupported)? != 0,
191 shared_mem_per_multiprocessor: attr(A::MaxSharedMemoryPerMultiprocessor)? as usize,
192 regs_per_multiprocessor: attr(A::MaxRegistersPerMultiprocessor)?,
193 managed_memory: attr(A::ManagedMemory)? != 0,
194 is_multi_gpu_board: attr(A::IsMultiGpuBoard)? != 0,
195 multi_gpu_board_group_id: attr(A::MultiGpuBoardGroupId)?,
196 host_native_atomic_supported: attr(A::HostNativeAtomicSupported)? != 0,
197 cooperative_launch: attr(A::CooperativeLaunch)? != 0,
198 cooperative_multi_device_launch: attr(A::CooperativeMultiDeviceLaunch)? != 0,
199 max_blocks_per_multi_processor: attr(A::MaxBlocksPerMultiprocessor)?,
200 shared_mem_per_block_optin: attr(A::MaxSharedMemoryPerBlockOptin)? as usize,
201 cluster_launch: attr(A::ClusterLaunch)? != 0,
202 })
203 }
204}
205
206// ─── Public API ───────────────────────────────────────────────────────────────
207
208/// Returns the number of CUDA-capable devices.
209///
210/// Mirrors `cudaGetDeviceCount`.
211///
212/// # Errors
213///
214/// Returns [`CudaRtError::DriverNotAvailable`] if the CUDA driver is not
215/// installed, or [`CudaRtError::NoGpu`] on systems with zero CUDA devices.
216pub fn get_device_count() -> CudaRtResult<u32> {
217 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
218 let mut count: c_int = 0;
219 // SAFETY: FFI; count is a valid stack-allocated i32. cuInit(0) was called
220 // during DriverApi::load(), so the driver is guaranteed to be initialised.
221 let rc = unsafe { (api.cu_device_get_count)(&raw mut count) };
222 if rc != 0 {
223 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::NoGpu));
224 }
225 if count == 0 {
226 return Err(CudaRtError::NoGpu);
227 }
228 Ok(count as u32)
229}
230
231/// Selects `device` as the current CUDA device for the calling thread.
232///
233/// Mirrors `cudaSetDevice`.
234///
235/// # Errors
236///
237/// Returns [`CudaRtError::InvalidDevice`] if `device >= get_device_count()`,
238/// or [`CudaRtError::DriverNotAvailable`] if the driver is absent.
239pub fn set_device(device: u32) -> CudaRtResult<()> {
240 let count = get_device_count()?;
241 if device >= count {
242 return Err(CudaRtError::InvalidDevice);
243 }
244 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
245 // Retain the primary context for this device (creates it if necessary) and
246 // make it the current context on the calling thread. This is what the real
247 // `cudaSetDevice` does internally; without it, driver API calls that need
248 // a context (cuMemAlloc, cuLaunchKernel, …) fail with DeviceUninitialized.
249 let mut ctx = oxicuda_driver::ffi::CUcontext::default();
250 // SAFETY: FFI; device is a valid ordinal (checked above).
251 let rc = unsafe { (api.cu_device_primary_ctx_retain)(&raw mut ctx, device as c_int) };
252 if rc != 0 {
253 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
254 }
255 // SAFETY: ctx is a valid primary context handle.
256 let rc = unsafe { (api.cu_ctx_set_current)(ctx) };
257 if rc != 0 {
258 return Err(CudaRtError::from_code(rc).unwrap_or(CudaRtError::InvalidDevice));
259 }
260 CURRENT_DEVICE.with(|cell| cell.set(Some(device as c_int)));
261 Ok(())
262}
263
264/// Returns the ordinal of the current CUDA device for the calling thread.
265///
266/// Mirrors `cudaGetDevice`.
267///
268/// # Errors
269///
270/// Returns [`CudaRtError::DeviceNotSet`] if no device has been selected.
271pub fn get_device() -> CudaRtResult<u32> {
272 CURRENT_DEVICE.with(|cell| {
273 cell.get()
274 .map(|d| d as u32)
275 .ok_or(CudaRtError::DeviceNotSet)
276 })
277}
278
279/// Returns properties of the specified device.
280///
281/// Mirrors `cudaGetDeviceProperties`.
282///
283/// # Errors
284///
285/// Propagates driver errors or returns [`CudaRtError::InvalidDevice`] for
286/// out-of-range ordinals.
287pub fn get_device_properties(device: u32) -> CudaRtResult<CudaDeviceProp> {
288 let count = get_device_count()?;
289 if device >= count {
290 return Err(CudaRtError::InvalidDevice);
291 }
292 CudaDeviceProp::from_device(device as c_int)
293}
294
295/// Blocks until all preceding tasks in the current device's context complete.
296///
297/// Mirrors `cudaDeviceSynchronize`.
298///
299/// # Errors
300///
301/// Returns [`CudaRtError::DeviceNotSet`] if no device is selected, or
302/// a driver error if synchronization fails.
303pub fn device_synchronize() -> CudaRtResult<()> {
304 let _device = get_device()?;
305 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
306 // SAFETY: FFI; driver's current context is valid.
307 unsafe { (api.cu_ctx_synchronize)() };
308 Ok(())
309}
310
311/// Explicitly destroys and cleans up all resources associated with the current
312/// device in the current process.
313///
314/// Mirrors `cudaDeviceReset`.
315///
316/// # Errors
317///
318/// Propagates driver errors.
319pub fn device_reset() -> CudaRtResult<()> {
320 let _device = get_device()?;
321 let api = try_driver().map_err(|_| CudaRtError::DriverNotAvailable)?;
322 // The Runtime API implements reset by resetting the primary context.
323 // We obtain and then release the primary context to force a reset.
324 let mut _dev: c_int = 0;
325 CURRENT_DEVICE.with(|cell| {
326 if let Some(d) = cell.get() {
327 _dev = d;
328 }
329 });
330 // SAFETY: FFI; _dev is a valid device ordinal.
331 unsafe { (api.cu_device_primary_ctx_reset_v2)(_dev) };
332 // Forget the device binding for this thread.
333 CURRENT_DEVICE.with(|cell| cell.set(None));
334 Ok(())
335}
336
337/// Returns the compute capability as a `(major, minor)` pair.
338///
339/// Convenience helper built on top of [`get_device_properties`].
340///
341/// # Errors
342///
343/// Propagates errors from `get_device_properties`.
344pub fn get_compute_capability(device: u32) -> CudaRtResult<(u32, u32)> {
345 let props = get_device_properties(device)?;
346 Ok((props.major, props.minor))
347}
348
349// ─── Tests ───────────────────────────────────────────────────────────────────
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 #[cfg(not(feature = "gpu-tests"))]
357 fn get_device_without_set_errors() {
358 // Fresh thread has no device set.
359 let result = get_device();
360 assert!(matches!(result, Err(CudaRtError::DeviceNotSet)));
361 }
362
363 #[test]
364 fn set_device_persists_in_thread() {
365 // We can't test against a real GPU here, but we can verify the
366 // thread-local logic is consistent if get_device_count works.
367 // If driver is absent, both calls return the same error class.
368 let count_result = get_device_count();
369 match count_result {
370 Err(CudaRtError::DriverNotAvailable) | Err(CudaRtError::NoGpu) => {
371 // No GPU environment — only verify DeviceNotSet.
372 assert!(get_device().is_err());
373 }
374 Ok(n) => {
375 // Real GPU: set device 0 and verify round-trip.
376 set_device(0).expect("set_device(0) failed");
377 assert_eq!(get_device().unwrap(), 0);
378 // Out-of-range device must fail.
379 assert!(matches!(set_device(n), Err(CudaRtError::InvalidDevice)));
380 }
381 Err(e) => panic!("unexpected error: {e}"),
382 }
383 }
384
385 #[test]
386 fn from_code_round_trip() {
387 // Verify that error::from_code covers the device-specific codes.
388 assert_eq!(CudaRtError::from_code(100), Some(CudaRtError::NoDevice));
389 assert_eq!(
390 CudaRtError::from_code(101),
391 Some(CudaRtError::InvalidDevice)
392 );
393 }
394}