apalis_core/worker/
mod.rs

1use crate::backend::Backend;
2use crate::error::{BoxDynError, Error};
3use crate::layers::extensions::Data;
4use crate::monitor::shutdown::Shutdown;
5use crate::request::Request;
6use crate::service_fn::FromRequest;
7use crate::task::task_id::TaskId;
8use call_all::CallAllUnordered;
9use futures::future::{join, select, BoxFuture};
10use futures::stream::BoxStream;
11use futures::{Future, FutureExt, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15use std::fmt::{self, Display};
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::str::FromStr;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex, RwLock};
21use std::task::{Context as TaskCtx, Poll, Waker};
22use thiserror::Error;
23use tower::{Layer, Service, ServiceBuilder};
24
25mod call_all;
26
27/// A worker name wrapper usually used by Worker builder
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29pub struct WorkerId {
30    name: String,
31}
32
33/// An event handler for [`Worker`]
34pub type EventHandler = Arc<RwLock<Option<Box<dyn Fn(Worker<Event>) + Send + Sync>>>>;
35
36impl FromStr for WorkerId {
37    type Err = ();
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        Ok(WorkerId { name: s.to_owned() })
41    }
42}
43
44impl Display for WorkerId {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.write_str(self.name())?;
47        Ok(())
48    }
49}
50
51impl WorkerId {
52    /// Build a new worker ref
53    pub fn new<T: AsRef<str>>(name: T) -> Self {
54        Self {
55            name: name.as_ref().to_string(),
56        }
57    }
58
59    /// Get the name of the worker
60    pub fn name(&self) -> &str {
61        &self.name
62    }
63}
64
65/// Events emitted by a worker
66#[derive(Debug)]
67pub enum Event {
68    /// Worker started
69    Start,
70    /// Worker got a job
71    Engage(TaskId),
72    /// Worker is idle, stream has no new request for now
73    Idle,
74    /// A custom event
75    Custom(String),
76    /// Worker encountered an error
77    Error(BoxDynError),
78    /// Worker stopped
79    Stop,
80    /// Worker completed all pending tasks
81    Exit,
82}
83
84impl fmt::Display for Worker<Event> {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        let event_description = match &self.state {
87            Event::Start => "Worker started".to_string(),
88            Event::Engage(task_id) => format!("Worker engaged with Task ID: {}", task_id),
89            Event::Idle => "Worker is idle".to_string(),
90            Event::Custom(msg) => format!("Custom event: {}", msg),
91            Event::Error(err) => format!("Worker encountered an error: {}", err),
92            Event::Stop => "Worker stopped".to_string(),
93            Event::Exit => "Worker completed all pending tasks and exited".to_string(),
94        };
95
96        write!(f, "Worker [{}]: {}", self.id.name, event_description)
97    }
98}
99
100/// Possible errors that can occur when starting a worker.
101#[derive(Error, Debug, Clone)]
102pub enum WorkerError {
103    /// An error occurred while processing a job.
104    #[error("Failed to process job: {0}")]
105    ProcessingError(String),
106    /// An error occurred in the worker's service.
107    #[error("Service error: {0}")]
108    ServiceError(String),
109    /// An error occurred while trying to start the worker.
110    #[error("Failed to start worker: {0}")]
111    StartError(String),
112}
113
114/// A worker that is ready for running
115pub struct Ready<S, P> {
116    service: S,
117    backend: P,
118    pub(crate) shutdown: Option<Shutdown>,
119    pub(crate) event_handler: EventHandler,
120}
121
122impl<S, P> fmt::Debug for Ready<S, P>
123where
124    S: fmt::Debug,
125    P: fmt::Debug,
126{
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        f.debug_struct("Ready")
129            .field("service", &self.service)
130            .field("backend", &self.backend)
131            .field("shutdown", &self.shutdown)
132            .field("event_handler", &"...") // Avoid dumping potentially sensitive or verbose data
133            .finish()
134    }
135}
136
137impl<S, P> Clone for Ready<S, P>
138where
139    S: Clone,
140    P: Clone,
141{
142    fn clone(&self) -> Self {
143        Ready {
144            service: self.service.clone(),
145            backend: self.backend.clone(),
146            shutdown: self.shutdown.clone(),
147            event_handler: self.event_handler.clone(),
148        }
149    }
150}
151
152impl<S, P> Ready<S, P> {
153    /// Build a worker that is ready for execution
154    pub fn new(service: S, poller: P) -> Self {
155        Ready {
156            service,
157            backend: poller,
158            shutdown: None,
159            event_handler: EventHandler::default(),
160        }
161    }
162}
163
164/// Represents a generic [Worker] that can be in many different states
165#[derive(Debug, Clone, Serialize)]
166pub struct Worker<T> {
167    pub(crate) id: WorkerId,
168    pub(crate) state: T,
169}
170
171impl<T> Worker<T> {
172    /// Create a new worker instance
173    pub fn new(id: WorkerId, state: T) -> Self {
174        Self { id, state }
175    }
176
177    /// Get the inner state
178    pub fn inner(&self) -> &T {
179        &self.state
180    }
181
182    /// Get the worker id
183    pub fn id(&self) -> &WorkerId {
184        &self.id
185    }
186}
187
188impl<T> Deref for Worker<T> {
189    type Target = T;
190    fn deref(&self) -> &Self::Target {
191        &self.state
192    }
193}
194
195impl<T> DerefMut for Worker<T> {
196    fn deref_mut(&mut self) -> &mut Self::Target {
197        &mut self.state
198    }
199}
200
201impl Worker<Context> {
202    /// Allows workers to emit events
203    pub fn emit(&self, event: Event) -> bool {
204        if let Some(handler) = self.state.event_handler.read().unwrap().as_ref() {
205            handler(Worker {
206                id: self.id().clone(),
207                state: event,
208            });
209            return true;
210        }
211        false
212    }
213    /// Start running the worker
214    pub fn start(&self) {
215        self.state.running.store(true, Ordering::Relaxed);
216        self.state.is_ready.store(true, Ordering::Release);
217        self.emit(Event::Start);
218    }
219}
220
221impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
222    fn from_request(req: &Request<Req, Ctx>) -> Result<Self, Error> {
223        req.parts.data.get_checked().cloned()
224    }
225}
226
227impl<S, P> Worker<Ready<S, P>> {
228    /// Add an event handler to the worker
229    pub fn on_event<F: Fn(Worker<Event>) + Send + Sync + 'static>(self, f: F) -> Self {
230        let _ = self.event_handler.write().map(|mut res| {
231            let _ = res.insert(Box::new(f));
232        });
233        self
234    }
235
236    fn poll_jobs<Svc, Stm, Req, Res, Ctx>(
237        worker: Worker<Context>,
238        service: Svc,
239        stream: Stm,
240    ) -> BoxStream<'static, ()>
241    where
242        Svc: Service<Request<Req, Ctx>, Response = Res> + Send + 'static,
243        Stm: Stream<Item = Result<Option<Request<Req, Ctx>>, Error>> + Send + Unpin + 'static,
244        Req: Send + 'static + Sync,
245        Svc::Future: Send,
246        Svc::Response: 'static + Send + Sync + Serialize,
247        Svc::Error: Send + Sync + 'static + Into<BoxDynError>,
248        Ctx: Send + 'static + Sync,
249        Res: 'static,
250    {
251        let w = worker.clone();
252        let stream = stream.filter_map(move |result| {
253            let worker = worker.clone();
254
255            async move {
256                match result {
257                    Ok(Some(request)) => {
258                        worker.emit(Event::Engage(request.parts.task_id.clone()));
259                        Some(request)
260                    }
261                    Ok(None) => {
262                        worker.emit(Event::Idle);
263                        None
264                    }
265                    Err(err) => {
266                        worker.emit(Event::Error(Box::new(err)));
267                        None
268                    }
269                }
270            }
271        });
272        let stream = CallAllUnordered::new(service, stream).map(move |res| {
273            if let Err(error) = res {
274                let error = error.into();
275                if let Some(Error::MissingData(_)) = error.downcast_ref::<Error>() {
276                    w.stop();
277                }
278                w.emit(Event::Error(error));
279            }
280        });
281        stream.boxed()
282    }
283    /// Start a worker
284    pub fn run<Req, Res, Ctx>(self) -> Runnable
285    where
286        S: Service<Request<Req, Ctx>, Response = Res> + Send + 'static,
287        P: Backend<Request<Req, Ctx>, Res> + 'static,
288        Req: Send + 'static + Sync,
289        S::Future: Send,
290        S::Response: 'static + Send + Sync + Serialize,
291        S::Error: Send + Sync + 'static + Into<BoxDynError>,
292        P::Stream: Unpin + Send + 'static,
293        P::Layer: Layer<S>,
294        <P::Layer as Layer<S>>::Service: Service<Request<Req, Ctx>, Response = Res> + Send,
295        <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Future: Send,
296        <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Error:
297            Send + Into<BoxDynError> + Sync,
298        Ctx: Send + 'static + Sync,
299        Res: 'static,
300    {
301        let worker_id = self.id;
302        let ctx = Context {
303            running: Arc::default(),
304            task_count: Arc::default(),
305            waker: Arc::default(),
306            shutdown: self.state.shutdown,
307            event_handler: self.state.event_handler.clone(),
308            is_ready: Arc::default(),
309        };
310        let worker = Worker {
311            id: worker_id.clone(),
312            state: ctx.clone(),
313        };
314        let backend = self.state.backend;
315        let service = self.state.service;
316        let poller = backend.poll::<S>(&worker);
317        let stream = poller.stream;
318        let heartbeat = poller.heartbeat.boxed();
319        let layer = poller.layer;
320        let service = ServiceBuilder::new()
321            .layer(TrackerLayer::new(worker.state.clone()))
322            .layer(ReadinessLayer::new(worker.state.is_ready.clone()))
323            .layer(Data::new(worker.clone()))
324            .layer(layer)
325            .service(service);
326
327        Runnable {
328            poller: Self::poll_jobs(worker.clone(), service, stream),
329            heartbeat,
330            worker,
331            running: false,
332        }
333    }
334}
335
336/// A `Runnable` represents a unit of work that manages a worker's lifecycle and execution flow.
337///
338/// The `Runnable` struct is responsible for coordinating the core tasks of a worker, such as polling for jobs,
339/// maintaining heartbeats, and tracking its running state. It integrates various components required for
340/// the worker to operate effectively within an asynchronous runtime.
341#[must_use = "A Runnable must be awaited of no jobs will be consumed"]
342pub struct Runnable {
343    poller: BoxStream<'static, ()>,
344    heartbeat: BoxFuture<'static, ()>,
345    worker: Worker<Context>,
346    running: bool,
347}
348
349impl Runnable {
350    /// Returns a handle to the worker, allowing control and functionality like stopping
351    pub fn get_handle(&self) -> Worker<Context> {
352        self.worker.clone()
353    }
354}
355
356impl fmt::Debug for Runnable {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        f.debug_struct("Runnable")
359            .field("poller", &"<stream>")
360            .field("heartbeat", &"<future>")
361            .field("worker", &self.worker)
362            .field("running", &self.running)
363            .finish()
364    }
365}
366
367impl Future for Runnable {
368    type Output = ();
369
370    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
371        let this = self.get_mut();
372        let poller = &mut this.poller;
373        let heartbeat = &mut this.heartbeat;
374        let worker = &mut this.worker;
375
376        let poller_future = async { while (poller.next().await).is_some() {} };
377
378        if !this.running {
379            worker.start();
380            this.running = true;
381        }
382        let combined = Box::pin(join(poller_future, heartbeat.as_mut()));
383
384        let mut combined = select(
385            combined,
386            worker.state.clone().map(|_| worker.emit(Event::Stop)),
387        )
388        .boxed();
389        match Pin::new(&mut combined).poll(cx) {
390            Poll::Ready(_) => {
391                worker.emit(Event::Exit);
392                Poll::Ready(())
393            }
394            Poll::Pending => Poll::Pending,
395        }
396    }
397}
398
399/// Stores the Workers context
400#[derive(Clone, Default)]
401pub struct Context {
402    task_count: Arc<AtomicUsize>,
403    waker: Arc<Mutex<Option<Waker>>>,
404    running: Arc<AtomicBool>,
405    shutdown: Option<Shutdown>,
406    event_handler: EventHandler,
407    is_ready: Arc<AtomicBool>,
408}
409
410impl fmt::Debug for Context {
411    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
412        f.debug_struct("WorkerContext")
413            .field("shutdown", &["Shutdown handle"])
414            .field("task_count", &self.task_count)
415            .field("running", &self.running)
416            .finish()
417    }
418}
419
420pin_project! {
421    /// A future tracked by the worker
422    pub struct Tracked<F> {
423        ctx: Context,
424        #[pin]
425        task: F,
426    }
427}
428
429impl<F: Future> Future for Tracked<F> {
430    type Output = F::Output;
431
432    fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<F::Output> {
433        let this = self.project();
434
435        match this.task.poll(cx) {
436            res @ Poll::Ready(_) => {
437                this.ctx.end_task();
438                res
439            }
440            Poll::Pending => Poll::Pending,
441        }
442    }
443}
444
445impl Context {
446    /// Start a task that is tracked by the worker
447    pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
448        self.start_task();
449        Tracked {
450            ctx: self.clone(),
451            task,
452        }
453    }
454
455    /// Calling this function triggers shutting down the worker while waiting for any tasks to complete
456    pub fn stop(&self) {
457        self.running.store(false, Ordering::Relaxed);
458        self.wake()
459    }
460
461    fn start_task(&self) {
462        self.task_count.fetch_add(1, Ordering::Relaxed);
463    }
464
465    fn end_task(&self) {
466        if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
467            self.wake();
468        }
469    }
470
471    pub(crate) fn wake(&self) {
472        if let Ok(waker) = self.waker.lock() {
473            if let Some(waker) = &*waker {
474                waker.wake_by_ref();
475            }
476        }
477    }
478
479    /// Returns whether the worker is running
480    pub fn is_running(&self) -> bool {
481        self.running.load(Ordering::Relaxed)
482    }
483
484    /// Returns the current futures in the worker domain
485    /// This include futures spawned via `worker.track`
486    pub fn task_count(&self) -> usize {
487        self.task_count.load(Ordering::Relaxed)
488    }
489
490    /// Returns whether the worker has pending tasks
491    pub fn has_pending_tasks(&self) -> bool {
492        self.task_count.load(Ordering::Relaxed) > 0
493    }
494
495    /// Is the shutdown token called
496    pub fn is_shutting_down(&self) -> bool {
497        self.shutdown
498            .as_ref()
499            .map(|s| !self.is_running() || s.is_shutting_down())
500            .unwrap_or(!self.is_running())
501    }
502
503    fn add_waker(&self, cx: &mut TaskCtx<'_>) {
504        if let Ok(mut waker_guard) = self.waker.lock() {
505            if waker_guard
506                .as_ref()
507                .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
508            {
509                *waker_guard = Some(cx.waker().clone());
510            }
511        }
512    }
513
514    /// Checks if the stored waker matches the current one.
515    fn has_recent_waker(&self, cx: &TaskCtx<'_>) -> bool {
516        if let Ok(waker_guard) = self.waker.lock() {
517            if let Some(stored_waker) = &*waker_guard {
518                return stored_waker.will_wake(cx.waker());
519            }
520        }
521        false
522    }
523
524    /// Returns if the worker is ready to consume new tasks
525    pub fn is_ready(&self) -> bool {
526        self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
527    }
528}
529
530impl Future for Context {
531    type Output = ();
532
533    fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<()> {
534        let task_count = self.task_count.load(Ordering::Relaxed);
535        if self.is_shutting_down() && task_count == 0 {
536            Poll::Ready(())
537        } else {
538            if !self.has_recent_waker(cx) {
539                self.add_waker(cx);
540            }
541            Poll::Pending
542        }
543    }
544}
545
546#[derive(Debug, Clone)]
547struct TrackerLayer {
548    ctx: Context,
549}
550
551impl TrackerLayer {
552    fn new(ctx: Context) -> Self {
553        Self { ctx }
554    }
555}
556
557impl<S> Layer<S> for TrackerLayer {
558    type Service = TrackerService<S>;
559
560    fn layer(&self, service: S) -> Self::Service {
561        TrackerService {
562            ctx: self.ctx.clone(),
563            service,
564        }
565    }
566}
567#[derive(Debug, Clone)]
568struct TrackerService<S> {
569    ctx: Context,
570    service: S,
571}
572
573impl<S, Req, Ctx> Service<Request<Req, Ctx>> for TrackerService<S>
574where
575    S: Service<Request<Req, Ctx>>,
576{
577    type Response = S::Response;
578    type Error = S::Error;
579    type Future = Tracked<S::Future>;
580
581    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
582        self.service.poll_ready(cx)
583    }
584
585    fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
586        self.ctx.track(self.service.call(request))
587    }
588}
589
590#[derive(Clone)]
591struct ReadinessLayer {
592    is_ready: Arc<AtomicBool>,
593}
594
595impl ReadinessLayer {
596    fn new(is_ready: Arc<AtomicBool>) -> Self {
597        Self { is_ready }
598    }
599}
600
601impl<S> Layer<S> for ReadinessLayer {
602    type Service = ReadinessService<S>;
603
604    fn layer(&self, inner: S) -> Self::Service {
605        ReadinessService {
606            inner,
607            is_ready: self.is_ready.clone(),
608        }
609    }
610}
611
612struct ReadinessService<S> {
613    inner: S,
614    is_ready: Arc<AtomicBool>,
615}
616
617impl<S, Request> Service<Request> for ReadinessService<S>
618where
619    S: Service<Request>,
620{
621    type Response = S::Response;
622    type Error = S::Error;
623    type Future = S::Future;
624
625    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
626        // Delegate poll_ready to the inner service
627        let result = self.inner.poll_ready(cx);
628        // Update the readiness state based on the result
629        match &result {
630            Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
631            Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
632        }
633
634        result
635    }
636
637    fn call(&mut self, req: Request) -> Self::Future {
638        self.inner.call(req)
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use std::{ops::Deref, sync::atomic::AtomicUsize};
645
646    use crate::{
647        builder::{WorkerBuilder, WorkerFactoryFn},
648        layers::extensions::Data,
649        memory::MemoryStorage,
650        mq::MessageQueue,
651    };
652
653    use super::*;
654
655    const ITEMS: u32 = 100;
656
657    #[test]
658    fn it_parses_worker_names() {
659        assert_eq!(
660            WorkerId::from_str("worker").unwrap(),
661            WorkerId {
662                name: "worker".to_string()
663            }
664        );
665        assert_eq!(
666            WorkerId::from_str("worker-0").unwrap(),
667            WorkerId {
668                name: "worker-0".to_string()
669            }
670        );
671        assert_eq!(
672            WorkerId::from_str("complex&*-worker-name-0").unwrap(),
673            WorkerId {
674                name: "complex&*-worker-name-0".to_string()
675            }
676        );
677    }
678
679    #[tokio::test]
680    async fn it_works() {
681        let in_memory = MemoryStorage::new();
682        let mut handle = in_memory.clone();
683
684        tokio::spawn(async move {
685            for i in 0..ITEMS {
686                handle.enqueue(i).await.unwrap();
687            }
688        });
689
690        #[derive(Clone, Debug, Default)]
691        struct Count(Arc<AtomicUsize>);
692
693        impl Deref for Count {
694            type Target = Arc<AtomicUsize>;
695            fn deref(&self) -> &Self::Target {
696                &self.0
697            }
698        }
699
700        async fn task(job: u32, count: Data<Count>, worker: Worker<Context>) {
701            count.fetch_add(1, Ordering::Relaxed);
702            if job == ITEMS - 1 {
703                worker.stop();
704            }
705        }
706        let worker = WorkerBuilder::new("rango-tango")
707            .data(Count::default())
708            .backend(in_memory);
709        let worker = worker.build_fn(task);
710        worker.run().await;
711    }
712}