1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
use crate::*;

#[cfg(not(target_arch = "wasm32"))]
use std::thread;
#[cfg(target_arch = "wasm32")]
use wasm_thread as thread;

/// A threadpool which executes Geese work on both the main thread and background
/// threads. This threadpool controls the lifetime of its background threads, and they
/// are cancelled upon drop.
#[derive(Debug)]
pub struct HardwareThreadPool {
    /// The inner threadpool information.
    inner: Arc<HardwareThreadPoolInner>,
}

impl HardwareThreadPool {
    /// Creates a new hardware threadpool and spawns the specified number of background threads. If `0` is specified,
    /// then this acts as a single-threaded (main thread only) threadpool.
    #[inline(always)]
    pub fn new(background_threads: usize) -> Self {
        let inner = Arc::new(HardwareThreadPoolInner {
            handle_count: AtomicUsize::new(1),
            ..Default::default()
        });
        Self::spawn_workers(&inner, background_threads);
        Self { inner }
    }

    /// Spawns background worker threads which repeatedly join on the inner context until it is destroyed.
    #[inline(always)]
    fn spawn_workers(inner: &Arc<HardwareThreadPoolInner>, background_threads: usize) {
        for _ in 0..background_threads {
            let inner_clone = inner.clone();
            thread::spawn(move || inner_clone.run());
        }
    }
}

impl Default for HardwareThreadPool {
    #[inline(always)]
    fn default() -> Self {
        Self::new(0)
    }
}

impl Drop for HardwareThreadPool {
    #[inline(always)]
    fn drop(&mut self) {
        self.inner.decrement_counter();
    }
}

impl GeeseThreadPool for HardwareThreadPool {
    fn set_callback(&self, callback: Option<Arc<dyn Fn() + Send + Sync>>) {
        self.inner.set_callback(callback);
    }
}

/// Stores the inner synchronization state for a threadpool.
#[derive(Default)]
struct HardwareThreadPoolInner {
    /// The callback that available threads should invoke.
    callback: wasm_sync::Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
    /// The number of extant threadpool handles. When this number reaches zero, worker threads cancel themselves.
    handle_count: AtomicUsize,
    /// A condition variable that is signaled when the callback or handle count changes.
    on_changed: wasm_sync::Condvar,
}

impl HardwareThreadPoolInner {
    /// Joins this threadpool, attempting to complete any available work or waiting until available work changes.
    #[inline(always)]
    pub fn join(&self) {
        let guard = self
            .callback
            .lock()
            .expect("Could not acquire callback lock.");
        if let Some(callback) = &*guard {
            let to_run = callback.clone();
            drop(guard);
            to_run();
        } else {
            drop(self.on_changed.wait(guard));
        }
    }

    /// Dedicates the caller thread to this threadpool, repeatedly joining with the pool until it is destroyed.
    #[inline(always)]
    pub fn run(&self) {
        while self.handle_count.load(Ordering::Acquire) > 0 {
            self.join();
        }
    }

    /// Sets the work callback and notifies all waiting threads.
    #[inline(always)]
    pub fn set_callback(&self, callback: Option<Arc<dyn Fn() + Send + Sync>>) {
        *self
            .callback
            .lock()
            .expect("Could not acquire callback lock.") = callback;
        self.on_changed.notify_all();
    }

    /// Decrements the handle counter and notifies all waiting threads.
    #[inline(always)]
    pub fn decrement_counter(&self) {
        self.handle_count.fetch_sub(1, Ordering::Release);
        self.on_changed.notify_all();
    }
}

impl std::fmt::Debug for HardwareThreadPoolInner {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("HardwareThreadPoolInner")
            .field("handle_count", &self.handle_count)
            .field("on_changed", &self.on_changed)
            .finish()
    }
}