use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::time::Duration;
use parking_lot::Mutex;
use smol::block_on;
use vortex_error::VortexExpect;
#[derive(Clone)]
pub struct CurrentThreadWorkerPool {
executor: Arc<smol::Executor<'static>>,
state: Arc<Mutex<PoolState>>,
}
impl CurrentThreadWorkerPool {
pub(super) fn new(executor: Arc<smol::Executor<'static>>) -> Self {
Self {
executor,
state: Arc::new(Mutex::new(PoolState::default())),
}
}
pub fn set_workers_to_available_parallelism(&self) {
let n = std::thread::available_parallelism()
.map(|n| n.get().saturating_sub(1).max(1))
.unwrap_or(1);
self.set_workers(n);
}
pub fn set_workers(&self, n: usize) {
let mut state = self.state.lock();
let current = state.workers.len();
if n > current {
for _ in current..n {
let shutdown = Arc::new(AtomicBool::new(false));
let executor = Arc::clone(&self.executor);
let shutdown_clone = Arc::clone(&shutdown);
std::thread::Builder::new()
.name("vortex-current-thread-worker".to_string())
.spawn(move || {
block_on(executor.run(async move {
while !shutdown_clone.load(Ordering::Relaxed) {
smol::Timer::after(Duration::from_millis(100)).await;
}
}))
})
.vortex_expect("Failed to spawn current thread worker");
state.workers.push(WorkerHandle { shutdown });
}
} else if n < current {
while state.workers.len() > n {
if let Some(worker) = state.workers.pop() {
worker.shutdown.store(true, Ordering::Relaxed);
}
}
}
}
pub fn worker_count(&self) -> usize {
self.state.lock().workers.len()
}
}
#[derive(Default)]
struct PoolState {
workers: Vec<WorkerHandle>,
}
struct WorkerHandle {
shutdown: Arc<AtomicBool>,
}
impl Drop for CurrentThreadWorkerPool {
fn drop(&mut self) {
let mut state = self.state.lock();
for worker in state.workers.drain(..) {
worker.shutdown.store(true, Ordering::Relaxed);
}
}
}