rstsr_openblas/
threading.rs1use crate::prelude_dev::*;
4#[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
5use core::ffi::c_int;
6use rstsr_blas_traits::prelude_dev::*;
7use std::sync::Mutex;
8
9use rstsr_openblas_ffi::cblas::{OPENBLAS_OPENMP, OPENBLAS_SEQUENTIAL, OPENBLAS_THREAD};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OpenBLASParallel {
19 Sequential,
20 Thread,
21 OpenMP,
22}
23
24pub fn get_parallel() -> OpenBLASParallel {
25 unsafe {
26 match rstsr_openblas_ffi::cblas::openblas_get_parallel().try_into().unwrap() {
27 OPENBLAS_SEQUENTIAL => OpenBLASParallel::Sequential,
28 OPENBLAS_THREAD => OpenBLASParallel::Thread,
29 OPENBLAS_OPENMP => {
30 if cfg!(any(feature = "openmp", feature = "dynamic_loading")) {
31 OpenBLASParallel::OpenMP
32 } else {
33 panic!(concat!(
34 "OpenMP is not enabled in `rstsr-openblas-ffi`, but detected using shared library `libopenblas` compiled with OpenMP.\n",
35 "Please either\n",
36 "- enable feature `dynamic_loading` when building `rstsr-openblas` and rebuild this crate, and everything will be determined at runtime;\n",
37 "- enable feature `openmp` when building `rstsr-openblas` and rebuild this crate, with OpenMP library linked;\n",
38 "- run with libopenblas compiled with pthread (rebuild `rstsr-openblas-ffi` is not required in this case).",
39 ))
40 }
41 },
42 _ => panic!("Unknown parallelism type"),
43 }
44 }
45}
46
47struct OpenBLASConfig {
52 parallel: Option<u32>,
53}
54
55impl OpenBLASConfig {
56 fn set_num_threads(&mut self, n: usize) {
57 unsafe {
58 match self.get_parallel() {
59 OPENBLAS_THREAD => rstsr_openblas_ffi::cblas::openblas_set_num_threads(n as i32),
60 #[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
61 OPENBLAS_OPENMP => rstsr_openblas_ffi::cblas::omp_set_num_threads(n as c_int),
62 _ => (),
63 }
64 }
65 }
66
67 fn get_num_threads(&mut self) -> usize {
68 unsafe {
69 match self.get_parallel() {
70 OPENBLAS_THREAD => rstsr_openblas_ffi::cblas::openblas_get_num_threads() as usize,
71 #[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
72 OPENBLAS_OPENMP => rstsr_openblas_ffi::cblas::omp_get_max_threads() as usize,
73 _ => 1,
74 }
75 }
76 }
77
78 fn get_parallel(&mut self) -> u32 {
79 match self.parallel {
80 Some(p) => p,
81 None => {
82 let p = unsafe { rstsr_openblas_ffi::cblas::openblas_get_parallel() } as u32;
83 if cfg!(any(feature = "openmp", feature = "dynamic_loading")) {
84 self.parallel = Some(p);
85 p
86 } else {
87 panic!(concat!(
88 "OpenMP is not enabled in `rstsr-openblas-ffi`, but detected using shared library `libopenblas` compiled with OpenMP.\n",
89 "Please either\n",
90 "- enable feature `dynamic_loading` when building `rstsr-openblas` and rebuild this crate, and everything will be determined at runtime;\n",
91 "- enable feature `openmp` when building `rstsr-openblas` and rebuild this crate, with OpenMP library linked;\n",
92 "- run with libopenblas compiled with pthread (rebuild `rstsr-openblas-ffi` is not required in this case).",
93 ))
94 }
95 },
96 }
97 }
98}
99
100static INTERNAL: Mutex<OpenBLASConfig> = Mutex::new(OpenBLASConfig { parallel: None });
101
102pub fn set_num_threads(n: usize) {
106 INTERNAL.lock().unwrap().set_num_threads(n);
107}
108
109pub fn get_num_threads() -> usize {
110 INTERNAL.lock().unwrap().get_num_threads()
111}
112
113pub fn with_num_threads<F, R>(nthreads: usize, f: F) -> R
114where
115 F: FnOnce() -> R,
116{
117 let n = get_num_threads();
118 set_num_threads(nthreads);
119 let r = f();
120 set_num_threads(n);
121 return r;
122}
123
124impl BlasThreadAPI for DeviceBLAS {
129 fn get_blas_num_threads(&self) -> usize {
130 crate::threading::get_num_threads()
131 }
132
133 fn set_blas_num_threads(&self, nthreads: usize) {
134 crate::threading::set_num_threads(nthreads);
135 }
136
137 fn with_blas_num_threads<T>(&self, nthreads: usize, f: impl FnOnce() -> T) -> T {
138 crate::threading::with_num_threads(nthreads, f)
139 }
140}
141
142