use std::sync::mpsc::{self, Receiver, Sender};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::{Duration, Instant};
use crate::error::{IoError, Result};
#[derive(Debug, Clone)]
pub struct ThreadPoolConfig {
pub io_threads: usize,
pub cpu_threads: usize,
pub max_queue_size: usize,
pub keep_alive: Duration,
pub work_stealing: bool,
}
impl Default for ThreadPoolConfig {
fn default() -> Self {
let available_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self {
io_threads: available_cores / 2,
cpu_threads: available_cores / 2,
max_queue_size: 1000,
keep_alive: Duration::from_secs(60),
work_stealing: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum WorkType {
IO,
CPU,
}
pub struct WorkItem {
pub work_type: WorkType,
pub task: Box<dyn FnOnce() -> Result<()> + Send>,
pub task_id: Option<u64>,
}
#[derive(Debug, Clone, Default)]
pub struct ThreadPoolStats {
pub tasks_submitted: u64,
pub tasks_completed: u64,
pub tasks_failed: u64,
pub total_execution_time_ms: f64,
pub avg_execution_time_ms: f64,
pub current_queue_size: usize,
pub max_queue_size_reached: usize,
pub active_threads: usize,
}
pub struct ThreadPool {
io_workers: Vec<Worker>,
cpu_workers: Vec<Worker>,
io_sender: Sender<WorkItem>,
cpu_sender: Sender<WorkItem>,
#[allow(dead_code)]
config: ThreadPoolConfig,
stats: Arc<Mutex<ThreadPoolStats>>,
shutdown: Arc<Mutex<bool>>,
}
struct Worker {
#[allow(dead_code)]
id: usize,
thread: Option<JoinHandle<()>>,
}
impl ThreadPool {
pub fn new(config: ThreadPoolConfig) -> Self {
let (io_sender, io_receiver) = mpsc::channel();
let (cpu_sender, cpu_receiver) = mpsc::channel();
let stats = Arc::new(Mutex::new(ThreadPoolStats::default()));
let shutdown = Arc::new(Mutex::new(false));
let io_receiver = Arc::new(Mutex::new(io_receiver));
let mut io_workers = Vec::with_capacity(config.io_threads);
for id in 0..config.io_threads {
let receiver = Arc::clone(&io_receiver);
let stats_clone = Arc::clone(&stats);
let shutdown_clone = Arc::clone(&shutdown);
let thread = thread::spawn(move || {
Self::worker_loop(id, receiver, stats_clone, shutdown_clone, WorkType::IO)
});
io_workers.push(Worker {
id,
thread: Some(thread),
});
}
let cpu_receiver = Arc::new(Mutex::new(cpu_receiver));
let mut cpu_workers = Vec::with_capacity(config.cpu_threads);
for id in 0..config.cpu_threads {
let receiver = Arc::clone(&cpu_receiver);
let stats_clone = Arc::clone(&stats);
let shutdown_clone = Arc::clone(&shutdown);
let thread = thread::spawn(move || {
Self::worker_loop(id, receiver, stats_clone, shutdown_clone, WorkType::CPU)
});
cpu_workers.push(Worker {
id,
thread: Some(thread),
});
}
Self {
io_workers,
cpu_workers,
io_sender,
cpu_sender,
config,
stats,
shutdown,
}
}
pub fn submit<F>(&self, worktype: WorkType, task: F) -> Result<()>
where
F: FnOnce() -> Result<()> + Send + 'static,
{
let work_item = WorkItem {
work_type: worktype,
task: Box::new(task),
task_id: None,
};
{
let mut stats = self.stats.lock().expect("Operation failed");
stats.tasks_submitted += 1;
}
match worktype {
WorkType::IO => {
self.io_sender.send(work_item).map_err(|_| {
IoError::Other("Failed to submit I/O task: thread pool shut down".to_string())
})?;
}
WorkType::CPU => {
self.cpu_sender.send(work_item).map_err(|_| {
IoError::Other("Failed to submit CPU task: thread pool shut down".to_string())
})?;
}
}
Ok(())
}
pub fn submit_batch<F>(&self, worktype: WorkType, tasks: Vec<F>) -> Result<()>
where
F: FnOnce() -> Result<()> + Send + 'static,
{
for task in tasks {
self.submit(worktype, task)?;
}
Ok(())
}
pub fn parallel_map<T, F, R>(
&self,
items: Vec<T>,
_work_type: WorkType,
func: F,
) -> Result<Vec<R>>
where
T: Send + 'static,
F: Fn(T) -> R + Send + Sync + 'static,
R: Send + 'static + std::fmt::Debug,
{
use std::sync::mpsc;
let func = Arc::new(func);
let (sender, receiver) = mpsc::channel();
let mut handles = Vec::new();
let num_items = items.len();
for (index, item) in items.into_iter().enumerate() {
let func_clone = Arc::clone(&func);
let sender_clone = sender.clone();
let handle = thread::spawn(move || {
let result = func_clone(item);
let _ = sender_clone.send((index, result));
});
handles.push(handle);
}
drop(sender);
let mut results: Vec<Option<R>> = (0..num_items).map(|_| None).collect();
for _ in 0..num_items {
match receiver.recv() {
Ok((index, result)) => {
results[index] = Some(result);
}
Err(_) => {
return Err(IoError::Other(
"Failed to receive result from worker thread".to_string(),
))
}
}
}
for handle in handles {
handle
.join()
.map_err(|_| IoError::Other("Thread panicked".to_string()))?;
}
let final_results: Result<Vec<R>> = results
.into_iter()
.enumerate()
.map(|(i, opt)| {
opt.ok_or_else(|| IoError::Other(format!("Missing result for item {}", i)))
})
.collect();
final_results
}
pub fn get_stats(&self) -> ThreadPoolStats {
self.stats.lock().expect("Operation failed").clone()
}
pub fn pending_tasks(&self) -> usize {
0
}
pub fn wait_for_completion(&self) -> Result<()> {
thread::sleep(Duration::from_millis(100));
Ok(())
}
pub fn shutdown(mut self) -> Result<()> {
{
let mut shutdown = self.shutdown.lock().expect("Operation failed");
*shutdown = true;
}
drop(self.io_sender);
drop(self.cpu_sender);
for worker in &mut self.io_workers {
if let Some(thread) = worker.thread.take() {
thread
.join()
.map_err(|_| IoError::Other("Failed to join I/O worker thread".to_string()))?;
}
}
for worker in &mut self.cpu_workers {
if let Some(thread) = worker.thread.take() {
thread
.join()
.map_err(|_| IoError::Other("Failed to join CPU worker thread".to_string()))?;
}
}
Ok(())
}
fn worker_loop(
id: usize,
receiver: Arc<Mutex<Receiver<WorkItem>>>,
stats: Arc<Mutex<ThreadPoolStats>>,
shutdown: Arc<Mutex<bool>>,
worker_type: WorkType,
) {
loop {
if *shutdown.lock().expect("Operation failed") {
break;
}
let work_item = {
let receiver = receiver.lock().expect("Operation failed");
receiver.recv_timeout(Duration::from_millis(100))
};
match work_item {
Ok(item) => {
let start_time = Instant::now();
let result = (item.task)();
let execution_time = start_time.elapsed().as_millis() as f64;
{
let mut stats_guard = stats.lock().expect("Operation failed");
match result {
Ok(_) => {
stats_guard.tasks_completed += 1;
}
Err(_) => {
stats_guard.tasks_failed += 1;
}
}
stats_guard.total_execution_time_ms += execution_time;
let total_tasks = stats_guard.tasks_completed + stats_guard.tasks_failed;
if total_tasks > 0 {
stats_guard.avg_execution_time_ms =
stats_guard.total_execution_time_ms / total_tasks as f64;
}
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {
continue;
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
break;
}
}
}
println!("Worker {id} ({worker_type:?}) shutting down");
}
}
static GLOBAL_THREAD_POOL: std::sync::OnceLock<ThreadPool> = std::sync::OnceLock::new();
#[allow(dead_code)]
pub fn init_global_thread_pool(config: ThreadPoolConfig) {
let _ = GLOBAL_THREAD_POOL.set(ThreadPool::new(config));
}
#[allow(dead_code)]
pub fn global_thread_pool() -> &'static ThreadPool {
GLOBAL_THREAD_POOL.get_or_init(|| ThreadPool::new(ThreadPoolConfig::default()))
}
#[allow(dead_code)]
pub fn execute<F>(work_type: WorkType, task: F) -> Result<()>
where
F: FnOnce() -> Result<()> + Send + 'static,
{
global_thread_pool().submit(work_type, task)
}
#[allow(dead_code)]
pub fn optimal_config() -> ThreadPoolConfig {
let available_cores = thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
let io_threads = if available_cores <= 2 {
1
} else if available_cores <= 4 {
2
} else {
available_cores / 2
};
let cpu_threads = available_cores - io_threads;
ThreadPoolConfig {
io_threads,
cpu_threads,
max_queue_size: available_cores * 100, keep_alive: Duration::from_secs(30),
work_stealing: available_cores > 2,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_thread_pool_creation() {
let config = ThreadPoolConfig::default();
let pool = ThreadPool::new(config.clone());
assert_eq!(pool.io_workers.len(), config.io_threads);
assert_eq!(pool.cpu_workers.len(), config.cpu_threads);
}
#[test]
fn test_task_submission() {
let pool = ThreadPool::new(ThreadPoolConfig::default());
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = Arc::clone(&counter);
let result = pool.submit(WorkType::CPU, move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
});
assert!(result.is_ok());
thread::sleep(Duration::from_millis(100));
let stats = pool.get_stats();
assert!(stats.tasks_submitted > 0);
}
#[test]
fn test_batch_submission() {
let pool = ThreadPool::new(ThreadPoolConfig::default());
let counter = Arc::new(AtomicUsize::new(0));
let tasks: Vec<_> = (0..10)
.map(|_| {
let counter_clone = Arc::clone(&counter);
move || {
counter_clone.fetch_add(1, Ordering::SeqCst);
Ok(())
}
})
.collect();
let result = pool.submit_batch(WorkType::CPU, tasks);
assert!(result.is_ok());
thread::sleep(Duration::from_millis(200));
let stats = pool.get_stats();
assert_eq!(stats.tasks_submitted, 10);
}
#[test]
fn test_optimal_config() {
let config = optimal_config();
assert!(config.io_threads > 0);
assert!(config.cpu_threads > 0);
assert!(config.max_queue_size > 0);
}
#[test]
fn test_global_thread_pool() {
let result = execute(WorkType::CPU, || {
println!("Global thread pool test");
Ok(())
});
assert!(result.is_ok());
}
}