use crate::{RouterTensor, RunnerChannel};
use alloc::boxed::Box;
use alloc::vec::Vec;
use burn_backend::{
DType, TensorData,
backend::{DeviceId, DeviceOps, ExecutionError},
};
use burn_ir::{OperationIr, TensorId, TensorIr};
use burn_std::future::DynFut;
use core::ops::DerefMut;
use hashbrown::HashMap;
use spin::Mutex;
pub type Client<R> = <R as RunnerChannel>::Client;
pub(crate) static CLIENTS: RunnerClientLocator = RunnerClientLocator::new();
type Key = (core::any::TypeId, DeviceId);
pub trait RunnerClient: Clone + Send + Sync + Sized {
type Device: DeviceOps;
fn register_op(&self, op: OperationIr);
fn register(&self, op: OperationIr) -> Vec<RouterTensor<Self>> {
let out = op
.outputs()
.map(|output| {
RouterTensor::new(output.id, output.shape.clone(), output.dtype, self.clone())
})
.collect();
self.register_op(op);
out
}
fn read_tensor_async(&self, tensor: TensorIr) -> DynFut<Result<TensorData, ExecutionError>>;
fn sync(&self) -> Result<(), ExecutionError>;
fn create_empty_handle(&self) -> TensorId;
fn register_tensor_data(&self, data: TensorData) -> RouterTensor<Self>;
fn device(&self) -> Self::Device;
fn seed(&self, seed: u64);
fn dtype_usage(&self, dtype: DType) -> burn_backend::DTypeUsageSet;
}
pub(crate) struct RunnerClientLocator {
clients: Mutex<Option<HashMap<Key, Box<dyn core::any::Any + Send>>>>,
}
pub fn get_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
CLIENTS.client::<R>(device)
}
fn new_client<R: RunnerChannel>(device: &R::Device) -> Client<R> {
R::init_client(device)
}
impl RunnerClientLocator {
pub const fn new() -> Self {
Self {
clients: Mutex::new(None),
}
}
pub fn client<R: RunnerChannel + 'static>(&self, device: &R::Device) -> Client<R> {
let device_id = device.id();
let client_id = (core::any::TypeId::of::<R>(), device_id);
let mut clients = self.clients.lock();
if clients.is_none() {
let client = new_client::<R>(device);
Self::register_inner::<R>(client_id, client, &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(&client_id) {
Some(client) => {
let client: &Client<R> = client.downcast_ref().unwrap();
client.clone()
}
None => {
let client = new_client::<R>(device);
let any = Box::new(client.clone());
clients.insert(client_id, any);
client
}
},
_ => unreachable!(),
}
}
fn register_inner<R: RunnerChannel + 'static>(
key: Key,
client: Client<R>,
clients: &mut Option<HashMap<Key, Box<dyn core::any::Any + Send>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(&key) {
panic!("Client already created for device {key:?}");
}
clients.insert(key, Box::new(client));
}
}
}