use ferrotorch_core::FerrotorchResult;
use ferrotorch_gpu::{GpuFloat, GpuTensor, tensor_to_cpu, tensor_to_gpu};
use crate::backend::Backend;
use crate::collective::{ReduceOp, allreduce, broadcast};
use crate::error::DistributedError;
#[cfg(feature = "nccl")]
use crate::nccl_backend::{NcclBackend, reduce_op_to_nccl};
#[cfg(feature = "nccl")]
use crate::nccl_sys::NcclDataType;
fn cpu_path_allreduce<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
op: ReduceOp,
) -> FerrotorchResult<GpuTensor<T>> {
let cpu_tensor = tensor_to_cpu(tensor)?;
let reduced = allreduce(&cpu_tensor, backend, op)?;
let gpu_result = tensor_to_gpu(&reduced, tensor.device()).map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_allreduce: CPU->GPU transfer failed: {e}"),
}
})?;
Ok(gpu_result)
}
fn cpu_path_broadcast<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
root: usize,
) -> FerrotorchResult<GpuTensor<T>> {
let cpu_tensor = tensor_to_cpu(tensor)?;
let bcast = broadcast(&cpu_tensor, backend, root)?;
let gpu_result = tensor_to_gpu(&bcast, tensor.device()).map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_broadcast: CPU->GPU transfer failed: {e}"),
}
})?;
Ok(gpu_result)
}
#[cfg(feature = "nccl")]
fn nccl_dtype_of<T: GpuFloat>() -> FerrotorchResult<NcclDataType> {
use std::any::TypeId;
if TypeId::of::<T>() == TypeId::of::<f32>() {
Ok(NcclDataType::Float32)
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
Ok(NcclDataType::Float64)
} else {
Err(DistributedError::UnsupportedOp {
message: format!(
"NCCL fast path does not cover dtype {} — only f32/f64 are wired",
std::any::type_name::<T>(),
),
}
.into())
}
}
#[cfg(feature = "nccl")]
fn nccl_path_allreduce<T: GpuFloat>(
tensor: &GpuTensor<T>,
nccl: &NcclBackend,
op: ReduceOp,
) -> FerrotorchResult<GpuTensor<T>> {
let dtype = nccl_dtype_of::<T>()?;
let nccl_op = reduce_op_to_nccl(&op);
let count = tensor.numel();
let out = tensor.try_clone().map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_allreduce: D2D clone failed: {e}"),
}
})?;
let ptr = out.cu_device_ptr() as *mut std::ffi::c_void;
unsafe { nccl.allreduce_raw(ptr.cast_const(), ptr, count, dtype, nccl_op) }?;
nccl.synchronize()?;
Ok(out)
}
#[cfg(feature = "nccl")]
fn nccl_path_broadcast<T: GpuFloat>(
tensor: &GpuTensor<T>,
nccl: &NcclBackend,
root: usize,
) -> FerrotorchResult<GpuTensor<T>> {
let dtype = nccl_dtype_of::<T>()?;
let count = tensor.numel();
let world_size = nccl.world_size();
if root >= world_size {
return Err(DistributedError::InvalidRank {
rank: root,
world_size,
}
.into());
}
let root_i32 = i32::try_from(root).map_err(|_| DistributedError::InvalidRank {
rank: root,
world_size,
})?;
let out = tensor.try_clone().map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_broadcast: D2D clone failed: {e}"),
}
})?;
let ptr = out.cu_device_ptr() as *mut std::ffi::c_void;
unsafe { nccl.broadcast_raw(ptr.cast_const(), ptr, count, dtype, root_i32) }?;
nccl.synchronize()?;
Ok(out)
}
pub fn gpu_allreduce<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
op: ReduceOp,
) -> FerrotorchResult<GpuTensor<T>> {
#[cfg(feature = "nccl")]
if let Some(nccl) = backend.as_nccl_backend() {
return nccl_path_allreduce(tensor, nccl, op);
}
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
collective = "allreduce",
"GPU collective is using host round-trip (Gloo-equivalent slow path). \
Unset FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
return cpu_path_allreduce(tensor, backend, op);
}
Err(DistributedError::UnsupportedOp {
message: "gpu_allreduce requires the `nccl` feature for GPU-native operation. \
Set FERROTORCH_ENABLE_GPU_FALLBACK=1 to enable the host round-trip \
(Gloo-equivalent) fallback instead."
.into(),
}
.into())
}
pub fn gpu_broadcast<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
root: usize,
) -> FerrotorchResult<GpuTensor<T>> {
#[cfg(feature = "nccl")]
if let Some(nccl) = backend.as_nccl_backend() {
return nccl_path_broadcast(tensor, nccl, root);
}
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
collective = "broadcast",
"GPU collective is using host round-trip (Gloo-equivalent slow path). \
Unset FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
return cpu_path_broadcast(tensor, backend, root);
}
Err(DistributedError::UnsupportedOp {
message: "gpu_broadcast requires the `nccl` feature for GPU-native operation. \
Set FERROTORCH_ENABLE_GPU_FALLBACK=1 to enable the host round-trip \
(Gloo-equivalent) fallback instead."
.into(),
}
.into())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
use ferrotorch_gpu::{GpuDevice, tensor_to_gpu as t2g};
use std::sync::Arc;
use std::thread;
fn gpu_from_slice(data: &[f32], shape: &[usize]) -> GpuTensor<f32> {
let cpu = ferrotorch_core::from_slice(data, shape).unwrap();
let device = GpuDevice::new(0).unwrap();
t2g(&cpu, &device).unwrap()
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_sum_2_ranks() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![1.0, 2.0, 3.0]
} else {
vec![4.0, 5.0, 6.0]
};
let gt = gpu_from_slice(&data, &[3]);
let result = gpu_allreduce(>, b.as_ref(), ReduceOp::Sum).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert_eq!(out.len(), 3);
assert!(
(out[0] - 5.0).abs() < 1e-6,
"rank {rank}: expected 5.0, got {}",
out[0]
);
assert!(
(out[1] - 7.0).abs() < 1e-6,
"rank {rank}: expected 7.0, got {}",
out[1]
);
assert!(
(out[2] - 9.0).abs() < 1e-6,
"rank {rank}: expected 9.0, got {}",
out[2]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_mean_2_ranks() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![2.0, 4.0]
} else {
vec![6.0, 8.0]
};
let gt = gpu_from_slice(&data, &[2]);
let result = gpu_allreduce(>, b.as_ref(), ReduceOp::Mean).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert!(
(out[0] - 4.0).abs() < 1e-6,
"rank {rank}: expected 4.0, got {}",
out[0]
);
assert!(
(out[1] - 6.0).abs() < 1e-6,
"rank {rank}: expected 6.0, got {}",
out[1]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_broadcast_from_rank_0() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![42.0, 99.0]
} else {
vec![0.0, 0.0]
};
let gt = gpu_from_slice(&data, &[2]);
let result = gpu_broadcast(>, b.as_ref(), 0).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert!(
(out[0] - 42.0).abs() < 1e-6,
"rank {rank}: expected 42.0, got {}",
out[0]
);
assert!(
(out[1] - 99.0).abs() < 1e-6,
"rank {rank}: expected 99.0, got {}",
out[1]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_single_rank() {
let group = SimulatedBackend::create_group(1).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0, 3.0], &[3]);
let result = gpu_allreduce(>, &group[0], ReduceOp::Sum).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert_eq!(out, &[1.0, 2.0, 3.0]);
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_preserves_shape() {
let group = SimulatedBackend::create_group(1).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let result = gpu_allreduce(>, &group[0], ReduceOp::Sum).unwrap();
assert_eq!(result.shape(), &[2, 3]);
}
#[test]
#[ignore = "tracking issue #1135 (replaces closed #668): opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_broadcast_invalid_root() {
let group = SimulatedBackend::create_group(2).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0], &[2]);
let result = gpu_broadcast(>, &group[0], 5);
assert!(result.is_err());
}
#[cfg(feature = "nccl")]
#[test]
#[ignore = "requires NCCL (libnccl2) and a CUDA device — exercises the GpuTensor → NcclBackend wiring landed in #1135"]
fn gpu_allreduce_dispatches_to_nccl_in_single_rank_mode() {
use crate::nccl_backend::NcclBackend;
use crate::nccl_sys::get_unique_id;
let unique_id = get_unique_id().expect("NCCL unique ID generation");
let nccl = NcclBackend::new(0, 1, unique_id).expect("NcclBackend init");
let input = [1.5_f32, -2.5, 3.5, 0.0];
let gt = gpu_from_slice(&input, &[4]);
let result = gpu_allreduce(>, &nccl, ReduceOp::Sum)
.expect("single-rank NCCL allreduce dispatch must succeed");
assert_eq!(result.shape(), &[4]);
let cpu = result.cpu().expect("result to CPU");
let out = cpu.data().expect("flat data");
for (i, (got, want)) in out.iter().zip(input.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-6,
"allreduce[{i}]: got {got}, want {want}",
);
}
}
#[cfg(feature = "nccl")]
#[test]
#[ignore = "requires NCCL (libnccl2) and a CUDA device — exercises the GpuTensor → NcclBackend wiring landed in #1135"]
fn gpu_broadcast_dispatches_to_nccl_in_single_rank_mode() {
use crate::nccl_backend::NcclBackend;
use crate::nccl_sys::get_unique_id;
let unique_id = get_unique_id().expect("NCCL unique ID generation");
let nccl = NcclBackend::new(0, 1, unique_id).expect("NcclBackend init");
let input = [42.0_f32, 99.0, -7.5];
let gt = gpu_from_slice(&input, &[3]);
let result = gpu_broadcast(>, &nccl, 0)
.expect("single-rank NCCL broadcast dispatch must succeed");
assert_eq!(result.shape(), &[3]);
let cpu = result.cpu().expect("result to CPU");
let out = cpu.data().expect("flat data");
for (i, (got, want)) in out.iter().zip(input.iter()).enumerate() {
assert!(
(got - want).abs() < 1e-6,
"broadcast[{i}]: got {got}, want {want}",
);
}
}
}