cubecl_common/device/handle/
mod.rs1mod base;
2
3pub use base::*;
4
5use crate::device::{DeviceService, ServerUtilitiesHandle};
6
7#[cfg(feature = "std")]
8#[allow(dead_code)]
9mod channel;
10
11#[allow(dead_code)]
12mod mutex;
13
14#[cfg(feature = "std")]
15#[allow(dead_code)]
16mod reentrant;
17
18#[cfg(all(feature = "std", multi_threading))]
19type Inner<S> = channel::ChannelDeviceHandle<S>;
20#[cfg(all(feature = "std", not(multi_threading)))]
22type Inner<S> = reentrant::ReentrantMutexDeviceHandle<S>;
23#[cfg(all(not(feature = "std"), not(multi_threading)))]
24type Inner<S> = mutex::MutexDeviceHandle<S>;
25
26pub struct DeviceHandle<S: DeviceService> {
28 handle: Inner<S>,
29}
30
31impl<S: DeviceService> Clone for DeviceHandle<S> {
32 fn clone(&self) -> Self {
33 Self {
34 handle: self.handle.clone(),
35 }
36 }
37}
38
39#[allow(missing_docs)]
40impl<S: DeviceService> DeviceHandle<S> {
41 pub const fn is_blocking() -> bool {
42 Inner::<S>::BLOCKING
43 }
44
45 pub fn insert(device_id: super::DeviceId, service: S) -> Result<Self, ServiceCreationError> {
46 Ok(Self {
47 handle: <Inner<S> as DeviceHandleSpec<S>>::insert(device_id, service)?,
48 })
49 }
50
51 pub fn new(device_id: super::DeviceId) -> Self {
52 Self {
53 handle: <Inner<S> as DeviceHandleSpec<S>>::new(device_id),
54 }
55 }
56
57 pub fn utilities(&self) -> ServerUtilitiesHandle {
58 self.handle.utilities()
59 }
60
61 pub fn submit_blocking<R: Send + 'static, T: FnOnce(&mut S) -> R + Send + 'static>(
62 &self,
63 task: T,
64 ) -> Result<R, CallError> {
65 self.handle.submit_blocking(task)
66 }
67
68 pub fn submit_blocking_scoped<'a, R: Send + 'a, T: FnOnce(&mut S) -> R + Send + 'a>(
69 &self,
70 task: T,
71 ) -> R {
72 self.handle.submit_blocking_scoped(task)
73 }
74
75 pub fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
76 self.handle.submit(task)
77 }
78
79 pub fn flush_queue(&self) {
80 self.handle.flush_queue();
81 }
82
83 pub fn exclusive<R: Send + 'static, T: FnOnce() -> R + Send + 'static>(
84 &self,
85 task: T,
86 ) -> Result<R, CallError> {
87 self.handle.exclusive(task)
88 }
89
90 pub fn exclusive_scoped<R: Send, T: FnOnce() -> R + Send>(
91 &self,
92 task: T,
93 ) -> Result<R, CallError> {
94 self.handle.exclusive_scoped(task)
95 }
96}
97
98#[cfg(test)]
99mod tests_channel {
100 type DeviceHandle<S> = channel::ChannelDeviceHandle<S>;
101
102 include!("./tests.rs");
103 include!("./tests_recursive.rs");
104}
105
106#[cfg(test)]
107mod tests_mutex {
108 type DeviceHandle<S> = mutex::MutexDeviceHandle<S>;
109
110 include!("./tests.rs");
111}
112
113#[cfg(test)]
114mod tests_reentrant {
115 type DeviceHandle<S> = reentrant::ReentrantMutexDeviceHandle<S>;
116
117 include!("./tests.rs");
118 include!("./tests_recursive.rs");
119}