Skip to main content

kapsl_hal/
device_mesh.rs

1use crate::device::Device;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5#[derive(Debug, Clone)]
6pub enum MeshTopology {
7    DataParallel,
8    TensorParallel {
9        degree: usize,
10        mesh_shape: (usize, usize),
11    },
12    PipelineParallel {
13        stages: usize,
14    },
15    Mixed {
16        tp: usize,
17        pp: usize,
18        dp: usize,
19    },
20}
21
22impl MeshTopology {
23    /// Get the expected world size for this topology
24    pub fn expected_world_size(&self) -> usize {
25        match self {
26            MeshTopology::DataParallel => 1, // Can be any size
27            MeshTopology::TensorParallel { degree, .. } => *degree,
28            MeshTopology::PipelineParallel { stages } => *stages,
29            MeshTopology::Mixed { tp, pp, dp } => tp * pp * dp,
30        }
31    }
32
33    /// Validate if the current world size matches the topology
34    pub fn validate(&self, world_size: usize) -> Result<(), String> {
35        match self {
36            MeshTopology::DataParallel => Ok(()), // Any size works
37            MeshTopology::TensorParallel { degree, .. } => {
38                if world_size < *degree {
39                    Err(format!(
40                        "Tensor parallel degree {} requires at least {} devices, got {}",
41                        degree, degree, world_size
42                    ))
43                } else {
44                    Ok(())
45                }
46            }
47            MeshTopology::PipelineParallel { stages } => {
48                if world_size < *stages {
49                    Err(format!(
50                        "Pipeline parallel requires {} stages, got {} devices",
51                        stages, world_size
52                    ))
53                } else {
54                    Ok(())
55                }
56            }
57            MeshTopology::Mixed { tp, pp, dp } => {
58                let required = tp * pp * dp;
59                if world_size < required {
60                    Err(format!(
61                        "Mixed parallelism (TP={}, PP={}, DP={}) requires {} devices, got {}",
62                        tp, pp, dp, required, world_size
63                    ))
64                } else {
65                    Ok(())
66                }
67            }
68        }
69    }
70}
71
72#[derive(Debug, Clone)]
73pub struct ProcessGroup {
74    pub name: String,
75    pub ranks: Vec<usize>,
76    pub backend: GroupBackend,
77}
78
79#[derive(Debug, Clone, Copy)]
80pub enum GroupBackend {
81    Nccl, // NVIDIA GPUs
82    Gloo, // CPU/generic
83    MPI,  // HPC environments
84    Mock, // Testing
85}
86
87/// Device Mesh for distributed ML inference
88///
89/// Manages a set of devices arranged in a logical mesh topology for
90/// data/tensor/pipeline parallelism. Optimized for memory efficiency.
91#[derive(Debug, Clone)]
92pub struct DeviceMesh {
93    /// Devices ordered by mesh coordinate (using Arc for memory efficiency)
94    devices: Arc<Vec<Arc<Device>>>,
95
96    /// Total number of devices in the mesh
97    pub world_size: usize,
98
99    /// This process's global rank
100    pub rank: usize,
101
102    /// Mesh topology configuration
103    pub topology: MeshTopology,
104
105    /// Process groups (map group name -> group definition)
106    groups: Arc<HashMap<String, ProcessGroup>>,
107
108    /// Communication backend handle
109    pub comm: Option<Arc<dyn MeshComm + Send + Sync>>,
110}
111
112impl DeviceMesh {
113    /// Create a new device mesh from a list of devices
114    pub fn new(devices: Vec<Device>) -> Self {
115        let world_size = devices.len();
116        let devices: Vec<Arc<Device>> = devices.into_iter().map(Arc::new).collect();
117
118        // Default to DataParallel topology
119        let topology = MeshTopology::DataParallel;
120
121        // Default groups: "world" contains all ranks
122        let mut groups = HashMap::new();
123        groups.insert(
124            "world".to_string(),
125            ProcessGroup {
126                name: "world".to_string(),
127                ranks: (0..world_size).collect(),
128                backend: GroupBackend::Mock,
129            },
130        );
131
132        Self {
133            devices: Arc::new(devices),
134            world_size,
135            rank: 0, // Default to rank 0 for single-node
136            topology,
137            groups: Arc::new(groups),
138            comm: None,
139        }
140    }
141
142    /// Create a new device mesh with a mock communication backend for testing
143    pub fn new_with_mock_comm(devices: Vec<Device>, rank: usize) -> Self {
144        use crate::mock_comm::MockComm;
145
146        let world_size = devices.len();
147        let mut mesh = Self::new(devices);
148        mesh.rank = rank;
149        mesh.comm = Some(Arc::new(MockComm::new(rank, world_size)));
150        mesh
151    }
152
153    /// Create a new device mesh with NCCL communication backend for real GPU communication
154    ///
155    /// # Arguments
156    /// * `devices` - List of devices in the mesh
157    /// * `nccl_id` - NCCL unique ID (generate with `NcclComm::generate_id()` on rank 0)
158    /// * `rank` - This process's rank
159    ///
160    /// # Requirements
161    /// - CUDA device available
162    /// - NCCL library installed
163    /// - All ranks must call this with the same `nccl_id`
164    #[cfg(feature = "nccl")]
165    pub fn new_with_nccl(
166        devices: Vec<Device>,
167        nccl_id: &crate::nccl_comm::cudarc::nccl::Id,
168        rank: usize,
169    ) -> Result<Self, String> {
170        use crate::nccl_comm::NcclComm;
171        use cudarc::driver::CudaDevice;
172
173        let world_size = devices.len();
174
175        // Get the CUDA device for this rank
176        let cuda_device = CudaDevice::new(rank)
177            .map_err(|e| format!("Failed to get CUDA device {}: {:?}", rank, e))?;
178
179        let nccl_comm = NcclComm::new(cuda_device, nccl_id, rank, world_size)?;
180
181        let mut mesh = Self::new(devices);
182        mesh.rank = rank;
183        mesh.comm = Some(Arc::new(nccl_comm));
184        Ok(mesh)
185    }
186
187    /// Create a mesh with specific topology
188    pub fn with_topology(devices: Vec<Device>, topology: MeshTopology) -> Result<Self, String> {
189        let world_size = devices.len();
190
191        // Validate topology against world size
192        topology.validate(world_size)?;
193
194        let devices: Vec<Arc<Device>> = devices.into_iter().map(Arc::new).collect();
195
196        let mut groups = HashMap::new();
197
198        // Create world group
199        groups.insert(
200            "world".to_string(),
201            ProcessGroup {
202                name: "world".to_string(),
203                ranks: (0..world_size).collect(),
204                backend: GroupBackend::Mock,
205            },
206        );
207
208        // Create topology-specific groups
209        match &topology {
210            MeshTopology::TensorParallel { degree, .. } => {
211                // Create TP groups
212                for i in 0..world_size / degree {
213                    let start = i * degree;
214                    let ranks: Vec<usize> = (start..start + degree).collect();
215                    groups.insert(
216                        format!("tp_{}", i),
217                        ProcessGroup {
218                            name: format!("tp_{}", i),
219                            ranks,
220                            backend: GroupBackend::Nccl,
221                        },
222                    );
223                }
224            }
225            MeshTopology::PipelineParallel { stages } => {
226                // Each stage is a group
227                for stage in 0..*stages {
228                    groups.insert(
229                        format!("pp_stage_{}", stage),
230                        ProcessGroup {
231                            name: format!("pp_stage_{}", stage),
232                            ranks: vec![stage],
233                            backend: GroupBackend::Gloo,
234                        },
235                    );
236                }
237            }
238            MeshTopology::Mixed { tp, pp, dp } => {
239                // Create TP groups
240                let tp_size = *tp;
241                for dp_idx in 0..*dp {
242                    for pp_idx in 0..*pp {
243                        let base = (dp_idx * pp + pp_idx) * tp_size;
244                        let ranks: Vec<usize> = (base..base + tp_size).collect();
245                        groups.insert(
246                            format!("tp_dp{}_pp{}", dp_idx, pp_idx),
247                            ProcessGroup {
248                                name: format!("tp_dp{}_pp{}", dp_idx, pp_idx),
249                                ranks,
250                                backend: GroupBackend::Nccl,
251                            },
252                        );
253                    }
254                }
255            }
256            _ => {}
257        }
258
259        Ok(Self {
260            devices: Arc::new(devices),
261            world_size,
262            rank: 0,
263            topology,
264            groups: Arc::new(groups),
265            comm: None,
266        })
267    }
268
269    /// Set the rank for this process
270    pub fn set_rank(&mut self, rank: usize) -> Result<(), String> {
271        if rank >= self.world_size {
272            return Err(format!(
273                "Rank {} out of bounds for world size {}",
274                rank, self.world_size
275            ));
276        }
277        self.rank = rank;
278        Ok(())
279    }
280
281    /// Get device by rank (memory efficient - returns Arc clone)
282    pub fn get_device(&self, rank: usize) -> Option<Arc<Device>> {
283        self.devices.get(rank).cloned()
284    }
285
286    /// Get the local device for this process
287    pub fn local_device(&self) -> Option<Arc<Device>> {
288        self.get_device(self.rank)
289    }
290
291    /// Get all devices (returns Arc to avoid cloning the entire Vec)
292    pub fn all_devices(&self) -> Arc<Vec<Arc<Device>>> {
293        self.devices.clone()
294    }
295
296    /// Get devices for a specific backend type (memory efficient)
297    pub fn devices_by_backend(&self, backend: crate::device::DeviceBackend) -> Vec<Arc<Device>> {
298        self.devices
299            .iter()
300            .filter(|d| std::mem::discriminant(&d.backend) == std::mem::discriminant(&backend))
301            .cloned()
302            .collect()
303    }
304
305    /// Get devices in a specific process group
306    pub fn devices_in_group(&self, group_name: &str) -> Result<Vec<Arc<Device>>, String> {
307        let group = self
308            .groups
309            .get(group_name)
310            .ok_or_else(|| format!("Group '{}' not found", group_name))?;
311
312        Ok(group
313            .ranks
314            .iter()
315            .filter_map(|&rank| self.get_device(rank))
316            .collect())
317    }
318
319    /// Add a custom process group
320    pub fn add_group(
321        &mut self,
322        name: String,
323        ranks: Vec<usize>,
324        backend: GroupBackend,
325    ) -> Result<(), String> {
326        // Validate ranks
327        for &rank in &ranks {
328            if rank >= self.world_size {
329                return Err(format!("Rank {} out of bounds", rank));
330            }
331        }
332
333        let groups = Arc::make_mut(&mut self.groups);
334        groups.insert(
335            name.clone(),
336            ProcessGroup {
337                name,
338                ranks,
339                backend,
340            },
341        );
342
343        Ok(())
344    }
345
346    /// Get a process group by name
347    pub fn get_group(&self, name: &str) -> Option<&ProcessGroup> {
348        self.groups.get(name)
349    }
350
351    /// List all group names
352    pub fn group_names(&self) -> Vec<String> {
353        self.groups.keys().cloned().collect()
354    }
355
356    /// Check if this rank is in a specific group
357    pub fn in_group(&self, group_name: &str) -> bool {
358        self.groups
359            .get(group_name)
360            .map(|g| g.ranks.contains(&self.rank))
361            .unwrap_or(false)
362    }
363
364    /// Get rank within a group (local rank)
365    pub fn group_rank(&self, group_name: &str) -> Option<usize> {
366        self.groups
367            .get(group_name)
368            .and_then(|g| g.ranks.iter().position(|&r| r == self.rank))
369    }
370
371    /// Set the communication backend
372    pub fn set_comm(&mut self, comm: Arc<dyn MeshComm + Send + Sync>) {
373        self.comm = Some(comm);
374    }
375
376    /// Get total memory across all devices
377    pub fn total_memory_mb(&self) -> u64 {
378        self.devices.iter().map(|d| d.memory_mb).sum()
379    }
380
381    /// Get total compute units across all devices
382    pub fn total_compute_units(&self) -> u32 {
383        self.devices.iter().map(|d| d.compute_units).sum()
384    }
385
386    /// Reshape the mesh to a different topology
387    pub fn reshape(&mut self, new_topology: MeshTopology) -> Result<(), String> {
388        new_topology.validate(self.world_size)?;
389        self.topology = new_topology;
390        Ok(())
391    }
392
393    /// Get mesh statistics
394    pub fn stats(&self) -> MeshStats {
395        let backend_counts = self.count_backends();
396
397        MeshStats {
398            world_size: self.world_size,
399            total_memory_mb: self.total_memory_mb(),
400            total_compute_units: self.total_compute_units(),
401            backend_distribution: backend_counts,
402            group_count: self.groups.len(),
403        }
404    }
405
406    fn count_backends(&self) -> HashMap<String, usize> {
407        let mut counts = HashMap::new();
408        for device in self.devices.iter() {
409            let backend_name = format!("{:?}", device.backend);
410            *counts.entry(backend_name).or_insert(0) += 1;
411        }
412        counts
413    }
414}
415
416/// Statistics about the device mesh
417#[derive(Debug, Clone)]
418pub struct MeshStats {
419    pub world_size: usize,
420    pub total_memory_mb: u64,
421    pub total_compute_units: u32,
422    pub backend_distribution: HashMap<String, usize>,
423    pub group_count: usize,
424}
425
426/// Communication operations for distributed execution
427pub trait MeshComm: std::fmt::Debug {
428    /// All-reduce operation: reduce values across all ranks in a group
429    fn all_reduce(
430        &self,
431        buf: &mut [u8],
432        dtype: DType,
433        op: ReduceOp,
434        group: &str,
435    ) -> Result<(), String>;
436
437    /// All-gather: gather data from all ranks
438    fn all_gather(
439        &self,
440        local: &[u8],
441        out: &mut [u8],
442        dtype: DType,
443        group: &str,
444    ) -> Result<(), String>;
445
446    /// Broadcast from root rank to all ranks in group
447    fn broadcast(&self, buf: &mut [u8], root_rank: usize, group: &str) -> Result<(), String>;
448
449    /// Reduce-scatter: reduce and distribute results
450    fn reduce_scatter(
451        &self,
452        buf: &mut [u8],
453        out: &mut [u8],
454        op: ReduceOp,
455        group: &str,
456    ) -> Result<(), String>;
457
458    /// Barrier synchronization
459    fn barrier(&self, group: &str) -> Result<(), String>;
460
461    /// Send data to a specific rank
462    fn send(&self, buf: &[u8], dest_rank: usize) -> Result<(), String>;
463
464    /// Receive data from a specific rank
465    fn recv(&self, buf: &mut [u8], src_rank: usize) -> Result<(), String>;
466}
467
468/// Data type for communication operations
469#[derive(Debug, Clone, Copy, PartialEq, Eq)]
470pub enum DType {
471    Float32,
472    Float16,
473    BFloat16,
474    Int32,
475    Int64,
476    UInt8,
477}
478
479impl DType {
480    pub fn size_bytes(&self) -> usize {
481        match self {
482            DType::Float32 => 4,
483            DType::Float16 => 2,
484            DType::BFloat16 => 2,
485            DType::Int32 => 4,
486            DType::Int64 => 8,
487            DType::UInt8 => 1,
488        }
489    }
490}
491
492/// Reduction operations
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum ReduceOp {
495    Sum,
496    Product,
497    Min,
498    Max,
499    Average,
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505    use crate::device::DeviceBackend;
506
507    fn create_test_devices(count: usize) -> Vec<Device> {
508        (0..count)
509            .map(|i| Device {
510                id: i,
511                name: format!("GPU_{}", i),
512                backend: DeviceBackend::Cuda,
513                memory_mb: 16000,
514                compute_units: 80,
515                pci_bus_id: None,
516                partition_id: None,
517                driver_version: None,
518                compute_capability: None,
519                utilization_gpu_pct: None,
520                temperature_c: None,
521                supports_fp16: true,
522                supports_int8: true,
523                cuda_version: Some("12.0".to_string()),
524            })
525            .collect()
526    }
527
528    #[test]
529    fn test_mesh_creation() {
530        let devices = create_test_devices(4);
531        let mesh = DeviceMesh::new(devices);
532
533        assert_eq!(mesh.world_size, 4);
534        assert_eq!(mesh.rank, 0);
535        assert!(mesh.get_group("world").is_some());
536    }
537
538    #[test]
539    fn test_tensor_parallel_topology() {
540        let devices = create_test_devices(8);
541        let topology = MeshTopology::TensorParallel {
542            degree: 4,
543            mesh_shape: (2, 4),
544        };
545
546        let mesh = DeviceMesh::with_topology(devices, topology).unwrap();
547        assert_eq!(mesh.world_size, 8);
548
549        // Should have 2 TP groups (8 devices / 4 degree)
550        assert!(mesh.get_group("tp_0").is_some());
551        assert!(mesh.get_group("tp_1").is_some());
552    }
553
554    #[test]
555    fn test_group_operations() {
556        let devices = create_test_devices(4);
557        let mut mesh = DeviceMesh::new(devices);
558
559        // Add custom group
560        mesh.add_group("custom".to_string(), vec![0, 2], GroupBackend::Gloo)
561            .unwrap();
562
563        assert!(mesh.get_group("custom").is_some());
564        assert_eq!(mesh.get_group("custom").unwrap().ranks, vec![0, 2]);
565    }
566
567    #[test]
568    fn test_mesh_stats() {
569        let devices = create_test_devices(4);
570        let mesh = DeviceMesh::new(devices);
571
572        let stats = mesh.stats();
573        assert_eq!(stats.world_size, 4);
574        assert_eq!(stats.total_memory_mb, 64000); // 4 * 16000
575        assert_eq!(stats.total_compute_units, 320); // 4 * 80
576    }
577}