cubecl_runtime/
base.rs

1use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
2use core::ops::DerefMut;
3use hashbrown::HashMap;
4
5/// The compute type has the responsibility to retrieve the correct compute client based on the
6/// given device.
7pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
8    clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
9}
10
11impl<Device, Server, Channel> Default for ComputeRuntime<Device, Server, Channel>
12where
13    Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
14    Server: ComputeServer,
15    Channel: ComputeChannel<Server>,
16{
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
23where
24    Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
25    Server: ComputeServer,
26    Channel: ComputeChannel<Server>,
27{
28    /// Create a new compute.
29    pub const fn new() -> Self {
30        Self {
31            clients: spin::Mutex::new(None),
32        }
33    }
34
35    /// Get the compute client for the given device.
36    ///
37    /// Provide the init function to create a new client if it isn't already initialized.
38    pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
39    where
40        Init: Fn() -> ComputeClient<Server, Channel>,
41    {
42        let mut clients = self.clients.lock();
43
44        if clients.is_none() {
45            Self::register_inner(device, init(), &mut clients);
46        }
47
48        match clients.deref_mut() {
49            Some(clients) => match clients.get(device) {
50                Some(client) => client.clone(),
51                None => {
52                    let client = init();
53                    clients.insert(device.clone(), client.clone());
54                    client
55                }
56            },
57            _ => unreachable!(),
58        }
59    }
60
61    /// Register the compute client for the given device.
62    ///
63    /// # Note
64    ///
65    /// This function is mostly useful when the creation of the compute client can't be done
66    /// synchronously and require special context.
67    ///
68    /// # Panics
69    ///
70    /// If a client is already registered for the given device.
71    pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
72        let mut clients = self.clients.lock();
73
74        Self::register_inner(device, client, &mut clients);
75    }
76
77    fn register_inner(
78        device: &Device,
79        client: ComputeClient<Server, Channel>,
80        clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
81    ) {
82        if clients.is_none() {
83            *clients = Some(HashMap::new());
84        }
85
86        if let Some(clients) = clients {
87            if clients.contains_key(device) {
88                panic!("Client already created for device {device:?}");
89            }
90
91            clients.insert(device.clone(), client);
92        }
93    }
94}