use std::env;
use crossbeam_channel::{bounded, Receiver, SendError};
use log::trace;
use once_cell::sync::Lazy;
use yastl::Pool;
static NUM_THREADS: Lazy<usize> = Lazy::new(read_num_threads);
pub static THREAD_POOL: Lazy<Pool> = Lazy::new(|| Pool::new(*NUM_THREADS));
fn read_num_threads() -> usize {
env::var("EC_GPU_NUM_THREADS")
.ok()
.and_then(|num| num.parse::<usize>().ok())
.unwrap_or_else(num_cpus::get)
}
#[derive(Clone, Default)]
pub struct Worker {}
impl Worker {
pub fn new() -> Worker {
Worker {}
}
pub fn log_num_threads(&self) -> u32 {
log2_floor(*NUM_THREADS)
}
pub fn compute<F, R>(&self, f: F) -> Waiter<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let (sender, receiver) = bounded(1);
THREAD_POOL.spawn(move || {
let res = f();
if let Err(SendError(_)) = sender.send(res) {
trace!("Cannot send result");
}
});
Waiter { receiver }
}
pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>, usize) -> R,
{
let chunk_size = if elements < *NUM_THREADS {
1
} else {
elements / *NUM_THREADS
};
THREAD_POOL.scoped(|scope| f(scope, chunk_size))
}
pub fn scoped<'a, F, R>(&self, f: F) -> R
where
F: FnOnce(&yastl::Scope<'a>) -> R,
{
let (sender, receiver) = bounded(1);
THREAD_POOL.scoped(|s| {
let res = f(s);
sender.send(res).unwrap();
});
receiver.recv().unwrap()
}
}
pub struct Waiter<T> {
receiver: Receiver<T>,
}
impl<T> Waiter<T> {
pub fn wait(&self) -> T {
self.receiver.recv().unwrap()
}
pub fn done(val: T) -> Self {
let (sender, receiver) = bounded(1);
sender.send(val).unwrap();
Waiter { receiver }
}
}
fn log2_floor(num: usize) -> u32 {
assert!(num > 0);
let mut pow = 0;
while (1 << (pow + 1)) <= num {
pow += 1;
}
pow
}
#[cfg(test)]
pub mod tests {
use super::*;
#[test]
fn test_log2_floor() {
assert_eq!(log2_floor(1), 0);
assert_eq!(log2_floor(3), 1);
assert_eq!(log2_floor(4), 2);
assert_eq!(log2_floor(5), 2);
assert_eq!(log2_floor(6), 2);
assert_eq!(log2_floor(7), 2);
assert_eq!(log2_floor(8), 3);
}
#[test]
fn test_read_num_threads() {
let num_cpus = num_cpus::get();
temp_env::with_var("EC_GPU_NUM_THREADS", None::<&str>, || {
assert_eq!(
read_num_threads(),
num_cpus,
"By default the number of threads matches the number of CPUs."
);
});
temp_env::with_var("EC_GPU_NUM_THREADS", Some("1234"), || {
assert_eq!(
read_num_threads(),
1234,
"Number of threads matches the environment variable."
);
});
}
}