apalis_core/worker/
context.rs

1//! Worker context and task tracking.
2//!
3//! [`WorkerContext`] is responsible for managing
4//! the execution lifecycle of a worker, tracking tasks, handling shutdown, and emitting
5//! lifecycle events. It also provides [`Tracked`] for wrapping and monitoring asynchronous
6//! tasks within the worker domain.
7//!
8//! ## Lifecycle
9//! A `WorkerContext` goes through distinct phases of operation:
10//!
11//! - **Pending**: Created via [`WorkerContext::new`] and must be explicitly started.
12//! - **Running**: Activated by calling [`WorkerContext::start`]. The worker becomes ready to accept and track tasks.
13//! - **Paused**: Temporarily halted via [`WorkerContext::pause`]. New tasks are blocked from execution.
14//! - **Resumed**: Brought back to `Running` using [`WorkerContext::resume`].
15//! - **Stopped**: Finalized via [`WorkerContext::stop`]. The worker shuts down gracefully, allowing tracked tasks to complete.
16//!
17//! The `WorkerContext` itself implements [`Future`], and can be `.await`ed — it resolves
18//! once the worker is shut down and all tasks have completed.
19//!
20//! ## Task Management
21//! Asynchronous tasks can be tracked using [`WorkerContext::track`], which wraps a future
22//! in a [`Tracked`] type. This ensures:
23//! - Task count is incremented before execution and decremented on completion
24//! - Shutdown is automatically triggered once all tasks are done
25//!
26//! Use [`task_count`](WorkerContext::task_count) and [`has_pending_tasks`](WorkerContext::has_pending_tasks) to inspect
27//! ongoing task state.
28//!
29//! ## Shutdown Semantics
30//! The worker is considered shutting down if:
31//! - `stop()` has been called
32//! - A shutdown signal (if configured) has been triggered
33//!
34//! Once shutdown begins, no new tasks should be accepted. Internally, a stored [`Waker`] is
35//! used to drive progress toward shutdown completion.
36//!
37//! ## Event Handling
38//! Worker lifecycle events (e.g., `Start`, `Stop`) can be emitted using [`WorkerContext::emit`].
39//! Custom handlers can be registered via [`WorkerContext::wrap_listener`] to hook into these transitions.
40//!
41//! ## Request Integration
42//! `WorkerContext` implements [`FromRequest`] so it can be extracted automatically in request
43//! handlers when using a compatible framework or service layer.
44//!
45//! ## Types
46//! - [`WorkerContext`] — shared state container for a worker
47//! - [`Tracked`] — future wrapper for task lifecycle tracking
48use std::{
49    any::type_name,
50    fmt,
51    future::Future,
52    pin::Pin,
53    sync::{
54        atomic::{AtomicBool, AtomicUsize, Ordering},
55        Arc, Mutex,
56    },
57    task::{Context, Poll, Waker},
58};
59
60use crate::{
61    error::{WorkerError, WorkerStateError},
62    monitor::shutdown::Shutdown,
63    task::{data::MissingDataError, Task},
64    task_fn::FromRequest,
65    worker::{
66        event::{Event, EventHandler},
67        state::{InnerWorkerState, WorkerState},
68    },
69};
70
71/// Utility for managing a worker's context
72///
73/// A worker context is created for each worker thread and is responsible for managing
74/// the worker's state, task tracking, and event handling.
75///
76///  **Tip**: All fields are wrapped inside [`Arc`] so it should be cheap to clone
77#[derive(Clone)]
78pub struct WorkerContext {
79    pub(super) name: Arc<String>,
80    task_count: Arc<AtomicUsize>,
81    waker: Arc<Mutex<Option<Waker>>>,
82    state: Arc<WorkerState>,
83    pub(crate) shutdown: Option<Shutdown>,
84    event_handler: EventHandler,
85    pub(super) is_ready: Arc<AtomicBool>,
86    pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("WorkerContext")
92            .field("shutdown", &["Shutdown handle"])
93            .field("task_count", &self.task_count)
94            .field("state", &self.state.load(Ordering::SeqCst))
95            .field("service", &self.service)
96            .field("is_ready", &self.is_ready)
97            .finish()
98    }
99}
100
101/// A future tracked by the worker
102#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105    ctx: WorkerContext,
106    #[pin]
107    task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111    type Output = F::Output;
112
113    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114        let this = self.project();
115
116        match this.task.poll(cx) {
117            res @ Poll::Ready(_) => res,
118            Poll::Pending => Poll::Pending,
119        }
120    }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125    fn drop(self: Pin<&mut Self>) {
126        self.ctx.end_task();
127    }
128}
129
130impl WorkerContext {
131    /// Create a new worker context
132    pub fn new<S>(name: &str) -> Self {
133        Self {
134            name: Arc::new(name.to_owned()),
135            service: type_name::<S>(),
136            task_count: Default::default(),
137            waker: Default::default(),
138            state: Default::default(),
139            shutdown: Default::default(),
140            event_handler: Arc::new(Box::new(|_, _| {
141                // noop
142            })),
143            is_ready: Default::default(),
144        }
145    }
146
147    /// Get the worker id
148    pub fn name(&self) -> &String {
149        &self.name
150    }
151
152    /// Start running the worker
153    pub fn start(&mut self) -> Result<(), WorkerError> {
154        let current_state = self.state.load(Ordering::SeqCst);
155        if current_state != InnerWorkerState::Pending {
156            return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
157        }
158        self.state
159            .store(InnerWorkerState::Running, Ordering::SeqCst);
160        self.is_ready.store(false, Ordering::SeqCst);
161        self.emit(&Event::Start);
162        info!("Worker {} started", self.name());
163        Ok(())
164    }
165
166    /// Restart running the worker
167    pub fn restart(&mut self) -> Result<(), WorkerError> {
168        self.state
169            .store(InnerWorkerState::Pending, Ordering::SeqCst);
170        self.is_ready.store(false, Ordering::SeqCst);
171        info!("Worker {} restarted", self.name());
172        Ok(())
173    }
174
175    /// Start a task that is tracked by the worker
176    pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
177        self.start_task();
178        Tracked {
179            ctx: self.clone(),
180            task,
181        }
182    }
183    /// Pauses a worker, preventing any new jobs from being polled
184    pub fn pause(&self) -> Result<(), WorkerError> {
185        if !self.is_running() {
186            return Err(WorkerError::StateError(WorkerStateError::NotRunning));
187        }
188        self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
189        info!("Worker {} paused", self.name());
190        Ok(())
191    }
192
193    /// Resume a worker that is paused
194    pub fn resume(&self) -> Result<(), WorkerError> {
195        if !self.is_paused() {
196            return Err(WorkerError::StateError(WorkerStateError::NotPaused));
197        }
198        if self.is_shutting_down() {
199            return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
200        }
201        self.state
202            .store(InnerWorkerState::Running, Ordering::SeqCst);
203        info!("Worker {} resumed", self.name());
204        Ok(())
205    }
206
207    /// Calling this function triggers shutting down the worker while waiting for any tasks to complete
208    pub fn stop(&self) -> Result<(), WorkerError> {
209        let current_state = self.state.load(Ordering::SeqCst);
210        if current_state == InnerWorkerState::Pending {
211            return Err(WorkerError::StateError(WorkerStateError::NotStarted));
212        }
213        self.state
214            .store(InnerWorkerState::Stopped, Ordering::SeqCst);
215        self.wake();
216        self.emit_ref(&Event::Stop);
217        info!("Worker {} stopped", self.name());
218        Ok(())
219    }
220
221    /// Returns if the worker is ready to consume new tasks
222    pub fn is_ready(&self) -> bool {
223        self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
224    }
225
226    /// Get the type of service
227    pub fn get_service(&self) -> &str {
228        &self.service
229    }
230
231    /// Returns whether the worker is running
232    pub fn is_running(&self) -> bool {
233        self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
234    }
235
236    /// Returns whether the worker is pending
237    pub fn is_pending(&self) -> bool {
238        self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
239    }
240
241    /// Returns whether the worker is paused
242    pub fn is_paused(&self) -> bool {
243        self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
244    }
245
246    /// Returns the current futures in the worker domain
247    /// This include futures spawned via `worker.track`
248    pub fn task_count(&self) -> usize {
249        self.task_count.load(Ordering::Relaxed)
250    }
251
252    /// Returns whether the worker has pending tasks
253    pub fn has_pending_tasks(&self) -> bool {
254        self.task_count.load(Ordering::Relaxed) > 0
255    }
256
257    /// Is the shutdown token called
258    pub fn is_shutting_down(&self) -> bool {
259        self.shutdown
260            .as_ref()
261            .map(|s| !self.is_running() || s.is_shutting_down())
262            .unwrap_or(!self.is_running())
263    }
264
265    /// Allows workers to emit events
266    pub fn emit(&mut self, event: &Event) {
267        self.emit_ref(event);
268    }
269
270    fn emit_ref(&self, event: &Event) {
271        let handler = self.event_handler.as_ref();
272        handler(self, event);
273    }
274
275    /// Wraps the event listener with a new function
276    pub fn wrap_listener<F: Fn(&WorkerContext, &Event) + Send + Sync + 'static>(&mut self, f: F) {
277        let cur = self.event_handler.clone();
278        let new: Box<dyn Fn(&WorkerContext, &Event) + Send + Sync + 'static> =
279            Box::new(move |ctx, ev| {
280                f(&ctx, &ev);
281                cur(&ctx, &ev);
282            });
283        self.event_handler = Arc::new(new);
284    }
285
286    pub(crate) fn add_waker(&self, cx: &mut Context<'_>) {
287        if let Ok(mut waker_guard) = self.waker.lock() {
288            if waker_guard
289                .as_ref()
290                .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
291            {
292                *waker_guard = Some(cx.waker().clone());
293            }
294        }
295    }
296
297    /// Checks if the stored waker matches the current one.
298    fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
299        if let Ok(waker_guard) = self.waker.lock() {
300            if let Some(stored_waker) = &*waker_guard {
301                return stored_waker.will_wake(cx.waker());
302            }
303        }
304        false
305    }
306
307    fn start_task(&self) {
308        self.task_count.fetch_add(1, Ordering::Relaxed);
309    }
310
311    fn end_task(&self) {
312        if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
313            self.wake();
314        }
315    }
316
317    pub(crate) fn wake(&self) {
318        if let Ok(waker) = self.waker.lock() {
319            if let Some(waker) = &*waker {
320                waker.wake_by_ref();
321            }
322        }
323    }
324}
325
326impl Future for WorkerContext {
327    type Output = Result<(), WorkerError>;
328
329    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
330        let task_count = self.task_count.load(Ordering::Relaxed);
331        let state = self.state.load(Ordering::SeqCst);
332        if state == InnerWorkerState::Pending {
333            return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
334        }
335        if self.is_shutting_down() && task_count == 0 {
336            Poll::Ready(Ok(()))
337        } else {
338            if !self.has_recent_waker(cx) {
339                self.add_waker(cx);
340            }
341            Poll::Pending
342        }
343    }
344}
345
346impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
347    for WorkerContext
348{
349    type Error = MissingDataError;
350    async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
351        task.parts.data.get_checked().cloned()
352    }
353}
354
355impl Drop for WorkerContext {
356    fn drop(&mut self) {
357        if Arc::strong_count(&self.state) > 1 {
358            // There are still other references to this context, so we shouldn't log a warning.
359            return;
360        }
361        if self.is_running() {
362            eprintln!(
363                "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
364                self.name(),
365                self.task_count()
366            );
367        }
368    }
369}