cuda_rust_wasm/runtime/
multi_gpu.rs1use crate::{Result, runtime_error};
7use crate::runtime::device::{Device, DeviceProperties, BackendType};
8use std::sync::Arc;
9
10pub struct MultiGpuContext {
12 devices: Vec<Arc<Device>>,
14 active_device: usize,
16 peer_access: Vec<Vec<bool>>,
18}
19
20impl MultiGpuContext {
21 pub fn new() -> Result<Self> {
23 let mut devices = Vec::new();
24
25 let default_device = Device::get_default()?;
27 devices.push(default_device);
28
29 let additional = Self::probe_additional_devices();
33 devices.extend(additional);
34
35 let device_count = devices.len();
36 let peer_access = vec![vec![false; device_count]; device_count];
37
38 let mut ctx = Self {
39 devices,
40 active_device: 0,
41 peer_access,
42 };
43
44 ctx.setup_peer_access();
46 Ok(ctx)
47 }
48
49 pub fn device_count(&self) -> usize {
51 self.devices.len()
52 }
53
54 pub fn device(&self, index: usize) -> Result<&Arc<Device>> {
56 self.devices.get(index).ok_or_else(|| {
57 runtime_error!("Device index {} out of range (have {})", index, self.devices.len())
58 })
59 }
60
61 pub fn active_device(&self) -> &Arc<Device> {
63 &self.devices[self.active_device]
64 }
65
66 pub fn active_device_index(&self) -> usize {
68 self.active_device
69 }
70
71 pub fn set_device(&mut self, index: usize) -> Result<()> {
73 if index >= self.devices.len() {
74 return Err(runtime_error!(
75 "Device index {} out of range (have {})",
76 index, self.devices.len()
77 ));
78 }
79 self.active_device = index;
80 Ok(())
81 }
82
83 pub fn can_access_peer(&self, src: usize, dst: usize) -> Result<bool> {
85 if src >= self.devices.len() || dst >= self.devices.len() {
86 return Err(runtime_error!("Device index out of range"));
87 }
88 Ok(self.peer_access[src][dst])
89 }
90
91 pub fn enable_peer_access(&mut self, src: usize, dst: usize) -> Result<()> {
93 if src >= self.devices.len() || dst >= self.devices.len() {
94 return Err(runtime_error!("Device index out of range"));
95 }
96 if src == dst {
97 return Ok(()); }
99
100 let src_backend = self.devices[src].backend();
102 let dst_backend = self.devices[dst].backend();
103 if src_backend != dst_backend {
104 return Err(runtime_error!(
105 "Cannot enable peer access between different backends ({:?} and {:?})",
106 src_backend, dst_backend
107 ));
108 }
109
110 self.peer_access[src][dst] = true;
111 self.peer_access[dst][src] = true;
112 Ok(())
113 }
114
115 pub fn disable_peer_access(&mut self, src: usize, dst: usize) -> Result<()> {
117 if src >= self.devices.len() || dst >= self.devices.len() {
118 return Err(runtime_error!("Device index out of range"));
119 }
120 self.peer_access[src][dst] = false;
121 self.peer_access[dst][src] = false;
122 Ok(())
123 }
124
125 pub fn all_properties(&self) -> Vec<&DeviceProperties> {
127 self.devices.iter().map(|d| d.properties()).collect()
128 }
129
130 pub fn distribute_range(&self, total: usize) -> Vec<DeviceRange> {
132 let n = self.devices.len();
133 let chunk = total / n;
134 let remainder = total % n;
135
136 let mut ranges = Vec::with_capacity(n);
137 let mut offset = 0;
138
139 for i in 0..n {
140 let len = chunk + if i < remainder { 1 } else { 0 };
141 ranges.push(DeviceRange {
142 device_index: i,
143 offset,
144 length: len,
145 });
146 offset += len;
147 }
148
149 ranges
150 }
151
152 fn probe_additional_devices() -> Vec<Arc<Device>> {
154 if let Ok(output) = std::process::Command::new("nvidia-smi")
156 .args(["--query-gpu=count", "--format=csv,noheader,nounits"])
157 .output()
158 {
159 if output.status.success() {
160 let stdout = String::from_utf8_lossy(&output.stdout);
161 if let Ok(count) = stdout.trim().parse::<usize>() {
162 if count > 1 {
163 let mut additional = Vec::new();
165 for id in 1..count {
166 if let Ok(dev) = Device::get_by_id(id) {
167 additional.push(dev);
168 }
169 }
170 return additional;
171 }
172 }
173 }
174 }
175 Vec::new()
176 }
177
178 fn setup_peer_access(&mut self) {
180 let n = self.devices.len();
181 for i in 0..n {
182 self.peer_access[i][i] = true; for j in (i + 1)..n {
184 let same_backend = self.devices[i].backend() == self.devices[j].backend();
185 if same_backend {
186 self.peer_access[i][j] = true;
187 self.peer_access[j][i] = true;
188 }
189 }
190 }
191 }
192}
193
194impl Default for MultiGpuContext {
195 fn default() -> Self {
196 Self::new().unwrap_or_else(|_| {
197 Self {
199 devices: vec![Device::get_default().expect("default device should be available")],
200 active_device: 0,
201 peer_access: vec![vec![true]],
202 }
203 })
204 }
205}
206
207#[derive(Debug, Clone)]
209pub struct DeviceRange {
210 pub device_index: usize,
211 pub offset: usize,
212 pub length: usize,
213}
214
215pub fn memcpy_peer(
217 _dst_device: usize,
218 _src_device: usize,
219 _size: usize,
220) -> Result<()> {
221 Ok(())
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_multi_gpu_creation() {
231 let ctx = MultiGpuContext::new().unwrap();
232 assert!(ctx.device_count() >= 1);
233 assert_eq!(ctx.active_device_index(), 0);
234 }
235
236 #[test]
237 fn test_device_access() {
238 let ctx = MultiGpuContext::new().unwrap();
239 let dev = ctx.device(0);
240 assert!(dev.is_ok());
241 }
242
243 #[test]
244 fn test_device_out_of_range() {
245 let ctx = MultiGpuContext::new().unwrap();
246 let result = ctx.device(999);
247 assert!(result.is_err());
248 }
249
250 #[test]
251 fn test_set_active_device() {
252 let mut ctx = MultiGpuContext::new().unwrap();
253 assert!(ctx.set_device(0).is_ok());
254 assert_eq!(ctx.active_device_index(), 0);
255 }
256
257 #[test]
258 fn test_set_device_out_of_range() {
259 let mut ctx = MultiGpuContext::new().unwrap();
260 assert!(ctx.set_device(999).is_err());
261 }
262
263 #[test]
264 fn test_self_peer_access() {
265 let ctx = MultiGpuContext::new().unwrap();
266 assert!(ctx.can_access_peer(0, 0).unwrap());
267 }
268
269 #[test]
270 fn test_distribute_range() {
271 let ctx = MultiGpuContext::new().unwrap();
272 let ranges = ctx.distribute_range(100);
273 assert!(!ranges.is_empty());
274
275 let total: usize = ranges.iter().map(|r| r.length).sum();
277 assert_eq!(total, 100);
278 }
279
280 #[test]
281 fn test_distribute_range_uneven() {
282 let ctx = MultiGpuContext::new().unwrap();
283 let n = ctx.device_count();
284 let total = n * 10 + 3; let ranges = ctx.distribute_range(total);
286
287 let sum: usize = ranges.iter().map(|r| r.length).sum();
288 assert_eq!(sum, total);
289
290 let mut offset = 0;
292 for r in &ranges {
293 assert_eq!(r.offset, offset);
294 offset += r.length;
295 }
296 }
297
298 #[test]
299 fn test_all_properties() {
300 let ctx = MultiGpuContext::new().unwrap();
301 let props = ctx.all_properties();
302 assert_eq!(props.len(), ctx.device_count());
303 }
304
305 #[test]
306 fn test_memcpy_peer() {
307 assert!(memcpy_peer(0, 0, 1024).is_ok());
309 }
310
311 #[test]
312 fn test_default_context() {
313 let ctx = MultiGpuContext::default();
314 assert!(ctx.device_count() >= 1);
315 }
316}