1use crate::job::{JobFifo, JobRef, StackJob};
2use crate::latch::{AsCoreLatch, CoreLatch, CountLatch, Latch, LatchRef, LockLatch, SpinLatch};
3use crate::log::Event::*;
4use crate::log::Logger;
5use crate::sleep::Sleep;
6use crate::unwind;
7use crate::{
8 ErrorKind, ExitHandler, PanicHandler, StartHandler, ThreadPoolBuildError, ThreadPoolBuilder,
9};
10use crossbeam_deque::{Injector, Steal, Stealer, Worker};
11use std::cell::Cell;
12use std::collections::hash_map::DefaultHasher;
13use std::fmt;
14use std::hash::Hasher;
15use std::io;
16use std::mem;
17use std::ptr;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, Mutex, Once};
20use std::thread;
21use std::usize;
22
23pub struct ThreadBuilder {
26 name: Option<String>,
27 stack_size: Option<usize>,
28 worker: Worker<JobRef>,
29 stealer: Stealer<JobRef>,
30 registry: Arc<Registry>,
31 index: usize,
32}
33
34impl ThreadBuilder {
35 pub fn index(&self) -> usize {
37 self.index
38 }
39
40 pub fn name(&self) -> Option<&str> {
42 self.name.as_deref()
43 }
44
45 pub fn stack_size(&self) -> Option<usize> {
47 self.stack_size
48 }
49
50 pub fn run(self) {
53 unsafe { main_loop(self) }
54 }
55}
56
57impl fmt::Debug for ThreadBuilder {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 f.debug_struct("ThreadBuilder")
60 .field("pool", &self.registry.id())
61 .field("index", &self.index)
62 .field("name", &self.name)
63 .field("stack_size", &self.stack_size)
64 .finish()
65 }
66}
67
68pub trait ThreadSpawn {
73 private_decl! {}
74
75 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()>;
78}
79
80#[derive(Debug, Default)]
85pub struct DefaultSpawn;
86
87impl ThreadSpawn for DefaultSpawn {
88 private_impl! {}
89
90 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
91 let mut b = thread::Builder::new();
92 if let Some(name) = thread.name() {
93 b = b.name(name.to_owned());
94 }
95 if let Some(stack_size) = thread.stack_size() {
96 b = b.stack_size(stack_size);
97 }
98 b.spawn(|| thread.run())?;
99 Ok(())
100 }
101}
102
103#[derive(Debug)]
108pub struct CustomSpawn<F>(F);
109
110impl<F> CustomSpawn<F>
111where
112 F: FnMut(ThreadBuilder) -> io::Result<()>,
113{
114 pub(super) fn new(spawn: F) -> Self {
115 CustomSpawn(spawn)
116 }
117}
118
119impl<F> ThreadSpawn for CustomSpawn<F>
120where
121 F: FnMut(ThreadBuilder) -> io::Result<()>,
122{
123 private_impl! {}
124
125 #[inline]
126 fn spawn(&mut self, thread: ThreadBuilder) -> io::Result<()> {
127 (self.0)(thread)
128 }
129}
130
131pub(super) struct Registry {
132 logger: Logger,
133 thread_infos: Vec<ThreadInfo>,
134 sleep: Sleep,
135 injected_jobs: Injector<JobRef>,
136 broadcasts: Mutex<Vec<Worker<JobRef>>>,
137 panic_handler: Option<Box<PanicHandler>>,
138 start_handler: Option<Box<StartHandler>>,
139 exit_handler: Option<Box<ExitHandler>>,
140
141 terminate_count: AtomicUsize,
155}
156
157static mut THE_REGISTRY: Option<Arc<Registry>> = None;
161static THE_REGISTRY_SET: Once = Once::new();
162
163pub(super) fn global_registry() -> &'static Arc<Registry> {
167 set_global_registry(default_global_registry)
168 .or_else(|err| unsafe { THE_REGISTRY.as_ref().ok_or(err) })
169 .expect("The global thread pool has not been initialized.")
170}
171
172pub(super) fn init_global_registry<S>(
175 builder: ThreadPoolBuilder<S>,
176) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
177where
178 S: ThreadSpawn,
179{
180 set_global_registry(|| Registry::new(builder))
181}
182
183fn set_global_registry<F>(registry: F) -> Result<&'static Arc<Registry>, ThreadPoolBuildError>
186where
187 F: FnOnce() -> Result<Arc<Registry>, ThreadPoolBuildError>,
188{
189 let mut result = Err(ThreadPoolBuildError::new(
190 ErrorKind::GlobalPoolAlreadyInitialized,
191 ));
192
193 THE_REGISTRY_SET.call_once(|| {
194 result = registry()
195 .map(|registry: Arc<Registry>| unsafe { &*THE_REGISTRY.get_or_insert(registry) })
196 });
197
198 result
199}
200
201fn default_global_registry() -> Result<Arc<Registry>, ThreadPoolBuildError> {
202 let result = Registry::new(ThreadPoolBuilder::new());
203
204 let unsupported = matches!(&result, Err(e) if e.is_unsupported());
211 if unsupported && WorkerThread::current().is_null() {
212 let builder = ThreadPoolBuilder::new()
213 .num_threads(1)
214 .spawn_handler(|thread| {
215 let worker_thread = Box::leak(Box::new(WorkerThread::from(thread)));
219 let registry = &*worker_thread.registry;
220 let index = worker_thread.index;
221
222 unsafe {
223 WorkerThread::set_current(worker_thread);
224
225 Latch::set(®istry.thread_infos[index].primed);
227 }
228
229 Ok(())
230 });
231
232 let fallback_result = Registry::new(builder);
233 if fallback_result.is_ok() {
234 return fallback_result;
235 }
236 }
237
238 result
239}
240
241struct Terminator<'a>(&'a Arc<Registry>);
242
243impl<'a> Drop for Terminator<'a> {
244 fn drop(&mut self) {
245 self.0.terminate()
246 }
247}
248
249impl Registry {
250 pub(super) fn new<S>(
251 mut builder: ThreadPoolBuilder<S>,
252 ) -> Result<Arc<Self>, ThreadPoolBuildError>
253 where
254 S: ThreadSpawn,
255 {
256 let n_threads = Ord::min(builder.get_num_threads(), crate::max_num_threads());
258
259 let breadth_first = builder.get_breadth_first();
260
261 let (workers, stealers): (Vec<_>, Vec<_>) = (0..n_threads)
262 .map(|_| {
263 let worker = if breadth_first {
264 Worker::new_fifo()
265 } else {
266 Worker::new_lifo()
267 };
268
269 let stealer = worker.stealer();
270 (worker, stealer)
271 })
272 .unzip();
273
274 let (broadcasts, broadcast_stealers): (Vec<_>, Vec<_>) = (0..n_threads)
275 .map(|_| {
276 let worker = Worker::new_fifo();
277 let stealer = worker.stealer();
278 (worker, stealer)
279 })
280 .unzip();
281
282 let logger = Logger::new(n_threads);
283 let registry = Arc::new(Registry {
284 logger: logger.clone(),
285 thread_infos: stealers.into_iter().map(ThreadInfo::new).collect(),
286 sleep: Sleep::new(logger, n_threads),
287 injected_jobs: Injector::new(),
288 broadcasts: Mutex::new(broadcasts),
289 terminate_count: AtomicUsize::new(1),
290 panic_handler: builder.take_panic_handler(),
291 start_handler: builder.take_start_handler(),
292 exit_handler: builder.take_exit_handler(),
293 });
294
295 let t1000 = Terminator(®istry);
297
298 for (index, (worker, stealer)) in workers.into_iter().zip(broadcast_stealers).enumerate() {
299 let thread = ThreadBuilder {
300 name: builder.get_thread_name(index),
301 stack_size: builder.get_stack_size(),
302 registry: Arc::clone(®istry),
303 worker,
304 stealer,
305 index,
306 };
307 if let Err(e) = builder.get_spawn_handler().spawn(thread) {
308 return Err(ThreadPoolBuildError::new(ErrorKind::IOError(e)));
309 }
310 }
311
312 mem::forget(t1000);
314
315 Ok(registry)
316 }
317
318 pub(super) fn current() -> Arc<Registry> {
319 unsafe {
320 let worker_thread = WorkerThread::current();
321 let registry = if worker_thread.is_null() {
322 global_registry()
323 } else {
324 &(*worker_thread).registry
325 };
326 Arc::clone(registry)
327 }
328 }
329
330 pub(super) fn current_num_threads() -> usize {
334 unsafe {
335 let worker_thread = WorkerThread::current();
336 if worker_thread.is_null() {
337 global_registry().num_threads()
338 } else {
339 (*worker_thread).registry.num_threads()
340 }
341 }
342 }
343
344 pub(super) fn current_thread(&self) -> Option<&WorkerThread> {
346 unsafe {
347 let worker = WorkerThread::current().as_ref()?;
348 if worker.registry().id() == self.id() {
349 Some(worker)
350 } else {
351 None
352 }
353 }
354 }
355
356 pub(super) fn id(&self) -> RegistryId {
358 RegistryId {
361 addr: self as *const Self as usize,
362 }
363 }
364
365 #[inline]
366 pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
367 self.logger.log(event)
368 }
369
370 pub(super) fn num_threads(&self) -> usize {
371 self.thread_infos.len()
372 }
373
374 pub(super) fn catch_unwind(&self, f: impl FnOnce()) {
375 if let Err(err) = unwind::halt_unwinding(f) {
376 let abort_guard = unwind::AbortIfPanic;
378 if let Some(ref handler) = self.panic_handler {
379 handler(err);
380 mem::forget(abort_guard);
381 }
382 }
383 }
384
385 pub(super) fn wait_until_primed(&self) {
390 for info in &self.thread_infos {
391 info.primed.wait();
392 }
393 }
394
395 #[cfg(test)]
398 pub(super) fn wait_until_stopped(&self) {
399 for info in &self.thread_infos {
400 info.stopped.wait();
401 }
402 }
403
404 pub(super) fn inject_or_push(&self, job_ref: JobRef) {
414 let worker_thread = WorkerThread::current();
415 unsafe {
416 if !worker_thread.is_null() && (*worker_thread).registry().id() == self.id() {
417 (*worker_thread).push(job_ref);
418 } else {
419 self.inject(&[job_ref]);
420 }
421 }
422 }
423
424 pub(super) fn inject(&self, injected_jobs: &[JobRef]) {
428 self.log(|| JobsInjected {
429 count: injected_jobs.len(),
430 });
431
432 debug_assert_ne!(
438 self.terminate_count.load(Ordering::Acquire),
439 0,
440 "inject() sees state.terminate as true"
441 );
442
443 let queue_was_empty = self.injected_jobs.is_empty();
444
445 for &job_ref in injected_jobs {
446 self.injected_jobs.push(job_ref);
447 }
448
449 self.sleep
450 .new_injected_jobs(usize::MAX, injected_jobs.len() as u32, queue_was_empty);
451 }
452
453 fn has_injected_job(&self) -> bool {
454 !self.injected_jobs.is_empty()
455 }
456
457 fn pop_injected_job(&self, worker_index: usize) -> Option<JobRef> {
458 loop {
459 match self.injected_jobs.steal() {
460 Steal::Success(job) => {
461 self.log(|| JobUninjected {
462 worker: worker_index,
463 });
464 return Some(job);
465 }
466 Steal::Empty => return None,
467 Steal::Retry => {}
468 }
469 }
470 }
471
472 pub(super) fn inject_broadcast(&self, injected_jobs: impl ExactSizeIterator<Item = JobRef>) {
478 assert_eq!(self.num_threads(), injected_jobs.len());
479 self.log(|| JobBroadcast {
480 count: self.num_threads(),
481 });
482 {
483 let broadcasts = self.broadcasts.lock().unwrap();
484
485 debug_assert_ne!(
491 self.terminate_count.load(Ordering::Acquire),
492 0,
493 "inject_broadcast() sees state.terminate as true"
494 );
495
496 assert_eq!(broadcasts.len(), injected_jobs.len());
497 for (worker, job_ref) in broadcasts.iter().zip(injected_jobs) {
498 worker.push(job_ref);
499 }
500 }
501 for i in 0..self.num_threads() {
502 self.sleep.notify_worker_latch_is_set(i);
503 }
504 }
505
506 pub(super) fn in_worker<OP, R>(&self, op: OP) -> R
512 where
513 OP: FnOnce(&WorkerThread, bool) -> R + Send,
514 R: Send,
515 {
516 unsafe {
517 let worker_thread = WorkerThread::current();
518 if worker_thread.is_null() {
519 self.in_worker_cold(op)
520 } else if (*worker_thread).registry().id() != self.id() {
521 self.in_worker_cross(&*worker_thread, op)
522 } else {
523 op(&*worker_thread, false)
527 }
528 }
529 }
530
531 #[cold]
532 unsafe fn in_worker_cold<OP, R>(&self, op: OP) -> R
533 where
534 OP: FnOnce(&WorkerThread, bool) -> R + Send,
535 R: Send,
536 {
537 thread_local!(static LOCK_LATCH: LockLatch = LockLatch::new());
538
539 LOCK_LATCH.with(|l| {
540 debug_assert!(WorkerThread::current().is_null());
542 let job = StackJob::new(
543 |injected| {
544 let worker_thread = WorkerThread::current();
545 assert!(injected && !worker_thread.is_null());
546 op(&*worker_thread, true)
547 },
548 LatchRef::new(l),
549 );
550 self.inject(&[job.as_job_ref()]);
551 job.latch.wait_and_reset(); self.logger.log(|| Flush);
555
556 job.into_result()
557 })
558 }
559
560 #[cold]
561 unsafe fn in_worker_cross<OP, R>(&self, current_thread: &WorkerThread, op: OP) -> R
562 where
563 OP: FnOnce(&WorkerThread, bool) -> R + Send,
564 R: Send,
565 {
566 debug_assert!(current_thread.registry().id() != self.id());
569 let latch = SpinLatch::cross(current_thread);
570 let job = StackJob::new(
571 |injected| {
572 let worker_thread = WorkerThread::current();
573 assert!(injected && !worker_thread.is_null());
574 op(&*worker_thread, true)
575 },
576 latch,
577 );
578 self.inject(&[job.as_job_ref()]);
579 current_thread.wait_until(&job.latch);
580 job.into_result()
581 }
582
583 pub(super) fn increment_terminate_count(&self) {
604 let previous = self.terminate_count.fetch_add(1, Ordering::AcqRel);
605 debug_assert!(previous != 0, "registry ref count incremented from zero");
606 assert!(
607 previous != std::usize::MAX,
608 "overflow in registry ref count"
609 );
610 }
611
612 pub(super) fn terminate(&self) {
616 if self.terminate_count.fetch_sub(1, Ordering::AcqRel) == 1 {
617 for (i, thread_info) in self.thread_infos.iter().enumerate() {
618 unsafe { CountLatch::set_and_tickle_one(&thread_info.terminate, self, i) };
619 }
620 }
621 }
622
623 pub(super) fn notify_worker_latch_is_set(&self, target_worker_index: usize) {
625 self.sleep.notify_worker_latch_is_set(target_worker_index);
626 }
627}
628
629#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
630pub(super) struct RegistryId {
631 addr: usize,
632}
633
634struct ThreadInfo {
635 primed: LockLatch,
639
640 stopped: LockLatch,
643
644 terminate: CountLatch,
652
653 stealer: Stealer<JobRef>,
655}
656
657impl ThreadInfo {
658 fn new(stealer: Stealer<JobRef>) -> ThreadInfo {
659 ThreadInfo {
660 primed: LockLatch::new(),
661 stopped: LockLatch::new(),
662 terminate: CountLatch::new(),
663 stealer,
664 }
665 }
666}
667
668pub(super) struct WorkerThread {
672 worker: Worker<JobRef>,
674
675 stealer: Stealer<JobRef>,
677
678 fifo: JobFifo,
680
681 index: usize,
682
683 rng: XorShift64Star,
685
686 registry: Arc<Registry>,
687}
688
689thread_local! {
695 static WORKER_THREAD_STATE: Cell<*const WorkerThread> = Cell::new(ptr::null());
696}
697
698impl From<ThreadBuilder> for WorkerThread {
699 fn from(thread: ThreadBuilder) -> Self {
700 Self {
701 worker: thread.worker,
702 stealer: thread.stealer,
703 fifo: JobFifo::new(),
704 index: thread.index,
705 rng: XorShift64Star::new(),
706 registry: thread.registry,
707 }
708 }
709}
710
711impl Drop for WorkerThread {
712 fn drop(&mut self) {
713 WORKER_THREAD_STATE.with(|t| {
715 assert!(t.get().eq(&(self as *const _)));
716 t.set(ptr::null());
717 });
718 }
719}
720
721impl WorkerThread {
722 #[inline]
726 pub(super) fn current() -> *const WorkerThread {
727 WORKER_THREAD_STATE.with(Cell::get)
728 }
729
730 unsafe fn set_current(thread: *const WorkerThread) {
733 WORKER_THREAD_STATE.with(|t| {
734 assert!(t.get().is_null());
735 t.set(thread);
736 });
737 }
738
739 #[inline]
741 pub(super) fn registry(&self) -> &Arc<Registry> {
742 &self.registry
743 }
744
745 #[inline]
746 pub(super) fn log(&self, event: impl FnOnce() -> crate::log::Event) {
747 self.registry.logger.log(event)
748 }
749
750 #[inline]
752 pub(super) fn index(&self) -> usize {
753 self.index
754 }
755
756 #[inline]
757 pub(super) unsafe fn push(&self, job: JobRef) {
758 self.log(|| JobPushed { worker: self.index });
759 let queue_was_empty = self.worker.is_empty();
760 self.worker.push(job);
761 self.registry
762 .sleep
763 .new_internal_jobs(self.index, 1, queue_was_empty);
764 }
765
766 #[inline]
767 pub(super) unsafe fn push_fifo(&self, job: JobRef) {
768 self.push(self.fifo.push(job));
769 }
770
771 #[inline]
772 pub(super) fn local_deque_is_empty(&self) -> bool {
773 self.worker.is_empty()
774 }
775
776 #[inline]
781 pub(super) unsafe fn take_local_job(&self) -> Option<JobRef> {
782 let popped_job = self.worker.pop();
783
784 if popped_job.is_some() {
785 self.log(|| JobPopped { worker: self.index });
786 return popped_job;
787 }
788
789 loop {
790 match self.stealer.steal() {
791 Steal::Success(job) => return Some(job),
792 Steal::Empty => return None,
793 Steal::Retry => {}
794 }
795 }
796 }
797
798 fn has_injected_job(&self) -> bool {
799 !self.stealer.is_empty() || self.registry.has_injected_job()
800 }
801
802 #[inline]
805 pub(super) unsafe fn wait_until<L: AsCoreLatch + ?Sized>(&self, latch: &L) {
806 let latch = latch.as_core_latch();
807 if !latch.probe() {
808 self.wait_until_cold(latch);
809 }
810 }
811
812 #[cold]
813 unsafe fn wait_until_cold(&self, latch: &CoreLatch) {
814 let abort_guard = unwind::AbortIfPanic;
820
821 let mut idle_state = self.registry.sleep.start_looking(self.index, latch);
822 while !latch.probe() {
823 if let Some(job) = self
829 .take_local_job()
830 .or_else(|| self.steal())
831 .or_else(|| self.registry.pop_injected_job(self.index))
832 {
833 self.registry.sleep.work_found(idle_state);
834 self.execute(job);
835 idle_state = self.registry.sleep.start_looking(self.index, latch);
836 } else {
837 self.registry
838 .sleep
839 .no_work_found(&mut idle_state, latch, || self.has_injected_job())
840 }
841 }
842
843 self.registry.sleep.work_found(idle_state);
847
848 self.log(|| ThreadSawLatchSet {
849 worker: self.index,
850 latch_addr: latch.addr(),
851 });
852 mem::forget(abort_guard); }
854
855 #[inline]
856 pub(super) unsafe fn execute(&self, job: JobRef) {
857 job.execute();
858 }
859
860 unsafe fn steal(&self) -> Option<JobRef> {
865 debug_assert!(self.local_deque_is_empty());
867
868 let thread_infos = &self.registry.thread_infos.as_slice();
870 let num_threads = thread_infos.len();
871 if num_threads <= 1 {
872 return None;
873 }
874
875 loop {
876 let mut retry = false;
877 let start = self.rng.next_usize(num_threads);
878 let job = (start..num_threads)
879 .chain(0..start)
880 .filter(move |&i| i != self.index)
881 .find_map(|victim_index| {
882 let victim = &thread_infos[victim_index];
883 match victim.stealer.steal() {
884 Steal::Success(job) => {
885 self.log(|| JobStolen {
886 worker: self.index,
887 victim: victim_index,
888 });
889 Some(job)
890 }
891 Steal::Empty => None,
892 Steal::Retry => {
893 retry = true;
894 None
895 }
896 }
897 });
898 if job.is_some() || !retry {
899 return job;
900 }
901 }
902 }
903}
904
905unsafe fn main_loop(thread: ThreadBuilder) {
908 let worker_thread = &WorkerThread::from(thread);
909 WorkerThread::set_current(worker_thread);
910 let registry = &*worker_thread.registry;
911 let index = worker_thread.index;
912
913 Latch::set(®istry.thread_infos[index].primed);
915
916 let abort_guard = unwind::AbortIfPanic;
920
921 if let Some(ref handler) = registry.start_handler {
923 registry.catch_unwind(|| handler(index));
924 }
925
926 let my_terminate_latch = ®istry.thread_infos[index].terminate;
927 worker_thread.log(|| ThreadStart {
928 worker: index,
929 terminate_addr: my_terminate_latch.as_core_latch().addr(),
930 });
931 worker_thread.wait_until(my_terminate_latch);
932
933 debug_assert!(worker_thread.take_local_job().is_none());
935
936 Latch::set(®istry.thread_infos[index].stopped);
938
939 mem::forget(abort_guard);
941
942 worker_thread.log(|| ThreadTerminate { worker: index });
943
944 if let Some(ref handler) = registry.exit_handler {
946 registry.catch_unwind(|| handler(index));
947 }
949}
950
951pub(super) fn in_worker<OP, R>(op: OP) -> R
957where
958 OP: FnOnce(&WorkerThread, bool) -> R + Send,
959 R: Send,
960{
961 unsafe {
962 let owner_thread = WorkerThread::current();
963 if !owner_thread.is_null() {
964 op(&*owner_thread, false)
968 } else {
969 global_registry().in_worker(op)
970 }
971 }
972}
973
974struct XorShift64Star {
979 state: Cell<u64>,
980}
981
982impl XorShift64Star {
983 fn new() -> Self {
984 let mut seed = 0;
986 while seed == 0 {
987 let mut hasher = DefaultHasher::new();
988 static COUNTER: AtomicUsize = AtomicUsize::new(0);
989 hasher.write_usize(COUNTER.fetch_add(1, Ordering::Relaxed));
990 seed = hasher.finish();
991 }
992
993 XorShift64Star {
994 state: Cell::new(seed),
995 }
996 }
997
998 fn next(&self) -> u64 {
999 let mut x = self.state.get();
1000 debug_assert_ne!(x, 0);
1001 x ^= x >> 12;
1002 x ^= x << 25;
1003 x ^= x >> 27;
1004 self.state.set(x);
1005 x.wrapping_mul(0x2545_f491_4f6c_dd1d)
1006 }
1007
1008 fn next_usize(&self, n: usize) -> usize {
1010 (self.next() % n as u64) as usize
1011 }
1012}