burn-cubecl 0.21.0-pre.3

Generic backend that can be compiled just-in-time to any shader language target
Documentation
use burn_backend::{
    DeviceOps,
    distributed::{DistributedBackend, ReduceOperation, TensorRef},
    tensor::Device,
};

use crate::{BoolElement, CubeBackend, CubeRuntime, FloatElement, IntElement};

impl<R, F, I, BT> DistributedBackend for CubeBackend<R, F, I, BT>
where
    R: CubeRuntime,
    F: FloatElement,
    I: IntElement,
    BT: BoolElement,
{
    unsafe fn all_reduce_in_place(tensors: Vec<TensorRef<Self>>, op: ReduceOperation) {
        let tensors = tensors.iter().map(|t| unsafe { &*t.0 }).collect::<Vec<_>>();
        let all_ids = tensors.iter().map(|t| t.device.id()).collect::<Vec<_>>();

        for tensor in tensors {
            let device = &tensor.device;
            let client = R::client(device);

            let op = match op {
                ReduceOperation::Sum => cubecl::server::ReduceOperation::Sum,
                ReduceOperation::Mean => cubecl::server::ReduceOperation::Mean,
            };

            client.all_reduce(
                tensor.handle.clone(),
                tensor.handle.clone(),
                tensor.dtype.into(),
                all_ids.clone(),
                op,
            );
        }
    }

    fn sync_collective(device: &Device<Self>) {
        let client = R::client(device);
        client.sync_collective();
    }
}