1use crate::catch_unwind::CatchUnwindFuture;
2use crate::errors::{AsyncQueueError, BackieError};
3use crate::runnable::BackgroundTask;
4use crate::store::TaskStore;
5use crate::task::{CurrentTask, Task, TaskState};
6use crate::{QueueConfig, RetentionMode};
7use futures::future::FutureExt;
8use futures::select;
9use std::collections::BTreeMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14pub type ExecuteTaskFn<AppData> = Arc<
15 dyn Fn(
16 CurrentTask,
17 serde_json::Value,
18 AppData,
19 ) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
20 + Send
21 + Sync,
22>;
23
24pub type StateFn<AppData> = Arc<dyn Fn() -> AppData + Send + Sync>;
25
26#[derive(Debug, thiserror::Error)]
27pub enum TaskExecError {
28 #[error("Task deserialization failed: {0}")]
29 TaskDeserializationFailed(#[from] serde_json::Error),
30
31 #[error("Task execution failed: {0}")]
32 ExecutionFailed(String),
33
34 #[error("Task panicked with: {0}")]
35 Panicked(String),
36}
37
38pub(crate) fn runnable<BT>(
39 task_info: CurrentTask,
40 payload: serde_json::Value,
41 app_context: BT::AppData,
42) -> Pin<Box<dyn Future<Output = Result<(), TaskExecError>> + Send>>
43where
44 BT: BackgroundTask,
45{
46 Box::pin(async move {
47 let background_task: BT = serde_json::from_value(payload)?;
48 match background_task.run(task_info, app_context).await {
49 Ok(_) => Ok(()),
50 Err(err) => Err(TaskExecError::ExecutionFailed(format!("{:?}", err))),
51 }
52 })
53}
54
55pub struct Worker<AppData, S>
57where
58 AppData: Clone + Send + 'static,
59 S: TaskStore + Clone,
60{
61 store: S,
62
63 config: QueueConfig,
64
65 task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
66
67 app_data_fn: StateFn<AppData>,
68
69 shutdown: Option<tokio::sync::watch::Receiver<()>>,
71}
72
73impl<AppData, S> Worker<AppData, S>
74where
75 AppData: Clone + Send + 'static,
76 S: TaskStore + Clone,
77{
78 pub(crate) fn new(
79 store: S,
80 config: QueueConfig,
81 task_registry: BTreeMap<String, ExecuteTaskFn<AppData>>,
82 app_data_fn: StateFn<AppData>,
83 shutdown: Option<tokio::sync::watch::Receiver<()>>,
84 ) -> Self {
85 Self {
86 store,
87 config,
88 task_registry,
89 app_data_fn,
90 shutdown,
91 }
92 }
93
94 pub(crate) async fn run_tasks(&mut self) -> Result<(), BackieError> {
95 let registered_task_names = self.task_registry.keys().cloned().collect::<Vec<_>>();
96 loop {
97 if let Some(ref shutdown) = self.shutdown {
99 if shutdown.has_changed()? {
100 return Ok(());
101 }
102 };
103
104 match self
105 .store
106 .pull_next_task(
107 &self.config.name,
108 self.config.execution_timeout,
109 ®istered_task_names,
110 )
111 .await?
112 {
113 Some(task) => {
114 self.run(task).await?;
115 }
116 None => {
117 match &mut self.shutdown {
120 Some(recv) => {
121 select! {
124 _ = recv.changed().fuse() => {
125 log::info!("Shutting down worker");
126 return Ok(());
127 }
128 _ = tokio::time::sleep(self.config.pull_interval).fuse() => {}
129 }
130 }
131 None => {
132 tokio::time::sleep(self.config.pull_interval).await;
133 }
134 };
135 }
136 };
137 }
138 }
139
140 async fn run(&self, task: Task) -> Result<(), BackieError> {
141 let task_info = CurrentTask::new(&task);
142 let runnable_task_caller = self
143 .task_registry
144 .get(&task.task_name)
145 .ok_or_else(|| AsyncQueueError::TaskNotRegistered(task.task_name.clone()))?;
146
147 let result: Result<(), TaskExecError> = CatchUnwindFuture::create({
149 let task_payload = task.payload.clone();
150 let app_data = (self.app_data_fn)();
151 let runnable_task_caller = runnable_task_caller.clone();
152 async move { runnable_task_caller(task_info, task_payload, app_data).await }
153 })
154 .await
155 .and_then(|result| {
156 result?;
157 Ok(())
158 });
159
160 match &result {
161 Ok(_) => self.finalize_task(task, result).await?,
162 Err(error) => {
163 if task.retries < task.max_retries {
164 let backoff = task.backoff_mode().next_attempt(task.retries);
165
166 log::debug!(
167 "Task {} failed to run and will be retried in {} seconds",
168 task.id,
169 backoff.as_secs()
170 );
171
172 let error_message = format!("{}", error);
173
174 self.store
175 .schedule_task_retry(task.id, backoff, &error_message)
176 .await?;
177 } else {
178 log::debug!("Task {} failed and reached the maximum retries", task.id);
179 self.finalize_task(task, result).await?;
180 }
181 }
182 }
183 Ok(())
184 }
185
186 async fn finalize_task(
187 &self,
188 task: Task,
189 result: Result<(), TaskExecError>,
190 ) -> Result<(), BackieError> {
191 match self.config.retention_mode {
192 RetentionMode::KeepAll => match result {
193 Ok(_) => {
194 self.store.set_task_state(task.id, TaskState::Done).await?;
195 log::debug!("Task {} done and kept in the database", task.id);
196 }
197 Err(error) => {
198 log::debug!("Task {} failed and kept in the database", task.id);
199 self.store
200 .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
201 .await?;
202 }
203 },
204 RetentionMode::RemoveAll => {
205 log::debug!("Task {} finalized and deleted from the database", task.id);
206 self.store.remove_task(task.id).await?;
207 }
208 RetentionMode::RemoveDone => match result {
209 Ok(_) => {
210 log::debug!("Task {} done and deleted from the database", task.id);
211 self.store.remove_task(task.id).await?;
212 }
213 Err(error) => {
214 log::debug!("Task {} failed and kept in the database", task.id);
215 self.store
216 .set_task_state(task.id, TaskState::Failed(format!("{}", error)))
217 .await?;
218 }
219 },
220 };
221
222 Ok(())
223 }
224}
225
226#[cfg(test)]
227mod async_worker_tests {
228 use super::*;
229 use async_trait::async_trait;
230 use serde::{Deserialize, Serialize};
231
232 #[derive(thiserror::Error, Debug)]
233 enum TaskError {
234 #[error("Something went wrong")]
235 SomethingWrong,
236
237 #[error("{0}")]
238 Custom(String),
239 }
240
241 #[derive(Serialize, Deserialize)]
242 struct WorkerAsyncTask {
243 pub number: u16,
244 }
245
246 #[async_trait]
247 impl BackgroundTask for WorkerAsyncTask {
248 const TASK_NAME: &'static str = "WorkerAsyncTask";
249 type AppData = ();
250 type Error = ();
251
252 async fn run(&self, _: CurrentTask, _: Self::AppData) -> Result<(), ()> {
253 Ok(())
254 }
255 }
256
257 #[derive(Serialize, Deserialize)]
258 struct WorkerAsyncTaskSchedule {
259 pub number: u16,
260 }
261
262 #[async_trait]
263 impl BackgroundTask for WorkerAsyncTaskSchedule {
264 const TASK_NAME: &'static str = "WorkerAsyncTaskSchedule";
265 type AppData = ();
266 type Error = ();
267
268 async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
269 Ok(())
270 }
271
272 }
276
277 #[derive(Serialize, Deserialize)]
278 struct AsyncFailedTask {
279 pub number: u16,
280 }
281
282 #[async_trait]
283 impl BackgroundTask for AsyncFailedTask {
284 const TASK_NAME: &'static str = "AsyncFailedTask";
285 const MAX_RETRIES: i32 = 0;
286 type AppData = ();
287 type Error = TaskError;
288
289 async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), TaskError> {
290 let message = format!("number {} is wrong :(", self.number);
291
292 Err(TaskError::Custom(message))
293 }
294 }
295
296 #[derive(Serialize, Deserialize, Clone)]
297 struct AsyncRetryTask {}
298
299 #[async_trait]
300 impl BackgroundTask for AsyncRetryTask {
301 const TASK_NAME: &'static str = "AsyncRetryTask";
302 type AppData = ();
303 type Error = TaskError;
304
305 async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
306 Err(TaskError::SomethingWrong)
307 }
308 }
309
310 #[derive(Serialize, Deserialize)]
311 struct AsyncTaskType1 {}
312
313 #[async_trait]
314 impl BackgroundTask for AsyncTaskType1 {
315 const TASK_NAME: &'static str = "AsyncTaskType1";
316 type AppData = ();
317 type Error = ();
318
319 async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), Self::Error> {
320 Ok(())
321 }
322 }
323
324 #[derive(Serialize, Deserialize)]
325 struct AsyncTaskType2 {}
326
327 #[async_trait]
328 impl BackgroundTask for AsyncTaskType2 {
329 const TASK_NAME: &'static str = "AsyncTaskType2";
330 type AppData = ();
331 type Error = ();
332
333 async fn run(&self, _task: CurrentTask, _data: Self::AppData) -> Result<(), ()> {
334 Ok(())
335 }
336 }
337}