use std::process::{Command, Stdio};
use std::time::Duration;
use tokio::time::timeout;
use torsh_core::Result;
use torsh_distributed::{
backend::BackendType,
backend::ReduceOp,
collectives::{all_reduce, barrier},
init_process_group, ProcessGroup,
};
use torsh_tensor::creation::{full, ones};
use torsh_tensor::Tensor;
#[derive(Debug, Clone)]
pub struct MultiNodeTestConfig {
pub num_nodes: u32,
pub processes_per_node: u32,
pub master_addr: String,
pub master_port: u16,
pub timeout_secs: u64,
}
impl Default for MultiNodeTestConfig {
fn default() -> Self {
Self {
num_nodes: 2,
processes_per_node: 2,
master_addr: "127.0.0.1".to_string(),
master_port: 29500,
timeout_secs: 30,
}
}
}
pub struct MultiNodeTestCoordinator {
config: MultiNodeTestConfig,
processes: Vec<std::process::Child>,
}
impl MultiNodeTestCoordinator {
pub fn new(config: MultiNodeTestConfig) -> Self {
Self {
config,
processes: Vec::new(),
}
}
pub fn spawn_workers(&mut self) -> Result<()> {
let world_size = self.config.num_nodes * self.config.processes_per_node;
for node_id in 0..self.config.num_nodes {
for local_rank in 0..self.config.processes_per_node {
let global_rank = node_id * self.config.processes_per_node + local_rank;
let child = Command::new("cargo")
.args(["test", "--test", "test_multi_node_worker"])
.env("TORSH_RANK", global_rank.to_string())
.env("TORSH_WORLD_SIZE", world_size.to_string())
.env("TORSH_MASTER_ADDR", &self.config.master_addr)
.env("TORSH_MASTER_PORT", self.config.master_port.to_string())
.env("TORSH_NODE_ID", node_id.to_string())
.env("TORSH_LOCAL_RANK", local_rank.to_string())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.map_err(|e| {
torsh_core::TorshError::Other(format!("Failed to spawn worker: {}", e))
})?;
self.processes.push(child);
}
}
Ok(())
}
pub fn wait_for_completion(&mut self) -> Result<()> {
let processes = std::mem::take(&mut self.processes);
for process in processes {
let output = process.wait_with_output().map_err(|e| {
torsh_core::TorshError::Other(format!("Process wait failed: {}", e))
})?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(torsh_core::TorshError::Other(format!(
"Worker process failed: {}",
stderr
)));
}
}
Ok(())
}
pub fn cleanup(&mut self) {
for process in &mut self.processes {
let _ = process.kill();
let _ = process.wait();
}
self.processes.clear();
}
}
impl Drop for MultiNodeTestCoordinator {
fn drop(&mut self) {
self.cleanup();
}
}
pub async fn multi_node_worker() -> Result<()> {
let rank: u32 = std::env::var("TORSH_RANK")
.unwrap_or("0".to_string())
.parse()
.unwrap_or(0);
let world_size: u32 = std::env::var("TORSH_WORLD_SIZE")
.unwrap_or("1".to_string())
.parse()
.unwrap_or(1);
let master_addr = std::env::var("TORSH_MASTER_ADDR").unwrap_or("127.0.0.1".to_string());
let master_port: u16 = std::env::var("TORSH_MASTER_PORT")
.unwrap_or("29500".to_string())
.parse()
.unwrap_or(29500);
let pg = init_process_group(
BackendType::Gloo,
rank,
world_size,
&master_addr,
master_port,
)
.await?;
barrier(&pg).await?;
let mut tensor = full::<f32>(&[4, 4], rank as f32 + 1.0)?;
all_reduce(&mut tensor, ReduceOp::Sum, &pg).await?;
let expected_sum = (1..=world_size).sum::<u32>() as f32;
let expected_average = expected_sum / world_size as f32;
let data = tensor.to_vec()?;
for value in data {
assert!(
(value - expected_average).abs() < 1e-5_f32,
"AllReduce result incorrect: got {}, expected {}",
value,
expected_average
);
}
barrier(&pg).await?;
Ok(())
}
#[tokio::test]
async fn test_two_node_communication() -> Result<()> {
let config = MultiNodeTestConfig {
num_nodes: 2,
processes_per_node: 1,
timeout_secs: 30,
..Default::default()
};
let pg1 = init_process_group(BackendType::Gloo, 0, 2, "127.0.0.1", 29501).await?;
let pg2 = init_process_group(BackendType::Gloo, 1, 2, "127.0.0.1", 29501).await?;
let barrier_result = timeout(Duration::from_secs(config.timeout_secs), async {
barrier(&pg1).await?;
barrier(&pg2).await?;
Result::Ok(())
})
.await;
assert!(barrier_result.is_ok(), "Barrier synchronization failed");
Ok(())
}
#[tokio::test]
async fn test_multi_node_gradient_aggregation() -> Result<()> {
let world_size = 4;
let mut process_groups = Vec::new();
for rank in 0..world_size {
let pg =
init_process_group(BackendType::Gloo, rank, world_size, "127.0.0.1", 29502).await?;
process_groups.push(pg);
}
let mut gradients: Vec<Tensor> = Vec::new();
for i in 0..world_size as usize {
let tensor = full::<f32>(&[10, 10], i as f32 + 1.0)?;
gradients.push(tensor);
}
for (i, gradient) in gradients.iter_mut().enumerate() {
all_reduce(gradient, ReduceOp::Sum, &process_groups[i]).await?;
}
for (i, gradient) in gradients.iter().enumerate() {
let expected_value = (i + 1) as f32;
let data: Vec<f32> = gradient.to_vec()?;
for &value in &data {
assert!(
(value - expected_value).abs() < 1e-5_f32,
"Gradient should retain original value with mock: got {}, expected {}",
value,
expected_value
);
}
}
Ok(())
}
#[tokio::test]
async fn test_node_failure_simulation() -> Result<()> {
let world_size = 3;
let mut process_groups: Vec<Option<ProcessGroup>> = Vec::new();
for rank in 0..world_size {
let pg = init_process_group(BackendType::Gloo, rank, world_size, "127.0.0.1", 29503)
.await
.ok();
process_groups.push(pg);
}
process_groups[1] = None;
if let (Some(pg0), Some(pg2)) = (&process_groups[0], &process_groups[2]) {
let result = timeout(Duration::from_secs(5), async {
barrier(pg0).await?;
barrier(pg2).await?;
Result::Ok(())
})
.await;
assert!(result.is_ok(), "Communication failed with node failure");
}
Ok(())
}
#[tokio::test]
async fn test_dynamic_node_joining() -> Result<()> {
let initial_world_size = 2;
let initial_pg_futures: Vec<_> = (0..initial_world_size)
.map(|rank| {
init_process_group(
BackendType::Gloo,
rank,
initial_world_size,
"127.0.0.1",
29504,
)
})
.collect();
let initial_pg_results = futures_util::future::join_all(initial_pg_futures).await;
let initial_pgs: Vec<ProcessGroup> = initial_pg_results
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
torsh_core::TorshError::Other(format!("Failed to initialize process groups: {:?}", e))
})?;
for pg in &initial_pgs {
barrier(pg).await?;
}
let expanded_world_size = 3;
let new_node_pg = init_process_group(
BackendType::Gloo,
2,
expanded_world_size,
"127.0.0.1",
29504,
)
.await?;
barrier(&new_node_pg).await?;
Ok(())
}
#[tokio::test]
async fn test_cross_node_data_consistency() -> Result<()> {
let world_size = 4;
let pg_futures: Vec<_> = (0..world_size)
.map(|rank| init_process_group(BackendType::Gloo, rank, world_size, "127.0.0.1", 29505))
.collect();
let pg_results = futures_util::future::join_all(pg_futures).await;
let process_groups: Vec<ProcessGroup> = pg_results
.into_iter()
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| {
torsh_core::TorshError::Other(format!("Failed to initialize process groups: {:?}", e))
})?;
let initial_data = ones::<f32>(&[5, 5])?.mul_scalar(2.0)?;
let mut node_data: Vec<Tensor> = vec![initial_data; world_size as usize];
for round in 0..3 {
for (i, data) in node_data.iter_mut().enumerate() {
all_reduce(data, ReduceOp::Sum, &process_groups[i]).await?;
}
let reference_data = node_data[0].to_vec()?;
for (node_idx, data) in node_data.iter().enumerate() {
let node_data_vec = data.to_vec()?;
assert_eq!(
reference_data.len(),
node_data_vec.len(),
"Data length mismatch at node {} round {}",
node_idx,
round
);
for (i, (&ref_val, &node_val)) in
reference_data.iter().zip(node_data_vec.iter()).enumerate()
{
assert!(
(ref_val - node_val).abs() < 1e-6,
"Data inconsistency at node {} element {} round {}: {} vs {}",
node_idx,
i,
round,
ref_val,
node_val
);
}
}
}
Ok(())
}