use rayon::ThreadPool;
use std::sync::{Arc, LazyLock, Mutex};
static THREAD_POOL: LazyLock<Mutex<Option<Arc<ThreadPool>>>> = LazyLock::new(|| Mutex::new(None));
pub fn set_num_threads(n: usize) {
if n == 0 {
panic!("Number of threads must be at least 1");
}
let new_pool = Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(n)
.build()
.expect("Failed to build thread pool"),
);
let mut pool = THREAD_POOL.lock().expect("Thread pool mutex poisoned");
*pool = Some(new_pool);
}
pub fn set_max_threads() {
set_num_threads(num_cpus::get());
}
pub fn set_ludicrous_speed() {
set_max_threads();
}
#[doc(hidden)]
pub fn get_thread_pool() -> Arc<ThreadPool> {
let mut pool_guard = THREAD_POOL.lock().expect("Thread pool mutex poisoned");
if pool_guard.is_none() {
let num_cpus = num_cpus::get();
let default_threads = ((num_cpus as f64 * 0.9).ceil() as usize).max(1);
*pool_guard = Some(Arc::new(
rayon::ThreadPoolBuilder::new()
.num_threads(default_threads)
.build()
.expect("Failed to build default thread pool"),
));
}
pool_guard.as_ref().unwrap().clone()
}
pub fn get_max_threads() -> usize {
get_thread_pool().current_num_threads()
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_default_thread_count() {
let num_cpus = num_cpus::get();
let expected = ((num_cpus as f64 * 0.9).ceil() as usize).max(1);
assert!(expected >= 1);
assert!(expected <= num_cpus);
}
#[test]
fn test_get_max_threads() {
let threads = get_max_threads();
assert!(threads >= 1);
assert!(threads <= num_cpus::get());
}
#[test]
#[serial]
#[should_panic(expected = "Number of threads must be at least 1")]
fn test_set_num_threads_zero_panics() {
set_num_threads(0);
}
#[test]
#[serial]
fn test_set_num_threads_reinitialize() {
set_num_threads(2);
assert_eq!(get_max_threads(), 2);
set_num_threads(4);
assert_eq!(get_max_threads(), 4);
set_num_threads(1);
assert_eq!(get_max_threads(), 1);
}
#[test]
#[serial]
fn test_set_max_threads() {
let num_cpus_val = num_cpus::get();
set_max_threads();
assert_eq!(get_max_threads(), num_cpus_val);
set_max_threads();
assert_eq!(get_max_threads(), num_cpus_val);
}
#[test]
#[serial]
fn test_set_ludicrous_speed() {
let num_cpus_val = num_cpus::get();
set_ludicrous_speed();
assert_eq!(get_max_threads(), num_cpus_val);
set_ludicrous_speed();
assert_eq!(get_max_threads(), num_cpus_val);
}
#[test]
fn test_get_thread_pool_returns_valid_pool() {
let pool = get_thread_pool();
let threads = pool.current_num_threads();
assert!(threads >= 1);
}
#[test]
#[serial]
fn test_reinitialize_changes_thread_count() {
set_num_threads(3);
let pool1 = get_thread_pool();
assert_eq!(pool1.current_num_threads(), 3);
set_num_threads(5);
let pool2 = get_thread_pool();
assert_eq!(pool2.current_num_threads(), 5);
assert_eq!(get_max_threads(), 5);
}
#[test]
#[serial]
fn test_mixed_function_reinitialization() {
set_num_threads(2);
assert_eq!(get_max_threads(), 2);
set_max_threads();
let max_threads = get_max_threads();
assert!(max_threads >= 2);
set_ludicrous_speed();
assert_eq!(get_max_threads(), max_threads);
set_num_threads(1);
assert_eq!(get_max_threads(), 1);
set_max_threads();
assert_eq!(get_max_threads(), max_threads);
}
#[test]
#[serial]
fn test_get_max_threads_reflects_set_num_threads() {
let test_values = [1, 2, 4, 8];
for n in test_values {
set_num_threads(n);
assert_eq!(
get_max_threads(),
n,
"Expected {} threads, got {}",
n,
get_max_threads()
);
}
}
#[test]
#[serial]
fn test_thread_pool_arc_cloning() {
set_num_threads(4);
let pool1 = get_thread_pool();
let pool2 = get_thread_pool();
let pool3 = get_thread_pool();
assert_eq!(pool1.current_num_threads(), 4);
assert_eq!(pool2.current_num_threads(), 4);
assert_eq!(pool3.current_num_threads(), 4);
set_num_threads(6);
let pool4 = get_thread_pool();
assert_eq!(pool4.current_num_threads(), 6);
assert_eq!(pool1.current_num_threads(), 4);
}
}