Skip to main content

ibverbs_rs/network/
tcp_exchanger.rs

1use crate::network::config::{NetworkConfig, NodeConfig};
2use ExchangeError::*;
3use log::{debug, warn};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::ops::Range;
9use std::time::Duration;
10use thiserror::Error;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream};
13use tokio::time::timeout;
14
15/// An error that can occur during a TCP-based endpoint exchange.
16#[derive(Debug, Error)]
17pub enum ExchangeError {
18    /// A rank referenced during the exchange is not present in the [`NetworkConfig`].
19    #[error("Rank {rank} not in network")]
20    InvalidRank { rank: usize },
21    /// A message could not be serialized or deserialized.
22    #[error("Error serializing/deserializing data ({0})")]
23    SerdeError(#[from] serde_json::Error),
24    /// An underlying TCP I/O operation failed.
25    #[error("Error during IO operation ({0})")]
26    IoError(#[from] std::io::Error),
27    /// The serialized message exceeds `u32::MAX` bytes and cannot be framed
28    /// with the 4-byte length prefix used by the wire protocol. The field
29    /// contains the actual encoded size in bytes.
30    #[error("Encoded message size {0} exceeds u32::MAX and cannot be framed")]
31    MessageTooLarge(usize),
32    /// The exchange did not complete within [`ExchangeConfig::exchange_timeout`].
33    #[error("Exchange timed out")]
34    Timeout,
35}
36
37/// Configuration for a TCP exchange operation.
38pub struct ExchangeConfig {
39    /// Maximum time to wait for the entire exchange to complete.
40    pub exchange_timeout: Duration,
41    /// Delay between connection retries when a remote peer is not yet listening.
42    pub retry_delay: Duration,
43}
44
45impl Default for ExchangeConfig {
46    fn default() -> Self {
47        Self {
48            exchange_timeout: Duration::from_secs(60),
49            retry_delay: Duration::from_millis(1000),
50        }
51    }
52}
53
54/// TCP-based all-to-all data exchange between nodes in a network.
55///
56/// Used during setup to exchange RDMA endpoint information (e.g. [`QueuePairEndpoint`](crate::ibverbs::queue_pair::builder::QueuePairEndpoint))
57/// between peers over TCP before RDMA channels are established.
58pub struct Exchanger {}
59
60#[derive(Debug, Deserialize, Serialize)]
61struct ExchangeMessage<T> {
62    rank: usize,
63    data: T,
64}
65
66impl Exchanger {
67    /// Exchanges `data` with all other nodes in the network, blocking until complete or timeout.
68    ///
69    /// Returns a `Vec<T>` indexed by rank, where the entry at this node's rank is a clone
70    /// of the local `data` and all other entries come from the corresponding remote nodes.
71    pub fn await_exchange_all<T: Serialize + DeserializeOwned + Clone>(
72        rank: usize,
73        network: &NetworkConfig,
74        data: &T,
75        config: &ExchangeConfig,
76    ) -> Result<Vec<T>, ExchangeError> {
77        tokio::runtime::Builder::new_current_thread()
78            .enable_all()
79            .build()?
80            .block_on(async move {
81                timeout(
82                    config.exchange_timeout,
83                    Self::exchange_all(rank, network, data, config),
84                )
85                .await
86                .unwrap_or(Err(Timeout))
87            })
88    }
89
90    async fn exchange_all<T: Serialize + DeserializeOwned + Clone>(
91        rank: usize,
92        network: &NetworkConfig,
93        data: &T,
94        config: &ExchangeConfig,
95    ) -> Result<Vec<T>, ExchangeError> {
96        let self_node = network.get(rank).ok_or(InvalidRank { rank })?;
97        let lower_ranks = 0..self_node.rankid;
98        let greater_ranks = (self_node.rankid + 1)..(network.world_size());
99
100        debug!(
101            "Exchanging from {}:\n\tlower nodes -> {lower_ranks:?}\n\thigher nodes -> {greater_ranks:?}",
102            self_node.rankid,
103        );
104
105        // Exchange server to lower nodes
106        debug!("Serving exchange...");
107        let lower_nodes_data = Self::exchange_all_serve(data, self_node, lower_ranks).await?;
108        debug!("Done serving");
109
110        // Exchange connect to greater nodes
111        debug!("Connecting exchange...");
112        let greater_nodes_data =
113            Self::exchange_all_connect(data, self_node, greater_ranks, network, config).await?;
114        debug!("Done connecting");
115
116        Ok(lower_nodes_data
117            .into_iter()
118            .chain(std::iter::once(data.to_owned()))
119            .chain(greater_nodes_data)
120            .collect())
121    }
122
123    async fn exchange_all_serve<T: Serialize + DeserializeOwned>(
124        data: &T,
125        self_node: &NodeConfig,
126        remote_ranks: Range<usize>,
127    ) -> Result<Vec<T>, ExchangeError> {
128        let server = TcpListener::bind((self_node.hostname.as_str(), self_node.port)).await?;
129        let mut received = HashMap::new();
130
131        while received.len() < remote_ranks.len() {
132            let (mut stream, _) = server.accept().await?;
133            Self::exchange_serve(
134                data,
135                self_node.rankid,
136                remote_ranks.clone(),
137                &mut stream,
138                &mut received,
139            )
140            .await?;
141        }
142
143        // Iterating on a map directly is O(capacity) so iterate with indices instead
144        Ok(remote_ranks
145            .map(|rank| {
146                received
147                    .remove(&rank)
148                    .expect("rank should have been inserted by the exchange loop above")
149            })
150            .collect())
151    }
152
153    async fn exchange_all_connect<T: Serialize + DeserializeOwned>(
154        data: &T,
155        self_node: &NodeConfig,
156        remote_ranks: Range<usize>,
157        network: &NetworkConfig,
158        config: &ExchangeConfig,
159    ) -> Result<Vec<T>, ExchangeError> {
160        let mut received = HashMap::new();
161
162        for remote_rank in remote_ranks.clone() {
163            let remote_node = network
164                .get(remote_rank)
165                .ok_or(InvalidRank { rank: remote_rank })?;
166
167            let mut stream;
168            loop {
169                if let Ok(s) =
170                    TcpStream::connect((remote_node.hostname.as_str(), remote_node.port)).await
171                {
172                    stream = s;
173                    break;
174                }
175                tokio::time::sleep(config.retry_delay).await;
176            }
177
178            Self::exchange_connect(
179                data,
180                self_node.rankid,
181                remote_ranks.clone(),
182                &mut stream,
183                &mut received,
184            )
185            .await?;
186        }
187
188        // Iterating on a map directly is O(capacity) so iterate with indices instead
189        Ok(remote_ranks
190            .map(|rank| {
191                received
192                    .remove(&rank)
193                    .expect("rank should have been inserted by the exchange loop above")
194            })
195            .collect())
196    }
197
198    async fn exchange_serve<T: Serialize + DeserializeOwned>(
199        data: &T,
200        self_rank: usize,
201        remote_ranks: Range<usize>,
202        stream: &mut TcpStream,
203        received: &mut HashMap<usize, T>,
204    ) -> Result<(), ExchangeError> {
205        // Send self data
206        Self::write_stream(self_rank, data, stream).await?;
207
208        // Read incoming data
209        let incoming_data = Self::read_stream::<T>(stream).await?;
210        Self::insert_if_valid(incoming_data, received, remote_ranks.clone());
211
212        Ok(())
213    }
214
215    async fn exchange_connect<T: Serialize + DeserializeOwned>(
216        data: &T,
217        self_rank: usize,
218        remote_ranks: Range<usize>,
219        stream: &mut TcpStream,
220        received: &mut HashMap<usize, T>,
221    ) -> Result<(), ExchangeError> {
222        // Read incoming data
223        let incoming_data = Self::read_stream::<T>(stream).await?;
224        Self::insert_if_valid(incoming_data, received, remote_ranks.clone());
225
226        // Send self data
227        Self::write_stream(self_rank, data, stream).await?;
228
229        Ok(())
230    }
231
232    fn insert_if_valid<T: Serialize + DeserializeOwned>(
233        incoming_data: ExchangeMessage<T>,
234        received: &mut HashMap<usize, T>,
235        valid_range: Range<usize>,
236    ) -> bool {
237        // Validate rank is in range
238        if valid_range.contains(&incoming_data.rank) {
239            // Insert incoming data to map
240            let out = received.insert(incoming_data.rank, incoming_data.data);
241            if out.is_some() {
242                // Warn if config already received for the specified rank id
243                warn!("Duplicate exchange from {}", incoming_data.rank,);
244            }
245            debug!("Exchange progress -> {}", received.len());
246            true
247        } else {
248            // Warn if exchange from invalid rank received
249            warn!("Invalid rank incoming exchange {}", incoming_data.rank);
250            false
251        }
252    }
253
254    async fn read_stream<T: DeserializeOwned>(
255        stream: &mut (impl AsyncReadExt + Unpin),
256    ) -> Result<ExchangeMessage<T>, ExchangeError> {
257        let mut size_buf = [0u8; size_of::<u32>()];
258        stream.read_exact(&mut size_buf[..]).await?;
259        let msg_size = u32::from_be_bytes(size_buf);
260
261        let mut msg_buf = vec![0u8; msg_size as usize];
262        stream.read_exact(&mut msg_buf[..]).await?;
263        Ok(serde_json::from_slice(&msg_buf)?)
264    }
265
266    async fn write_stream<T: Serialize>(
267        rank: usize,
268        data: &T,
269        stream: &mut (impl AsyncWriteExt + Unpin),
270    ) -> Result<(), ExchangeError> {
271        let encoded = serde_json::to_vec(&ExchangeMessage { rank, data })?;
272        let len = u32::try_from(encoded.len()).map_err(|_| MessageTooLarge(encoded.len()))?;
273        stream.write_all(len.to_be_bytes().as_ref()).await?;
274        stream.write_all(encoded.as_slice()).await?;
275        Ok(())
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282    use std::collections::HashMap;
283
284    fn run_async<F: std::future::Future>(f: F) -> F::Output {
285        tokio::runtime::Builder::new_current_thread()
286            .enable_all()
287            .build()
288            .unwrap()
289            .block_on(f)
290    }
291
292    #[test]
293    fn write_read_round_trip_string() {
294        run_async(async {
295            let (mut writer, mut reader) = tokio::io::duplex(1024);
296            Exchanger::write_stream(7, &"test data".to_string(), &mut writer)
297                .await
298                .unwrap();
299            drop(writer);
300
301            let msg: ExchangeMessage<String> = Exchanger::read_stream(&mut reader).await.unwrap();
302            assert_eq!(msg.rank, 7);
303            assert_eq!(msg.data, "test data");
304        });
305    }
306
307    #[test]
308    fn write_read_round_trip_struct() {
309        #[derive(Debug, PartialEq, Serialize, Deserialize)]
310        struct Endpoint {
311            lid: u16,
312            qpn: u32,
313            psn: u32,
314        }
315
316        run_async(async {
317            let endpoint = Endpoint {
318                lid: 1,
319                qpn: 0x1234,
320                psn: 0xABCD,
321            };
322
323            let (mut writer, mut reader) = tokio::io::duplex(1024);
324            Exchanger::write_stream(3, &endpoint, &mut writer)
325                .await
326                .unwrap();
327            drop(writer);
328
329            let msg: ExchangeMessage<Endpoint> = Exchanger::read_stream(&mut reader).await.unwrap();
330            assert_eq!(msg.rank, 3);
331            assert_eq!(msg.data, endpoint);
332        });
333    }
334
335    #[test]
336    fn write_read_round_trip_vec() {
337        run_async(async {
338            let data = vec![1u64, 2, 3, 4, 5];
339
340            let (mut writer, mut reader) = tokio::io::duplex(1024);
341            Exchanger::write_stream(0, &data, &mut writer)
342                .await
343                .unwrap();
344            drop(writer);
345
346            let msg: ExchangeMessage<Vec<u64>> = Exchanger::read_stream(&mut reader).await.unwrap();
347            assert_eq!(msg.rank, 0);
348            assert_eq!(msg.data, data);
349        });
350    }
351
352    #[test]
353    fn read_stream_rejects_truncated_length() {
354        run_async(async {
355            let data = [0u8, 1];
356            let mut reader = &data[..];
357            assert!(Exchanger::read_stream::<String>(&mut reader).await.is_err());
358        });
359    }
360
361    #[test]
362    fn read_stream_rejects_truncated_body() {
363        run_async(async {
364            let mut data = Vec::new();
365            data.extend_from_slice(&100u32.to_be_bytes());
366            data.extend_from_slice(&[0u8, 1]);
367            let mut reader = &data[..];
368            assert!(Exchanger::read_stream::<String>(&mut reader).await.is_err());
369        });
370    }
371
372    #[test]
373    fn insert_if_valid_accepts_valid_rank() {
374        let mut received = HashMap::new();
375        let msg = ExchangeMessage {
376            rank: 2,
377            data: "hello".to_string(),
378        };
379        assert!(Exchanger::insert_if_valid(msg, &mut received, 0..5));
380        assert_eq!(received.get(&2).unwrap(), "hello");
381    }
382
383    #[test]
384    fn insert_if_valid_rejects_out_of_range() {
385        let mut received = HashMap::new();
386        let msg = ExchangeMessage {
387            rank: 10,
388            data: "hello".to_string(),
389        };
390        assert!(!Exchanger::insert_if_valid(msg, &mut received, 0..5));
391        assert!(received.is_empty());
392    }
393
394    #[test]
395    fn insert_if_valid_overwrites_duplicate() {
396        let mut received = HashMap::new();
397        received.insert(2, "first".to_string());
398        let msg = ExchangeMessage {
399            rank: 2,
400            data: "second".to_string(),
401        };
402        assert!(Exchanger::insert_if_valid(msg, &mut received, 0..5));
403        assert_eq!(received.get(&2).unwrap(), "second");
404    }
405
406    fn make_network(ports: &[u16]) -> NetworkConfig {
407        let mut builder = NetworkConfig::builder();
408        for (i, &port) in ports.iter().enumerate() {
409            builder = builder.add_node(
410                NodeConfig::builder()
411                    .hostname("127.0.0.1")
412                    .port(port)
413                    .ibdev("test0")
414                    .rankid(i)
415                    .build(),
416            );
417        }
418        builder.build().unwrap()
419    }
420
421    #[test]
422    fn two_node_exchange() {
423        let network = make_network(&[41100, 41101]);
424
425        let handles: Vec<_> = (0..2)
426            .map(|rank| {
427                let net = network.clone();
428                std::thread::spawn(move || {
429                    Exchanger::await_exchange_all(
430                        rank,
431                        &net,
432                        &format!("from_{rank}"),
433                        &ExchangeConfig::default(),
434                    )
435                })
436            })
437            .collect();
438
439        let expected = vec!["from_0".to_string(), "from_1".to_string()];
440        for handle in handles {
441            assert_eq!(handle.join().unwrap().unwrap(), expected);
442        }
443    }
444
445    #[test]
446    fn three_node_exchange() {
447        let network = make_network(&[41200, 41201, 41202]);
448
449        let handles: Vec<_> = (0..3)
450            .map(|rank| {
451                let net = network.clone();
452                std::thread::spawn(move || {
453                    Exchanger::await_exchange_all(
454                        rank,
455                        &net,
456                        &format!("from_{rank}"),
457                        &ExchangeConfig::default(),
458                    )
459                })
460            })
461            .collect();
462
463        let expected: Vec<String> = (0..3).map(|i| format!("from_{i}")).collect();
464        for handle in handles {
465            assert_eq!(handle.join().unwrap().unwrap(), expected);
466        }
467    }
468}