use anyhow::Result;
use candle_core::Tensor;
#[cfg(feature = "nccl")]
use cudarc::driver::safe::CudaDevice;
#[cfg(feature = "nccl")]
use cudarc::nccl::safe::{Comm, Id};
#[cfg(feature = "nccl")]
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub world_size: usize,
pub rank: usize,
pub comm_file: String,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
world_size: 1,
rank: 0,
comm_file: "nccl_id.txt".to_string(),
}
}
}
impl DistributedConfig {
pub fn is_distributed(&self) -> bool {
self.world_size > 1
}
pub fn is_main_process(&self) -> bool {
self.rank == 0
}
}
#[cfg(feature = "nccl")]
pub struct NcclCommunicator {
comm: Rc<Comm>,
rank: usize,
world_size: usize,
}
#[cfg(feature = "nccl")]
impl NcclCommunicator {
pub fn new(config: &DistributedConfig) -> Result<Self> {
use std::io::Write;
let comm_file = std::path::PathBuf::from(&config.comm_file);
let id = if config.rank == 0 {
if comm_file.exists() {
std::fs::remove_file(&comm_file)?;
}
let id = Id::new().map_err(|e| anyhow::anyhow!("Failed to create NCCL ID: {:?}", e))?;
let tmp_file = comm_file.with_extension("tmp");
let mut file = std::fs::File::create(&tmp_file)?;
file.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
tracing::info!("Rank 0: Created NCCL ID and wrote to {:?}", comm_file);
id
} else {
tracing::info!("Rank {}: Waiting for NCCL ID file...", config.rank);
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
std::thread::sleep(std::time::Duration::from_millis(100));
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow::anyhow!("Invalid NCCL ID file"))?;
let id = Id::uninit(internal);
tracing::info!("Rank {}: Read NCCL ID from {:?}", config.rank, comm_file);
id
};
let cuda_device = CudaDevice::new(config.rank).map_err(|e| {
anyhow::anyhow!("Failed to create CUDA device {}: {:?}", config.rank, e)
})?;
let comm = Comm::from_rank(cuda_device, config.rank, config.world_size, id)
.map_err(|e| anyhow::anyhow!("Failed to create NCCL communicator: {:?}", e.0))?;
if config.rank == 0 {
std::thread::sleep(std::time::Duration::from_secs(2));
if comm_file.exists() {
let _ = std::fs::remove_file(&comm_file);
}
}
tracing::info!("Rank {}: NCCL communicator initialized", config.rank);
Ok(Self {
comm: Rc::new(comm),
rank: config.rank,
world_size: config.world_size,
})
}
pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
let reduced = self.all_reduce_sum(tensor)?;
let avg = reduced.affine(1.0 / self.world_size as f64, 0.0)?;
Ok(avg)
}
pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
use cudarc::nccl::safe::ReduceOp;
let storage = tensor.storage_and_layout().0;
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let mut output = data.clone();
self.comm
.all_reduce(&data, &mut output, &ReduceOp::Sum)
.map_err(|e| anyhow::anyhow!("NCCL all-reduce failed: {:?}", e.0))?;
let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
Ok(result)
}
pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let mut output = data.clone();
self.comm
.broadcast(&data, &mut output, 0)
.map_err(|e| anyhow::anyhow!("NCCL broadcast failed: {:?}", e.0))?;
let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
Ok(result)
}
pub fn barrier(&self) -> Result<()> {
let dummy = vec![0.0f32];
let mut output = dummy.clone();
self.comm
.all_reduce(&dummy, &mut output, &cudarc::nccl::safe::ReduceOp::Sum)
.map_err(|e| anyhow::anyhow!("NCCL barrier failed: {:?}", e.0))?;
Ok(())
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
}
#[cfg(not(feature = "nccl"))]
pub struct NcclCommunicator {
rank: usize,
world_size: usize,
}
#[cfg(not(feature = "nccl"))]
impl NcclCommunicator {
pub fn new(_config: &DistributedConfig) -> Result<Self> {
anyhow::bail!("NCCL support not enabled. Build with --features nccl")
}
pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn barrier(&self) -> Result<()> {
Ok(())
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
}
pub fn sync_gradients(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
for var in var_map.all_vars() {
let tensor = var.as_tensor();
let synced = comm.all_reduce_avg(tensor)?;
var.set(&synced)?;
}
Ok(())
}