burn_compute/
compute.rs

1use crate::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer};
2use core::ops::DerefMut;
3use hashbrown::HashMap;
4
5/// The compute type has the responsibility to retrieve the correct compute client based on the
6/// given device.
7pub struct ComputeRuntime<Device, Server: ComputeServer, Channel> {
8    clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
9}
10
11impl<Device, Server, Channel> ComputeRuntime<Device, Server, Channel>
12where
13    Device: core::hash::Hash + PartialEq + Eq + Clone + core::fmt::Debug,
14    Server: ComputeServer,
15    Channel: ComputeChannel<Server>,
16{
17    /// Create a new compute.
18    pub const fn new() -> Self {
19        Self {
20            clients: spin::Mutex::new(None),
21        }
22    }
23
24    /// Get the compute client for the given device.
25    ///
26    /// Provide the init function to create a new client if it isn't already initialized.
27    pub fn client<Init>(&self, device: &Device, init: Init) -> ComputeClient<Server, Channel>
28    where
29        Init: Fn() -> ComputeClient<Server, Channel>,
30    {
31        let mut clients = self.clients.lock();
32
33        if clients.is_none() {
34            Self::register_inner(device, init(), &mut clients);
35        }
36
37        match clients.deref_mut() {
38            Some(clients) => match clients.get(device) {
39                Some(client) => client.clone(),
40                None => {
41                    let client = init();
42                    clients.insert(device.clone(), client.clone());
43                    client
44                }
45            },
46            _ => unreachable!(),
47        }
48    }
49
50    /// Register the compute client for the given device.
51    ///
52    /// # Note
53    ///
54    /// This function is mostly useful when the creation of the compute client can't be done
55    /// synchronously and require special context.
56    ///
57    /// # Panics
58    ///
59    /// If a client is already registered for the given device.
60    pub fn register(&self, device: &Device, client: ComputeClient<Server, Channel>) {
61        let mut clients = self.clients.lock();
62
63        Self::register_inner(device, client, &mut clients);
64    }
65
66    fn register_inner(
67        device: &Device,
68        client: ComputeClient<Server, Channel>,
69        clients: &mut Option<HashMap<Device, ComputeClient<Server, Channel>>>,
70    ) {
71        if clients.is_none() {
72            *clients = Some(HashMap::new());
73        }
74
75        if let Some(clients) = clients {
76            if clients.contains_key(device) {
77                panic!("Client already created for device {:?}", device);
78            }
79
80            clients.insert(device.clone(), client);
81        }
82    }
83}