Skip to main content

burn_router/channel/
base.rs

1use alloc::string::String;
2use burn_backend::{DType, Element, Shape, backend::DeviceOps};
3use burn_ir::TensorIr;
4
5use crate::{MultiBackendBridge, RouterTensor, RunnerClient, get_client};
6
7/// Type alias for `<Br as MultiBackendBridge>::TensorHandle`.
8pub type TensorHandle<Br> = <Br as MultiBackendBridge>::TensorHandle;
9
10/// Defines the connection channel and operations for a setup with multiple backend runner clients.
11pub trait RunnerChannel: Clone + Send + Sync + 'static + Sized {
12    /// Device type.
13    type Device: DeviceOps;
14    /// A bridge that can transfer tensors between multiple backends.
15    type Bridge: MultiBackendBridge<Device = Self::Device>;
16    /// Client type.
17    type Client: RunnerClient<Device = Self::Device>;
18    /// Float element type.
19    type FloatElem: Element;
20    /// Int element type.
21    type IntElem: Element;
22    /// Bool element type.
23    type BoolElem: Element;
24
25    /// Name of the channel.
26    fn name(device: &Self::Device) -> String;
27
28    /// Initialize a new client for the given device.
29    fn init_client(device: &Self::Device) -> Self::Client;
30
31    /// Get the tensor handle corresponding to the [tensor representation](TensorIr).
32    fn get_tensor_handle(tensor: &TensorIr, client: &Self::Client) -> TensorHandle<Self::Bridge>;
33
34    /// Create a tensor with the given handle and shape.
35    fn register_tensor(
36        client: &Self::Client,
37        handle: TensorHandle<Self::Bridge>,
38        shape: Shape,
39        dtype: DType,
40    ) -> RouterTensor<Self::Client>;
41
42    /// Change the tensor to a different client backend.
43    fn change_client_backend(
44        tensor: RouterTensor<Self::Client>,
45        device: &Self::Device, // target device
46    ) -> RouterTensor<Self::Client> {
47        // Get tensor handle from current client
48        let original_client = tensor.client.clone();
49        let desc = tensor.into_ir();
50        let mut handle = Self::get_tensor_handle(&desc, &original_client);
51
52        if desc.dtype.is_float() {
53            handle = Self::Bridge::change_backend_float(handle, desc.shape.clone(), device);
54        } else if desc.dtype.is_int() {
55            handle = Self::Bridge::change_backend_int(handle, desc.shape.clone(), device);
56        } else if desc.dtype.is_bool() {
57            handle = Self::Bridge::change_backend_bool(handle, desc.shape.clone(), device);
58        } else {
59            unimplemented!()
60        }
61
62        // Register tensor handle on target client
63        let target_client = get_client::<Self>(device);
64        Self::register_tensor(&target_client, handle, desc.shape, desc.dtype)
65    }
66}