async_cuda_core/runtime/
execution.rs

1use std::sync::atomic::{AtomicBool, Ordering};
2use std::sync::mpsc::{channel, Receiver, Sender};
3use std::sync::{Arc, Mutex};
4
5use once_cell::sync::Lazy;
6
7use crate::runtime::thread_local::RuntimeThreadLocal;
8use crate::runtime::work::Work;
9
10/// Refers to the global runtime. The runtime is responsible for running all CUDA operations in a
11/// dedicated thread.
12///
13/// Note that this object should not be used by callers because each thread gets its own delegate
14/// object to communicate with the runtime.
15///
16/// # Usage
17///
18/// Each thread should get its own [`RuntimeThreadLocal`] object, which acts as delegate object.
19///
20/// Use `Runtime::thread_local` to get the thread local object:
21///
22/// ```ignore
23/// let runtime = RUNTIME.lock().unwrap().thread_local();
24/// ```
25pub(super) static RUNTIME: Lazy<Mutex<Runtime>> = Lazy::new(|| Mutex::new(Runtime::new()));
26
27/// Runtime object that holds the runtime thread and a channel
28/// to send jobs onto the worker queue.
29pub struct Runtime {
30    join_handle: Option<std::thread::JoinHandle<()>>,
31    run_flag: Arc<AtomicBool>,
32    work_tx: Sender<Work>,
33}
34
35impl Runtime {
36    /// Acquire a thread local delegate for the runtime.
37    pub(super) fn thread_local(&self) -> RuntimeThreadLocal {
38        RuntimeThreadLocal::from_sender(self.work_tx.clone())
39    }
40
41    /// Create runtime.
42    fn new() -> Self {
43        let run_flag = Arc::new(AtomicBool::new(true));
44        let (work_tx, work_rx) = channel::<Work>();
45
46        let join_handle = std::thread::spawn({
47            let run_flag = run_flag.clone();
48            move || Self::worker(run_flag, work_rx)
49        });
50
51        Runtime {
52            join_handle: Some(join_handle),
53            run_flag,
54            work_tx,
55        }
56    }
57
58    /// Worker loop. Receives jobs from the worker queue and executes them until [`run_flag`]
59    /// becomes `false`.
60    ///
61    /// # Arguments
62    ///
63    /// * `run_flag` - Atomic flag that indicates whether the worker should continue running.
64    /// * `work_rx` - Receives work to execute.
65    fn worker(run_flag: Arc<AtomicBool>, work_rx: Receiver<Work>) {
66        while run_flag.load(Ordering::Relaxed) {
67            match work_rx.recv() {
68                Ok(work) => work.run(),
69                Err(_) => break,
70            }
71        }
72    }
73}
74
75impl Drop for Runtime {
76    fn drop(&mut self) {
77        self.run_flag.store(false, Ordering::Relaxed);
78
79        // Put dummy workload into the queue to trigger the loop to continue and encounted the
80        // `run_flag` that is now false, then stop. Note that if this fails, it means the underlying
81        // channel is broken. It is not a problem, since that must mean the worker already quit
82        // before, and it will join immediatly.
83        let _ = self.work_tx.send(Work::new(|| {}));
84
85        if let Some(join_handle) = self.join_handle.take() {
86            join_handle
87                .join()
88                .expect("failed to join on runtime thread");
89        }
90    }
91}
92
93#[cfg(test)]
94mod tests {
95    use super::*;
96
97    #[test]
98    fn test_drop() {
99        let runtime = Runtime::new();
100        std::thread::sleep(std::time::Duration::from_millis(10));
101        drop(runtime);
102    }
103
104    #[test]
105    fn test_it_does_work() {
106        let runtime = Runtime::new();
107        let (tx, rx) = std::sync::mpsc::channel();
108        assert!(runtime
109            .thread_local()
110            .enqueue(Work::new(move || {
111                assert!(tx.send(true).is_ok());
112            }))
113            .is_ok());
114        assert!(matches!(
115            rx.recv_timeout(std::time::Duration::from_millis(100)),
116            Ok(true),
117        ));
118    }
119}