tract_linalg/
multithread.rs

1use std::cell::RefCell;
2#[allow(unused_imports)]
3use std::sync::{Arc, Mutex};
4
5#[cfg(feature = "multithread-mm")]
6use rayon::{ThreadPool, ThreadPoolBuilder};
7
8#[derive(Debug, Clone, Default)]
9pub enum Executor {
10    #[default]
11    SingleThread,
12    #[cfg(feature = "multithread-mm")]
13    MultiThread(Arc<ThreadPool>),
14}
15
16impl Executor {
17    #[cfg(feature = "multithread-mm")]
18    pub fn multithread(n: usize) -> Executor {
19        Executor::multithread_with_name(n, "tract-default")
20    }
21
22    #[cfg(feature = "multithread-mm")]
23    pub fn multithread_with_name(n: usize, name: &str) -> Executor {
24        let name = name.to_string();
25        let pool = ThreadPoolBuilder::new()
26            .thread_name(move |n| format!("{name}-{n}"))
27            .num_threads(n)
28            .build()
29            .unwrap();
30        Executor::MultiThread(Arc::new(pool))
31    }
32}
33
34static DEFAULT_EXECUTOR: Mutex<Executor> = Mutex::new(Executor::SingleThread);
35
36thread_local! {
37    static TLS_EXECUTOR_OVERRIDE: RefCell<Option<Executor>> = Default::default();
38}
39
40pub fn current_tract_executor() -> Executor {
41    if let Some(over_ride) = TLS_EXECUTOR_OVERRIDE.with_borrow(|tls| tls.clone()) {
42        over_ride
43    } else {
44        DEFAULT_EXECUTOR.lock().unwrap().clone()
45    }
46}
47
48pub fn set_default_executor(executor: Executor) {
49    *DEFAULT_EXECUTOR.lock().unwrap() = executor;
50}
51
52pub fn multithread_tract_scope<R, F: FnOnce() -> R>(pool: Executor, f: F) -> R {
53    let previous = TLS_EXECUTOR_OVERRIDE.replace(Some(pool));
54    let result = f();
55    TLS_EXECUTOR_OVERRIDE.set(previous);
56    result
57}