use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
use core::ops::DerefMut;
use hashbrown::HashMap;
pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}
#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
pub enum ExecutionMode {
#[default]
Checked,
Unchecked,
}
impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
fn default() -> Self {
Self::new()
}
}
impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
where
Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
pub const fn new() -> Self {
Self {
clients: spin::Mutex::new(None),
}
}
pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
where
Init: Fn() -> ComputeClient<Server, Channel>,
{
let mut clients = self.clients.lock();
if clients.is_none() {
Self::register_inner(device, init(), &mut clients);
}
match clients.deref_mut() {
Some(clients) => match clients.get(device) {
Some(client) => client.clone(),
None => {
let client = init();
clients.insert(device.clone(), client.clone());
client
}
},
_ => unreachable!(),
}
}
pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
let mut clients = self.clients.lock();
Self::register_inner(device, client, &mut clients);
}
fn register_inner(
device: &Device,
client: ComputeClient<Server, Channel>,
clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
) {
if clients.is_none() {
*clients = Some(HashMap::new());
}
if let Some(clients) = clients {
if clients.contains_key(device) {
panic!("Client already created for device {:?}", device);
}
clients.insert(device.clone(), client);
}
}
}