use rustorch::autograd::Variable;
use rustorch::distributed::*;
use rustorch::error::RusTorchResult;
use rustorch::nn::Linear;
use rustorch::tensor::Tensor;
use std::time::Duration;
#[test]
fn test_distributed_initialization() {
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29500");
let result = init_process_group(
DistributedBackend::TCP,
Some("tcp://localhost:29500"),
Some(1),
Some(0),
Some(Duration::from_secs(30)),
);
if result.is_err() {
println!(
"Skipping distributed initialization test - not supported in this environment: {:?}",
result
);
return;
}
assert!(is_initialized());
assert_eq!(api::get_rank(), 0);
assert_eq!(api::get_world_size(), 1);
let _ = destroy_process_group();
}
#[test]
fn test_ddp_wrapper() {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29500");
let init_result = init_process_group(DistributedBackend::TCP, None, None, None, None);
if init_result.is_err() {
println!(
"Skipping DDP wrapper test - initialization failed: {:?}",
init_result
);
return;
}
let linear: Linear<f32> = Linear::new(10, 5);
let ddp_result = wrap_module(linear, Some(vec![0]));
if ddp_result.is_err() {
println!(
"Skipping DDP wrapper test - DDP not available: {:?}",
ddp_result
);
destroy_process_group().ok();
return;
}
let ddp = ddp_result.unwrap();
let input = Variable::new(Tensor::randn(&[2, 10]), false);
let output = ddp.forward(&input);
assert!(output.is_ok(), "DDP forward pass failed: {:?}", output);
let output_var = output.unwrap();
let output_data = output_var.data();
let output_guard = output_data.read().unwrap();
assert_eq!(output_guard.shape(), &[2, 5]);
let _ = destroy_process_group();
}
#[test]
fn test_all_reduce_operation() {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29501");
let init_result = init_process_group(DistributedBackend::TCP, None, None, None, None);
if init_result.is_err() {
println!(
"Skipping test - distributed initialization failed: {:?}",
init_result
);
return;
}
let mut tensor: Tensor<f32> = Tensor::ones(&[3, 3]);
let result = all_reduce(&mut tensor, ReduceOp::Sum, None, false);
if result.is_err() {
println!("Skipping all-reduce operation test - distributed backend not available in CI");
destroy_process_group().ok();
return;
}
let _ = destroy_process_group();
}
#[test]
fn test_broadcast_operation() {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29502");
let init_result = init_process_group(DistributedBackend::TCP, None, None, None, None);
if init_result.is_err() {
println!(
"Skipping test - distributed initialization failed: {:?}",
init_result
);
return;
}
let mut tensor: Tensor<f32> = Tensor::randn(&[2, 2]);
let _original_data = tensor.clone();
let result = broadcast(&mut tensor, 0, None, false);
if result.is_err() {
println!("Skipping broadcast operation test - distributed backend not available in CI");
destroy_process_group().ok();
return;
}
let _ = destroy_process_group();
}
#[test]
fn test_gradient_synchronization() {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29503");
let init_result = init_process_group(DistributedBackend::TCP, None, None, None, None);
if init_result.is_err() {
println!(
"Skipping test - distributed initialization failed: {:?}",
init_result
);
return;
}
let linear: Linear<f32> = Linear::new(5, 3);
let ddp = match wrap_module(linear, Some(vec![0])) {
Ok(ddp) => ddp,
Err(e) => {
println!(
"Skipping gradient synchronization test - DDP not available: {:?}",
e
);
destroy_process_group().ok();
return;
}
};
let input = Variable::new(Tensor::randn(&[2, 5]), false);
let _output = ddp.forward(&input).unwrap();
let sync_result = ddp.sync_gradients();
assert!(
sync_result.is_ok(),
"Gradient synchronization failed: {:?}",
sync_result
);
let _ = destroy_process_group();
}
#[test]
fn test_distributed_performance() {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29504");
let init_result = init_process_group(DistributedBackend::TCP, None, None, None, None);
if init_result.is_err() {
println!(
"Skipping test - distributed initialization failed: {:?}",
init_result
);
return;
}
let sizes = vec![
vec![100, 100], vec![1000, 100], vec![1000, 1000], ];
for size in sizes {
let mut tensor: Tensor<f32> = Tensor::randn(&size);
let start = std::time::Instant::now();
let result = all_reduce(&mut tensor, ReduceOp::Sum, None, false);
let duration = start.elapsed();
if result.is_err() {
println!("Skipping performance test for size {:?} - distributed operation not available in CI", size);
continue;
}
println!("All-reduce for {:?}: {:?}", size, duration);
}
let _ = destroy_process_group();
}
#[test]
fn test_distributed_error_handling() {
let _ = destroy_process_group();
assert!(!is_initialized());
let mut tensor: Tensor<f32> = Tensor::ones(&[2, 2]);
let result = all_reduce(&mut tensor, ReduceOp::Sum, None, false);
if result.is_ok() {
println!(
"Warning: All-reduce succeeded without initialization - unexpected in test environment"
);
}
let invalid_group_result = new_group(vec![], None, None);
assert!(
invalid_group_result.is_err(),
"Empty group creation should fail"
);
}
#[cfg(feature = "nccl")]
#[test]
fn test_nccl_specific_features() {
use rustorch::distributed::nccl_integration::{NCCLOps, NCCLOptimizations};
let config = NCCLOps::get_optimal_config(8, 32.0);
assert!(config.compression_enabled);
assert_eq!(config.bucket_size_mb, 50);
let config = NCCLOps::get_optimal_config(2, 8.0);
assert!(!config.compression_enabled);
assert_eq!(config.bucket_size_mb, 25);
}
#[test]
fn test_multi_gpu_validation() {
let validator: RusTorchResult<MultiGpuValidator<f32>> = MultiGpuValidator::new();
assert!(validator.is_ok());
let validator = validator.unwrap();
let available_gpus = validator.get_devices();
println!("Detected {} GPUs", available_gpus.len());
}
#[test]
fn test_distributed_training_scenario() -> RusTorchResult<()> {
let _ = destroy_process_group();
std::env::set_var("RANK", "0");
std::env::set_var("WORLD_SIZE", "1");
std::env::set_var("MASTER_ADDR", "localhost");
std::env::set_var("MASTER_PORT", "29505");
if let Err(e) = init_process_group(DistributedBackend::TCP, None, None, None, None) {
println!(
"Skipping distributed training scenario test - initialization failed in CI: {:?}",
e
);
return Ok(());
}
let model: Linear<f32> = Linear::new(784, 10);
let ddp_model = match wrap_module(model, Some(vec![0])) {
Ok(ddp) => ddp,
Err(e) => {
println!(
"Skipping training scenario - DDP wrapper creation failed: {:?}",
e
);
destroy_process_group().ok();
return Ok(());
}
};
let batch_size = 32;
let input = Variable::new(Tensor::randn(&[batch_size, 784]), false);
let _target: Variable<f32> = Variable::new(Tensor::randn(&[batch_size, 10]), false);
let output = ddp_model.forward(&input)?;
let output_data = output.data();
let output_guard = output_data.read().unwrap();
assert_eq!(output_guard.shape(), &[batch_size, 10]);
ddp_model.sync_gradients()?;
destroy_process_group()?;
Ok(())
}