use std::future::Future;
use std::sync::Arc;
use std::thread;
use astrelis_core::profiling::profile_function;
use async_executor::{Executor, Task};
pub struct TaskPool {
executor: Arc<Executor<'static>>,
threads: Vec<thread::JoinHandle<()>>,
shutdown: Arc<std::sync::atomic::AtomicBool>,
}
impl TaskPool {
pub fn new(num_threads: usize) -> Self {
profile_function!();
assert!(num_threads > 0, "TaskPool must have at least one thread");
let executor = Arc::new(Executor::new());
let shutdown = Arc::new(std::sync::atomic::AtomicBool::new(false));
let mut threads = Vec::with_capacity(num_threads);
for i in 0..num_threads {
let exec = executor.clone();
let shutdown_flag = shutdown.clone();
let handle = thread::Builder::new()
.name(format!("astrelis-task-{}", i))
.spawn(move || {
while !shutdown_flag.load(std::sync::atomic::Ordering::Relaxed) {
if !exec.try_tick() {
thread::sleep(std::time::Duration::from_millis(1));
}
}
})
.expect("Failed to spawn task pool thread");
threads.push(handle);
}
tracing::debug!("TaskPool created with {} threads", num_threads);
Self {
executor,
threads,
shutdown,
}
}
pub fn with_num_cpus() -> Self {
Self::new(num_cpus::get())
}
pub fn default_threads() -> Self {
let num_threads = (num_cpus::get().saturating_sub(1)).max(1);
Self::new(num_threads)
}
pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
where
T: Send + 'static,
{
profile_function!();
self.executor.spawn(future)
}
pub fn thread_count(&self) -> usize {
self.threads.len()
}
pub fn shutdown(mut self) {
profile_function!();
tracing::debug!("Shutting down TaskPool with {} threads", self.threads.len());
self.shutdown
.store(true, std::sync::atomic::Ordering::Relaxed);
let threads = std::mem::take(&mut self.threads);
for handle in threads {
if let Err(e) = handle.join() {
tracing::error!("Task pool thread panicked: {:?}", e);
}
}
tracing::debug!("TaskPool shutdown complete");
}
}
impl Default for TaskPool {
fn default() -> Self {
Self::default_threads()
}
}
impl Drop for TaskPool {
fn drop(&mut self) {
self.shutdown
.store(true, std::sync::atomic::Ordering::Relaxed);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_pool_creation() {
let pool = TaskPool::new(2);
assert_eq!(pool.thread_count(), 2);
}
#[test]
fn test_spawn_and_await() {
let pool = TaskPool::new(2);
let task = pool.spawn(async { 42 });
let result = pollster::block_on(task);
assert_eq!(result, 42);
}
#[test]
fn test_multiple_tasks() {
let pool = TaskPool::new(4);
let tasks: Vec<_> = (0..10).map(|i| pool.spawn(async move { i * 2 })).collect();
let results: Vec<_> = tasks.into_iter().map(pollster::block_on).collect();
assert_eq!(results, vec![0, 2, 4, 6, 8, 10, 12, 14, 16, 18]);
}
#[test]
fn test_default_threads() {
let pool = TaskPool::default_threads();
assert!(pool.thread_count() >= 1);
assert!(pool.thread_count() <= num_cpus::get());
}
#[test]
#[should_panic(expected = "TaskPool must have at least one thread")]
fn test_zero_threads_panics() {
TaskPool::new(0);
}
#[test]
fn test_shutdown() {
let pool = TaskPool::new(2);
let _task1 = pool.spawn(async { 1 });
let _task2 = pool.spawn(async { 2 });
pool.shutdown();
}
}