1use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
2use core::ops::DerefMut;
3use hashbrown::HashMap;
4
5pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
8 clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
9}
10
11#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
13pub enum ExecutionMode {
14 #[default]
16 Checked,
17 Unchecked,
19}
20
21pub use cubecl_common::benchmark::{TimestampsError, TimestampsResult};
22
23impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
24where
25 Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
26 Server: ComputeServer,
27 Channel: ComputeChannel<Server>,
28{
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
35where
36 Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
37 Server: ComputeServer,
38 Channel: ComputeChannel<Server>,
39{
40 pub const fn new() -> Self {
42 Self {
43 clients: spin::Mutex::new(None),
44 }
45 }
46
47 pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
51 where
52 Init: Fn() -> ComputeClient<Server, Channel>,
53 {
54 let mut clients = self.clients.lock();
55
56 if clients.is_none() {
57 Self::register_inner(device, init(), &mut clients);
58 }
59
60 match clients.deref_mut() {
61 Some(clients) => match clients.get(device) {
62 Some(client) => client.clone(),
63 None => {
64 let client = init();
65 clients.insert(device.clone(), client.clone());
66 client
67 }
68 },
69 _ => unreachable!(),
70 }
71 }
72
73 pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
84 let mut clients = self.clients.lock();
85
86 Self::register_inner(device, client, &mut clients);
87 }
88
89 fn register_inner(
90 device: &Device,
91 client: ComputeClient<Server, Channel>,
92 clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
93 ) {
94 if clients.is_none() {
95 *clients = Some(HashMap::new());
96 }
97
98 if let Some(clients) = clients {
99 if clients.contains_key(device) {
100 panic!("Client already created for device {:?}", device);
101 }
102
103 clients.insert(device.clone(), client);
104 }
105 }
106}