1use ferrotorch_core::{FerrotorchError, FerrotorchResult};
32
33#[derive(Debug, Clone)]
37pub struct DeviceMesh {
38 shape: Vec<usize>,
39 dim_names: Option<Vec<String>>,
42}
43
44impl DeviceMesh {
45 pub fn new(shape: Vec<usize>, world_size: usize) -> FerrotorchResult<Self> {
49 if shape.is_empty() {
50 return Err(FerrotorchError::InvalidArgument {
51 message: "DeviceMesh: shape must be non-empty".into(),
52 });
53 }
54 let prod: usize = shape.iter().product::<usize>().max(1);
55 if prod != world_size {
56 return Err(FerrotorchError::InvalidArgument {
57 message: format!("DeviceMesh: shape product {prod} != world_size {world_size}"),
58 });
59 }
60 for (i, &d) in shape.iter().enumerate() {
61 if d == 0 {
62 return Err(FerrotorchError::InvalidArgument {
63 message: format!("DeviceMesh: dim {i} is 0"),
64 });
65 }
66 }
67 Ok(Self {
68 shape,
69 dim_names: None,
70 })
71 }
72
73 pub fn new_with_names(
76 shape: Vec<usize>,
77 dim_names: Vec<String>,
78 world_size: usize,
79 ) -> FerrotorchResult<Self> {
80 if dim_names.len() != shape.len() {
81 return Err(FerrotorchError::InvalidArgument {
82 message: format!(
83 "DeviceMesh: dim_names.len()={} != shape.len()={}",
84 dim_names.len(),
85 shape.len()
86 ),
87 });
88 }
89 let mut mesh = Self::new(shape, world_size)?;
90 mesh.dim_names = Some(dim_names);
91 Ok(mesh)
92 }
93
94 pub fn shape(&self) -> &[usize] {
96 &self.shape
97 }
98
99 pub fn dim_names(&self) -> Option<&[String]> {
101 self.dim_names.as_deref()
102 }
103
104 pub fn ndim(&self) -> usize {
106 self.shape.len()
107 }
108
109 pub fn size(&self) -> usize {
111 self.shape.iter().product::<usize>().max(1)
112 }
113
114 pub fn dim_index(&self, name: &str) -> FerrotorchResult<usize> {
116 let names = self
117 .dim_names
118 .as_ref()
119 .ok_or(FerrotorchError::InvalidArgument {
120 message: "DeviceMesh: no dim names registered".into(),
121 })?;
122 names
123 .iter()
124 .position(|n| n == name)
125 .ok_or(FerrotorchError::InvalidArgument {
126 message: format!("DeviceMesh: dim name '{name}' not found"),
127 })
128 }
129
130 pub fn coords(&self, rank: usize) -> FerrotorchResult<Vec<usize>> {
133 if rank >= self.size() {
134 return Err(FerrotorchError::InvalidArgument {
135 message: format!(
136 "DeviceMesh: rank {rank} out of range for mesh size {}",
137 self.size()
138 ),
139 });
140 }
141 let mut out = vec![0usize; self.shape.len()];
142 let mut r = rank;
143 for i in (0..self.shape.len()).rev() {
144 out[i] = r % self.shape[i];
145 r /= self.shape[i];
146 }
147 Ok(out)
148 }
149
150 pub fn rank_of(&self, coords: &[usize]) -> FerrotorchResult<usize> {
152 if coords.len() != self.shape.len() {
153 return Err(FerrotorchError::InvalidArgument {
154 message: format!(
155 "DeviceMesh: coords len {} != ndim {}",
156 coords.len(),
157 self.shape.len()
158 ),
159 });
160 }
161 let mut rank = 0usize;
162 for (i, &c) in coords.iter().enumerate() {
163 if c >= self.shape[i] {
164 return Err(FerrotorchError::InvalidArgument {
165 message: format!(
166 "DeviceMesh: coord[{i}] = {c} out of range for dim size {}",
167 self.shape[i]
168 ),
169 });
170 }
171 rank = rank * self.shape[i] + c;
172 }
173 Ok(rank)
174 }
175
176 pub fn ranks_along_dim(&self, dim: usize, rank: usize) -> FerrotorchResult<Vec<usize>> {
182 if dim >= self.shape.len() {
183 return Err(FerrotorchError::InvalidArgument {
184 message: format!(
185 "DeviceMesh: dim {dim} out of range for ndim {}",
186 self.shape.len()
187 ),
188 });
189 }
190 let mut coords = self.coords(rank)?;
191 let mut ranks = Vec::with_capacity(self.shape[dim]);
192 for d in 0..self.shape[dim] {
193 coords[dim] = d;
194 ranks.push(self.rank_of(&coords)?);
195 }
196 Ok(ranks)
197 }
198
199 pub fn groups_along_dim(&self, dim: usize) -> FerrotorchResult<Vec<Vec<usize>>> {
204 if dim >= self.shape.len() {
205 return Err(FerrotorchError::InvalidArgument {
206 message: format!(
207 "DeviceMesh: dim {dim} out of range for ndim {}",
208 self.shape.len()
209 ),
210 });
211 }
212 let world = self.size();
213 let mut groups: Vec<Vec<usize>> = Vec::new();
214 let mut seen = vec![false; world];
215 for rank in 0..world {
216 if seen[rank] {
217 continue;
218 }
219 let g = self.ranks_along_dim(dim, rank)?;
220 for &r in &g {
221 seen[r] = true;
222 }
223 groups.push(g);
224 }
225 Ok(groups)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn mesh_shape_must_match_world_size() {
235 let err = DeviceMesh::new(vec![2, 3], 5).unwrap_err();
236 assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
237 }
238
239 #[test]
240 fn mesh_zero_dim_rejected() {
241 let err = DeviceMesh::new(vec![2, 0], 0).unwrap_err();
242 assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
243 }
244
245 #[test]
246 fn mesh_coords_roundtrip_2d() {
247 let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
248 for rank in 0..8 {
252 let coords = m.coords(rank).unwrap();
253 assert_eq!(coords[0], rank / 4);
254 assert_eq!(coords[1], rank % 4);
255 assert_eq!(m.rank_of(&coords).unwrap(), rank);
256 }
257 }
258
259 #[test]
260 fn mesh_ranks_along_dim_returns_correct_axis() {
261 let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
262 assert_eq!(m.ranks_along_dim(0, 5).unwrap(), vec![1, 5]);
266 assert_eq!(m.ranks_along_dim(1, 5).unwrap(), vec![4, 5, 6, 7]);
267 }
268
269 #[test]
270 fn mesh_groups_along_dim_partition_world() {
271 let m = DeviceMesh::new(vec![2, 4], 8).unwrap();
272 let g0 = m.groups_along_dim(0).unwrap();
274 assert_eq!(g0.len(), 4);
275 for g in &g0 {
276 assert_eq!(g.len(), 2);
277 }
278 let mut all: Vec<usize> = g0.iter().flatten().copied().collect();
280 all.sort_unstable();
281 assert_eq!(all, (0..8).collect::<Vec<_>>());
282
283 let g1 = m.groups_along_dim(1).unwrap();
285 assert_eq!(g1.len(), 2);
286 assert_eq!(g1[0], vec![0, 1, 2, 3]);
287 assert_eq!(g1[1], vec![4, 5, 6, 7]);
288 }
289
290 #[test]
291 fn mesh_with_dim_names_resolve_index() {
292 let m = DeviceMesh::new_with_names(vec![2, 4], vec!["dp".to_string(), "tp".to_string()], 8)
293 .unwrap();
294 assert_eq!(m.dim_index("dp").unwrap(), 0);
295 assert_eq!(m.dim_index("tp").unwrap(), 1);
296 assert!(m.dim_index("missing").is_err());
297 }
298
299 #[test]
300 fn mesh_new_with_names_rejects_mismatched_lengths() {
301 let err = DeviceMesh::new_with_names(vec![2, 4], vec!["only_one_name".to_string()], 8)
302 .unwrap_err();
303 assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
304 }
305
306 #[test]
307 fn mesh_oob_rank_errors() {
308 let m = DeviceMesh::new(vec![2, 2], 4).unwrap();
309 assert!(m.coords(4).is_err());
310 }
311
312 #[test]
313 fn mesh_oob_coord_errors() {
314 let m = DeviceMesh::new(vec![2, 2], 4).unwrap();
315 let err = m.rank_of(&[0, 5]).unwrap_err();
316 assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
317 }
318
319 #[test]
320 fn mesh_3d_correctness() {
321 let m = DeviceMesh::new(vec![2, 2, 3], 12).unwrap();
323 assert_eq!(m.coords(7).unwrap(), vec![1, 0, 1]);
325 assert_eq!(m.rank_of(&[1, 0, 1]).unwrap(), 7);
326 assert_eq!(m.ranks_along_dim(2, 7).unwrap(), vec![6, 7, 8]);
328 }
329}