use anyhow::Result;
use candle_core::Tensor;
#[cfg(feature = "nccl")]
use cudarc::driver::safe::{CudaContext, CudaStream};
#[cfg(feature = "nccl")]
use cudarc::nccl::safe::{Comm, Id};
#[cfg(feature = "nccl")]
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct DistributedConfig {
pub world_size: usize,
pub rank: usize,
pub comm_file: String,
}
impl Default for DistributedConfig {
fn default() -> Self {
Self {
world_size: 1,
rank: 0,
comm_file: "nccl_id.txt".to_string(),
}
}
}
impl DistributedConfig {
pub fn is_distributed(&self) -> bool {
self.world_size > 1
}
pub fn is_main_process(&self) -> bool {
self.rank == 0
}
}
#[cfg(feature = "nccl")]
pub struct NcclCommunicator {
comm: Comm,
stream: Arc<CudaStream>,
rank: usize,
world_size: usize,
}
#[cfg(feature = "nccl")]
impl NcclCommunicator {
pub fn new(config: &DistributedConfig) -> Result<Self> {
use std::io::Write;
let comm_file = std::path::PathBuf::from(&config.comm_file);
let id = if config.rank == 0 {
if comm_file.exists() {
std::fs::remove_file(&comm_file)?;
}
let id = Id::new().map_err(|e| anyhow::anyhow!("Failed to create NCCL ID: {:?}", e))?;
let tmp_file = comm_file.with_extension("tmp");
let mut file = std::fs::File::create(&tmp_file)?;
file.write_all(&id.internal().iter().map(|&i| i as u8).collect::<Vec<_>>())?;
std::fs::rename(&tmp_file, &comm_file)?;
tracing::info!("Rank 0: Created NCCL ID and wrote to {:?}", comm_file);
id
} else {
tracing::info!("Rank {}: Waiting for NCCL ID file...", config.rank);
while !comm_file.exists() {
std::thread::sleep(std::time::Duration::from_millis(100));
}
std::thread::sleep(std::time::Duration::from_millis(100));
let data = std::fs::read(&comm_file)?;
let internal: [i8; 128] = data
.into_iter()
.map(|i| i as i8)
.collect::<Vec<_>>()
.try_into()
.map_err(|_| anyhow::anyhow!("Invalid NCCL ID file"))?;
let id = Id::uninit(internal);
tracing::info!("Rank {}: Read NCCL ID from {:?}", config.rank, comm_file);
id
};
let gpu_ordinal = 0;
let ctx = CudaContext::new(gpu_ordinal).map_err(|e| {
anyhow::anyhow!("Failed to create CUDA context {}: {:?}", gpu_ordinal, e)
})?;
let stream = ctx.default_stream();
let comm = Comm::from_rank(stream.clone(), config.rank, config.world_size, id)
.map_err(|e| anyhow::anyhow!("Failed to create NCCL communicator: {:?}", e.0))?;
tracing::info!("Rank {}: NCCL communicator initialized", config.rank);
if config.rank == 0 {
std::thread::sleep(std::time::Duration::from_millis(500));
if comm_file.exists() {
let _ = std::fs::remove_file(&comm_file);
}
}
Ok(Self {
comm,
stream,
rank: config.rank,
world_size: config.world_size,
})
}
pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
let reduced = self.all_reduce_sum(tensor)?;
let avg = reduced.affine(1.0 / self.world_size as f64, 0.0)?;
Ok(avg)
}
pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
use cudarc::nccl::safe::ReduceOp;
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let len = data.len();
let gpu_data = self
.stream
.clone_htod(&data)
.map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
let mut gpu_output = self
.stream
.alloc_zeros::<f32>(len)
.map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
self.comm
.all_reduce(&gpu_data, &mut gpu_output, &ReduceOp::Sum)
.map_err(|e| anyhow::anyhow!("NCCL all-reduce failed: {:?}", e.0))?;
self.stream
.synchronize()
.map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
let output = self
.stream
.clone_dtoh(&gpu_output)
.map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
Ok(result)
}
pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
let len = data.len();
let gpu_data = if self.rank == 0 {
Some(
self.stream
.clone_htod(&data)
.map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?,
)
} else {
None
};
let mut gpu_output = self
.stream
.alloc_zeros::<f32>(len)
.map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
self.comm
.broadcast(gpu_data.as_ref(), &mut gpu_output, 0)
.map_err(|e| anyhow::anyhow!("NCCL broadcast failed: {:?}", e.0))?;
self.stream
.synchronize()
.map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
let output = self
.stream
.clone_dtoh(&gpu_output)
.map_err(|e| anyhow::anyhow!("Failed to copy data from GPU: {:?}", e))?;
let result = Tensor::from_vec(output, tensor.shape(), tensor.device())?;
Ok(result)
}
pub fn barrier(&self) -> Result<()> {
use cudarc::nccl::safe::ReduceOp;
let dummy = [0.0f32];
let gpu_dummy = self
.stream
.clone_htod(&dummy)
.map_err(|e| anyhow::anyhow!("Failed to copy data to GPU: {:?}", e))?;
let mut gpu_output = self
.stream
.alloc_zeros::<f32>(1)
.map_err(|e| anyhow::anyhow!("Failed to allocate GPU buffer: {:?}", e))?;
self.comm
.all_reduce(&gpu_dummy, &mut gpu_output, &ReduceOp::Sum)
.map_err(|e| anyhow::anyhow!("NCCL barrier failed: {:?}", e.0))?;
self.stream
.synchronize()
.map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
Ok(())
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn finalize(self) -> Result<()> {
self.stream
.synchronize()
.map_err(|e| anyhow::anyhow!("Stream sync failed: {:?}", e))?;
Ok(())
}
}
#[cfg(not(feature = "nccl"))]
pub struct NcclCommunicator {
rank: usize,
world_size: usize,
}
#[cfg(not(feature = "nccl"))]
impl NcclCommunicator {
pub fn new(_config: &DistributedConfig) -> Result<Self> {
anyhow::bail!("NCCL support not enabled. Build with --features nccl")
}
pub fn all_reduce_avg(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn all_reduce_sum(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn broadcast(&self, tensor: &Tensor) -> Result<Tensor> {
Ok(tensor.clone())
}
pub fn barrier(&self) -> Result<()> {
Ok(())
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn world_size(&self) -> usize {
self.world_size
}
pub fn finalize(self) -> Result<()> {
Ok(())
}
}
fn sync_vars(
var_map: &candle_nn::VarMap,
op: impl FnOnce(&Tensor) -> Result<Tensor>,
) -> Result<()> {
use candle_core::Shape;
let vars: Vec<candle_core::Var> = var_map.all_vars();
if vars.is_empty() {
return Ok(());
}
let mut shapes: Vec<Shape> = Vec::with_capacity(vars.len());
let mut flat_data: Vec<f32> = Vec::new();
for var in &vars {
let tensor = var.as_tensor();
shapes.push(tensor.shape().clone());
let data: Vec<f32> = tensor.flatten_all()?.to_vec1()?;
flat_data.extend(data);
}
let device = vars[0].as_tensor().device();
let len = flat_data.len();
let flat_tensor = Tensor::from_vec(flat_data, len, device)?;
let synced = op(&flat_tensor)?;
let synced_data: Vec<f32> = synced.to_vec1()?;
let mut offset = 0;
for (var, shape) in vars.iter().zip(shapes.iter()) {
let size = shape.elem_count();
let data = &synced_data[offset..offset + size];
let tensor = Tensor::from_vec(data.to_vec(), shape.dims(), device)?;
var.set(&tensor)?;
offset += size;
}
Ok(())
}
pub fn sync_model(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
sync_vars(var_map, |t| comm.broadcast(t))
}
pub fn sync_gradients(var_map: &candle_nn::VarMap, comm: &NcclCommunicator) -> Result<()> {
sync_vars(var_map, |t| comm.all_reduce_avg(t))
}