tourniquet_celery/
lib.rs

1//! [Tourniquet](https://docs.rs/tourniquet) integration with the [celery](https://docs.rs/celery)
2//! library.
3//!
4//! # Example
5//!
6//! ```rust,no_run
7//! # use celery::task::TaskResult;
8//! # use tourniquet::RoundRobin;
9//! # use tourniquet_celery::{CeleryConnector, RoundRobinExt};
10//! #
11//! #[celery::task]
12//! async fn do_work(work: String) -> TaskResult<()> {
13//!     // Some work
14//! # println!("{}", work);
15//!     Ok(())
16//! }
17//!
18//! # #[tokio::main]
19//! # async fn main() -> Result<(), Box<dyn std::error::Error>> {
20//! let rr = RoundRobin::new(
21//!     vec!["amqp://rabbit01:5672/".to_owned(), "amqp://rabbit02:5672".to_owned()],
22//!     CeleryConnector { name: "rr", routes: &[("*", "my_route")], ..Default::default() },
23//! );
24//!
25//! # let work = "foo".to_owned();
26//! rr.send_task(|| do_work::new(work.clone())).await.expect("Failed to send task");
27//! # Ok(())
28//! # }
29//! ```
30
31use std::error::Error;
32use std::fmt::{Debug, Display, Error as FmtError, Formatter};
33
34use async_trait::async_trait;
35use celery::{
36    error::BackendError::*,
37    error::BrokerError::BadRoutingPattern,
38    error::CeleryError::{self, *},
39    task::{AsyncResult, Signature, Task},
40    Celery, CeleryBuilder,
41};
42use tourniquet::{Connector, Next, RoundRobin};
43#[cfg(feature = "trace")]
44use tracing::{
45    field::{display, Empty},
46    instrument, Span,
47};
48
49/// Wrapper for [`CeleryError`](https://docs.rs/celery-rs/latest/celery/error/struct.CeleryError.html)
50/// that implements [`Next`](https://docs.rs/tourniquet/latest/tourniquet/trait.Next.html).
51pub struct RRCeleryError(CeleryError);
52
53impl Next for RRCeleryError {
54    fn is_next(&self) -> bool {
55        match self.0 {
56            BrokerError(BadRoutingPattern(_)) => false,
57            BrokerError(_) | IoError(_) | ProtocolError(_) => true,
58            BackendError(NotConfigured | Timeout | Redis(_)) => true,
59            BackendError(Serialization(_) | Pool(_) | PoolCreationError(_) | TaskFailed(_)) => {
60                false
61            }
62            NoQueueToConsume
63            | ForcedShutdown
64            | TaskRegistrationError(_)
65            | UnregisteredTaskError(_) => false,
66        }
67    }
68}
69
70impl From<CeleryError> for RRCeleryError {
71    fn from(e: CeleryError) -> Self {
72        Self(e)
73    }
74}
75
76impl From<RRCeleryError> for CeleryError {
77    fn from(e: RRCeleryError) -> Self {
78        e.0
79    }
80}
81
82impl Display for RRCeleryError {
83    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
84        Display::fmt(&self.0, f)
85    }
86}
87
88impl Debug for RRCeleryError {
89    fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), FmtError> {
90        Debug::fmt(&self.0, f)
91    }
92}
93
94impl Error for RRCeleryError {
95    fn source(&self) -> Option<&(dyn Error + 'static)> {
96        Some(&self.0)
97    }
98}
99
100/// Ready to use connector for Celery.
101///
102/// Please refer to
103/// [celery's documentation](https://docs.rs/celery/^0.3/celery/struct.CeleryBuilder.html)
104/// for more information about the fields of this structure.
105///
106/// Note that this is a basic connector for celery, exposing only the most common options used by
107/// celery producers, not consumers.  Please create your own connector should you need finer
108/// grained control over the created celery client.
109pub struct CeleryConnector<'a> {
110    pub name: &'a str,
111    pub default_queue: Option<&'a str>,
112    pub routes: &'a [(&'a str, &'a str)],
113    pub connection_timeout: Option<u32>,
114}
115
116impl<'a> Default for CeleryConnector<'a> {
117    fn default() -> Self {
118        Self { name: "celery", default_queue: None, routes: &[], connection_timeout: None }
119    }
120}
121
122#[async_trait]
123impl<'a> Connector<String, Celery, RRCeleryError> for CeleryConnector<'a> {
124    #[cfg_attr(feature = "trace", tracing::instrument(skip(self), err))]
125    async fn connect(&self, url: &String) -> Result<Celery, RRCeleryError> {
126        let mut builder = CeleryBuilder::new(self.name, url.as_ref());
127
128        if let Some(queue) = self.default_queue {
129            builder = builder.default_queue(queue);
130        }
131        for (pattern, queue) in self.routes {
132            builder = builder.task_route(pattern, queue);
133        }
134        if let Some(timeout) = self.connection_timeout {
135            builder = builder.broker_connection_timeout(timeout);
136        }
137
138        Ok(builder.build().await?)
139    }
140}
141
142#[async_trait]
143pub trait RoundRobinExt {
144    async fn send_task<T, F>(&self, task_gen: F) -> Result<AsyncResult, CeleryError>
145    where
146        T: Task + 'static,
147        F: Fn() -> Signature<T> + Send + Sync;
148}
149
150#[async_trait]
151impl<SvcSrc, Conn> RoundRobinExt for RoundRobin<SvcSrc, Celery, RRCeleryError, Conn>
152where
153    SvcSrc: Debug + Send + Sync,
154    Conn: Connector<SvcSrc, Celery, RRCeleryError> + Send + Sync,
155{
156    /// Send a Celery task.
157    ///
158    /// The `task_gen` argument returns a signature for each attempt, should each attempt hold a
159    /// different value (e.g. trace id, attempt id, timestamp, ...).
160    #[cfg_attr(
161        feature = "trace",
162        instrument(
163            fields(task_name = display(Signature::<T>::task_name()), task_id = Empty),
164            skip(self, task_gen),
165            err,
166        ),
167    )]
168    async fn send_task<T, F>(&self, task_gen: F) -> Result<AsyncResult, CeleryError>
169    where
170        T: Task + 'static,
171        F: Fn() -> Signature<T> + Send + Sync,
172    {
173        log::debug!("Sending task {}", Signature::<T>::task_name());
174
175        let task_gen = &task_gen;
176        let task =
177            self.run(|celery| async move { Ok(celery.send_task(task_gen()).await?) }).await?;
178
179        #[cfg(feature = "trace")]
180        Span::current().record("task_id", &display(&task.task_id));
181
182        Ok(task)
183    }
184}
185
186/// Shorthand type for a basic RoundRobin type using Celery
187pub type CeleryRoundRobin = RoundRobin<String, Celery, RRCeleryError, CeleryConnector<'static>>;