1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
#![warn(missing_docs)]
//! Entangled provides thread pool based on the `async-executor` crate to spawn async futures on.
//!
//! It's main selling point is the "scoped spawn" functionality, which is essentially spawning
//! futures from the calling thread, which have access to the stack of the calling thread and joins
//! them after they have completed.
//!
//! # Example
//!
//! ```
//! use entangled::*;
//! use std::sync::atomic::*;
//!
//! let pool = ThreadPool::new(
//!     ThreadPoolDescriptor::default()
//! ).expect("can't create task pool");
//!
//! let counter = AtomicI32::new(0);
//! let ref_counter = &counter;
//!
//! pool.scope(|scope| {
//!     for _ in 0..10 {
//!         scope.spawn(async {
//!             ref_counter.fetch_add(1, Ordering::Relaxed);
//!         });
//!     }
//! });
//!
//! assert_eq!(counter.load(Ordering::Relaxed), 10);
//! ```

use std::sync::Arc;

#[doc(no_inline)]
pub use async_executor::Task;

/// Describes how a `ThreadPool` should be created.
pub struct ThreadPoolDescriptor {
    /// Spawns at most n threads for the thread pool. Default: 2.
    pub num_threads: usize,

    /// The stack size of the spawned threads. Default: 2 MiB.
    pub stack_size: usize,

    /// Name of the threads. Threads will be named:
    /// {thread_name} ({thread index}), i.e. "Thread pool (0)"
    /// Default: "Thread pool"
    pub thread_name: String,

    /// Closure invoked on worker thread start. Closure parameter contains the index of the created thread.
    /// Default: None
    pub start_handler: Option<Box<dyn Fn(usize) + Send + Sync>>,

    /// Closure invoked on worker thread exit. Closure parameter contains the index of the created thread.
    /// Default: None
    pub exit_handler: Option<Box<dyn Fn(usize) + Send + Sync>>,
}

impl Default for ThreadPoolDescriptor {
    fn default() -> Self {
        Self {
            num_threads: 2,
            stack_size: 2 * 1024 * 1024,
            thread_name: "Thread pool".to_owned(),
            start_handler: None,
            exit_handler: None,
        }
    }
}

/// Since the `ThreadPool` is Send + Sync, we need to wrap the threads into an inner struct,
/// so that we can track it's lifetime and properly can shutdown the threads on drop.
#[derive(Debug)]
struct ThreadPoolInner {
    threads: Vec<std::thread::JoinHandle<()>>,
    shutdown_tx: async_channel::Sender<()>,
}

impl Drop for ThreadPoolInner {
    fn drop(&mut self) {
        // Close the sender so that the shutdown is triggered.
        self.shutdown_tx.close();

        for join_handle in self.threads.drain(..) {
            let res = join_handle.join();
            if !std::thread::panicking() {
                res.expect("the task thread panicked while executing");
            }
        }
    }
}

/// A thread pool for executing futures.
///
/// Drives given futures to completion.
#[derive(Debug, Clone)]
pub struct ThreadPool {
    executor: Arc<async_executor::Executor<'static>>,
    inner: Arc<ThreadPoolInner>,
}

impl ThreadPool {
    /// Create a new `ThreadPool`. Thread pools can be freely cloned.
    ///
    /// # How to provide a custom handler
    ///
    /// ```rust
    /// use entangled::*;
    ///
    /// let descriptor = ThreadPoolDescriptor {
    ///     num_threads: 0,
    ///     start_handler: Some(Box::new(|index| {
    ///         println!("Thread {} is starting", index);
    ///     })),
    ///     ..Default::default()
    /// };
    ///
    /// let pool = ThreadPool::new(descriptor).unwrap();
    /// ```
    ///
    pub fn new(descriptor: ThreadPoolDescriptor) -> Result<Self, std::io::Error> {
        let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();

        let executor = Arc::new(async_executor::Executor::new());
        let mut threads = Vec::with_capacity(descriptor.num_threads);

        let descriptor = Arc::new(descriptor);

        for i in 0..descriptor.num_threads {
            let thread_descriptor = descriptor.clone();
            let thread_executor = Arc::clone(&executor);
            let thread_name = format!("{} ({})", descriptor.thread_name, i);
            let thread_shutdown_rx = shutdown_rx.clone();

            let mut thread_builder = std::thread::Builder::new().name(thread_name);
            thread_builder = thread_builder.stack_size(descriptor.stack_size);

            let thread = thread_builder.spawn(move || {
                if let Some(start_handler) = &thread_descriptor.start_handler {
                    start_handler(i)
                }

                let shutdown_future = thread_executor.run(thread_shutdown_rx.recv());

                if let Some(exit_handler) = &thread_descriptor.exit_handler {
                    exit_handler(i)
                }

                // We expect an async_channel::TryRecvError::Closed
                futures_lite::future::block_on(shutdown_future).unwrap_err();
            })?;

            threads.push(thread)
        }

        Ok(Self {
            executor,
            inner: Arc::new(ThreadPoolInner {
                threads,
                shutdown_tx,
            }),
        })
    }

    /// Creates a "fork-join" scope `s` and invokes the closure with a reference to `s`.
    /// This closure can then spawn futures into `s`. When the closure returns, it will block
    /// until all futures that have been spawned into `s` complete.
    ///
    /// In general, spawned futures may access stack data in place that outlives the scope itself.
    /// Other data must be fully owned by the spawned future.
    pub fn scope<'scope, S, R>(&self, s: S) -> Vec<R>
    where
        S: FnOnce(&mut Scope<'scope, R>) + 'scope + Send,
        R: Send + 'static,
    {
        // We transmute the lifetime of the executor to the lifetime of the scope.
        let executor = &*self.executor;
        let executor: &'scope async_executor::Executor = unsafe { std::mem::transmute(executor) };

        let mut scope = Scope {
            executor,
            spawned_tasks: Vec::new(),
        };

        // We call the callback `s`, which will return the spawned tasks.
        s(&mut scope);

        if scope.spawned_tasks.is_empty() {
            // Nothing to do.
            Vec::with_capacity(0)
        } else if scope.spawned_tasks.len() == 1 {
            // Only one task to create, so we can drive it to completion directly.
            vec![futures_lite::future::block_on(&mut scope.spawned_tasks[0])]
        } else {
            let mut futures = async move {
                let mut future_results = Vec::with_capacity(scope.spawned_tasks.len());
                for task in scope.spawned_tasks {
                    future_results.push(task.await);
                }
                future_results
            };

            // Pin the futures so that they don't move, and can thus be relied upon.
            let futures = unsafe { core::pin::Pin::new_unchecked(&mut futures) };

            // We transmute the lifetime of the futures from 'scope to 'static so that
            // we can spawn then on the thread pool. This is only safe, because we
            // make sure to drive them to completion until we exit the function.
            let futures: std::pin::Pin<&mut dyn futures_lite::Future<Output = Vec<R>>> = futures;
            let mut futures: std::pin::Pin<
                &'static mut (dyn futures_lite::Future<Output = Vec<R>> + 'static),
            > = unsafe { std::mem::transmute(futures) };

            // We also use the calling thread to drive the futures to completion.
            loop {
                if let Some(result) =
                    futures_lite::future::block_on(futures_lite::future::poll_once(&mut futures))
                {
                    break result;
                };

                self.executor.try_tick();
            }
        }
    }

    /// Spawns a static future onto the thread pool. The returned `Task` is a future. It can also be
    /// cancelled and "detached" allowing it to continue running without having to be polled by the
    /// end-user.
    pub fn spawn<T>(
        &self,
        future: impl futures_lite::Future<Output = T> + Send + 'static,
    ) -> async_executor::Task<T>
    where
        T: Send + 'static,
    {
        self.executor.spawn(future)
    }
}

/// Scopes the execution of futures.
#[derive(Debug)]
pub struct Scope<'scope, R> {
    executor: &'scope async_executor::Executor<'scope>,
    spawned_tasks: Vec<async_executor::Task<R>>,
}

impl<'scope, T: Send + 'scope> Scope<'scope, T> {
    /// `spawn` is similar to the spawn function in Rust's standard library. The difference is that
    /// we spawn a future and it is scoped, meaning that it's guaranteed to terminate before the
    /// current stack frame goes away, allowing you to reference the parent stack frame directly.
    ///
    /// This is ensured by having the parent thread join on the child futures before the scope exits.
    pub fn spawn<Fut: futures_lite::Future<Output = T> + 'scope + Send>(&mut self, f: Fut) {
        let task = self.executor.spawn(f);
        self.spawned_tasks.push(task);
    }
}

#[cfg(test)]
mod tests {
    use std::sync::atomic::{AtomicI32, Ordering};

    use super::*;

    #[test]
    pub fn test_scoped_spawn() {
        let pool = ThreadPool::new(ThreadPoolDescriptor::default()).unwrap();

        let boxed = Box::new(100);
        let boxed_ref = &*boxed;

        let counter = Arc::new(AtomicI32::new(0));

        let outputs = pool.scope(|scope| {
            for _ in 0..100 {
                let count_clone = counter.clone();
                scope.spawn(async move {
                    if *boxed_ref != 100 {
                        panic!("expected 100")
                    } else {
                        count_clone.fetch_add(1, Ordering::Relaxed);
                        *boxed_ref
                    }
                });
            }
        });

        for output in &outputs {
            assert_eq!(*output, 100);
        }

        assert_eq!(outputs.len(), 100);
        assert_eq!(counter.load(Ordering::Relaxed), 100);
    }

    #[test]
    pub fn test_custom_handler() {
        let start_counter = Arc::new(AtomicI32::new(0));
        let thread_start_counter = start_counter.clone();

        let exit_counter = Arc::new(AtomicI32::new(0));
        let thread_exit_counter = exit_counter.clone();

        let _ = ThreadPool::new(ThreadPoolDescriptor {
            num_threads: 5,
            start_handler: Some(Box::new(move |_| {
                thread_start_counter.fetch_add(1, Ordering::SeqCst);
            })),
            exit_handler: Some(Box::new(move |_| {
                thread_exit_counter.fetch_add(1, Ordering::SeqCst);
            })),
            ..Default::default()
        })
        .unwrap();

        std::thread::sleep(std::time::Duration::from_millis(50));

        assert_eq!(start_counter.load(Ordering::SeqCst), 5);
        assert_eq!(exit_counter.load(Ordering::SeqCst), 5);
    }

    #[test]
    pub fn test_task_spawn() {
        let pool = ThreadPool::new(ThreadPoolDescriptor::default()).unwrap();

        let task = pool.spawn(async { 42 });

        assert_eq!(futures_lite::future::block_on(task), 42);
    }
}