Skip to main content

baracuda_driver/
device.rs

1//! Physical-GPU query and enumeration.
2
3use core::ffi::c_char;
4
5use baracuda_cuda_sys::types::CUdevice_attribute as Attr;
6use baracuda_cuda_sys::{driver, CUdevice};
7
8use crate::error::{check, Result};
9use crate::init::init;
10
11/// A CUDA device (a physical GPU, or a logical slice of one under MIG).
12///
13/// Cheap `Copy` type — it's just an ordinal.
14#[derive(Copy, Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
15pub struct Device(pub(crate) CUdevice);
16
17impl Device {
18    /// Number of CUDA devices visible to the process.
19    pub fn count() -> Result<u32> {
20        init()?;
21        let d = driver()?;
22        let cu = d.cu_device_get_count()?;
23        let mut n: core::ffi::c_int = 0;
24        // SAFETY: `out` points to a writable i32.
25        check(unsafe { cu(&mut n) })?;
26        Ok(n as u32)
27    }
28
29    /// Retrieve the device with the given ordinal.
30    pub fn get(ordinal: u32) -> Result<Self> {
31        init()?;
32        let d = driver()?;
33        let cu = d.cu_device_get()?;
34        let mut dev = CUdevice::default();
35        // SAFETY: `dev` points to a writable CUdevice; the cast is widening on 64-bit.
36        check(unsafe { cu(&mut dev, ordinal as core::ffi::c_int) })?;
37        Ok(Self(dev))
38    }
39
40    /// All visible devices, in ordinal order.
41    pub fn all() -> Result<Vec<Self>> {
42        let count = Self::count()?;
43        (0..count).map(Self::get).collect()
44    }
45
46    /// Raw ordinal (`0`, `1`, ...).
47    #[inline]
48    pub fn ordinal(&self) -> i32 {
49        self.0 .0
50    }
51
52    /// Human-readable name, e.g. `"NVIDIA GeForce RTX 4090"`.
53    pub fn name(&self) -> Result<String> {
54        let d = driver()?;
55        let cu = d.cu_device_get_name()?;
56        let mut buf = vec![0u8; 256];
57        // SAFETY: `buf` is valid for writes of `buf.len()` bytes; the
58        // function is documented to null-terminate within the buffer.
59        check(unsafe {
60            cu(
61                buf.as_mut_ptr() as *mut c_char,
62                buf.len() as core::ffi::c_int,
63                self.0,
64            )
65        })?;
66        let nul = buf.iter().position(|&b| b == 0).unwrap_or(buf.len());
67        Ok(String::from_utf8_lossy(&buf[..nul]).into_owned())
68    }
69
70    /// Total global memory on this device, in bytes.
71    pub fn total_memory(&self) -> Result<u64> {
72        let d = driver()?;
73        let cu = d.cu_device_total_mem()?;
74        let mut bytes: usize = 0;
75        // SAFETY: writable pointer to `usize`; CUDA writes `size_t`.
76        check(unsafe { cu(&mut bytes, self.0) })?;
77        Ok(bytes as u64)
78    }
79
80    /// Compute capability as `(major, minor)`, e.g. `(9, 0)` for Hopper.
81    pub fn compute_capability(&self) -> Result<(u32, u32)> {
82        let major = self.attribute(Attr::COMPUTE_CAPABILITY_MAJOR)?;
83        let minor = self.attribute(Attr::COMPUTE_CAPABILITY_MINOR)?;
84        Ok((major as u32, minor as u32))
85    }
86
87    /// Multiprocessor count (SM count).
88    pub fn multiprocessor_count(&self) -> Result<u32> {
89        Ok(self.attribute(Attr::MULTIPROCESSOR_COUNT)? as u32)
90    }
91
92    /// Warp size in threads (almost always 32).
93    pub fn warp_size(&self) -> Result<u32> {
94        Ok(self.attribute(Attr::WARP_SIZE)? as u32)
95    }
96
97    /// Query an arbitrary `CUdevice_attribute`. See
98    /// [`baracuda_cuda_sys::types::CUdevice_attribute`] for the full list.
99    pub fn attribute(&self, attr: i32) -> Result<i32> {
100        let d = driver()?;
101        let cu = d.cu_device_get_attribute()?;
102        let mut val: core::ffi::c_int = 0;
103        // SAFETY: writable i32; `attr` is a valid CUDA attribute selector
104        // (caller supplied, but CUDA returns an error for invalid selectors
105        // rather than UB).
106        check(unsafe { cu(&mut val, attr, self.0) })?;
107        Ok(val)
108    }
109
110    /// The raw `CUdevice` handle. Use with care.
111    #[inline]
112    pub fn as_raw(&self) -> CUdevice {
113        self.0
114    }
115
116    /// Return the device's 16-byte UUID.
117    pub fn uuid(&self) -> Result<[u8; 16]> {
118        let d = driver()?;
119        let cu = d.cu_device_get_uuid()?;
120        let mut out = [0u8; 16];
121        check(unsafe { cu(out.as_mut_ptr(), self.0) })?;
122        Ok(out)
123    }
124
125    /// Return the device's Windows LUID and 32-bit device-node mask
126    /// (Windows only; Linux returns zeros).
127    pub fn luid(&self) -> Result<([u8; 8], u32)> {
128        let d = driver()?;
129        let cu = d.cu_device_get_luid()?;
130        let mut luid = [0i8; 8];
131        let mut mask: core::ffi::c_uint = 0;
132        check(unsafe { cu(luid.as_mut_ptr(), &mut mask, self.0) })?;
133        Ok((luid.map(|b| b as u8), mask))
134    }
135
136    /// Query a peer-to-peer attribute between `self` (as source) and
137    /// `peer` (as destination). Pass a constant from
138    /// [`baracuda_cuda_sys::types::CUdevice_P2PAttribute`].
139    pub fn p2p_attribute(&self, peer: &Device, attr: i32) -> Result<i32> {
140        let d = driver()?;
141        let cu = d.cu_device_get_p2p_attribute()?;
142        let mut v: core::ffi::c_int = 0;
143        check(unsafe { cu(&mut v, attr, self.0, peer.0) })?;
144        Ok(v)
145    }
146
147    /// Query whether this device supports a given exec-affinity type
148    /// (e.g. SM-count partitioning at context-creation time).
149    pub fn exec_affinity_support(&self, affinity_type: i32) -> Result<bool> {
150        let d = driver()?;
151        let cu = d.cu_device_get_exec_affinity_support()?;
152        let mut v: core::ffi::c_int = 0;
153        check(unsafe { cu(&mut v, affinity_type, self.0) })?;
154        Ok(v != 0)
155    }
156
157    /// `true` if this device can directly access allocations on `peer`.
158    /// Peer access still requires a matching `Context::enable_peer_access`
159    /// on the accessing side before kernels can dereference peer pointers.
160    pub fn can_access_peer(&self, peer: &Device) -> Result<bool> {
161        let d = driver()?;
162        let cu = d.cu_device_can_access_peer()?;
163        let mut out: core::ffi::c_int = 0;
164        check(unsafe { cu(&mut out, self.0, peer.0) })?;
165        Ok(out != 0)
166    }
167
168    /// Query the primary-context state for this device.
169    /// Returns `(flags, active)` — `flags` is the same bitmask
170    /// [`crate::Context::with_flags`] takes, `active` is `true` if some
171    /// caller currently holds a retained primary-context reference.
172    pub fn primary_ctx_state(&self) -> Result<(u32, bool)> {
173        let d = driver()?;
174        let cu = d.cu_device_primary_ctx_get_state()?;
175        let mut flags: core::ffi::c_uint = 0;
176        let mut active: core::ffi::c_int = 0;
177        check(unsafe { cu(self.0, &mut flags, &mut active) })?;
178        Ok((flags, active != 0))
179    }
180
181    /// Set the flags used when the primary context is later created.
182    /// Returns `CUDA_ERROR_PRIMARY_CONTEXT_ACTIVE` if the primary context
183    /// already exists; reset with `Context::reset_primary` first.
184    pub fn set_primary_ctx_flags(&self, flags: u32) -> Result<()> {
185        let d = driver()?;
186        let cu = d.cu_device_primary_ctx_set_flags()?;
187        check(unsafe { cu(self.0, flags) })
188    }
189}