1use std::io::{Read, Write};
22use std::net::TcpStream;
23
24pub struct RingAllReduceWorker {
29 rank: usize,
31 world_size: usize,
33 send_stream: TcpStream,
35 recv_stream: TcpStream,
37}
38
39impl RingAllReduceWorker {
40 pub fn new(
48 rank: usize,
49 world_size: usize,
50 send_stream: TcpStream,
51 recv_stream: TcpStream,
52 ) -> Self {
53 assert!(world_size >= 2, "ring AllReduce requires >= 2 workers");
54 assert!(rank < world_size, "rank must be < world_size");
55 Self { rank, world_size, send_stream, recv_stream }
56 }
57
58 pub fn allreduce(&mut self, data: &mut [f32]) -> Result<(), String> {
72 contract_pre_gradient_allreduce!();
73 let n = self.world_size;
74 let d = data.len();
75
76 let chunk_size = d / n;
78 let remainder = d % n;
79 let chunks: Vec<(usize, usize)> = (0..n)
80 .map(|i| {
81 let start = i * chunk_size + i.min(remainder);
82 let len = chunk_size + usize::from(i < remainder);
83 (start, len)
84 })
85 .collect();
86
87 let max_chunk_len = chunks.iter().map(|(_, len)| *len).max().unwrap_or(0);
89 let mut send_buf = vec![0u8; max_chunk_len * 4];
90 let mut recv_buf = vec![0u8; max_chunk_len * 4];
91
92 for round in 0..(n - 1) {
95 let send_chunk_idx = (self.rank + n - round) % n;
97 let (send_start, send_len) = chunks[send_chunk_idx];
98
99 let recv_chunk_idx = (self.rank + n - round - 1) % n;
101 let (recv_start, recv_len) = chunks[recv_chunk_idx];
102
103 f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
105
106 self.send_stream
109 .write_all(&send_buf[..send_len * 4])
110 .map_err(|e| format!("ring send error (round {round}): {e}"))?;
111 self.recv_stream
112 .read_exact(&mut recv_buf[..recv_len * 4])
113 .map_err(|e| format!("ring recv error (round {round}): {e}"))?;
114
115 for i in 0..recv_len {
117 let received =
118 f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
119 data[recv_start + i] += received;
120 }
121 }
122
123 for round in 0..(n - 1) {
126 let send_chunk_idx = (self.rank + n - round + 1) % n;
127 let (send_start, send_len) = chunks[send_chunk_idx];
128
129 let recv_chunk_idx = (self.rank + n - round) % n;
130 let (recv_start, recv_len) = chunks[recv_chunk_idx];
131
132 f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
134
135 self.send_stream
136 .write_all(&send_buf[..send_len * 4])
137 .map_err(|e| format!("ring allgather send error (round {round}): {e}"))?;
138 self.recv_stream
139 .read_exact(&mut recv_buf[..recv_len * 4])
140 .map_err(|e| format!("ring allgather recv error (round {round}): {e}"))?;
141
142 for i in 0..recv_len {
144 data[recv_start + i] =
145 f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
146 }
147 }
148
149 let inv_n = 1.0 / n as f32;
151 for x in data.iter_mut() {
152 *x *= inv_n;
153 }
154
155 Ok(())
156 }
157}
158
159fn f32_slice_to_bytes(src: &[f32], dst: &mut [u8]) {
161 for (i, &val) in src.iter().enumerate() {
162 dst[i * 4..(i + 1) * 4].copy_from_slice(&val.to_le_bytes());
163 }
164}
165
166pub fn allreduce_pair(
175 data: &mut [f32],
176 send_stream: &mut TcpStream,
177 recv_stream: &mut TcpStream,
178) -> Result<(), String> {
179 let byte_len = data.len() * 4;
180 let mut send_buf = vec![0u8; byte_len];
181 let mut recv_buf = vec![0u8; byte_len];
182
183 f32_slice_to_bytes(data, &mut send_buf);
184
185 send_stream.write_all(&send_buf).map_err(|e| format!("pair send error: {e}"))?;
186 recv_stream.read_exact(&mut recv_buf).map_err(|e| format!("pair recv error: {e}"))?;
187
188 for i in 0..data.len() {
189 let remote = f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
190 data[i] = (data[i] + remote) * 0.5;
191 }
192
193 Ok(())
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use std::net::TcpListener;
200 use std::thread;
201
202 fn setup_ring(n: usize) -> Vec<RingAllReduceWorker> {
205 let listeners: Vec<TcpListener> =
207 (0..n).map(|_| TcpListener::bind("127.0.0.1:0").expect("bind")).collect();
208 let addrs: Vec<_> = listeners.iter().map(|l| l.local_addr().expect("addr")).collect();
209
210 let mut send_streams = Vec::with_capacity(n);
213 let mut recv_streams = Vec::with_capacity(n);
214
215 let accept_handles: Vec<_> = listeners
217 .into_iter()
218 .map(|listener| {
219 thread::spawn(move || {
220 let (stream, _) = listener.accept().expect("accept");
221 stream
222 })
223 })
224 .collect();
225
226 for w in 0..n {
228 let right = (w + 1) % n;
229 let stream = TcpStream::connect(addrs[right]).expect("connect");
230 stream.set_nodelay(true).ok();
231 send_streams.push(stream);
232 }
233
234 for handle in accept_handles {
236 let stream = handle.join().expect("accept thread");
237 stream.set_nodelay(true).ok();
238 recv_streams.push(stream);
239 }
240
241 let mut workers = Vec::with_capacity(n);
243 for w in 0..n {
244 workers.push(RingAllReduceWorker::new(
245 w,
246 n,
247 send_streams.remove(0),
248 recv_streams.remove(0),
249 ));
250 }
251 workers
252 }
253
254 #[test]
255 fn test_ring_allreduce_2_workers_identical() {
256 let mut workers = setup_ring(2);
257
258 let data0 = vec![1.0f32, 2.0, 3.0];
259 let data1 = vec![1.0f32, 2.0, 3.0];
260
261 let mut d0 = data0.clone();
262 let mut w1 = workers.pop().unwrap();
263 let mut d1 = data1.clone();
264
265 let h1 = thread::spawn(move || {
266 w1.allreduce(&mut d1).expect("allreduce w1");
267 d1
268 });
269
270 workers[0].allreduce(&mut d0).expect("allreduce w0");
271 let result1 = h1.join().expect("join w1");
272
273 for (&v, &expected) in d0.iter().zip(&[1.0, 2.0, 3.0]) {
275 assert!((v - expected).abs() < 1e-6, "w0: {v} != {expected}");
276 }
277 for (&v, &expected) in result1.iter().zip(&[1.0, 2.0, 3.0]) {
278 assert!((v - expected).abs() < 1e-6, "w1: {v} != {expected}");
279 }
280 }
281
282 #[test]
283 fn test_ring_allreduce_2_workers_distinct() {
284 let mut workers = setup_ring(2);
285
286 let mut d0 = vec![2.0f32, 4.0, 6.0];
287 let mut d1 = vec![8.0f32, 6.0, 4.0];
288 let mut w1 = workers.pop().unwrap();
291
292 let h1 = thread::spawn(move || {
293 w1.allreduce(&mut d1).expect("allreduce w1");
294 d1
295 });
296
297 workers[0].allreduce(&mut d0).expect("allreduce w0");
298 let result1 = h1.join().expect("join w1");
299
300 for &v in &d0 {
301 assert!((v - 5.0).abs() < 1e-6, "w0: {v} != 5.0");
302 }
303 for &v in &result1 {
304 assert!((v - 5.0).abs() < 1e-6, "w1: {v} != 5.0");
305 }
306 }
307
308 #[test]
309 fn test_ring_allreduce_3_workers() {
310 let mut workers = setup_ring(3);
311
312 let mut d0 = vec![1.0f32, 0.0, 0.0];
313 let mut d1 = vec![0.0f32, 1.0, 0.0];
314 let mut d2 = vec![0.0f32, 0.0, 1.0];
315 let mut w2 = workers.pop().unwrap();
318 let mut w1 = workers.pop().unwrap();
319
320 let h2 = thread::spawn(move || {
321 w2.allreduce(&mut d2).expect("allreduce w2");
322 d2
323 });
324 let h1 = thread::spawn(move || {
325 w1.allreduce(&mut d1).expect("allreduce w1");
326 d1
327 });
328
329 workers[0].allreduce(&mut d0).expect("allreduce w0");
330 let r1 = h1.join().expect("join w1");
331 let r2 = h2.join().expect("join w2");
332
333 let expected = 1.0 / 3.0;
334 for &v in &d0 {
335 assert!((v - expected).abs() < 1e-5, "w0: {v} != {expected}");
336 }
337 for &v in &r1 {
338 assert!((v - expected).abs() < 1e-5, "w1: {v} != {expected}");
339 }
340 for &v in &r2 {
341 assert!((v - expected).abs() < 1e-5, "w2: {v} != {expected}");
342 }
343 }
344
345 #[test]
346 fn test_ring_allreduce_non_divisible_length() {
347 let mut workers = setup_ring(3);
349
350 let mut d0 = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
351 let mut d1 = vec![7.0f32, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
352 let mut d2 = vec![0.0f32; 7];
353
354 let mut w2 = workers.pop().unwrap();
355 let mut w1 = workers.pop().unwrap();
356
357 let h2 = thread::spawn(move || {
358 w2.allreduce(&mut d2).expect("allreduce");
359 d2
360 });
361 let h1 = thread::spawn(move || {
362 w1.allreduce(&mut d1).expect("allreduce");
363 d1
364 });
365 workers[0].allreduce(&mut d0).expect("allreduce");
366 let r1 = h1.join().expect("join");
367 let r2 = h2.join().expect("join");
368
369 let expected: Vec<f32> =
371 vec![8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0];
372 for (i, (&v, &e)) in d0.iter().zip(&expected).enumerate() {
373 assert!((v - e).abs() < 1e-5, "w0[{i}]: {v} != {e}");
374 }
375 assert_eq!(d0, r1, "w0 == w1");
376 assert_eq!(d0, r2, "w0 == w2");
377 }
378
379 #[test]
380 fn test_ring_allreduce_large_vector() {
381 let mut workers = setup_ring(2);
382 let d = 100_000;
383 let mut d0: Vec<f32> = (0..d).map(|i| i as f32).collect();
384 let mut d1: Vec<f32> = (0..d).map(|i| (d - 1 - i) as f32).collect();
385 let mut w1 = workers.pop().unwrap();
388
389 let h1 = thread::spawn(move || {
390 w1.allreduce(&mut d1).expect("allreduce");
391 d1
392 });
393 workers[0].allreduce(&mut d0).expect("allreduce");
394 let r1 = h1.join().expect("join");
395
396 let expected = (d as f32 - 1.0) / 2.0;
397 for (i, &v) in d0.iter().enumerate() {
398 assert!((v - expected).abs() < 1e-2, "w0[{i}]: {v} != {expected}");
399 }
400 assert_eq!(d0, r1, "results must be identical");
401 }
402
403 #[test]
404 fn test_allreduce_pair() {
405 let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
406 let addr = listener.local_addr().expect("addr");
407
408 let h = thread::spawn(move || {
409 let (recv, _) = listener.accept().expect("accept");
410 let send = TcpStream::connect(addr).expect("connect");
411 (recv, send)
414 });
415
416 let listener_a = TcpListener::bind("127.0.0.1:0").expect("bind");
418 let listener_b = TcpListener::bind("127.0.0.1:0").expect("bind");
419 let addr_a = listener_a.local_addr().expect("addr");
420 let addr_b = listener_b.local_addr().expect("addr");
421 drop(h);
422
423 let ha = thread::spawn(move || {
424 let send = TcpStream::connect(addr_b).expect("connect to b");
425 let (recv, _) = listener_a.accept().expect("accept from b");
426 (send, recv)
427 });
428
429 let send_b = TcpStream::connect(addr_a).expect("connect to a");
430 let (recv_b, _) = listener_b.accept().expect("accept from a");
431
432 let (mut send_a, mut recv_a) = ha.join().expect("join");
433 let mut send_b = send_b;
434 let mut recv_b = recv_b;
435
436 let mut d_a = vec![10.0f32, 20.0, 30.0];
437 let mut d_b = vec![30.0f32, 20.0, 10.0];
438
439 let hb = thread::spawn(move || {
440 allreduce_pair(&mut d_b, &mut send_b, &mut recv_b).expect("pair b");
441 d_b
442 });
443
444 allreduce_pair(&mut d_a, &mut send_a, &mut recv_a).expect("pair a");
445 let result_b = hb.join().expect("join");
446
447 assert_eq!(d_a, vec![20.0, 20.0, 20.0]);
448 assert_eq!(result_b, vec![20.0, 20.0, 20.0]);
449 }
450}