Skip to main content

burn_router/client/
base.rs

1use crate::{RouterTensor, RunnerChannel};
2use alloc::boxed::Box;
3use alloc::vec::Vec;
4use burn_backend::{
5    DType, TensorData,
6    backend::{DeviceId, DeviceOps, ExecutionError},
7};
8use burn_ir::{OperationIr, TensorId, TensorIr};
9use burn_std::future::DynFut;
10use core::ops::DerefMut;
11use hashbrown::HashMap;
12use spin::Mutex;
13
14/// Type alias for `<R as RunnerChannel>::Client`.
15pub type Client<R> = <R as RunnerChannel>::Client;
16pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
17
18type Key = (core::any::TypeId, DeviceId);
19
20/// Define how to interact with the runner.
21pub trait RunnerClient: Clone + Send + Sync + Sized {
22    /// Device type.
23    type Device: DeviceOps;
24
25    /// Register a new tensor operation to be executed by the (runner) server.
26    fn register_op(&self, op: OperationIr);
27    /// Register a new tensor operation to be executed by the (runner) server.
28    ///
29    /// Returns the new (uninitialized) output tensor(s) generated by the registered operation.
30    fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {
31        let out = op
32            .outputs()
33            .map(|output| {
34                RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())
35            })
36            .collect();
37        self.register_op(op);
38
39        out
40    }
41    /// Read the values contained by a tensor.
42    fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;
43    /// Sync the runner, ensure that all computations are finished.
44    fn sync(&self) -> Result<(), ExecutionError>;
45    /// Create a new (uninitialized) empty tensor and returns its corresponding [tensor id](TensorId).
46    fn create_empty_handle(&self) -> TensorId;
47    /// Create a new [RouterTensor] from the tensor data.
48    fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
49    /// Get the current device used by all operations handled by this client.
50    fn device(&self) -> Self::Device;
51    /// Seed the runner.
52    fn seed(&self, seed: u64);
53    /// Whether the type is supported.
54    fn supports_dtype(&self, dtype: DType) -> bool;
55}
56
57pub(crate) struct RunnerClientLocator {
58    clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
59}
60
61/// Get the client for the given device
62pub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
63    CLIENTS.client::<R>(device)
64}
65
66/// Initialize a new client for the given device.
67///
68/// If a (global) seed was previously set, the client seed is set.
69fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
70    R::init_client(device)
71}
72
73impl RunnerClientLocator {
74    /// Create a new client locator.
75    pub const fn new() -> Self {
76        Self {
77            clients: Mutex::new(None),
78        }
79    }
80
81    /// Get the runner client for the given device.
82    ///
83    /// If a client isn't already initialized, it is created.
84    pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
85        let device_id = device.id();
86        let client_id = (core::any::TypeId::of::<R>(), device_id);
87        let mut clients = self.clients.lock();
88
89        if clients.is_none() {
90            let client = new_client::<R>(device);
91            Self::register_inner::<R>(client_id, client, &mut clients);
92        }
93
94        match clients.deref_mut() {
95            Some(clients) => match clients.get(&client_id) {
96                Some(client) => {
97                    let client: &Client<R> = client.downcast_ref().unwrap();
98                    client.clone()
99                }
100                None => {
101                    let client = new_client::<R>(device);
102                    let any = Box::new(client.clone());
103                    clients.insert(client_id, any);
104                    client
105                }
106            },
107            _ => unreachable!(),
108        }
109    }
110
111    fn register_inner<R: RunnerChannel + 'static>(
112        key: Key,
113        client: Client<R>,
114        clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
115    ) {
116        if clients.is_none() {
117            *clients = Some(HashMap::new());
118        }
119
120        if let Some(clients) = clients {
121            if clients.contains_key(&key) {
122                panic!("Client already created for device {key:?}");
123            }
124
125            clients.insert(key, Box::new(client));
126        }
127    }
128}