ec_gpu_gen/
threadpool.rs

1//! An interface for dealing with the kinds of parallel computations involved.
2use std::env;
3
4use crossbeam_channel::{bounded, Receiver, SendError};
5use log::trace;
6use once_cell::sync::Lazy;
7use yastl::Pool;
8
9/// The number of threads the thread pool should use.
10///
11/// By default it's equal to the number of CPUs, but it can be changed with the
12/// `EC_GPU_NUM_THREADS` environment variable.
13static NUM_THREADS: Lazy<usize> = Lazy::new(read_num_threads);
14
15/// The thread pool that is used for the computations.
16///
17/// By default, it's size is equal to the number of CPUs. It can be set to a different value with
18/// the `EC_GPU_NUM_THREADS` environment variable.
19pub static THREAD_POOL: Lazy<Pool> = Lazy::new(|| Pool::new(*NUM_THREADS));
20
21/// Returns the number of threads.
22///
23/// The number can be set with the `EC_GPU_NUM_THREADS` environment variable. If it isn't set, it
24/// defaults to the number of CPUs the system has.
25fn read_num_threads() -> usize {
26    env::var("EC_GPU_NUM_THREADS")
27        .ok()
28        .and_then(|num| num.parse::<usize>().ok())
29        .unwrap_or_else(num_cpus::get)
30}
31
32/// A worker operates on a pool of threads.
33#[derive(Clone, Default)]
34pub struct Worker {}
35
36impl Worker {
37    /// Returns a new worker.
38    pub fn new() -> Worker {
39        Worker {}
40    }
41
42    /// Returns binary logarithm (floored) of the number of threads.
43    ///
44    /// This means, the number of threads is `2^log_num_threads()`.
45    pub fn log_num_threads(&self) -> u32 {
46        log2_floor(*NUM_THREADS)
47    }
48
49    /// Executes a function in a thread and returns a [`Waiter`] immediately.
50    pub fn compute<F, R>(&self, f: F) -> Waiter<R>
51    where
52        F: FnOnce() -> R + Send + 'static,
53        R: Send + 'static,
54    {
55        let (sender, receiver) = bounded(1);
56
57        THREAD_POOL.spawn(move || {
58            let res = f();
59            // Best effort. We run it in a separate thread, so the receiver might not exist
60            // anymore, but that's OK. It only means that we are not interested in the result.
61            // A message is logged though, as concurrency issues are hard to debug and this might
62            // help in such cases.
63            if let Err(SendError(_)) = sender.send(res) {
64                trace!("Cannot send result");
65            }
66        });
67
68        Waiter { receiver }
69    }
70
71    /// Executes a function and returns the result once it is finished.
72    ///
73    /// The function gets the [`yastl::Scope`] as well as the `chunk_size` as parameters. THe
74    /// `chunk_size` is number of elements per thread.
75    pub fn scope<'a, F, R>(&self, elements: usize, f: F) -> R
76    where
77        F: FnOnce(&yastl::Scope<'a>, usize) -> R,
78    {
79        let chunk_size = if elements < *NUM_THREADS {
80            1
81        } else {
82            elements / *NUM_THREADS
83        };
84
85        THREAD_POOL.scoped(|scope| f(scope, chunk_size))
86    }
87
88    /// Executes the passed in function, and returns the result once it is finished.
89    pub fn scoped<'a, F, R>(&self, f: F) -> R
90    where
91        F: FnOnce(&yastl::Scope<'a>) -> R,
92    {
93        let (sender, receiver) = bounded(1);
94        THREAD_POOL.scoped(|s| {
95            let res = f(s);
96            sender.send(res).unwrap();
97        });
98
99        receiver.recv().unwrap()
100    }
101}
102
103/// A future that is waiting for a result.
104pub struct Waiter<T> {
105    receiver: Receiver<T>,
106}
107
108impl<T> Waiter<T> {
109    /// Wait for the result.
110    pub fn wait(&self) -> T {
111        self.receiver.recv().unwrap()
112    }
113
114    /// One off sending.
115    pub fn done(val: T) -> Self {
116        let (sender, receiver) = bounded(1);
117        sender.send(val).unwrap();
118
119        Waiter { receiver }
120    }
121}
122
123fn log2_floor(num: usize) -> u32 {
124    assert!(num > 0);
125
126    let mut pow = 0;
127
128    while (1 << (pow + 1)) <= num {
129        pow += 1;
130    }
131
132    pow
133}
134
135#[cfg(test)]
136pub mod tests {
137    use super::*;
138
139    #[test]
140    fn test_log2_floor() {
141        assert_eq!(log2_floor(1), 0);
142        assert_eq!(log2_floor(3), 1);
143        assert_eq!(log2_floor(4), 2);
144        assert_eq!(log2_floor(5), 2);
145        assert_eq!(log2_floor(6), 2);
146        assert_eq!(log2_floor(7), 2);
147        assert_eq!(log2_floor(8), 3);
148    }
149
150    #[test]
151    fn test_read_num_threads() {
152        let num_cpus = num_cpus::get();
153        temp_env::with_var("EC_GPU_NUM_THREADS", None::<&str>, || {
154            assert_eq!(
155                read_num_threads(),
156                num_cpus,
157                "By default the number of threads matches the number of CPUs."
158            );
159        });
160
161        temp_env::with_var("EC_GPU_NUM_THREADS", Some("1234"), || {
162            assert_eq!(
163                read_num_threads(),
164                1234,
165                "Number of threads matches the environment variable."
166            );
167        });
168    }
169}