use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};
use thiserror::Error;
use tokio::sync::RwLock;
#[derive(Error, Debug, Clone)]
pub enum ProcessError {
#[error("Process not initialized - call init() first")]
NotInitialized,
#[error("Process already initialized")]
AlreadyInitialized,
#[error("Invalid rank {rank}, must be < {size}")]
InvalidRank { rank: usize, size: usize },
#[error("Invalid communicator size: {0}")]
InvalidSize(usize),
#[error("Network error: {0}")]
NetworkError(String),
#[error("Configuration error: {0}")]
ConfigError(String),
#[error("Barrier failed: {0}")]
BarrierFailed(String),
#[error("Split operation failed: {0}")]
SplitFailed(String),
#[error("Communication error: {0}")]
CommunicationError(String),
}
#[derive(Debug, Clone)]
pub struct ProcessInfo {
pub rank: usize,
pub size: usize,
pub addr: SocketAddr,
pub hostname: String,
}
impl ProcessInfo {
pub fn new(
rank: usize,
size: usize,
addr: SocketAddr,
hostname: String,
) -> Result<Self, ProcessError> {
if rank >= size {
return Err(ProcessError::InvalidRank { rank, size });
}
if size == 0 {
return Err(ProcessError::InvalidSize(size));
}
Ok(Self {
rank,
size,
addr,
hostname,
})
}
pub fn is_root(&self) -> bool {
self.rank == 0
}
}
#[derive(Debug, Clone)]
pub struct ProcessGroup {
pub ranks: Vec<usize>,
pub local_to_global: HashMap<usize, usize>,
pub global_to_local: HashMap<usize, usize>,
}
impl ProcessGroup {
pub fn new(ranks: Vec<usize>) -> Result<Self, ProcessError> {
if ranks.is_empty() {
return Err(ProcessError::InvalidSize(0));
}
let mut local_to_global = HashMap::new();
let mut global_to_local = HashMap::new();
for (local_rank, &global_rank) in ranks.iter().enumerate() {
local_to_global.insert(local_rank, global_rank);
global_to_local.insert(global_rank, local_rank);
}
Ok(Self {
ranks,
local_to_global,
global_to_local,
})
}
pub fn size(&self) -> usize {
self.ranks.len()
}
pub fn local_to_global_rank(&self, local_rank: usize) -> Result<usize, ProcessError> {
self.local_to_global
.get(&local_rank)
.copied()
.ok_or_else(|| ProcessError::InvalidRank {
rank: local_rank,
size: self.size(),
})
}
pub fn global_to_local_rank(&self, global_rank: usize) -> Result<usize, ProcessError> {
self.global_to_local
.get(&global_rank)
.copied()
.ok_or_else(|| ProcessError::InvalidRank {
rank: global_rank,
size: self.size(),
})
}
pub fn contains(&self, global_rank: usize) -> bool {
self.global_to_local.contains_key(&global_rank)
}
}
#[derive(Clone)]
pub struct Communicator {
info: Arc<ProcessInfo>,
group: Arc<ProcessGroup>,
addresses: Arc<HashMap<usize, SocketAddr>>,
barrier_counter: Arc<RwLock<usize>>,
}
impl Communicator {
pub fn new(
info: ProcessInfo,
group: ProcessGroup,
addresses: HashMap<usize, SocketAddr>,
) -> Result<Self, ProcessError> {
Ok(Self {
info: Arc::new(info),
group: Arc::new(group),
addresses: Arc::new(addresses),
barrier_counter: Arc::new(RwLock::new(0)),
})
}
pub fn rank(&self) -> usize {
self.info.rank
}
pub fn size(&self) -> usize {
self.info.size
}
pub fn process_info(&self) -> &ProcessInfo {
&self.info
}
pub fn group(&self) -> &ProcessGroup {
&self.group
}
pub fn is_root(&self) -> bool {
self.info.is_root()
}
pub fn address(&self, rank: usize) -> Result<SocketAddr, ProcessError> {
self.addresses
.get(&rank)
.copied()
.ok_or_else(|| ProcessError::InvalidRank {
rank,
size: self.size(),
})
}
pub async fn barrier(&self) -> Result<(), ProcessError> {
let mut counter = self.barrier_counter.write().await;
*counter += 1;
if *counter >= self.size() {
*counter = 0;
Ok(())
} else {
Ok(())
}
}
pub async fn split(&self, color: usize) -> Result<Communicator, ProcessError> {
let new_ranks = vec![self.rank()]; let new_group =
ProcessGroup::new(new_ranks).map_err(|e| ProcessError::SplitFailed(e.to_string()))?;
let new_info = ProcessInfo::new(
0, new_group.size(),
self.info.addr,
self.info.hostname.clone(),
)
.map_err(|e| ProcessError::SplitFailed(e.to_string()))?;
let new_addresses = HashMap::new();
Communicator::new(new_info, new_group, new_addresses)
.map_err(|e| ProcessError::SplitFailed(e.to_string()))
}
pub async fn create_group(&self, ranks: &[usize]) -> Result<Communicator, ProcessError> {
for &rank in ranks {
if rank >= self.size() {
return Err(ProcessError::InvalidRank {
rank,
size: self.size(),
});
}
}
let new_group = ProcessGroup::new(ranks.to_vec())
.map_err(|e| ProcessError::SplitFailed(e.to_string()))?;
let new_rank = new_group.global_to_local_rank(self.rank())?;
let new_info = ProcessInfo::new(
new_rank,
new_group.size(),
self.info.addr,
self.info.hostname.clone(),
)
.map_err(|e| ProcessError::SplitFailed(e.to_string()))?;
let mut new_addresses = HashMap::new();
for &rank in ranks {
if let Ok(addr) = self.address(rank) {
new_addresses.insert(rank, addr);
}
}
Communicator::new(new_info, new_group, new_addresses)
.map_err(|e| ProcessError::SplitFailed(e.to_string()))
}
}
impl std::fmt::Debug for Communicator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Communicator")
.field("rank", &self.rank())
.field("size", &self.size())
.field("is_root", &self.is_root())
.finish()
}
}
pub type WorldCommunicator = Communicator;
static GLOBAL_WORLD: OnceLock<WorldCommunicator> = OnceLock::new();
pub async fn init() -> Result<WorldCommunicator, ProcessError> {
if GLOBAL_WORLD.get().is_some() {
return Err(ProcessError::AlreadyInitialized);
}
let rank: usize = std::env::var("NUMRS2_RANK")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let size: usize = std::env::var("NUMRS2_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(1);
let _master_addr: SocketAddr = std::env::var("NUMRS2_MASTER_ADDR")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or_else(|| "127.0.0.1:5000".parse().expect("Valid default address"));
let port = 5000 + rank as u16;
let bind_addr: SocketAddr = std::env::var("NUMRS2_BIND_ADDR")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or_else(|| {
format!("127.0.0.1:{}", port)
.parse()
.expect("Valid default bind address")
});
let hostname = std::env::var("HOSTNAME")
.or_else(|_| std::env::var("COMPUTERNAME"))
.unwrap_or_else(|_| "localhost".to_string());
let info = ProcessInfo::new(rank, size, bind_addr, hostname)?;
let all_ranks: Vec<usize> = (0..size).collect();
let group = ProcessGroup::new(all_ranks)?;
let mut addresses = HashMap::new();
for r in 0..size {
let addr: SocketAddr = format!("127.0.0.1:{}", 5000 + r)
.parse()
.map_err(|e| ProcessError::ConfigError(format!("Invalid address: {}", e)))?;
addresses.insert(r, addr);
}
let world = Communicator::new(info, group, addresses)?;
GLOBAL_WORLD
.set(world.clone())
.map_err(|_| ProcessError::AlreadyInitialized)?;
Ok(world)
}
pub async fn finalize(_world: WorldCommunicator) -> Result<(), ProcessError> {
if GLOBAL_WORLD.get().is_none() {
return Err(ProcessError::NotInitialized);
}
Ok(())
}
pub fn rank() -> Result<usize, ProcessError> {
GLOBAL_WORLD
.get()
.map(|w| w.rank())
.ok_or(ProcessError::NotInitialized)
}
pub fn size() -> Result<usize, ProcessError> {
GLOBAL_WORLD
.get()
.map(|w| w.size())
.ok_or(ProcessError::NotInitialized)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_process_info() {
let addr: SocketAddr = "127.0.0.1:5000".parse().expect("Valid address");
let info = ProcessInfo::new(0, 4, addr, "localhost".to_string()).expect("Valid info");
assert_eq!(info.rank, 0);
assert_eq!(info.size, 4);
assert!(info.is_root());
assert_eq!(info.hostname, "localhost");
}
#[test]
fn test_process_info_invalid_rank() {
let addr: SocketAddr = "127.0.0.1:5000".parse().expect("Valid address");
let result = ProcessInfo::new(5, 4, addr, "localhost".to_string());
assert!(result.is_err());
match result {
Err(ProcessError::InvalidRank { rank, size }) => {
assert_eq!(rank, 5);
assert_eq!(size, 4);
}
_ => panic!("Expected InvalidRank error"),
}
}
#[test]
fn test_process_group() {
let ranks = vec![0, 2, 4, 6];
let group = ProcessGroup::new(ranks.clone()).expect("Valid group");
assert_eq!(group.size(), 4);
assert_eq!(group.local_to_global_rank(0).expect("Valid"), 0);
assert_eq!(group.local_to_global_rank(1).expect("Valid"), 2);
assert_eq!(group.global_to_local_rank(4).expect("Valid"), 2);
assert!(group.contains(0));
assert!(group.contains(4));
assert!(!group.contains(1));
assert!(!group.contains(3));
}
#[test]
fn test_process_group_empty() {
let result = ProcessGroup::new(vec![]);
assert!(result.is_err());
}
#[tokio::test]
async fn test_communicator_creation() {
let addr: SocketAddr = "127.0.0.1:5000".parse().expect("Valid address");
let info = ProcessInfo::new(0, 4, addr, "localhost".to_string()).expect("Valid info");
let group = ProcessGroup::new(vec![0, 1, 2, 3]).expect("Valid group");
let addresses = HashMap::new();
let comm = Communicator::new(info, group, addresses).expect("Valid communicator");
assert_eq!(comm.rank(), 0);
assert_eq!(comm.size(), 4);
assert!(comm.is_root());
}
}