use alloc::string::String;
use burn_backend::{DType, Element, Shape, backend::DeviceOps};
use burn_ir::TensorIr;
use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
type Device: DeviceOps;
type Bridge: MultiBackendBridge<Device = Self::Device>;
type Client: RunnerClient<Device = Self::Device>;
type FloatElem: Element;
type IntElem: Element;
type BoolElem: Element;
fn name(device: &Self::Device) -> String;
fn init_client(device: &Self::Device) -> Self::Client;
fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;
fn register_tensor(
client: &Self::Client,
handle: TensorHandle<Self::Bridge>,
shape: Shape,
dtype: DType,
) -> RouterTensor<Self::Client>;
fn change_client_backend(
tensor: RouterTensor<Self::Client>,
device: &Self::Device, ) -> RouterTensor<Self::Client> {
let original_client = tensor.client.clone();
let desc = tensor.into_ir();
let mut handle = Self::get_tensor_handle(&desc, &original_client);
if desc.dtype.is_float() {
handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);
} else if desc.dtype.is_int() {
handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);
} else if desc.dtype.is_bool() {
handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);
} else {
unimplemented!()
}
let target_client = get_client::<Self>(device);
Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
}
}