1use std::env;
3
4use crossbeam_channel::{bounded, Receiver, SendError};
5use log::trace;
6use once_cell::sync::Lazy;
7use yastl::Pool;
8
9static NUM_THREADS: Lazy<usize> = Lazy::new(read_num_threads);
14
15pub static THREAD_POOL: Lazy<Pool> = Lazy::new(|| Pool::new(*NUM_THREADS));
20
21fn read_num_threads() -> usize {
26 env::var("EC_GPU_NUM_THREADS")
27 .ok()
28 .and_then(|num| num.parse::<usize>().ok())
29 .unwrap_or_else(num_cpus::get)
30}
31
32#[derive(Clone, Default)]
34pub struct Worker {}
35
36impl Worker {
37 pub fn new() -> Worker {
39 Worker {}
40 }
41
42 pub fn log_num_threads(&self) -> u32 {
46 log2_floor(*NUM_THREADS)
47 }
48
49 pub fn compute<F, R>(&self, f: F) -> Waiter<R>
51 where
52 F: FnOnce() -> R + Send + 'static,
53 R: Send + 'static,
54 {
55 let (sender, receiver) = bounded(1);
56
57 THREAD_POOL.spawn(move || {
58 let res = f();
59 if let Err(SendError(_)) = sender.send(res) {
64 trace!("Cannot send result");
65 }
66 });
67
68 Waiter { receiver }
69 }
70
71 pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
76 where
77 F: FnOnce(&yastl::Scope<'a>, usize) -> R,
78 {
79 let chunk_size = if elements < *NUM_THREADS {
80 1
81 } else {
82 elements / *NUM_THREADS
83 };
84
85 THREAD_POOL.scoped(|scope| f(scope, chunk_size))
86 }
87
88 pub fn scoped<'a, F, R>(&self, f: F) -> R
90 where
91 F: FnOnce(&yastl::Scope<'a>) -> R,
92 {
93 let (sender, receiver) = bounded(1);
94 THREAD_POOL.scoped(|s| {
95 let res = f(s);
96 sender.send(res).unwrap();
97 });
98
99 receiver.recv().unwrap()
100 }
101}
102
103pub struct Waiter<T> {
105 receiver: Receiver<T>,
106}
107
108impl<T> Waiter<T> {
109 pub fn wait(&self) -> T {
111 self.receiver.recv().unwrap()
112 }
113
114 pub fn done(val: T) -> Self {
116 let (sender, receiver) = bounded(1);
117 sender.send(val).unwrap();
118
119 Waiter { receiver }
120 }
121}
122
123fn log2_floor(num: usize) -> u32 {
124 assert!(num > 0);
125
126 let mut pow = 0;
127
128 while (1 << (pow + 1)) <= num {
129 pow += 1;
130 }
131
132 pow
133}
134
135#[cfg(test)]
136pub mod tests {
137 use super::*;
138
139 #[test]
140 fn test_log2_floor() {
141 assert_eq!(log2_floor(1), 0);
142 assert_eq!(log2_floor(3), 1);
143 assert_eq!(log2_floor(4), 2);
144 assert_eq!(log2_floor(5), 2);
145 assert_eq!(log2_floor(6), 2);
146 assert_eq!(log2_floor(7), 2);
147 assert_eq!(log2_floor(8), 3);
148 }
149
150 #[test]
151 fn test_read_num_threads() {
152 let num_cpus = num_cpus::get();
153 temp_env::with_var("EC_GPU_NUM_THREADS", None::<&str>, || {
154 assert_eq!(
155 read_num_threads(),
156 num_cpus,
157 "By default the number of threads matches the number of CPUs."
158 );
159 });
160
161 temp_env::with_var("EC_GPU_NUM_THREADS", Some("1234"), || {
162 assert_eq!(
163 read_num_threads(),
164 1234,
165 "Number of threads matches the environment variable."
166 );
167 });
168 }
169}