use rayon::prelude::*;
use std::sync::Once;
static POOL_INIT: Once = Once::new();
fn ensure_pool() {
POOL_INIT.call_once(|| {
let cfg = crate::config::RuntimeConfig::global();
let n = cfg.pool_workers.max(1);
let _ = rayon::ThreadPoolBuilder::new()
.num_threads(n)
.thread_name(|i| format!("rlx-rayon-{i}"))
.build_global();
});
}
pub fn num_threads() -> usize {
ensure_pool();
rayon::current_num_threads()
}
#[inline]
pub fn par_for<F: Fn(usize, usize) + Sync>(total: usize, min_per_thread: usize, f: &F) {
if total == 0 {
return;
}
ensure_pool();
let grain = min_per_thread.max(1);
let n_threads = (total / grain).max(1).min(num_threads());
if n_threads <= 1 {
f(0, total);
return;
}
let chunk = total.div_ceil(n_threads);
(0..n_threads).into_par_iter().for_each(|t| {
let off = t * chunk;
if off < total {
f(off, (off + chunk).min(total) - off);
}
});
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU64, Ordering};
#[test]
fn par_for_sums_correctly() {
let data = vec![1.0f32; 10_000];
let total = AtomicU64::new(0);
par_for(data.len(), 100, &|off, cnt| {
let partial: f32 = data[off..off + cnt].iter().sum();
total.fetch_add(partial.to_bits() as u64, Ordering::Relaxed);
});
assert!(total.load(Ordering::Relaxed) > 0);
}
#[test]
fn par_for_small_is_sequential() {
let sum = std::sync::atomic::AtomicUsize::new(0);
par_for(10, 100, &|off, cnt| {
sum.fetch_add(cnt, Ordering::Relaxed);
assert_eq!(off + cnt, 10);
});
assert_eq!(sum.load(Ordering::Relaxed), 10);
}
#[test]
fn par_for_exact_sum_many_dispatches() {
for &n in &[256usize, 1024, 4097] {
let sum = std::sync::atomic::AtomicUsize::new(0);
par_for(n, 256, &|off, cnt| {
sum.fetch_add(cnt, Ordering::Relaxed);
assert!(off + cnt <= n);
});
assert_eq!(sum.load(Ordering::Relaxed), n);
}
}
#[test]
fn par_for_concurrent_callers_isolated() {
std::thread::scope(|s| {
for t in 0..4 {
s.spawn(move || {
let n = 4096 + t * 17;
let sum = std::sync::atomic::AtomicUsize::new(0);
par_for(n, 128, &|off, cnt| {
sum.fetch_add(cnt, Ordering::Relaxed);
assert!(off + cnt <= n);
});
assert_eq!(sum.load(Ordering::Relaxed), n);
});
}
});
}
}