celery/task/
mod.rs

1//! Provides the [`Task`] trait as well as options for configuring tasks.
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rand::distributions::{Distribution, Uniform};
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8use std::time::{SystemTime, UNIX_EPOCH};
9
10use crate::error::TaskError;
11
12mod async_result;
13mod options;
14mod request;
15mod signature;
16
17pub use async_result::AsyncResult;
18pub use options::TaskOptions;
19pub use request::Request;
20pub use signature::Signature;
21
22/// The return type for a task.
23pub type TaskResult<R> = Result<R, TaskError>;
24
25#[doc(hidden)]
26pub trait AsTaskResult {
27    type Returns: Send + Sync + std::fmt::Debug + ResultValue + for<'de> Deserialize<'de>;
28}
29
30impl<R> AsTaskResult for TaskResult<R>
31where
32    R: Send + Sync + std::fmt::Debug + ResultValue + for<'de> Deserialize<'de>,
33{
34    type Returns = R;
35}
36
37/// Helper trait for converting task return values into JSON for storage.
38pub trait ResultValue {
39    fn to_json_value(&self) -> Result<Value, serde_json::Error>;
40}
41
42impl<T> ResultValue for T
43where
44    T: Serialize,
45{
46    fn to_json_value(&self) -> Result<Value, serde_json::Error> {
47        serde_json::to_value(self)
48    }
49}
50
51/// A `Task` represents a unit of work that a `Celery` app can produce or consume.
52///
53/// The recommended way to create tasks is through the [`task`](../attr.task.html) attribute macro, not by directly implementing
54/// this trait. For more information see the [tasks chapter](https://rusty-celery.github.io/guide/defining-tasks.html)
55/// in the Rusty Celery Book.
56#[async_trait]
57pub trait Task: Send + Sync + std::marker::Sized {
58    /// The name of the task. When a task is registered it will be registered with this name.
59    const NAME: &'static str;
60
61    /// For compatability with Python tasks. This keeps track of the order
62    /// of arguments for the task so that the task can be called from Python with
63    /// positional arguments.
64    const ARGS: &'static [&'static str];
65
66    /// Default task options.
67    const DEFAULTS: TaskOptions = TaskOptions {
68        time_limit: None,
69        hard_time_limit: None,
70        expires: None,
71        max_retries: None,
72        min_retry_delay: None,
73        max_retry_delay: None,
74        retry_for_unexpected: None,
75        acks_late: None,
76        content_type: None,
77    };
78
79    /// The parameters of the task.
80    type Params: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>;
81
82    /// The return type of the task.
83    type Returns: Send + Sync + std::fmt::Debug + ResultValue + for<'de> Deserialize<'de>;
84
85    /// Used to initialize a task instance from a request.
86    fn from_request(request: Request<Self>, options: TaskOptions) -> Self;
87
88    /// Get a reference to the request used to create this task instance.
89    fn request(&self) -> &Request<Self>;
90
91    /// Get a reference to the task's configuration options.
92    ///
93    /// This is a product of both app-level task options and the options configured specifically
94    /// for the given task. Options specified at the *task*-level take priority over options
95    /// specified at the app level. So, if the task was defined like this:
96    ///
97    /// ```rust
98    /// # use celery::prelude::*;
99    /// #[celery::task(time_limit = 3)]
100    /// fn add(x: i32, y: i32) -> TaskResult<i32> {
101    ///     Ok(x + y)
102    /// }
103    /// ```
104    ///
105    /// But the `Celery` app was built with a `task_time_limit` of 5, then
106    /// `Task::options().time_limit` would be `Some(3)`.
107    fn options(&self) -> &TaskOptions;
108
109    /// This function defines how a task executes.
110    async fn run(&self, params: Self::Params) -> TaskResult<Self::Returns>;
111
112    /// Callback that will run after a task fails.
113    #[allow(unused_variables)]
114    async fn on_failure(&self, err: &TaskError) {}
115
116    /// Callback that will run after a task completes successfully.
117    #[allow(unused_variables)]
118    async fn on_success(&self, returned: &Self::Returns) {}
119
120    /// Returns the registered name of the task.
121    fn name(&self) -> &'static str {
122        Self::NAME
123    }
124
125    /// This can be called from within a task function to trigger a retry in `countdown` seconds.
126    fn retry_with_countdown(&self, countdown: u32) -> TaskResult<Self::Returns> {
127        let eta = match SystemTime::now().duration_since(UNIX_EPOCH) {
128            Ok(now) => {
129                let now_secs = now.as_secs() as u32;
130                let now_millis = now.subsec_millis();
131                let eta_secs = now_secs + countdown;
132                Some(DateTime::<Utc>::from_naive_utc_and_offset(
133                    DateTime::from_timestamp(eta_secs as i64, now_millis * 1000)
134                        .map(|dt| dt.naive_utc())
135                        .ok_or_else(|| {
136                            TaskError::UnexpectedError(format!(
137                                "Invalid countdown seconds {countdown}",
138                            ))
139                        })?,
140                    Utc,
141                ))
142            }
143            Err(_) => None,
144        };
145        Err(TaskError::Retry(eta))
146    }
147
148    /// This can be called from within a task function to trigger a retry at the specified `eta`.
149    fn retry_with_eta(&self, eta: DateTime<Utc>) -> TaskResult<Self::Returns> {
150        Err(TaskError::Retry(Some(eta)))
151    }
152
153    /// Get a future ETA at which time the task should be retried. By default this
154    /// uses a capped exponential backoff strategy.
155    fn retry_eta(&self) -> Option<DateTime<Utc>> {
156        let retries = self.request().retries;
157        let delay_secs = std::cmp::min(
158            2u32.checked_pow(retries)
159                .unwrap_or_else(|| self.max_retry_delay()),
160            self.max_retry_delay(),
161        );
162        let delay_secs = std::cmp::max(delay_secs, self.min_retry_delay());
163        let between = Uniform::from(0..1000);
164        let mut rng = rand::thread_rng();
165        let delay_millis = between.sample(&mut rng);
166        match SystemTime::now().duration_since(UNIX_EPOCH) {
167            Ok(now) => {
168                let now_secs = now.as_secs() as u32;
169                let now_millis = now.subsec_millis();
170                let eta_secs = now_secs + delay_secs;
171                let eta_millis = now_millis + delay_millis;
172                DateTime::from_timestamp(eta_secs as i64, eta_millis * 1000)
173                    .map(|dt| dt.naive_utc())
174                    .map(|eta| DateTime::<Utc>::from_naive_utc_and_offset(eta, Utc))
175            }
176            Err(_) => None,
177        }
178    }
179
180    fn retry_for_unexpected(&self) -> bool {
181        Self::DEFAULTS
182            .retry_for_unexpected
183            .or(self.options().retry_for_unexpected)
184            .unwrap_or(true)
185    }
186
187    fn time_limit(&self) -> Option<u32> {
188        self.request().time_limit.or_else(|| {
189            // Take min or `time_limit` and `hard_time_limit`.
190            let time_limit = Self::DEFAULTS.time_limit.or(self.options().time_limit);
191            let hard_time_limit = Self::DEFAULTS
192                .hard_time_limit
193                .or(self.options().hard_time_limit);
194            match (time_limit, hard_time_limit) {
195                (Some(t1), Some(t2)) => Some(std::cmp::min(t1, t2)),
196                (Some(t1), None) => Some(t1),
197                (None, Some(t2)) => Some(t2),
198                _ => None,
199            }
200        })
201    }
202
203    fn max_retries(&self) -> Option<u32> {
204        Self::DEFAULTS.max_retries.or(self.options().max_retries)
205    }
206
207    fn min_retry_delay(&self) -> u32 {
208        Self::DEFAULTS
209            .min_retry_delay
210            .or(self.options().min_retry_delay)
211            .unwrap_or(0)
212    }
213
214    fn max_retry_delay(&self) -> u32 {
215        Self::DEFAULTS
216            .max_retry_delay
217            .or(self.options().max_retry_delay)
218            .unwrap_or(3600)
219    }
220
221    fn acks_late(&self) -> bool {
222        Self::DEFAULTS
223            .acks_late
224            .or(self.options().acks_late)
225            .unwrap_or(false)
226    }
227}
228
229#[derive(Clone, Debug)]
230pub(crate) enum TaskEvent {
231    StatusChange(TaskStatus),
232}
233
234#[derive(Clone, Debug)]
235pub(crate) enum TaskStatus {
236    Pending,
237    Finished,
238}
239
240/// Extension methods for `Result` types within a task body.
241///
242/// These methods can be used to convert a `Result<T, E>` to a `Result<T, TaskError>` with the
243/// appropriate [`TaskError`] variant. The trait has a blanket implementation for any error type that implements
244/// [`std::error::Error`](https://doc.rust-lang.org/std/error/trait.Error.html).
245///
246/// # Examples
247///
248/// ```rust
249/// # use celery::prelude::*;
250/// fn do_some_io() -> Result<(), std::io::Error> {
251///     unimplemented!()
252/// }
253///
254/// #[celery::task]
255/// fn fallible_io_task() -> TaskResult<()> {
256///     do_some_io().with_expected_err(|| "IO error")?;
257///     Ok(())
258/// }
259/// ```
260pub trait TaskResultExt<T, E, F, C> {
261    /// Convert the error type to a [`TaskError::ExpectedError`].
262    fn with_expected_err(self, f: F) -> Result<T, TaskError>;
263
264    /// Convert the error type to a [`TaskError::UnexpectedError`].
265    fn with_unexpected_err(self, f: F) -> Result<T, TaskError>;
266}
267
268impl<T, E, F, C> TaskResultExt<T, E, F, C> for Result<T, E>
269where
270    E: std::error::Error,
271    F: FnOnce() -> C,
272    C: std::fmt::Display + Send + Sync + 'static,
273{
274    fn with_expected_err(self, f: F) -> Result<T, TaskError> {
275        self.map_err(|e| TaskError::ExpectedError(format!("{} ➥ Cause: {:?}", f(), e)))
276    }
277
278    fn with_unexpected_err(self, f: F) -> Result<T, TaskError> {
279        self.map_err(|e| TaskError::UnexpectedError(format!("{} ➥ Cause: {:?}", f(), e)))
280    }
281}