use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use ferrotorch_core::FerrotorchResult;
use crate::error::DistributedError;
pub trait Backend: Send + Sync {
fn rank(&self) -> usize;
fn world_size(&self) -> usize;
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()>;
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()>;
fn barrier(&self) -> FerrotorchResult<()>;
}
pub struct TcpBackend {
rank: usize,
world_size: usize,
connections: Vec<Mutex<TcpStream>>,
}
impl TcpBackend {
pub fn new(rank: usize, world_size: usize, master_addr: &str) -> FerrotorchResult<Self> {
if world_size < 2 {
return Err(DistributedError::InvalidWorldSize { world_size }.into());
}
if rank >= world_size {
return Err(DistributedError::InvalidRank { rank, world_size }.into());
}
let mut peer_streams: Vec<Option<TcpStream>> = (0..world_size).map(|_| None).collect();
if rank == 0 {
let listener = TcpListener::bind(master_addr).map_err(|e| DistributedError::Io {
message: format!("rank 0 bind {master_addr}: {e}"),
})?;
for _ in 1..world_size {
let (mut stream, _addr) =
listener.accept().map_err(|e| DistributedError::Io {
message: format!("rank 0 accept: {e}"),
})?;
let mut rank_buf = [0u8; 8];
stream
.read_exact(&mut rank_buf)
.map_err(|e| DistributedError::Io {
message: format!("rank 0 read peer rank: {e}"),
})?;
let peer_rank = u64::from_le_bytes(rank_buf) as usize;
if peer_rank >= world_size || peer_rank == 0 {
return Err(DistributedError::InvalidRank {
rank: peer_rank,
world_size,
}
.into());
}
peer_streams[peer_rank] = Some(stream);
}
} else {
let mut stream =
TcpStream::connect(master_addr).map_err(|e| DistributedError::Io {
message: format!("rank {rank} connect to {master_addr}: {e}"),
})?;
stream
.write_all(&(rank as u64).to_le_bytes())
.map_err(|e| DistributedError::Io {
message: format!("rank {rank} announce: {e}"),
})?;
peer_streams[0] = Some(stream);
}
let connections: Vec<Mutex<TcpStream>> = peer_streams
.into_iter()
.enumerate()
.map(|(i, opt)| {
if i == rank {
Mutex::new(
opt.unwrap_or_else(|| {
TcpStream::connect("127.0.0.1:0").unwrap_or_else(|_| {
panic!("cannot create placeholder stream")
})
}),
)
} else if let Some(s) = opt {
Mutex::new(s)
} else {
panic!(
"rank {rank} has no direct connection to rank {i} (star topology limitation)"
);
}
})
.collect();
Ok(Self {
rank,
world_size,
connections,
})
}
}
impl Backend for TcpBackend {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
if dst_rank == self.rank {
return Err(DistributedError::SelfSend { rank: self.rank }.into());
}
if dst_rank >= self.world_size {
return Err(DistributedError::InvalidRank {
rank: dst_rank,
world_size: self.world_size,
}
.into());
}
let mut stream = self.connections[dst_rank]
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("send to rank {dst_rank}: {e}"),
})?;
let len_bytes = (data.len() as u64).to_le_bytes();
stream
.write_all(&len_bytes)
.map_err(|e| DistributedError::Io {
message: format!("send len to rank {dst_rank}: {e}"),
})?;
stream
.write_all(data)
.map_err(|e| DistributedError::Io {
message: format!("send data to rank {dst_rank}: {e}"),
})?;
stream.flush().map_err(|e| DistributedError::Io {
message: format!("flush to rank {dst_rank}: {e}"),
})?;
Ok(())
}
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
if src_rank == self.rank {
return Err(DistributedError::SelfSend { rank: self.rank }.into());
}
if src_rank >= self.world_size {
return Err(DistributedError::InvalidRank {
rank: src_rank,
world_size: self.world_size,
}
.into());
}
let mut stream = self.connections[src_rank]
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("recv from rank {src_rank}: {e}"),
})?;
let mut len_bytes = [0u8; 8];
stream
.read_exact(&mut len_bytes)
.map_err(|e| DistributedError::Io {
message: format!("recv len from rank {src_rank}: {e}"),
})?;
let len = u64::from_le_bytes(len_bytes) as usize;
if len != dst.len() {
return Err(DistributedError::SizeMismatch {
expected: dst.len(),
got: len,
}
.into());
}
stream
.read_exact(dst)
.map_err(|e| DistributedError::Io {
message: format!("recv data from rank {src_rank}: {e}"),
})?;
Ok(())
}
fn barrier(&self) -> FerrotorchResult<()> {
let tag = [0u8; 1];
if self.rank == 0 {
let mut buf = [0u8; 1];
for r in 1..self.world_size {
self.recv(&mut buf, r)?;
}
for r in 1..self.world_size {
self.send(&tag, r)?;
}
} else {
self.send(&tag, 0)?;
let mut buf = [0u8; 1];
self.recv(&mut buf, 0)?;
}
Ok(())
}
}
type ChannelMatrix = Arc<Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>>>;
pub struct SimulatedBackend {
rank: usize,
world_size: usize,
channels: ChannelMatrix,
}
impl SimulatedBackend {
pub fn create_group(world_size: usize) -> FerrotorchResult<Vec<Self>> {
if world_size == 0 {
return Err(DistributedError::InvalidWorldSize { world_size }.into());
}
let mut matrix: Vec<Vec<(Mutex<Sender<Vec<u8>>>, Mutex<Receiver<Vec<u8>>>)>> = Vec::new();
for _src in 0..world_size {
let mut row = Vec::new();
for _dst in 0..world_size {
let (tx, rx) = mpsc::channel();
row.push((Mutex::new(tx), Mutex::new(rx)));
}
matrix.push(row);
}
let shared = Arc::new(matrix);
let backends: Vec<Self> = (0..world_size)
.map(|rank| Self {
rank,
world_size,
channels: Arc::clone(&shared),
})
.collect();
Ok(backends)
}
}
impl Backend for SimulatedBackend {
fn rank(&self) -> usize {
self.rank
}
fn world_size(&self) -> usize {
self.world_size
}
fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
if dst_rank >= self.world_size {
return Err(DistributedError::InvalidRank {
rank: dst_rank,
world_size: self.world_size,
}
.into());
}
let tx = self.channels[self.rank][dst_rank]
.0
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("send channel lock rank {} -> {dst_rank}: {e}", self.rank),
})?;
tx.send(data.to_vec())
.map_err(|e| DistributedError::ChannelClosed {
message: format!("send rank {} -> {dst_rank}: {e}", self.rank),
})?;
Ok(())
}
fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
if src_rank >= self.world_size {
return Err(DistributedError::InvalidRank {
rank: src_rank,
world_size: self.world_size,
}
.into());
}
let rx = self.channels[src_rank][self.rank]
.1
.lock()
.map_err(|e| DistributedError::LockPoisoned {
message: format!("recv channel lock rank {src_rank} -> {}: {e}", self.rank),
})?;
let data = rx
.recv()
.map_err(|e| DistributedError::ChannelClosed {
message: format!("recv rank {src_rank} -> {}: {e}", self.rank),
})?;
if data.len() != dst.len() {
return Err(DistributedError::SizeMismatch {
expected: dst.len(),
got: data.len(),
}
.into());
}
dst.copy_from_slice(&data);
Ok(())
}
fn barrier(&self) -> FerrotorchResult<()> {
let tag = [0u8; 1];
if self.rank == 0 {
let mut buf = [0u8; 1];
for r in 1..self.world_size {
self.recv(&mut buf, r)?;
}
for r in 1..self.world_size {
self.send(&tag, r)?;
}
} else {
self.send(&tag, 0)?;
let mut buf = [0u8; 1];
self.recv(&mut buf, 0)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_simulated_send_recv() {
let group = SimulatedBackend::create_group(2).unwrap();
let mut iter = group.into_iter();
let b0 = Arc::new(iter.next().unwrap());
let b1 = Arc::new(iter.next().unwrap());
let b0c = Arc::clone(&b0);
let sender = thread::spawn(move || {
b0c.send(&[1, 2, 3, 4], 1).unwrap();
});
let mut buf = [0u8; 4];
b1.recv(&mut buf, 0).unwrap();
sender.join().unwrap();
assert_eq!(buf, [1, 2, 3, 4]);
}
#[test]
fn test_simulated_barrier() {
let group = SimulatedBackend::create_group(4).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.into_iter()
.map(|b| {
thread::spawn(move || {
b.barrier().unwrap();
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
fn test_simulated_rank_world_size() {
let group = SimulatedBackend::create_group(3).unwrap();
assert_eq!(group[0].rank(), 0);
assert_eq!(group[1].rank(), 1);
assert_eq!(group[2].rank(), 2);
assert_eq!(group[0].world_size(), 3);
}
#[test]
fn test_invalid_world_size() {
let result = SimulatedBackend::create_group(0);
assert!(result.is_err());
}
#[test]
fn test_send_to_invalid_rank() {
let group = SimulatedBackend::create_group(2).unwrap();
let result = group[0].send(&[1], 5);
assert!(result.is_err());
}
}