mod channel;
mod handle;
mod pool;
pub use channel::{WorkerChannel, WorkerCommand, WorkerMessage, WorkerReceiver, WorkerSender};
pub use handle::{WorkerHandle, WorkerState};
pub use pool::{Worker, WorkerPool};
use std::future::Future;
use std::pin::Pin;
#[cfg(feature = "async")]
mod shared_runtime {
use std::sync::OnceLock;
use tokio::runtime::{Handle, Runtime};
static RUNTIME: OnceLock<Runtime> = OnceLock::new();
pub fn handle() -> Result<Handle, String> {
if let Ok(handle) = Handle::try_current() {
return Ok(handle);
}
if let Some(runtime) = RUNTIME.get() {
Ok(runtime.handle().clone())
} else {
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.worker_threads(
std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
)
.build()
.map(|runtime| {
if RUNTIME.set(runtime).is_err() {
}
RUNTIME
.get()
.unwrap_or_else(|| {
panic!("Shared runtime must be initialized after line 74")
})
.handle()
.clone()
})
.map_err(|e| format!("Failed to create tokio runtime: {}", e))
}
}
}
#[cfg(feature = "async")]
pub(crate) use shared_runtime::handle as get_runtime_handle;
pub type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
pub type WorkerResult<T> = Result<T, WorkerError>;
#[derive(Debug, Clone)]
pub enum WorkerError {
Cancelled,
Panicked(String),
ChannelClosed,
Timeout,
Custom(String),
RuntimeCreationFailed(String),
}
impl std::fmt::Display for WorkerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WorkerError::Cancelled => write!(f, "Worker task was cancelled"),
WorkerError::Panicked(msg) => write!(f, "Worker task panicked: {}", msg),
WorkerError::ChannelClosed => write!(f, "Worker channel closed"),
WorkerError::Timeout => write!(f, "Worker task timed out"),
WorkerError::Custom(msg) => write!(f, "Worker error: {}", msg),
WorkerError::RuntimeCreationFailed(msg) => {
write!(f, "Failed to create tokio runtime: {}", msg)
}
}
}
}
impl std::error::Error for WorkerError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default)]
pub enum Priority {
Low,
#[default]
Normal,
High,
}
#[derive(Debug, Clone)]
pub struct WorkerConfig {
pub threads: usize,
pub queue_capacity: usize,
pub default_timeout_ms: Option<u64>,
}
impl Default for WorkerConfig {
fn default() -> Self {
Self {
threads: std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4),
queue_capacity: 1000,
default_timeout_ms: None,
}
}
}
impl WorkerConfig {
pub fn with_threads(threads: usize) -> Self {
Self {
threads: threads.max(1),
..Default::default()
}
}
}
pub fn run_blocking<F, T>(f: F) -> WorkerHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
WorkerHandle::spawn_blocking(f)
}
pub fn spawn<F, T>(future: F) -> WorkerHandle<T>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
WorkerHandle::spawn(future)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_worker_error_display() {
assert_eq!(
format!("{}", WorkerError::Cancelled),
"Worker task was cancelled"
);
assert_eq!(
format!("{}", WorkerError::Panicked("test".to_string())),
"Worker task panicked: test"
);
assert_eq!(
format!("{}", WorkerError::ChannelClosed),
"Worker channel closed"
);
assert_eq!(format!("{}", WorkerError::Timeout), "Worker task timed out");
assert_eq!(
format!("{}", WorkerError::Custom("error".to_string())),
"Worker error: error"
);
assert_eq!(
format!(
"{}",
WorkerError::RuntimeCreationFailed("failed".to_string())
),
"Failed to create tokio runtime: failed"
);
}
#[test]
fn test_priority_ordering() {
assert!(Priority::Low < Priority::Normal);
assert!(Priority::Normal < Priority::High);
assert!(Priority::Low < Priority::High);
assert_eq!(Priority::Normal, Priority::default());
}
#[test]
fn test_worker_config_default() {
let config = WorkerConfig::default();
assert!(config.threads >= 1);
assert_eq!(config.queue_capacity, 1000);
assert!(config.default_timeout_ms.is_none());
}
#[test]
fn test_worker_config_with_threads() {
let config = WorkerConfig::with_threads(4);
assert_eq!(config.threads, 4);
assert_eq!(config.queue_capacity, 1000);
let config = WorkerConfig::with_threads(0);
assert_eq!(config.threads, 1);
}
#[test]
fn test_run_blocking() {
let handle = run_blocking(|| 42);
drop(handle);
}
#[cfg(feature = "async")]
#[test]
fn test_shared_runtime_handle() {
let result1 = shared_runtime::handle();
assert!(result1.is_ok());
let result2 = shared_runtime::handle();
assert!(result2.is_ok());
}
}