1use colored::Colorize;
2use futures::stream::StreamExt;
3use log::{debug, error, info, warn};
4use std::collections::HashMap;
5use std::convert::TryFrom;
6use std::error::Error;
7use std::sync::Arc;
8use tokio::select;
9
10#[cfg(unix)]
11use tokio::signal::unix::{signal, Signal, SignalKind};
12
13use tokio::sync::mpsc::{self, UnboundedSender};
14use tokio::sync::RwLock;
15use tokio::time::{self, Duration};
16use tokio_stream::StreamMap;
17
18mod trace;
19
20use crate::broker::{
21 broker_builder_from_url, build_and_connect, configure_task_routes, Broker, BrokerBuilder,
22 Delivery,
23};
24use crate::error::{BrokerError, CeleryError, TraceError};
25use crate::protocol::{Message, MessageContentType};
26use crate::routing::Rule;
27use crate::task::{AsyncResult, Signature, Task, TaskEvent, TaskOptions, TaskStatus};
28use trace::{build_tracer, TraceBuilder, TracerTrait};
29
30struct Config {
31 name: String,
32 hostname: String,
33 broker_builder: Box<dyn BrokerBuilder>,
34 broker_connection_timeout: u32,
35 broker_connection_retry: bool,
36 broker_connection_max_retries: u32,
37 broker_connection_retry_delay: u32,
38 default_queue: String,
39 task_options: TaskOptions,
40 task_routes: Vec<(String, String)>,
41}
42
43pub struct CeleryBuilder {
45 config: Config,
46}
47
48impl CeleryBuilder {
49 pub fn new(name: &str, broker_url: &str) -> Self {
51 Self {
52 config: Config {
53 name: name.into(),
54 hostname: format!(
55 "{}@{}",
56 name,
57 hostname::get()
58 .ok()
59 .and_then(|sys_hostname| sys_hostname.into_string().ok())
60 .unwrap_or_else(|| "unknown".into())
61 ),
62 broker_builder: broker_builder_from_url(broker_url),
63 broker_connection_timeout: 2,
64 broker_connection_retry: true,
65 broker_connection_max_retries: 5,
66 broker_connection_retry_delay: 5,
67 default_queue: "celery".into(),
68 task_options: TaskOptions::default(),
69 task_routes: vec![],
70 },
71 }
72 }
73
74 pub fn hostname(mut self, hostname: &str) -> Self {
79 self.config.hostname = hostname.into();
80 self
81 }
82
83 pub fn default_queue(mut self, queue_name: &str) -> Self {
85 self.config.default_queue = queue_name.into();
86 self
87 }
88
89 pub fn prefetch_count(mut self, prefetch_count: u16) -> Self {
97 self.config.broker_builder = self.config.broker_builder.prefetch_count(prefetch_count);
98 self
99 }
100
101 pub fn heartbeat(mut self, heartbeat: Option<u16>) -> Self {
103 self.config.broker_builder = self.config.broker_builder.heartbeat(heartbeat);
104 self
105 }
106
107 pub fn task_time_limit(mut self, task_time_limit: u32) -> Self {
109 self.config.task_options.time_limit = Some(task_time_limit);
110 self
111 }
112
113 pub fn task_hard_time_limit(mut self, task_hard_time_limit: u32) -> Self {
119 self.config.task_options.hard_time_limit = Some(task_hard_time_limit);
120 self
121 }
122
123 pub fn task_max_retries(mut self, task_max_retries: u32) -> Self {
125 self.config.task_options.max_retries = Some(task_max_retries);
126 self
127 }
128
129 pub fn task_min_retry_delay(mut self, task_min_retry_delay: u32) -> Self {
131 self.config.task_options.min_retry_delay = Some(task_min_retry_delay);
132 self
133 }
134
135 pub fn task_max_retry_delay(mut self, task_max_retry_delay: u32) -> Self {
137 self.config.task_options.max_retry_delay = Some(task_max_retry_delay);
138 self
139 }
140
141 pub fn task_retry_for_unexpected(mut self, retry_for_unexpected: bool) -> Self {
144 self.config.task_options.retry_for_unexpected = Some(retry_for_unexpected);
145 self
146 }
147
148 pub fn acks_late(mut self, acks_late: bool) -> Self {
151 self.config.task_options.acks_late = Some(acks_late);
152 self
153 }
154
155 pub fn task_content_type(mut self, content_type: MessageContentType) -> Self {
157 self.config.task_options.content_type = Some(content_type);
158 self
159 }
160
161 pub fn task_route(mut self, pattern: &str, queue: &str) -> Self {
163 self.config.task_routes.push((pattern.into(), queue.into()));
164 self
165 }
166
167 pub fn broker_connection_timeout(mut self, timeout: u32) -> Self {
169 self.config.broker_connection_timeout = timeout;
170 self
171 }
172
173 pub fn broker_connection_retry(mut self, retry: bool) -> Self {
175 self.config.broker_connection_retry = retry;
176 self
177 }
178
179 pub fn broker_connection_max_retries(mut self, max_retries: u32) -> Self {
182 self.config.broker_connection_max_retries = max_retries;
183 self
184 }
185
186 pub fn broker_connection_retry_delay(mut self, retry_delay: u32) -> Self {
188 self.config.broker_connection_retry_delay = retry_delay;
189 self
190 }
191
192 pub async fn build(self) -> Result<Celery, CeleryError> {
194 let broker_builder = self
196 .config
197 .broker_builder
198 .declare_queue(&self.config.default_queue);
199
200 let (broker_builder, task_routes) =
201 configure_task_routes(broker_builder, &self.config.task_routes)?;
202
203 let broker = build_and_connect(
204 broker_builder,
205 self.config.broker_connection_timeout,
206 if self.config.broker_connection_retry {
207 self.config.broker_connection_max_retries
208 } else {
209 0
210 },
211 self.config.broker_connection_retry_delay,
212 )
213 .await?;
214
215 Ok(Celery {
216 name: self.config.name,
217 hostname: self.config.hostname,
218 broker,
219 default_queue: self.config.default_queue,
220 task_options: self.config.task_options,
221 task_routes,
222 task_trace_builders: RwLock::new(HashMap::new()),
223 broker_connection_timeout: self.config.broker_connection_timeout,
224 broker_connection_retry: self.config.broker_connection_retry,
225 broker_connection_max_retries: self.config.broker_connection_max_retries,
226 broker_connection_retry_delay: self.config.broker_connection_retry_delay,
227 })
228 }
229}
230
231pub struct Celery {
234 pub name: String,
236
237 pub hostname: String,
239
240 pub broker: Box<dyn Broker>,
242
243 pub default_queue: String,
245
246 pub task_options: TaskOptions,
248
249 task_routes: Vec<Rule>,
251
252 task_trace_builders: RwLock<HashMap<String, TraceBuilder>>,
255
256 broker_connection_timeout: u32,
257 broker_connection_retry: bool,
258 broker_connection_max_retries: u32,
259 broker_connection_retry_delay: u32,
260}
261
262impl Celery {
263 pub async fn display_pretty(&self) {
268 let banner = format!(
270 r#"
271 _________________ >_<
272 / ______________ \ | |
273/ / \_\ ,---. | | ,---. ,--.--.,--. ,--.
274| / .< >. | .-. :| || .-. :| .--' \ ' /
275| | ( ) \ --.| |\ --.| | \ /
276| | --o--o-- `----'`-' `----'`--' .-' /
277| | _/ \_ __ `--'
278| | / \________/ \ / /
279| \ | | / /
280 \ \_____________/ / {}
281 \_______________/
282"#,
283 self.hostname
284 );
285 println!("{}", banner.truecolor(255, 102, 0));
286
287 println!("{}", "[broker]".bold());
289 println!(" {}", self.broker.safe_url());
290 println!();
291
292 println!("{}", "[tasks]".bold());
294 for task in self.task_trace_builders.read().await.keys() {
295 println!(" . {task}");
296 }
297 println!();
298 }
299
300 pub async fn send_task<T: Task>(
303 &self,
304 mut task_sig: Signature<T>,
305 ) -> Result<AsyncResult, CeleryError> {
306 task_sig.options.update(&self.task_options);
307 let maybe_queue = task_sig.queue.take();
308 let queue = maybe_queue.as_deref().unwrap_or_else(|| {
309 crate::routing::route(T::NAME, &self.task_routes).unwrap_or(&self.default_queue)
310 });
311 let message = Message::try_from(task_sig)?;
312 info!(
313 "Sending task {}[{}] to {}",
314 T::NAME,
315 message.task_id(),
316 queue,
317 );
318 self.broker.send(&message, queue).await?;
319 Ok(AsyncResult::new(message.task_id()))
320 }
321
322 pub async fn register_task<T: Task + 'static>(&self) -> Result<(), CeleryError> {
324 let mut task_trace_builders = self.task_trace_builders.write().await;
325 if task_trace_builders.contains_key(T::NAME) {
326 Err(CeleryError::TaskRegistrationError(T::NAME.into()))
327 } else {
328 task_trace_builders.insert(T::NAME.into(), Box::new(build_tracer::<T>));
329 debug!("Registered task {}", T::NAME);
330 Ok(())
331 }
332 }
333
334 async fn get_task_tracer(
335 self: &Arc<Self>,
336 message: Message,
337 event_tx: UnboundedSender<TaskEvent>,
338 ) -> Result<Box<dyn TracerTrait>, Box<dyn Error + Send + Sync + 'static>> {
339 let task_trace_builders = self.task_trace_builders.read().await;
340 if let Some(build_tracer) = task_trace_builders.get(&message.headers.task) {
341 Ok(build_tracer(
342 self.clone(),
343 message,
344 self.task_options,
345 event_tx,
346 self.hostname.clone(),
347 )
348 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?)
349 } else {
350 Err(
351 Box::new(CeleryError::UnregisteredTaskError(message.headers.task))
352 as Box<dyn Error + Send + Sync + 'static>,
353 )
354 }
355 }
356
357 async fn try_handle_delivery(
360 self: &Arc<Self>,
361 delivery: Box<dyn Delivery>,
362 event_tx: UnboundedSender<TaskEvent>,
363 ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
364 let message = match delivery.try_deserialize_message() {
366 Ok(message) => message,
367 Err(e) => {
368 self.broker
371 .ack(delivery.as_ref())
372 .await
373 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
374 return Err(Box::new(e));
375 }
376 };
377
378 let mut tracer = match self.get_task_tracer(message, event_tx).await {
382 Ok(tracer) => tracer,
383 Err(e) => {
384 self.broker
388 .ack(delivery.as_ref())
389 .await
390 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
391 return Err(e);
392 }
393 };
394
395 if tracer.is_delayed() {
396 if let Err(e) = self.broker.increase_prefetch_count().await {
399 self.broker
405 .retry(delivery.as_ref(), None)
406 .await
407 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
408 self.broker
409 .ack(delivery.as_ref())
410 .await
411 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
412 return Err(Box::new(e));
413 };
414
415 tracer.wait().await;
417 }
418
419 if !tracer.acks_late() {
421 self.broker
422 .ack(delivery.as_ref())
423 .await
424 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
425 }
426
427 if let Err(TraceError::Retry(retry_eta)) = tracer.trace().await {
432 self.broker
434 .retry(delivery.as_ref(), retry_eta)
435 .await
436 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
437 }
438
439 if tracer.acks_late() {
441 self.broker
442 .ack(delivery.as_ref())
443 .await
444 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
445 }
446
447 if tracer.is_delayed() {
450 self.broker
451 .decrease_prefetch_count()
452 .await
453 .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
454 }
455
456 Ok(())
457 }
458
459 async fn handle_delivery(
461 self: Arc<Self>,
462 delivery: Box<dyn Delivery>,
463 event_tx: UnboundedSender<TaskEvent>,
464 ) {
465 if let Err(e) = self.try_handle_delivery(delivery, event_tx).await {
466 error!("{}", e);
467 }
468 }
469
470 pub async fn close(&self) -> Result<(), CeleryError> {
472 Ok(self.broker.close().await?)
473 }
474
475 pub async fn consume(self: &Arc<Self>) -> Result<(), CeleryError> {
477 let queues = &[&self.default_queue.clone()[..]];
478 Self::consume_from(self, queues).await
479 }
480
481 pub async fn consume_from(self: &Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
483 loop {
484 let result = self.clone()._consume_from(queues).await;
485 if !self.broker_connection_retry {
486 return result;
487 }
488
489 if let Err(err) = result {
490 match err {
491 CeleryError::BrokerError(broker_err) => {
492 if broker_err.is_connection_error() {
493 error!("Broker connection failed");
494 } else {
495 return Err(CeleryError::BrokerError(broker_err));
496 }
497 }
498 _ => return Err(err),
499 };
500 } else {
501 return result;
502 }
503
504 let mut reconnect_successful: bool = false;
505 for _ in 0..self.broker_connection_max_retries {
506 info!("Trying to re-establish connection with broker");
507 time::sleep(Duration::from_secs(
508 self.broker_connection_retry_delay as u64,
509 ))
510 .await;
511
512 match self.broker.reconnect(self.broker_connection_timeout).await {
513 Err(err) => {
514 if err.is_connection_error() {
515 continue;
516 }
517 return Err(CeleryError::BrokerError(err));
518 }
519 Ok(_) => {
520 info!("Successfully reconnected with broker");
521 reconnect_successful = true;
522 break;
523 }
524 };
525 }
526
527 if !reconnect_successful {
528 return Err(CeleryError::BrokerError(BrokerError::NotConnected));
529 }
530 }
531 }
532
533 #[allow(clippy::cognitive_complexity)]
534 async fn _consume_from(self: Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
535 if queues.is_empty() {
536 return Err(CeleryError::NoQueueToConsume);
537 }
538
539 info!("Consuming from {:?}", queues);
540
541 let (broker_error_tx, mut broker_error_rx) = mpsc::channel::<BrokerError>(100);
544
545 let mut stream_map = StreamMap::new();
547 let mut consumer_tags = vec![];
548 for queue in queues {
549 let broker_error_tx = broker_error_tx.clone();
550
551 let (consumer_tag, consumer) = self
552 .broker
553 .consume(
554 queue,
555 Box::new(move |e| {
556 broker_error_tx.clone().try_send(e).ok();
557 }),
558 )
559 .await?;
560 stream_map.insert(queue, consumer);
561 consumer_tags.push(consumer_tag);
562 }
563
564 let mut ender = Ender::new()?;
566
567 let (task_event_tx, mut task_event_rx) = mpsc::unbounded_channel::<TaskEvent>();
571 let mut pending_tasks = 0;
572
573 loop {
580 select! {
581 maybe_delivery_result = stream_map.next() => {
582 if let Some((queue, delivery_result)) = maybe_delivery_result {
583 match delivery_result {
584 Ok(delivery) => {
585 let task_event_tx = task_event_tx.clone();
586 debug!("Received delivery from {}: {:?}", queue, delivery);
587 tokio::spawn(self.clone().handle_delivery(delivery, task_event_tx));
588 }
589 Err(e) => {
590 error!("Deliver failed: {}", e);
591 }
592 }
593 }
594 },
595 ending = ender.wait() => {
596 if let Ok(SigType::Interrupt) = ending {
597 warn!("Ope! Hitting Ctrl+C again will terminate all running tasks!");
598 }
599 info!("Warm shutdown...");
600 break;
601 },
602 maybe_task_event = task_event_rx.recv() => {
603 if let Some(event) = maybe_task_event {
604 debug!("Received task event {:?}", event);
605 match event {
606 TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
607 TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
608 };
609 }
610 },
611 maybe_broker_error = broker_error_rx.recv() => {
612 if let Some(broker_error) = maybe_broker_error {
613 error!("{}", broker_error);
614 return Err(broker_error.into());
615 }
616 }
617 };
618 }
619
620 for consumer_tag in consumer_tags {
622 debug!("Cancelling consumer {}", consumer_tag);
623 self.broker.cancel(&consumer_tag).await?;
624 }
625
626 if pending_tasks > 0 {
627 info!("Waiting on {} pending tasks...", pending_tasks);
631 loop {
632 select! {
633 ending = ender.wait() => {
634 if let Ok(SigType::Interrupt) = ending {
635 warn!("Okay fine, shutting down now. See ya!");
636 return Err(CeleryError::ForcedShutdown);
637 }
638 },
639 maybe_event = task_event_rx.recv() => {
640 if let Some(event) = maybe_event {
641 debug!("Received task event {:?}", event);
642 match event {
643 TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
644 TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
645 };
646 if pending_tasks <= 0 {
647 break;
648 }
649 }
650 },
651 };
652 }
653 }
654
655 info!("No more pending tasks. See ya!");
656
657 Ok(())
658 }
659}
660
661#[allow(unused)]
662enum SigType {
663 Interrupt,
665 Terminate,
667}
668
669#[cfg(unix)]
671struct Ender {
672 sigint: Signal,
673 sigterm: Signal,
674}
675
676#[cfg(unix)]
677impl Ender {
678 fn new() -> Result<Self, std::io::Error> {
679 let sigint = signal(SignalKind::interrupt())?;
680 let sigterm = signal(SignalKind::terminate())?;
681
682 Ok(Ender { sigint, sigterm })
683 }
684
685 async fn wait(&mut self) -> Result<SigType, std::io::Error> {
687 let sigtype;
688
689 select! {
690 _ = self.sigint.recv() => {
691 sigtype = SigType::Interrupt
692 },
693 _ = self.sigterm.recv() => {
694 sigtype = SigType::Terminate
695 }
696 }
697
698 Ok(sigtype)
699 }
700}
701
702#[cfg(windows)]
703struct Ender;
704
705#[cfg(windows)]
706impl Ender {
707 fn new() -> Result<Self, std::io::Error> {
708 Ok(Ender)
709 }
710
711 async fn wait(&mut self) -> Result<SigType, std::io::Error> {
712 tokio::signal::ctrl_c().await?;
713
714 Ok(SigType::Interrupt)
715 }
716}
717
718#[cfg(test)]
719mod tests;