Skip to main content

cubecl_common/device/handle/
mod.rs

1mod base;
2
3pub use base::*;
4
5use crate::device::{DeviceService, ServerUtilitiesHandle};
6
7#[cfg(feature = "std")]
8#[allow(dead_code)]
9mod channel;
10
11#[allow(dead_code)]
12mod mutex;
13
14#[cfg(feature = "std")]
15#[allow(dead_code)]
16mod reentrant;
17
18#[cfg(all(feature = "std", multi_threading))]
19type Inner<S> = channel::ChannelDeviceHandle<S>;
20// type Inner<S> = reentrant::ReentrantMutexDeviceHandle<S>;
21#[cfg(all(feature = "std", not(multi_threading)))]
22type Inner<S> = reentrant::ReentrantMutexDeviceHandle<S>;
23#[cfg(all(not(feature = "std"), not(multi_threading)))]
24type Inner<S> = mutex::MutexDeviceHandle<S>;
25
26/// TODO: Docs
27pub struct DeviceHandle<S: DeviceService> {
28    handle: Inner<S>,
29}
30
31impl<S: DeviceService> Clone for DeviceHandle<S> {
32    fn clone(&self) -> Self {
33        Self {
34            handle: self.handle.clone(),
35        }
36    }
37}
38
39#[allow(missing_docs)]
40impl<S: DeviceService> DeviceHandle<S> {
41    pub const fn is_blocking() -> bool {
42        Inner::<S>::BLOCKING
43    }
44
45    pub fn insert(device_id: super::DeviceId, service: S) -> Result<Self, ServiceCreationError> {
46        Ok(Self {
47            handle: <Inner<S> as DeviceHandleSpec<S>>::insert(device_id, service)?,
48        })
49    }
50
51    pub fn new(device_id: super::DeviceId) -> Self {
52        Self {
53            handle: <Inner<S> as DeviceHandleSpec<S>>::new(device_id),
54        }
55    }
56
57    pub fn utilities(&self) -> ServerUtilitiesHandle {
58        self.handle.utilities()
59    }
60
61    pub fn submit_blocking<R: Send + 'static, T: FnOnce(&mut S) -> R + Send + 'static>(
62        &self,
63        task: T,
64    ) -> Result<R, CallError> {
65        self.handle.submit_blocking(task)
66    }
67
68    pub fn submit_blocking_scoped<'a, R: Send + 'a, T: FnOnce(&mut S) -> R + Send + 'a>(
69        &self,
70        task: T,
71    ) -> R {
72        self.handle.submit_blocking_scoped(task)
73    }
74
75    pub fn submit<T: FnOnce(&mut S) + Send + 'static>(&self, task: T) {
76        self.handle.submit(task)
77    }
78
79    pub fn flush_queue(&self) {
80        self.handle.flush_queue();
81    }
82
83    pub fn exclusive<R: Send + 'static, T: FnOnce() -> R + Send + 'static>(
84        &self,
85        task: T,
86    ) -> Result<R, CallError> {
87        self.handle.exclusive(task)
88    }
89
90    pub fn exclusive_scoped<R: Send, T: FnOnce() -> R + Send>(
91        &self,
92        task: T,
93    ) -> Result<R, CallError> {
94        self.handle.exclusive_scoped(task)
95    }
96}
97
98#[cfg(test)]
99mod tests_channel {
100    type DeviceHandle<S> = channel::ChannelDeviceHandle<S>;
101
102    include!("./tests.rs");
103    include!("./tests_recursive.rs");
104}
105
106#[cfg(test)]
107mod tests_mutex {
108    type DeviceHandle<S> = mutex::MutexDeviceHandle<S>;
109
110    include!("./tests.rs");
111}
112
113#[cfg(test)]
114mod tests_reentrant {
115    type DeviceHandle<S> = reentrant::ReentrantMutexDeviceHandle<S>;
116
117    include!("./tests.rs");
118    include!("./tests_recursive.rs");
119}