tract_linalg/
multithread.rs1use std::cell::RefCell;
2#[cfg(feature = "multithread-mm")]
3use std::sync::atomic::{AtomicUsize, Ordering};
4#[allow(unused_imports)]
5use std::sync::{Arc, Mutex};
6
7#[cfg(feature = "multithread-mm")]
8use rayon::{ThreadPool, ThreadPoolBuilder};
9
10#[derive(Debug, Clone, Default)]
11pub enum Executor {
12 #[default]
13 SingleThread,
14 #[cfg(feature = "multithread-mm")]
15 MultiThread(Arc<ThreadPool>),
16 #[cfg(feature = "multithread-mm")]
25 RayonGlobal,
26}
27
28impl Executor {
29 #[cfg(feature = "multithread-mm")]
30 pub fn multithread(n: usize) -> Executor {
31 Executor::multithread_with_name(n, "tract-default")
32 }
33
34 #[cfg(feature = "multithread-mm")]
35 pub fn multithread_with_name(n: usize, name: &str) -> Executor {
36 let name = name.to_string();
37 let pool = ThreadPoolBuilder::new()
38 .thread_name(move |n| format!("{name}-{n}"))
39 .num_threads(n)
40 .build()
41 .unwrap();
42 Executor::MultiThread(Arc::new(pool))
43 }
44}
45
46static DEFAULT_EXECUTOR: Mutex<Executor> = Mutex::new(Executor::SingleThread);
47
48thread_local! {
49 static TLS_EXECUTOR_OVERRIDE: RefCell<Option<Executor>> = Default::default();
50}
51
52pub fn current_tract_executor() -> Executor {
53 if let Some(over_ride) = TLS_EXECUTOR_OVERRIDE.with_borrow(|tls| tls.clone()) {
54 over_ride
55 } else {
56 DEFAULT_EXECUTOR.lock().unwrap().clone()
57 }
58}
59
60pub fn set_default_executor(executor: Executor) {
61 *DEFAULT_EXECUTOR.lock().unwrap() = executor;
62}
63
64pub fn multithread_tract_scope<R, F: FnOnce() -> R>(pool: Executor, f: F) -> R {
65 let previous = TLS_EXECUTOR_OVERRIDE.replace(Some(pool));
66 let result = f();
67 TLS_EXECUTOR_OVERRIDE.set(previous);
68 result
69}
70
71#[cfg(feature = "multithread-mm")]
80static THREADING_PANEL_THRESHOLD: AtomicUsize = AtomicUsize::new(64);
81
82#[cfg(feature = "multithread-mm")]
84pub fn current_threading_panel_threshold() -> usize {
85 THREADING_PANEL_THRESHOLD.load(Ordering::Relaxed)
86}
87
88#[cfg(feature = "multithread-mm")]
91pub fn set_threading_panel_threshold(panels: usize) {
92 THREADING_PANEL_THRESHOLD.store(panels, Ordering::Relaxed);
93}