use std::time::Duration;
use crate::error::DistributedError;
use super::error::GlooResult;
pub(super) trait RingTransport: Sync {
fn rank(&self) -> usize;
fn world_size(&self) -> usize;
fn send(&self, data: &[u8], dst: usize) -> GlooResult<()>;
fn recv(&self, dst: &mut [u8], src: usize, timeout: Duration) -> GlooResult<()>;
}
fn ring_neighbours(rank: usize, world_size: usize) -> (usize, usize) {
debug_assert!(world_size >= 2);
let next = (rank + 1) % world_size;
let prev = (rank + world_size - 1) % world_size;
(prev, next)
}
pub(super) fn ring_allreduce_sum_f32_bytes(
transport: &dyn RingTransport,
buf: &mut [u8],
timeout: Duration,
) -> GlooResult<()> {
let rank = transport.rank();
let world_size = transport.world_size();
if world_size == 1 {
return Ok(());
}
if buf.is_empty() {
return Ok(());
}
if buf.len() % std::mem::size_of::<f32>() != 0 {
return Err(DistributedError::SizeMismatch {
expected: buf.len() - (buf.len() % std::mem::size_of::<f32>()),
got: buf.len(),
});
}
let total_elems = buf.len() / std::mem::size_of::<f32>();
let chunk_ranges = chunk_ranges(total_elems, world_size);
let (prev, next) = ring_neighbours(rank, world_size);
for step in 0..(world_size - 1) {
let send_chunk = (rank + world_size - step) % world_size;
let recv_chunk = (rank + world_size - step - 1) % world_size;
let (send_lo, send_hi) = chunk_ranges[send_chunk];
let (recv_lo, recv_hi) = chunk_ranges[recv_chunk];
let send_bytes = buf[bytes_of(send_lo)..bytes_of(send_hi)].to_vec();
let mut recv_bytes = vec![0u8; bytes_of(recv_hi) - bytes_of(recv_lo)];
send_recv(transport, &send_bytes, next, &mut recv_bytes, prev, timeout)?;
accumulate_f32_inplace(&mut buf[bytes_of(recv_lo)..bytes_of(recv_hi)], &recv_bytes);
}
for step in 0..(world_size - 1) {
let send_chunk = (rank + world_size + 1 - step) % world_size;
let recv_chunk = (rank + world_size - step) % world_size;
let (send_lo, send_hi) = chunk_ranges[send_chunk];
let (recv_lo, recv_hi) = chunk_ranges[recv_chunk];
let send_bytes = buf[bytes_of(send_lo)..bytes_of(send_hi)].to_vec();
let mut recv_bytes = vec![0u8; bytes_of(recv_hi) - bytes_of(recv_lo)];
send_recv(transport, &send_bytes, next, &mut recv_bytes, prev, timeout)?;
buf[bytes_of(recv_lo)..bytes_of(recv_hi)].copy_from_slice(&recv_bytes);
}
Ok(())
}
fn chunk_ranges(total_elems: usize, world_size: usize) -> Vec<(usize, usize)> {
(0..world_size)
.map(|i| {
let lo = i * total_elems / world_size;
let hi = if i + 1 == world_size {
total_elems
} else {
(i + 1) * total_elems / world_size
};
(lo, hi)
})
.collect()
}
const fn bytes_of(elem_idx: usize) -> usize {
elem_idx * std::mem::size_of::<f32>()
}
fn accumulate_f32_inplace(dst: &mut [u8], src: &[u8]) {
debug_assert_eq!(dst.len(), src.len());
debug_assert_eq!(dst.len() % std::mem::size_of::<f32>(), 0);
for (d_chunk, s_chunk) in dst
.chunks_exact_mut(std::mem::size_of::<f32>())
.zip(src.chunks_exact(std::mem::size_of::<f32>()))
{
let d_arr: [u8; 4] = d_chunk.try_into().expect("4-byte chunk");
let s_arr: [u8; 4] = s_chunk.try_into().expect("4-byte chunk");
let new = f32::from_le_bytes(d_arr) + f32::from_le_bytes(s_arr);
d_chunk.copy_from_slice(&new.to_le_bytes());
}
}
fn send_recv(
transport: &dyn RingTransport,
send_bytes: &[u8],
next: usize,
recv_bytes: &mut [u8],
prev: usize,
timeout: Duration,
) -> GlooResult<()> {
std::thread::scope(|scope| {
let send_handle = scope.spawn(move || transport.send(send_bytes, next));
let recv_result = transport.recv(recv_bytes, prev, timeout);
let send_result = send_handle.join().map_err(|_| DistributedError::Io {
message: "gloo_native ring send worker panicked".to_string(),
})?;
send_result?;
recv_result?;
Ok(())
})
}
pub(super) fn tree_broadcast_f32_bytes(
transport: &dyn RingTransport,
buf: &mut [u8],
root: usize,
timeout: Duration,
) -> GlooResult<()> {
let rank = transport.rank();
let world_size = transport.world_size();
if world_size == 1 {
return Ok(());
}
if root >= world_size {
return Err(DistributedError::InvalidRank {
rank: root,
world_size,
});
}
let tree_rank = (rank + world_size - root) % world_size;
if tree_rank != 0 {
let parent_tree = (tree_rank - 1) / 2;
let parent = (parent_tree + root) % world_size;
transport.recv(buf, parent, timeout)?;
}
for child_tree in [tree_rank * 2 + 1, tree_rank * 2 + 2] {
if child_tree < world_size {
let child = (child_tree + root) % world_size;
transport.send(buf, child)?;
}
}
Ok(())
}
pub(super) fn ring_barrier(transport: &dyn RingTransport, timeout: Duration) -> GlooResult<()> {
let rank = transport.rank();
let world_size = transport.world_size();
if world_size == 1 {
return Ok(());
}
let (prev, next) = ring_neighbours(rank, world_size);
let token = [0u8; 1];
if rank == 0 {
transport.send(&token, next)?;
let mut buf = [0u8; 1];
transport.recv(&mut buf, prev, timeout)?;
} else {
let mut buf = [0u8; 1];
transport.recv(&mut buf, prev, timeout)?;
transport.send(&token, next)?;
}
if rank == 0 {
transport.send(&token, next)?;
let mut buf = [0u8; 1];
transport.recv(&mut buf, prev, timeout)?;
} else {
let mut buf = [0u8; 1];
transport.recv(&mut buf, prev, timeout)?;
transport.send(&token, next)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::sync::mpsc::{Receiver, Sender, channel};
type ChannelPair = (Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>);
type ChannelMatrix = Vec<Vec<ChannelPair>>;
struct Channels {
inner: ChannelMatrix,
}
impl Channels {
fn new(world_size: usize) -> Self {
let inner = (0..world_size)
.map(|_src| {
(0..world_size)
.map(|_dst| {
let (tx, rx) = channel();
(Mutex::new(tx), Mutex::new(rx))
})
.collect()
})
.collect();
Self { inner }
}
}
struct RankView<'a> {
my_rank: usize,
world_size: usize,
ch: &'a Channels,
}
impl RingTransport for RankView<'_> {
fn rank(&self) -> usize {
self.my_rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, data: &[u8], dst: usize) -> GlooResult<()> {
self.ch.inner[self.my_rank][dst]
.0
.lock()
.unwrap()
.send(data.to_vec())
.map_err(|e| DistributedError::ChannelClosed {
message: format!("test send {} -> {dst}: {e}", self.my_rank),
})?;
Ok(())
}
fn recv(&self, dst: &mut [u8], src: usize, _timeout: Duration) -> GlooResult<()> {
let v = self.ch.inner[src][self.my_rank]
.1
.lock()
.unwrap()
.recv()
.map_err(|e| DistributedError::ChannelClosed {
message: format!("test recv {src} -> {}: {e}", self.my_rank),
})?;
if v.len() != dst.len() {
return Err(DistributedError::SizeMismatch {
expected: dst.len(),
got: v.len(),
});
}
dst.copy_from_slice(&v);
Ok(())
}
}
fn floats_to_le_bytes(xs: &[f32]) -> Vec<u8> {
let mut out = Vec::with_capacity(xs.len() * 4);
for &x in xs {
out.extend_from_slice(&x.to_le_bytes());
}
out
}
fn le_bytes_to_floats(bs: &[u8]) -> Vec<f32> {
bs.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
#[test]
fn ring_allreduce_two_ranks_sum() {
let channels = Channels::new(2);
let inputs = [[1.0f32, 2.0, 3.0], [4.0f32, 5.0, 6.0]];
let mut bufs: Vec<Vec<u8>> = inputs.iter().map(|x| floats_to_le_bytes(x)).collect();
std::thread::scope(|s| {
let mut handles = Vec::new();
for (rank, buf) in bufs.iter_mut().enumerate() {
let ch = &channels;
handles.push(s.spawn(move || {
let r = RankView {
my_rank: rank,
world_size: 2,
ch,
};
ring_allreduce_sum_f32_bytes(&r, buf, Duration::from_secs(5))
.expect("allreduce");
}));
}
for h in handles {
h.join().unwrap();
}
});
let expected = [5.0f32, 7.0, 9.0];
for (i, buf) in bufs.iter().enumerate() {
let got = le_bytes_to_floats(buf);
assert_eq!(got, expected, "rank {i}");
}
}
#[test]
fn ring_allreduce_four_ranks_sum_with_uneven_chunks() {
let world_size = 4;
let channels = Channels::new(world_size);
let inputs: Vec<Vec<f32>> = [1.0f32, 2.0, 4.0, 8.0]
.iter()
.map(|&v| vec![v; 7])
.collect();
let mut bufs: Vec<Vec<u8>> = inputs.iter().map(|x| floats_to_le_bytes(x)).collect();
std::thread::scope(|s| {
let mut handles = Vec::new();
for (rank, buf) in bufs.iter_mut().enumerate() {
let ch = &channels;
handles.push(s.spawn(move || {
let r = RankView {
my_rank: rank,
world_size,
ch,
};
ring_allreduce_sum_f32_bytes(&r, buf, Duration::from_secs(5))
.expect("allreduce");
}));
}
for h in handles {
h.join().unwrap();
}
});
let expected = vec![15.0f32; 7];
for (i, buf) in bufs.iter().enumerate() {
let got = le_bytes_to_floats(buf);
assert_eq!(got, expected, "rank {i}");
}
}
#[test]
fn ring_allreduce_three_ranks_sum() {
let world_size = 3;
let channels = Channels::new(world_size);
let inputs = [
vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0],
vec![10.0f32, 20.0, 30.0, 40.0, 50.0, 60.0],
vec![100.0f32, 200.0, 300.0, 400.0, 500.0, 600.0],
];
let mut bufs: Vec<Vec<u8>> = inputs.iter().map(|x| floats_to_le_bytes(x)).collect();
std::thread::scope(|s| {
let mut handles = Vec::new();
for (rank, buf) in bufs.iter_mut().enumerate() {
let ch = &channels;
handles.push(s.spawn(move || {
let r = RankView {
my_rank: rank,
world_size,
ch,
};
ring_allreduce_sum_f32_bytes(&r, buf, Duration::from_secs(5))
.expect("allreduce");
}));
}
for h in handles {
h.join().unwrap();
}
});
let expected = vec![111.0f32, 222.0, 333.0, 444.0, 555.0, 666.0];
for (i, buf) in bufs.iter().enumerate() {
let got = le_bytes_to_floats(buf);
assert_eq!(got, expected, "rank {i}");
}
}
#[test]
fn tree_broadcast_distributes_from_root() {
let world_size = 4;
let channels = Channels::new(world_size);
let payload = floats_to_le_bytes(&[42.0, 43.0, 44.0]);
let mut bufs: Vec<Vec<u8>> = (0..world_size)
.map(|r| {
if r == 1 {
payload.clone()
} else {
vec![0u8; payload.len()]
}
})
.collect();
std::thread::scope(|s| {
let mut handles = Vec::new();
for (rank, buf) in bufs.iter_mut().enumerate() {
let ch = &channels;
handles.push(s.spawn(move || {
let r = RankView {
my_rank: rank,
world_size,
ch,
};
tree_broadcast_f32_bytes(&r, buf, 1, Duration::from_secs(5))
.expect("broadcast");
}));
}
for h in handles {
h.join().unwrap();
}
});
for (rank, buf) in bufs.iter().enumerate() {
let got = le_bytes_to_floats(buf);
assert_eq!(got, vec![42.0, 43.0, 44.0], "rank {rank}");
}
}
#[test]
fn ring_barrier_serialises_all_ranks() {
let world_size = 4;
let channels = Channels::new(world_size);
let entered = std::sync::atomic::AtomicUsize::new(0);
std::thread::scope(|s| {
let mut handles = Vec::new();
for rank in 0..world_size {
let ch = &channels;
let entered_ref = &entered;
handles.push(s.spawn(move || {
let r = RankView {
my_rank: rank,
world_size,
ch,
};
entered_ref.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
ring_barrier(&r, Duration::from_secs(5)).expect("barrier");
let n = entered_ref.load(std::sync::atomic::Ordering::SeqCst);
assert_eq!(
n, world_size,
"rank {rank}: expected all {world_size} to be entered, saw {n}"
);
}));
}
for h in handles {
h.join().unwrap();
}
});
}
#[test]
fn chunk_ranges_balanced() {
let r = chunk_ranges(8, 4);
assert_eq!(r, vec![(0, 2), (2, 4), (4, 6), (6, 8)]);
}
#[test]
fn chunk_ranges_unbalanced_cover_all_elements_exactly_once() {
let r = chunk_ranges(7, 4);
assert_eq!(r, vec![(0, 1), (1, 3), (3, 5), (5, 7)]);
assert_eq!(r.first().unwrap().0, 0);
assert_eq!(r.last().unwrap().1, 7);
for w in r.windows(2) {
assert_eq!(w[0].1, w[1].0, "chunk boundaries must abut");
}
let total: usize = r.iter().map(|(lo, hi)| hi - lo).sum();
assert_eq!(total, 7);
}
#[test]
fn ring_neighbours_wrap_around() {
assert_eq!(ring_neighbours(0, 4), (3, 1));
assert_eq!(ring_neighbours(3, 4), (2, 0));
assert_eq!(ring_neighbours(2, 4), (1, 3));
}
}