app_task/
lib.rs

1use std::fmt::Display;
2use std::sync::Arc;
3
4use backoff_strategy::constant_time::ConstantTimeBackoff;
5use backoff_strategy::{BackoffStrategy, DefaultStrategyFactory, StrategyFactory};
6use futures::Future;
7use tokio::task::JoinHandle;
8
9pub mod backoff_strategy;
10
11pub struct TaskRunner<T, SF = DefaultStrategyFactory<ConstantTimeBackoff>> {
12    app: Arc<T>,
13    backoff_strategy: SF,
14}
15
16impl<T> TaskRunner<T> {
17    pub fn new(app: Arc<T>) -> Self {
18        Self {
19            app,
20            backoff_strategy: DefaultStrategyFactory::new(),
21        }
22    }
23}
24
25impl<T, SF> TaskRunner<T, SF>
26where
27    T: Send + Sync + 'static,
28    SF: StrategyFactory,
29{
30    pub fn with_default_strategy<NS>(self) -> TaskRunner<T, DefaultStrategyFactory<NS>>
31    where
32        NS: StrategyFactory,
33    {
34        TaskRunner {
35            app: self.app,
36            backoff_strategy: DefaultStrategyFactory::new(),
37        }
38    }
39
40    pub fn with_strategy<NSF>(self, backoff_strategy: NSF) -> TaskRunner<T, NSF> {
41        TaskRunner {
42            app: self.app,
43            backoff_strategy,
44        }
45    }
46}
47
48impl<T, SF> TaskRunner<T, SF>
49where
50    T: Send + Sync + 'static,
51    SF: StrategyFactory,
52{
53    /// Spawns a task that will run until it returns Ok(()) or panics.
54    /// If the task returns an error, it will be logged and the task will be retried with a backoff.
55    ///
56    /// If the task panics, the panic output will be returned as an error.
57    pub fn spawn_task<S, C, F, E>(&self, label: S, task: C) -> JoinHandle<()>
58    where
59        S: ToString,
60        C: Fn(Arc<T>) -> F + Send + Sync + 'static,
61        F: Future<Output = Result<(), E>> + Send + 'static,
62        E: Display + Send + Sync,
63    {
64        let app = self.app.clone();
65        let label = label.to_string();
66
67        let mut backoff_strategy = self.backoff_strategy.create_strategy();
68
69        tokio::spawn(async move {
70            loop {
71                tracing::info!(task_label = label, "Running task");
72
73                let result = task(app.clone()).await;
74
75                match result {
76                    Ok(()) => {
77                        tracing::info!(task_label = label, "Task finished");
78                        break;
79                    }
80                    Err(err) => {
81                        tracing::error!(task_label = label, error = %err, "Task failed");
82                        backoff_strategy.add_failure();
83                        tokio::time::sleep(backoff_strategy.next_backoff()).await;
84                    }
85                }
86            }
87        })
88    }
89}