cubecl_runtime/
base.rs

1use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
2use core::ops::DerefMut;
3use hashbrown::HashMap;
4
5/// The compute type has the responsibility to retrieve the correct compute client based on the
6/// given device.
7pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
8    clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
9}
10
11/// The kind of execution to be performed.
12#[derive(Default, Hash, PartialEq, Eq, Clone, Debug, Copy)]
13pub enum ExecutionMode {
14    /// Checked kernels are safe.
15    #[default]
16    Checked,
17    /// Unchecked kernels are unsafe.
18    Unchecked,
19}
20
21pub use cubecl_common::benchmark::{TimestampsError, TimestampsResult};
22
23impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
24where
25    Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
26    Server: ComputeServer,
27    Channel: ComputeChannel<Server>,
28{
29    fn default() -> Self {
30        Self::new()
31    }
32}
33
34impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
35where
36    Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
37    Server: ComputeServer,
38    Channel: ComputeChannel<Server>,
39{
40    /// Create a new compute.
41    pub const fn new() -> Self {
42        Self {
43            clients: spin::Mutex::new(None),
44        }
45    }
46
47    /// Get the compute client for the given device.
48    ///
49    /// Provide the init function to create a new client if it isn't already initialized.
50    pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
51    where
52        Init: Fn() -> ComputeClient<Server, Channel>,
53    {
54        let mut clients = self.clients.lock();
55
56        if clients.is_none() {
57            Self::register_inner(device, init(), &mut clients);
58        }
59
60        match clients.deref_mut() {
61            Some(clients) => match clients.get(device) {
62                Some(client) => client.clone(),
63                None => {
64                    let client = init();
65                    clients.insert(device.clone(), client.clone());
66                    client
67                }
68            },
69            _ => unreachable!(),
70        }
71    }
72
73    /// Register the compute client for the given device.
74    ///
75    /// # Note
76    ///
77    /// This function is mostly useful when the creation of the compute client can't be done
78    /// synchronously and require special context.
79    ///
80    /// # Panics
81    ///
82    /// If a client is already registered for the given device.
83    pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
84        let mut clients = self.clients.lock();
85
86        Self::register_inner(device, client, &mut clients);
87    }
88
89    fn register_inner(
90        device: &Device,
91        client: ComputeClient<Server, Channel>,
92        clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
93    ) {
94        if clients.is_none() {
95            *clients = Some(HashMap::new());
96        }
97
98        if let Some(clients) = clients {
99            if clients.contains_key(device) {
100                panic!("Client already created for device {:?}", device);
101            }
102
103            clients.insert(device.clone(), client);
104        }
105    }
106}