rstsr_core/feature_rayon/
device.rs

1use crate::prelude_dev::*;
2
3extern crate alloc;
4use alloc::sync::Arc;
5
6pub trait DeviceRayonAPI {
7    /// Set the number of threads for the device.
8    fn set_num_threads(&mut self, num_threads: usize);
9
10    /// Get the number of threads for the device.
11    ///
12    /// This function should give the number of threads for the pool. It is not
13    /// related to whether the current work is done in parallel or serial.
14    fn get_num_threads(&self) -> usize;
15
16    /// Get the thread pool for the device.
17    ///
18    /// **Note**:
19    ///
20    /// For developers, this function should not be used directly. Instead, use
21    /// `get_current_pool` to detect whether using thread pool of its own (Some)
22    /// or using parent thread pool (None).
23    fn get_pool(&self) -> &ThreadPool;
24
25    /// Get the current thread pool for the device.
26    ///
27    /// - If in parallel worker, this returns None. This means the program should use the thread
28    ///   pool from the parent. It is important that this does not necessarily means this work
29    ///   should be done in serial.
30    /// - If not in rayon parallel worker, this returns the thread pool.
31    fn get_current_pool(&self) -> Option<&ThreadPool>;
32}
33
34/// This is base device for Parallel CPU device.
35///
36/// This device is not intended to be used directly, but to be used as a base.
37/// Possible inherited devices could be Faer or Blas.
38///
39/// This device is intended not to implement `DeviceAPI<T>`.
40#[derive(Clone, Debug)]
41pub struct DeviceCpuRayon {
42    num_threads: usize,
43    pool: Arc<ThreadPool>,
44    default_order: FlagOrder,
45}
46
47impl DeviceCpuRayon {
48    pub fn new(num_threads: usize) -> Self {
49        let pool = Arc::new(Self::generate_pool(num_threads).unwrap());
50        DeviceCpuRayon { num_threads, pool, default_order: FlagOrder::default() }
51    }
52
53    /// Generate a new thread pool with the given number of threads.
54    ///
55    /// If the number of threads is 0, the current number of threads will be used.
56    ///
57    /// Notes for developers:
58    /// - This function will still gives number of threads > 1 when inside parallelled rayon thread
59    ///   pool.
60    /// - For input number of threads 0, this function technically **DOES NOT** give thread pool
61    ///   that relates to `RAYON_NUM_THREADS`, but the number of threads of global thread pool
62    ///   instead. That is to say, the priority of number of threads is
63    ///   - The value of current number of threads, if this function is called inside a user custom
64    ///     thread pool.
65    ///   - The value you initialized the rayon's global thread pool before calling this function:
66    ///
67    ///     ```rust,ignore
68    ///     rayon::ThreadPoolBuilder::new().num_threads(xxx).build_global().unwrap()
69    ///     ```
70    ///   - The value you have declared in environmental variable `RAYON_NUM_THREADS`.
71    ///   - The number of logical CPUs on the machine.
72    fn generate_pool(n: usize) -> Result<ThreadPool> {
73        let actual_threads = if n == 0 { rayon::current_num_threads() } else { n };
74        rayon::ThreadPoolBuilder::new().num_threads(actual_threads).build().map_err(Error::from)
75    }
76}
77
78impl Default for DeviceCpuRayon {
79    fn default() -> Self {
80        DeviceCpuRayon::new(0)
81    }
82}
83
84impl DeviceBaseAPI for DeviceCpuRayon {
85    fn same_device(&self, other: &Self) -> bool {
86        self.default_order == other.default_order
87    }
88
89    fn default_order(&self) -> FlagOrder {
90        self.default_order
91    }
92
93    fn set_default_order(&mut self, order: FlagOrder) {
94        self.default_order = order;
95    }
96}
97
98impl DeviceRayonAPI for DeviceCpuRayon {
99    #[inline]
100    fn set_num_threads(&mut self, num_threads: usize) {
101        let num_threads_old = self.num_threads;
102        if num_threads_old != num_threads {
103            let pool = Self::generate_pool(num_threads).unwrap();
104            self.num_threads = num_threads;
105            self.pool = Arc::new(pool);
106        }
107    }
108
109    #[inline]
110    fn get_num_threads(&self) -> usize {
111        match self.num_threads {
112            0 => self.pool.current_num_threads(),
113            _ => self.num_threads,
114        }
115    }
116
117    #[inline]
118    fn get_pool(&self) -> &ThreadPool {
119        self.pool.as_ref()
120    }
121
122    #[inline]
123    fn get_current_pool(&self) -> Option<&ThreadPool> {
124        match rayon::current_thread_index() {
125            Some(_) => None,
126            None => Some(self.pool.as_ref()),
127        }
128    }
129}