rstsr_openblas/
threading.rs

1//! openblas threading
2
3use crate::prelude_dev::*;
4#[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
5use core::ffi::c_int;
6use rstsr_blas_traits::prelude_dev::*;
7use std::sync::Mutex;
8
9use rstsr_openblas_ffi::cblas::{OPENBLAS_OPENMP, OPENBLAS_SEQUENTIAL, OPENBLAS_THREAD};
10
11/* #region required openmp ffi */
12
13/* #endregion */
14
15/* #region parallel scheme */
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum OpenBLASParallel {
19    Sequential,
20    Thread,
21    OpenMP,
22}
23
24pub fn get_parallel() -> OpenBLASParallel {
25    unsafe {
26        match rstsr_openblas_ffi::cblas::openblas_get_parallel().try_into().unwrap() {
27            OPENBLAS_SEQUENTIAL => OpenBLASParallel::Sequential,
28            OPENBLAS_THREAD => OpenBLASParallel::Thread,
29            OPENBLAS_OPENMP => {
30                if cfg!(any(feature = "openmp", feature = "dynamic_loading")) {
31                    OpenBLASParallel::OpenMP
32                } else {
33                    panic!(concat!(
34                        "OpenMP is not enabled in `rstsr-openblas-ffi`, but detected using shared library `libopenblas` compiled with OpenMP.\n",
35                        "Please either\n",
36                        "- enable feature `dynamic_loading` when building `rstsr-openblas` and rebuild this crate, and everything will be determined at runtime;\n",
37                        "- enable feature `openmp` when building `rstsr-openblas` and rebuild this crate, with OpenMP library linked;\n",
38                        "- run with libopenblas compiled with pthread (rebuild `rstsr-openblas-ffi` is not required in this case).",
39                    ))
40                }
41            },
42            _ => panic!("Unknown parallelism type"),
43        }
44    }
45}
46
47/* #endregion */
48
49/* #region threading number control */
50
51struct OpenBLASConfig {
52    parallel: Option<u32>,
53}
54
55impl OpenBLASConfig {
56    fn set_num_threads(&mut self, n: usize) {
57        unsafe {
58            match self.get_parallel() {
59                OPENBLAS_THREAD => rstsr_openblas_ffi::cblas::openblas_set_num_threads(n as i32),
60                #[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
61                OPENBLAS_OPENMP => rstsr_openblas_ffi::cblas::omp_set_num_threads(n as c_int),
62                _ => (),
63            }
64        }
65    }
66
67    fn get_num_threads(&mut self) -> usize {
68        unsafe {
69            match self.get_parallel() {
70                OPENBLAS_THREAD => rstsr_openblas_ffi::cblas::openblas_get_num_threads() as usize,
71                #[cfg(any(feature = "openmp", feature = "dynamic_loading"))]
72                OPENBLAS_OPENMP => rstsr_openblas_ffi::cblas::omp_get_max_threads() as usize,
73                _ => 1,
74            }
75        }
76    }
77
78    fn get_parallel(&mut self) -> u32 {
79        match self.parallel {
80            Some(p) => p,
81            None => {
82                let p = unsafe { rstsr_openblas_ffi::cblas::openblas_get_parallel() } as u32;
83                if cfg!(any(feature = "openmp", feature = "dynamic_loading")) {
84                    self.parallel = Some(p);
85                    p
86                } else {
87                    panic!(concat!(
88                        "OpenMP is not enabled in `rstsr-openblas-ffi`, but detected using shared library `libopenblas` compiled with OpenMP.\n",
89                        "Please either\n",
90                        "- enable feature `dynamic_loading` when building `rstsr-openblas` and rebuild this crate, and everything will be determined at runtime;\n",
91                        "- enable feature `openmp` when building `rstsr-openblas` and rebuild this crate, with OpenMP library linked;\n",
92                        "- run with libopenblas compiled with pthread (rebuild `rstsr-openblas-ffi` is not required in this case).",
93                    ))
94                }
95            },
96        }
97    }
98}
99
100static INTERNAL: Mutex<OpenBLASConfig> = Mutex::new(OpenBLASConfig { parallel: None });
101
102/// Set number of threads for OpenBLAS.
103///
104/// This function should be safe to call from multiple threads.
105pub fn set_num_threads(n: usize) {
106    INTERNAL.lock().unwrap().set_num_threads(n);
107}
108
109pub fn get_num_threads() -> usize {
110    INTERNAL.lock().unwrap().get_num_threads()
111}
112
113pub fn with_num_threads<F, R>(nthreads: usize, f: F) -> R
114where
115    F: FnOnce() -> R,
116{
117    let n = get_num_threads();
118    set_num_threads(nthreads);
119    let r = f();
120    set_num_threads(n);
121    return r;
122}
123
124/* #endregion */
125
126/* #region trait impl */
127
128impl BlasThreadAPI for DeviceBLAS {
129    fn get_blas_num_threads(&self) -> usize {
130        crate::threading::get_num_threads()
131    }
132
133    fn set_blas_num_threads(&self, nthreads: usize) {
134        crate::threading::set_num_threads(nthreads);
135    }
136
137    fn with_blas_num_threads<T>(&self, nthreads: usize, f: impl FnOnce() -> T) -> T {
138        crate::threading::with_num_threads(nthreads, f)
139    }
140}
141
142/* #endregion */