use crate::{
stores::Store, timers::DelayQueueTimer, utils::processor_types::SharedStore,
worker::processor_types::SyncFn, Job, JobState, JobToken, KioError, KioResult, Queue,
};
use crate::utils::main_loop;
use derive_more::Debug;
use futures::future::{Future, FutureExt};
use serde::{de::DeserializeOwned, Serialize};
use std::sync::Arc;
use uuid::Uuid;
mod metrics;
mod worker_opts;
pub use metrics::*;
use crate::error::WorkerError;
use crate::events::EventParameters;
use crate::Counter;
use arc_swap::ArcSwapOption;
use hdrhistogram::Histogram;
use tokio::{sync::Notify, task::JoinHandle};
use tokio_metrics::TaskMonitor;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
type JobMeta<D, R, P> = (
Job<D, R, P>,
JobToken,
TaskHandle,
TaskMonitor,
Histogram<u64>,
);
use crossbeam::atomic::AtomicCell;
use dashmap::DashMap;
pub type JobMap<D, R, P> = Arc<DashMap<u64, JobMeta<D, R, P>>>;
pub type Task = JoinHandle<KioResult<()>>;
pub type TaskHandle = ArcSwapOption<Task>;
pub type SharedTaskHandle = Arc<TaskHandle>;
pub type ProcessingQueue = TaskTracker;
use derive_more::IsVariant;
pub use worker_opts::WorkerOpts;
#[derive(IsVariant, Default, Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum WorkerState {
Active,
#[default]
Idle,
Closed,
}
#[cfg(feature = "tracing")]
use tracing::{debug, instrument, warn, Instrument, Span};
pub use worker_opts::MIN_DELAY_MS_LIMIT;
#[derive(Clone, Debug)]
pub struct Worker<D, R, P, S> {
pub id: Uuid,
#[cfg(feature = "tracing")]
resource_span: Span,
queue: Arc<Queue<D, R, P, S>>,
jobs_in_progress: JobMap<D, R, P>,
#[debug(skip)]
processor: WorkerCallback<D, R, P, S>,
pub opts: WorkerOpts,
cancellation_token: Arc<CancellationToken>,
pub state: Arc<AtomicCell<WorkerState>>,
processing: ProcessingQueue,
timer_pauser: Arc<AtomicCell<bool>>,
timers: DelayQueueTimer<D, R, P, S>,
block_until: Counter,
active_job_count: Arc<AtomicCell<usize>>,
continue_notifier: Arc<Notify>,
main_task: SharedTaskHandle,
}
use crate::utils::processor_types;
use processor_types::Callback;
pub type WorkerCallback<D, R, P, S> = Callback<D, R, P, S>;
impl<
D: Clone + DeserializeOwned + 'static + Send + Sync + Serialize,
R: Clone + DeserializeOwned + 'static + Serialize + Send + Sync,
P: Clone + DeserializeOwned + 'static + Send + Sync + Serialize,
S: Clone + Store<D, R, P> + Send + 'static + Sync,
> Worker<D, R, P, S>
{
#[track_caller]
pub fn new_sync<C, E>(
queue: &Queue<D, R, P, S>,
processor: C,
worker_opts: Option<WorkerOpts>,
) -> KioResult<Self>
where
KioError: From<E>,
C: Fn(SharedStore<S>, Job<D, R, P>) -> Result<R, E> + Send + Sync + 'static,
P: Send + Sync + 'static,
R: Send + Sync + 'static,
D: Send + Sync + 'static,
S: Sync + Store<D, R, P> + Send + 'static,
E: std::error::Error + Send + 'static,
{
Self::new::<C, SyncFn<C, D, R, P, S, E>, E>(queue, processor, worker_opts)
}
#[track_caller]
pub fn new_async<C, Fut, E>(
queue: &Queue<D, R, P, S>,
processor: C,
worker_opts: Option<WorkerOpts>,
) -> KioResult<Self>
where
KioError: From<E>,
C: Fn(SharedStore<S>, Job<D, R, P>) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<R, E>> + Send + 'static,
P: Send + Sync + 'static,
R: Send + Sync + 'static,
S: Sync + Store<D, R, P> + Send + 'static,
D: Send + Sync + 'static,
E: std::error::Error + Send + 'static,
{
use processor_types::AsyncFn;
Self::new::<C, AsyncFn<C, D, R, P, S, E>, E>(queue, processor, worker_opts)
}
#[track_caller]
fn new<C, F, E>(
queue: &Queue<D, R, P, S>,
processor: C,
worker_opts: Option<WorkerOpts>,
) -> KioResult<Self>
where
KioError: From<E>,
C: Into<F>,
Callback<D, R, P, S>: From<F>,
P: Send + Sync + 'static,
R: Send + Sync + 'static,
D: Send + Sync + 'static,
S: Store<D, R, P> + Send + Sync + 'static,
E: std::error::Error + Send + 'static,
{
let queue = Arc::new(queue.clone());
let jobs_in_progress: JobMap<_, _, _> = Arc::new(DashMap::new());
let f: F = processor.into();
let callback = Callback::from(f);
let id = Uuid::new_v4();
let opts = worker_opts.unwrap_or_default();
let jobs = jobs_in_progress.clone();
let cancellation_token: Arc<CancellationToken> = Arc::default();
let continue_notifier = queue.worker_notifier.clone();
let notifier = continue_notifier.clone();
let state: Arc<AtomicCell<WorkerState>> = Arc::default();
let worker_state = state.clone();
let timer_pauser: Arc<AtomicCell<bool>> = Arc::default();
let processing = TaskTracker::new();
let timers = DelayQueueTimer::new(
jobs,
id,
opts,
queue.clone(),
cancellation_token.clone(),
worker_state,
notifier,
timer_pauser.clone(),
processing.clone(),
);
#[cfg(feature = "tracing")]
let resource_span = {
let callback_type = match &callback {
Callback::Async(_) => "Async",
Callback::Sync(_) => "Sync",
};
{
let location = std::panic::Location::caller().to_string();
let queue_name = queue.name();
let worker_type = format!(
"{}-Worker({},{queue_name})",
callback_type,
id.as_u64_pair().0,
);
tracing::info_span!(parent:None, "",worker_type, ?location)
}
};
let main_task = Arc::default();
let worker = Self {
state,
timer_pauser,
main_task,
#[cfg(feature = "tracing")]
resource_span,
timers,
continue_notifier,
block_until: Arc::default(),
opts,
id,
queue,
jobs_in_progress,
processing,
processor: callback,
cancellation_token,
active_job_count: Arc::default(),
};
if worker.opts.autorun {
worker.run()?;
}
Ok(worker)
}
#[must_use]
pub fn is_running(&self) -> bool {
self.state.load().is_active() && !self.cancellation_token.is_cancelled()
}
#[must_use]
pub fn is_idle(&self) -> bool {
self.state.load().is_idle()
}
pub fn run(&self) -> KioResult<()> {
let prev = self
.state
.compare_exchange(WorkerState::Idle, WorkerState::Active);
if let Err(current) = prev {
if current.is_active() && !self.cancellation_token.is_cancelled() {
return Err(WorkerError::WorkerAlreadyRunningWithId(self.id).into());
}
if current.is_closed() || self.cancellation_token.is_cancelled() {
return Err(WorkerError::WorkerAlreadyClosed(self.id).into());
}
}
#[cfg(not(feature = "tracing"))]
let params = (
self.id,
self.cancellation_token.clone(),
self.processing.clone(),
self.opts,
self.block_until.clone(),
self.jobs_in_progress.clone(),
self.active_job_count.clone(),
self.processor.clone(),
self.queue.clone(),
self.state.clone(),
self.continue_notifier.clone(),
self.timers.clone(),
self.timer_pauser.clone(),
);
#[cfg(feature = "tracing")]
let params = (
self.resource_span.clone(),
self.id,
self.cancellation_token.clone(),
self.processing.clone(),
self.opts,
self.block_until.clone(),
self.jobs_in_progress.clone(),
self.active_job_count.clone(),
self.processor.clone(),
self.queue.clone(),
self.state.clone(),
self.continue_notifier.clone(),
self.timers.clone(),
self.timer_pauser.clone(),
);
#[cfg(feature = "tracing")]
let main = main_loop(params).instrument(self.resource_span.clone());
#[cfg(not(feature = "tracing"))]
let main = main_loop(params);
let main_task = tokio::spawn(main.boxed());
self.main_task.swap(Some(main_task.into()));
Ok(())
}
#[must_use]
pub fn closed(&self) -> bool {
self.cancellation_token.is_cancelled() || self.state.load().is_closed()
}
#[cfg_attr(feature="tracing", instrument(parent = &self.resource_span, skip(self)))]
pub fn close(&self) {
if !self.is_running() {
return;
}
#[cfg(feature = "tracing")]
debug!(
"cancel the worker's engine_loop, current_state: {:#?}",
self.state.load()
);
self.processing.close();
self.queue.resume_workers();
self.queue.worker_notifier.notify_waiters();
self.queue.pause_workers.store(false);
self.cancellation_token.cancel();
let mut main_task = self.main_task.load_full();
if let Some(handle) = main_task.take() {
#[cfg(feature = "tracing")]
{
let running_tasks = self.processing.len();
warn!("waiting for all {running_tasks} tasks to complete or abort");
}
while !handle.is_finished() {}
}
}
pub fn on<F, C>(&self, event: JobState, callback: C) -> Uuid
where
C: Fn(EventParameters<R, P>) -> F + Send + Sync + 'static,
F: Future<Output = ()> + Send + Sync + 'static,
{
self.queue.on(event, callback)
}
pub fn on_all_events<F, C>(&self, callback: C) -> Uuid
where
C: Fn(EventParameters<R, P>) -> F + Send + Sync + 'static,
F: Future<Output = ()> + Send + Sync + 'static,
{
self.queue.on_all_events(callback)
}
#[must_use]
pub fn remove_event_listener(&self, id: Uuid) -> Option<Uuid> {
self.queue.remove_event_listener(id)
}
}