use std::time::Duration;
use ferrotorch_core::FerrotorchResult;
use crate::backend::Backend;
use crate::error::DistributedError;
pub(super) mod collectives;
pub(super) mod connect;
pub(super) mod error;
pub(super) mod transport;
use self::collectives::{
RingTransport, ring_allreduce_sum_f32_bytes, ring_barrier, tree_broadcast_f32_bytes,
};
use self::connect::{PeerConn, PeerStreams, RendezvousConfig, rendezvous};
use self::error::GlooResult;
use self::transport::{recv_msg_into, send_msg, with_read_timeout};
pub use self::connect::RendezvousConfig as GlooRendezvousConfig;
pub const DEFAULT_GLOO_TIMEOUT: Duration = Duration::from_secs(60);
#[derive(Debug)]
pub struct GlooBackendInner {
rank: usize,
world_size: usize,
connections: PeerStreams,
}
impl GlooBackendInner {
pub fn new(cfg: &RendezvousConfig) -> GlooResult<Self> {
let connections = rendezvous(cfg)?;
Ok(Self {
rank: cfg.rank,
world_size: cfg.world_size,
connections,
})
}
fn conn(&self, peer: usize) -> GlooResult<&PeerConn> {
if peer == self.rank {
return Err(DistributedError::SelfSend { rank: self.rank });
}
if peer >= self.world_size {
return Err(DistributedError::InvalidRank {
rank: peer,
world_size: self.world_size,
});
}
self.connections[peer]
.as_ref()
.ok_or(DistributedError::NoConnection { rank: peer })
}
fn send_inner(&self, data: &[u8], dst: usize) -> GlooResult<()> {
let conn = self.conn(dst)?;
let mut stream = conn
.writer
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("gloo_native send rank {} -> {dst}: {e}", self.rank),
})?;
send_msg(&mut stream, data)
}
fn recv_inner(&self, dst: &mut [u8], src: usize, timeout: Duration) -> GlooResult<()> {
let conn = self.conn(src)?;
let mut stream = conn
.reader
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("gloo_native recv rank {src} -> {}: {e}", self.rank),
})?;
with_read_timeout(&mut stream, timeout, |s| recv_msg_into(s, dst))
}
pub fn ring_allreduce_sum_f32(&self, data: &mut [f32]) -> FerrotorchResult<()> {
self.ring_allreduce_sum_f32_with_timeout(data, DEFAULT_GLOO_TIMEOUT)
}
pub fn ring_allreduce_sum_f32_with_timeout(
&self,
data: &mut [f32],
timeout: Duration,
) -> FerrotorchResult<()> {
let byte_len = std::mem::size_of_val(data);
let bytes: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<u8>(), byte_len) };
ring_allreduce_sum_f32_bytes(self, bytes, timeout).map_err(Into::into)
}
pub fn tree_broadcast_f32(&self, data: &mut [f32], root: usize) -> FerrotorchResult<()> {
self.tree_broadcast_f32_with_timeout(data, root, DEFAULT_GLOO_TIMEOUT)
}
pub fn tree_broadcast_f32_with_timeout(
&self,
data: &mut [f32],
root: usize,
timeout: Duration,
) -> FerrotorchResult<()> {
let byte_len = std::mem::size_of_val(data);
let bytes: &mut [u8] =
unsafe { std::slice::from_raw_parts_mut(data.as_mut_ptr().cast::<u8>(), byte_len) };
tree_broadcast_f32_bytes(self, bytes, root, timeout).map_err(Into::into)
}
}
impl RingTransport for GlooBackendInner {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, data: &[u8], dst: usize) -> GlooResult<()> {
self.send_inner(data, dst)
}
fn recv(&self, dst: &mut [u8], src: usize, timeout: Duration) -> GlooResult<()> {
self.recv_inner(dst, src, timeout)
}
}
impl Backend for GlooBackendInner {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
self.send_inner(data, dst_rank).map_err(Into::into)
}
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
self.recv_inner(dst, src_rank, DEFAULT_GLOO_TIMEOUT)
.map_err(Into::into)
}
fn recv_timeout(
&self,
dst: &mut [u8],
src_rank: usize,
timeout: Duration,
) -> FerrotorchResult<()> {
self.recv_inner(dst, src_rank, timeout).map_err(Into::into)
}
fn barrier(&self) -> FerrotorchResult<()> {
ring_barrier(self, DEFAULT_GLOO_TIMEOUT).map_err(Into::into)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpListener};
use std::sync::Arc;
use std::thread;
fn spawn_group(world_size: usize) -> Vec<Arc<GlooBackendInner>> {
let probe = TcpListener::bind("127.0.0.1:0").expect("probe bind");
let master_addr = probe.local_addr().expect("local_addr").to_string();
drop(probe);
let handles: Vec<_> = (0..world_size)
.map(|rank| {
let ma = master_addr.clone();
thread::spawn(move || {
let cfg = RendezvousConfig {
master_addr: ma,
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
Arc::new(GlooBackendInner::new(&cfg).expect("backend"))
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("join"))
.collect()
}
#[test]
fn full_mesh_send_recv_two_ranks() {
let group = spawn_group(2);
let b0 = Arc::clone(&group[0]);
let b1 = Arc::clone(&group[1]);
let sender = thread::spawn(move || {
Backend::send(&*b0, &[7, 8, 9, 10], 1).expect("send 0->1");
});
let mut buf = [0u8; 4];
Backend::recv(&*b1, &mut buf, 0).expect("recv 1<-0");
sender.join().expect("join");
assert_eq!(buf, [7, 8, 9, 10]);
}
#[test]
fn ring_allreduce_over_real_tcp_two_ranks() {
let group = spawn_group(2);
let mut a = vec![1.0f32, 2.0, 3.0, 4.0];
let mut b = vec![10.0f32, 20.0, 30.0, 40.0];
thread::scope(|s| {
let g0 = Arc::clone(&group[0]);
let g1 = Arc::clone(&group[1]);
let h0 = s.spawn(move || {
g0.ring_allreduce_sum_f32(&mut a).expect("ar 0");
a
});
let h1 = s.spawn(move || {
g1.ring_allreduce_sum_f32(&mut b).expect("ar 1");
b
});
let r0 = h0.join().unwrap();
let r1 = h1.join().unwrap();
let expected = vec![11.0f32, 22.0, 33.0, 44.0];
assert_eq!(r0, expected, "rank 0");
assert_eq!(r1, expected, "rank 1");
});
}
#[test]
fn ring_allreduce_over_real_tcp_four_ranks() {
let group = spawn_group(4);
let inputs: Vec<Vec<f32>> = (0..4u32)
.map(|r| (0..13u32).map(|i| (r * 100 + i) as f32).collect())
.collect();
let expected: Vec<f32> = (0..13u32)
.map(|i| (0..4u32).map(|r| (r * 100 + i) as f32).sum())
.collect();
thread::scope(|s| {
let mut handles = Vec::new();
for (rank, input) in inputs.into_iter().enumerate() {
let g = Arc::clone(&group[rank]);
handles.push(s.spawn(move || {
let mut data = input;
g.ring_allreduce_sum_f32(&mut data).expect("allreduce");
data
}));
}
for h in handles {
let got = h.join().unwrap();
assert_eq!(got, expected);
}
});
}
#[test]
fn tree_broadcast_over_real_tcp_four_ranks() {
let group = spawn_group(4);
let payload = vec![100.5f32, 200.25, 300.125];
thread::scope(|s| {
let mut handles = Vec::new();
for (rank, g_ref) in group.iter().enumerate() {
let g = Arc::clone(g_ref);
let p = payload.clone();
handles.push(s.spawn(move || {
let mut data = if rank == 2 { p } else { vec![0.0f32; 3] };
g.tree_broadcast_f32(&mut data, 2).expect("broadcast");
data
}));
}
for h in handles {
let got = h.join().unwrap();
assert_eq!(got, vec![100.5f32, 200.25, 300.125]);
}
});
}
#[test]
fn barrier_over_real_tcp_three_ranks() {
let group = spawn_group(3);
thread::scope(|s| {
let mut handles = Vec::new();
for g_ref in group.iter().take(3) {
let g = Arc::clone(g_ref);
handles.push(s.spawn(move || {
Backend::barrier(&*g).expect("barrier");
}));
}
for h in handles {
h.join().unwrap();
}
});
}
#[test]
fn rendezvous_config_from_env_reads_pytorch_vars() {
unsafe {
std::env::set_var("MASTER_ADDR", "127.0.0.1");
std::env::set_var("MASTER_PORT", "29501");
std::env::set_var("RANK", "2");
std::env::set_var("WORLD_SIZE", "4");
}
let cfg = RendezvousConfig::from_env().expect("from_env");
assert_eq!(cfg.master_addr, "127.0.0.1:29501");
assert_eq!(cfg.rank, 2);
assert_eq!(cfg.world_size, 4);
unsafe {
std::env::remove_var("MASTER_ADDR");
std::env::remove_var("MASTER_PORT");
std::env::remove_var("RANK");
std::env::remove_var("WORLD_SIZE");
}
}
}