use crate::prelude_dev::*;
extern crate alloc;
use alloc::sync::Arc;
pub trait DeviceRayonAPI {
fn set_num_threads(&mut self, num_threads: usize);
fn get_num_threads(&self) -> usize;
fn get_pool(&self) -> &ThreadPool;
fn get_current_pool(&self) -> Option<&ThreadPool>;
}
#[derive(Clone, Debug)]
pub struct DeviceCpuRayon {
num_threads: usize,
pool: Arc<ThreadPool>,
default_order: FlagOrder,
}
impl DeviceCpuRayon {
pub fn new(num_threads: usize) -> Self {
let pool = Arc::new(Self::generate_pool(num_threads).unwrap());
DeviceCpuRayon { num_threads, pool, default_order: FlagOrder::default() }
}
fn generate_pool(n: usize) -> Result<ThreadPool> {
let actual_threads = if n == 0 { rayon::current_num_threads() } else { n };
rayon::ThreadPoolBuilder::new().num_threads(actual_threads).build().map_err(Error::from)
}
}
impl Default for DeviceCpuRayon {
fn default() -> Self {
DeviceCpuRayon::new(0)
}
}
impl DeviceBaseAPI for DeviceCpuRayon {
fn same_device(&self, other: &Self) -> bool {
self.default_order == other.default_order
}
fn default_order(&self) -> FlagOrder {
self.default_order
}
fn set_default_order(&mut self, order: FlagOrder) {
self.default_order = order;
}
}
impl DeviceRayonAPI for DeviceCpuRayon {
#[inline]
fn set_num_threads(&mut self, num_threads: usize) {
let num_threads_old = self.num_threads;
if num_threads_old != num_threads {
let pool = Self::generate_pool(num_threads).unwrap();
self.num_threads = num_threads;
self.pool = Arc::new(pool);
}
}
#[inline]
fn get_num_threads(&self) -> usize {
match self.num_threads {
0 => self.pool.current_num_threads(),
_ => self.num_threads,
}
}
#[inline]
fn get_pool(&self) -> &ThreadPool {
self.pool.as_ref()
}
#[inline]
fn get_current_pool(&self) -> Option<&ThreadPool> {
match rayon::current_thread_index() {
Some(_) => None,
None => Some(self.pool.as_ref()),
}
}
}