1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::mpsc::{channel, Receiver, Sender};
use std::sync::{Arc, Mutex};
use once_cell::sync::Lazy;
use crate::runtime::thread_local::RuntimeThreadLocal;
use crate::runtime::work::Work;
/// Refers to the global runtime. The runtime is responsible for running all CUDA operations in a
/// dedicated thread.
///
/// Note that this object should not be used by callers because each thread gets its own delegate
/// object to communicate with the runtime.
///
/// # Usage
///
/// Each thread should get its own [`RuntimeThreadLocal`] object, which acts as delegate object.
///
/// Use `Runtime::thread_local` to get the thread local object:
///
/// ```ignore
/// let runtime = RUNTIME.lock().unwrap().thread_local();
/// ```
pub(super) static RUNTIME: Lazy<Mutex<Runtime>> = Lazy::new(|| Mutex::new(Runtime::new()));
/// Runtime object that holds the runtime thread and a channel
/// to send jobs onto the worker queue.
pub struct Runtime {
join_handle: Option<std::thread::JoinHandle<()>>,
run_flag: Arc<AtomicBool>,
work_tx: Sender<Work>,
}
impl Runtime {
/// Acquire a thread local delegate for the runtime.
pub(super) fn thread_local(&self) -> RuntimeThreadLocal {
RuntimeThreadLocal::from_sender(self.work_tx.clone())
}
/// Create runtime.
fn new() -> Self {
let run_flag = Arc::new(AtomicBool::new(true));
let (work_tx, work_rx) = channel::<Work>();
let join_handle = std::thread::spawn({
let run_flag = run_flag.clone();
move || Self::worker(run_flag, work_rx)
});
Runtime {
join_handle: Some(join_handle),
run_flag,
work_tx,
}
}
/// Worker loop. Receives jobs from the worker queue and executes them until [`run_flag`]
/// becomes `false`.
///
/// # Arguments
///
/// * `run_flag` - Atomic flag that indicates whether the worker should continue running.
/// * `work_rx` - Receives work to execute.
fn worker(run_flag: Arc<AtomicBool>, work_rx: Receiver<Work>) {
while run_flag.load(Ordering::Relaxed) {
match work_rx.recv() {
Ok(work) => work.run(),
Err(_) => break,
}
}
}
}
impl Drop for Runtime {
fn drop(&mut self) {
self.run_flag.store(false, Ordering::Relaxed);
// Put dummy workload into the queue to trigger the loop to continue and encounted the
// `run_flag` that is now false, then stop. Note that if this fails, it means the underlying
// channel is broken. It is not a problem, since that must mean the worker already quit
// before, and it will join immediatly.
let _ = self.work_tx.send(Work::new(|| {}));
if let Some(join_handle) = self.join_handle.take() {
join_handle
.join()
.expect("failed to join on runtime thread");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_drop() {
let runtime = Runtime::new();
std::thread::sleep(std::time::Duration::from_millis(10));
drop(runtime);
}
#[test]
fn test_it_does_work() {
let runtime = Runtime::new();
let (tx, rx) = std::sync::mpsc::channel();
assert!(runtime
.thread_local()
.enqueue(Work::new(move || {
assert!(tx.send(true).is_ok());
}))
.is_ok());
assert!(matches!(
rx.recv_timeout(std::time::Duration::from_millis(100)),
Ok(true),
));
}
}