use scirs2_core::parallel_ops::*;
use std::sync::{Arc, Mutex, OnceLock};
static THREAD_POOL_CONFIG: OnceLock<Arc<Mutex<ThreadPoolConfig>>> = OnceLock::new();
#[derive(Debug, Clone)]
pub struct ThreadPoolConfig {
pub num_threads: Option<usize>,
pub stack_size: Option<usize>,
pub thread_name_prefix: String,
pub pin_threads: bool,
}
impl Default for ThreadPoolConfig {
fn default() -> Self {
Self {
num_threads: None, stack_size: Some(8 * 1024 * 1024), thread_name_prefix: "scirs2-worker".to_string(),
pin_threads: false,
}
}
}
#[allow(dead_code)]
pub fn init_thread_pool(config: ThreadPoolConfig) -> Result<(), String> {
THREAD_POOL_CONFIG
.set(Arc::new(Mutex::new(config)))
.map_err(|_| "Thread pool already initialized".to_string())
}
#[allow(dead_code)]
pub fn get_thread_pool_config() -> ThreadPoolConfig {
THREAD_POOL_CONFIG
.get()
.map(|config| config.lock().expect("Operation failed").clone())
.unwrap_or_default()
}
#[allow(dead_code)]
pub fn update_thread_pool_config<F>(_updatefn: F) -> Result<(), String>
where
F: FnOnce(&mut ThreadPoolConfig),
{
if let Some(config) = THREAD_POOL_CONFIG.get() {
let mut config = config.lock().expect("Operation failed");
_updatefn(&mut *config);
Ok(())
} else {
Err("Thread pool not initialized".to_string())
}
}
#[derive(Debug, Clone)]
pub struct WorkerInfo {
pub thread_id: usize,
pub num_workers: usize,
pub cpu_affinity: Option<usize>,
}
thread_local! {
static WORKER_INFO: std::cell::RefCell<Option<WorkerInfo>> = const { std::cell::RefCell::new(None) };
}
#[allow(dead_code)]
pub fn current_worker_info() -> Option<WorkerInfo> {
WORKER_INFO.with(|info| info.borrow().clone())
}
#[allow(dead_code)]
pub fn set_worker_info(info: WorkerInfo) {
WORKER_INFO.with(|cell| {
*cell.borrow_mut() = Some(info);
});
}
#[allow(dead_code)]
pub trait ParallelIteratorExt: ParallelIterator {
fn with_threads(self, numthreads: usize) -> Self;
fn with_thread_init<F>(self, init: F) -> Self
where
F: Fn() + Send + Sync + 'static;
}
#[allow(dead_code)]
pub trait ThreadPoolArrayExt<T, D> {
fn par_map_inplace<F>(&mut self, f: F)
where
F: Fn(&mut T) + Send + Sync;
fn par_chunks_mut<F>(&mut self, chunksize: usize, f: F)
where
F: Fn(&mut [T]) + Send + Sync;
}
#[allow(dead_code)]
pub struct ThreadPoolContext {
config: ThreadPoolConfig,
}
impl ThreadPoolContext {
pub fn new() -> Self {
Self {
config: get_thread_pool_config(),
}
}
pub fn execute_parallel<F, R>(&self, operation: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
operation()
}
pub fn execute_with_threads<F, R>(&self, numthreads: usize, operation: F) -> R
where
F: FnOnce() -> R + Send,
R: Send,
{
let _prev_threads = num_threads();
let result = operation();
result
}
}
#[allow(dead_code)]
pub struct AdaptiveThreadPool {
min_threads: usize,
max_threads: usize,
current_threads: Arc<Mutex<usize>>,
load_threshold: f64,
}
impl AdaptiveThreadPool {
pub fn new(_min_threads: usize, maxthreads: usize) -> Self {
Self {
min_threads: _min_threads,
max_threads: maxthreads,
current_threads: Arc::new(Mutex::new(_min_threads)),
load_threshold: 0.8,
}
}
pub fn adjust_threads(&self, currentload: f64) {
let mut threads = self.current_threads.lock().expect("Operation failed");
if currentload > self.load_threshold && *threads < self.max_threads {
*threads = (*threads + 1).min(self.max_threads);
} else if currentload < self.load_threshold * 0.5 && *threads > self.min_threads {
*threads = (*threads - 1).max(self.min_threads);
}
}
pub fn current_thread_count(&self) -> usize {
*self.current_threads.lock().expect("Operation failed")
}
}
#[allow(dead_code)]
pub struct WorkStealingQueue<T> {
queues: Vec<Arc<Mutex<Vec<T>>>>,
}
impl<T: Send> WorkStealingQueue<T> {
pub fn new(_numqueues: usize) -> Self {
let _queues = (0.._numqueues)
.map(|_| Arc::new(Mutex::new(Vec::new())))
.collect();
Self { queues: _queues }
}
pub fn push(&self, queueid: usize, item: T) {
if let Some(queue) = self.queues.get(queueid) {
queue.lock().expect("Operation failed").push(item);
}
}
pub fn pop(&self, queueid: usize) -> Option<T> {
if let Some(queue) = self.queues.get(queueid) {
if let Some(item) = queue.lock().expect("Operation failed").pop() {
return Some(item);
}
}
for (i, queue) in self.queues.iter().enumerate() {
if i != queueid {
if let Some(item) = queue.lock().expect("Operation failed").pop() {
return Some(item);
}
}
}
None
}
}
#[allow(dead_code)]
pub fn configure_parallel_ops() {
let config = get_thread_pool_config();
if let Some(num_threads) = config.num_threads {
let _ = num_threads;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_thread_pool_config() {
let config = ThreadPoolConfig {
num_threads: Some(4),
stack_size: Some(4 * 1024 * 1024),
thread_name_prefix: "test-worker".to_string(),
pin_threads: true,
};
assert_eq!(config.num_threads, Some(4));
assert_eq!(config.thread_name_prefix, "test-worker");
}
#[test]
fn test_adaptive_thread_pool() {
let pool = AdaptiveThreadPool::new(2, 8);
assert_eq!(pool.current_thread_count(), 2);
pool.adjust_threads(0.9);
assert_eq!(pool.current_thread_count(), 3);
pool.adjust_threads(0.3);
assert_eq!(pool.current_thread_count(), 2);
}
#[test]
fn test_work_stealing_queue() {
let queue: WorkStealingQueue<i32> = WorkStealingQueue::new(2);
queue.push(0, 1);
queue.push(0, 2);
assert_eq!(queue.pop(0), Some(2));
assert_eq!(queue.pop(1), Some(1));
assert_eq!(queue.pop(0), None);
assert_eq!(queue.pop(1), None);
}
}