1#[cfg(not(feature = "std"))]
9use alloc::vec;
10#[cfg(not(feature = "std"))]
11use alloc::vec::Vec;
12
13use core::sync::atomic::{AtomicBool, Ordering};
14
15#[cfg(feature = "parallel")]
16use rayon::prelude::*;
17
18static PARALLELISM_DISABLED: AtomicBool = AtomicBool::new(false);
20
21pub fn disable_global_parallelism() {
26 PARALLELISM_DISABLED.store(true, Ordering::SeqCst);
27}
28
29pub fn enable_global_parallelism() {
31 PARALLELISM_DISABLED.store(false, Ordering::SeqCst);
32}
33
34pub fn is_parallelism_enabled() -> bool {
36 !PARALLELISM_DISABLED.load(Ordering::SeqCst)
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
41pub enum Par {
42 Seq,
44 #[cfg(feature = "parallel")]
46 Rayon,
47 #[cfg(feature = "parallel")]
49 RayonWith(usize),
50}
51
52#[allow(clippy::derivable_impls)]
55impl Default for Par {
56 fn default() -> Self {
57 #[cfg(feature = "parallel")]
58 {
59 Par::Rayon
60 }
61 #[cfg(not(feature = "parallel"))]
62 {
63 Par::Seq
64 }
65 }
66}
67
68impl Par {
69 #[inline]
71 pub fn is_sequential(&self) -> bool {
72 match self {
73 Par::Seq => true,
74 #[cfg(feature = "parallel")]
75 _ => !is_parallelism_enabled(),
76 }
77 }
78
79 #[cfg(feature = "parallel")]
81 pub fn num_threads(&self) -> usize {
82 if !is_parallelism_enabled() {
83 return 1;
84 }
85
86 match self {
87 Par::Seq => 1,
88 Par::Rayon => rayon::current_num_threads(),
89 Par::RayonWith(n) => *n,
90 }
91 }
92
93 #[cfg(not(feature = "parallel"))]
95 pub fn num_threads(&self) -> usize {
96 1
97 }
98}
99
100#[derive(Debug, Clone, Copy)]
102pub struct ParThreshold {
103 pub min_elements: usize,
105 pub min_work_per_thread: usize,
107}
108
109impl Default for ParThreshold {
110 fn default() -> Self {
111 ParThreshold {
112 min_elements: 4096,
113 min_work_per_thread: 256,
114 }
115 }
116}
117
118impl ParThreshold {
119 pub const fn new(min_elements: usize, min_work_per_thread: usize) -> Self {
121 ParThreshold {
122 min_elements,
123 min_work_per_thread,
124 }
125 }
126
127 #[inline]
129 pub fn should_parallelize(&self, total_work: usize, par: Par) -> bool {
130 if par.is_sequential() {
131 return false;
132 }
133
134 if total_work < self.min_elements {
135 return false;
136 }
137
138 let threads = par.num_threads();
139 if threads <= 1 {
140 return false;
141 }
142
143 total_work / threads >= self.min_work_per_thread
144 }
145}
146
147#[derive(Debug, Clone, Copy)]
149pub struct WorkRange {
150 pub start: usize,
152 pub end: usize,
154}
155
156impl WorkRange {
157 #[inline]
159 pub const fn new(start: usize, end: usize) -> Self {
160 WorkRange { start, end }
161 }
162
163 #[inline]
165 pub const fn len(&self) -> usize {
166 self.end - self.start
167 }
168
169 #[inline]
171 pub const fn is_empty(&self) -> bool {
172 self.start >= self.end
173 }
174}
175
176pub fn partition_work(total: usize, num_threads: usize) -> Vec<WorkRange> {
178 if num_threads == 0 || total == 0 {
179 return vec![];
180 }
181
182 if num_threads == 1 {
183 return vec![WorkRange::new(0, total)];
184 }
185
186 let chunk_size = total.div_ceil(num_threads);
187 let mut ranges = Vec::with_capacity(num_threads);
188
189 let mut start = 0;
190 while start < total {
191 let end = (start + chunk_size).min(total);
192 ranges.push(WorkRange::new(start, end));
193 start = end;
194 }
195
196 ranges
197}
198
199#[inline]
203pub fn for_each_range<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
204where
205 F: Fn(WorkRange) + Send + Sync,
206{
207 if !threshold.should_parallelize(total, par) {
208 f(WorkRange::new(0, total));
209 return;
210 }
211
212 #[cfg(feature = "parallel")]
213 {
214 let ranges = partition_work(total, par.num_threads());
215 ranges.into_par_iter().for_each(|range| {
216 f(range);
217 });
218 }
219
220 #[cfg(not(feature = "parallel"))]
221 {
222 f(WorkRange::new(0, total));
223 }
224}
225
226#[allow(unused_variables)]
230pub fn map_reduce<T, Map, Reduce>(
231 total: usize,
232 par: Par,
233 threshold: &ParThreshold,
234 identity: T,
235 map: Map,
236 reduce: Reduce,
237) -> T
238where
239 T: Clone + Send + Sync,
240 Map: Fn(WorkRange) -> T + Send + Sync,
241 Reduce: Fn(T, T) -> T + Send + Sync,
242{
243 if !threshold.should_parallelize(total, par) {
244 return map(WorkRange::new(0, total));
245 }
246
247 #[cfg(feature = "parallel")]
248 {
249 let ranges = partition_work(total, par.num_threads());
250 ranges
251 .into_par_iter()
252 .map(map)
253 .reduce(|| identity.clone(), reduce)
254 }
255
256 #[cfg(not(feature = "parallel"))]
257 {
258 map(WorkRange::new(0, total))
259 }
260}
261
262pub fn for_each_indexed<F>(total: usize, par: Par, threshold: &ParThreshold, f: F)
264where
265 F: Fn(usize) + Send + Sync,
266{
267 if !threshold.should_parallelize(total, par) {
268 for i in 0..total {
269 f(i);
270 }
271 return;
272 }
273
274 #[cfg(feature = "parallel")]
275 {
276 (0..total).into_par_iter().for_each(f);
277 }
278
279 #[cfg(not(feature = "parallel"))]
280 {
281 for i in 0..total {
282 f(i);
283 }
284 }
285}
286
287pub trait ThreadPool: Send + Sync {
295 fn num_threads(&self) -> usize;
297
298 fn execute<F>(&self, f: F)
300 where
301 F: FnOnce() + Send + 'static;
302
303 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
305 where
306 A: FnOnce() -> RA + Send,
307 B: FnOnce() -> RB + Send,
308 RA: Send,
309 RB: Send;
310
311 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
313 where
314 F: Fn(usize) + Send + Sync;
315
316 fn map_reduce<T, Map, Reduce>(
318 &self,
319 range: core::ops::Range<usize>,
320 identity: T,
321 map: Map,
322 reduce: Reduce,
323 ) -> T
324 where
325 T: Clone + Send + Sync,
326 Map: Fn(usize) -> T + Send + Sync,
327 Reduce: Fn(T, T) -> T + Send + Sync;
328}
329
330#[derive(Debug, Clone, Copy, Default)]
332pub struct SequentialPool;
333
334impl ThreadPool for SequentialPool {
335 #[inline]
336 fn num_threads(&self) -> usize {
337 1
338 }
339
340 #[inline]
341 fn execute<F>(&self, f: F)
342 where
343 F: FnOnce() + Send + 'static,
344 {
345 f();
346 }
347
348 #[inline]
349 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
350 where
351 A: FnOnce() -> RA + Send,
352 B: FnOnce() -> RB + Send,
353 RA: Send,
354 RB: Send,
355 {
356 (a(), b())
357 }
358
359 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
360 where
361 F: Fn(usize) + Send + Sync,
362 {
363 for i in range {
364 f(i);
365 }
366 }
367
368 fn map_reduce<T, Map, Reduce>(
369 &self,
370 range: core::ops::Range<usize>,
371 identity: T,
372 map: Map,
373 reduce: Reduce,
374 ) -> T
375 where
376 T: Clone + Send + Sync,
377 Map: Fn(usize) -> T + Send + Sync,
378 Reduce: Fn(T, T) -> T + Send + Sync,
379 {
380 let mut acc = identity;
381 for i in range {
382 acc = reduce(acc, map(i));
383 }
384 acc
385 }
386}
387
388#[cfg(feature = "parallel")]
390#[derive(Debug, Clone, Copy, Default)]
391pub struct RayonGlobalPool;
392
393#[cfg(feature = "parallel")]
394impl ThreadPool for RayonGlobalPool {
395 #[inline]
396 fn num_threads(&self) -> usize {
397 rayon::current_num_threads()
398 }
399
400 #[inline]
401 fn execute<F>(&self, f: F)
402 where
403 F: FnOnce() + Send + 'static,
404 {
405 rayon::spawn(f);
406 }
407
408 #[inline]
409 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
410 where
411 A: FnOnce() -> RA + Send,
412 B: FnOnce() -> RB + Send,
413 RA: Send,
414 RB: Send,
415 {
416 rayon::join(a, b)
417 }
418
419 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
420 where
421 F: Fn(usize) + Send + Sync,
422 {
423 range.into_par_iter().for_each(f);
424 }
425
426 fn map_reduce<T, Map, Reduce>(
427 &self,
428 range: core::ops::Range<usize>,
429 identity: T,
430 map: Map,
431 reduce: Reduce,
432 ) -> T
433 where
434 T: Clone + Send + Sync,
435 Map: Fn(usize) -> T + Send + Sync,
436 Reduce: Fn(T, T) -> T + Send + Sync,
437 {
438 range
439 .into_par_iter()
440 .map(map)
441 .reduce(|| identity.clone(), reduce)
442 }
443}
444
445#[cfg(feature = "parallel")]
447pub struct CustomRayonPool {
448 pool: rayon::ThreadPool,
449}
450
451#[cfg(feature = "parallel")]
452impl CustomRayonPool {
453 pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
455 let pool = rayon::ThreadPoolBuilder::new()
456 .num_threads(num_threads)
457 .build()?;
458 Ok(CustomRayonPool { pool })
459 }
460
461 pub fn with_num_threads(n: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
466 Self::new(n)
467 }
468
469 pub fn with_builder<F>(configure: F) -> Result<Self, rayon::ThreadPoolBuildError>
471 where
472 F: FnOnce(rayon::ThreadPoolBuilder) -> rayon::ThreadPoolBuilder,
473 {
474 let builder = rayon::ThreadPoolBuilder::new();
475 let pool = configure(builder).build()?;
476 Ok(CustomRayonPool { pool })
477 }
478
479 pub fn install<R, F>(&self, f: F) -> R
481 where
482 F: FnOnce() -> R + Send,
483 R: Send,
484 {
485 self.pool.install(f)
486 }
487}
488
489#[cfg(feature = "parallel")]
490impl ThreadPool for CustomRayonPool {
491 #[inline]
492 fn num_threads(&self) -> usize {
493 self.pool.current_num_threads()
494 }
495
496 #[inline]
497 fn execute<F>(&self, f: F)
498 where
499 F: FnOnce() + Send + 'static,
500 {
501 self.pool.spawn(f);
502 }
503
504 #[inline]
505 fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
506 where
507 A: FnOnce() -> RA + Send,
508 B: FnOnce() -> RB + Send,
509 RA: Send,
510 RB: Send,
511 {
512 self.pool.join(a, b)
513 }
514
515 fn for_each<F>(&self, range: core::ops::Range<usize>, f: F)
516 where
517 F: Fn(usize) + Send + Sync,
518 {
519 self.pool.install(|| {
520 range.into_par_iter().for_each(f);
521 });
522 }
523
524 fn map_reduce<T, Map, Reduce>(
525 &self,
526 range: core::ops::Range<usize>,
527 identity: T,
528 map: Map,
529 reduce: Reduce,
530 ) -> T
531 where
532 T: Clone + Send + Sync,
533 Map: Fn(usize) -> T + Send + Sync,
534 Reduce: Fn(T, T) -> T + Send + Sync,
535 {
536 self.pool.install(|| {
537 range
538 .into_par_iter()
539 .map(map)
540 .reduce(|| identity.clone(), reduce)
541 })
542 }
543}
544
545pub struct PoolScope<'a, P: ThreadPool> {
549 pool: &'a P,
550 threshold: ParThreshold,
551}
552
553impl<'a, P: ThreadPool> PoolScope<'a, P> {
554 pub fn new(pool: &'a P) -> Self {
556 PoolScope {
557 pool,
558 threshold: ParThreshold::default(),
559 }
560 }
561
562 pub fn with_threshold(pool: &'a P, threshold: ParThreshold) -> Self {
564 PoolScope { pool, threshold }
565 }
566
567 #[inline]
569 pub fn num_threads(&self) -> usize {
570 self.pool.num_threads()
571 }
572
573 #[inline]
575 pub fn join<A, B, RA, RB>(&self, a: A, b: B) -> (RA, RB)
576 where
577 A: FnOnce() -> RA + Send,
578 B: FnOnce() -> RB + Send,
579 RA: Send,
580 RB: Send,
581 {
582 self.pool.join(a, b)
583 }
584
585 pub fn for_each<F>(&self, total: usize, f: F)
587 where
588 F: Fn(usize) + Send + Sync,
589 {
590 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
591 for i in 0..total {
592 f(i);
593 }
594 } else {
595 self.pool.for_each(0..total, f);
596 }
597 }
598
599 pub fn for_each_range<F>(&self, total: usize, f: F)
601 where
602 F: Fn(WorkRange) + Send + Sync,
603 {
604 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
605 f(WorkRange::new(0, total));
606 } else {
607 let ranges = partition_work(total, self.pool.num_threads());
608 for range in ranges {
609 f(range);
610 }
611 }
612 }
613
614 pub fn map_reduce<T, Map, Reduce>(
616 &self,
617 total: usize,
618 identity: T,
619 map: Map,
620 reduce: Reduce,
621 ) -> T
622 where
623 T: Clone + Send + Sync,
624 Map: Fn(usize) -> T + Send + Sync,
625 Reduce: Fn(T, T) -> T + Send + Sync,
626 {
627 if total < self.threshold.min_elements || self.pool.num_threads() <= 1 {
628 let mut acc = identity;
629 for i in 0..total {
630 acc = reduce(acc, map(i));
631 }
632 acc
633 } else {
634 self.pool.map_reduce(0..total, identity, map, reduce)
635 }
636 }
637}
638
639#[cfg(feature = "parallel")]
641pub fn default_pool() -> RayonGlobalPool {
642 RayonGlobalPool
643}
644
645#[cfg(not(feature = "parallel"))]
647pub fn default_pool() -> SequentialPool {
648 SequentialPool
649}
650
651#[cfg(feature = "parallel")]
655pub fn with_default_pool<R, F>(f: F) -> R
656where
657 F: FnOnce(PoolScope<'_, RayonGlobalPool>) -> R,
658{
659 let pool = RayonGlobalPool;
660 f(PoolScope::new(&pool))
661}
662
663#[cfg(not(feature = "parallel"))]
665pub fn with_default_pool<R, F>(f: F) -> R
666where
667 F: FnOnce(PoolScope<'_, SequentialPool>) -> R,
668{
669 let pool = SequentialPool;
670 f(PoolScope::new(&pool))
671}
672
673#[derive(Debug, Clone, Default)]
695pub struct OxiblasThreadConfig {
696 pub num_threads: usize,
698 pub stack_size: usize,
700 pub thread_name: Option<String>,
702}
703
704impl OxiblasThreadConfig {
705 pub fn new() -> Self {
707 Self::default()
708 }
709
710 pub fn num_threads(mut self, n: usize) -> Self {
712 self.num_threads = n;
713 self
714 }
715
716 pub fn stack_size(mut self, bytes: usize) -> Self {
718 self.stack_size = bytes;
719 self
720 }
721
722 pub fn thread_name(mut self, name: impl Into<String>) -> Self {
724 self.thread_name = Some(name.into());
725 self
726 }
727
728 pub fn effective_threads(&self) -> usize {
731 if self.num_threads == 0 {
732 std::thread::available_parallelism()
733 .map(|n| n.get())
734 .unwrap_or(1)
735 } else {
736 self.num_threads
737 }
738 }
739
740 #[cfg(feature = "parallel")]
744 pub fn build_pool(&self) -> Result<CustomRayonPool, rayon::ThreadPoolBuildError> {
745 let mut builder = rayon::ThreadPoolBuilder::new().num_threads(self.effective_threads());
746 if self.stack_size > 0 {
747 builder = builder.stack_size(self.stack_size);
748 }
749 if let Some(name) = &self.thread_name {
750 let name = name.clone();
751 builder = builder.thread_name(move |i| format!("{name}-{i}"));
752 }
753 let pool = builder.build()?;
754 Ok(CustomRayonPool { pool })
755 }
756}
757
758#[cfg(feature = "std")]
765trait AnyPool: Send + Sync {
766 fn num_threads_dyn(&self) -> usize;
767}
768
769#[cfg(all(feature = "std", feature = "parallel"))]
770impl AnyPool for CustomRayonPool {
771 fn num_threads_dyn(&self) -> usize {
772 self.num_threads()
773 }
774}
775
776#[cfg(feature = "std")]
777impl AnyPool for SequentialPool {
778 fn num_threads_dyn(&self) -> usize {
779 1
780 }
781}
782
783#[cfg(feature = "std")]
784static GLOBAL_POOL: std::sync::OnceLock<Box<dyn AnyPool>> = std::sync::OnceLock::new();
785
786#[cfg(all(feature = "std", feature = "parallel"))]
809pub fn set_global_thread_pool(pool: CustomRayonPool) {
810 let _ = GLOBAL_POOL.set(Box::new(pool));
811}
812
813#[cfg(all(feature = "std", not(feature = "parallel")))]
816pub fn set_global_thread_pool(pool: SequentialPool) {
817 let _ = GLOBAL_POOL.set(Box::new(pool));
818}
819
820#[cfg(feature = "std")]
823pub fn global_num_threads() -> usize {
824 GLOBAL_POOL.get().map(|p| p.num_threads_dyn()).unwrap_or(1)
825}
826
827#[cfg(feature = "parallel")]
844pub fn with_thread_count(n: usize, f: impl FnOnce() + Send) {
845 let pool = rayon::ThreadPoolBuilder::new().num_threads(n).build();
846 match pool {
847 Ok(p) => p.install(f),
848 Err(_) => f(), }
850}
851
852#[cfg(not(feature = "parallel"))]
854pub fn with_thread_count(_n: usize, f: impl FnOnce()) {
855 f();
856}
857
858#[cfg(feature = "parallel")]
869pub struct ThreadLocalAccum<T> {
870 values: Vec<std::sync::Mutex<T>>,
871}
872
873#[cfg(feature = "parallel")]
874impl<T: Clone + Send> ThreadLocalAccum<T> {
875 pub fn new(identity: T) -> Self {
877 let num_threads = rayon::current_num_threads();
878 let values = (0..num_threads)
879 .map(|_| std::sync::Mutex::new(identity.clone()))
880 .collect();
881 ThreadLocalAccum { values }
882 }
883
884 pub fn get(&self) -> std::sync::MutexGuard<'_, T> {
886 let thread_idx = rayon::current_thread_index().unwrap_or(0) % self.values.len();
887 self.values[thread_idx]
888 .lock()
889 .unwrap_or_else(|poisoned| poisoned.into_inner())
890 }
891
892 pub fn reduce<F>(self, f: F) -> T
894 where
895 F: Fn(T, T) -> T,
896 {
897 self.values
898 .into_iter()
899 .map(|m| {
900 m.into_inner()
901 .unwrap_or_else(|poisoned| poisoned.into_inner())
902 })
903 .reduce(f)
904 .expect("ThreadLocalAccum should have at least one value")
905 }
906}
907
908#[cfg(test)]
909mod tests {
910 use super::*;
911
912 #[test]
913 fn test_partition_work() {
914 let ranges = partition_work(100, 4);
915 assert_eq!(ranges.len(), 4);
916
917 let mut covered = [false; 100];
919 for range in &ranges {
920 for (offset, covered_elem) in covered[range.start..range.end].iter_mut().enumerate() {
921 let i = range.start + offset;
922 assert!(!*covered_elem, "Overlap at {}", i);
923 *covered_elem = true;
924 }
925 }
926 assert!(covered.iter().all(|&x| x), "Not all elements covered");
927 }
928
929 #[test]
930 fn test_partition_work_uneven() {
931 let ranges = partition_work(10, 4);
932
933 let total: usize = ranges.iter().map(|r| r.len()).sum();
935 assert_eq!(total, 10);
936 }
937
938 #[test]
939 fn test_partition_work_single() {
940 let ranges = partition_work(100, 1);
941 assert_eq!(ranges.len(), 1);
942 assert_eq!(ranges[0].start, 0);
943 assert_eq!(ranges[0].end, 100);
944 }
945
946 #[test]
947 fn test_threshold() {
948 let threshold = ParThreshold::new(100, 10);
949
950 assert!(!threshold.should_parallelize(50, Par::Seq));
951 assert!(!threshold.should_parallelize(50, Par::default()));
952
953 #[cfg(feature = "parallel")]
954 {
955 assert!(threshold.should_parallelize(1000, Par::Rayon));
957 }
958 }
959
960 #[test]
961 fn test_global_parallelism() {
962 let was_enabled = is_parallelism_enabled();
964
965 disable_global_parallelism();
966 assert!(!is_parallelism_enabled());
967
968 enable_global_parallelism();
969 assert!(is_parallelism_enabled());
970
971 if !was_enabled {
973 disable_global_parallelism();
974 }
975 }
976
977 #[test]
978 fn test_sequential_map_reduce() {
979 let result = map_reduce(
980 100,
981 Par::Seq,
982 &ParThreshold::default(),
983 0usize,
984 |range| range.len(),
985 |a, b| a + b,
986 );
987 assert_eq!(result, 100);
988 }
989
990 #[test]
992 fn test_sequential_pool() {
993 let pool = SequentialPool;
994
995 assert_eq!(pool.num_threads(), 1);
996
997 let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
999 assert_eq!(a, 2);
1000 assert_eq!(b, 4);
1001
1002 let sum = std::sync::atomic::AtomicUsize::new(0);
1004 pool.for_each(0..10, |i| {
1005 sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
1006 });
1007 assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
1008
1009 let result = pool.map_reduce(0..10, 0, |i| i, |a, b| a + b);
1011 assert_eq!(result, 45);
1012 }
1013
1014 #[test]
1015 fn test_pool_scope() {
1016 let pool = SequentialPool;
1017 let scope = PoolScope::new(&pool);
1018
1019 assert_eq!(scope.num_threads(), 1);
1020
1021 let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
1023 assert_eq!(result, (0..100).sum::<usize>());
1024
1025 let sum = std::sync::atomic::AtomicUsize::new(0);
1027 scope.for_each(10, |i| {
1028 sum.fetch_add(i, std::sync::atomic::Ordering::SeqCst);
1029 });
1030 assert_eq!(sum.load(std::sync::atomic::Ordering::SeqCst), 45);
1031 }
1032
1033 #[test]
1034 fn test_pool_scope_with_threshold() {
1035 let pool = SequentialPool;
1036 let threshold = ParThreshold::new(50, 10);
1037 let scope = PoolScope::with_threshold(&pool, threshold);
1038
1039 let result = scope.map_reduce(100, 0usize, |i| i, |a, b| a + b);
1041 assert_eq!(result, (0..100).sum::<usize>());
1042 }
1043
1044 #[test]
1045 fn test_default_pool() {
1046 let pool = default_pool();
1047 assert!(pool.num_threads() >= 1);
1049 }
1050
1051 #[test]
1052 fn test_with_default_pool() {
1053 let result = with_default_pool(|scope| scope.num_threads());
1054 assert!(result >= 1);
1055 }
1056
1057 #[cfg(feature = "parallel")]
1058 #[test]
1059 fn test_rayon_global_pool() {
1060 let pool = RayonGlobalPool;
1061
1062 assert!(pool.num_threads() >= 1);
1064
1065 let (a, b) = pool.join(|| 1 + 1, || 2 + 2);
1067 assert_eq!(a, 2);
1068 assert_eq!(b, 4);
1069
1070 let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
1072 assert_eq!(result, (0..100).sum::<usize>());
1073 }
1074
1075 #[cfg(feature = "parallel")]
1076 #[test]
1077 fn test_custom_rayon_pool() {
1078 let pool = CustomRayonPool::new(2).expect("Failed to create pool");
1079
1080 assert_eq!(pool.num_threads(), 2);
1081
1082 let result = pool.map_reduce(0..100, 0, |i| i, |a, b| a + b);
1084 assert_eq!(result, (0..100).sum::<usize>());
1085
1086 let result = pool.install(|| (0..100).into_par_iter().sum::<usize>());
1088 assert_eq!(result, (0..100).sum());
1089 }
1090
1091 #[test]
1094 fn test_thread_config_default() {
1095 let cfg = OxiblasThreadConfig::default();
1096 assert_eq!(cfg.num_threads, 0);
1097 assert_eq!(cfg.stack_size, 0);
1098 assert!(cfg.thread_name.is_none());
1099 }
1100
1101 #[test]
1102 fn test_thread_config_builder() {
1103 let cfg = OxiblasThreadConfig::new()
1104 .num_threads(4)
1105 .stack_size(1024 * 1024)
1106 .thread_name("oxiblas-worker");
1107 assert_eq!(cfg.num_threads, 4);
1108 assert_eq!(cfg.stack_size, 1024 * 1024);
1109 assert_eq!(cfg.thread_name.as_deref(), Some("oxiblas-worker"));
1110 }
1111
1112 #[test]
1113 fn test_thread_config_effective_threads_zero() {
1114 let cfg = OxiblasThreadConfig::new().num_threads(0);
1115 assert!(cfg.effective_threads() >= 1);
1117 }
1118
1119 #[test]
1120 fn test_thread_config_effective_threads_explicit() {
1121 let cfg = OxiblasThreadConfig::new().num_threads(3);
1122 assert_eq!(cfg.effective_threads(), 3);
1123 }
1124
1125 #[cfg(feature = "parallel")]
1126 #[test]
1127 fn test_custom_rayon_pool_with_num_threads() {
1128 let pool = CustomRayonPool::with_num_threads(2).expect("build pool");
1129 assert_eq!(pool.num_threads(), 2);
1130 let sum: usize = pool.map_reduce(0..50, 0, |i| i, |a, b| a + b);
1131 assert_eq!(sum, (0..50).sum::<usize>());
1132 }
1133
1134 #[cfg(feature = "parallel")]
1135 #[test]
1136 fn test_oxiblas_thread_config_build_pool() {
1137 let cfg = OxiblasThreadConfig::new().num_threads(2);
1138 let pool = cfg.build_pool().expect("build pool");
1139 assert_eq!(pool.num_threads(), 2);
1140 }
1141
1142 #[cfg(feature = "parallel")]
1143 #[test]
1144 fn test_with_thread_count() {
1145 with_thread_count(2, || {
1147 assert_eq!(rayon::current_num_threads(), 2);
1148 });
1149 }
1150
1151 #[cfg(not(feature = "parallel"))]
1152 #[test]
1153 fn test_with_thread_count_sequential() {
1154 let mut called = false;
1156 with_thread_count(4, || {
1157 called = true;
1158 });
1159 assert!(called);
1160 }
1161
1162 #[cfg(feature = "std")]
1163 #[test]
1164 fn test_global_num_threads_default() {
1165 assert!(global_num_threads() >= 1);
1168 }
1169}