1use rayon::prelude::*;
24use std::sync::Once;
25
26static POOL_INIT: Once = Once::new();
27
28fn ensure_pool() {
29 POOL_INIT.call_once(|| {
30 let cfg = crate::config::RuntimeConfig::global();
31 let n = cfg.pool_workers.max(1);
32 let _ = rayon::ThreadPoolBuilder::new()
33 .num_threads(n)
34 .thread_name(|i| format!("rlx-rayon-{i}"))
35 .build_global();
36 });
37}
38
39pub fn num_threads() -> usize {
41 ensure_pool();
42 rayon::current_num_threads()
43}
44
45#[inline]
51pub fn par_for<F: Fn(usize, usize) + Sync>(total: usize, min_per_thread: usize, f: &F) {
52 if total == 0 {
53 return;
54 }
55 ensure_pool();
56 let grain = min_per_thread.max(1);
57 let n_threads = (total / grain).max(1).min(num_threads());
58 if n_threads <= 1 {
59 f(0, total);
60 return;
61 }
62 let chunk = total.div_ceil(n_threads);
63 (0..n_threads).into_par_iter().for_each(|t| {
64 let off = t * chunk;
65 if off < total {
66 f(off, (off + chunk).min(total) - off);
67 }
68 });
69}
70
71#[cfg(test)]
72mod tests {
73 use super::*;
74 use std::sync::atomic::{AtomicU64, Ordering};
75
76 #[test]
77 fn par_for_sums_correctly() {
78 let data = vec![1.0f32; 10_000];
79 let total = AtomicU64::new(0);
80
81 par_for(data.len(), 100, &|off, cnt| {
82 let partial: f32 = data[off..off + cnt].iter().sum();
83 total.fetch_add(partial.to_bits() as u64, Ordering::Relaxed);
84 });
85
86 assert!(total.load(Ordering::Relaxed) > 0);
87 }
88
89 #[test]
90 fn par_for_small_is_sequential() {
91 let sum = std::sync::atomic::AtomicUsize::new(0);
92 par_for(10, 100, &|off, cnt| {
93 sum.fetch_add(cnt, Ordering::Relaxed);
94 assert_eq!(off + cnt, 10);
95 });
96 assert_eq!(sum.load(Ordering::Relaxed), 10);
97 }
98
99 #[test]
100 fn par_for_exact_sum_many_dispatches() {
101 for &n in &[256usize, 1024, 4097] {
102 let sum = std::sync::atomic::AtomicUsize::new(0);
103 par_for(n, 256, &|off, cnt| {
104 sum.fetch_add(cnt, Ordering::Relaxed);
105 assert!(off + cnt <= n);
106 });
107 assert_eq!(sum.load(Ordering::Relaxed), n);
108 }
109 }
110
111 #[test]
112 fn par_for_concurrent_callers_isolated() {
113 std::thread::scope(|s| {
114 for t in 0..4 {
115 s.spawn(move || {
116 let n = 4096 + t * 17;
117 let sum = std::sync::atomic::AtomicUsize::new(0);
118 par_for(n, 128, &|off, cnt| {
119 sum.fetch_add(cnt, Ordering::Relaxed);
120 assert!(off + cnt <= n);
121 });
122 assert_eq!(sum.load(Ordering::Relaxed), n);
123 });
124 }
125 });
126 }
127}