1use colored::Colorize;
2use futures::stream::StreamExt;
3use log::{debug, error, info, warn};
4use std::collections::HashMap;
5use std::convert::TryFrom;
6use std::error::Error;
7use std::sync::Arc;
8use tokio::select;
9
10#[cfg(unix)]
11use tokio::signal::unix::{signal, Signal, SignalKind};
12
13use tokio::sync::mpsc::{self, UnboundedSender};
14use tokio::sync::RwLock;
15use tokio::time::{self, Duration};
16use tokio_stream::StreamMap;
17
18mod trace;
19
20use crate::backend::ResultBackend;
21use crate::broker::{
22 broker_builder_from_url, build_and_connect, configure_task_routes, Broker, BrokerBuilder,
23 Delivery,
24};
25use crate::error::{BrokerError, CeleryError, TraceError};
26use crate::protocol::{Message, MessageContentType};
27use crate::routing::Rule;
28use crate::task::{AsyncResult, Signature, Task, TaskEvent, TaskOptions, TaskStatus};
29use trace::{build_tracer, TraceBuilder, TracerTrait};
30
31struct Config {
32 name: String,
33 hostname: String,
34 broker_builder: Box<dyn BrokerBuilder>,
35 broker_connection_timeout: u32,
36 broker_connection_retry: bool,
37 broker_connection_max_retries: u32,
38 broker_connection_retry_delay: u32,
39 default_queue: String,
40 task_options: TaskOptions,
41 task_routes: Vec<(String, String)>,
42 result_backend: Option<Arc<dyn ResultBackend>>,
43}
44
45pub struct CeleryBuilder {
47 config: Config,
48}
49
50impl CeleryBuilder {
51 pub fn new(name: &str, broker_url: &str) -> Self {
53 Self {
54 config: Config {
55 name: name.into(),
56 hostname: format!(
57 "{}@{}",
58 name,
59 hostname::get()
60 .ok()
61 .and_then(|sys_hostname| sys_hostname.into_string().ok())
62 .unwrap_or_else(|| "unknown".into())
63 ),
64 broker_builder: broker_builder_from_url(broker_url),
65 broker_connection_timeout: 2,
66 broker_connection_retry: true,
67 broker_connection_max_retries: 5,
68 broker_connection_retry_delay: 5,
69 default_queue: "celery".into(),
70 task_options: TaskOptions::default(),
71 task_routes: vec![],
72 result_backend: None,
73 },
74 }
75 }
76
77 pub fn hostname(mut self, hostname: &str) -> Self {
82 self.config.hostname = hostname.into();
83 self
84 }
85
86 pub fn default_queue(mut self, queue_name: &str) -> Self {
88 self.config.default_queue = queue_name.into();
89 self
90 }
91
92 pub fn result_backend<B>(mut self, backend: B) -> Self
94 where
95 B: ResultBackend + 'static,
96 {
97 self.config.result_backend = Some(Arc::new(backend));
98 self
99 }
100
101 pub fn prefetch_count(mut self, prefetch_count: u16) -> Self {
109 self.config.broker_builder = self.config.broker_builder.prefetch_count(prefetch_count);
110 self
111 }
112
113 pub fn heartbeat(mut self, heartbeat: Option<u16>) -> Self {
115 self.config.broker_builder = self.config.broker_builder.heartbeat(heartbeat);
116 self
117 }
118
119 pub fn task_time_limit(mut self, task_time_limit: u32) -> Self {
121 self.config.task_options.time_limit = Some(task_time_limit);
122 self
123 }
124
125 pub fn task_hard_time_limit(mut self, task_hard_time_limit: u32) -> Self {
131 self.config.task_options.hard_time_limit = Some(task_hard_time_limit);
132 self
133 }
134
135 pub fn task_max_retries(mut self, task_max_retries: u32) -> Self {
137 self.config.task_options.max_retries = Some(task_max_retries);
138 self
139 }
140
141 pub fn task_min_retry_delay(mut self, task_min_retry_delay: u32) -> Self {
143 self.config.task_options.min_retry_delay = Some(task_min_retry_delay);
144 self
145 }
146
147 pub fn task_max_retry_delay(mut self, task_max_retry_delay: u32) -> Self {
149 self.config.task_options.max_retry_delay = Some(task_max_retry_delay);
150 self
151 }
152
153 pub fn task_retry_for_unexpected(mut self, retry_for_unexpected: bool) -> Self {
156 self.config.task_options.retry_for_unexpected = Some(retry_for_unexpected);
157 self
158 }
159
160 pub fn acks_late(mut self, acks_late: bool) -> Self {
163 self.config.task_options.acks_late = Some(acks_late);
164 self
165 }
166
167 pub fn task_content_type(mut self, content_type: MessageContentType) -> Self {
169 self.config.task_options.content_type = Some(content_type);
170 self
171 }
172
173 pub fn task_route(mut self, pattern: &str, queue: &str) -> Self {
175 self.config.task_routes.push((pattern.into(), queue.into()));
176 self
177 }
178
179 pub fn broker_connection_timeout(mut self, timeout: u32) -> Self {
181 self.config.broker_connection_timeout = timeout;
182 self
183 }
184
185 pub fn broker_connection_retry(mut self, retry: bool) -> Self {
187 self.config.broker_connection_retry = retry;
188 self
189 }
190
191 pub fn broker_connection_max_retries(mut self, max_retries: u32) -> Self {
194 self.config.broker_connection_max_retries = max_retries;
195 self
196 }
197
198 pub fn broker_connection_retry_delay(mut self, retry_delay: u32) -> Self {
200 self.config.broker_connection_retry_delay = retry_delay;
201 self
202 }
203
204 pub async fn build(self) -> Result<Celery, CeleryError> {
206 let broker_builder = self
208 .config
209 .broker_builder
210 .declare_queue(&self.config.default_queue);
211
212 let (broker_builder, task_routes) =
213 configure_task_routes(broker_builder, &self.config.task_routes)?;
214
215 let broker = build_and_connect(
216 broker_builder,
217 self.config.broker_connection_timeout,
218 if self.config.broker_connection_retry {
219 self.config.broker_connection_max_retries
220 } else {
221 0
222 },
223 self.config.broker_connection_retry_delay,
224 )
225 .await?;
226
227 Ok(Celery {
228 name: self.config.name,
229 hostname: self.config.hostname,
230 broker,
231 default_queue: self.config.default_queue,
232 task_options: self.config.task_options,
233 task_routes,
234 task_trace_builders: RwLock::new(HashMap::new()),
235 broker_connection_timeout: self.config.broker_connection_timeout,
236 broker_connection_retry: self.config.broker_connection_retry,
237 broker_connection_max_retries: self.config.broker_connection_max_retries,
238 broker_connection_retry_delay: self.config.broker_connection_retry_delay,
239 result_backend: self.config.result_backend.clone(),
240 })
241 }
242}
243
244pub struct Celery {
247 pub name: String,
249
250 pub hostname: String,
252
253 pub broker: Box<dyn Broker>,
255
256 pub default_queue: String,
258
259 pub task_options: TaskOptions,
261
262 task_routes: Vec<Rule>,
264
265 task_trace_builders: RwLock<HashMap<String, TraceBuilder>>,
268
269 broker_connection_timeout: u32,
270 broker_connection_retry: bool,
271 broker_connection_max_retries: u32,
272 broker_connection_retry_delay: u32,
273 result_backend: Option<Arc<dyn ResultBackend>>,
274}
275
276impl Celery {
277 pub fn result_backend(&self) -> Option<Arc<dyn ResultBackend>> {
279 self.result_backend.clone()
280 }
281
282 pub async fn display_pretty(&self) {
287 let banner = format!(
289 r#"
290 _________________ >_<
291 / ______________ \ | |
292/ / \_\ ,---. | | ,---. ,--.--.,--. ,--.
293| / .< >. | .-. :| || .-. :| .--' \ ' /
294| | ( ) \ --.| |\ --.| | \ /
295| | --o--o-- `----'`-' `----'`--' .-' /
296| | _/ \_ __ `--'
297| | / \________/ \ / /
298| \ | | / /
299 \ \_____________/ / {}
300 \_______________/
301"#,
302 self.hostname
303 );
304 println!("{}", banner.truecolor(255, 102, 0));
305
306 println!("{}", "[broker]".bold());
308 println!(" {}", self.broker.safe_url());
309 println!();
310
311 println!("{}", "[tasks]".bold());
313 for task in self.task_trace_builders.read().await.keys() {
314 println!(" . {task}");
315 }
316 println!();
317 }
318
319 pub async fn send_task<T: Task>(
322 &self,
323 mut task_sig: Signature<T>,
324 ) -> Result<AsyncResult, CeleryError> {
325 task_sig.options.update(&self.task_options);
326 let maybe_queue = task_sig.queue.take();
327 let queue = maybe_queue.as_deref().unwrap_or_else(|| {
328 crate::routing::route(T::NAME, &self.task_routes).unwrap_or(&self.default_queue)
329 });
330 let message = Message::try_from(task_sig)?;
331 info!(
332 "Sending task {}[{}] to {}",
333 T::NAME,
334 message.task_id(),
335 queue,
336 );
337 self.broker.send(&message, queue).await?;
338 Ok(AsyncResult::with_backend(
339 message.task_id(),
340 self.result_backend(),
341 ))
342 }
343
344 pub async fn register_task<T: Task + 'static>(&self) -> Result<(), CeleryError> {
346 let mut task_trace_builders = self.task_trace_builders.write().await;
347 if task_trace_builders.contains_key(T::NAME) {
348 Err(CeleryError::TaskRegistrationError(T::NAME.into()))
349 } else {
350 task_trace_builders.insert(T::NAME.into(), Box::new(build_tracer::<T>));
351 debug!("Registered task {}", T::NAME);
352 Ok(())
353 }
354 }
355
356 async fn get_task_tracer(
357 self: &Arc<Self>,
358 message: Message,
359 event_tx: UnboundedSender<TaskEvent>,
360 ) -> Result<Box<dyn TracerTrait>, Box<dyn Error + Send + Sync + 'static>> {
361 let task_trace_builders = self.task_trace_builders.read().await;
362 if let Some(build_tracer) = task_trace_builders.get(&message.headers.task) {
363 Ok(build_tracer(
364 self.clone(),
365 message,
366 self.task_options,
367 event_tx,
368 self.hostname.clone(),
369 )
370 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?)
371 } else {
372 Err(
373 Box::new(CeleryError::UnregisteredTaskError(message.headers.task))
374 as Box<dyn Error + Send + Sync + 'static>,
375 )
376 }
377 }
378
379 async fn try_handle_delivery(
382 self: &Arc<Self>,
383 delivery: Box<dyn Delivery>,
384 event_tx: UnboundedSender<TaskEvent>,
385 ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
386 let message = match delivery.try_deserialize_message() {
388 Ok(message) => message,
389 Err(e) => {
390 self.broker
393 .ack(delivery.as_ref())
394 .await
395 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
396 return Err(Box::new(e));
397 }
398 };
399
400 let mut tracer = match self.get_task_tracer(message, event_tx).await {
404 Ok(tracer) => tracer,
405 Err(e) => {
406 self.broker
410 .ack(delivery.as_ref())
411 .await
412 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
413 return Err(e);
414 }
415 };
416
417 if tracer.is_delayed() {
418 if let Err(e) = self.broker.increase_prefetch_count().await {
421 self.broker
427 .retry(delivery.as_ref(), None)
428 .await
429 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
430 self.broker
431 .ack(delivery.as_ref())
432 .await
433 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
434 return Err(Box::new(e));
435 };
436
437 tracer.wait().await;
439 }
440
441 if !tracer.acks_late() {
443 self.broker
444 .ack(delivery.as_ref())
445 .await
446 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
447 }
448
449 if let Err(TraceError::Retry(retry_eta)) = tracer.trace().await {
454 self.broker
456 .retry(delivery.as_ref(), retry_eta)
457 .await
458 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
459 }
460
461 if tracer.acks_late() {
463 self.broker
464 .ack(delivery.as_ref())
465 .await
466 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
467 }
468
469 if tracer.is_delayed() {
472 self.broker
473 .decrease_prefetch_count()
474 .await
475 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
476 }
477
478 Ok(())
479 }
480
481 async fn handle_delivery(
483 self: Arc<Self>,
484 delivery: Box<dyn Delivery>,
485 event_tx: UnboundedSender<TaskEvent>,
486 ) {
487 if let Err(e) = self.try_handle_delivery(delivery, event_tx).await {
488 error!("{}", e);
489 }
490 }
491
492 pub async fn close(&self) -> Result<(), CeleryError> {
494 Ok(self.broker.close().await?)
495 }
496
497 pub async fn consume(self: &Arc<Self>) -> Result<(), CeleryError> {
499 let queues = &[&self.default_queue.clone()[..]];
500 Self::consume_from(self, queues).await
501 }
502
503 pub async fn consume_from(self: &Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
505 loop {
506 let result = self.clone()._consume_from(queues).await;
507 if !self.broker_connection_retry {
508 return result;
509 }
510
511 if let Err(err) = result {
512 match err {
513 CeleryError::BrokerError(broker_err) => {
514 if broker_err.is_connection_error() {
515 error!("Broker connection failed");
516 } else {
517 return Err(CeleryError::BrokerError(broker_err));
518 }
519 }
520 _ => return Err(err),
521 };
522 } else {
523 return result;
524 }
525
526 let mut reconnect_successful: bool = false;
527 for _ in 0..self.broker_connection_max_retries {
528 info!("Trying to re-establish connection with broker");
529 time::sleep(Duration::from_secs(
530 self.broker_connection_retry_delay as u64,
531 ))
532 .await;
533
534 match self.broker.reconnect(self.broker_connection_timeout).await {
535 Err(err) => {
536 if err.is_connection_error() {
537 continue;
538 }
539 return Err(CeleryError::BrokerError(err));
540 }
541 Ok(_) => {
542 info!("Successfully reconnected with broker");
543 reconnect_successful = true;
544 break;
545 }
546 };
547 }
548
549 if !reconnect_successful {
550 return Err(CeleryError::BrokerError(BrokerError::NotConnected));
551 }
552 }
553 }
554
555 #[allow(clippy::cognitive_complexity)]
556 async fn _consume_from(self: Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
557 if queues.is_empty() {
558 return Err(CeleryError::NoQueueToConsume);
559 }
560
561 info!("Consuming from {:?}", queues);
562
563 let (broker_error_tx, mut broker_error_rx) = mpsc::channel::<BrokerError>(100);
566
567 let mut stream_map = StreamMap::new();
569 let mut consumer_tags = vec![];
570 for queue in queues {
571 let broker_error_tx = broker_error_tx.clone();
572
573 let (consumer_tag, consumer) = self
574 .broker
575 .consume(
576 queue,
577 Box::new(move |e| {
578 broker_error_tx.clone().try_send(e).ok();
579 }),
580 )
581 .await?;
582 stream_map.insert(queue, consumer);
583 consumer_tags.push(consumer_tag);
584 }
585
586 let mut ender = Ender::new()?;
588
589 let (task_event_tx, mut task_event_rx) = mpsc::unbounded_channel::<TaskEvent>();
593 let mut pending_tasks = 0;
594
595 loop {
602 select! {
603 maybe_delivery_result = stream_map.next() => {
604 if let Some((queue, delivery_result)) = maybe_delivery_result {
605 match delivery_result {
606 Ok(delivery) => {
607 let task_event_tx = task_event_tx.clone();
608 debug!("Received delivery from {}: {:?}", queue, delivery);
609 tokio::spawn(self.clone().handle_delivery(delivery, task_event_tx));
610 }
611 Err(e) => {
612 error!("Deliver failed: {}", e);
613 }
614 }
615 }
616 },
617 ending = ender.wait() => {
618 if let Ok(SigType::Interrupt) = ending {
619 warn!("Ope! Hitting Ctrl+C again will terminate all running tasks!");
620 }
621 info!("Warm shutdown...");
622 break;
623 },
624 maybe_task_event = task_event_rx.recv() => {
625 if let Some(event) = maybe_task_event {
626 debug!("Received task event {:?}", event);
627 match event {
628 TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
629 TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
630 };
631 }
632 },
633 maybe_broker_error = broker_error_rx.recv() => {
634 if let Some(broker_error) = maybe_broker_error {
635 error!("{}", broker_error);
636 return Err(broker_error.into());
637 }
638 }
639 };
640 }
641
642 for consumer_tag in consumer_tags {
644 debug!("Cancelling consumer {}", consumer_tag);
645 self.broker.cancel(&consumer_tag).await?;
646 }
647
648 if pending_tasks > 0 {
649 info!("Waiting on {} pending tasks...", pending_tasks);
653 loop {
654 select! {
655 ending = ender.wait() => {
656 if let Ok(SigType::Interrupt) = ending {
657 warn!("Okay fine, shutting down now. See ya!");
658 return Err(CeleryError::ForcedShutdown);
659 }
660 },
661 maybe_event = task_event_rx.recv() => {
662 if let Some(event) = maybe_event {
663 debug!("Received task event {:?}", event);
664 match event {
665 TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
666 TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
667 };
668 if pending_tasks <= 0 {
669 break;
670 }
671 }
672 },
673 };
674 }
675 }
676
677 info!("No more pending tasks. See ya!");
678
679 Ok(())
680 }
681}
682
683#[allow(unused)]
684enum SigType {
685 Interrupt,
687 Terminate,
689}
690
691#[cfg(unix)]
693struct Ender {
694 sigint: Signal,
695 sigterm: Signal,
696}
697
698#[cfg(unix)]
699impl Ender {
700 fn new() -> Result<Self, std::io::Error> {
701 let sigint = signal(SignalKind::interrupt())?;
702 let sigterm = signal(SignalKind::terminate())?;
703
704 Ok(Ender { sigint, sigterm })
705 }
706
707 async fn wait(&mut self) -> Result<SigType, std::io::Error> {
709 let sigtype;
710
711 select! {
712 _ = self.sigint.recv() => {
713 sigtype = SigType::Interrupt
714 },
715 _ = self.sigterm.recv() => {
716 sigtype = SigType::Terminate
717 }
718 }
719
720 Ok(sigtype)
721 }
722}
723
724#[cfg(windows)]
725struct Ender;
726
727#[cfg(windows)]
728impl Ender {
729 fn new() -> Result<Self, std::io::Error> {
730 Ok(Ender)
731 }
732
733 async fn wait(&mut self) -> Result<SigType, std::io::Error> {
734 tokio::signal::ctrl_c().await?;
735
736 Ok(SigType::Interrupt)
737 }
738}
739
740#[cfg(test)]
741mod tests;