Skip to main content

ibverbs_rs/network/
config.rs

1use bon::Builder;
2use serde::{Deserialize, Serialize};
3use std::ops::Deref;
4use thiserror::Error;
5
6/// A validated network topology describing all nodes that participate in RDMA communication.
7///
8/// Nodes are sorted by rank and indexed via [`Deref<Target = [NodeConfig]>`](std::ops::Deref).
9/// Build one with [`NetworkConfig::builder`].
10#[derive(Debug, Clone)]
11pub struct NetworkConfig {
12    hosts: Vec<NodeConfig>,
13}
14
15/// Configuration for a single node in the network.
16#[derive(Debug, Clone, Builder, Serialize, Deserialize)]
17#[builder(on(String, into))]
18pub struct NodeConfig {
19    /// Network hostname or IP address.
20    pub hostname: String,
21    /// TCP port used for the initial endpoint exchange.
22    pub port: u16,
23    /// Name of the RDMA device to use (e.g. `"mlx5_0"`).
24    pub ibdev: String,
25    /// Unique rank identifier, must be sequential starting from 0.
26    pub rankid: usize,
27    /// Optional human-readable label.
28    #[builder(default)]
29    #[serde(skip_serializing_if = "String::is_empty", default)]
30    pub comment: String,
31}
32
33impl NetworkConfig {
34    /// Returns a [`RawNetworkConfig`] builder for constructing a network topology.
35    pub fn builder() -> RawNetworkConfig {
36        RawNetworkConfig { hosts: vec![] }
37    }
38}
39
40/// An error returned by [`RawNetworkConfig::build`] when the configuration is invalid.
41#[derive(Debug, Copy, Clone, Error)]
42pub enum NetworkConfigError {
43    /// No nodes were added to the configuration.
44    #[error("Empty network")]
45    EmptyNetwork,
46    /// The lowest rank present is not `0`. Ranks must be a contiguous sequence
47    /// starting at zero.
48    #[error("First rank id is not zero")]
49    FirstRankNotZero,
50    /// There is a gap in the rank sequence. `gap_rank` is the first missing rank.
51    #[error("Ranks are non sequential, {gap_rank} is missing")]
52    NonSequentialRanks { gap_rank: usize },
53    /// The same rank appears more than once. `dup_rank` is the repeated rank.
54    #[error("Rank {dup_rank} appears multiple times")]
55    DuplicatedRank { dup_rank: usize },
56}
57
58/// An unvalidated network configuration. Add nodes with [`add_node`](Self::add_node),
59/// then call [`build`](Self::build) to validate and produce a [`NetworkConfig`].
60///
61/// # JSON format
62///
63/// `RawNetworkConfig` implements `Serialize`/`Deserialize` and can be loaded from JSON:
64///
65/// ```json
66/// {
67///   "hosts": [
68///     { "hostname": "node1", "port": 10000, "ibdev": "mlx5_0", "rankid": 0 },
69///     { "hostname": "node2", "port": 10000, "ibdev": "mlx5_0", "rankid": 1 }
70///   ]
71/// }
72/// ```
73///
74/// The optional `comment` field is omitted from serialization when empty.
75#[derive(Debug, Clone, Serialize, Deserialize, Default)]
76pub struct RawNetworkConfig {
77    hosts: Vec<NodeConfig>,
78}
79
80impl RawNetworkConfig {
81    /// Appends a node to the configuration.
82    pub fn add_node(mut self, node: NodeConfig) -> Self {
83        self.hosts.push(node);
84        self
85    }
86
87    /// Truncates the node list to at most `num_nodes` entries.
88    pub fn truncate(mut self, num_nodes: usize) -> Self {
89        self.hosts.truncate(num_nodes);
90        self
91    }
92
93    /// Validates and builds the [`NetworkConfig`].
94    ///
95    /// Ranks must be unique, sequential, and start at 0. Nodes are sorted by rank.
96    pub fn build(mut self) -> Result<NetworkConfig, NetworkConfigError> {
97        self.hosts.sort_by_key(|n| n.rankid);
98
99        // Network cannot be empty
100        if self.hosts.is_empty() {
101            return Err(NetworkConfigError::EmptyNetwork);
102        }
103
104        // Rank ids must start at 0
105        if self.hosts.first().map(|h| h.rankid) != Some(0) {
106            return Err(NetworkConfigError::FirstRankNotZero);
107        }
108
109        for i in 1..self.hosts.len() {
110            let node_config = &self.hosts[i];
111
112            // Rank ids must be unique
113            if node_config.rankid == self.hosts[i - 1].rankid {
114                return Err(NetworkConfigError::DuplicatedRank {
115                    dup_rank: node_config.rankid,
116                });
117            }
118
119            // Rank ids must be sequential
120            if node_config.rankid != i {
121                return Err(NetworkConfigError::NonSequentialRanks { gap_rank: i });
122            }
123        }
124
125        Ok(NetworkConfig { hosts: self.hosts })
126    }
127}
128
129impl Deref for NetworkConfig {
130    type Target = [NodeConfig];
131
132    fn deref(&self) -> &Self::Target {
133        self.hosts.as_slice()
134    }
135}
136
137impl<'a> IntoIterator for &'a NetworkConfig {
138    type Item = &'a NodeConfig;
139    type IntoIter = std::slice::Iter<'a, NodeConfig>;
140
141    fn into_iter(self) -> Self::IntoIter {
142        self.iter()
143    }
144}
145
146impl NetworkConfig {
147    /// Returns the total number of nodes in the network.
148    pub fn world_size(&self) -> usize {
149        self.hosts.len()
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156
157    #[test]
158    fn valid_network_config() {
159        let config_builder = RawNetworkConfig {
160            hosts: vec![
161                NodeConfig {
162                    hostname: "tdeb02".to_string(),
163                    port: 10000,
164                    ibdev: "mlx5_0".to_string(),
165                    rankid: 0,
166                    comment: String::new(),
167                },
168                NodeConfig {
169                    hostname: "tdeb02".to_string(),
170                    port: 10001,
171                    ibdev: "mlx5_0".to_string(),
172                    rankid: 1,
173                    comment: String::new(),
174                },
175            ],
176        };
177
178        let config = config_builder.build().unwrap();
179        assert_eq!(config.len(), 2);
180        assert_eq!(config[0].rankid, 0);
181        assert_eq!(config[1].rankid, 1);
182    }
183
184    #[test]
185    fn valid_network_config_out_of_order() {
186        let config_builder = RawNetworkConfig {
187            hosts: vec![
188                NodeConfig {
189                    hostname: "node2".to_string(),
190                    port: 10001,
191                    ibdev: "mlx5_0".to_string(),
192                    rankid: 1,
193                    comment: String::new(),
194                },
195                NodeConfig {
196                    hostname: "node1".to_string(),
197                    port: 10000,
198                    ibdev: "mlx5_0".to_string(),
199                    rankid: 0,
200                    comment: String::new(),
201                },
202                NodeConfig {
203                    hostname: "node3".to_string(),
204                    port: 10002,
205                    ibdev: "mlx5_0".to_string(),
206                    rankid: 2,
207                    comment: String::new(),
208                },
209            ],
210        };
211
212        let config = config_builder.build().unwrap();
213        // Should be sorted by rank ID
214        assert_eq!(config[0].rankid, 0);
215        assert_eq!(config[0].hostname, "node1");
216        assert_eq!(config[1].rankid, 1);
217        assert_eq!(config[1].hostname, "node2");
218        assert_eq!(config[2].rankid, 2);
219        assert_eq!(config[2].hostname, "node3");
220    }
221
222    #[test]
223    fn empty_node_config() {
224        let config_builder = RawNetworkConfig { hosts: vec![] };
225        assert!(matches!(
226            config_builder.build(),
227            Err(NetworkConfigError::EmptyNetwork)
228        ));
229    }
230
231    #[test]
232    fn single_node_config() {
233        let config_builder = RawNetworkConfig {
234            hosts: vec![NodeConfig {
235                hostname: "single".to_string(),
236                port: 8080,
237                ibdev: "mlx5_1".to_string(),
238                rankid: 0,
239                comment: String::new(),
240            }],
241        };
242
243        let config = config_builder.build().unwrap();
244        assert_eq!(config.len(), 1);
245        assert_eq!(config[0].rankid, 0);
246    }
247
248    #[test]
249    fn missing_rank_zero() {
250        let config_builder = RawNetworkConfig {
251            hosts: vec![
252                NodeConfig {
253                    hostname: "node1".to_string(),
254                    port: 10000,
255                    ibdev: "mlx5_0".to_string(),
256                    rankid: 1,
257                    comment: String::new(),
258                },
259                NodeConfig {
260                    hostname: "node2".to_string(),
261                    port: 10001,
262                    ibdev: "mlx5_0".to_string(),
263                    rankid: 2,
264                    comment: String::new(),
265                },
266            ],
267        };
268
269        assert!(matches!(
270            config_builder.build(),
271            Err(NetworkConfigError::FirstRankNotZero)
272        ));
273    }
274
275    #[test]
276    fn non_sequential_ranks() {
277        let config_builder = RawNetworkConfig {
278            hosts: vec![
279                NodeConfig {
280                    hostname: "node1".to_string(),
281                    port: 10000,
282                    ibdev: "mlx5_0".to_string(),
283                    rankid: 0,
284                    comment: String::new(),
285                },
286                NodeConfig {
287                    hostname: "node2".to_string(),
288                    port: 10001,
289                    ibdev: "mlx5_0".to_string(),
290                    rankid: 2, // Missing rankid 1
291                    comment: String::new(),
292                },
293            ],
294        };
295
296        assert!(matches!(
297            config_builder.build(),
298            Err(NetworkConfigError::NonSequentialRanks { gap_rank: 1 })
299        ));
300    }
301
302    #[test]
303    fn non_sequential_ranks_before_duplicate() {
304        // Gap at rankid 1, duplicate at rankid 3
305        // Gap should be detected first since 1 < 3
306        let config_builder = RawNetworkConfig {
307            hosts: vec![
308                NodeConfig {
309                    hostname: "node1".to_string(),
310                    port: 10000,
311                    ibdev: "mlx5_0".to_string(),
312                    rankid: 0,
313                    comment: String::new(),
314                },
315                NodeConfig {
316                    hostname: "node2".to_string(),
317                    port: 10001,
318                    ibdev: "mlx5_0".to_string(),
319                    rankid: 3, // Gap: missing rankid 1 and 2
320                    comment: String::new(),
321                },
322                NodeConfig {
323                    hostname: "node3".to_string(),
324                    port: 10002,
325                    ibdev: "mlx5_0".to_string(),
326                    rankid: 3, // Duplicate rankid 3
327                    comment: String::new(),
328                },
329            ],
330        };
331
332        assert!(matches!(
333            config_builder.build(),
334            Err(NetworkConfigError::NonSequentialRanks { gap_rank: 1 })
335        ));
336    }
337
338    #[test]
339    fn duplicate_ranks_before_non_sequential() {
340        // Duplicate at rankid 1, gap at rankid 3 (missing 2)
341        // Duplicate should be detected first since 1 < 3
342        let config_builder = RawNetworkConfig {
343            hosts: vec![
344                NodeConfig {
345                    hostname: "node1".to_string(),
346                    port: 10000,
347                    ibdev: "mlx5_0".to_string(),
348                    rankid: 0,
349                    comment: String::new(),
350                },
351                NodeConfig {
352                    hostname: "node2".to_string(),
353                    port: 10001,
354                    ibdev: "mlx5_0".to_string(),
355                    rankid: 1,
356                    comment: String::new(),
357                },
358                NodeConfig {
359                    hostname: "node3".to_string(),
360                    port: 10002,
361                    ibdev: "mlx5_0".to_string(),
362                    rankid: 1, // Duplicate rankid 1
363                    comment: String::new(),
364                },
365                NodeConfig {
366                    hostname: "node4".to_string(),
367                    port: 10003,
368                    ibdev: "mlx5_0".to_string(),
369                    rankid: 3, // Gap: missing rankid 2
370                    comment: String::new(),
371                },
372            ],
373        };
374
375        assert!(matches!(
376            config_builder.build(),
377            Err(NetworkConfigError::DuplicatedRank { dup_rank: 1 })
378        ));
379    }
380
381    #[test]
382    fn deref_access() {
383        let config_builder = RawNetworkConfig {
384            hosts: vec![
385                NodeConfig {
386                    hostname: "test1".to_string(),
387                    port: 9000,
388                    ibdev: "mlx5_0".to_string(),
389                    rankid: 0,
390                    comment: String::new(),
391                },
392                NodeConfig {
393                    hostname: "test2".to_string(),
394                    port: 9001,
395                    ibdev: "mlx5_0".to_string(),
396                    rankid: 1,
397                    comment: String::new(),
398                },
399            ],
400        };
401
402        let config = config_builder.build().unwrap();
403
404        // Test Deref implementation - should work like a slice
405        assert_eq!(config.len(), 2);
406        assert_eq!(config[0].hostname, "test1");
407        assert_eq!(config[1].hostname, "test2");
408        assert_eq!(config.first().unwrap().port, 9000);
409        assert_eq!(config.last().unwrap().port, 9001);
410
411        // Test iteration
412        let hostnames: Vec<&String> = config.iter().map(|node| &node.hostname).collect();
413        assert_eq!(hostnames, vec!["test1", "test2"]);
414    }
415}