#![allow(unexpected_cfgs)]
use crate::backend::{Backend, BackendConfig, BackendType, MockBackend};
use crate::{TorshDistributedError, TorshResult};
use parking_lot::RwLock;
use std::sync::Arc;
pub type Rank = u32;
pub type WorldSize = u32;
pub struct ProcessGroup {
backend: Arc<RwLock<Box<dyn Backend>>>,
rank: Rank,
world_size: WorldSize,
#[allow(dead_code)]
master_addr: String,
#[allow(dead_code)]
master_port: u16,
}
impl ProcessGroup {
pub async fn new(
backend_type: BackendType,
rank: Rank,
world_size: WorldSize,
master_addr: &str,
master_port: u16,
) -> TorshResult<Self> {
let mut backend = create_backend(backend_type, rank, world_size)?;
let config = BackendConfig::default();
backend.init(config).await?;
let pg = Self {
backend: Arc::new(RwLock::new(backend)),
rank,
world_size,
master_addr: master_addr.to_string(),
master_port,
};
Ok(pg)
}
pub fn rank(&self) -> Rank {
self.rank
}
pub fn world_size(&self) -> WorldSize {
self.world_size
}
pub fn backend_type(&self) -> BackendType {
self.backend.read().backend_type()
}
pub fn backend(&self) -> &Arc<RwLock<Box<dyn Backend>>> {
&self.backend
}
}
fn create_backend(
backend_type: BackendType,
rank: Rank,
world_size: WorldSize,
) -> TorshResult<Box<dyn Backend>> {
match backend_type {
#[cfg(feature = "nccl")]
BackendType::Nccl => {
Ok(Box::new(MockBackend::with_backend_type(
rank,
world_size,
BackendType::Nccl,
)))
}
#[cfg(not(feature = "nccl"))]
BackendType::Nccl => Err(TorshDistributedError::feature_not_available(
"NCCL backend",
"nccl",
)),
#[cfg(feature = "mpi")]
BackendType::Mpi => {
Ok(Box::new(MockBackend::with_backend_type(
rank,
world_size,
BackendType::Mpi,
)))
}
#[cfg(not(feature = "mpi"))]
BackendType::Mpi => Err(TorshDistributedError::feature_not_available(
"MPI backend",
"mpi",
)),
BackendType::Gloo => {
Ok(Box::new(MockBackend::with_backend_type(
rank,
world_size,
BackendType::Gloo,
)))
}
BackendType::Custom(name) => Err(TorshDistributedError::feature_not_available(
format!("Custom backend: {}", name),
"custom backend implementation",
)),
}
}