use std::io::{Read, Write};
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, TcpListener, TcpStream};
use std::sync::Mutex;
use std::time::Duration;
use crate::error::DistributedError;
use super::error::GlooResult;
const PEER_AD_BYTES: usize = 6;
const RENDEZVOUS_RETRY_TIMEOUT: Duration = Duration::from_secs(30);
const RENDEZVOUS_RETRY_INTERVAL: Duration = Duration::from_millis(50);
#[derive(Debug)]
pub(super) struct PeerConn {
pub(super) writer: Mutex<TcpStream>,
pub(super) reader: Mutex<TcpStream>,
}
impl PeerConn {
fn from_stream(stream: TcpStream) -> GlooResult<Self> {
let reader = stream.try_clone().map_err(|e| DistributedError::Io {
message: format!("gloo_native try_clone read half: {e}"),
})?;
Ok(Self {
writer: Mutex::new(stream),
reader: Mutex::new(reader),
})
}
}
pub(super) type PeerStreams = Vec<Option<PeerConn>>;
#[derive(Debug, Clone)]
pub struct RendezvousConfig {
pub master_addr: String,
pub rank: usize,
pub world_size: usize,
pub bind_addr: SocketAddr,
}
impl RendezvousConfig {
pub fn from_env() -> GlooResult<Self> {
fn env(key: &str) -> GlooResult<String> {
std::env::var(key).map_err(|_| DistributedError::Io {
message: format!("gloo_native rendezvous: env var `{key}` is not set"),
})
}
fn parse_usize(key: &str, raw: &str) -> GlooResult<usize> {
raw.parse::<usize>().map_err(|e| DistributedError::Io {
message: format!("gloo_native rendezvous: env `{key}` parse: {e}"),
})
}
let master_addr_host = env("MASTER_ADDR")?;
let master_port = env("MASTER_PORT")?;
let rank_raw = env("RANK")?;
let world_size_raw = env("WORLD_SIZE")?;
let rank = parse_usize("RANK", &rank_raw)?;
let world_size = parse_usize("WORLD_SIZE", &world_size_raw)?;
Ok(Self {
master_addr: format!("{master_addr_host}:{master_port}"),
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
})
}
}
pub(super) fn rendezvous(cfg: &RendezvousConfig) -> GlooResult<PeerStreams> {
if cfg.world_size < 2 {
return Err(DistributedError::InvalidWorldSize {
world_size: cfg.world_size,
});
}
if cfg.rank >= cfg.world_size {
return Err(DistributedError::InvalidRank {
rank: cfg.rank,
world_size: cfg.world_size,
});
}
let peer_listener = TcpListener::bind(cfg.bind_addr).map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} bind peer listener: {e}", cfg.rank),
})?;
let my_addr = peer_listener
.local_addr()
.map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} local_addr: {e}", cfg.rank),
})?;
let my_ad = encode_peer_ad(&my_addr)?;
let peer_table = if cfg.rank == 0 {
run_master(cfg, my_ad)?
} else {
run_worker(cfg, my_ad)?
};
form_full_mesh(cfg, &peer_listener, &peer_table)
}
fn encode_peer_ad(addr: &SocketAddr) -> GlooResult<[u8; PEER_AD_BYTES]> {
match addr {
SocketAddr::V4(v4) => {
let mut out = [0u8; PEER_AD_BYTES];
out[..4].copy_from_slice(&v4.ip().octets());
out[4..].copy_from_slice(&v4.port().to_le_bytes());
Ok(out)
}
SocketAddr::V6(_) => Err(DistributedError::Io {
message: "gloo_native rendezvous: IPv6 bind_addr is not supported".to_string(),
}),
}
}
fn decode_peer_ad(buf: [u8; PEER_AD_BYTES]) -> SocketAddr {
let ip = Ipv4Addr::new(buf[0], buf[1], buf[2], buf[3]);
let port = u16::from_le_bytes([buf[4], buf[5]]);
SocketAddr::V4(SocketAddrV4::new(ip, port))
}
fn run_master(
cfg: &RendezvousConfig,
my_ad: [u8; PEER_AD_BYTES],
) -> GlooResult<Vec<[u8; PEER_AD_BYTES]>> {
let listener = TcpListener::bind(&cfg.master_addr).map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 bind {}: {e}", cfg.master_addr),
})?;
let mut peer_table = vec![[0u8; PEER_AD_BYTES]; cfg.world_size];
peer_table[0] = my_ad;
let mut master_conns: Vec<(usize, TcpStream)> = Vec::with_capacity(cfg.world_size - 1);
for _ in 1..cfg.world_size {
let (mut stream, _) = listener.accept().map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 accept: {e}"),
})?;
let mut rank_buf = [0u8; 8];
stream
.read_exact(&mut rank_buf)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 read peer rank: {e}"),
})?;
let peer_rank = u64::from_le_bytes(rank_buf) as usize;
if peer_rank == 0 || peer_rank >= cfg.world_size {
return Err(DistributedError::InvalidRank {
rank: peer_rank,
world_size: cfg.world_size,
});
}
let mut peer_ad = [0u8; PEER_AD_BYTES];
stream
.read_exact(&mut peer_ad)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 read peer ad: {e}"),
})?;
peer_table[peer_rank] = peer_ad;
master_conns.push((peer_rank, stream));
}
let mut flat = Vec::with_capacity(cfg.world_size * PEER_AD_BYTES);
for ad in &peer_table {
flat.extend_from_slice(ad);
}
for (_, mut s) in master_conns {
s.write_all(&flat).map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 broadcast peer table: {e}"),
})?;
s.flush().map_err(|e| DistributedError::Io {
message: format!("gloo_native rank 0 flush peer table: {e}"),
})?;
}
Ok(peer_table)
}
fn run_worker(
cfg: &RendezvousConfig,
my_ad: [u8; PEER_AD_BYTES],
) -> GlooResult<Vec<[u8; PEER_AD_BYTES]>> {
let deadline = std::time::Instant::now() + RENDEZVOUS_RETRY_TIMEOUT;
let mut stream = loop {
match TcpStream::connect(&cfg.master_addr) {
Ok(s) => break s,
Err(_) if std::time::Instant::now() < deadline => {
std::thread::sleep(RENDEZVOUS_RETRY_INTERVAL);
}
Err(e) => {
return Err(DistributedError::Io {
message: format!(
"gloo_native rank {} connect to {} (after retries): {e}",
cfg.rank, cfg.master_addr,
),
});
}
}
};
stream
.write_all(&(cfg.rank as u64).to_le_bytes())
.map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} announce rank: {e}", cfg.rank),
})?;
stream.write_all(&my_ad).map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} announce ad: {e}", cfg.rank),
})?;
stream.flush().map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} flush announce: {e}", cfg.rank),
})?;
let flat_len = cfg.world_size * PEER_AD_BYTES;
let mut flat = vec![0u8; flat_len];
stream
.read_exact(&mut flat)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} read peer table: {e}", cfg.rank),
})?;
let mut peer_table = vec![[0u8; PEER_AD_BYTES]; cfg.world_size];
for (i, slot) in peer_table.iter_mut().enumerate() {
slot.copy_from_slice(&flat[i * PEER_AD_BYTES..(i + 1) * PEER_AD_BYTES]);
}
Ok(peer_table)
}
fn form_full_mesh(
cfg: &RendezvousConfig,
peer_listener: &TcpListener,
peer_table: &[[u8; PEER_AD_BYTES]],
) -> GlooResult<PeerStreams> {
let mut streams: Vec<Option<TcpStream>> = (0..cfg.world_size).map(|_| None).collect();
for _ in (cfg.rank + 1)..cfg.world_size {
let (mut stream, _) = peer_listener.accept().map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} accept full-mesh peer: {e}", cfg.rank,),
})?;
let mut rank_buf = [0u8; 8];
stream
.read_exact(&mut rank_buf)
.map_err(|e| DistributedError::Io {
message: format!(
"gloo_native rank {} read full-mesh peer rank: {e}",
cfg.rank,
),
})?;
let peer_rank = u64::from_le_bytes(rank_buf) as usize;
if peer_rank <= cfg.rank || peer_rank >= cfg.world_size {
return Err(DistributedError::InvalidRank {
rank: peer_rank,
world_size: cfg.world_size,
});
}
streams[peer_rank] = Some(stream);
}
for peer in 0..cfg.rank {
let peer_addr = decode_peer_ad(peer_table[peer]);
let mut stream = TcpStream::connect(peer_addr).map_err(|e| DistributedError::Io {
message: format!(
"gloo_native rank {} connect full-mesh peer {peer} at {peer_addr}: {e}",
cfg.rank,
),
})?;
stream
.write_all(&(cfg.rank as u64).to_le_bytes())
.map_err(|e| DistributedError::Io {
message: format!(
"gloo_native rank {} announce to full-mesh peer {peer}: {e}",
cfg.rank,
),
})?;
stream.flush().map_err(|e| DistributedError::Io {
message: format!("gloo_native rank {} flush to peer {peer}: {e}", cfg.rank),
})?;
streams[peer] = Some(stream);
}
streams
.into_iter()
.enumerate()
.map(|(i, opt)| {
if i == cfg.rank {
Ok(None)
} else {
opt.map(PeerConn::from_stream).transpose()
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn run_n_rank_rendezvous(world_size: usize) -> Vec<PeerStreams> {
let probe = TcpListener::bind("127.0.0.1:0").expect("probe bind");
let master_addr = probe.local_addr().expect("local_addr").to_string();
drop(probe);
let handles: Vec<_> = (0..world_size)
.map(|rank| {
let ma = master_addr.clone();
thread::spawn(move || {
let cfg = RendezvousConfig {
master_addr: ma,
rank,
world_size,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
rendezvous(&cfg).expect("rendezvous")
})
})
.collect();
handles
.into_iter()
.map(|h| h.join().expect("join"))
.collect()
}
#[test]
fn rendezvous_full_mesh_n2() {
let conns = run_n_rank_rendezvous(2);
assert_eq!(conns.len(), 2);
assert!(conns[0][0].is_none());
assert!(conns[0][1].is_some());
assert!(conns[1][0].is_some());
assert!(conns[1][1].is_none());
}
#[test]
fn rendezvous_full_mesh_n4_all_slots_filled() {
let conns = run_n_rank_rendezvous(4);
assert_eq!(conns.len(), 4);
for (rank, slots) in conns.iter().enumerate() {
assert_eq!(slots.len(), 4);
for (peer, slot) in slots.iter().enumerate() {
if peer == rank {
assert!(slot.is_none(), "rank {rank}: self-slot must be None");
} else {
assert!(
slot.is_some(),
"rank {rank}: slot for peer {peer} must be Some"
);
}
}
}
}
#[test]
fn rendezvous_rejects_world_size_below_two() {
let cfg = RendezvousConfig {
master_addr: "127.0.0.1:0".to_string(),
rank: 0,
world_size: 1,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
let err = rendezvous(&cfg).expect_err("must reject world_size=1");
match err {
DistributedError::InvalidWorldSize { world_size } => assert_eq!(world_size, 1),
other => panic!("expected InvalidWorldSize, got {other:?}"),
}
}
#[test]
fn rendezvous_rejects_rank_out_of_range() {
let cfg = RendezvousConfig {
master_addr: "127.0.0.1:0".to_string(),
rank: 5,
world_size: 4,
bind_addr: SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0)),
};
let err = rendezvous(&cfg).expect_err("must reject rank >= world_size");
match err {
DistributedError::InvalidRank { rank, world_size } => {
assert_eq!(rank, 5);
assert_eq!(world_size, 4);
}
other => panic!("expected InvalidRank, got {other:?}"),
}
}
#[test]
fn peer_ad_round_trip() {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 51234));
let ad = encode_peer_ad(&addr).expect("encode");
let back = decode_peer_ad(ad);
assert_eq!(addr, back);
}
}