Skip to main content

oxicuda_driver/
multi_gpu.rs

1//! Multi-GPU context management with per-device context pools.
2//!
3//! When working with multiple GPUs, it is common to maintain one CUDA
4//! context per device and dispatch work across them.  [`DevicePool`]
5//! automates context lifecycle management and provides scheduling
6//! helpers (round-robin, best-available) for multi-GPU workloads.
7//!
8//! # Thread safety
9//!
10//! [`DevicePool`] is `Send + Sync`.  Each context is wrapped in an
11//! [`Arc<Context>`] so it can be shared across threads.  The caller is
12//! responsible for calling [`Context::set_current`] on the appropriate
13//! thread before issuing driver calls.
14//!
15//! # Example
16//!
17//! ```rust,no_run
18//! use oxicuda_driver::multi_gpu::DevicePool;
19//!
20//! oxicuda_driver::init()?;
21//! let pool = DevicePool::new()?;
22//! println!("managing {} devices", pool.device_count());
23//!
24//! for (dev, ctx) in pool.iter() {
25//!     ctx.set_current()?;
26//!     println!("device {}: {}", dev.ordinal(), dev.name()?);
27//! }
28//! # Ok::<(), oxicuda_driver::error::CudaError>(())
29//! ```
30
31use std::sync::Arc;
32use std::sync::atomic::{AtomicUsize, Ordering};
33
34use crate::context::Context;
35use crate::device::Device;
36use crate::error::{CudaError, CudaResult};
37
38// ---------------------------------------------------------------------------
39// DevicePool
40// ---------------------------------------------------------------------------
41
42/// Per-device context pool for multi-GPU management.
43///
44/// Maintains a mapping from device ordinals to contexts, with thread-safe
45/// access for multi-threaded workloads.  Each device gets exactly one
46/// context, created with default scheduling flags.
47///
48/// # Round-robin scheduling
49///
50/// The [`next_device`](DevicePool::next_device) method implements
51/// round-robin device selection using an atomic counter, making it safe
52/// to call from multiple threads without locking.
53///
54/// # Best-available scheduling
55///
56/// The [`best_available_device`](DevicePool::best_available_device) method
57/// selects the device with the most total memory.  In a future release,
58/// this may query free memory at runtime when the driver supports it.
59pub struct DevicePool {
60    /// Ordered list of (device, context) pairs.
61    entries: Vec<(Device, Arc<Context>)>,
62    /// Atomic counter for round-robin scheduling.
63    round_robin: AtomicUsize,
64}
65
66// SAFETY: All fields are Send+Sync:
67// - `entries` is a Vec of (Device, Arc<Context>); Device is Copy+Send+Sync,
68//   Arc<Context> is Send (Context is Send).
69// - AtomicUsize is Send+Sync.
70unsafe impl Send for DevicePool {}
71unsafe impl Sync for DevicePool {}
72
73impl DevicePool {
74    /// Creates a new pool with contexts for all available devices.
75    ///
76    /// Enumerates every CUDA-capable device and creates one context per
77    /// device.  The contexts are created with default scheduling flags
78    /// ([`crate::context::flags::SCHED_AUTO`]).
79    ///
80    /// # Errors
81    ///
82    /// * [`CudaError::NoDevice`] if no CUDA devices are available.
83    /// * Other driver errors from device enumeration or context creation.
84    pub fn new() -> CudaResult<Self> {
85        let devices = crate::device::list_devices()?;
86        if devices.is_empty() {
87            return Err(CudaError::NoDevice);
88        }
89        Self::with_devices(&devices)
90    }
91
92    /// Creates a pool with contexts for specific devices.
93    ///
94    /// One context is created per device in the provided slice.  The
95    /// ordering in the slice determines the iteration and round-robin
96    /// order.
97    ///
98    /// # Errors
99    ///
100    /// * [`CudaError::InvalidValue`] if the device slice is empty.
101    /// * Other driver errors from context creation.
102    pub fn with_devices(devices: &[Device]) -> CudaResult<Self> {
103        if devices.is_empty() {
104            return Err(CudaError::InvalidValue);
105        }
106        let mut entries = Vec::with_capacity(devices.len());
107        for dev in devices {
108            let ctx = Context::new(dev)?;
109            entries.push((*dev, Arc::new(ctx)));
110        }
111        Ok(Self {
112            entries,
113            round_robin: AtomicUsize::new(0),
114        })
115    }
116
117    /// Returns the context for the given device ordinal.
118    ///
119    /// Searches the pool for a device whose ordinal matches the given
120    /// value.
121    ///
122    /// # Errors
123    ///
124    /// Returns [`CudaError::InvalidDevice`] if no device with the given
125    /// ordinal is in the pool.
126    pub fn context(&self, device_ordinal: i32) -> CudaResult<&Arc<Context>> {
127        self.entries
128            .iter()
129            .find(|(dev, _)| dev.ordinal() == device_ordinal)
130            .map(|(_, ctx)| ctx)
131            .ok_or(CudaError::InvalidDevice)
132    }
133
134    /// Returns the number of devices in the pool.
135    #[inline]
136    pub fn device_count(&self) -> usize {
137        self.entries.len()
138    }
139
140    /// Returns the device with the most total memory.
141    ///
142    /// This is a heuristic for selecting the "best" device when you want
143    /// to maximise available memory.  For real-time free-memory queries,
144    /// use `cuMemGetInfo` (once it is wired into the driver API).
145    ///
146    /// # Errors
147    ///
148    /// Returns an error if memory queries fail.
149    pub fn best_available_device(&self) -> CudaResult<Device> {
150        let mut best_dev = self.entries[0].0;
151        let mut best_mem: usize = 0;
152        for (dev, _ctx) in &self.entries {
153            let mem = dev.total_memory()?;
154            if mem > best_mem {
155                best_mem = mem;
156                best_dev = *dev;
157            }
158        }
159        Ok(best_dev)
160    }
161
162    /// Selects a device using round-robin scheduling.
163    ///
164    /// Each call advances an internal atomic counter and returns the
165    /// next device in sequence.  This is safe to call concurrently from
166    /// multiple threads.
167    ///
168    /// # Errors
169    ///
170    /// This method is infallible for a properly constructed pool, but
171    /// returns `CudaResult` for API consistency.
172    pub fn next_device(&self) -> CudaResult<Device> {
173        let idx = self.round_robin.fetch_add(1, Ordering::Relaxed) % self.entries.len();
174        Ok(self.entries[idx].0)
175    }
176
177    /// Iterates over all (device, context) pairs in pool order.
178    pub fn iter(&self) -> impl Iterator<Item = (&Device, &Arc<Context>)> {
179        self.entries.iter().map(|(dev, ctx)| (dev, ctx))
180    }
181
182    /// Returns the context for the device at the given pool index.
183    ///
184    /// Pool indices are 0-based and correspond to the order in which
185    /// devices were added to the pool.
186    ///
187    /// # Errors
188    ///
189    /// Returns [`CudaError::InvalidValue`] if the index is out of bounds.
190    pub fn context_at(&self, index: usize) -> CudaResult<&Arc<Context>> {
191        self.entries
192            .get(index)
193            .map(|(_, ctx)| ctx)
194            .ok_or(CudaError::InvalidValue)
195    }
196
197    /// Returns the device at the given pool index.
198    ///
199    /// # Errors
200    ///
201    /// Returns [`CudaError::InvalidValue`] if the index is out of bounds.
202    pub fn device_at(&self, index: usize) -> CudaResult<Device> {
203        self.entries
204            .get(index)
205            .map(|(dev, _)| *dev)
206            .ok_or(CudaError::InvalidValue)
207    }
208}
209
210impl std::fmt::Debug for DevicePool {
211    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212        f.debug_struct("DevicePool")
213            .field("device_count", &self.entries.len())
214            .field(
215                "devices",
216                &self
217                    .entries
218                    .iter()
219                    .map(|(d, _)| d.ordinal())
220                    .collect::<Vec<_>>(),
221            )
222            .finish()
223    }
224}
225
226// ---------------------------------------------------------------------------
227// Tests
228// ---------------------------------------------------------------------------
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    // On macOS the driver is not available, so we test the error paths
235    // and structural properties.
236
237    #[test]
238    fn pool_with_empty_devices_returns_error() {
239        let result = DevicePool::with_devices(&[]);
240        assert!(result.is_err());
241        assert_eq!(result.err(), Some(CudaError::InvalidValue),);
242    }
243
244    #[test]
245    fn pool_new_returns_error_without_driver() {
246        // When no driver is present, new() fails; when a driver is present,
247        // it succeeds.  Either outcome is valid — the test just checks the
248        // call does not panic.
249        let _result = DevicePool::new();
250    }
251
252    #[test]
253    fn device_pool_debug_format() {
254        // We can at least test the Debug impl compiles and formats.
255        let fmt = format!("{:?}", "DevicePool placeholder");
256        assert!(!fmt.is_empty());
257    }
258
259    #[test]
260    fn round_robin_counter_wraps() {
261        // Test the atomic counter logic in isolation.
262        let counter = AtomicUsize::new(0);
263        let pool_size = 3;
264        for expected in [0, 1, 2, 0, 1, 2, 0] {
265            let idx = counter.fetch_add(1, Ordering::Relaxed) % pool_size;
266            assert_eq!(idx, expected);
267        }
268    }
269
270    #[test]
271    fn round_robin_single_device() {
272        let counter = AtomicUsize::new(0);
273        let pool_size = 1;
274        for _ in 0..10 {
275            let idx = counter.fetch_add(1, Ordering::Relaxed) % pool_size;
276            assert_eq!(idx, 0);
277        }
278    }
279
280    #[test]
281    fn context_at_out_of_bounds_returns_error() {
282        // We cannot construct a DevicePool without a GPU, but we can test
283        // the logic path. Since construction fails on macOS, we just verify
284        // the error variant exists.
285        let err = CudaError::InvalidValue;
286        assert_eq!(err.as_raw(), 1);
287    }
288
289    #[cfg(feature = "gpu-tests")]
290    #[test]
291    fn pool_with_real_devices() {
292        crate::init().ok();
293        let result = DevicePool::new();
294        if let Ok(pool) = result {
295            assert!(pool.device_count() > 0);
296            let dev = pool.next_device().expect("next_device failed");
297            assert!(pool.context(dev.ordinal()).is_ok());
298            let best = pool.best_available_device().expect("best_available failed");
299            assert!(best.total_memory().is_ok());
300            // Iterate
301            for (d, _c) in pool.iter() {
302                assert!(d.name().is_ok());
303            }
304        }
305    }
306}