async_cpupool/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3
4mod drop_notifier;
5mod executor;
6mod notify;
7mod queue;
8mod selector;
9mod spsc;
10mod sync;
11
12use std::{
13    future::Future,
14    num::{NonZeroU16, NonZeroUsize},
15    sync::{atomic::AtomicU64, Arc, Mutex},
16    thread::JoinHandle,
17    time::Instant,
18};
19
20use drop_notifier::{DropListener, DropNotifier};
21use executor::block_on;
22use queue::Queue;
23use selector::select;
24
25#[cfg(any(loom, test))]
26#[doc(hidden)]
27pub mod tests {
28    #[doc(hidden)]
29    pub mod queue {
30        pub use crate::queue::{bounded, queue_count, Queue};
31    }
32
33    #[doc(hidden)]
34    pub mod notify {
35        pub use crate::notify::{notify_count, Listener, Notify};
36    }
37}
38
39/// Configuration builder for the CpuPool
40#[derive(Debug)]
41pub struct Config {
42    name: &'static str,
43    buffer_multiplier: usize,
44    min_threads: u16,
45    max_threads: u16,
46}
47
48impl Config {
49    /// Create a new configuration builder with the default configuration
50    pub fn new() -> Self {
51        Config {
52            name: "cpupool",
53            buffer_multiplier: 8,
54            min_threads: 1,
55            max_threads: 4,
56        }
57    }
58
59    /// Set the name for the CpuPool
60    ///
61    /// This is used for setting the names of spawned threads
62    ///
63    /// default: `"cpupool"`
64    ///
65    /// Example:
66    /// ```rust
67    /// # use async_cpupool::Config;
68    /// Config::new().name("sig-pool");
69    /// ```
70    pub fn name(mut self, name: &'static str) -> Self {
71        self.name = name;
72        self
73    }
74
75    /// Set the multiplier for the internal queue's buffer size
76    ///
77    /// This value must be at least 1. the buffer's size will be equal to `max_threads * buffer_multiplier`
78    ///
79    /// default: `8`
80    ///
81    /// Example:
82    /// ```rust
83    /// # use async_cpupool::Config;
84    /// Config::new().buffer_multiplier(4);
85    /// ```
86    pub fn buffer_multiplier(mut self, buffer_multiplier: usize) -> Self {
87        self.buffer_multiplier = buffer_multiplier;
88        self
89    }
90
91    /// Set the minimum allowed number of running threads
92    ///
93    /// When there is little work to do, threads will be reaped until just this number remain
94    ///
95    /// default: `1`
96    ///
97    /// Example:
98    /// ```rust
99    /// # use async_cpupool::Config;
100    /// Config::new().min_threads(2);
101    /// ```
102    pub fn min_threads(mut self, min_threads: u16) -> Self {
103        self.min_threads = min_threads;
104        self
105    }
106
107    /// Set the maximum allowed number of running threads
108    ///
109    /// When the threadpool is under load, threads will be spawned until this limit is reached
110    ///
111    /// default: `4`
112    ///
113    /// Example:
114    /// ```rust
115    /// # use async_cpupool::Config;
116    /// Config::new().max_threads(16);
117    /// ```
118    pub fn max_threads(mut self, max_threads: u16) -> Self {
119        self.max_threads = max_threads;
120        self
121    }
122
123    /// Create a CpuPool with the given configuration, spawning `min_threads` threads
124    ///
125    /// This will error if `min_threads` is greater than `max_threads`, or if `buffer_multiplier`,
126    /// `max_threads`, or `min_threads` are `0`
127    ///
128    /// Example:
129    /// ```rust
130    /// # use async_cpupool::Config;
131    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
132    /// let pool = Config::new()
133    ///     .name("sig-pool")
134    ///     .min_threads(4)
135    ///     .max_threads(16)
136    ///     .buffer_multiplier(2)
137    ///     .build()?;
138    /// # Ok(())
139    /// # }
140    /// ```
141    pub fn build(self) -> Result<CpuPool, ConfigError> {
142        let Config {
143            name,
144            buffer_multiplier,
145            min_threads,
146            max_threads,
147        } = self;
148
149        if max_threads < min_threads {
150            return Err(ConfigError::ThreadCount);
151        }
152
153        let buffer_multiplier = buffer_multiplier
154            .try_into()
155            .map_err(|_| ConfigError::BufferMultiplier)?;
156
157        let max_threads = max_threads
158            .try_into()
159            .map_err(|_| ConfigError::MaxThreads)?;
160
161        let min_threads = min_threads
162            .try_into()
163            .map_err(|_| ConfigError::MinThreads)?;
164
165        Ok(CpuPool {
166            state: Arc::new(CpuPoolState::new(
167                name,
168                buffer_multiplier,
169                min_threads,
170                max_threads,
171            )),
172        })
173    }
174}
175
176impl Default for Config {
177    fn default() -> Self {
178        Self::new()
179    }
180}
181
182/// Errors created by invalid configuration of the CpuPool
183#[derive(Debug)]
184pub enum ConfigError {
185    /// The configured maxumim threads value is lower than the configured minimum threads value
186    ThreadCount,
187
188    /// The buffer_multiplier is 0
189    BufferMultiplier,
190
191    /// The max_threads value is 0
192    MaxThreads,
193
194    /// The min_threads value is 0
195    MinThreads,
196}
197
198impl std::fmt::Display for ConfigError {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        match self {
201            Self::ThreadCount => write!(f, "min_threads cannot be higher than max_threads"),
202            Self::BufferMultiplier => write!(f, "buffer_multiplier cannot be zero"),
203            Self::MaxThreads => write!(f, "max_threads cannot be zero"),
204            Self::MinThreads => write!(f, "min_threads cannot be zero"),
205        }
206    }
207}
208
209impl std::error::Error for ConfigError {}
210
211/// The blocking operation was canceled due to a panic
212#[derive(Debug)]
213pub struct Canceled;
214
215impl std::fmt::Display for Canceled {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        write!(f, "Blocking operation has panicked")
218    }
219}
220
221impl std::error::Error for Canceled {}
222
223/// The CPUPool handle
224#[derive(Clone, Debug)]
225pub struct CpuPool {
226    state: Arc<CpuPoolState>,
227}
228
229impl CpuPool {
230    /// Create a new CpuPool with the default configuration
231    ///
232    /// Example:
233    /// ```rust
234    /// # use async_cpupool::CpuPool;
235    /// let pool = CpuPool::new();
236    /// ```
237    pub fn new() -> Self {
238        Config::default().build().expect("Defaults are valid")
239    }
240
241    /// Create a configuration builder to customize the CpuPool
242    ///
243    /// Example:
244    /// ```rust
245    /// # use async_cpupool::CpuPool;
246    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
247    /// let pool = CpuPool::configure().build()?;
248    /// # Ok(())
249    /// # }
250    /// ```
251    pub fn configure() -> Config {
252        Config::default()
253    }
254
255    /// Spawn a blocking operation on the CpuPool
256    ///
257    /// Example:
258    /// ```rust
259    /// # use async_cpupool::CpuPool;
260    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
261    /// # smol::block_on(async {
262    /// let pool = CpuPool::new();
263    ///
264    /// pool.spawn(|| std::thread::sleep(std::time::Duration::from_secs(3))).await?;
265    /// # Ok(())
266    /// # })
267    /// # }
268    /// ```
269    pub fn spawn<F, T>(&self, send_fn: F) -> impl Future<Output = Result<T, Canceled>> + '_
270    where
271        F: FnOnce() -> T + Send + 'static,
272        T: Send + 'static,
273    {
274        let (response_tx, response_rx) = spsc::channel();
275
276        let send_fn = Box::new(move || {
277            let output = (send_fn)();
278
279            match response_tx.blocking_send(output) {
280                Ok(()) => (), // sent
281                Err(Canceled) => tracing::warn!("receiver hung up"),
282            }
283        });
284
285        let opt = self.state.queue.try_push(send_fn);
286
287        let current_threads = self
288            .state
289            .current_threads
290            .load(std::sync::atomic::Ordering::Acquire);
291
292        let pushed = match self.state.queue.is_full_or() {
293            Ok(()) => self.push_thread(),
294            Err(len) if len > current_threads as usize => self.push_thread(),
295            Err(_) => false,
296        };
297
298        if pushed {
299            tracing::trace!("Pushed thread");
300        }
301
302        async {
303            if let Some(item) = opt {
304                self.state.queue.push(item).await;
305            }
306
307            let current_threads = self
308                .state
309                .current_threads
310                .load(std::sync::atomic::Ordering::Acquire);
311
312            match self.state.queue.is_full_or() {
313                Ok(()) => {
314                    self.push_thread();
315                }
316                Err(len) if len > current_threads as usize => {
317                    self.push_thread();
318                }
319                Err(len) if len < current_threads.ilog2() as usize => {
320                    if let Some(thread) = self.pop_thread() {
321                        thread.reap().await;
322                    }
323                }
324                Err(_) => {}
325            }
326
327            response_rx.recv().await
328        }
329    }
330
331    /// Attempt to close the CpuPool
332    ///
333    /// This operation returns `true` when the pool was succesfully closed, or `false` if there
334    /// exist other references to the pool, preventing closure.
335    ///
336    /// It is not required to call close to close a CpuPool. CpuPools will automatically close
337    /// themselves when all clones are dropped. This is simply a method to integrate better with
338    /// async runtimes.
339    /// Example:
340    /// ```rust
341    /// # use async_cpupool::CpuPool;
342    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
343    /// # smol::block_on(async {
344    /// let pool = CpuPool::new();
345    ///
346    /// let closed = pool.close().await;
347    /// assert!(closed);
348    /// # Ok(())
349    /// # })
350    /// # }
351    /// ```
352    pub async fn close(self) -> bool {
353        let Some(mut state) = Arc::into_inner(self.state) else {
354            return false;
355        };
356
357        let mut threads = state.take_threads();
358
359        for thread in &mut threads {
360            thread.signal.take();
361        }
362
363        for mut thread in threads {
364            thread.closed.listen().await;
365
366            if let Some(handle) = thread.handle.take() {
367                handle.join().expect("Thread panicked");
368            }
369        }
370
371        true
372    }
373
374    fn push_thread(&self) -> bool {
375        let current_threads = self
376            .state
377            .current_threads
378            .load(std::sync::atomic::Ordering::Acquire);
379
380        if current_threads >= u64::from(u16::from(self.state.max_threads)) {
381            tracing::trace!("At thread maximum");
382
383            return false;
384        }
385
386        if self
387            .state
388            .current_threads
389            .compare_exchange(
390                current_threads,
391                current_threads + 1,
392                std::sync::atomic::Ordering::AcqRel,
393                std::sync::atomic::Ordering::Relaxed,
394            )
395            .is_err()
396        {
397            tracing::trace!("Didn't acquire spawn authorization");
398
399            return false;
400        }
401
402        // we updated the count, so we have authorization to spawn a new thread
403
404        let thread_id = self
405            .state
406            .thread_id
407            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
408
409        let thread = spawn(self.state.name, thread_id, self.state.queue.clone());
410
411        self.state
412            .threads
413            .lock()
414            .expect("threads lock poison")
415            .push(thread);
416
417        true
418    }
419
420    fn pop_thread(&self) -> Option<Thread> {
421        let current_threads = self
422            .state
423            .current_threads
424            .load(std::sync::atomic::Ordering::Acquire);
425
426        if current_threads <= u64::from(u16::from(self.state.min_threads)) {
427            tracing::info!("At thread minimum");
428
429            return None;
430        }
431
432        if self
433            .state
434            .current_threads
435            .compare_exchange(
436                current_threads,
437                current_threads - 1,
438                std::sync::atomic::Ordering::AcqRel,
439                std::sync::atomic::Ordering::Relaxed,
440            )
441            .is_err()
442        {
443            tracing::trace!("Didn't acquire reap authorization");
444
445            return None;
446        }
447
448        // we updated the count, so we have authorization to reap a thread
449
450        self.state
451            .threads
452            .lock()
453            .expect("threads lock poison")
454            .pop()
455    }
456}
457
458impl Default for CpuPool {
459    fn default() -> Self {
460        Self::new()
461    }
462}
463
464type SendFn = Box<dyn FnOnce() + Send>;
465
466struct CpuPoolState {
467    name: &'static str,
468    min_threads: NonZeroU16,
469    max_threads: NonZeroU16,
470    current_threads: AtomicU64,
471    thread_id: AtomicU64,
472    queue: Queue<SendFn>,
473    threads: Mutex<ThreadVec>,
474}
475
476impl CpuPoolState {
477    fn new(
478        name: &'static str,
479        buffer_multiplier: NonZeroUsize,
480        min_threads: NonZeroU16,
481        max_threads: NonZeroU16,
482    ) -> Self {
483        let thread_capacity = usize::from(u16::from(max_threads));
484
485        let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
486
487        let start_threads = u64::from(u16::from(min_threads));
488
489        let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
490            spawn(name, i, queue.clone())
491        });
492
493        let current_threads = AtomicU64::new(start_threads);
494        let thread_id = AtomicU64::new(start_threads);
495
496        CpuPoolState {
497            name,
498            min_threads,
499            max_threads,
500            current_threads,
501            thread_id,
502            queue,
503            threads: Mutex::new(threads),
504        }
505    }
506
507    fn take_threads(&mut self) -> Vec<Thread> {
508        self.threads.lock().expect("threads lock poison").take()
509    }
510}
511
512impl std::fmt::Debug for CpuPoolState {
513    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514        f.debug_struct("CpuPoolState")
515            .field("name", &self.name)
516            .field("min_threads", &self.min_threads)
517            .field("max_threads", &self.max_threads)
518            .finish()
519    }
520}
521
522struct ThreadVec {
523    threads: Vec<Thread>,
524}
525
526impl ThreadVec {
527    fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
528    where
529        F: Fn(u64) -> Thread,
530    {
531        let mut threads = Vec::with_capacity(max_threads);
532
533        for i in 0..start_threads {
534            threads.push((spawn)(i));
535        }
536
537        Self { threads }
538    }
539
540    fn push(&mut self, thread: Thread) {
541        self.threads.push(thread);
542    }
543
544    fn pop(&mut self) -> Option<Thread> {
545        self.threads.pop()
546    }
547
548    fn take(&mut self) -> Vec<Thread> {
549        std::mem::take(&mut self.threads)
550    }
551}
552
553impl Drop for ThreadVec {
554    fn drop(&mut self) {
555        for thread in &mut self.threads {
556            thread.signal.take();
557        }
558
559        for thread in &mut self.threads {
560            if let Some(handle) = thread.handle.take() {
561                handle.join().expect("Thread panicked");
562            }
563        }
564    }
565}
566
567struct Thread {
568    handle: Option<JoinHandle<()>>,
569    signal: Option<DropNotifier>,
570    closed: DropListener,
571}
572
573impl Thread {
574    async fn reap(mut self) {
575        self.signal.take();
576
577        self.closed.listen().await;
578
579        if let Some(handle) = self.handle.take() {
580            handle.join().expect("Thread panicked");
581        }
582    }
583}
584
585fn spawn(name: &'static str, id: u64, receiver: Queue<SendFn>) -> Thread {
586    let (closed_notifier, closed_listener) = drop_notifier::notifier();
587    let (signal_notifier, signal_listener) = drop_notifier::notifier();
588
589    let handle = std::thread::Builder::new()
590        .name(format!("{name}-{id}"))
591        .spawn(move || run(name, id, receiver, signal_listener, closed_notifier))
592        .expect("Failed to spawn new thread");
593
594    Thread {
595        handle: Some(handle),
596        signal: Some(signal_notifier),
597        closed: closed_listener,
598    }
599}
600
601struct MetricsGuard {
602    name: &'static str,
603    id: u64,
604    start: Instant,
605    armed: bool,
606}
607
608impl MetricsGuard {
609    fn guard(name: &'static str, id: u64) -> Self {
610        tracing::trace!("Starting {name}-{id}");
611        metrics::counter!(format!("async-cpupool.{name}.thread.launched")).increment(1);
612
613        MetricsGuard {
614            name,
615            id,
616            start: Instant::now(),
617            armed: true,
618        }
619    }
620
621    fn disarm(mut self) {
622        self.armed = false;
623    }
624}
625
626impl Drop for MetricsGuard {
627    fn drop(&mut self) {
628        metrics::counter!(format!("async-cpupool.{}.thread.closed", self.name), "clean" => (!self.armed).to_string()).increment(1);
629        metrics::histogram!(format!("async-cpupool.{}.thread.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
630        tracing::trace!("Stopping {}-{}", self.name, self.id);
631    }
632}
633
634fn run(
635    name: &'static str,
636    id: u64,
637    receiver: Queue<SendFn>,
638    signal: DropListener,
639    closed_tx: DropNotifier,
640) {
641    let guard = MetricsGuard::guard(name, id);
642
643    let mut signal = std::pin::pin!(signal.listen());
644
645    loop {
646        match block_on(select(&mut signal, receiver.pop())) {
647            selector::Either::Left(_) => break,
648            selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn),
649        }
650    }
651
652    guard.disarm();
653
654    drop(closed_tx);
655}
656
657fn invoke_send_fn(name: &'static str, send_fn: SendFn) {
658    let start = Instant::now();
659    metrics::counter!(format!("async-cpupool.{name}.operation.start")).increment(1);
660
661    let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
662        (send_fn)();
663    }));
664
665    metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1);
666    metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
667
668    if let Err(e) = res {
669        tracing::trace!("panic in spawned task: {e:?}");
670    }
671}