1use crate::network::config::{NetworkConfig, NodeConfig};
2use ExchangeError::*;
3use log::{debug, warn};
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fmt::Debug;
8use std::ops::Range;
9use std::time::Duration;
10use thiserror::Error;
11use tokio::io::{AsyncReadExt, AsyncWriteExt};
12use tokio::net::{TcpListener, TcpStream};
13use tokio::time::timeout;
14
15#[derive(Debug, Error)]
17pub enum ExchangeError {
18 #[error("Rank {rank} not in network")]
20 InvalidRank { rank: usize },
21 #[error("Error serializing/deserializing data ({0})")]
23 SerdeError(#[from] serde_json::Error),
24 #[error("Error during IO operation ({0})")]
26 IoError(#[from] std::io::Error),
27 #[error("Encoded message size {0} exceeds u32::MAX and cannot be framed")]
31 MessageTooLarge(usize),
32 #[error("Exchange timed out")]
34 Timeout,
35}
36
37pub struct ExchangeConfig {
39 pub exchange_timeout: Duration,
41 pub retry_delay: Duration,
43}
44
45impl Default for ExchangeConfig {
46 fn default() -> Self {
47 Self {
48 exchange_timeout: Duration::from_secs(60),
49 retry_delay: Duration::from_millis(1000),
50 }
51 }
52}
53
54pub struct Exchanger {}
59
60#[derive(Debug, Deserialize, Serialize)]
61struct ExchangeMessage<T> {
62 rank: usize,
63 data: T,
64}
65
66impl Exchanger {
67 pub fn await_exchange_all<T: Serialize + DeserializeOwned + Clone>(
72 rank: usize,
73 network: &NetworkConfig,
74 data: &T,
75 config: &ExchangeConfig,
76 ) -> Result<Vec<T>, ExchangeError> {
77 tokio::runtime::Builder::new_current_thread()
78 .enable_all()
79 .build()?
80 .block_on(async move {
81 timeout(
82 config.exchange_timeout,
83 Self::exchange_all(rank, network, data, config),
84 )
85 .await
86 .unwrap_or(Err(Timeout))
87 })
88 }
89
90 async fn exchange_all<T: Serialize + DeserializeOwned + Clone>(
91 rank: usize,
92 network: &NetworkConfig,
93 data: &T,
94 config: &ExchangeConfig,
95 ) -> Result<Vec<T>, ExchangeError> {
96 let self_node = network.get(rank).ok_or(InvalidRank { rank })?;
97 let lower_ranks = 0..self_node.rankid;
98 let greater_ranks = (self_node.rankid + 1)..(network.world_size());
99
100 debug!(
101 "Exchanging from {}:\n\tlower nodes -> {lower_ranks:?}\n\thigher nodes -> {greater_ranks:?}",
102 self_node.rankid,
103 );
104
105 debug!("Serving exchange...");
107 let lower_nodes_data = Self::exchange_all_serve(data, self_node, lower_ranks).await?;
108 debug!("Done serving");
109
110 debug!("Connecting exchange...");
112 let greater_nodes_data =
113 Self::exchange_all_connect(data, self_node, greater_ranks, network, config).await?;
114 debug!("Done connecting");
115
116 Ok(lower_nodes_data
117 .into_iter()
118 .chain(std::iter::once(data.to_owned()))
119 .chain(greater_nodes_data)
120 .collect())
121 }
122
123 async fn exchange_all_serve<T: Serialize + DeserializeOwned>(
124 data: &T,
125 self_node: &NodeConfig,
126 remote_ranks: Range<usize>,
127 ) -> Result<Vec<T>, ExchangeError> {
128 let server = TcpListener::bind((self_node.hostname.as_str(), self_node.port)).await?;
129 let mut received = HashMap::new();
130
131 while received.len() < remote_ranks.len() {
132 let (mut stream, _) = server.accept().await?;
133 Self::exchange_serve(
134 data,
135 self_node.rankid,
136 remote_ranks.clone(),
137 &mut stream,
138 &mut received,
139 )
140 .await?;
141 }
142
143 Ok(remote_ranks
145 .map(|rank| {
146 received
147 .remove(&rank)
148 .expect("rank should have been inserted by the exchange loop above")
149 })
150 .collect())
151 }
152
153 async fn exchange_all_connect<T: Serialize + DeserializeOwned>(
154 data: &T,
155 self_node: &NodeConfig,
156 remote_ranks: Range<usize>,
157 network: &NetworkConfig,
158 config: &ExchangeConfig,
159 ) -> Result<Vec<T>, ExchangeError> {
160 let mut received = HashMap::new();
161
162 for remote_rank in remote_ranks.clone() {
163 let remote_node = network
164 .get(remote_rank)
165 .ok_or(InvalidRank { rank: remote_rank })?;
166
167 let mut stream;
168 loop {
169 if let Ok(s) =
170 TcpStream::connect((remote_node.hostname.as_str(), remote_node.port)).await
171 {
172 stream = s;
173 break;
174 }
175 tokio::time::sleep(config.retry_delay).await;
176 }
177
178 Self::exchange_connect(
179 data,
180 self_node.rankid,
181 remote_ranks.clone(),
182 &mut stream,
183 &mut received,
184 )
185 .await?;
186 }
187
188 Ok(remote_ranks
190 .map(|rank| {
191 received
192 .remove(&rank)
193 .expect("rank should have been inserted by the exchange loop above")
194 })
195 .collect())
196 }
197
198 async fn exchange_serve<T: Serialize + DeserializeOwned>(
199 data: &T,
200 self_rank: usize,
201 remote_ranks: Range<usize>,
202 stream: &mut TcpStream,
203 received: &mut HashMap<usize, T>,
204 ) -> Result<(), ExchangeError> {
205 Self::write_stream(self_rank, data, stream).await?;
207
208 let incoming_data = Self::read_stream::<T>(stream).await?;
210 Self::insert_if_valid(incoming_data, received, remote_ranks.clone());
211
212 Ok(())
213 }
214
215 async fn exchange_connect<T: Serialize + DeserializeOwned>(
216 data: &T,
217 self_rank: usize,
218 remote_ranks: Range<usize>,
219 stream: &mut TcpStream,
220 received: &mut HashMap<usize, T>,
221 ) -> Result<(), ExchangeError> {
222 let incoming_data = Self::read_stream::<T>(stream).await?;
224 Self::insert_if_valid(incoming_data, received, remote_ranks.clone());
225
226 Self::write_stream(self_rank, data, stream).await?;
228
229 Ok(())
230 }
231
232 fn insert_if_valid<T: Serialize + DeserializeOwned>(
233 incoming_data: ExchangeMessage<T>,
234 received: &mut HashMap<usize, T>,
235 valid_range: Range<usize>,
236 ) -> bool {
237 if valid_range.contains(&incoming_data.rank) {
239 let out = received.insert(incoming_data.rank, incoming_data.data);
241 if out.is_some() {
242 warn!("Duplicate exchange from {}", incoming_data.rank,);
244 }
245 debug!("Exchange progress -> {}", received.len());
246 true
247 } else {
248 warn!("Invalid rank incoming exchange {}", incoming_data.rank);
250 false
251 }
252 }
253
254 async fn read_stream<T: DeserializeOwned>(
255 stream: &mut (impl AsyncReadExt + Unpin),
256 ) -> Result<ExchangeMessage<T>, ExchangeError> {
257 let mut size_buf = [0u8; size_of::<u32>()];
258 stream.read_exact(&mut size_buf[..]).await?;
259 let msg_size = u32::from_be_bytes(size_buf);
260
261 let mut msg_buf = vec![0u8; msg_size as usize];
262 stream.read_exact(&mut msg_buf[..]).await?;
263 Ok(serde_json::from_slice(&msg_buf)?)
264 }
265
266 async fn write_stream<T: Serialize>(
267 rank: usize,
268 data: &T,
269 stream: &mut (impl AsyncWriteExt + Unpin),
270 ) -> Result<(), ExchangeError> {
271 let encoded = serde_json::to_vec(&ExchangeMessage { rank, data })?;
272 let len = u32::try_from(encoded.len()).map_err(|_| MessageTooLarge(encoded.len()))?;
273 stream.write_all(len.to_be_bytes().as_ref()).await?;
274 stream.write_all(encoded.as_slice()).await?;
275 Ok(())
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282 use std::collections::HashMap;
283
284 fn run_async<F: std::future::Future>(f: F) -> F::Output {
285 tokio::runtime::Builder::new_current_thread()
286 .enable_all()
287 .build()
288 .unwrap()
289 .block_on(f)
290 }
291
292 #[test]
293 fn write_read_round_trip_string() {
294 run_async(async {
295 let (mut writer, mut reader) = tokio::io::duplex(1024);
296 Exchanger::write_stream(7, &"test data".to_string(), &mut writer)
297 .await
298 .unwrap();
299 drop(writer);
300
301 let msg: ExchangeMessage<String> = Exchanger::read_stream(&mut reader).await.unwrap();
302 assert_eq!(msg.rank, 7);
303 assert_eq!(msg.data, "test data");
304 });
305 }
306
307 #[test]
308 fn write_read_round_trip_struct() {
309 #[derive(Debug, PartialEq, Serialize, Deserialize)]
310 struct Endpoint {
311 lid: u16,
312 qpn: u32,
313 psn: u32,
314 }
315
316 run_async(async {
317 let endpoint = Endpoint {
318 lid: 1,
319 qpn: 0x1234,
320 psn: 0xABCD,
321 };
322
323 let (mut writer, mut reader) = tokio::io::duplex(1024);
324 Exchanger::write_stream(3, &endpoint, &mut writer)
325 .await
326 .unwrap();
327 drop(writer);
328
329 let msg: ExchangeMessage<Endpoint> = Exchanger::read_stream(&mut reader).await.unwrap();
330 assert_eq!(msg.rank, 3);
331 assert_eq!(msg.data, endpoint);
332 });
333 }
334
335 #[test]
336 fn write_read_round_trip_vec() {
337 run_async(async {
338 let data = vec![1u64, 2, 3, 4, 5];
339
340 let (mut writer, mut reader) = tokio::io::duplex(1024);
341 Exchanger::write_stream(0, &data, &mut writer)
342 .await
343 .unwrap();
344 drop(writer);
345
346 let msg: ExchangeMessage<Vec<u64>> = Exchanger::read_stream(&mut reader).await.unwrap();
347 assert_eq!(msg.rank, 0);
348 assert_eq!(msg.data, data);
349 });
350 }
351
352 #[test]
353 fn read_stream_rejects_truncated_length() {
354 run_async(async {
355 let data = [0u8, 1];
356 let mut reader = &data[..];
357 assert!(Exchanger::read_stream::<String>(&mut reader).await.is_err());
358 });
359 }
360
361 #[test]
362 fn read_stream_rejects_truncated_body() {
363 run_async(async {
364 let mut data = Vec::new();
365 data.extend_from_slice(&100u32.to_be_bytes());
366 data.extend_from_slice(&[0u8, 1]);
367 let mut reader = &data[..];
368 assert!(Exchanger::read_stream::<String>(&mut reader).await.is_err());
369 });
370 }
371
372 #[test]
373 fn insert_if_valid_accepts_valid_rank() {
374 let mut received = HashMap::new();
375 let msg = ExchangeMessage {
376 rank: 2,
377 data: "hello".to_string(),
378 };
379 assert!(Exchanger::insert_if_valid(msg, &mut received, 0..5));
380 assert_eq!(received.get(&2).unwrap(), "hello");
381 }
382
383 #[test]
384 fn insert_if_valid_rejects_out_of_range() {
385 let mut received = HashMap::new();
386 let msg = ExchangeMessage {
387 rank: 10,
388 data: "hello".to_string(),
389 };
390 assert!(!Exchanger::insert_if_valid(msg, &mut received, 0..5));
391 assert!(received.is_empty());
392 }
393
394 #[test]
395 fn insert_if_valid_overwrites_duplicate() {
396 let mut received = HashMap::new();
397 received.insert(2, "first".to_string());
398 let msg = ExchangeMessage {
399 rank: 2,
400 data: "second".to_string(),
401 };
402 assert!(Exchanger::insert_if_valid(msg, &mut received, 0..5));
403 assert_eq!(received.get(&2).unwrap(), "second");
404 }
405
406 fn make_network(ports: &[u16]) -> NetworkConfig {
407 let mut builder = NetworkConfig::builder();
408 for (i, &port) in ports.iter().enumerate() {
409 builder = builder.add_node(
410 NodeConfig::builder()
411 .hostname("127.0.0.1")
412 .port(port)
413 .ibdev("test0")
414 .rankid(i)
415 .build(),
416 );
417 }
418 builder.build().unwrap()
419 }
420
421 #[test]
422 fn two_node_exchange() {
423 let network = make_network(&[41100, 41101]);
424
425 let handles: Vec<_> = (0..2)
426 .map(|rank| {
427 let net = network.clone();
428 std::thread::spawn(move || {
429 Exchanger::await_exchange_all(
430 rank,
431 &net,
432 &format!("from_{rank}"),
433 &ExchangeConfig::default(),
434 )
435 })
436 })
437 .collect();
438
439 let expected = vec!["from_0".to_string(), "from_1".to_string()];
440 for handle in handles {
441 assert_eq!(handle.join().unwrap().unwrap(), expected);
442 }
443 }
444
445 #[test]
446 fn three_node_exchange() {
447 let network = make_network(&[41200, 41201, 41202]);
448
449 let handles: Vec<_> = (0..3)
450 .map(|rank| {
451 let net = network.clone();
452 std::thread::spawn(move || {
453 Exchanger::await_exchange_all(
454 rank,
455 &net,
456 &format!("from_{rank}"),
457 &ExchangeConfig::default(),
458 )
459 })
460 })
461 .collect();
462
463 let expected: Vec<String> = (0..3).map(|i| format!("from_{i}")).collect();
464 for handle in handles {
465 assert_eq!(handle.join().unwrap().unwrap(), expected);
466 }
467 }
468}