backie/
worker.rs

1use crate::catch_unwind::CatchUnwindFuture;
2use crate::errors::{AsyncQueueError, BackieError};
3use crate::runnable::BackgroundTask;
4use crate::store::TaskStore;
5use crate::task::{CurrentTask, Task, TaskState};
6use crate::{QueueConfig, RetentionMode};
7use futures::future::FutureExt;
8use futures::select;
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14pub type ExecuteTaskFn<AppData> = Arc<
15    dyn Fn(
16            CurrentTask,
17            serde_json::Value,
18            AppData,
19        ) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
20        + Send
21        + Sync,
22>;
23
24pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
25
26#[derive(Debug, thiserror::Error)]
27pub enum TaskExecError {
28    #[error("Task deserialization failed: {0}")]
29    TaskDeserializationFailed(#[from] serde_json::Error),
30
31    #[error("Task execution failed: {0}")]
32    ExecutionFailed(String),
33
34    #[error("Task panicked with: {0}")]
35    Panicked(String),
36}
37
38pub(crate) fn runnable<BT>(
39    task_info: CurrentTask,
40    payload: serde_json::Value,
41    app_context: BT::AppData,
42) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
43where
44    BT: BackgroundTask,
45{
46    Box::pin(async move {
47        let background_task: BT = serde_json::from_value(payload)?;
48        match background_task.run(task_info, app_context).await {
49            Ok(_) => Ok(()),
50            Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
51        }
52    })
53}
54
55/// Worker that executes tasks.
56pub struct Worker<AppData, S>
57where
58    AppData: Clone + Send + 'static,
59    S: TaskStore + Clone,
60{
61    store: S,
62
63    config: QueueConfig,
64
65    task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
66
67    app_data_fn: StateFn<AppData>,
68
69    /// Notification for the worker to stop.
70    shutdown: Option<tokio::sync::watch::Receiver<()>>,
71}
72
73impl<AppData, S> Worker<AppData, S>
74where
75    AppData: Clone + Send + 'static,
76    S: TaskStore + Clone,
77{
78    pub(crate) fn new(
79        store: S,
80        config: QueueConfig,
81        task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
82        app_data_fn: StateFn<AppData>,
83        shutdown: Option<tokio::sync::watch::Receiver<()>>,
84    ) -> Self {
85        Self {
86            store,
87            config,
88            task_registry,
89            app_data_fn,
90            shutdown,
91        }
92    }
93
94    pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
95        let registered_task_names = self.task_registry.keys().cloned().collect::<Vec<_>>();
96        loop {
97            // Check if has to stop before pulling next task
98            if let Some(ref shutdown) = self.shutdown {
99                if shutdown.has_changed()? {
100                    return Ok(());
101                }
102            };
103
104            match self
105                .store
106                .pull_next_task(
107                    &self.config.name,
108                    self.config.execution_timeout,
109                    &registered_task_names,
110                )
111                .await?
112            {
113                Some(task) => {
114                    self.run(task).await?;
115                }
116                None => {
117                    // Listen to watchable future
118                    // All that until a max timeout
119                    match &mut self.shutdown {
120                        Some(recv) => {
121                            // Listen to watchable future
122                            // All that until a max timeout
123                            select! {
124                                _ = recv.changed().fuse() => {
125                                    log::info!("Shutting down worker");
126                                    return Ok(());
127                                }
128                                _ = tokio::time::sleep(self.config.pull_interval).fuse() => {}
129                            }
130                        }
131                        None => {
132                            tokio::time::sleep(self.config.pull_interval).await;
133                        }
134                    };
135                }
136            };
137        }
138    }
139
140    async fn run(&self, task: Task) -> Result<(), BackieError> {
141        let task_info = CurrentTask::new(&task);
142        let runnable_task_caller = self
143            .task_registry
144            .get(&task.task_name)
145            .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
146
147        // catch panics
148        let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
149            let task_payload = task.payload.clone();
150            let app_data = (self.app_data_fn)();
151            let runnable_task_caller = runnable_task_caller.clone();
152            async move { runnable_task_caller(task_info, task_payload, app_data).await }
153        })
154        .await
155        .and_then(|result| {
156            result?;
157            Ok(())
158        });
159
160        match &result {
161            Ok(_) => self.finalize_task(task, result).await?,
162            Err(error) => {
163                if task.retries < task.max_retries {
164                    let backoff = task.backoff_mode().next_attempt(task.retries);
165
166                    log::debug!(
167                        "Task {} failed to run and will be retried in {} seconds",
168                        task.id,
169                        backoff.as_secs()
170                    );
171
172                    let error_message = format!("{}", error);
173
174                    self.store
175                        .schedule_task_retry(task.id, backoff, &error_message)
176                        .await?;
177                } else {
178                    log::debug!("Task {} failed and reached the maximum retries", task.id);
179                    self.finalize_task(task, result).await?;
180                }
181            }
182        }
183        Ok(())
184    }
185
186    async fn finalize_task(
187        &self,
188        task: Task,
189        result: Result<(), TaskExecError>,
190    ) -> Result<(), BackieError> {
191        match self.config.retention_mode {
192            RetentionMode::KeepAll => match result {
193                Ok(_) => {
194                    self.store.set_task_state(task.id, TaskState::Done).await?;
195                    log::debug!("Task {} done and kept in the database", task.id);
196                }
197                Err(error) => {
198                    log::debug!("Task {} failed and kept in the database", task.id);
199                    self.store
200                        .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
201                        .await?;
202                }
203            },
204            RetentionMode::RemoveAll => {
205                log::debug!("Task {} finalized and deleted from the database", task.id);
206                self.store.remove_task(task.id).await?;
207            }
208            RetentionMode::RemoveDone => match result {
209                Ok(_) => {
210                    log::debug!("Task {} done and deleted from the database", task.id);
211                    self.store.remove_task(task.id).await?;
212                }
213                Err(error) => {
214                    log::debug!("Task {} failed and kept in the database", task.id);
215                    self.store
216                        .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
217                        .await?;
218                }
219            },
220        };
221
222        Ok(())
223    }
224}
225
226#[cfg(test)]
227mod async_worker_tests {
228    use super::*;
229    use async_trait::async_trait;
230    use serde::{Deserialize, Serialize};
231
232    #[derive(thiserror::Error, Debug)]
233    enum TaskError {
234        #[error("Something went wrong")]
235        SomethingWrong,
236
237        #[error("{0}")]
238        Custom(String),
239    }
240
241    #[derive(Serialize, Deserialize)]
242    struct WorkerAsyncTask {
243        pub number: u16,
244    }
245
246    #[async_trait]
247    impl BackgroundTask for WorkerAsyncTask {
248        const TASK_NAME: &'static str = "WorkerAsyncTask";
249        type AppData = ();
250        type Error = ();
251
252        async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
253            Ok(())
254        }
255    }
256
257    #[derive(Serialize, Deserialize)]
258    struct WorkerAsyncTaskSchedule {
259        pub number: u16,
260    }
261
262    #[async_trait]
263    impl BackgroundTask for WorkerAsyncTaskSchedule {
264        const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
265        type AppData = ();
266        type Error = ();
267
268        async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
269            Ok(())
270        }
271
272        // fn cron(&self) -> Option<Scheduled> {
273        //     Some(Scheduled::ScheduleOnce(Utc::now() + Duration::seconds(1)))
274        // }
275    }
276
277    #[derive(Serialize, Deserialize)]
278    struct AsyncFailedTask {
279        pub number: u16,
280    }
281
282    #[async_trait]
283    impl BackgroundTask for AsyncFailedTask {
284        const TASK_NAME: &'static str = "AsyncFailedTask";
285        const MAX_RETRIES: i32 = 0;
286        type AppData = ();
287        type Error = TaskError;
288
289        async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
290            let message = format!("number {} is wrong :(", self.number);
291
292            Err(TaskError::Custom(message))
293        }
294    }
295
296    #[derive(Serialize, Deserialize, Clone)]
297    struct AsyncRetryTask {}
298
299    #[async_trait]
300    impl BackgroundTask for AsyncRetryTask {
301        const TASK_NAME: &'static str = "AsyncRetryTask";
302        type AppData = ();
303        type Error = TaskError;
304
305        async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
306            Err(TaskError::SomethingWrong)
307        }
308    }
309
310    #[derive(Serialize, Deserialize)]
311    struct AsyncTaskType1 {}
312
313    #[async_trait]
314    impl BackgroundTask for AsyncTaskType1 {
315        const TASK_NAME: &'static str = "AsyncTaskType1";
316        type AppData = ();
317        type Error = ();
318
319        async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
320            Ok(())
321        }
322    }
323
324    #[derive(Serialize, Deserialize)]
325    struct AsyncTaskType2 {}
326
327    #[async_trait]
328    impl BackgroundTask for AsyncTaskType2 {
329        const TASK_NAME: &'static str = "AsyncTaskType2";
330        type AppData = ();
331        type Error = ();
332
333        async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
334            Ok(())
335        }
336    }
337}