1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
use std::sync::mpsc::Sender;
use once_cell::sync::Lazy;
use crate::error::Error;
use crate::runtime::execution::RUNTIME;
use crate::runtime::work::Work;
thread_local! {
    /// Thread-local runtime delegate.
    ///
    /// This object serves as the per-thread reference to the [`RUNTIME`] that can be used to
    /// enqueue work on the runtime thread.
    ///
    /// # Usage
    ///
    /// ```ignore
    /// assert!(
    ///     RUNTIME_THREAD_LOCAL.with(|runtime|
    ///         runtime.enqueue(Work::new(|| ()))
    ///     ).is_ok()
    /// )
    /// ```
    pub(super) static RUNTIME_THREAD_LOCAL: Lazy<RuntimeThreadLocal> = Lazy::new(|| {
        RUNTIME.lock().unwrap().thread_local()
    });
}
/// Per-thread delegate for global runtime.
pub struct RuntimeThreadLocal(Sender<Work>);
impl RuntimeThreadLocal {
    /// Initialize [`RuntimeThreadLocal`] from [`Sender`] that allows the delegate to send work to
    /// the actual [`crate::runtime::execution::Runtime`].
    ///
    /// # Arguments
    ///
    /// * `sender` - Sender through which work can be sent to runtime.
    pub(super) fn from_sender(sender: Sender<Work>) -> Self {
        RuntimeThreadLocal(sender)
    }
    /// Enqueue work on runtime.
    ///
    /// # Arguments
    ///
    /// * `function` - Unit of work in function closure to enqueue.
    pub(super) fn enqueue(&self, function: Work) -> Result<(), Error> {
        self.0.send(function).map_err(|_| Error::Runtime)
    }
}
/// Enqueue work on the runtime without caring about the return value. This is useful in situations
/// where work must be performed but the result does not matter. For example, when destorying CUDA
/// object as part of dropping an object.
///
/// # Arguments
///
/// * `f` - Function closure to execute on runtime.
///
/// # Example
///
/// ```ignore
/// enqueue_decoupled(move || {
///     // ...
/// });
/// ```
#[inline]
pub fn enqueue_decoupled(f: impl FnOnce() + Send + 'static) {
    let f = Box::new(f);
    RUNTIME_THREAD_LOCAL
        .with(|runtime| runtime.enqueue(Work::new(f)))
        .expect("runtime broken")
}
#[cfg(test)]
mod tests {
    use super::*;
    #[test]
    fn test_enqueue_works() {
        let (tx, rx) = std::sync::mpsc::channel();
        assert!(RUNTIME_THREAD_LOCAL
            .with(|runtime| {
                runtime.enqueue(Work::new(move || {
                    assert!(tx.send(true).is_ok());
                }))
            })
            .is_ok());
        assert!(matches!(
            rx.recv_timeout(std::time::Duration::from_millis(100)),
            Ok(true),
        ));
    }
    #[test]
    fn test_enqueue_decoupled_works() {
        let (tx, rx) = std::sync::mpsc::channel();
        enqueue_decoupled(move || {
            assert!(tx.send(true).is_ok());
        });
        assert!(matches!(
            rx.recv_timeout(std::time::Duration::from_millis(100)),
            Ok(true),
        ));
    }
}