apalis_core/worker/
context.rs

1//! Worker context and task tracking.
2//!
3//! [`WorkerContext`] is responsible for managing
4//! the execution lifecycle of a worker, tracking tasks, handling shutdown, and emitting
5//! lifecycle events. It also provides [`Tracked`] for wrapping and monitoring asynchronous
6//! tasks within the worker domain.
7//!
8//! ## Lifecycle
9//! A `WorkerContext` goes through distinct phases of operation:
10//!
11//! - **Pending**: Created via [`WorkerContext::new`] and must be explicitly started.
12//! - **Running**: Activated by calling [`WorkerContext::start`]. The worker becomes ready to accept and track tasks.
13//! - **Paused**: Temporarily halted via [`WorkerContext::pause`]. New tasks are blocked from execution.
14//! - **Resumed**: Brought back to `Running` using [`WorkerContext::resume`].
15//! - **Stopped**: Finalized via [`WorkerContext::stop`]. The worker shuts down gracefully, allowing tracked tasks to complete.
16//!
17//! The `WorkerContext` itself implements [`Future`], and can be `.await`ed — it resolves
18//! once the worker is shut down and all tasks have completed.
19//!
20//! ## Task Management
21//! Asynchronous tasks can be tracked using [`WorkerContext::track`], which wraps a future
22//! in a [`Tracked`] type. This ensures:
23//! - Task count is incremented before execution and decremented on completion
24//! - Shutdown is automatically triggered once all tasks are done
25//!
26//! Use [`task_count`](WorkerContext::task_count) and [`has_pending_tasks`](WorkerContext::has_pending_tasks) to inspect
27//! ongoing task state.
28//!
29//! ## Shutdown Semantics
30//! The worker is considered shutting down if:
31//! - `stop()` has been called
32//! - A shutdown signal (if configured) has been triggered
33//!
34//! Once shutdown begins, no new tasks should be accepted. Internally, a stored [`Waker`] is
35//! used to drive progress toward shutdown completion.
36//!
37//! ## Event Handling
38//! Worker lifecycle events (e.g., `Start`, `Stop`) can be emitted using [`WorkerContext::emit`].
39//! Custom handlers can be registered via [`WorkerContext::wrap_listener`] to hook into these transitions.
40//!
41//! ## Request Integration
42//! `WorkerContext` implements [`FromRequest`] so it can be extracted automatically in request
43//! handlers when using a compatible framework or service layer.
44//!
45//! ## Types
46//! - [`WorkerContext`] — shared state container for a worker
47//! - [`Tracked`] — future wrapper for task lifecycle tracking
48use std::{
49    any::type_name,
50    fmt,
51    future::Future,
52    pin::Pin,
53    sync::{
54        Arc, Mutex,
55        atomic::{AtomicBool, AtomicUsize, Ordering},
56    },
57    task::{Context, Poll, Waker},
58};
59
60use crate::{
61    error::{WorkerError, WorkerStateError},
62    monitor::shutdown::Shutdown,
63    task::{Task, data::MissingDataError},
64    task_fn::FromRequest,
65    worker::{
66        event::{Event, EventListener, RawEventListener},
67        state::{InnerWorkerState, WorkerState},
68    },
69};
70
71/// Utility for managing a worker's context
72///
73/// A worker context is created for each worker thread and is responsible for managing
74/// the worker's state, task tracking, and event handling.
75///
76///  **Tip**: All fields are wrapped inside [`Arc`] so it should be cheap to clone
77#[derive(Clone)]
78pub struct WorkerContext {
79    pub(super) name: Arc<String>,
80    task_count: Arc<AtomicUsize>,
81    waker: Arc<Mutex<Option<Waker>>>,
82    state: Arc<WorkerState>,
83    pub(crate) shutdown: Option<Shutdown>,
84    event_handler: EventListener,
85    pub(super) is_ready: Arc<AtomicBool>,
86    pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("WorkerContext")
92            .field("shutdown", &["Shutdown handle"])
93            .field("task_count", &self.task_count)
94            .field("state", &self.state.load(Ordering::SeqCst))
95            .field("service", &self.service)
96            .field("is_ready", &self.is_ready)
97            .finish()
98    }
99}
100
101/// A future tracked by the worker
102#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105    ctx: WorkerContext,
106    #[pin]
107    task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111    type Output = F::Output;
112
113    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114        let this = self.project();
115
116        match this.task.poll(cx) {
117            res @ Poll::Ready(_) => res,
118            Poll::Pending => Poll::Pending,
119        }
120    }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125    fn drop(self: Pin<&mut Self>) {
126        self.ctx.end_task();
127    }
128}
129
130impl WorkerContext {
131    /// Create a new worker context
132    #[must_use]
133    pub fn new<S>(name: &str) -> Self {
134        Self {
135            name: Arc::new(name.to_owned()),
136            service: type_name::<S>(),
137            task_count: Default::default(),
138            waker: Default::default(),
139            state: Default::default(),
140            shutdown: Default::default(),
141            event_handler: Arc::new(Box::new(|_, _| {
142                // noop
143            })),
144            is_ready: Default::default(),
145        }
146    }
147
148    /// Get the worker id
149    #[must_use]
150    pub fn name(&self) -> &String {
151        &self.name
152    }
153
154    /// Start running the worker
155    pub fn start(&mut self) -> Result<(), WorkerError> {
156        let current_state = self.state.load(Ordering::SeqCst);
157        if current_state != InnerWorkerState::Pending {
158            return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
159        }
160        self.state
161            .store(InnerWorkerState::Running, Ordering::SeqCst);
162        self.is_ready.store(false, Ordering::SeqCst);
163        self.emit(&Event::Start);
164        info!("Worker {} started", self.name());
165        Ok(())
166    }
167
168    /// Restart running the worker
169    pub fn restart(&mut self) -> Result<(), WorkerError> {
170        self.state
171            .store(InnerWorkerState::Pending, Ordering::SeqCst);
172        self.is_ready.store(false, Ordering::SeqCst);
173        info!("Worker {} restarted", self.name());
174        Ok(())
175    }
176
177    /// Start a task that is tracked by the worker
178    pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
179        self.start_task();
180        Tracked {
181            ctx: self.clone(),
182            task,
183        }
184    }
185    /// Pauses a worker, preventing any new jobs from being polled
186    pub fn pause(&self) -> Result<(), WorkerError> {
187        if !self.is_running() {
188            return Err(WorkerError::StateError(WorkerStateError::NotRunning));
189        }
190        self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
191        info!("Worker {} paused", self.name());
192        Ok(())
193    }
194
195    /// Resume a worker that is paused
196    pub fn resume(&self) -> Result<(), WorkerError> {
197        if !self.is_paused() {
198            return Err(WorkerError::StateError(WorkerStateError::NotPaused));
199        }
200        if self.is_shutting_down() {
201            return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
202        }
203        self.state
204            .store(InnerWorkerState::Running, Ordering::SeqCst);
205        info!("Worker {} resumed", self.name());
206        Ok(())
207    }
208
209    /// Calling this function triggers shutting down the worker while waiting for any tasks to complete
210    pub fn stop(&self) -> Result<(), WorkerError> {
211        let current_state = self.state.load(Ordering::SeqCst);
212        if current_state == InnerWorkerState::Pending {
213            return Err(WorkerError::StateError(WorkerStateError::NotStarted));
214        }
215        self.state
216            .store(InnerWorkerState::Stopped, Ordering::SeqCst);
217        self.wake();
218        self.emit_ref(&Event::Stop);
219        info!("Worker {} stopped", self.name());
220        Ok(())
221    }
222
223    /// Checks if the worker is ready to consume new tasks
224    #[must_use]
225    pub fn is_ready(&self) -> bool {
226        self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
227    }
228
229    /// Get the type of service
230    #[must_use]
231    pub fn get_service(&self) -> &str {
232        self.service
233    }
234
235    /// Checks whether the worker is running
236    #[must_use]
237    pub fn is_running(&self) -> bool {
238        self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
239    }
240
241    /// Checks whether the worker is pending
242    #[must_use]
243    pub fn is_pending(&self) -> bool {
244        self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
245    }
246
247    /// Checks whether the worker is paused
248    #[must_use]
249    pub fn is_paused(&self) -> bool {
250        self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
251    }
252
253    /// Checks the current futures in the worker domain
254    /// This include futures spawned via `worker.track`
255    #[must_use]
256    pub fn task_count(&self) -> usize {
257        self.task_count.load(Ordering::Relaxed)
258    }
259
260    /// Checks whether the worker has pending tasks
261    #[must_use]
262    pub fn has_pending_tasks(&self) -> bool {
263        self.task_count.load(Ordering::Relaxed) > 0
264    }
265
266    /// Is the shutdown token called
267    #[must_use]
268    pub fn is_shutting_down(&self) -> bool {
269        self.shutdown
270            .as_ref()
271            .map(|s| !self.is_running() || s.is_shutting_down())
272            .unwrap_or(!self.is_running())
273    }
274
275    /// Allows workers to emit events
276    pub fn emit(&mut self, event: &Event) {
277        self.emit_ref(event);
278    }
279
280    fn emit_ref(&self, event: &Event) {
281        let handler = self.event_handler.as_ref();
282        handler(self, event);
283    }
284
285    /// Wraps the event listener with a new function
286    pub fn wrap_listener<F: Fn(&Self, &Event) + Send + Sync + 'static>(&mut self, f: F) {
287        let cur = self.event_handler.clone();
288        let new: RawEventListener = Box::new(move |ctx, ev| {
289            f(ctx, ev);
290            cur(ctx, ev);
291        });
292        self.event_handler = Arc::new(new);
293    }
294
295    pub(crate) fn add_waker(&self, cx: &Context<'_>) {
296        if let Ok(mut waker_guard) = self.waker.lock() {
297            if waker_guard
298                .as_ref()
299                .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
300            {
301                *waker_guard = Some(cx.waker().clone());
302            }
303        }
304    }
305
306    /// Checks if the stored waker matches the current one.
307    fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
308        if let Ok(waker_guard) = self.waker.lock() {
309            if let Some(stored_waker) = &*waker_guard {
310                return stored_waker.will_wake(cx.waker());
311            }
312        }
313        false
314    }
315
316    fn start_task(&self) {
317        self.task_count.fetch_add(1, Ordering::Relaxed);
318    }
319
320    fn end_task(&self) {
321        if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
322            self.wake();
323        }
324    }
325
326    pub(crate) fn wake(&self) {
327        if let Ok(waker) = self.waker.lock() {
328            if let Some(waker) = &*waker {
329                waker.wake_by_ref();
330            }
331        }
332    }
333}
334
335impl Future for WorkerContext {
336    type Output = Result<(), WorkerError>;
337
338    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
339        let task_count = self.task_count.load(Ordering::Relaxed);
340        let state = self.state.load(Ordering::SeqCst);
341        if state == InnerWorkerState::Pending {
342            return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
343        }
344        if self.is_shutting_down() && task_count == 0 {
345            Poll::Ready(Ok(()))
346        } else {
347            if !self.has_recent_waker(cx) {
348                self.add_waker(cx);
349            }
350            Poll::Pending
351        }
352    }
353}
354
355impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
356    for WorkerContext
357{
358    type Error = MissingDataError;
359    async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
360        task.parts.data.get_checked().cloned()
361    }
362}
363
364impl Drop for WorkerContext {
365    fn drop(&mut self) {
366        if Arc::strong_count(&self.state) > 1 {
367            // There are still other references to this context, so we shouldn't log a warning.
368            return;
369        }
370        if self.is_running() {
371            error!(
372                "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
373                self.name(),
374                self.task_count()
375            );
376        }
377    }
378}