use std::marker::PhantomData;
use burn_backend::Shape;
use burn_communication::ProtocolClient;
use burn_ir::TensorIr;
use burn_router::{RouterTensor, RunnerChannel, get_client};
use super::{
RemoteClient,
runner::{RemoteBridge, RemoteDevice, RemoteTensorHandle},
};
pub struct RemoteChannel<C: ProtocolClient> {
_p: PhantomData<C>,
}
impl<C: ProtocolClient> RunnerChannel for RemoteChannel<C> {
type Device = RemoteDevice;
type Bridge = RemoteBridge<C>;
type Client = RemoteClient;
type FloatElem = f32;
type IntElem = i32;
type BoolElem = u32;
fn name(device: &Self::Device) -> String {
format!("remote-{device:?}")
}
fn init_client(device: &Self::Device) -> Self::Client {
RemoteClient::init::<C>(device.clone())
}
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> RemoteTensorHandle<C> {
RemoteTensorHandle {
client: client.clone(),
tensor: tensor.clone(),
_p: PhantomData,
}
}
fn register_tensor(
_client: &Self::Client,
_handle: RemoteTensorHandle<C>,
_shape: Shape,
_dtype: burn_backend::DType,
) -> RouterTensor<Self::Client> {
panic!("Can't register manually a tensor on a remote channel.");
}
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
target_device: &Self::Device, ) -> RouterTensor<Self::Client> {
let original_client = tensor.client.clone();
let desc = tensor.into_ir();
let handle = Self::get_tensor_handle(&desc, &original_client);
let handle = handle.change_backend(target_device);
let id = handle.tensor.id;
let target_client = get_client::<Self>(target_device);
let router_tensor: RouterTensor<RemoteClient> =
RouterTensor::new(id, handle.tensor.shape, handle.tensor.dtype, target_client);
router_tensor
}
}
impl<C: ProtocolClient> Clone for RemoteChannel<C> {
fn clone(&self) -> Self {
RemoteChannel { _p: PhantomData }
}
}