1use super::range::{
12 FixedRangeFactory, Range, RangeFactory, RangeOrchestrator, WorkStealingRangeFactory,
13};
14use super::sync::{make_lending_group, Borrower, Lender, WorkerState};
15use super::util::LifetimeParameterized;
16use crate::iter::Accumulator;
17use crate::macros::{log_debug, log_error, log_warn};
18use crossbeam_utils::CachePadded;
19#[cfg(all(
21 not(miri),
22 any(
23 target_os = "android",
24 target_os = "dragonfly",
25 target_os = "freebsd",
26 target_os = "linux"
27 )
28))]
29use nix::{
30 sched::{sched_setaffinity, CpuSet},
31 unistd::Pid,
32};
33use std::convert::TryFrom;
34use std::marker::PhantomData;
35use std::num::NonZeroUsize;
36use std::ops::ControlFlow;
37use std::sync::atomic::{AtomicUsize, Ordering};
38use std::sync::{Arc, Mutex};
39use std::thread::JoinHandle;
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
43pub enum ThreadCount {
44 AvailableParallelism,
47 Count(NonZeroUsize),
49}
50
51impl TryFrom<usize> for ThreadCount {
52 type Error = <NonZeroUsize as TryFrom<usize>>::Error;
53
54 fn try_from(thread_count: usize) -> Result<Self, Self::Error> {
55 let count = NonZeroUsize::try_from(thread_count)?;
56 Ok(ThreadCount::Count(count))
57 }
58}
59
60#[derive(Clone, Copy)]
62pub enum RangeStrategy {
63 Fixed,
65 WorkStealing,
67}
68
69#[derive(Clone, Copy)]
71pub enum CpuPinningPolicy {
72 No,
74 IfSupported,
77 Always,
80}
81
82pub struct ThreadPoolBuilder {
84 pub num_threads: ThreadCount,
86 pub range_strategy: RangeStrategy,
88 pub cpu_pinning: CpuPinningPolicy,
90}
91
92impl ThreadPoolBuilder {
93 pub fn build(&self) -> ThreadPool {
113 ThreadPool::new(self)
114 }
115}
116
117pub struct ThreadPool {
125 inner: ThreadPoolEnum,
126}
127
128impl ThreadPool {
129 fn new(builder: &ThreadPoolBuilder) -> Self {
131 Self {
132 inner: ThreadPoolEnum::new(builder),
133 }
134 }
135
136 pub fn num_threads(&self) -> NonZeroUsize {
139 self.inner.num_threads()
140 }
141
142 pub(crate) fn upper_bounded_pipeline<Output: Send, Accum>(
149 &mut self,
150 input_len: usize,
151 init: impl Fn() -> Accum + Sync,
152 process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
153 finalize: impl Fn(Accum) -> Output + Sync,
154 reduce: impl Fn(Output, Output) -> Output,
155 ) -> Output {
156 self.inner
157 .upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
158 }
159
160 pub(crate) fn iter_pipeline<Output: Send>(
163 &mut self,
164 input_len: usize,
165 accum: impl Accumulator<usize, Output> + Sync,
166 reduce: impl Accumulator<Output, Output>,
167 ) -> Output {
168 self.inner.iter_pipeline(input_len, accum, reduce)
169 }
170}
171
172enum ThreadPoolEnum {
175 Fixed(ThreadPoolImpl<FixedRangeFactory>),
176 WorkStealing(ThreadPoolImpl<WorkStealingRangeFactory>),
177}
178
179impl ThreadPoolEnum {
180 fn new(builder: &ThreadPoolBuilder) -> Self {
182 let num_threads: NonZeroUsize = match builder.num_threads {
183 ThreadCount::AvailableParallelism => std::thread::available_parallelism()
184 .expect("Getting the available parallelism failed"),
185 ThreadCount::Count(count) => count,
186 };
187 let num_threads: usize = num_threads.into();
188 match builder.range_strategy {
189 RangeStrategy::Fixed => ThreadPoolEnum::Fixed(ThreadPoolImpl::new(
190 num_threads,
191 FixedRangeFactory::new(num_threads),
192 builder.cpu_pinning,
193 )),
194 RangeStrategy::WorkStealing => ThreadPoolEnum::WorkStealing(ThreadPoolImpl::new(
195 num_threads,
196 WorkStealingRangeFactory::new(num_threads),
197 builder.cpu_pinning,
198 )),
199 }
200 }
201
202 fn num_threads(&self) -> NonZeroUsize {
205 match self {
206 ThreadPoolEnum::Fixed(inner) => inner.num_threads(),
207 ThreadPoolEnum::WorkStealing(inner) => inner.num_threads(),
208 }
209 }
210
211 fn upper_bounded_pipeline<Output: Send, Accum>(
218 &mut self,
219 input_len: usize,
220 init: impl Fn() -> Accum + Sync,
221 process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
222 finalize: impl Fn(Accum) -> Output + Sync,
223 reduce: impl Fn(Output, Output) -> Output,
224 ) -> Output {
225 match self {
226 ThreadPoolEnum::Fixed(inner) => {
227 inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
228 }
229 ThreadPoolEnum::WorkStealing(inner) => {
230 inner.upper_bounded_pipeline(input_len, init, process_item, finalize, reduce)
231 }
232 }
233 }
234
235 fn iter_pipeline<Output: Send>(
238 &mut self,
239 input_len: usize,
240 accum: impl Accumulator<usize, Output> + Sync,
241 reduce: impl Accumulator<Output, Output>,
242 ) -> Output {
243 match self {
244 ThreadPoolEnum::Fixed(inner) => inner.iter_pipeline(input_len, accum, reduce),
245 ThreadPoolEnum::WorkStealing(inner) => inner.iter_pipeline(input_len, accum, reduce),
246 }
247 }
248}
249
250struct ThreadPoolImpl<F: RangeFactory> {
253 threads: Vec<WorkerThreadHandle>,
255 range_orchestrator: F::Orchestrator,
257 pipeline: Lender<DynLifetimeSyncPipeline<F::Range>>,
259}
260
261struct WorkerThreadHandle {
263 handle: JoinHandle<()>,
265}
266
267impl<F: RangeFactory> ThreadPoolImpl<F> {
268 fn new(num_threads: usize, range_factory: F, cpu_pinning: CpuPinningPolicy) -> Self
270 where
271 F::Range: Send + 'static,
272 {
273 let (lender, borrowers) = make_lending_group(num_threads);
274
275 #[cfg(any(
276 miri,
277 not(any(
278 target_os = "android",
279 target_os = "dragonfly",
280 target_os = "freebsd",
281 target_os = "linux"
282 ))
283 ))]
284 match cpu_pinning {
285 CpuPinningPolicy::No => (),
286 CpuPinningPolicy::IfSupported => {
287 log_warn!("Pinning threads to CPUs is not implemented on this platform.")
288 }
289 CpuPinningPolicy::Always => {
290 panic!("Pinning threads to CPUs is not implemented on this platform.")
291 }
292 }
293
294 let threads = borrowers
295 .into_iter()
296 .enumerate()
297 .map(|(id, borrower)| {
298 let mut context = ThreadContext {
299 id,
300 range: range_factory.range(id),
301 pipeline: borrower,
302 };
303 WorkerThreadHandle {
304 handle: std::thread::spawn(move || {
305 #[cfg(all(
306 not(miri),
307 any(
308 target_os = "android",
309 target_os = "dragonfly",
310 target_os = "freebsd",
311 target_os = "linux"
312 )
313 ))]
314 match cpu_pinning {
315 CpuPinningPolicy::No => (),
316 CpuPinningPolicy::IfSupported => {
317 let mut cpu_set = CpuSet::new();
318 if let Err(_e) = cpu_set.set(id) {
319 log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
320 } else if let Err(_e) =
321 sched_setaffinity(Pid::from_raw(0), &cpu_set)
322 {
323 log_warn!("Failed to set CPU affinity for thread #{id}: {_e}");
324 } else {
325 log_debug!("Pinned thread #{id} to CPU #{id}");
326 }
327 }
328 CpuPinningPolicy::Always => {
329 let mut cpu_set = CpuSet::new();
330 if let Err(e) = cpu_set.set(id) {
331 panic!("Failed to set CPU affinity for thread #{id}: {e}");
332 } else if let Err(e) = sched_setaffinity(Pid::from_raw(0), &cpu_set)
333 {
334 panic!("Failed to set CPU affinity for thread #{id}: {e}");
335 } else {
336 log_debug!("Pinned thread #{id} to CPU #{id}");
337 }
338 }
339 }
340 context.run()
341 }),
342 }
343 })
344 .collect();
345 log_debug!("[main thread] Spawned threads");
346
347 Self {
348 threads,
349 range_orchestrator: range_factory.orchestrator(),
350 pipeline: lender,
351 }
352 }
353
354 fn num_threads(&self) -> NonZeroUsize {
357 self.threads.len().try_into().unwrap()
358 }
359
360 fn upper_bounded_pipeline<Output: Send, Accum>(
367 &mut self,
368 input_len: usize,
369 init: impl Fn() -> Accum + Sync,
370 process_item: impl Fn(Accum, usize) -> ControlFlow<Accum, Accum> + Sync,
371 finalize: impl Fn(Accum) -> Output + Sync,
372 reduce: impl Fn(Output, Output) -> Output,
373 ) -> Output {
374 self.range_orchestrator.reset_ranges(input_len);
375
376 let num_threads = self.threads.len();
377 let outputs = (0..num_threads)
378 .map(|_| Mutex::new(None))
379 .collect::<Arc<[_]>>();
380 let bound = AtomicUsize::new(usize::MAX);
381
382 self.pipeline.lend(&UpperBoundedPipelineImpl {
383 bound: CachePadded::new(bound),
384 outputs: outputs.clone(),
385 init,
386 process_item,
387 finalize,
388 });
389
390 outputs
391 .iter()
392 .map(move |output| output.lock().unwrap().take().unwrap())
393 .reduce(reduce)
394 .unwrap()
395 }
396
397 fn iter_pipeline<Output: Send>(
400 &mut self,
401 input_len: usize,
402 accum: impl Accumulator<usize, Output> + Sync,
403 reduce: impl Accumulator<Output, Output>,
404 ) -> Output {
405 self.range_orchestrator.reset_ranges(input_len);
406
407 let num_threads = self.threads.len();
408 let outputs = (0..num_threads)
409 .map(|_| Mutex::new(None))
410 .collect::<Arc<[_]>>();
411
412 self.pipeline.lend(&IterPipelineImpl {
413 outputs: outputs.clone(),
414 accum,
415 });
416
417 reduce.accumulate(
418 outputs
419 .iter()
420 .map(move |output| output.lock().unwrap().take().unwrap()),
421 )
422 }
423}
424
425impl<F: RangeFactory> Drop for ThreadPoolImpl<F> {
426 #[allow(clippy::single_match, clippy::unused_enumerate_index)]
428 fn drop(&mut self) {
429 self.pipeline.finish_workers();
430
431 log_debug!("[main thread] Joining threads in the pool...");
432 for (_i, t) in self.threads.drain(..).enumerate() {
433 let result = t.handle.join();
434 match result {
435 Ok(_) => log_debug!("[main thread] Thread {_i} joined with result: {result:?}"),
436 Err(_) => log_error!("[main thread] Thread {_i} joined with result: {result:?}"),
437 }
438 }
439 log_debug!("[main thread] Joined threads.");
440
441 #[cfg(feature = "log_parallelism")]
442 self.range_orchestrator.print_statistics();
443 }
444}
445
446trait Pipeline<R: Range> {
447 fn run(&self, worker_id: usize, range: &R);
448}
449
450struct DynLifetimeSyncPipeline<R: Range>(PhantomData<R>);
455
456impl<R: Range> LifetimeParameterized for DynLifetimeSyncPipeline<R> {
457 type T<'a> = dyn Pipeline<R> + Sync + 'a;
458}
459
460struct UpperBoundedPipelineImpl<
461 Output,
462 Accum,
463 Init: Fn() -> Accum,
464 ProcessItem: Fn(Accum, usize) -> ControlFlow<Accum, Accum>,
465 Finalize: Fn(Accum) -> Output,
466> {
467 bound: CachePadded<AtomicUsize>,
468 outputs: Arc<[Mutex<Option<Output>>]>,
469 init: Init,
470 process_item: ProcessItem,
471 finalize: Finalize,
472}
473
474impl<R, Output, Accum, Init, ProcessItem, Finalize> Pipeline<R>
475 for UpperBoundedPipelineImpl<Output, Accum, Init, ProcessItem, Finalize>
476where
477 R: Range,
478 Init: Fn() -> Accum,
479 ProcessItem: Fn(Accum, usize) -> ControlFlow<Accum, Accum>,
480 Finalize: Fn(Accum) -> Output,
481{
482 fn run(&self, worker_id: usize, range: &R) {
483 let mut accumulator = (self.init)();
484 for i in range.upper_bounded_iter(&self.bound) {
485 let acc = (self.process_item)(accumulator, i);
486 accumulator = match acc {
487 ControlFlow::Continue(acc) => acc,
488 ControlFlow::Break(acc) => {
489 self.bound.fetch_min(i, Ordering::Relaxed);
490 acc
491 }
492 };
493 }
494 let output = (self.finalize)(accumulator);
495 *self.outputs[worker_id].lock().unwrap() = Some(output);
496 }
497}
498
499struct IterPipelineImpl<Output, Accum: Accumulator<usize, Output>> {
500 outputs: Arc<[Mutex<Option<Output>>]>,
501 accum: Accum,
502}
503
504impl<R, Output, Accum> Pipeline<R> for IterPipelineImpl<Output, Accum>
505where
506 R: Range,
507 Accum: Accumulator<usize, Output>,
508{
509 fn run(&self, worker_id: usize, range: &R) {
510 let output = self.accum.accumulate(range.iter());
511 *self.outputs[worker_id].lock().unwrap() = Some(output);
512 }
513}
514
515struct ThreadContext<R: Range> {
517 id: usize,
519 range: R,
521 pipeline: Borrower<DynLifetimeSyncPipeline<R>>,
523}
524
525impl<R: Range> ThreadContext<R> {
526 fn run(&mut self) {
528 loop {
529 match self.pipeline.borrow(|pipeline| {
530 pipeline.run(self.id, &self.range);
531 }) {
532 WorkerState::Finished => break,
533 WorkerState::Ready => continue,
534 }
535 }
536 }
537}
538
539#[cfg(test)]
540mod test {
541 use super::*;
542 use crate::iter::{IntoParallelRefSource, ParallelIteratorExt, ParallelSourceExt};
543
544 #[test]
545 fn test_thread_count_try_from_usize() {
546 assert!(ThreadCount::try_from(0).is_err());
547 assert_eq!(
548 ThreadCount::try_from(1),
549 Ok(ThreadCount::Count(NonZeroUsize::try_from(1).unwrap()))
550 );
551 }
552
553 #[test]
554 fn test_build_thread_pool_available_parallelism() {
555 let mut thread_pool = ThreadPoolBuilder {
556 num_threads: ThreadCount::AvailableParallelism,
557 range_strategy: RangeStrategy::Fixed,
558 cpu_pinning: CpuPinningPolicy::No,
559 }
560 .build();
561
562 let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
563 let sum = input
564 .par_iter()
565 .with_thread_pool(&mut thread_pool)
566 .sum::<i32>();
567
568 assert_eq!(sum, 5 * 11);
569 }
570
571 #[test]
572 fn test_build_thread_pool_fixed_thread_count() {
573 let mut thread_pool = ThreadPoolBuilder {
574 num_threads: ThreadCount::try_from(4).unwrap(),
575 range_strategy: RangeStrategy::Fixed,
576 cpu_pinning: CpuPinningPolicy::No,
577 }
578 .build();
579
580 let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
581 let sum = input
582 .par_iter()
583 .with_thread_pool(&mut thread_pool)
584 .sum::<i32>();
585
586 assert_eq!(sum, 5 * 11);
587 }
588
589 #[test]
590 fn test_build_thread_pool_cpu_pinning_if_supported() {
591 let mut thread_pool = ThreadPoolBuilder {
592 num_threads: ThreadCount::AvailableParallelism,
593 range_strategy: RangeStrategy::Fixed,
594 cpu_pinning: CpuPinningPolicy::IfSupported,
595 }
596 .build();
597
598 let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
599 let sum = input
600 .par_iter()
601 .with_thread_pool(&mut thread_pool)
602 .sum::<i32>();
603
604 assert_eq!(sum, 5 * 11);
605 }
606
607 #[cfg(all(
608 not(miri),
609 any(
610 target_os = "android",
611 target_os = "dragonfly",
612 target_os = "freebsd",
613 target_os = "linux"
614 )
615 ))]
616 #[test]
617 fn test_build_thread_pool_cpu_pinning_always() {
618 let mut thread_pool = ThreadPoolBuilder {
619 num_threads: ThreadCount::AvailableParallelism,
620 range_strategy: RangeStrategy::Fixed,
621 cpu_pinning: CpuPinningPolicy::Always,
622 }
623 .build();
624
625 let input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
626 let sum = input
627 .par_iter()
628 .with_thread_pool(&mut thread_pool)
629 .sum::<i32>();
630
631 assert_eq!(sum, 5 * 11);
632 }
633
634 #[cfg(any(
635 miri,
636 not(any(
637 target_os = "android",
638 target_os = "dragonfly",
639 target_os = "freebsd",
640 target_os = "linux"
641 ))
642 ))]
643 #[test]
644 #[should_panic = "Pinning threads to CPUs is not implemented on this platform."]
645 fn test_build_thread_pool_cpu_pinning_always_not_supported() {
646 ThreadPoolBuilder {
647 num_threads: ThreadCount::AvailableParallelism,
648 range_strategy: RangeStrategy::Fixed,
649 cpu_pinning: CpuPinningPolicy::Always,
650 }
651 .build();
652 }
653
654 #[test]
655 fn test_num_threads() {
656 let thread_pool = ThreadPoolBuilder {
657 num_threads: ThreadCount::AvailableParallelism,
658 range_strategy: RangeStrategy::Fixed,
659 cpu_pinning: CpuPinningPolicy::No,
660 }
661 .build();
662 assert_eq!(
663 thread_pool.num_threads(),
664 std::thread::available_parallelism().unwrap()
665 );
666
667 let thread_pool = ThreadPoolBuilder {
668 num_threads: ThreadCount::try_from(4).unwrap(),
669 range_strategy: RangeStrategy::Fixed,
670 cpu_pinning: CpuPinningPolicy::No,
671 }
672 .build();
673 assert_eq!(
674 thread_pool.num_threads(),
675 NonZeroUsize::try_from(4).unwrap()
676 );
677 }
678}