use std::num::NonZeroUsize;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::thread;
pub fn run_seeds<F, T>(
count: usize,
base_seed: u64,
threads: Option<NonZeroUsize>,
factory: F,
) -> Vec<T>
where
F: Fn(u64) -> T + Sync,
T: Send,
{
if count == 0 {
return Vec::new();
}
let thread_count = threads
.or_else(|| thread::available_parallelism().ok())
.map(NonZeroUsize::get)
.unwrap_or(1)
.min(count);
let cursor = AtomicUsize::new(0);
let factory = &factory;
let cursor = &cursor;
let chunks: Vec<Vec<(usize, T)>> = thread::scope(|s| {
let mut handles = Vec::with_capacity(thread_count);
for _ in 0..thread_count {
handles.push(s.spawn(move || {
let mut local: Vec<(usize, T)> = Vec::new();
loop {
let i = cursor.fetch_add(1, Ordering::Relaxed);
if i >= count {
break;
}
let seed = base_seed.wrapping_add(i as u64);
local.push((i, factory(seed)));
}
local
}));
}
handles
.into_iter()
.map(|h| h.join().expect("worker panicked"))
.collect()
});
let mut slots: Vec<Option<T>> = Vec::with_capacity(count);
slots.resize_with(count, || None);
for chunk in chunks {
for (i, v) in chunk {
debug_assert!(slots[i].is_none(), "index {} filled twice", i);
slots[i] = Some(v);
}
}
slots
.into_iter()
.map(|o| o.expect("all slots must be filled after workers join"))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_count_returns_empty() {
let results: Vec<u64> = run_seeds(0, 0, None, |s| s);
assert!(results.is_empty());
}
#[test]
fn results_come_back_in_seed_order() {
let results = run_seeds(100, 1000, None, |seed| seed);
assert_eq!(results.len(), 100);
for (i, r) in results.iter().enumerate() {
assert_eq!(*r, 1000 + i as u64);
}
}
#[test]
fn single_thread_matches_multi_thread_output() {
let factory = |seed: u64| seed.wrapping_mul(31).rotate_left(7);
let single = run_seeds(50, 42, NonZeroUsize::new(1), factory);
let multi = run_seeds(50, 42, NonZeroUsize::new(8), factory);
assert_eq!(single, multi);
}
#[test]
fn thread_count_higher_than_count_is_fine() {
let results = run_seeds(3, 0, NonZeroUsize::new(16), |seed| seed);
assert_eq!(results, vec![0, 1, 2]);
}
}