Skip to main content

ferrotorch_distributed/
device_mesh.rs

1//! `DeviceMesh` — multi-dimensional rank layout. (#591)
2//!
3//! Mirrors `torch.distributed.DeviceMesh`. A mesh is an n-dimensional
4//! arrangement of ranks: e.g. for 2-D parallelism with 8 GPUs split into
5//! 2-way data parallel × 4-way tensor parallel, the mesh has shape `[2, 4]`
6//! with ranks `[[0, 1, 2, 3], [4, 5, 6, 7]]`.
7//!
8//! The mesh exposes:
9//! - The rank's coordinate within the mesh (`coords()`)
10//! - The list of ranks along each dimension that share the same coords on
11//!   every other dim (`ranks_along_dim()`) — used to construct sub-groups
12//!   for collective ops scoped to one parallelism axis.
13//!
14//! Sub-group / sub-backend creation is a separate concern handled by
15//! [`SubBackend`](crate::backend::SubBackend); this module is
16//! infrastructure-agnostic and just maintains the index math.
17//!
18//! ## REQ status (per `.design/ferrotorch-distributed/device_mesh.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 (DeviceMesh struct) | SHIPPED | `pub struct DeviceMesh { shape, dim_names }` in `device_mesh.rs`; consumer `dtensor.rs` stores `mesh: DeviceMesh` field on `DTensor`. |
23//! | REQ-2 (validating constructor) | SHIPPED | `pub fn new` in `device_mesh.rs` with three validation gates; consumer crate-root re-export at `lib.rs`. |
24//! | REQ-3 (new_with_names) | SHIPPED | `pub fn new_with_names` in `device_mesh.rs`; consumer crate-root re-export at `lib.rs`. |
25//! | REQ-4 (coords / rank_of) | SHIPPED | `pub fn coords` / `pub fn rank_of` in `device_mesh.rs`; consumer `pub fn ranks_along_dim` in same file calls both (production use). |
26//! | REQ-5 (ranks_along_dim) | SHIPPED | `pub fn ranks_along_dim` in `device_mesh.rs`; consumer `pub fn groups_along_dim` in same file invokes it. |
27//! | REQ-6 (groups_along_dim) | SHIPPED | `pub fn groups_along_dim` in `device_mesh.rs`; consumer crate-root re-export at `lib.rs` — boundary API for production training scripts building per-axis `SubBackend`s. |
28//! | REQ-7 (dim_index) | SHIPPED | `pub fn dim_index` in `device_mesh.rs`; consumer crate-root re-export at `lib.rs`. |
29//! | REQ-8 (accessors) | SHIPPED | `pub fn shape` / `pub fn dim_names` / `pub fn ndim` / `pub fn size` in `device_mesh.rs`; consumer `dtensor.rs` calls `mesh.ndim()` for placement-count validation. |
30
31use ferrotorch_core::{FerrotorchError, FerrotorchResult};
32
33/// An n-D arrangement of ranks. The product of `shape` must equal the
34/// world size; ranks are laid out in row-major order (last dim varies
35/// fastest).
36#[derive(Debug, Clone)]
37pub struct DeviceMesh {
38    shape: Vec<usize>,
39    /// Names for each mesh dimension, e.g. `["dp", "tp"]`. Optional —
40    /// callers can pass `None` to skip naming.
41    dim_names: Option<Vec<String>>,
42}
43
44impl DeviceMesh {
45    /// Create a mesh with the given shape. `world_size` must equal the
46    /// product of `shape` (verified eagerly so misconfigured launches
47    /// fail loudly instead of silently splitting wrong).
48    pub fn new(shape: Vec<usize>, world_size: usize) -> FerrotorchResult<Self> {
49        if shape.is_empty() {
50            return Err(FerrotorchError::InvalidArgument {
51                message: "DeviceMesh: shape must be non-empty".into(),
52            });
53        }
54        let prod: usize = shape.iter().product::<usize>().max(1);
55        if prod != world_size {
56            return Err(FerrotorchError::InvalidArgument {
57                message: format!("DeviceMesh: shape product {prod} != world_size {world_size}"),
58            });
59        }
60        for (i, &d) in shape.iter().enumerate() {
61            if d == 0 {
62                return Err(FerrotorchError::InvalidArgument {
63                    message: format!("DeviceMesh: dim {i} is 0"),
64                });
65            }
66        }
67        Ok(Self {
68            shape,
69            dim_names: None,
70        })
71    }
72
73    /// Variant of [`new`] that also attaches names to each dim.
74    /// `dim_names.len()` must match `shape.len()`.
75    pub fn new_with_names(
76        shape: Vec<usize>,
77        dim_names: Vec<String>,
78        world_size: usize,
79    ) -> FerrotorchResult<Self> {
80        if dim_names.len() != shape.len() {
81            return Err(FerrotorchError::InvalidArgument {
82                message: format!(
83                    "DeviceMesh: dim_names.len()={} != shape.len()={}",
84                    dim_names.len(),
85                    shape.len()
86                ),
87            });
88        }
89        let mut mesh = Self::new(shape, world_size)?;
90        mesh.dim_names = Some(dim_names);
91        Ok(mesh)
92    }
93
94    /// Mesh shape (`[dp, tp, ...]`).
95    pub fn shape(&self) -> &[usize] {
96        &self.shape
97    }
98
99    /// Optional dim names.
100    pub fn dim_names(&self) -> Option<&[String]> {
101        self.dim_names.as_deref()
102    }
103
104    /// Dimensionality of the mesh.
105    pub fn ndim(&self) -> usize {
106        self.shape.len()
107    }
108
109    /// Total number of ranks in the mesh (= world_size).
110    pub fn size(&self) -> usize {
111        self.shape.iter().product::<usize>().max(1)
112    }
113
114    /// Resolve a dim name to its index.
115    pub fn dim_index(&self, name: &str) -> FerrotorchResult<usize> {
116        let names = self
117            .dim_names
118            .as_ref()
119            .ok_or(FerrotorchError::InvalidArgument {
120                message: "DeviceMesh: no dim names registered".into(),
121            })?;
122        names
123            .iter()
124            .position(|n| n == name)
125            .ok_or(FerrotorchError::InvalidArgument {
126                message: format!("DeviceMesh: dim name '{name}' not found"),
127            })
128    }
129
130    /// Convert a rank to its multi-dim coordinate within the mesh.
131    /// Row-major: the last dim varies fastest.
132    pub fn coords(&self, rank: usize) -> FerrotorchResult<Vec<usize>> {
133        if rank >= self.size() {
134            return Err(FerrotorchError::InvalidArgument {
135                message: format!(
136                    "DeviceMesh: rank {rank} out of range for mesh size {}",
137                    self.size()
138                ),
139            });
140        }
141        let mut out = vec![0usize; self.shape.len()];
142        let mut r = rank;
143        for i in (0..self.shape.len()).rev() {
144            out[i] = r % self.shape[i];
145            r /= self.shape[i];
146        }
147        Ok(out)
148    }
149
150    /// Inverse of [`coords`]: convert a coordinate to its rank.
151    pub fn rank_of(&self, coords: &[usize]) -> FerrotorchResult<usize> {
152        if coords.len() != self.shape.len() {
153            return Err(FerrotorchError::InvalidArgument {
154                message: format!(
155                    "DeviceMesh: coords len {} != ndim {}",
156                    coords.len(),
157                    self.shape.len()
158                ),
159            });
160        }
161        let mut rank = 0usize;
162        for (i, &c) in coords.iter().enumerate() {
163            if c >= self.shape[i] {
164                return Err(FerrotorchError::InvalidArgument {
165                    message: format!(
166                        "DeviceMesh: coord[{i}] = {c} out of range for dim size {}",
167                        self.shape[i]
168                    ),
169                });
170            }
171            rank = rank * self.shape[i] + c;
172        }
173        Ok(rank)
174    }
175
176    /// All ranks along `dim` that share `rank`'s coordinates on every
177    /// other dim. Useful for constructing per-axis collective groups
178    /// (e.g. one TP group per data-parallel slice).
179    ///
180    /// Returns the ranks in increasing-coord order on `dim`.
181    pub fn ranks_along_dim(&self, dim: usize, rank: usize) -> FerrotorchResult<Vec<usize>> {
182        if dim >= self.shape.len() {
183            return Err(FerrotorchError::InvalidArgument {
184                message: format!(
185                    "DeviceMesh: dim {dim} out of range for ndim {}",
186                    self.shape.len()
187                ),
188            });
189        }
190        let mut coords = self.coords(rank)?;
191        let mut ranks = Vec::with_capacity(self.shape[dim]);
192        for d in 0..self.shape[dim] {
193            coords[dim] = d;
194            ranks.push(self.rank_of(&coords)?);
195        }
196        Ok(ranks)
197    }
198
199    /// All sub-groups along `dim`: a partitioning of the world into
200    /// disjoint groups of `shape[dim]` ranks each, such that every
201    /// group consists of ranks differing only on `dim`. Useful for
202    /// building sub-backends in bulk.
203    pub fn groups_along_dim(&self, dim: usize) -> FerrotorchResult<Vec<Vec<usize>>> {
204        if dim >= self.shape.len() {
205            return Err(FerrotorchError::InvalidArgument {
206                message: format!(
207                    "DeviceMesh: dim {dim} out of range for ndim {}",
208                    self.shape.len()
209                ),
210            });
211        }
212        let world = self.size();
213        let mut groups: Vec<Vec<usize>> = Vec::new();
214        let mut seen = vec![false; world];
215        for rank in 0..world {
216            if seen[rank] {
217                continue;
218            }
219            let g = self.ranks_along_dim(dim, rank)?;
220            for &r in &g {
221                seen[r] = true;
222            }
223            groups.push(g);
224        }
225        Ok(groups)
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn mesh_shape_must_match_world_size() {
235        let err = DeviceMesh::new(vec![2, 3], 5).unwrap_err();
236        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
237    }
238
239    #[test]
240    fn mesh_zero_dim_rejected() {
241        let err = DeviceMesh::new(vec![2, 0], 0).unwrap_err();
242        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
243    }
244
245    #[test]
246    fn mesh_coords_roundtrip_2d() {
247        let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
248        // 2x4 layout in row-major:
249        //   [[0, 1, 2, 3],
250        //    [4, 5, 6, 7]]
251        for rank in 0..8 {
252            let coords = m.coords(rank).unwrap();
253            assert_eq!(coords[0], rank / 4);
254            assert_eq!(coords[1], rank % 4);
255            assert_eq!(m.rank_of(&coords).unwrap(), rank);
256        }
257    }
258
259    #[test]
260    fn mesh_ranks_along_dim_returns_correct_axis() {
261        let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
262        // rank 5 has coords (1, 1).
263        // along dim 0 (data-parallel axis): same col → ranks (0, 1) and (1, 1) = [1, 5]
264        // along dim 1 (tensor-parallel axis): same row → ranks 4..=7
265        assert_eq!(m.ranks_along_dim(0, 5).unwrap(), vec![1, 5]);
266        assert_eq!(m.ranks_along_dim(1, 5).unwrap(), vec![4, 5, 6, 7]);
267    }
268
269    #[test]
270    fn mesh_groups_along_dim_partition_world() {
271        let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
272        // Along dim 0: 4 groups of 2 (one per col).
273        let g0 = m.groups_along_dim(0).unwrap();
274        assert_eq!(g0.len(), 4);
275        for g in &g0 {
276            assert_eq!(g.len(), 2);
277        }
278        // Sorted union covers every rank exactly once.
279        let mut all: Vec<usize> = g0.iter().flatten().copied().collect();
280        all.sort_unstable();
281        assert_eq!(all, (0..8).collect::<Vec<_>>());
282
283        // Along dim 1: 2 groups of 4 (one per row).
284        let g1 = m.groups_along_dim(1).unwrap();
285        assert_eq!(g1.len(), 2);
286        assert_eq!(g1[0], vec![0, 1, 2, 3]);
287        assert_eq!(g1[1], vec![4, 5, 6, 7]);
288    }
289
290    #[test]
291    fn mesh_with_dim_names_resolve_index() {
292        let m = DeviceMesh::new_with_names(vec![2, 4], vec!["dp".to_string(), "tp".to_string()], 8)
293            .unwrap();
294        assert_eq!(m.dim_index("dp").unwrap(), 0);
295        assert_eq!(m.dim_index("tp").unwrap(), 1);
296        assert!(m.dim_index("missing").is_err());
297    }
298
299    #[test]
300    fn mesh_new_with_names_rejects_mismatched_lengths() {
301        let err = DeviceMesh::new_with_names(vec![2, 4], vec!["only_one_name".to_string()], 8)
302            .unwrap_err();
303        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
304    }
305
306    #[test]
307    fn mesh_oob_rank_errors() {
308        let m = DeviceMesh::new(vec![2, 2], 4).unwrap();
309        assert!(m.coords(4).is_err());
310    }
311
312    #[test]
313    fn mesh_oob_coord_errors() {
314        let m = DeviceMesh::new(vec![2, 2], 4).unwrap();
315        let err = m.rank_of(&[0, 5]).unwrap_err();
316        assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
317    }
318
319    #[test]
320    fn mesh_3d_correctness() {
321        // 2x2x3 = 12 ranks, row-major.
322        let m = DeviceMesh::new(vec![2, 2, 3], 12).unwrap();
323        // rank 7: 7 = 1*6 + 0*3 + 1 → coords (1, 0, 1)
324        assert_eq!(m.coords(7).unwrap(), vec![1, 0, 1]);
325        assert_eq!(m.rank_of(&[1, 0, 1]).unwrap(), 7);
326        // along innermost dim from rank 7: same (1, 0, *) → 6, 7, 8
327        assert_eq!(m.ranks_along_dim(2, 7).unwrap(), vec![6, 7, 8]);
328    }
329}