celery/task/
mod.rs

1//! Provides the [`Task`] trait as well as options for configuring tasks.
2
3use async_trait::async_trait;
4use chrono::{DateTime, Utc};
5use rand::distributions::{Distribution, Uniform};
6use serde::{Deserialize, Serialize};
7use std::time::{SystemTime, UNIX_EPOCH};
8
9use crate::error::TaskError;
10
11mod async_result;
12mod options;
13mod request;
14mod signature;
15
16pub use async_result::AsyncResult;
17pub use options::TaskOptions;
18pub use request::Request;
19pub use signature::Signature;
20
21/// The return type for a task.
22pub type TaskResult<R> = Result<R, TaskError>;
23
24#[doc(hidden)]
25pub trait AsTaskResult {
26    type Returns: Send + Sync + std::fmt::Debug;
27}
28
29impl<R> AsTaskResult for TaskResult<R>
30where
31    R: Send + Sync + std::fmt::Debug,
32{
33    type Returns = R;
34}
35
36/// A `Task` represents a unit of work that a `Celery` app can produce or consume.
37///
38/// The recommended way to create tasks is through the [`task`](../attr.task.html) attribute macro, not by directly implementing
39/// this trait. For more information see the [tasks chapter](https://rusty-celery.github.io/guide/defining-tasks.html)
40/// in the Rusty Celery Book.
41#[async_trait]
42pub trait Task: Send + Sync + std::marker::Sized {
43    /// The name of the task. When a task is registered it will be registered with this name.
44    const NAME: &'static str;
45
46    /// For compatability with Python tasks. This keeps track of the order
47    /// of arguments for the task so that the task can be called from Python with
48    /// positional arguments.
49    const ARGS: &'static [&'static str];
50
51    /// Default task options.
52    const DEFAULTS: TaskOptions = TaskOptions {
53        time_limit: None,
54        hard_time_limit: None,
55        expires: None,
56        max_retries: None,
57        min_retry_delay: None,
58        max_retry_delay: None,
59        retry_for_unexpected: None,
60        acks_late: None,
61        content_type: None,
62    };
63
64    /// The parameters of the task.
65    type Params: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de>;
66
67    /// The return type of the task.
68    type Returns: Send + Sync + std::fmt::Debug;
69
70    /// Used to initialize a task instance from a request.
71    fn from_request(request: Request<Self>, options: TaskOptions) -> Self;
72
73    /// Get a reference to the request used to create this task instance.
74    fn request(&self) -> &Request<Self>;
75
76    /// Get a reference to the task's configuration options.
77    ///
78    /// This is a product of both app-level task options and the options configured specifically
79    /// for the given task. Options specified at the *task*-level take priority over options
80    /// specified at the app level. So, if the task was defined like this:
81    ///
82    /// ```rust
83    /// # use celery::prelude::*;
84    /// #[celery::task(time_limit = 3)]
85    /// fn add(x: i32, y: i32) -> TaskResult<i32> {
86    ///     Ok(x + y)
87    /// }
88    /// ```
89    ///
90    /// But the `Celery` app was built with a `task_time_limit` of 5, then
91    /// `Task::options().time_limit` would be `Some(3)`.
92    fn options(&self) -> &TaskOptions;
93
94    /// This function defines how a task executes.
95    async fn run(&self, params: Self::Params) -> TaskResult<Self::Returns>;
96
97    /// Callback that will run after a task fails.
98    #[allow(unused_variables)]
99    async fn on_failure(&self, err: &TaskError) {}
100
101    /// Callback that will run after a task completes successfully.
102    #[allow(unused_variables)]
103    async fn on_success(&self, returned: &Self::Returns) {}
104
105    /// Returns the registered name of the task.
106    fn name(&self) -> &'static str {
107        Self::NAME
108    }
109
110    /// This can be called from within a task function to trigger a retry in `countdown` seconds.
111    fn retry_with_countdown(&self, countdown: u32) -> TaskResult<Self::Returns> {
112        let eta = match SystemTime::now().duration_since(UNIX_EPOCH) {
113            Ok(now) => {
114                let now_secs = now.as_secs() as u32;
115                let now_millis = now.subsec_millis();
116                let eta_secs = now_secs + countdown;
117                Some(DateTime::<Utc>::from_naive_utc_and_offset(
118                    DateTime::from_timestamp(eta_secs as i64, now_millis * 1000)
119                        .map(|dt| dt.naive_utc())
120                        .ok_or_else(|| {
121                            TaskError::UnexpectedError(format!(
122                                "Invalid countdown seconds {countdown}",
123                            ))
124                        })?,
125                    Utc,
126                ))
127            }
128            Err(_) => None,
129        };
130        Err(TaskError::Retry(eta))
131    }
132
133    /// This can be called from within a task function to trigger a retry at the specified `eta`.
134    fn retry_with_eta(&self, eta: DateTime<Utc>) -> TaskResult<Self::Returns> {
135        Err(TaskError::Retry(Some(eta)))
136    }
137
138    /// Get a future ETA at which time the task should be retried. By default this
139    /// uses a capped exponential backoff strategy.
140    fn retry_eta(&self) -> Option<DateTime<Utc>> {
141        let retries = self.request().retries;
142        let delay_secs = std::cmp::min(
143            2u32.checked_pow(retries)
144                .unwrap_or_else(|| self.max_retry_delay()),
145            self.max_retry_delay(),
146        );
147        let delay_secs = std::cmp::max(delay_secs, self.min_retry_delay());
148        let between = Uniform::from(0..1000);
149        let mut rng = rand::thread_rng();
150        let delay_millis = between.sample(&mut rng);
151        match SystemTime::now().duration_since(UNIX_EPOCH) {
152            Ok(now) => {
153                let now_secs = now.as_secs() as u32;
154                let now_millis = now.subsec_millis();
155                let eta_secs = now_secs + delay_secs;
156                let eta_millis = now_millis + delay_millis;
157                DateTime::from_timestamp(eta_secs as i64, eta_millis * 1000)
158                    .map(|dt| dt.naive_utc())
159                    .map(|eta| DateTime::<Utc>::from_naive_utc_and_offset(eta, Utc))
160            }
161            Err(_) => None,
162        }
163    }
164
165    fn retry_for_unexpected(&self) -> bool {
166        Self::DEFAULTS
167            .retry_for_unexpected
168            .or(self.options().retry_for_unexpected)
169            .unwrap_or(true)
170    }
171
172    fn time_limit(&self) -> Option<u32> {
173        self.request().time_limit.or_else(|| {
174            // Take min or `time_limit` and `hard_time_limit`.
175            let time_limit = Self::DEFAULTS.time_limit.or(self.options().time_limit);
176            let hard_time_limit = Self::DEFAULTS
177                .hard_time_limit
178                .or(self.options().hard_time_limit);
179            match (time_limit, hard_time_limit) {
180                (Some(t1), Some(t2)) => Some(std::cmp::min(t1, t2)),
181                (Some(t1), None) => Some(t1),
182                (None, Some(t2)) => Some(t2),
183                _ => None,
184            }
185        })
186    }
187
188    fn max_retries(&self) -> Option<u32> {
189        Self::DEFAULTS.max_retries.or(self.options().max_retries)
190    }
191
192    fn min_retry_delay(&self) -> u32 {
193        Self::DEFAULTS
194            .min_retry_delay
195            .or(self.options().min_retry_delay)
196            .unwrap_or(0)
197    }
198
199    fn max_retry_delay(&self) -> u32 {
200        Self::DEFAULTS
201            .max_retry_delay
202            .or(self.options().max_retry_delay)
203            .unwrap_or(3600)
204    }
205
206    fn acks_late(&self) -> bool {
207        Self::DEFAULTS
208            .acks_late
209            .or(self.options().acks_late)
210            .unwrap_or(false)
211    }
212}
213
214#[derive(Clone, Debug)]
215pub(crate) enum TaskEvent {
216    StatusChange(TaskStatus),
217}
218
219#[derive(Clone, Debug)]
220pub(crate) enum TaskStatus {
221    Pending,
222    Finished,
223}
224
225/// Extension methods for `Result` types within a task body.
226///
227/// These methods can be used to convert a `Result<T, E>` to a `Result<T, TaskError>` with the
228/// appropriate [`TaskError`] variant. The trait has a blanket implementation for any error type that implements
229/// [`std::error::Error`](https://doc.rust-lang.org/std/error/trait.Error.html).
230///
231/// # Examples
232///
233/// ```rust
234/// # use celery::prelude::*;
235/// fn do_some_io() -> Result<(), std::io::Error> {
236///     unimplemented!()
237/// }
238///
239/// #[celery::task]
240/// fn fallible_io_task() -> TaskResult<()> {
241///     do_some_io().with_expected_err(|| "IO error")?;
242///     Ok(())
243/// }
244/// ```
245pub trait TaskResultExt<T, E, F, C> {
246    /// Convert the error type to a [`TaskError::ExpectedError`].
247    fn with_expected_err(self, f: F) -> Result<T, TaskError>;
248
249    /// Convert the error type to a [`TaskError::UnexpectedError`].
250    fn with_unexpected_err(self, f: F) -> Result<T, TaskError>;
251}
252
253impl<T, E, F, C> TaskResultExt<T, E, F, C> for Result<T, E>
254where
255    E: std::error::Error,
256    F: FnOnce() -> C,
257    C: std::fmt::Display + Send + Sync + 'static,
258{
259    fn with_expected_err(self, f: F) -> Result<T, TaskError> {
260        self.map_err(|e| TaskError::ExpectedError(format!("{} ➥ Cause: {:?}", f(), e)))
261    }
262
263    fn with_unexpected_err(self, f: F) -> Result<T, TaskError> {
264        self.map_err(|e| TaskError::UnexpectedError(format!("{} ➥ Cause: {:?}", f(), e)))
265    }
266}