apalis_core/worker/
mod.rs

1use crate::backend::Backend;
2use crate::error::{BoxDynError, Error};
3use crate::layers::extensions::Data;
4use crate::monitor::shutdown::Shutdown;
5use crate::request::Request;
6use crate::service_fn::FromRequest;
7use crate::task::task_id::TaskId;
8use call_all::CallAllUnordered;
9use futures::future::{join, select, BoxFuture};
10use futures::stream::BoxStream;
11use futures::{Future, FutureExt, Stream, StreamExt};
12use pin_project_lite::pin_project;
13use serde::{Deserialize, Serialize};
14use std::fmt::Debug;
15use std::fmt::{self, Display};
16use std::ops::{Deref, DerefMut};
17use std::pin::Pin;
18use std::str::FromStr;
19use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
20use std::sync::{Arc, Mutex, RwLock};
21use std::task::{Context as TaskCtx, Poll, Waker};
22use thiserror::Error;
23use tower::{Layer, Service, ServiceBuilder};
24
25mod call_all;
26
27/// A worker name wrapper usually used by Worker builder
28#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
29pub struct WorkerId {
30    name: String,
31}
32
33/// An event handler for [`Worker`]
34pub type EventHandler = Arc<RwLock<Option<Box<dyn Fn(Worker<Event>) + Send + Sync>>>>;
35
36impl FromStr for WorkerId {
37    type Err = ();
38
39    fn from_str(s: &str) -> Result<Self, Self::Err> {
40        Ok(WorkerId { name: s.to_owned() })
41    }
42}
43
44impl Display for WorkerId {
45    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46        f.write_str(self.name())?;
47        Ok(())
48    }
49}
50
51impl WorkerId {
52    /// Build a new worker ref
53    pub fn new<T: AsRef<str>>(name: T) -> Self {
54        Self {
55            name: name.as_ref().to_string(),
56        }
57    }
58
59    /// Get the name of the worker
60    pub fn name(&self) -> &str {
61        &self.name
62    }
63}
64
65/// Events emitted by a worker
66#[derive(Debug)]
67pub enum Event {
68    /// Worker started
69    Start,
70    /// Worker got a job
71    Engage(TaskId),
72    /// Worker is idle, stream has no new request for now
73    Idle,
74    /// A custom event
75    Custom(String),
76    /// Worker encountered an error
77    Error(BoxDynError),
78    /// Worker stopped
79    Stop,
80    /// Worker completed all pending tasks
81    Exit,
82}
83
84impl fmt::Display for Worker<Event> {
85    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
86        let event_description = match &self.state {
87            Event::Start => "Worker started".to_string(),
88            Event::Engage(task_id) => format!("Worker engaged with Task ID: {}", task_id),
89            Event::Idle => "Worker is idle".to_string(),
90            Event::Custom(msg) => format!("Custom event: {}", msg),
91            Event::Error(err) => format!("Worker encountered an error: {}", err),
92            Event::Stop => "Worker stopped".to_string(),
93            Event::Exit => "Worker completed all pending tasks and exited".to_string(),
94        };
95
96        write!(f, "Worker [{}]: {}", self.id.name, event_description)
97    }
98}
99
100/// Possible errors that can occur when starting a worker.
101#[derive(Error, Debug, Clone)]
102pub enum WorkerError {
103    /// An error occurred while processing a job.
104    #[error("Failed to process job: {0}")]
105    ProcessingError(String),
106    /// An error occurred in the worker's service.
107    #[error("Service error: {0}")]
108    ServiceError(String),
109    /// An error occurred while trying to start the worker.
110    #[error("Failed to start worker: {0}")]
111    StartError(String),
112}
113
114/// A worker that is ready for running
115pub struct Ready<S, P> {
116    service: S,
117    backend: P,
118    pub(crate) shutdown: Option<Shutdown>,
119    pub(crate) event_handler: EventHandler,
120}
121
122impl<S, P> fmt::Debug for Ready<S, P>
123where
124    S: fmt::Debug,
125    P: fmt::Debug,
126{
127    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
128        f.debug_struct("Ready")
129            .field("service", &self.service)
130            .field("backend", &self.backend)
131            .field("shutdown", &self.shutdown)
132            .field("event_handler", &"...") // Avoid dumping potentially sensitive or verbose data
133            .finish()
134    }
135}
136
137impl<S, P> Clone for Ready<S, P>
138where
139    S: Clone,
140    P: Clone,
141{
142    fn clone(&self) -> Self {
143        Ready {
144            service: self.service.clone(),
145            backend: self.backend.clone(),
146            shutdown: self.shutdown.clone(),
147            event_handler: self.event_handler.clone(),
148        }
149    }
150}
151
152impl<S, P> Ready<S, P> {
153    /// Build a worker that is ready for execution
154    pub fn new(service: S, poller: P) -> Self {
155        Ready {
156            service,
157            backend: poller,
158            shutdown: None,
159            event_handler: EventHandler::default(),
160        }
161    }
162}
163
164/// Represents a generic [Worker] that can be in many different states
165#[derive(Debug, Clone, Serialize)]
166pub struct Worker<T> {
167    pub(crate) id: WorkerId,
168    pub(crate) state: T,
169}
170
171impl<T> Worker<T> {
172    /// Create a new worker instance
173    pub fn new(id: WorkerId, state: T) -> Self {
174        Self { id, state }
175    }
176
177    /// Get the inner state
178    pub fn inner(&self) -> &T {
179        &self.state
180    }
181
182    /// Get the worker id
183    pub fn id(&self) -> &WorkerId {
184        &self.id
185    }
186}
187
188impl<T> Deref for Worker<T> {
189    type Target = T;
190    fn deref(&self) -> &Self::Target {
191        &self.state
192    }
193}
194
195impl<T> DerefMut for Worker<T> {
196    fn deref_mut(&mut self) -> &mut Self::Target {
197        &mut self.state
198    }
199}
200
201impl Worker<Context> {
202    /// Allows workers to emit events
203    pub fn emit(&self, event: Event) -> bool {
204        if let Some(handler) = self.state.event_handler.read().unwrap().as_ref() {
205            handler(Worker {
206                id: self.id().clone(),
207                state: event,
208            });
209            return true;
210        }
211        false
212    }
213    /// Start running the worker
214    pub fn start(&self) {
215        self.state.running.store(true, Ordering::Relaxed);
216        self.state.is_ready.store(true, Ordering::Release);
217        self.emit(Event::Start);
218    }
219}
220
221impl<Req, Ctx> FromRequest<Request<Req, Ctx>> for Worker<Context> {
222    fn from_request(req: &Request<Req, Ctx>) -> Result<Self, Error> {
223        req.parts.data.get_checked().cloned()
224    }
225}
226
227impl<S, P> Worker<Ready<S, P>> {
228    /// Add an event handler to the worker
229    pub fn on_event<F: Fn(Worker<Event>) + Send + Sync + 'static>(self, f: F) -> Self {
230        let _ = self.event_handler.write().map(|mut res| {
231            let _ = res.insert(Box::new(f));
232        });
233        self
234    }
235
236    fn poll_jobs<Svc, Stm, Req, Ctx>(
237        worker: Worker<Context>,
238        service: Svc,
239        stream: Stm,
240    ) -> BoxStream<'static, ()>
241    where
242        Svc: Service<Request<Req, Ctx>> + Send + 'static,
243        Stm: Stream<Item = Result<Option<Request<Req, Ctx>>, Error>> + Send + Unpin + 'static,
244        Req: Send + 'static,
245        Svc::Future: Send,
246        Svc::Error: Send + 'static + Into<BoxDynError>,
247        Ctx: Send + 'static,
248    {
249        let w = worker.clone();
250        let stream = stream.filter_map(move |result| {
251            let worker = worker.clone();
252
253            async move {
254                match result {
255                    Ok(Some(request)) => {
256                        worker.emit(Event::Engage(request.parts.task_id.clone()));
257                        Some(request)
258                    }
259                    Ok(None) => {
260                        worker.emit(Event::Idle);
261                        None
262                    }
263                    Err(err) => {
264                        worker.emit(Event::Error(Box::new(err)));
265                        None
266                    }
267                }
268            }
269        });
270        let stream = CallAllUnordered::new(service, stream).map(move |res| {
271            if let Err(error) = res {
272                let error = error.into();
273                if let Some(Error::MissingData(_)) = error.downcast_ref::<Error>() {
274                    w.stop();
275                }
276                w.emit(Event::Error(error));
277            }
278        });
279        stream.boxed()
280    }
281    /// Start a worker
282    pub fn run<Req, Ctx>(self) -> Runnable
283    where
284        S: Service<Request<Req, Ctx>> + 'static,
285        P: Backend<Request<Req, Ctx>> + 'static,
286        Req: Send + 'static,
287        S::Error: Send + 'static + Into<BoxDynError>,
288        P::Stream: Unpin + Send + 'static,
289        P::Layer: Layer<S>,
290        <P::Layer as Layer<S>>::Service: Service<Request<Req, Ctx>> + Send,
291        <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Future: Send,
292        <<P::Layer as Layer<S>>::Service as Service<Request<Req, Ctx>>>::Error:
293            Send + Into<BoxDynError>,
294        Ctx: Send + 'static,
295    {
296        fn type_name_of_val<T>(_t: &T) -> &'static str {
297            std::any::type_name::<T>()
298        }
299        let service = self.state.service;
300        let worker_id = self.id;
301        let ctx = Context {
302            running: Arc::default(),
303            task_count: Arc::default(),
304            waker: Arc::default(),
305            shutdown: self.state.shutdown,
306            event_handler: self.state.event_handler.clone(),
307            is_ready: Arc::default(),
308            service: type_name_of_val(&service).to_owned(),
309        };
310        let worker = Worker {
311            id: worker_id.clone(),
312            state: ctx.clone(),
313        };
314        let backend = self.state.backend;
315
316        let poller = backend.poll(&worker);
317        let stream = poller.stream;
318        let heartbeat = poller.heartbeat.boxed();
319        let layer = poller.layer;
320        let service = ServiceBuilder::new()
321            .layer(TrackerLayer::new(worker.state.clone()))
322            .layer(ReadinessLayer::new(worker.state.is_ready.clone()))
323            .layer(Data::new(worker.clone()))
324            .layer(layer)
325            .service(service);
326
327        Runnable {
328            poller: Self::poll_jobs(worker.clone(), service, stream),
329            heartbeat,
330            worker,
331            running: false,
332        }
333    }
334}
335
336/// A `Runnable` represents a unit of work that manages a worker's lifecycle and execution flow.
337///
338/// The `Runnable` struct is responsible for coordinating the core tasks of a worker, such as polling for jobs,
339/// maintaining heartbeats, and tracking its running state. It integrates various components required for
340/// the worker to operate effectively within an asynchronous runtime.
341#[must_use = "A Runnable must be awaited of no jobs will be consumed"]
342pub struct Runnable {
343    poller: BoxStream<'static, ()>,
344    heartbeat: BoxFuture<'static, ()>,
345    worker: Worker<Context>,
346    running: bool,
347}
348
349impl Runnable {
350    /// Returns a handle to the worker, allowing control and functionality like stopping
351    pub fn get_handle(&self) -> Worker<Context> {
352        self.worker.clone()
353    }
354}
355
356impl fmt::Debug for Runnable {
357    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
358        f.debug_struct("Runnable")
359            .field("poller", &"<stream>")
360            .field("heartbeat", &"<future>")
361            .field("worker", &self.worker)
362            .field("running", &self.running)
363            .finish()
364    }
365}
366
367impl Future for Runnable {
368    type Output = ();
369
370    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
371        let this = self.get_mut();
372        let poller = &mut this.poller;
373        let heartbeat = &mut this.heartbeat;
374        let worker = &mut this.worker;
375
376        let poller_future = async { while (poller.next().await).is_some() {} };
377
378        if !this.running {
379            worker.start();
380            this.running = true;
381        }
382        let combined = Box::pin(join(poller_future, heartbeat.as_mut()));
383
384        let mut combined = select(
385            combined,
386            worker.state.clone().map(|_| worker.emit(Event::Stop)),
387        )
388        .boxed();
389        match Pin::new(&mut combined).poll(cx) {
390            Poll::Ready(_) => {
391                worker.emit(Event::Exit);
392                Poll::Ready(())
393            }
394            Poll::Pending => Poll::Pending,
395        }
396    }
397}
398
399/// Stores the Workers context
400#[derive(Clone, Default)]
401pub struct Context {
402    task_count: Arc<AtomicUsize>,
403    waker: Arc<Mutex<Option<Waker>>>,
404    running: Arc<AtomicBool>,
405    shutdown: Option<Shutdown>,
406    event_handler: EventHandler,
407    is_ready: Arc<AtomicBool>,
408    service: String,
409}
410
411impl fmt::Debug for Context {
412    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413        f.debug_struct("WorkerContext")
414            .field("shutdown", &["Shutdown handle"])
415            .field("task_count", &self.task_count)
416            .field("running", &self.running)
417            .field("service", &self.service)
418            .finish()
419    }
420}
421
422pin_project! {
423    /// A future tracked by the worker
424    pub struct Tracked<F> {
425        ctx: Context,
426        #[pin]
427        task: F,
428    }
429}
430
431impl<F: Future> Future for Tracked<F> {
432    type Output = F::Output;
433
434    fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<F::Output> {
435        let this = self.project();
436
437        match this.task.poll(cx) {
438            res @ Poll::Ready(_) => {
439                this.ctx.end_task();
440                res
441            }
442            Poll::Pending => Poll::Pending,
443        }
444    }
445}
446
447impl Context {
448    /// Start a task that is tracked by the worker
449    pub fn track<F: Future>(&self, task: F) -> Tracked<F> {
450        self.start_task();
451        Tracked {
452            ctx: self.clone(),
453            task,
454        }
455    }
456
457    /// Calling this function triggers shutting down the worker while waiting for any tasks to complete
458    pub fn stop(&self) {
459        self.running.store(false, Ordering::Relaxed);
460        self.wake()
461    }
462
463    fn start_task(&self) {
464        self.task_count.fetch_add(1, Ordering::Relaxed);
465    }
466
467    fn end_task(&self) {
468        if self.task_count.fetch_sub(1, Ordering::Relaxed) == 1 {
469            self.wake();
470        }
471    }
472
473    pub(crate) fn wake(&self) {
474        if let Ok(waker) = self.waker.lock() {
475            if let Some(waker) = &*waker {
476                waker.wake_by_ref();
477            }
478        }
479    }
480
481    /// Returns whether the worker is running
482    pub fn is_running(&self) -> bool {
483        self.running.load(Ordering::Relaxed)
484    }
485
486    /// Returns the current futures in the worker domain
487    /// This include futures spawned via `worker.track`
488    pub fn task_count(&self) -> usize {
489        self.task_count.load(Ordering::Relaxed)
490    }
491
492    /// Returns whether the worker has pending tasks
493    pub fn has_pending_tasks(&self) -> bool {
494        self.task_count.load(Ordering::Relaxed) > 0
495    }
496
497    /// Is the shutdown token called
498    pub fn is_shutting_down(&self) -> bool {
499        self.shutdown
500            .as_ref()
501            .map(|s| !self.is_running() || s.is_shutting_down())
502            .unwrap_or(!self.is_running())
503    }
504
505    fn add_waker(&self, cx: &mut TaskCtx<'_>) {
506        if let Ok(mut waker_guard) = self.waker.lock() {
507            if waker_guard
508                .as_ref()
509                .map_or(true, |stored_waker| !stored_waker.will_wake(cx.waker()))
510            {
511                *waker_guard = Some(cx.waker().clone());
512            }
513        }
514    }
515
516    /// Checks if the stored waker matches the current one.
517    fn has_recent_waker(&self, cx: &TaskCtx<'_>) -> bool {
518        if let Ok(waker_guard) = self.waker.lock() {
519            if let Some(stored_waker) = &*waker_guard {
520                return stored_waker.will_wake(cx.waker());
521            }
522        }
523        false
524    }
525
526    /// Returns if the worker is ready to consume new tasks
527    pub fn is_ready(&self) -> bool {
528        self.is_ready.load(Ordering::Acquire) && !self.is_shutting_down()
529    }
530
531    /// Get the type of service
532    pub fn get_service(&self) -> &String {
533        &self.service
534    }
535}
536
537impl Future for Context {
538    type Output = ();
539
540    fn poll(self: Pin<&mut Self>, cx: &mut TaskCtx<'_>) -> Poll<()> {
541        let task_count = self.task_count.load(Ordering::Relaxed);
542        if self.is_shutting_down() && task_count == 0 {
543            Poll::Ready(())
544        } else {
545            if !self.has_recent_waker(cx) {
546                self.add_waker(cx);
547            }
548            Poll::Pending
549        }
550    }
551}
552
553#[derive(Debug, Clone)]
554struct TrackerLayer {
555    ctx: Context,
556}
557
558impl TrackerLayer {
559    fn new(ctx: Context) -> Self {
560        Self { ctx }
561    }
562}
563
564impl<S> Layer<S> for TrackerLayer {
565    type Service = TrackerService<S>;
566
567    fn layer(&self, service: S) -> Self::Service {
568        TrackerService {
569            ctx: self.ctx.clone(),
570            service,
571        }
572    }
573}
574#[derive(Debug, Clone)]
575struct TrackerService<S> {
576    ctx: Context,
577    service: S,
578}
579
580impl<S, Req, Ctx> Service<Request<Req, Ctx>> for TrackerService<S>
581where
582    S: Service<Request<Req, Ctx>>,
583{
584    type Response = S::Response;
585    type Error = S::Error;
586    type Future = Tracked<S::Future>;
587
588    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
589        self.service.poll_ready(cx)
590    }
591
592    fn call(&mut self, request: Request<Req, Ctx>) -> Self::Future {
593        request.parts.attempt.increment();
594        self.ctx.track(self.service.call(request))
595    }
596}
597
598#[derive(Clone)]
599struct ReadinessLayer {
600    is_ready: Arc<AtomicBool>,
601}
602
603impl ReadinessLayer {
604    fn new(is_ready: Arc<AtomicBool>) -> Self {
605        Self { is_ready }
606    }
607}
608
609impl<S> Layer<S> for ReadinessLayer {
610    type Service = ReadinessService<S>;
611
612    fn layer(&self, inner: S) -> Self::Service {
613        ReadinessService {
614            inner,
615            is_ready: self.is_ready.clone(),
616        }
617    }
618}
619
620struct ReadinessService<S> {
621    inner: S,
622    is_ready: Arc<AtomicBool>,
623}
624
625impl<S, Request> Service<Request> for ReadinessService<S>
626where
627    S: Service<Request>,
628{
629    type Response = S::Response;
630    type Error = S::Error;
631    type Future = S::Future;
632
633    fn poll_ready(&mut self, cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> {
634        // Delegate poll_ready to the inner service
635        let result = self.inner.poll_ready(cx);
636        // Update the readiness state based on the result
637        match &result {
638            Poll::Ready(Ok(_)) => self.is_ready.store(true, Ordering::Release),
639            Poll::Pending | Poll::Ready(Err(_)) => self.is_ready.store(false, Ordering::Release),
640        }
641
642        result
643    }
644
645    fn call(&mut self, req: Request) -> Self::Future {
646        self.inner.call(req)
647    }
648}
649
650#[cfg(test)]
651mod tests {
652    use std::{ops::Deref, sync::atomic::AtomicUsize};
653
654    use crate::{
655        builder::{WorkerBuilder, WorkerFactoryFn},
656        layers::extensions::Data,
657        memory::MemoryStorage,
658        mq::MessageQueue,
659    };
660
661    use super::*;
662
663    const ITEMS: u32 = 100;
664
665    #[test]
666    fn it_parses_worker_names() {
667        assert_eq!(
668            WorkerId::from_str("worker").unwrap(),
669            WorkerId {
670                name: "worker".to_string()
671            }
672        );
673        assert_eq!(
674            WorkerId::from_str("worker-0").unwrap(),
675            WorkerId {
676                name: "worker-0".to_string()
677            }
678        );
679        assert_eq!(
680            WorkerId::from_str("complex&*-worker-name-0").unwrap(),
681            WorkerId {
682                name: "complex&*-worker-name-0".to_string()
683            }
684        );
685    }
686
687    #[tokio::test]
688    async fn it_works() {
689        let in_memory = MemoryStorage::new();
690        let mut handle = in_memory.clone();
691
692        tokio::spawn(async move {
693            for i in 0..ITEMS {
694                handle.enqueue(i).await.unwrap();
695            }
696        });
697
698        #[derive(Clone, Debug, Default)]
699        struct Count(Arc<AtomicUsize>);
700
701        impl Deref for Count {
702            type Target = Arc<AtomicUsize>;
703            fn deref(&self) -> &Self::Target {
704                &self.0
705            }
706        }
707
708        async fn task(job: u32, count: Data<Count>, worker: Worker<Context>) {
709            count.fetch_add(1, Ordering::Relaxed);
710            if job == ITEMS - 1 {
711                worker.stop();
712            }
713        }
714        let worker = WorkerBuilder::new("rango-tango")
715            .data(Count::default())
716            .backend(in_memory);
717        let worker = worker.build_fn(task);
718        worker.run().await;
719    }
720}