Skip to main content

cuda_rust_wasm/runtime/
multi_gpu.rs

1//! Multi-GPU support for device enumeration and peer-to-peer operations
2//!
3//! Provides multi-device management, peer-to-peer memory access,
4//! and workload distribution across multiple GPU devices.
5
6use crate::{Result, runtime_error};
7use crate::runtime::device::{Device, DeviceProperties, BackendType};
8use std::sync::Arc;
9
10/// Multi-GPU context for managing multiple devices
11pub struct MultiGpuContext {
12    /// Available devices
13    devices: Vec<Arc<Device>>,
14    /// Active device index
15    active_device: usize,
16    /// Peer access matrix (devices[i] can access devices[j])
17    peer_access: Vec<Vec<bool>>,
18}
19
20impl MultiGpuContext {
21    /// Create a multi-GPU context by enumerating all available devices
22    pub fn new() -> Result<Self> {
23        let mut devices = Vec::new();
24
25        // Get the default device (always available)
26        let default_device = Device::get_default()?;
27        devices.push(default_device);
28
29        // Probe for additional devices based on backend
30        // In a real implementation, this would enumerate CUDA devices
31        // via cuDeviceGetCount or hipGetDeviceCount
32        let additional = Self::probe_additional_devices();
33        devices.extend(additional);
34
35        let device_count = devices.len();
36        let peer_access = vec![vec![false; device_count]; device_count];
37
38        let mut ctx = Self {
39            devices,
40            active_device: 0,
41            peer_access,
42        };
43
44        // Enable peer access where supported (same backend type)
45        ctx.setup_peer_access();
46        Ok(ctx)
47    }
48
49    /// Get number of available devices
50    pub fn device_count(&self) -> usize {
51        self.devices.len()
52    }
53
54    /// Get a device by index
55    pub fn device(&self, index: usize) -> Result<&Arc<Device>> {
56        self.devices.get(index).ok_or_else(|| {
57            runtime_error!("Device index {} out of range (have {})", index, self.devices.len())
58        })
59    }
60
61    /// Get the active device
62    pub fn active_device(&self) -> &Arc<Device> {
63        &self.devices[self.active_device]
64    }
65
66    /// Get the active device index
67    pub fn active_device_index(&self) -> usize {
68        self.active_device
69    }
70
71    /// Set the active device
72    pub fn set_device(&mut self, index: usize) -> Result<()> {
73        if index >= self.devices.len() {
74            return Err(runtime_error!(
75                "Device index {} out of range (have {})",
76                index, self.devices.len()
77            ));
78        }
79        self.active_device = index;
80        Ok(())
81    }
82
83    /// Check if peer access is enabled between two devices
84    pub fn can_access_peer(&self, src: usize, dst: usize) -> Result<bool> {
85        if src >= self.devices.len() || dst >= self.devices.len() {
86            return Err(runtime_error!("Device index out of range"));
87        }
88        Ok(self.peer_access[src][dst])
89    }
90
91    /// Enable peer access between two devices
92    pub fn enable_peer_access(&mut self, src: usize, dst: usize) -> Result<()> {
93        if src >= self.devices.len() || dst >= self.devices.len() {
94            return Err(runtime_error!("Device index out of range"));
95        }
96        if src == dst {
97            return Ok(()); // Self-access is always allowed
98        }
99
100        // Check backend compatibility
101        let src_backend = self.devices[src].backend();
102        let dst_backend = self.devices[dst].backend();
103        if src_backend != dst_backend {
104            return Err(runtime_error!(
105                "Cannot enable peer access between different backends ({:?} and {:?})",
106                src_backend, dst_backend
107            ));
108        }
109
110        self.peer_access[src][dst] = true;
111        self.peer_access[dst][src] = true;
112        Ok(())
113    }
114
115    /// Disable peer access between two devices
116    pub fn disable_peer_access(&mut self, src: usize, dst: usize) -> Result<()> {
117        if src >= self.devices.len() || dst >= self.devices.len() {
118            return Err(runtime_error!("Device index out of range"));
119        }
120        self.peer_access[src][dst] = false;
121        self.peer_access[dst][src] = false;
122        Ok(())
123    }
124
125    /// Get properties for all devices
126    pub fn all_properties(&self) -> Vec<&DeviceProperties> {
127        self.devices.iter().map(|d| d.properties()).collect()
128    }
129
130    /// Distribute a 1D range across all devices (simple round-robin)
131    pub fn distribute_range(&self, total: usize) -> Vec<DeviceRange> {
132        let n = self.devices.len();
133        let chunk = total / n;
134        let remainder = total % n;
135
136        let mut ranges = Vec::with_capacity(n);
137        let mut offset = 0;
138
139        for i in 0..n {
140            let len = chunk + if i < remainder { 1 } else { 0 };
141            ranges.push(DeviceRange {
142                device_index: i,
143                offset,
144                length: len,
145            });
146            offset += len;
147        }
148
149        ranges
150    }
151
152    /// Probe for additional GPU devices beyond the default
153    fn probe_additional_devices() -> Vec<Arc<Device>> {
154        // Probe nvidia-smi for multi-GPU systems
155        if let Ok(output) = std::process::Command::new("nvidia-smi")
156            .args(["--query-gpu=count", "--format=csv,noheader,nounits"])
157            .output()
158        {
159            if output.status.success() {
160                let stdout = String::from_utf8_lossy(&output.stdout);
161                if let Ok(count) = stdout.trim().parse::<usize>() {
162                    if count > 1 {
163                        // Return additional virtual devices
164                        let mut additional = Vec::new();
165                        for id in 1..count {
166                            if let Ok(dev) = Device::get_by_id(id) {
167                                additional.push(dev);
168                            }
169                        }
170                        return additional;
171                    }
172                }
173            }
174        }
175        Vec::new()
176    }
177
178    /// Setup peer access based on backend compatibility
179    fn setup_peer_access(&mut self) {
180        let n = self.devices.len();
181        for i in 0..n {
182            self.peer_access[i][i] = true; // Self-access always allowed
183            for j in (i + 1)..n {
184                let same_backend = self.devices[i].backend() == self.devices[j].backend();
185                if same_backend {
186                    self.peer_access[i][j] = true;
187                    self.peer_access[j][i] = true;
188                }
189            }
190        }
191    }
192}
193
194impl Default for MultiGpuContext {
195    fn default() -> Self {
196        Self::new().unwrap_or_else(|_| {
197            // Fallback: single device with no peer access
198            Self {
199                devices: vec![Device::get_default().expect("default device should be available")],
200                active_device: 0,
201                peer_access: vec![vec![true]],
202            }
203        })
204    }
205}
206
207/// Describes a range of work assigned to a device
208#[derive(Debug, Clone)]
209pub struct DeviceRange {
210    pub device_index: usize,
211    pub offset: usize,
212    pub length: usize,
213}
214
215/// Peer-to-peer memory copy placeholder
216pub fn memcpy_peer(
217    _dst_device: usize,
218    _src_device: usize,
219    _size: usize,
220) -> Result<()> {
221    // In CPU emulation mode, all memory is shared, so peer copy is a no-op
222    Ok(())
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_multi_gpu_creation() {
231        let ctx = MultiGpuContext::new().unwrap();
232        assert!(ctx.device_count() >= 1);
233        assert_eq!(ctx.active_device_index(), 0);
234    }
235
236    #[test]
237    fn test_device_access() {
238        let ctx = MultiGpuContext::new().unwrap();
239        let dev = ctx.device(0);
240        assert!(dev.is_ok());
241    }
242
243    #[test]
244    fn test_device_out_of_range() {
245        let ctx = MultiGpuContext::new().unwrap();
246        let result = ctx.device(999);
247        assert!(result.is_err());
248    }
249
250    #[test]
251    fn test_set_active_device() {
252        let mut ctx = MultiGpuContext::new().unwrap();
253        assert!(ctx.set_device(0).is_ok());
254        assert_eq!(ctx.active_device_index(), 0);
255    }
256
257    #[test]
258    fn test_set_device_out_of_range() {
259        let mut ctx = MultiGpuContext::new().unwrap();
260        assert!(ctx.set_device(999).is_err());
261    }
262
263    #[test]
264    fn test_self_peer_access() {
265        let ctx = MultiGpuContext::new().unwrap();
266        assert!(ctx.can_access_peer(0, 0).unwrap());
267    }
268
269    #[test]
270    fn test_distribute_range() {
271        let ctx = MultiGpuContext::new().unwrap();
272        let ranges = ctx.distribute_range(100);
273        assert!(!ranges.is_empty());
274
275        // Total length should sum to 100
276        let total: usize = ranges.iter().map(|r| r.length).sum();
277        assert_eq!(total, 100);
278    }
279
280    #[test]
281    fn test_distribute_range_uneven() {
282        let ctx = MultiGpuContext::new().unwrap();
283        let n = ctx.device_count();
284        let total = n * 10 + 3; // Not evenly divisible
285        let ranges = ctx.distribute_range(total);
286
287        let sum: usize = ranges.iter().map(|r| r.length).sum();
288        assert_eq!(sum, total);
289
290        // Each range should be contiguous
291        let mut offset = 0;
292        for r in &ranges {
293            assert_eq!(r.offset, offset);
294            offset += r.length;
295        }
296    }
297
298    #[test]
299    fn test_all_properties() {
300        let ctx = MultiGpuContext::new().unwrap();
301        let props = ctx.all_properties();
302        assert_eq!(props.len(), ctx.device_count());
303    }
304
305    #[test]
306    fn test_memcpy_peer() {
307        // Should succeed in CPU emulation mode
308        assert!(memcpy_peer(0, 0, 1024).is_ok());
309    }
310
311    #[test]
312    fn test_default_context() {
313        let ctx = MultiGpuContext::default();
314        assert!(ctx.device_count() >= 1);
315    }
316}