1use bon::Builder;
2use serde::{Deserialize, Serialize};
3use std::ops::Deref;
4use thiserror::Error;
5
6#[derive(Debug, Clone)]
11pub struct NetworkConfig {
12 hosts: Vec<NodeConfig>,
13}
14
15#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
17#[builder(on(String, into))]
18pub struct NodeConfig {
19 pub hostname: String,
21 pub port: u16,
23 pub ibdev: String,
25 pub rankid: usize,
27 #[builder(default)]
29 #[serde(skip_serializing_if = "String::is_empty", default)]
30 pub comment: String,
31}
32
33impl NetworkConfig {
34 pub fn builder() -> RawNetworkConfig {
36 RawNetworkConfig { hosts: vec![] }
37 }
38}
39
40#[derive(Debug, Copy, Clone, Error)]
42pub enum NetworkConfigError {
43 #[error("Empty network")]
45 EmptyNetwork,
46 #[error("First rank id is not zero")]
49 FirstRankNotZero,
50 #[error("Ranks are non sequential, {gap_rank} is missing")]
52 NonSequentialRanks { gap_rank: usize },
53 #[error("Rank {dup_rank} appears multiple times")]
55 DuplicatedRank { dup_rank: usize },
56}
57
58#[derive(Debug, Clone, Serialize, Deserialize, Default)]
76pub struct RawNetworkConfig {
77 hosts: Vec<NodeConfig>,
78}
79
80impl RawNetworkConfig {
81 pub fn add_node(mut self, node: NodeConfig) -> Self {
83 self.hosts.push(node);
84 self
85 }
86
87 pub fn truncate(mut self, num_nodes: usize) -> Self {
89 self.hosts.truncate(num_nodes);
90 self
91 }
92
93 pub fn build(mut self) -> Result<NetworkConfig, NetworkConfigError> {
97 self.hosts.sort_by_key(|n| n.rankid);
98
99 if self.hosts.is_empty() {
101 return Err(NetworkConfigError::EmptyNetwork);
102 }
103
104 if self.hosts.first().map(|h| h.rankid) != Some(0) {
106 return Err(NetworkConfigError::FirstRankNotZero);
107 }
108
109 for i in 1..self.hosts.len() {
110 let node_config = &self.hosts[i];
111
112 if node_config.rankid == self.hosts[i - 1].rankid {
114 return Err(NetworkConfigError::DuplicatedRank {
115 dup_rank: node_config.rankid,
116 });
117 }
118
119 if node_config.rankid != i {
121 return Err(NetworkConfigError::NonSequentialRanks { gap_rank: i });
122 }
123 }
124
125 Ok(NetworkConfig { hosts: self.hosts })
126 }
127}
128
129impl Deref for NetworkConfig {
130 type Target = [NodeConfig];
131
132 fn deref(&self) -> &Self::Target {
133 self.hosts.as_slice()
134 }
135}
136
137impl<'a> IntoIterator for &'a NetworkConfig {
138 type Item = &'a NodeConfig;
139 type IntoIter = std::slice::Iter<'a, NodeConfig>;
140
141 fn into_iter(self) -> Self::IntoIter {
142 self.iter()
143 }
144}
145
146impl NetworkConfig {
147 pub fn world_size(&self) -> usize {
149 self.hosts.len()
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156
157 #[test]
158 fn valid_network_config() {
159 let config_builder = RawNetworkConfig {
160 hosts: vec![
161 NodeConfig {
162 hostname: "tdeb02".to_string(),
163 port: 10000,
164 ibdev: "mlx5_0".to_string(),
165 rankid: 0,
166 comment: String::new(),
167 },
168 NodeConfig {
169 hostname: "tdeb02".to_string(),
170 port: 10001,
171 ibdev: "mlx5_0".to_string(),
172 rankid: 1,
173 comment: String::new(),
174 },
175 ],
176 };
177
178 let config = config_builder.build().unwrap();
179 assert_eq!(config.len(), 2);
180 assert_eq!(config[0].rankid, 0);
181 assert_eq!(config[1].rankid, 1);
182 }
183
184 #[test]
185 fn valid_network_config_out_of_order() {
186 let config_builder = RawNetworkConfig {
187 hosts: vec![
188 NodeConfig {
189 hostname: "node2".to_string(),
190 port: 10001,
191 ibdev: "mlx5_0".to_string(),
192 rankid: 1,
193 comment: String::new(),
194 },
195 NodeConfig {
196 hostname: "node1".to_string(),
197 port: 10000,
198 ibdev: "mlx5_0".to_string(),
199 rankid: 0,
200 comment: String::new(),
201 },
202 NodeConfig {
203 hostname: "node3".to_string(),
204 port: 10002,
205 ibdev: "mlx5_0".to_string(),
206 rankid: 2,
207 comment: String::new(),
208 },
209 ],
210 };
211
212 let config = config_builder.build().unwrap();
213 assert_eq!(config[0].rankid, 0);
215 assert_eq!(config[0].hostname, "node1");
216 assert_eq!(config[1].rankid, 1);
217 assert_eq!(config[1].hostname, "node2");
218 assert_eq!(config[2].rankid, 2);
219 assert_eq!(config[2].hostname, "node3");
220 }
221
222 #[test]
223 fn empty_node_config() {
224 let config_builder = RawNetworkConfig { hosts: vec![] };
225 assert!(matches!(
226 config_builder.build(),
227 Err(NetworkConfigError::EmptyNetwork)
228 ));
229 }
230
231 #[test]
232 fn single_node_config() {
233 let config_builder = RawNetworkConfig {
234 hosts: vec![NodeConfig {
235 hostname: "single".to_string(),
236 port: 8080,
237 ibdev: "mlx5_1".to_string(),
238 rankid: 0,
239 comment: String::new(),
240 }],
241 };
242
243 let config = config_builder.build().unwrap();
244 assert_eq!(config.len(), 1);
245 assert_eq!(config[0].rankid, 0);
246 }
247
248 #[test]
249 fn missing_rank_zero() {
250 let config_builder = RawNetworkConfig {
251 hosts: vec![
252 NodeConfig {
253 hostname: "node1".to_string(),
254 port: 10000,
255 ibdev: "mlx5_0".to_string(),
256 rankid: 1,
257 comment: String::new(),
258 },
259 NodeConfig {
260 hostname: "node2".to_string(),
261 port: 10001,
262 ibdev: "mlx5_0".to_string(),
263 rankid: 2,
264 comment: String::new(),
265 },
266 ],
267 };
268
269 assert!(matches!(
270 config_builder.build(),
271 Err(NetworkConfigError::FirstRankNotZero)
272 ));
273 }
274
275 #[test]
276 fn non_sequential_ranks() {
277 let config_builder = RawNetworkConfig {
278 hosts: vec![
279 NodeConfig {
280 hostname: "node1".to_string(),
281 port: 10000,
282 ibdev: "mlx5_0".to_string(),
283 rankid: 0,
284 comment: String::new(),
285 },
286 NodeConfig {
287 hostname: "node2".to_string(),
288 port: 10001,
289 ibdev: "mlx5_0".to_string(),
290 rankid: 2, comment: String::new(),
292 },
293 ],
294 };
295
296 assert!(matches!(
297 config_builder.build(),
298 Err(NetworkConfigError::NonSequentialRanks { gap_rank: 1 })
299 ));
300 }
301
302 #[test]
303 fn non_sequential_ranks_before_duplicate() {
304 let config_builder = RawNetworkConfig {
307 hosts: vec![
308 NodeConfig {
309 hostname: "node1".to_string(),
310 port: 10000,
311 ibdev: "mlx5_0".to_string(),
312 rankid: 0,
313 comment: String::new(),
314 },
315 NodeConfig {
316 hostname: "node2".to_string(),
317 port: 10001,
318 ibdev: "mlx5_0".to_string(),
319 rankid: 3, comment: String::new(),
321 },
322 NodeConfig {
323 hostname: "node3".to_string(),
324 port: 10002,
325 ibdev: "mlx5_0".to_string(),
326 rankid: 3, comment: String::new(),
328 },
329 ],
330 };
331
332 assert!(matches!(
333 config_builder.build(),
334 Err(NetworkConfigError::NonSequentialRanks { gap_rank: 1 })
335 ));
336 }
337
338 #[test]
339 fn duplicate_ranks_before_non_sequential() {
340 let config_builder = RawNetworkConfig {
343 hosts: vec![
344 NodeConfig {
345 hostname: "node1".to_string(),
346 port: 10000,
347 ibdev: "mlx5_0".to_string(),
348 rankid: 0,
349 comment: String::new(),
350 },
351 NodeConfig {
352 hostname: "node2".to_string(),
353 port: 10001,
354 ibdev: "mlx5_0".to_string(),
355 rankid: 1,
356 comment: String::new(),
357 },
358 NodeConfig {
359 hostname: "node3".to_string(),
360 port: 10002,
361 ibdev: "mlx5_0".to_string(),
362 rankid: 1, comment: String::new(),
364 },
365 NodeConfig {
366 hostname: "node4".to_string(),
367 port: 10003,
368 ibdev: "mlx5_0".to_string(),
369 rankid: 3, comment: String::new(),
371 },
372 ],
373 };
374
375 assert!(matches!(
376 config_builder.build(),
377 Err(NetworkConfigError::DuplicatedRank { dup_rank: 1 })
378 ));
379 }
380
381 #[test]
382 fn deref_access() {
383 let config_builder = RawNetworkConfig {
384 hosts: vec![
385 NodeConfig {
386 hostname: "test1".to_string(),
387 port: 9000,
388 ibdev: "mlx5_0".to_string(),
389 rankid: 0,
390 comment: String::new(),
391 },
392 NodeConfig {
393 hostname: "test2".to_string(),
394 port: 9001,
395 ibdev: "mlx5_0".to_string(),
396 rankid: 1,
397 comment: String::new(),
398 },
399 ],
400 };
401
402 let config = config_builder.build().unwrap();
403
404 assert_eq!(config.len(), 2);
406 assert_eq!(config[0].hostname, "test1");
407 assert_eq!(config[1].hostname, "test2");
408 assert_eq!(config.first().unwrap().port, 9000);
409 assert_eq!(config.last().unwrap().port, 9001);
410
411 let hostnames: Vec<&String> = config.iter().map(|node| &node.hostname).collect();
413 assert_eq!(hostnames, vec!["test1", "test2"]);
414 }
415}