use cubecl::device::DeviceId;
use crate::{
Backend,
distributed::CollectiveTensor,
tensor::{Device, FloatTensor},
};
use crate::distributed::{
DistributedConfig, DistributedParams, ReduceOperation, close_distributed_sync_server,
get_distributed_sync_client, start_distributed_sync_server,
};
#[derive(Clone)]
pub struct TensorRef<B: Backend>(pub *mut FloatTensor<B>);
unsafe impl<B> Sync for TensorRef<B> where B: Backend {}
unsafe impl<B> Send for TensorRef<B> where B: Backend {}
pub trait DistributedBackend: Backend {
fn start_communication_server(devices: &[Self::Device], config: DistributedConfig) {
start_distributed_sync_server::<Self>(devices, config);
}
fn close_communication_server(_device: &Self::Device) {
close_distributed_sync_server::<Self>();
}
fn register_sync_parameters(
_device: &Self::Device,
distributed_params: Vec<DistributedParams>,
) {
if let Some(sync_client) = get_distributed_sync_client::<Self>() {
sync_client.register_sync_parameters(distributed_params);
};
}
fn submit_sync_collective(device: &Self::Device) {
if let Some(sync_client) = get_distributed_sync_client::<Self>() {
sync_client.submit_sync_collective(device.clone());
};
}
fn submit_gradient_sync(tensor: TensorRef<Self>, distributed_params: DistributedParams) {
if let Some(sync_client) = get_distributed_sync_client::<Self>() {
sync_client.submit_gradient_sync(tensor, distributed_params);
};
}
fn all_reduce(
_tensor: FloatTensor<Self>,
_op: ReduceOperation,
_device_ids: Vec<DeviceId>,
) -> CollectiveTensor<Self> {
unimplemented!()
}
fn sync_collective(_device: &Self::Device) {
unimplemented!()
}
unsafe fn comm_device(tensor: &TensorRef<Self>) -> Device<Self> {
unsafe { Self::float_device(&(*tensor.0)) }
}
unsafe fn float_from_ref(tensor: &TensorRef<Self>) -> FloatTensor<Self> {
unsafe { (*tensor.0).clone() }
}
}