cubecl_common/device/handle/
mod.rs1mod base;
2
3pub use base::*;
4
5use crate::device::{DeviceId, DeviceService, ServerUtilitiesHandle};
6
7#[cfg(feature = "std")]
8#[allow(dead_code)]
9mod channel;
10
11#[allow(dead_code)]
12mod mutex;
13
14#[cfg(feature = "std")]
15#[allow(dead_code)]
16mod reentrant;
17
18#[cfg(all(feature = "std", multi_threading))]
19type Inner<S> = channel::ChannelDeviceHandle<S>;
20#[cfg(all(feature = "std", not(multi_threading)))]
22type Inner<S> = reentrant::ReentrantMutexDeviceHandle<S>;
23#[cfg(all(not(feature = "std"), not(multi_threading)))]
24type Inner<S> = mutex::MutexDeviceHandle<S>;
25
26pub struct DeviceHandle<S: DeviceService> {
28 handle: Inner<S>,
29}
30
31impl<S: DeviceService> Clone for DeviceHandle<S> {
32 fn clone(&self) -> Self {
33 Self {
34 handle: self.handle.clone(),
35 }
36 }
37}
38
39#[allow(missing_docs)]
40impl<S: DeviceService> DeviceHandle<S> {
41 pub const fn is_blocking() -> bool {
42 Inner::<S>::BLOCKING
43 }
44
45 pub fn insert(device_id: super::DeviceId, service: S) -> Result<Self, ServiceCreationError> {
46 Ok(Self {
47 handle: <Inner<S> as DeviceHandleSpec<S>>::insert(device_id, service)?,
48 })
49 }
50
51 pub fn new(device_id: super::DeviceId) -> Self {
52 Self {
53 handle: <Inner<S> as DeviceHandleSpec<S>>::new(device_id),
54 }
55 }
56
57 pub fn device_id(&self) -> DeviceId {
58 self.handle.device_id()
59 }
60
61 pub fn utilities(&self) -> ServerUtilitiesHandle {
62 self.handle.utilities()
63 }
64
65 pub fn submit_blocking<'a, R: Send, T: FnOnce(&mut S) -> R + Send + 'a>(
66 &self,
67 task: T,
68 ) -> Result<R, CallError> {
69 self.handle.submit_blocking(task)
70 }
71
72 pub fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
73 self.handle.submit(task)
74 }
75
76 pub fn flush_queue(&self) {
77 self.handle.flush_queue();
78 }
79
80 pub fn exclusive<R: Send, T: FnOnce() -> R + Send>(&self, task: T) -> Result<R, CallError> {
81 self.handle.exclusive(task)
82 }
83}
84
85#[cfg(test)]
86mod tests_channel {
87 type DeviceHandle<S> = channel::ChannelDeviceHandle<S>;
88
89 include!("./tests.rs");
90 include!("./tests_recursive.rs");
91}
92
93#[cfg(test)]
94mod tests_mutex {
95 type DeviceHandle<S> = mutex::MutexDeviceHandle<S>;
96
97 include!("./tests.rs");
98}
99
100#[cfg(test)]
101mod tests_reentrant {
102 type DeviceHandle<S> = reentrant::ReentrantMutexDeviceHandle<S>;
103
104 include!("./tests.rs");
105 include!("./tests_recursive.rs");
106}