Skip to main content

cubecl_common/device/handle/
mod.rs

1mod base;
2
3pub use base::*;
4
5use crate::device::{DeviceId, DeviceService, ServerUtilitiesHandle};
6
7#[cfg(feature = "std")]
8#[allow(dead_code)]
9mod channel;
10
11#[allow(dead_code)]
12mod mutex;
13
14#[cfg(feature = "std")]
15#[allow(dead_code)]
16mod reentrant;
17
18#[cfg(all(feature = "std", multi_threading))]
19type Inner<S> = channel::ChannelDeviceHandle<S>;
20// type Inner<S> = mutex::MutexDeviceHandle<S>;
21#[cfg(all(feature = "std", not(multi_threading)))]
22type Inner<S> = reentrant::ReentrantMutexDeviceHandle<S>;
23#[cfg(all(not(feature = "std"), not(multi_threading)))]
24type Inner<S> = mutex::MutexDeviceHandle<S>;
25
26/// TODO: Docs
27pub struct DeviceHandle<S: DeviceService> {
28    handle: Inner<S>,
29}
30
31impl<S: DeviceService> Clone for DeviceHandle<S> {
32    fn clone(&self) -> Self {
33        Self {
34            handle: self.handle.clone(),
35        }
36    }
37}
38
39#[allow(missing_docs)]
40impl<S: DeviceService> DeviceHandle<S> {
41    pub const fn is_blocking() -> bool {
42        Inner::<S>::BLOCKING
43    }
44
45    pub fn insert(device_id: super::DeviceId, service: S) -> Result<Self, ServiceCreationError> {
46        Ok(Self {
47            handle: <Inner<S> as DeviceHandleSpec<S>>::insert(device_id, service)?,
48        })
49    }
50
51    pub fn new(device_id: super::DeviceId) -> Self {
52        Self {
53            handle: <Inner<S> as DeviceHandleSpec<S>>::new(device_id),
54        }
55    }
56
57    pub fn device_id(&self) -> DeviceId {
58        self.handle.device_id()
59    }
60
61    pub fn utilities(&self) -> ServerUtilitiesHandle {
62        self.handle.utilities()
63    }
64
65    pub fn submit_blocking<'a, R: Send, T: FnOnce(&mut S) -> R + Send + 'a>(
66        &self,
67        task: T,
68    ) -> Result<R, CallError> {
69        self.handle.submit_blocking(task)
70    }
71
72    pub fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
73        self.handle.submit(task)
74    }
75
76    pub fn flush_queue(&self) {
77        self.handle.flush_queue();
78    }
79
80    pub fn exclusive<R: Send, T: FnOnce() -> R + Send>(&self, task: T) -> Result<R, CallError> {
81        self.handle.exclusive(task)
82    }
83}
84
85#[cfg(test)]
86mod tests_channel {
87    type DeviceHandle<S> = channel::ChannelDeviceHandle<S>;
88
89    include!("./tests.rs");
90    include!("./tests_recursive.rs");
91}
92
93#[cfg(test)]
94mod tests_mutex {
95    type DeviceHandle<S> = mutex::MutexDeviceHandle<S>;
96
97    include!("./tests.rs");
98}
99
100#[cfg(test)]
101mod tests_reentrant {
102    type DeviceHandle<S> = reentrant::ReentrantMutexDeviceHandle<S>;
103
104    include!("./tests.rs");
105    include!("./tests_recursive.rs");
106}