burn_router/channel/
base.rs1use alloc::string::String;
2use burn_backend::{DType, Element, Shape, backend::DeviceOps};
3use burn_ir::TensorIr;
4
5use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
6
7pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
9
10pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
12 type Device: DeviceOps;
14 type Bridge: MultiBackendBridge<Device = Self::Device>;
16 type Client: RunnerClient<Device = Self::Device>;
18 type FloatElem: Element;
20 type IntElem: Element;
22 type BoolElem: Element;
24
25 fn name(device: &Self::Device) -> String;
27
28 fn init_client(device: &Self::Device) -> Self::Client;
30
31 fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;
33
34 fn register_tensor(
36 client: &Self::Client,
37 handle: TensorHandle<Self::Bridge>,
38 shape: Shape,
39 dtype: DType,
40 ) -> RouterTensor<Self::Client>;
41
42 fn change_client_backend(
44 tensor: RouterTensor<Self::Client>,
45 device: &Self::Device, ) -> RouterTensor<Self::Client> {
47 let original_client = tensor.client.clone();
49 let desc = tensor.into_ir();
50 let mut handle = Self::get_tensor_handle(&desc, &original_client);
51
52 if desc.dtype.is_float() {
53 handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);
54 } else if desc.dtype.is_int() {
55 handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);
56 } else if desc.dtype.is_bool() {
57 handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);
58 } else {
59 unimplemented!()
60 }
61
62 let target_client = get_client::<Self>(device);
64 Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
65 }
66}