use ferrotorch_core::FerrotorchResult;
use ferrotorch_gpu::{GpuFloat, GpuTensor, tensor_to_cpu, tensor_to_gpu};
use crate::backend::Backend;
use crate::collective::{ReduceOp, allreduce, broadcast};
use crate::error::DistributedError;
fn cpu_path_allreduce<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
op: ReduceOp,
) -> FerrotorchResult<GpuTensor<T>> {
let cpu_tensor = tensor_to_cpu(tensor)?;
let reduced = allreduce(&cpu_tensor, backend, op)?;
let gpu_result = tensor_to_gpu(&reduced, tensor.device()).map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_allreduce: CPU->GPU transfer failed: {e}"),
}
})?;
Ok(gpu_result)
}
fn cpu_path_broadcast<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
root: usize,
) -> FerrotorchResult<GpuTensor<T>> {
let cpu_tensor = tensor_to_cpu(tensor)?;
let bcast = broadcast(&cpu_tensor, backend, root)?;
let gpu_result = tensor_to_gpu(&bcast, tensor.device()).map_err(|e| {
ferrotorch_core::FerrotorchError::InvalidArgument {
message: format!("gpu_broadcast: CPU->GPU transfer failed: {e}"),
}
})?;
Ok(gpu_result)
}
pub fn gpu_allreduce<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
op: ReduceOp,
) -> FerrotorchResult<GpuTensor<T>> {
#[cfg(feature = "nccl")]
let () = ();
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
collective = "allreduce",
"GPU collective is using host round-trip (Gloo-equivalent slow path). \
Unset FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
return cpu_path_allreduce(tensor, backend, op);
}
Err(DistributedError::UnsupportedOp {
message: "gpu_allreduce requires the `nccl` feature for GPU-native operation. \
Set FERROTORCH_ENABLE_GPU_FALLBACK=1 to enable the host round-trip \
(Gloo-equivalent) fallback instead."
.into(),
}
.into())
}
pub fn gpu_broadcast<T: GpuFloat>(
tensor: &GpuTensor<T>,
backend: &dyn Backend,
root: usize,
) -> FerrotorchResult<GpuTensor<T>> {
#[cfg(feature = "nccl")]
let () = ();
if std::env::var("FERROTORCH_ENABLE_GPU_FALLBACK").is_ok() {
tracing::warn!(
target: "ferrotorch::gpu_fallback",
collective = "broadcast",
"GPU collective is using host round-trip (Gloo-equivalent slow path). \
Unset FERROTORCH_ENABLE_GPU_FALLBACK to make this an error instead.",
);
return cpu_path_broadcast(tensor, backend, root);
}
Err(DistributedError::UnsupportedOp {
message: "gpu_broadcast requires the `nccl` feature for GPU-native operation. \
Set FERROTORCH_ENABLE_GPU_FALLBACK=1 to enable the host round-trip \
(Gloo-equivalent) fallback instead."
.into(),
}
.into())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::backend::SimulatedBackend;
use ferrotorch_gpu::{GpuDevice, tensor_to_gpu as t2g};
use std::sync::Arc;
use std::thread;
fn gpu_from_slice(data: &[f32], shape: &[usize]) -> GpuTensor<f32> {
let cpu = ferrotorch_core::from_slice(data, shape).unwrap();
let device = GpuDevice::new(0).unwrap();
t2g(&cpu, &device).unwrap()
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_sum_2_ranks() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![1.0, 2.0, 3.0]
} else {
vec![4.0, 5.0, 6.0]
};
let gt = gpu_from_slice(&data, &[3]);
let result = gpu_allreduce(>, b.as_ref(), ReduceOp::Sum).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert_eq!(out.len(), 3);
assert!(
(out[0] - 5.0).abs() < 1e-6,
"rank {rank}: expected 5.0, got {}",
out[0]
);
assert!(
(out[1] - 7.0).abs() < 1e-6,
"rank {rank}: expected 7.0, got {}",
out[1]
);
assert!(
(out[2] - 9.0).abs() < 1e-6,
"rank {rank}: expected 9.0, got {}",
out[2]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_mean_2_ranks() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![2.0, 4.0]
} else {
vec![6.0, 8.0]
};
let gt = gpu_from_slice(&data, &[2]);
let result = gpu_allreduce(>, b.as_ref(), ReduceOp::Mean).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert!(
(out[0] - 4.0).abs() < 1e-6,
"rank {rank}: expected 4.0, got {}",
out[0]
);
assert!(
(out[1] - 6.0).abs() < 1e-6,
"rank {rank}: expected 6.0, got {}",
out[1]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_broadcast_from_rank_0() {
let group = SimulatedBackend::create_group(2).unwrap();
let arcs: Vec<Arc<SimulatedBackend>> = group.into_iter().map(Arc::new).collect();
let handles: Vec<_> = arcs
.iter()
.cloned()
.map(|b| {
thread::spawn(move || {
let rank = b.rank();
let data: Vec<f32> = if rank == 0 {
vec![42.0, 99.0]
} else {
vec![0.0, 0.0]
};
let gt = gpu_from_slice(&data, &[2]);
let result = gpu_broadcast(>, b.as_ref(), 0).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert!(
(out[0] - 42.0).abs() < 1e-6,
"rank {rank}: expected 42.0, got {}",
out[0]
);
assert!(
(out[1] - 99.0).abs() < 1e-6,
"rank {rank}: expected 99.0, got {}",
out[1]
);
})
})
.collect();
for h in handles {
h.join().unwrap();
}
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_single_rank() {
let group = SimulatedBackend::create_group(1).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0, 3.0], &[3]);
let result = gpu_allreduce(>, &group[0], ReduceOp::Sum).unwrap();
let cpu = result.cpu().unwrap();
let out = cpu.data().unwrap();
assert_eq!(out, &[1.0, 2.0, 3.0]);
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_allreduce_preserves_shape() {
let group = SimulatedBackend::create_group(1).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let result = gpu_allreduce(>, &group[0], ReduceOp::Sum).unwrap();
assert_eq!(result.shape(), &[2, 3]);
}
#[test]
#[ignore = "tracking issue #668: opt-in fallback now required; test must be updated once NCCL wiring or env-var harness is in place"]
fn test_gpu_broadcast_invalid_root() {
let group = SimulatedBackend::create_group(2).unwrap();
let gt = gpu_from_slice(&[1.0, 2.0], &[2]);
let result = gpu_broadcast(>, &group[0], 5);
assert!(result.is_err());
}
}