use std::io::{Read, Write};
use std::net::TcpStream;
pub struct RingAllReduceWorker {
rank: usize,
world_size: usize,
send_stream: TcpStream,
recv_stream: TcpStream,
}
impl RingAllReduceWorker {
pub fn new(
rank: usize,
world_size: usize,
send_stream: TcpStream,
recv_stream: TcpStream,
) -> Self {
assert!(world_size >= 2, "ring AllReduce requires >= 2 workers");
assert!(rank < world_size, "rank must be < world_size");
Self { rank, world_size, send_stream, recv_stream }
}
pub fn allreduce(&mut self, data: &mut [f32]) -> Result<(), String> {
let n = self.world_size;
let d = data.len();
let chunk_size = d / n;
let remainder = d % n;
let chunks: Vec<(usize, usize)> = (0..n)
.map(|i| {
let start = i * chunk_size + i.min(remainder);
let len = chunk_size + usize::from(i < remainder);
(start, len)
})
.collect();
let max_chunk_len = chunks.iter().map(|(_, len)| *len).max().unwrap_or(0);
let mut send_buf = vec![0u8; max_chunk_len * 4];
let mut recv_buf = vec![0u8; max_chunk_len * 4];
for round in 0..(n - 1) {
let send_chunk_idx = (self.rank + n - round) % n;
let (send_start, send_len) = chunks[send_chunk_idx];
let recv_chunk_idx = (self.rank + n - round - 1) % n;
let (recv_start, recv_len) = chunks[recv_chunk_idx];
f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
self.send_stream
.write_all(&send_buf[..send_len * 4])
.map_err(|e| format!("ring send error (round {round}): {e}"))?;
self.recv_stream
.read_exact(&mut recv_buf[..recv_len * 4])
.map_err(|e| format!("ring recv error (round {round}): {e}"))?;
for i in 0..recv_len {
let received =
f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
data[recv_start + i] += received;
}
}
for round in 0..(n - 1) {
let send_chunk_idx = (self.rank + n - round + 1) % n;
let (send_start, send_len) = chunks[send_chunk_idx];
let recv_chunk_idx = (self.rank + n - round) % n;
let (recv_start, recv_len) = chunks[recv_chunk_idx];
f32_slice_to_bytes(&data[send_start..send_start + send_len], &mut send_buf);
self.send_stream
.write_all(&send_buf[..send_len * 4])
.map_err(|e| format!("ring allgather send error (round {round}): {e}"))?;
self.recv_stream
.read_exact(&mut recv_buf[..recv_len * 4])
.map_err(|e| format!("ring allgather recv error (round {round}): {e}"))?;
for i in 0..recv_len {
data[recv_start + i] =
f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
}
}
let inv_n = 1.0 / n as f32;
for x in data.iter_mut() {
*x *= inv_n;
}
Ok(())
}
}
fn f32_slice_to_bytes(src: &[f32], dst: &mut [u8]) {
for (i, &val) in src.iter().enumerate() {
dst[i * 4..(i + 1) * 4].copy_from_slice(&val.to_le_bytes());
}
}
pub fn allreduce_pair(
data: &mut [f32],
send_stream: &mut TcpStream,
recv_stream: &mut TcpStream,
) -> Result<(), String> {
let byte_len = data.len() * 4;
let mut send_buf = vec![0u8; byte_len];
let mut recv_buf = vec![0u8; byte_len];
f32_slice_to_bytes(data, &mut send_buf);
send_stream.write_all(&send_buf).map_err(|e| format!("pair send error: {e}"))?;
recv_stream.read_exact(&mut recv_buf).map_err(|e| format!("pair recv error: {e}"))?;
for i in 0..data.len() {
let remote = f32::from_le_bytes(recv_buf[i * 4..(i + 1) * 4].try_into().expect("4 bytes"));
data[i] = (data[i] + remote) * 0.5;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::TcpListener;
use std::thread;
fn setup_ring(n: usize) -> Vec<RingAllReduceWorker> {
let listeners: Vec<TcpListener> =
(0..n).map(|_| TcpListener::bind("127.0.0.1:0").expect("bind")).collect();
let addrs: Vec<_> = listeners.iter().map(|l| l.local_addr().expect("addr")).collect();
let mut send_streams = Vec::with_capacity(n);
let mut recv_streams = Vec::with_capacity(n);
let accept_handles: Vec<_> = listeners
.into_iter()
.map(|listener| {
thread::spawn(move || {
let (stream, _) = listener.accept().expect("accept");
stream
})
})
.collect();
for w in 0..n {
let right = (w + 1) % n;
let stream = TcpStream::connect(addrs[right]).expect("connect");
stream.set_nodelay(true).ok();
send_streams.push(stream);
}
for handle in accept_handles {
let stream = handle.join().expect("accept thread");
stream.set_nodelay(true).ok();
recv_streams.push(stream);
}
let mut workers = Vec::with_capacity(n);
for w in 0..n {
workers.push(RingAllReduceWorker::new(
w,
n,
send_streams.remove(0),
recv_streams.remove(0),
));
}
workers
}
#[test]
fn test_ring_allreduce_2_workers_identical() {
let mut workers = setup_ring(2);
let data0 = vec![1.0f32, 2.0, 3.0];
let data1 = vec![1.0f32, 2.0, 3.0];
let mut d0 = data0.clone();
let mut w1 = workers.pop().unwrap();
let mut d1 = data1.clone();
let h1 = thread::spawn(move || {
w1.allreduce(&mut d1).expect("allreduce w1");
d1
});
workers[0].allreduce(&mut d0).expect("allreduce w0");
let result1 = h1.join().expect("join w1");
for (&v, &expected) in d0.iter().zip(&[1.0, 2.0, 3.0]) {
assert!((v - expected).abs() < 1e-6, "w0: {v} != {expected}");
}
for (&v, &expected) in result1.iter().zip(&[1.0, 2.0, 3.0]) {
assert!((v - expected).abs() < 1e-6, "w1: {v} != {expected}");
}
}
#[test]
fn test_ring_allreduce_2_workers_distinct() {
let mut workers = setup_ring(2);
let mut d0 = vec![2.0f32, 4.0, 6.0];
let mut d1 = vec![8.0f32, 6.0, 4.0];
let mut w1 = workers.pop().unwrap();
let h1 = thread::spawn(move || {
w1.allreduce(&mut d1).expect("allreduce w1");
d1
});
workers[0].allreduce(&mut d0).expect("allreduce w0");
let result1 = h1.join().expect("join w1");
for &v in &d0 {
assert!((v - 5.0).abs() < 1e-6, "w0: {v} != 5.0");
}
for &v in &result1 {
assert!((v - 5.0).abs() < 1e-6, "w1: {v} != 5.0");
}
}
#[test]
fn test_ring_allreduce_3_workers() {
let mut workers = setup_ring(3);
let mut d0 = vec![1.0f32, 0.0, 0.0];
let mut d1 = vec![0.0f32, 1.0, 0.0];
let mut d2 = vec![0.0f32, 0.0, 1.0];
let mut w2 = workers.pop().unwrap();
let mut w1 = workers.pop().unwrap();
let h2 = thread::spawn(move || {
w2.allreduce(&mut d2).expect("allreduce w2");
d2
});
let h1 = thread::spawn(move || {
w1.allreduce(&mut d1).expect("allreduce w1");
d1
});
workers[0].allreduce(&mut d0).expect("allreduce w0");
let r1 = h1.join().expect("join w1");
let r2 = h2.join().expect("join w2");
let expected = 1.0 / 3.0;
for &v in &d0 {
assert!((v - expected).abs() < 1e-5, "w0: {v} != {expected}");
}
for &v in &r1 {
assert!((v - expected).abs() < 1e-5, "w1: {v} != {expected}");
}
for &v in &r2 {
assert!((v - expected).abs() < 1e-5, "w2: {v} != {expected}");
}
}
#[test]
fn test_ring_allreduce_non_divisible_length() {
let mut workers = setup_ring(3);
let mut d0 = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
let mut d1 = vec![7.0f32, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let mut d2 = vec![0.0f32; 7];
let mut w2 = workers.pop().unwrap();
let mut w1 = workers.pop().unwrap();
let h2 = thread::spawn(move || {
w2.allreduce(&mut d2).expect("allreduce");
d2
});
let h1 = thread::spawn(move || {
w1.allreduce(&mut d1).expect("allreduce");
d1
});
workers[0].allreduce(&mut d0).expect("allreduce");
let r1 = h1.join().expect("join");
let r2 = h2.join().expect("join");
let expected: Vec<f32> =
vec![8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0, 8.0 / 3.0];
for (i, (&v, &e)) in d0.iter().zip(&expected).enumerate() {
assert!((v - e).abs() < 1e-5, "w0[{i}]: {v} != {e}");
}
assert_eq!(d0, r1, "w0 == w1");
assert_eq!(d0, r2, "w0 == w2");
}
#[test]
fn test_ring_allreduce_large_vector() {
let mut workers = setup_ring(2);
let d = 100_000;
let mut d0: Vec<f32> = (0..d).map(|i| i as f32).collect();
let mut d1: Vec<f32> = (0..d).map(|i| (d - 1 - i) as f32).collect();
let mut w1 = workers.pop().unwrap();
let h1 = thread::spawn(move || {
w1.allreduce(&mut d1).expect("allreduce");
d1
});
workers[0].allreduce(&mut d0).expect("allreduce");
let r1 = h1.join().expect("join");
let expected = (d as f32 - 1.0) / 2.0;
for (i, &v) in d0.iter().enumerate() {
assert!((v - expected).abs() < 1e-2, "w0[{i}]: {v} != {expected}");
}
assert_eq!(d0, r1, "results must be identical");
}
#[test]
fn test_allreduce_pair() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("addr");
let h = thread::spawn(move || {
let (recv, _) = listener.accept().expect("accept");
let send = TcpStream::connect(addr).expect("connect");
(recv, send)
});
let listener_a = TcpListener::bind("127.0.0.1:0").expect("bind");
let listener_b = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr_a = listener_a.local_addr().expect("addr");
let addr_b = listener_b.local_addr().expect("addr");
drop(h);
let ha = thread::spawn(move || {
let send = TcpStream::connect(addr_b).expect("connect to b");
let (recv, _) = listener_a.accept().expect("accept from b");
(send, recv)
});
let send_b = TcpStream::connect(addr_a).expect("connect to a");
let (recv_b, _) = listener_b.accept().expect("accept from a");
let (mut send_a, mut recv_a) = ha.join().expect("join");
let mut send_b = send_b;
let mut recv_b = recv_b;
let mut d_a = vec![10.0f32, 20.0, 30.0];
let mut d_b = vec![30.0f32, 20.0, 10.0];
let hb = thread::spawn(move || {
allreduce_pair(&mut d_b, &mut send_b, &mut recv_b).expect("pair b");
d_b
});
allreduce_pair(&mut d_a, &mut send_a, &mut recv_a).expect("pair a");
let result_b = hb.join().expect("join");
assert_eq!(d_a, vec![20.0, 20.0, 20.0]);
assert_eq!(result_b, vec![20.0, 20.0, 20.0]);
}
}