apalis_core/worker/
context.rs

1//! Worker context and task tracking.
2//!
3//! [`WorkerContext`] is responsible for managing
4//! the execution lifecycle of a worker, tracking tasks, handling shutdown, and emitting
5//! lifecycle events. It also provides [`Tracked`] for wrapping and monitoring asynchronous
6//! tasks within the worker domain.
7//!
8//! ## Lifecycle
9//! A `WorkerContext` goes through distinct phases of operation:
10//!
11//! - **Pending**: Created via [`WorkerContext::new`] and must be explicitly started.
12//! - **Running**: Activated by calling [`WorkerContext::start`]. The worker becomes ready to accept and track tasks.
13//! - **Paused**: Temporarily halted via [`WorkerContext::pause`]. New tasks are blocked from execution.
14//! - **Resumed**: Brought back to `Running` using [`WorkerContext::resume`].
15//! - **Stopped**: Finalized via [`WorkerContext::stop`]. The worker shuts down gracefully, allowing tracked tasks to complete.
16//!
17//! The `WorkerContext` itself implements [`Future`], and can be `.await`ed — it resolves
18//! once the worker is shut down and all tasks have completed.
19//!
20//! ## Task Management
21//! Asynchronous tasks can be tracked using [`WorkerContext::track`], which wraps a future
22//! in a [`Tracked`] type. This ensures:
23//! - Task count is incremented before execution and decremented on completion
24//! - Shutdown is automatically triggered once all tasks are done
25//!
26//! Use [`task_count`](WorkerContext::task_count) and [`has_pending_tasks`](WorkerContext::has_pending_tasks) to inspect
27//! ongoing task state.
28//!
29//! ## Shutdown Semantics
30//! The worker is considered shutting down if:
31//! - `stop()` has been called
32//! - A shutdown signal (if configured) has been triggered
33//!
34//! Once shutdown begins, no new tasks should be accepted. Internally, a stored [`Waker`] is
35//! used to drive progress toward shutdown completion.
36//!
37//! ## Event Handling
38//! Worker lifecycle events (e.g., `Start`, `Stop`) can be emitted using [`WorkerContext::emit`].
39//! Custom handlers can be registered via [`WorkerContext::wrap_listener`] to hook into these transitions.
40//!
41//! ## Request Integration
42//! `WorkerContext` implements [`FromRequest`] so it can be extracted automatically in request
43//! handlers when using a compatible framework or service layer.
44//!
45//! ## Types
46//! - [`WorkerContext`] — shared state container for a worker
47//! - [`Tracked`] — future wrapper for task lifecycle tracking
48use std::{
49    any::type_name,
50    fmt,
51    future::Future,
52    pin::Pin,
53    sync::{
54        Arc, Mutex,
55        atomic::{AtomicBool, AtomicUsize, Ordering},
56    },
57    task::{Context, Poll, Waker},
58};
59
60use crate::{
61    error::{WorkerError, WorkerStateError},
62    monitor::shutdown::Shutdown,
63    task::{Task, data::MissingDataError},
64    task_fn::FromRequest,
65    worker::{
66        event::{Event, EventListener, RawEventListener},
67        state::{InnerWorkerState, WorkerState},
68    },
69};
70
71/// Utility for managing a worker's context
72///
73/// A worker context is created for each worker thread and is responsible for managing
74/// the worker's state, task tracking, and event handling.
75///
76///  **Tip**: All fields are wrapped inside [`Arc`] so it should be cheap to clone
77#[derive(Clone)]
78pub struct WorkerContext {
79    pub(super) name: Arc<String>,
80    task_count: Arc<AtomicUsize>,
81    waker: Arc<Mutex<Option<Waker>>>,
82    state: Arc<WorkerState>,
83    pub(crate) shutdown: Option<Shutdown>,
84    event_handler: EventListener,
85    pub(super) is_ready: Arc<AtomicBool>,
86    pub(super) service: &'static str,
87}
88
89impl fmt::Debug for WorkerContext {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        f.debug_struct("WorkerContext")
92            .field("shutdown", &["Shutdown handle"])
93            .field("task_count", &self.task_count)
94            .field("state", &self.state.load(Ordering::SeqCst))
95            .field("service", &self.service)
96            .field("is_ready", &self.is_ready)
97            .finish()
98    }
99}
100
101/// A future tracked by the worker
102#[pin_project::pin_project(PinnedDrop)]
103#[derive(Debug)]
104pub struct Tracked<F> {
105    ctx: WorkerContext,
106    #[pin]
107    task: F,
108}
109
110impl<F: Future> Future for Tracked<F> {
111    type Output = F::Output;
112
113    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
114        let this = self.project();
115
116        match this.task.poll(cx) {
117            res @ Poll::Ready(_) => res,
118            Poll::Pending => Poll::Pending,
119        }
120    }
121}
122
123#[pin_project::pinned_drop]
124impl<F> PinnedDrop for Tracked<F> {
125    fn drop(self: Pin<&mut Self>) {
126        self.ctx.end_task();
127    }
128}
129
130impl WorkerContext {
131    /// Create a new worker context
132    #[must_use]
133    pub fn new<S>(name: &str) -> Self {
134        Self {
135            name: Arc::new(name.to_owned()),
136            service: type_name::<S>(),
137            task_count: Default::default(),
138            waker: Default::default(),
139            state: Default::default(),
140            shutdown: Default::default(),
141            event_handler: Arc::new(Box::new(|_, _| {
142                // noop
143            })),
144            is_ready: Default::default(),
145        }
146    }
147
148    /// Get the worker id
149    #[must_use]
150    pub fn name(&self) -> &String {
151        &self.name
152    }
153
154    /// Start running the worker
155    pub fn start(&mut self) -> Result<(), WorkerError> {
156        let current_state = self.state.load(Ordering::SeqCst);
157        if current_state != InnerWorkerState::Pending {
158            return Err(WorkerError::StateError(WorkerStateError::AlreadyStarted));
159        }
160        self.state
161            .store(InnerWorkerState::Running, Ordering::SeqCst);
162        self.is_ready.store(false, Ordering::SeqCst);
163        self.emit(&Event::Start);
164        info!("Worker {} started", self.name());
165        Ok(())
166    }
167
168    /// Restart running the worker
169    pub fn restart(&mut self) -> Result<(), WorkerError> {
170        self.state
171            .store(InnerWorkerState::Pending, Ordering::SeqCst);
172        self.is_ready.store(false, Ordering::SeqCst);
173        info!("Worker {} restarted", self.name());
174        Ok(())
175    }
176
177    /// Start a task that is tracked by the worker
178    pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
179        self.start_task();
180        Tracked {
181            ctx: self.clone(),
182            task,
183        }
184    }
185    /// Pauses a worker, preventing any new jobs from being polled
186    pub fn pause(&self) -> Result<(), WorkerError> {
187        if !self.is_running() {
188            return Err(WorkerError::StateError(WorkerStateError::NotRunning));
189        }
190        self.state.store(InnerWorkerState::Paused, Ordering::SeqCst);
191        info!("Worker {} paused", self.name());
192        Ok(())
193    }
194
195    /// Resume a worker that is paused
196    pub fn resume(&self) -> Result<(), WorkerError> {
197        if !self.is_paused() {
198            return Err(WorkerError::StateError(WorkerStateError::NotPaused));
199        }
200        if self.is_shutting_down() {
201            return Err(WorkerError::StateError(WorkerStateError::ShuttingDown));
202        }
203        self.state
204            .store(InnerWorkerState::Running, Ordering::SeqCst);
205        self.wake();
206        info!("Worker {} resumed", self.name());
207        Ok(())
208    }
209
210    /// Calling this function triggers shutting down the worker while waiting for any tasks to complete
211    pub fn stop(&self) -> Result<(), WorkerError> {
212        let current_state = self.state.load(Ordering::SeqCst);
213        if current_state == InnerWorkerState::Pending {
214            return Err(WorkerError::StateError(WorkerStateError::NotStarted));
215        }
216        self.state
217            .store(InnerWorkerState::Stopped, Ordering::SeqCst);
218        self.wake();
219        self.emit_ref(&Event::Stop);
220        info!("Worker {} stopped", self.name());
221        Ok(())
222    }
223
224    /// Checks if the worker is ready to consume new tasks
225    #[must_use]
226    pub fn is_ready(&self) -> bool {
227        self.is_running() && !self.is_shutting_down() && self.is_ready.load(Ordering::SeqCst)
228    }
229
230    /// Get the type of service
231    #[must_use]
232    pub fn get_service(&self) -> &str {
233        self.service
234    }
235
236    /// Checks whether the worker is running
237    #[must_use]
238    pub fn is_running(&self) -> bool {
239        self.state.load(Ordering::SeqCst) == InnerWorkerState::Running
240    }
241
242    /// Checks whether the worker is pending
243    #[must_use]
244    pub fn is_pending(&self) -> bool {
245        self.state.load(Ordering::SeqCst) == InnerWorkerState::Pending
246    }
247
248    /// Checks whether the worker is paused
249    #[must_use]
250    pub fn is_paused(&self) -> bool {
251        self.state.load(Ordering::SeqCst) == InnerWorkerState::Paused
252    }
253
254    /// Checks whether the worker has been stopped
255    #[must_use]
256    pub fn is_stopped(&self) -> bool {
257        self.state.load(Ordering::SeqCst) == InnerWorkerState::Stopped
258    }
259
260    /// Checks the current futures in the worker domain
261    /// This include futures spawned via `worker.track`
262    #[must_use]
263    pub fn task_count(&self) -> usize {
264        self.task_count.load(Ordering::Relaxed)
265    }
266
267    /// Checks whether the worker has pending tasks
268    #[must_use]
269    pub fn has_pending_tasks(&self) -> bool {
270        self.task_count.load(Ordering::Relaxed) > 0
271    }
272
273    /// Is the shutdown token called
274    #[must_use]
275    pub fn is_shutting_down(&self) -> bool {
276        self.is_stopped() || self.shutdown.as_ref().is_some_and(|s| s.is_shutting_down())
277    }
278
279    /// Allows workers to emit events
280    pub fn emit(&mut self, event: &Event) {
281        self.emit_ref(event);
282    }
283
284    fn emit_ref(&self, event: &Event) {
285        let handler = self.event_handler.as_ref();
286        handler(self, event);
287    }
288
289    /// Wraps the event listener with a new function
290    pub fn wrap_listener<F: Fn(&Self, &Event) + Send + Sync + 'static>(&mut self, f: F) {
291        let cur = self.event_handler.clone();
292        let new: RawEventListener = Box::new(move |ctx, ev| {
293            f(ctx, ev);
294            cur(ctx, ev);
295        });
296        self.event_handler = Arc::new(new);
297    }
298
299    pub(crate) fn add_waker(&self, cx: &Context<'_>) {
300        if let Ok(mut waker_guard) = self.waker.lock() {
301            if waker_guard
302                .as_ref()
303                .is_none_or(|stored_waker| !stored_waker.will_wake(cx.waker()))
304            {
305                *waker_guard = Some(cx.waker().clone());
306            }
307        }
308    }
309
310    /// Checks if the stored waker matches the current one.
311    fn has_recent_waker(&self, cx: &Context<'_>) -> bool {
312        if let Ok(waker_guard) = self.waker.lock() {
313            if let Some(stored_waker) = &*waker_guard {
314                return stored_waker.will_wake(cx.waker());
315            }
316        }
317        false
318    }
319
320    fn start_task(&self) {
321        self.task_count.fetch_add(1, Ordering::Relaxed);
322    }
323
324    fn end_task(&self) {
325        if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
326            self.wake();
327        }
328    }
329
330    pub(crate) fn wake(&self) {
331        if let Ok(waker) = self.waker.lock() {
332            if let Some(waker) = &*waker {
333                waker.wake_by_ref();
334            }
335        }
336    }
337}
338
339impl Future for WorkerContext {
340    type Output = Result<(), WorkerError>;
341
342    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
343        let task_count = self.task_count.load(Ordering::Relaxed);
344        let state = self.state.load(Ordering::SeqCst);
345        if state == InnerWorkerState::Pending {
346            return Poll::Ready(Err(WorkerError::StateError(WorkerStateError::NotStarted)));
347        }
348        if self.is_shutting_down() && task_count == 0 {
349            Poll::Ready(Ok(()))
350        } else {
351            if !self.has_recent_waker(cx) {
352                self.add_waker(cx);
353            }
354            Poll::Pending
355        }
356    }
357}
358
359impl<Args: Sync, Ctx: Sync, IdType: Sync + Send> FromRequest<Task<Args, Ctx, IdType>>
360    for WorkerContext
361{
362    type Error = MissingDataError;
363    async fn from_request(task: &Task<Args, Ctx, IdType>) -> Result<Self, Self::Error> {
364        task.parts.data.get_checked().cloned()
365    }
366}
367
368impl Drop for WorkerContext {
369    fn drop(&mut self) {
370        if Arc::strong_count(&self.state) > 1 {
371            // There are still other references to this context, so we shouldn't log a warning.
372            return;
373        }
374        if self.is_running() && self.has_pending_tasks() {
375            error!(
376                "Worker '{}' is being dropped while running with `{}` tasks. Consider calling stop() before dropping.",
377                self.name(),
378                self.task_count()
379            );
380        }
381    }
382}
383
384#[cfg(test)]
385mod tests {
386    use crate::{
387        backend::memory::MemoryStorage, error::BoxDynError, worker::builder::WorkerBuilder,
388    };
389    use std::time::Duration;
390
391    use super::*;
392
393    #[tokio::test]
394    async fn test_worker_state_transitions() {
395        let backend = MemoryStorage::<u32>::new();
396
397        let worker = WorkerBuilder::new("test-worker")
398            .backend(backend)
399            .build(|_task: u32| async { Ok::<_, BoxDynError>(()) });
400
401        let mut ctx = WorkerContext::new::<()>("test-worker");
402        let ctx_handle = ctx.clone();
403
404        let worker_handle = tokio::spawn(async move { worker.run_with_ctx(&mut ctx).await });
405        tokio::time::sleep(Duration::from_millis(50)).await;
406
407        // Initial state: worker should be running
408        assert!(ctx_handle.is_running());
409        assert!(!ctx_handle.is_shutting_down());
410        assert!(!ctx_handle.is_stopped());
411
412        // Pause the worker
413        ctx_handle.pause().unwrap();
414        assert!(ctx_handle.is_paused());
415        assert!(
416            !ctx_handle.is_shutting_down(),
417            "Paused worker should NOT be considered shutting down"
418        );
419
420        // Resume the worker
421        ctx_handle.resume().unwrap();
422        assert!(ctx_handle.is_running());
423        assert!(!ctx_handle.is_paused());
424
425        // Stop the worker
426        ctx_handle.stop().unwrap();
427        assert!(ctx_handle.is_stopped());
428        assert!(ctx_handle.is_shutting_down());
429
430        // Try to resume a stopped worker (should fail with NotPaused error since state is Stopped)
431        assert!(
432            matches!(
433                ctx_handle.resume(),
434                Err(WorkerError::StateError(WorkerStateError::NotPaused))
435            ),
436            "Resuming a stopped worker should fail with NotPaused error"
437        );
438
439        worker_handle.await.unwrap().unwrap();
440    }
441}