use super::{get_distributed_state, DistributedBackend, ProcessGroup, ReduceOp};
use crate::error::{RusTorchError, RusTorchResult};
use crate::nn::Module;
use crate::tensor::Tensor;
use num_traits::Float;
use std::sync::{Arc, Mutex};
use std::time::Duration;
pub fn init_process_group(
backend: DistributedBackend,
init_method: Option<&str>,
world_size: Option<usize>,
rank: Option<usize>,
timeout: Option<Duration>,
) -> RusTorchResult<()> {
let (master_addr, master_port) = if let Some(init_method) = init_method {
parse_init_method(init_method)?
} else {
get_master_info_from_env()?
};
let world_size = world_size
.or_else(|| {
std::env::var("WORLD_SIZE")
.ok()
.and_then(|s| s.parse().ok())
})
.ok_or_else(|| RusTorchError::distributed("WORLD_SIZE not specified"))?;
let rank = rank
.or_else(|| std::env::var("RANK").ok().and_then(|s| s.parse().ok()))
.ok_or_else(|| RusTorchError::distributed("RANK not specified"))?;
let process_group = ProcessGroup::new(rank, world_size, backend, master_addr, master_port);
if let Some(_timeout) = timeout {
}
process_group.init()?;
let state = get_distributed_state();
let mut state_guard = state.lock().unwrap();
state_guard.set_process_group(process_group);
Ok(())
}
pub fn destroy_process_group() -> RusTorchResult<()> {
super::finalize()
}
pub fn get_rank() -> usize {
super::get_rank().unwrap_or(0)
}
pub fn get_world_size() -> usize {
super::get_world_size().unwrap_or(1)
}
pub fn is_initialized() -> bool {
super::is_available()
}
pub fn all_reduce<T: Float + Send + Sync + 'static>(
tensor: &mut Tensor<T>,
op: ReduceOp,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
let _ = (group, async_op);
super::common::CommonOps::default_all_reduce(tensor, op)?;
Ok(None) }
pub fn all_gather<T: Float + Send + Sync + 'static>(
tensor_list: &mut Vec<Tensor<T>>,
tensor: &Tensor<T>,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
let world_size = get_world_size();
let gathered = super::common::CommonOps::default_all_gather(tensor, world_size)?;
*tensor_list = gathered;
let _ = (group, async_op);
Ok(None)
}
pub fn broadcast<T: Float + Send + Sync + 'static>(
tensor: &mut Tensor<T>,
src: usize,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
super::common::CommonOps::default_broadcast(tensor, src)?;
let _ = (group, async_op);
Ok(None)
}
pub fn reduce<T: Float + Send + Sync + 'static>(
tensor: &mut Tensor<T>,
dst: usize,
op: ReduceOp,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
if get_rank() == dst {
super::common::CommonOps::default_all_reduce(tensor, op)?;
}
let _ = (group, async_op);
Ok(None)
}
pub fn scatter<T: Float + Send + Sync + 'static>(
tensor: &mut Tensor<T>,
scatter_list: Option<&[Tensor<T>]>,
src: usize,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
if let Some(tensors) = scatter_list {
if get_rank() == src {
let scattered = super::common::CommonOps::default_gather(
&tensors[get_rank()],
get_world_size(),
src,
)?;
if !scattered.is_empty() {
*tensor = scattered[get_rank()].clone();
}
}
}
let _ = (group, async_op);
Ok(None)
}
pub fn gather<T: Float + Send + Sync + 'static>(
tensor: &Tensor<T>,
gather_list: Option<&mut Vec<Tensor<T>>>,
dst: usize,
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
if get_rank() == dst {
if let Some(list) = gather_list {
let gathered = super::common::CommonOps::default_gather(tensor, get_world_size(), dst)?;
*list = gathered;
}
}
let _ = (group, async_op);
Ok(None)
}
pub fn barrier(
group: Option<&ProcessGroup>,
async_op: bool,
) -> RusTorchResult<Option<DistributedRequest>> {
if !is_initialized() {
return Err(RusTorchError::distributed("Process group not initialized"));
}
let _ = (group, async_op);
Ok(None)
}
#[derive(Debug)]
pub struct DistributedRequest {
completed: Arc<Mutex<bool>>,
}
impl DistributedRequest {
pub fn new() -> Self {
Self {
completed: Arc::new(Mutex::new(false)),
}
}
pub fn wait(&self) -> RusTorchResult<()> {
let mut completed = self.completed.lock().unwrap();
*completed = true;
Ok(())
}
pub fn is_completed(&self) -> bool {
*self.completed.lock().unwrap()
}
}
impl Default for DistributedRequest {
fn default() -> Self {
Self::new()
}
}
pub struct Group {
process_group: ProcessGroup,
}
impl Group {
pub fn new(ranks: Vec<usize>) -> RusTorchResult<Self> {
if ranks.is_empty() {
return Err(RusTorchError::distributed("Empty ranks list"));
}
let process_group = ProcessGroup::new(
ranks[0],
ranks.len(),
DistributedBackend::Gloo,
"localhost".to_string(),
29500,
);
Ok(Self { process_group })
}
pub fn size(&self) -> usize {
self.process_group.world_size
}
pub fn rank(&self) -> usize {
self.process_group.rank
}
}
pub fn new_group(
ranks: Vec<usize>,
timeout: Option<Duration>,
backend: Option<DistributedBackend>,
) -> RusTorchResult<Group> {
let _ = (timeout, backend); Group::new(ranks)
}
pub mod monitoring {
use super::*;
pub fn get_communication_stats() -> RusTorchResult<CommunicationStats> {
Ok(CommunicationStats::new())
}
#[derive(Debug, Clone)]
pub struct CommunicationStats {
pub total_bytes_sent: u64,
pub total_bytes_received: u64,
pub total_operations: u64,
pub average_latency_ms: f64,
}
impl CommunicationStats {
pub fn new() -> Self {
Self {
total_bytes_sent: 0,
total_bytes_received: 0,
total_operations: 0,
average_latency_ms: 0.0,
}
}
}
impl Default for CommunicationStats {
fn default() -> Self {
Self::new()
}
}
}
fn parse_init_method(init_method: &str) -> RusTorchResult<(String, u16)> {
if init_method.starts_with("tcp://") {
let addr_port = init_method.strip_prefix("tcp://").unwrap();
if let Some((addr, port_str)) = addr_port.split_once(':') {
let port = port_str
.parse::<u16>()
.map_err(|_| RusTorchError::distributed("Invalid port in init_method"))?;
Ok((addr.to_string(), port))
} else {
Err(RusTorchError::distributed("Invalid init_method format"))
}
} else {
Err(RusTorchError::distributed("Unsupported init_method scheme"))
}
}
fn get_master_info_from_env() -> RusTorchResult<(String, u16)> {
let addr = std::env::var("MASTER_ADDR")
.map_err(|_| RusTorchError::distributed("MASTER_ADDR not set"))?;
let port = std::env::var("MASTER_PORT")
.map_err(|_| RusTorchError::distributed("MASTER_PORT not set"))?
.parse::<u16>()
.map_err(|_| RusTorchError::distributed("Invalid MASTER_PORT"))?;
Ok((addr, port))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_init_method() {
let result = parse_init_method("tcp://127.0.0.1:29500");
assert!(result.is_ok());
let (addr, port) = result.unwrap();
assert_eq!(addr, "127.0.0.1");
assert_eq!(port, 29500);
}
#[test]
fn test_parse_init_method_invalid() {
let result = parse_init_method("invalid://127.0.0.1:29500");
assert!(result.is_err());
let result = parse_init_method("tcp://127.0.0.1");
assert!(result.is_err());
}
#[test]
fn test_distributed_request() {
let req = DistributedRequest::new();
assert!(!req.is_completed());
assert!(req.wait().is_ok());
assert!(req.is_completed());
}
#[test]
fn test_group_creation() {
let ranks = vec![0, 1, 2, 3];
let group = Group::new(ranks);
assert!(group.is_ok());
let group = group.unwrap();
assert_eq!(group.size(), 4);
assert_eq!(group.rank(), 0);
}
#[test]
fn test_group_creation_empty() {
let ranks = vec![];
let group = Group::new(ranks);
assert!(group.is_err());
}
}