use std::{
sync::Arc,
thread::JoinHandle,
};
use qubit_executor::service::{
ExecutorService,
ExecutorServiceLifecycle,
StopReport,
SubmissionError,
};
use qubit_executor::task::spi::TaskEndpointPair;
use qubit_executor::{
TaskHandle,
TrackedTask,
};
use qubit_function::{
Callable,
Runnable,
};
use super::fixed_thread_pool_builder::FixedThreadPoolBuilder;
use super::fixed_thread_pool_inner::FixedThreadPoolInner;
use super::fixed_worker::FixedWorker;
use super::fixed_worker_runtime::FixedWorkerRuntime;
use crate::{
ExecutorServiceBuilderError,
PoolJob,
ThreadPoolStats,
};
pub struct FixedThreadPool {
inner: Arc<FixedThreadPoolInner>,
}
impl FixedThreadPool {
pub(crate) fn new_with_builder(
builder: FixedThreadPoolBuilder,
) -> Result<Self, ExecutorServiceBuilderError> {
let FixedThreadPoolBuilder {
pool_size,
queue_capacity,
thread_name_prefix,
stack_size,
hooks,
} = builder;
let mut worker_runtimes = Vec::with_capacity(pool_size);
for index in 0..pool_size {
let worker_runtime = FixedWorkerRuntime::new(index);
worker_runtimes.push(worker_runtime);
}
let inner = Arc::new(FixedThreadPoolInner::with_hooks(
pool_size,
queue_capacity,
hooks,
));
let mut worker_handles = Vec::with_capacity(pool_size);
for (index, worker_runtime) in worker_runtimes.into_iter().enumerate() {
inner.reserve_worker_slot();
let worker_inner = Arc::clone(&inner);
let thread_name = format!("{}-{}", thread_name_prefix, index);
let mut builder = std::thread::Builder::new().name(thread_name);
if let Some(stack_size) = stack_size {
builder = builder.stack_size(stack_size);
}
match builder.spawn(move || FixedWorker::run(worker_inner, worker_runtime)) {
Ok(handle) => worker_handles.push(handle),
Err(source) => {
inner.rollback_worker_slot();
inner.stop_after_failed_build();
join_started_workers(worker_handles);
return Err(ExecutorServiceBuilderError::SpawnWorker {
index: Some(index),
source,
});
}
}
}
Ok(Self { inner })
}
pub fn new(pool_size: usize) -> Result<Self, ExecutorServiceBuilderError> {
Self::builder().pool_size(pool_size).build()
}
pub fn builder() -> FixedThreadPoolBuilder {
FixedThreadPoolBuilder::new()
}
pub fn pool_size(&self) -> usize {
self.inner.pool_size()
}
pub fn queued_count(&self) -> usize {
self.inner.queued_count()
}
pub fn running_count(&self) -> usize {
self.inner.running_count()
}
pub fn live_worker_count(&self) -> usize {
self.inner.state.read(|state| state.live_workers)
}
pub fn stats(&self) -> ThreadPoolStats {
self.inner.stats()
}
#[inline]
pub fn join(&self) {
self.inner.wait_until_idle();
}
}
impl Default for FixedThreadPool {
fn default() -> Self {
FixedThreadPoolBuilder::default()
.build()
.expect("failed to build default FixedThreadPool")
}
}
impl Drop for FixedThreadPool {
fn drop(&mut self) {
self.inner.shutdown();
}
}
impl ExecutorService for FixedThreadPool {
type ResultHandle<R, E>
= TaskHandle<R, E>
where
R: Send + 'static,
E: Send + 'static;
type TrackedHandle<R, E>
= TrackedTask<R, E>
where
R: Send + 'static,
E: Send + 'static;
fn submit<T, E>(&self, task: T) -> Result<(), SubmissionError>
where
T: Runnable<E> + Send + 'static,
E: Send + 'static,
{
self.inner.submit(PoolJob::detached(task))
}
fn submit_callable<C, R, E>(&self, task: C) -> Result<Self::ResultHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, completion) = TaskEndpointPair::new().into_parts();
let job = PoolJob::from_task(task, completion);
self.inner.submit(job)?;
Ok(handle)
}
fn submit_tracked_callable<C, R, E>(
&self,
task: C,
) -> Result<Self::TrackedHandle<R, E>, SubmissionError>
where
C: Callable<R, E> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let (handle, completion) = TaskEndpointPair::new().into_tracked_parts();
let job = PoolJob::from_task(task, completion);
self.inner.submit(job)?;
Ok(handle)
}
fn shutdown(&self) {
self.inner.shutdown();
}
fn stop(&self) -> StopReport {
self.inner.stop()
}
fn lifecycle(&self) -> ExecutorServiceLifecycle {
self.inner.lifecycle()
}
fn is_not_running(&self) -> bool {
self.inner.is_not_running()
}
fn is_terminated(&self) -> bool {
self.inner.is_terminated()
}
fn wait_termination(&self) {
self.inner.wait_for_termination();
}
}
fn join_started_workers(worker_handles: Vec<JoinHandle<()>>) {
for worker_handle in worker_handles {
let _ignored = worker_handle.join();
}
}