burn_router/client/
base.rs1use crate::{RouterTensor, RunnerChannel};
2use alloc::boxed::Box;
3use alloc::vec::Vec;
4use burn_backend::{
5 DType, TensorData,
6 backend::{DeviceId, DeviceOps, ExecutionError},
7};
8use burn_ir::{OperationIr, TensorId, TensorIr};
9use burn_std::future::DynFut;
10use core::ops::DerefMut;
11use hashbrown::HashMap;
12use spin::Mutex;
13
14pub type Client<R> = <R as RunnerChannel>::Client;
16pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
17
18type Key = (core::any::TypeId, DeviceId);
19
20pub trait RunnerClient: Clone + Send + Sync + Sized {
22 type Device: DeviceOps;
24
25 fn register_op(&self, op: OperationIr);
27 fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {
31 let out = op
32 .outputs()
33 .map(|output| {
34 RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())
35 })
36 .collect();
37 self.register_op(op);
38
39 out
40 }
41 fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;
43 fn sync(&self) -> Result<(), ExecutionError>;
45 fn create_empty_handle(&self) -> TensorId;
47 fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
49 fn device(&self) -> Self::Device;
51 fn seed(&self, seed: u64);
53 fn supports_dtype(&self, dtype: DType) -> bool;
55}
56
57pub(crate) struct RunnerClientLocator {
58 clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
59}
60
61pub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
63 CLIENTS.client::<R>(device)
64}
65
66fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
70 R::init_client(device)
71}
72
73impl RunnerClientLocator {
74 pub const fn new() -> Self {
76 Self {
77 clients: Mutex::new(None),
78 }
79 }
80
81 pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
85 let device_id = device.id();
86 let client_id = (core::any::TypeId::of::<R>(), device_id);
87 let mut clients = self.clients.lock();
88
89 if clients.is_none() {
90 let client = new_client::<R>(device);
91 Self::register_inner::<R>(client_id, client, &mut clients);
92 }
93
94 match clients.deref_mut() {
95 Some(clients) => match clients.get(&client_id) {
96 Some(client) => {
97 let client: &Client<R> = client.downcast_ref().unwrap();
98 client.clone()
99 }
100 None => {
101 let client = new_client::<R>(device);
102 let any = Box::new(client.clone());
103 clients.insert(client_id, any);
104 client
105 }
106 },
107 _ => unreachable!(),
108 }
109 }
110
111 fn register_inner<R: RunnerChannel + 'static>(
112 key: Key,
113 client: Client<R>,
114 clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
115 ) {
116 if clients.is_none() {
117 *clients = Some(HashMap::new());
118 }
119
120 if let Some(clients) = clients {
121 if clients.contains_key(&key) {
122 panic!("Client already created for device {key:?}");
123 }
124
125 clients.insert(key, Box::new(client));
126 }
127 }
128}