Skip to main content

lz4/
threadpool.rs

1//! Fixed-size work-stealing thread pool.
2//!
3//! Provides the same API as the `TPool` family of functions from the LZ4
4//! reference implementation (`threadpool.c`), backed by `rayon::ThreadPool`
5//! rather than pthreads / Windows IOCP.  Bounded-queue / blocking-submit
6//! semantics are preserved via a `crossbeam_channel::bounded` semaphore channel.
7//!
8
9use crossbeam_channel::{bounded, Receiver, Sender};
10use rayon::ThreadPool as RayonPool;
11use std::sync::{Arc, Condvar, Mutex};
12
13// ---------------------------------------------------------------------------
14// Job type — mirrors `TPool_job` from the C source.
15// ---------------------------------------------------------------------------
16type JobFn = Box<dyn FnOnce() + Send + 'static>;
17
18// ---------------------------------------------------------------------------
19// Internal shared state that workers and submitters both access.
20// ---------------------------------------------------------------------------
21struct PoolState {
22    pending: usize, // number of submitted-but-not-yet-finished jobs
23}
24
25/// Thread pool handle — equivalent to `TPool*` in the C API.
26///
27/// `TPool_create` → `TPool::new`
28/// `TPool_free`   → `Drop for TPool`  (joins workers automatically)
29/// `TPool_submitJob` → `TPool::submit_job`
30/// `TPool_jobsCompleted` → `TPool::jobs_completed`
31pub struct TPool {
32    /// rayon thread pool that executes jobs.
33    pool: Arc<RayonPool>,
34    /// Bounded channel used as a semaphore: the sender slot limits how many
35    /// jobs can be in-flight simultaneously (queue_size + nb_threads slots).
36    /// Submitters acquire a slot before posting; workers release it on finish.
37    slot_tx: Sender<()>,
38    slot_rx: Receiver<()>,
39    /// Shared counter of pending jobs plus a condvar for `jobs_completed`.
40    state: Arc<(Mutex<PoolState>, Condvar)>,
41}
42
43impl TPool {
44    /// `TPool_create(nbThreads, queueSize)` — returns `None` on failure.
45    ///
46    /// *nbThreads* must be ≥ 1, *queueSize* must be ≥ 1.
47    /// The C code allocates one extra queue slot to distinguish full vs. empty;
48    /// here `crossbeam_channel::bounded(queue_size + nb_threads)` plays the
49    /// same role as the semaphore initialised to `queueSize + nbWorkers` in the
50    /// Windows implementation.
51    pub fn new(nb_threads: usize, queue_size: usize) -> Option<Self> {
52        if nb_threads < 1 || queue_size < 1 {
53            return None;
54        }
55        let pool = rayon::ThreadPoolBuilder::new()
56            .num_threads(nb_threads)
57            .build()
58            .ok()?;
59
60        // Total slots = queue_size + nb_threads mirrors the Windows semaphore.
61        let capacity = queue_size + nb_threads;
62        let (slot_tx, slot_rx) = bounded(capacity);
63        // Pre-fill the channel so that `slot_rx.recv()` acts as "wait for a
64        // free slot" (i.e. we send tokens to represent free slots).
65        for _ in 0..capacity {
66            slot_tx.send(()).ok()?;
67        }
68
69        let state = Arc::new((Mutex::new(PoolState { pending: 0 }), Condvar::new()));
70
71        Some(TPool {
72            pool: Arc::new(pool),
73            slot_tx,
74            slot_rx,
75            state,
76        })
77    }
78
79    /// `TPool_submitJob(ctx, job_function, arg)` — may block if queue is full.
80    ///
81    /// In C the caller passes a raw `void (*fn)(void*)` + `void* arg`.
82    /// In Rust the equivalent is a `Box<dyn FnOnce() + Send>` closure that
83    /// has already captured its argument, eliminating the `void*` anti-pattern.
84    pub fn submit_job(&self, job: JobFn) {
85        // Block until a slot is available (mirrors `WaitForSingleObject` on the
86        // semaphore in the Windows path, or `pthread_cond_wait` in POSIX path).
87        self.slot_rx.recv().expect("threadpool slot channel closed");
88
89        // Increment pending count before spawning so `jobs_completed` cannot
90        // observe zero between submit and actual execution start.
91        {
92            let (lock, _cvar) = &*self.state;
93            let mut s = lock.lock().unwrap();
94            s.pending += 1;
95        }
96
97        let state = Arc::clone(&self.state);
98        let slot_tx = self.slot_tx.clone();
99        self.pool.spawn(move || {
100            job();
101
102            // Release the slot and decrement pending count.
103            let (lock, cvar) = &*state;
104            let mut s = lock.lock().unwrap();
105            s.pending -= 1;
106            if s.pending == 0 {
107                cvar.notify_all();
108            }
109            // Return the semaphore token.
110            let _ = slot_tx.send(());
111        });
112    }
113
114    /// `TPool_jobsCompleted(ctx)` — blocks until all submitted jobs have finished.
115    ///
116    /// Does NOT shut down the pool; it can accept further jobs afterwards,
117    /// identical to the C semantics.
118    pub fn jobs_completed(&self) {
119        let (lock, cvar) = &*self.state;
120        let mut s = lock.lock().unwrap();
121        while s.pending > 0 {
122            s = cvar.wait(s).unwrap();
123        }
124    }
125}
126
127impl Drop for TPool {
128    /// `TPool_free` — waits for all running jobs to finish then tears down the
129    /// rayon pool.  rayon's `ThreadPool` already joins workers on drop, so we
130    /// only need to ensure no jobs are still in-flight first.
131    fn drop(&mut self) {
132        self.jobs_completed();
133        // rayon::ThreadPool::drop joins all worker threads automatically.
134    }
135}