async_cuda/runtime/thread_local.rs
1use std::sync::mpsc::Sender;
2
3use once_cell::sync::Lazy;
4
5use crate::error::Error;
6use crate::runtime::execution::RUNTIME;
7use crate::runtime::work::Work;
8
9thread_local! {
10 /// Thread-local runtime delegate.
11 ///
12 /// This object serves as the per-thread reference to the [`RUNTIME`] that can be used to
13 /// enqueue work on the runtime thread.
14 ///
15 /// # Usage
16 ///
17 /// ```ignore
18 /// assert!(
19 /// RUNTIME_THREAD_LOCAL.with(|runtime|
20 /// runtime.enqueue(Work::new(|| ()))
21 /// ).is_ok()
22 /// )
23 /// ```
24 pub(super) static RUNTIME_THREAD_LOCAL: Lazy<RuntimeThreadLocal> = Lazy::new(|| {
25 RUNTIME.lock().unwrap().thread_local()
26 });
27}
28
29/// Per-thread delegate for global runtime.
30pub struct RuntimeThreadLocal(Sender<Work>);
31
32impl RuntimeThreadLocal {
33 /// Initialize [`RuntimeThreadLocal`] from [`Sender`] that allows the delegate to send work to
34 /// the actual [`crate::runtime::execution::Runtime`].
35 ///
36 /// # Arguments
37 ///
38 /// * `sender` - Sender through which work can be sent to runtime.
39 pub(super) fn from_sender(sender: Sender<Work>) -> Self {
40 RuntimeThreadLocal(sender)
41 }
42
43 /// Enqueue work on runtime.
44 ///
45 /// # Arguments
46 ///
47 /// * `function` - Unit of work in function closure to enqueue.
48 pub(super) fn enqueue(&self, function: Work) -> Result<(), Error> {
49 self.0.send(function).map_err(|_| Error::Runtime)
50 }
51}
52
53/// Enqueue work on the runtime without caring about the return value. This is useful in situations
54/// where work must be performed but the result does not matter. For example, when destorying CUDA
55/// object as part of dropping an object.
56///
57/// # Arguments
58///
59/// * `f` - Function closure to execute on runtime.
60///
61/// # Example
62///
63/// ```ignore
64/// enqueue_decoupled(move || {
65/// // ...
66/// });
67/// ```
68#[inline]
69pub fn enqueue_decoupled(f: impl FnOnce() + Send + 'static) {
70 let f = Box::new(f);
71 RUNTIME_THREAD_LOCAL
72 .with(|runtime| runtime.enqueue(Work::new(f)))
73 .expect("runtime broken")
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79
80 #[test]
81 fn test_enqueue_works() {
82 let (tx, rx) = std::sync::mpsc::channel();
83 assert!(RUNTIME_THREAD_LOCAL
84 .with(|runtime| {
85 runtime.enqueue(Work::new(move || {
86 assert!(tx.send(true).is_ok());
87 }))
88 })
89 .is_ok());
90 assert!(matches!(
91 rx.recv_timeout(std::time::Duration::from_millis(100)),
92 Ok(true),
93 ));
94 }
95
96 #[test]
97 fn test_enqueue_decoupled_works() {
98 let (tx, rx) = std::sync::mpsc::channel();
99 enqueue_decoupled(move || {
100 assert!(tx.send(true).is_ok());
101 });
102 assert!(matches!(
103 rx.recv_timeout(std::time::Duration::from_millis(100)),
104 Ok(true),
105 ));
106 }
107}