1use crate::device::Device;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5#[derive(Debug, Clone)]
6pub enum MeshTopology {
7 DataParallel,
8 TensorParallel {
9 degree: usize,
10 mesh_shape: (usize, usize),
11 },
12 PipelineParallel {
13 stages: usize,
14 },
15 Mixed {
16 tp: usize,
17 pp: usize,
18 dp: usize,
19 },
20}
21
22impl MeshTopology {
23 pub fn expected_world_size(&self) -> usize {
25 match self {
26 MeshTopology::DataParallel => 1, MeshTopology::TensorParallel { degree, .. } => *degree,
28 MeshTopology::PipelineParallel { stages } => *stages,
29 MeshTopology::Mixed { tp, pp, dp } => tp * pp * dp,
30 }
31 }
32
33 pub fn validate(&self, world_size: usize) -> Result<(), String> {
35 match self {
36 MeshTopology::DataParallel => Ok(()), MeshTopology::TensorParallel { degree, .. } => {
38 if world_size < *degree {
39 Err(format!(
40 "Tensor parallel degree {} requires at least {} devices, got {}",
41 degree, degree, world_size
42 ))
43 } else {
44 Ok(())
45 }
46 }
47 MeshTopology::PipelineParallel { stages } => {
48 if world_size < *stages {
49 Err(format!(
50 "Pipeline parallel requires {} stages, got {} devices",
51 stages, world_size
52 ))
53 } else {
54 Ok(())
55 }
56 }
57 MeshTopology::Mixed { tp, pp, dp } => {
58 let required = tp * pp * dp;
59 if world_size < required {
60 Err(format!(
61 "Mixed parallelism (TP={}, PP={}, DP={}) requires {} devices, got {}",
62 tp, pp, dp, required, world_size
63 ))
64 } else {
65 Ok(())
66 }
67 }
68 }
69 }
70}
71
72#[derive(Debug, Clone)]
73pub struct ProcessGroup {
74 pub name: String,
75 pub ranks: Vec<usize>,
76 pub backend: GroupBackend,
77}
78
79#[derive(Debug, Clone, Copy)]
80pub enum GroupBackend {
81 Nccl, Gloo, MPI, Mock, }
86
87#[derive(Debug, Clone)]
92pub struct DeviceMesh {
93 devices: Arc<Vec<Arc<Device>>>,
95
96 pub world_size: usize,
98
99 pub rank: usize,
101
102 pub topology: MeshTopology,
104
105 groups: Arc<HashMap<String, ProcessGroup>>,
107
108 pub comm: Option<Arc<dyn MeshComm + Send + Sync>>,
110}
111
112impl DeviceMesh {
113 pub fn new(devices: Vec<Device>) -> Self {
115 let world_size = devices.len();
116 let devices: Vec<Arc<Device>> = devices.into_iter().map(Arc::new).collect();
117
118 let topology = MeshTopology::DataParallel;
120
121 let mut groups = HashMap::new();
123 groups.insert(
124 "world".to_string(),
125 ProcessGroup {
126 name: "world".to_string(),
127 ranks: (0..world_size).collect(),
128 backend: GroupBackend::Mock,
129 },
130 );
131
132 Self {
133 devices: Arc::new(devices),
134 world_size,
135 rank: 0, topology,
137 groups: Arc::new(groups),
138 comm: None,
139 }
140 }
141
142 pub fn new_with_mock_comm(devices: Vec<Device>, rank: usize) -> Self {
144 use crate::mock_comm::MockComm;
145
146 let world_size = devices.len();
147 let mut mesh = Self::new(devices);
148 mesh.rank = rank;
149 mesh.comm = Some(Arc::new(MockComm::new(rank, world_size)));
150 mesh
151 }
152
153 #[cfg(feature = "nccl")]
165 pub fn new_with_nccl(
166 devices: Vec<Device>,
167 nccl_id: &crate::nccl_comm::cudarc::nccl::Id,
168 rank: usize,
169 ) -> Result<Self, String> {
170 use crate::nccl_comm::NcclComm;
171 use cudarc::driver::CudaDevice;
172
173 let world_size = devices.len();
174
175 let cuda_device = CudaDevice::new(rank)
177 .map_err(|e| format!("Failed to get CUDA device {}: {:?}", rank, e))?;
178
179 let nccl_comm = NcclComm::new(cuda_device, nccl_id, rank, world_size)?;
180
181 let mut mesh = Self::new(devices);
182 mesh.rank = rank;
183 mesh.comm = Some(Arc::new(nccl_comm));
184 Ok(mesh)
185 }
186
187 pub fn with_topology(devices: Vec<Device>, topology: MeshTopology) -> Result<Self, String> {
189 let world_size = devices.len();
190
191 topology.validate(world_size)?;
193
194 let devices: Vec<Arc<Device>> = devices.into_iter().map(Arc::new).collect();
195
196 let mut groups = HashMap::new();
197
198 groups.insert(
200 "world".to_string(),
201 ProcessGroup {
202 name: "world".to_string(),
203 ranks: (0..world_size).collect(),
204 backend: GroupBackend::Mock,
205 },
206 );
207
208 match &topology {
210 MeshTopology::TensorParallel { degree, .. } => {
211 for i in 0..world_size / degree {
213 let start = i * degree;
214 let ranks: Vec<usize> = (start..start + degree).collect();
215 groups.insert(
216 format!("tp_{}", i),
217 ProcessGroup {
218 name: format!("tp_{}", i),
219 ranks,
220 backend: GroupBackend::Nccl,
221 },
222 );
223 }
224 }
225 MeshTopology::PipelineParallel { stages } => {
226 for stage in 0..*stages {
228 groups.insert(
229 format!("pp_stage_{}", stage),
230 ProcessGroup {
231 name: format!("pp_stage_{}", stage),
232 ranks: vec![stage],
233 backend: GroupBackend::Gloo,
234 },
235 );
236 }
237 }
238 MeshTopology::Mixed { tp, pp, dp } => {
239 let tp_size = *tp;
241 for dp_idx in 0..*dp {
242 for pp_idx in 0..*pp {
243 let base = (dp_idx * pp + pp_idx) * tp_size;
244 let ranks: Vec<usize> = (base..base + tp_size).collect();
245 groups.insert(
246 format!("tp_dp{}_pp{}", dp_idx, pp_idx),
247 ProcessGroup {
248 name: format!("tp_dp{}_pp{}", dp_idx, pp_idx),
249 ranks,
250 backend: GroupBackend::Nccl,
251 },
252 );
253 }
254 }
255 }
256 _ => {}
257 }
258
259 Ok(Self {
260 devices: Arc::new(devices),
261 world_size,
262 rank: 0,
263 topology,
264 groups: Arc::new(groups),
265 comm: None,
266 })
267 }
268
269 pub fn set_rank(&mut self, rank: usize) -> Result<(), String> {
271 if rank >= self.world_size {
272 return Err(format!(
273 "Rank {} out of bounds for world size {}",
274 rank, self.world_size
275 ));
276 }
277 self.rank = rank;
278 Ok(())
279 }
280
281 pub fn get_device(&self, rank: usize) -> Option<Arc<Device>> {
283 self.devices.get(rank).cloned()
284 }
285
286 pub fn local_device(&self) -> Option<Arc<Device>> {
288 self.get_device(self.rank)
289 }
290
291 pub fn all_devices(&self) -> Arc<Vec<Arc<Device>>> {
293 self.devices.clone()
294 }
295
296 pub fn devices_by_backend(&self, backend: crate::device::DeviceBackend) -> Vec<Arc<Device>> {
298 self.devices
299 .iter()
300 .filter(|d| std::mem::discriminant(&d.backend) == std::mem::discriminant(&backend))
301 .cloned()
302 .collect()
303 }
304
305 pub fn devices_in_group(&self, group_name: &str) -> Result<Vec<Arc<Device>>, String> {
307 let group = self
308 .groups
309 .get(group_name)
310 .ok_or_else(|| format!("Group '{}' not found", group_name))?;
311
312 Ok(group
313 .ranks
314 .iter()
315 .filter_map(|&rank| self.get_device(rank))
316 .collect())
317 }
318
319 pub fn add_group(
321 &mut self,
322 name: String,
323 ranks: Vec<usize>,
324 backend: GroupBackend,
325 ) -> Result<(), String> {
326 for &rank in &ranks {
328 if rank >= self.world_size {
329 return Err(format!("Rank {} out of bounds", rank));
330 }
331 }
332
333 let groups = Arc::make_mut(&mut self.groups);
334 groups.insert(
335 name.clone(),
336 ProcessGroup {
337 name,
338 ranks,
339 backend,
340 },
341 );
342
343 Ok(())
344 }
345
346 pub fn get_group(&self, name: &str) -> Option<&ProcessGroup> {
348 self.groups.get(name)
349 }
350
351 pub fn group_names(&self) -> Vec<String> {
353 self.groups.keys().cloned().collect()
354 }
355
356 pub fn in_group(&self, group_name: &str) -> bool {
358 self.groups
359 .get(group_name)
360 .map(|g| g.ranks.contains(&self.rank))
361 .unwrap_or(false)
362 }
363
364 pub fn group_rank(&self, group_name: &str) -> Option<usize> {
366 self.groups
367 .get(group_name)
368 .and_then(|g| g.ranks.iter().position(|&r| r == self.rank))
369 }
370
371 pub fn set_comm(&mut self, comm: Arc<dyn MeshComm + Send + Sync>) {
373 self.comm = Some(comm);
374 }
375
376 pub fn total_memory_mb(&self) -> u64 {
378 self.devices.iter().map(|d| d.memory_mb).sum()
379 }
380
381 pub fn total_compute_units(&self) -> u32 {
383 self.devices.iter().map(|d| d.compute_units).sum()
384 }
385
386 pub fn reshape(&mut self, new_topology: MeshTopology) -> Result<(), String> {
388 new_topology.validate(self.world_size)?;
389 self.topology = new_topology;
390 Ok(())
391 }
392
393 pub fn stats(&self) -> MeshStats {
395 let backend_counts = self.count_backends();
396
397 MeshStats {
398 world_size: self.world_size,
399 total_memory_mb: self.total_memory_mb(),
400 total_compute_units: self.total_compute_units(),
401 backend_distribution: backend_counts,
402 group_count: self.groups.len(),
403 }
404 }
405
406 fn count_backends(&self) -> HashMap<String, usize> {
407 let mut counts = HashMap::new();
408 for device in self.devices.iter() {
409 let backend_name = format!("{:?}", device.backend);
410 *counts.entry(backend_name).or_insert(0) += 1;
411 }
412 counts
413 }
414}
415
416#[derive(Debug, Clone)]
418pub struct MeshStats {
419 pub world_size: usize,
420 pub total_memory_mb: u64,
421 pub total_compute_units: u32,
422 pub backend_distribution: HashMap<String, usize>,
423 pub group_count: usize,
424}
425
426pub trait MeshComm: std::fmt::Debug {
428 fn all_reduce(
430 &self,
431 buf: &mut [u8],
432 dtype: DType,
433 op: ReduceOp,
434 group: &str,
435 ) -> Result<(), String>;
436
437 fn all_gather(
439 &self,
440 local: &[u8],
441 out: &mut [u8],
442 dtype: DType,
443 group: &str,
444 ) -> Result<(), String>;
445
446 fn broadcast(&self, buf: &mut [u8], root_rank: usize, group: &str) -> Result<(), String>;
448
449 fn reduce_scatter(
451 &self,
452 buf: &mut [u8],
453 out: &mut [u8],
454 op: ReduceOp,
455 group: &str,
456 ) -> Result<(), String>;
457
458 fn barrier(&self, group: &str) -> Result<(), String>;
460
461 fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String>;
463
464 fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String>;
466}
467
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum DType {
471 Float32,
472 Float16,
473 BFloat16,
474 Int32,
475 Int64,
476 UInt8,
477}
478
479impl DType {
480 pub fn size_bytes(&self) -> usize {
481 match self {
482 DType::Float32 => 4,
483 DType::Float16 => 2,
484 DType::BFloat16 => 2,
485 DType::Int32 => 4,
486 DType::Int64 => 8,
487 DType::UInt8 => 1,
488 }
489 }
490}
491
492#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum ReduceOp {
495 Sum,
496 Product,
497 Min,
498 Max,
499 Average,
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505 use crate::device::DeviceBackend;
506
507 fn create_test_devices(count: usize) -> Vec<Device> {
508 (0..count)
509 .map(|i| Device {
510 id: i,
511 name: format!("GPU_{}", i),
512 backend: DeviceBackend::Cuda,
513 memory_mb: 16000,
514 compute_units: 80,
515 pci_bus_id: None,
516 partition_id: None,
517 driver_version: None,
518 compute_capability: None,
519 utilization_gpu_pct: None,
520 temperature_c: None,
521 supports_fp16: true,
522 supports_int8: true,
523 cuda_version: Some("12.0".to_string()),
524 })
525 .collect()
526 }
527
528 #[test]
529 fn test_mesh_creation() {
530 let devices = create_test_devices(4);
531 let mesh = DeviceMesh::new(devices);
532
533 assert_eq!(mesh.world_size, 4);
534 assert_eq!(mesh.rank, 0);
535 assert!(mesh.get_group("world").is_some());
536 }
537
538 #[test]
539 fn test_tensor_parallel_topology() {
540 let devices = create_test_devices(8);
541 let topology = MeshTopology::TensorParallel {
542 degree: 4,
543 mesh_shape: (2, 4),
544 };
545
546 let mesh = DeviceMesh::with_topology(devices, topology).unwrap();
547 assert_eq!(mesh.world_size, 8);
548
549 assert!(mesh.get_group("tp_0").is_some());
551 assert!(mesh.get_group("tp_1").is_some());
552 }
553
554 #[test]
555 fn test_group_operations() {
556 let devices = create_test_devices(4);
557 let mut mesh = DeviceMesh::new(devices);
558
559 mesh.add_group("custom".to_string(), vec![0, 2], GroupBackend::Gloo)
561 .unwrap();
562
563 assert!(mesh.get_group("custom").is_some());
564 assert_eq!(mesh.get_group("custom").unwrap().ranks, vec![0, 2]);
565 }
566
567 #[test]
568 fn test_mesh_stats() {
569 let devices = create_test_devices(4);
570 let mesh = DeviceMesh::new(devices);
571
572 let stats = mesh.stats();
573 assert_eq!(stats.world_size, 4);
574 assert_eq!(stats.total_memory_mb, 64000); assert_eq!(stats.total_compute_units, 320); }
577}