use std::sync::OnceLock;
use ferray_core::error::{FerrayError, FerrayResult};
use rayon::ThreadPool;
static GLOBAL_POOL: OnceLock<ThreadPool> = OnceLock::new();
static POOL_CACHE: std::sync::LazyLock<
std::sync::Mutex<std::collections::HashMap<usize, std::sync::Arc<ThreadPool>>>,
> = std::sync::LazyLock::new(|| std::sync::Mutex::new(std::collections::HashMap::new()));
pub fn set_num_threads(n: usize) -> FerrayResult<()> {
let pool = rayon::ThreadPoolBuilder::new()
.num_threads(n)
.build()
.map_err(|e| FerrayError::invalid_value(format!("failed to create thread pool: {e}")))?;
GLOBAL_POOL
.set(pool)
.map_err(|_| FerrayError::invalid_value("ferray thread pool already initialized"))
}
pub fn with_num_threads<F, R>(n: usize, f: F) -> FerrayResult<R>
where
F: FnOnce() -> R + Send,
R: Send,
{
let mut cache = POOL_CACHE
.lock()
.map_err(|e| FerrayError::invalid_value(format!("pool cache lock poisoned: {e}")))?;
let pool = if let Some(existing) = cache.get(&n) {
existing.clone()
} else {
let new_pool = rayon::ThreadPoolBuilder::new()
.num_threads(n)
.build()
.map_err(|e| {
FerrayError::invalid_value(format!("failed to create cached thread pool: {e}"))
})?;
let arc = std::sync::Arc::new(new_pool);
cache.insert(n, arc.clone());
arc
};
drop(cache); Ok(pool.install(f))
}
pub const PARALLEL_THRESHOLD_ELEMENTWISE: usize = 100_000;
pub const PARALLEL_THRESHOLD_COMPUTE: usize = 50_000;
pub const PARALLEL_THRESHOLD_REDUCTION: usize = 10_000;
pub const PARALLEL_THRESHOLD_SORT: usize = 100_000;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn with_num_threads_caches_pool_ac8() {
for _ in 0..100 {
let result = with_num_threads(2, || 42).unwrap();
assert_eq!(result, 42);
}
let cache = POOL_CACHE.lock().unwrap();
assert_eq!(cache.len(), 1);
assert!(cache.contains_key(&2));
}
#[test]
fn threshold_constants() {
assert_eq!(PARALLEL_THRESHOLD_ELEMENTWISE, 100_000);
assert_eq!(PARALLEL_THRESHOLD_COMPUTE, 50_000);
assert_eq!(PARALLEL_THRESHOLD_REDUCTION, 10_000);
assert_eq!(PARALLEL_THRESHOLD_SORT, 100_000);
}
}