async_cuda/runtime/
thread_local.rs

1use std::sync::mpsc::Sender;
2
3use once_cell::sync::Lazy;
4
5use crate::error::Error;
6use crate::runtime::execution::RUNTIME;
7use crate::runtime::work::Work;
8
9thread_local! {
10    /// Thread-local runtime delegate.
11    ///
12    /// This object serves as the per-thread reference to the [`RUNTIME`] that can be used to
13    /// enqueue work on the runtime thread.
14    ///
15    /// # Usage
16    ///
17    /// ```ignore
18    /// assert!(
19    ///     RUNTIME_THREAD_LOCAL.with(|runtime|
20    ///         runtime.enqueue(Work::new(|| ()))
21    ///     ).is_ok()
22    /// )
23    /// ```
24    pub(super) static RUNTIME_THREAD_LOCAL: Lazy<RuntimeThreadLocal> = Lazy::new(|| {
25        RUNTIME.lock().unwrap().thread_local()
26    });
27}
28
29/// Per-thread delegate for global runtime.
30pub struct RuntimeThreadLocal(Sender<Work>);
31
32impl RuntimeThreadLocal {
33    /// Initialize [`RuntimeThreadLocal`] from [`Sender`] that allows the delegate to send work to
34    /// the actual [`crate::runtime::execution::Runtime`].
35    ///
36    /// # Arguments
37    ///
38    /// * `sender` - Sender through which work can be sent to runtime.
39    pub(super) fn from_sender(sender: Sender<Work>) -> Self {
40        RuntimeThreadLocal(sender)
41    }
42
43    /// Enqueue work on runtime.
44    ///
45    /// # Arguments
46    ///
47    /// * `function` - Unit of work in function closure to enqueue.
48    pub(super) fn enqueue(&self, function: Work) -> Result<(), Error> {
49        self.0.send(function).map_err(|_| Error::Runtime)
50    }
51}
52
53/// Enqueue work on the runtime without caring about the return value. This is useful in situations
54/// where work must be performed but the result does not matter. For example, when destorying CUDA
55/// object as part of dropping an object.
56///
57/// # Arguments
58///
59/// * `f` - Function closure to execute on runtime.
60///
61/// # Example
62///
63/// ```ignore
64/// enqueue_decoupled(move || {
65///     // ...
66/// });
67/// ```
68#[inline]
69pub fn enqueue_decoupled(f: impl FnOnce() + Send + 'static) {
70    let f = Box::new(f);
71    RUNTIME_THREAD_LOCAL
72        .with(|runtime| runtime.enqueue(Work::new(f)))
73        .expect("runtime broken")
74}
75
76#[cfg(test)]
77mod tests {
78    use super::*;
79
80    #[test]
81    fn test_enqueue_works() {
82        let (tx, rx) = std::sync::mpsc::channel();
83        assert!(RUNTIME_THREAD_LOCAL
84            .with(|runtime| {
85                runtime.enqueue(Work::new(move || {
86                    assert!(tx.send(true).is_ok());
87                }))
88            })
89            .is_ok());
90        assert!(matches!(
91            rx.recv_timeout(std::time::Duration::from_millis(100)),
92            Ok(true),
93        ));
94    }
95
96    #[test]
97    fn test_enqueue_decoupled_works() {
98        let (tx, rx) = std::sync::mpsc::channel();
99        enqueue_decoupled(move || {
100            assert!(tx.send(true).is_ok());
101        });
102        assert!(matches!(
103            rx.recv_timeout(std::time::Duration::from_millis(100)),
104            Ok(true),
105        ));
106    }
107}