1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3
4mod drop_notifier;
5mod executor;
6mod notify;
7mod queue;
8mod selector;
9mod spsc;
10mod sync;
11
12use std::{
13 future::Future,
14 num::{NonZeroU16, NonZeroUsize},
15 sync::{atomic::AtomicU64, Arc, Mutex},
16 thread::JoinHandle,
17 time::Instant,
18};
19
20use drop_notifier::{DropListener, DropNotifier};
21use executor::block_on;
22use queue::Queue;
23use selector::select;
24
25#[cfg(any(loom, test))]
26#[doc(hidden)]
27pub mod tests {
28 #[doc(hidden)]
29 pub mod queue {
30 pub use crate::queue::{bounded, queue_count, Queue};
31 }
32
33 #[doc(hidden)]
34 pub mod notify {
35 pub use crate::notify::{notify_count, Listener, Notify};
36 }
37}
38
39#[derive(Debug)]
41pub struct Config {
42 name: &'static str,
43 buffer_multiplier: usize,
44 min_threads: u16,
45 max_threads: u16,
46}
47
48impl Config {
49 pub fn new() -> Self {
51 Config {
52 name: "cpupool",
53 buffer_multiplier: 8,
54 min_threads: 1,
55 max_threads: 4,
56 }
57 }
58
59 pub fn name(mut self, name: &'static str) -> Self {
71 self.name = name;
72 self
73 }
74
75 pub fn buffer_multiplier(mut self, buffer_multiplier: usize) -> Self {
87 self.buffer_multiplier = buffer_multiplier;
88 self
89 }
90
91 pub fn min_threads(mut self, min_threads: u16) -> Self {
103 self.min_threads = min_threads;
104 self
105 }
106
107 pub fn max_threads(mut self, max_threads: u16) -> Self {
119 self.max_threads = max_threads;
120 self
121 }
122
123 pub fn build(self) -> Result<CpuPool, ConfigError> {
142 let Config {
143 name,
144 buffer_multiplier,
145 min_threads,
146 max_threads,
147 } = self;
148
149 if max_threads < min_threads {
150 return Err(ConfigError::ThreadCount);
151 }
152
153 let buffer_multiplier = buffer_multiplier
154 .try_into()
155 .map_err(|_| ConfigError::BufferMultiplier)?;
156
157 let max_threads = max_threads
158 .try_into()
159 .map_err(|_| ConfigError::MaxThreads)?;
160
161 let min_threads = min_threads
162 .try_into()
163 .map_err(|_| ConfigError::MinThreads)?;
164
165 Ok(CpuPool {
166 state: Arc::new(CpuPoolState::new(
167 name,
168 buffer_multiplier,
169 min_threads,
170 max_threads,
171 )),
172 })
173 }
174}
175
176impl Default for Config {
177 fn default() -> Self {
178 Self::new()
179 }
180}
181
182#[derive(Debug)]
184pub enum ConfigError {
185 ThreadCount,
187
188 BufferMultiplier,
190
191 MaxThreads,
193
194 MinThreads,
196}
197
198impl std::fmt::Display for ConfigError {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 match self {
201 Self::ThreadCount => write!(f, "min_threads cannot be higher than max_threads"),
202 Self::BufferMultiplier => write!(f, "buffer_multiplier cannot be zero"),
203 Self::MaxThreads => write!(f, "max_threads cannot be zero"),
204 Self::MinThreads => write!(f, "min_threads cannot be zero"),
205 }
206 }
207}
208
209impl std::error::Error for ConfigError {}
210
211#[derive(Debug)]
213pub struct Canceled;
214
215impl std::fmt::Display for Canceled {
216 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217 write!(f, "Blocking operation has panicked")
218 }
219}
220
221impl std::error::Error for Canceled {}
222
223#[derive(Clone, Debug)]
225pub struct CpuPool {
226 state: Arc<CpuPoolState>,
227}
228
229impl CpuPool {
230 pub fn new() -> Self {
238 Config::default().build().expect("Defaults are valid")
239 }
240
241 pub fn configure() -> Config {
252 Config::default()
253 }
254
255 pub fn spawn<F, T>(&self, send_fn: F) -> impl Future<Output = Result<T, Canceled>> + '_
270 where
271 F: FnOnce() -> T + Send + 'static,
272 T: Send + 'static,
273 {
274 let (response_tx, response_rx) = spsc::channel();
275
276 let send_fn = Box::new(move || {
277 let output = (send_fn)();
278
279 match response_tx.blocking_send(output) {
280 Ok(()) => (), Err(Canceled) => tracing::warn!("receiver hung up"),
282 }
283 });
284
285 let opt = self.state.queue.try_push(send_fn);
286
287 let current_threads = self
288 .state
289 .current_threads
290 .load(std::sync::atomic::Ordering::Acquire);
291
292 let pushed = match self.state.queue.is_full_or() {
293 Ok(()) => self.push_thread(),
294 Err(len) if len > current_threads as usize => self.push_thread(),
295 Err(_) => false,
296 };
297
298 if pushed {
299 tracing::trace!("Pushed thread");
300 }
301
302 async {
303 if let Some(item) = opt {
304 self.state.queue.push(item).await;
305 }
306
307 let current_threads = self
308 .state
309 .current_threads
310 .load(std::sync::atomic::Ordering::Acquire);
311
312 match self.state.queue.is_full_or() {
313 Ok(()) => {
314 self.push_thread();
315 }
316 Err(len) if len > current_threads as usize => {
317 self.push_thread();
318 }
319 Err(len) if len < current_threads.ilog2() as usize => {
320 if let Some(thread) = self.pop_thread() {
321 thread.reap().await;
322 }
323 }
324 Err(_) => {}
325 }
326
327 response_rx.recv().await
328 }
329 }
330
331 pub async fn close(self) -> bool {
353 let Some(mut state) = Arc::into_inner(self.state) else {
354 return false;
355 };
356
357 let mut threads = state.take_threads();
358
359 for thread in &mut threads {
360 thread.signal.take();
361 }
362
363 for mut thread in threads {
364 thread.closed.listen().await;
365
366 if let Some(handle) = thread.handle.take() {
367 handle.join().expect("Thread panicked");
368 }
369 }
370
371 true
372 }
373
374 fn push_thread(&self) -> bool {
375 let current_threads = self
376 .state
377 .current_threads
378 .load(std::sync::atomic::Ordering::Acquire);
379
380 if current_threads >= u64::from(u16::from(self.state.max_threads)) {
381 tracing::trace!("At thread maximum");
382
383 return false;
384 }
385
386 if self
387 .state
388 .current_threads
389 .compare_exchange(
390 current_threads,
391 current_threads + 1,
392 std::sync::atomic::Ordering::AcqRel,
393 std::sync::atomic::Ordering::Relaxed,
394 )
395 .is_err()
396 {
397 tracing::trace!("Didn't acquire spawn authorization");
398
399 return false;
400 }
401
402 let thread_id = self
405 .state
406 .thread_id
407 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
408
409 let thread = spawn(self.state.name, thread_id, self.state.queue.clone());
410
411 self.state
412 .threads
413 .lock()
414 .expect("threads lock poison")
415 .push(thread);
416
417 true
418 }
419
420 fn pop_thread(&self) -> Option<Thread> {
421 let current_threads = self
422 .state
423 .current_threads
424 .load(std::sync::atomic::Ordering::Acquire);
425
426 if current_threads <= u64::from(u16::from(self.state.min_threads)) {
427 tracing::info!("At thread minimum");
428
429 return None;
430 }
431
432 if self
433 .state
434 .current_threads
435 .compare_exchange(
436 current_threads,
437 current_threads - 1,
438 std::sync::atomic::Ordering::AcqRel,
439 std::sync::atomic::Ordering::Relaxed,
440 )
441 .is_err()
442 {
443 tracing::trace!("Didn't acquire reap authorization");
444
445 return None;
446 }
447
448 self.state
451 .threads
452 .lock()
453 .expect("threads lock poison")
454 .pop()
455 }
456}
457
458impl Default for CpuPool {
459 fn default() -> Self {
460 Self::new()
461 }
462}
463
464type SendFn = Box<dyn FnOnce() + Send>;
465
466struct CpuPoolState {
467 name: &'static str,
468 min_threads: NonZeroU16,
469 max_threads: NonZeroU16,
470 current_threads: AtomicU64,
471 thread_id: AtomicU64,
472 queue: Queue<SendFn>,
473 threads: Mutex<ThreadVec>,
474}
475
476impl CpuPoolState {
477 fn new(
478 name: &'static str,
479 buffer_multiplier: NonZeroUsize,
480 min_threads: NonZeroU16,
481 max_threads: NonZeroU16,
482 ) -> Self {
483 let thread_capacity = usize::from(u16::from(max_threads));
484
485 let queue = queue::bounded(usize::from(buffer_multiplier).saturating_mul(thread_capacity));
486
487 let start_threads = u64::from(u16::from(min_threads));
488
489 let threads = ThreadVec::new(start_threads, thread_capacity, |i| {
490 spawn(name, i, queue.clone())
491 });
492
493 let current_threads = AtomicU64::new(start_threads);
494 let thread_id = AtomicU64::new(start_threads);
495
496 CpuPoolState {
497 name,
498 min_threads,
499 max_threads,
500 current_threads,
501 thread_id,
502 queue,
503 threads: Mutex::new(threads),
504 }
505 }
506
507 fn take_threads(&mut self) -> Vec<Thread> {
508 self.threads.lock().expect("threads lock poison").take()
509 }
510}
511
512impl std::fmt::Debug for CpuPoolState {
513 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
514 f.debug_struct("CpuPoolState")
515 .field("name", &self.name)
516 .field("min_threads", &self.min_threads)
517 .field("max_threads", &self.max_threads)
518 .finish()
519 }
520}
521
522struct ThreadVec {
523 threads: Vec<Thread>,
524}
525
526impl ThreadVec {
527 fn new<F>(start_threads: u64, max_threads: usize, spawn: F) -> Self
528 where
529 F: Fn(u64) -> Thread,
530 {
531 let mut threads = Vec::with_capacity(max_threads);
532
533 for i in 0..start_threads {
534 threads.push((spawn)(i));
535 }
536
537 Self { threads }
538 }
539
540 fn push(&mut self, thread: Thread) {
541 self.threads.push(thread);
542 }
543
544 fn pop(&mut self) -> Option<Thread> {
545 self.threads.pop()
546 }
547
548 fn take(&mut self) -> Vec<Thread> {
549 std::mem::take(&mut self.threads)
550 }
551}
552
553impl Drop for ThreadVec {
554 fn drop(&mut self) {
555 for thread in &mut self.threads {
556 thread.signal.take();
557 }
558
559 for thread in &mut self.threads {
560 if let Some(handle) = thread.handle.take() {
561 handle.join().expect("Thread panicked");
562 }
563 }
564 }
565}
566
567struct Thread {
568 handle: Option<JoinHandle<()>>,
569 signal: Option<DropNotifier>,
570 closed: DropListener,
571}
572
573impl Thread {
574 async fn reap(mut self) {
575 self.signal.take();
576
577 self.closed.listen().await;
578
579 if let Some(handle) = self.handle.take() {
580 handle.join().expect("Thread panicked");
581 }
582 }
583}
584
585fn spawn(name: &'static str, id: u64, receiver: Queue<SendFn>) -> Thread {
586 let (closed_notifier, closed_listener) = drop_notifier::notifier();
587 let (signal_notifier, signal_listener) = drop_notifier::notifier();
588
589 let handle = std::thread::Builder::new()
590 .name(format!("{name}-{id}"))
591 .spawn(move || run(name, id, receiver, signal_listener, closed_notifier))
592 .expect("Failed to spawn new thread");
593
594 Thread {
595 handle: Some(handle),
596 signal: Some(signal_notifier),
597 closed: closed_listener,
598 }
599}
600
601struct MetricsGuard {
602 name: &'static str,
603 id: u64,
604 start: Instant,
605 armed: bool,
606}
607
608impl MetricsGuard {
609 fn guard(name: &'static str, id: u64) -> Self {
610 tracing::trace!("Starting {name}-{id}");
611 metrics::counter!(format!("async-cpupool.{name}.thread.launched")).increment(1);
612
613 MetricsGuard {
614 name,
615 id,
616 start: Instant::now(),
617 armed: true,
618 }
619 }
620
621 fn disarm(mut self) {
622 self.armed = false;
623 }
624}
625
626impl Drop for MetricsGuard {
627 fn drop(&mut self) {
628 metrics::counter!(format!("async-cpupool.{}.thread.closed", self.name), "clean" => (!self.armed).to_string()).increment(1);
629 metrics::histogram!(format!("async-cpupool.{}.thread.seconds", self.name), "clean" => (!self.armed).to_string()).record(self.start.elapsed().as_secs_f64());
630 tracing::trace!("Stopping {}-{}", self.name, self.id);
631 }
632}
633
634fn run(
635 name: &'static str,
636 id: u64,
637 receiver: Queue<SendFn>,
638 signal: DropListener,
639 closed_tx: DropNotifier,
640) {
641 let guard = MetricsGuard::guard(name, id);
642
643 let mut signal = std::pin::pin!(signal.listen());
644
645 loop {
646 match block_on(select(&mut signal, receiver.pop())) {
647 selector::Either::Left(_) => break,
648 selector::Either::Right(send_fn) => invoke_send_fn(name, send_fn),
649 }
650 }
651
652 guard.disarm();
653
654 drop(closed_tx);
655}
656
657fn invoke_send_fn(name: &'static str, send_fn: SendFn) {
658 let start = Instant::now();
659 metrics::counter!(format!("async-cpupool.{name}.operation.start")).increment(1);
660
661 let res = std::panic::catch_unwind(std::panic::AssertUnwindSafe(move || {
662 (send_fn)();
663 }));
664
665 metrics::counter!(format!("async-cpupool.{name}.operation.end"), "complete" => res.is_ok().to_string()).increment(1);
666 metrics::histogram!(format!("async-cpupool.{name}.operation.seconds"), "complete" => res.is_ok().to_string()).record(start.elapsed().as_secs_f64());
667
668 if let Err(e) = res {
669 tracing::trace!("panic in spawned task: {e:?}");
670 }
671}