celery/app/
mod.rs

1use colored::Colorize;
2use futures::stream::StreamExt;
3use log::{debug, error, info, warn};
4use std::collections::HashMap;
5use std::convert::TryFrom;
6use std::error::Error;
7use std::sync::Arc;
8use tokio::select;
9
10#[cfg(unix)]
11use tokio::signal::unix::{signal, Signal, SignalKind};
12
13use tokio::sync::mpsc::{self, UnboundedSender};
14use tokio::sync::RwLock;
15use tokio::time::{self, Duration};
16use tokio_stream::StreamMap;
17
18mod trace;
19
20use crate::backend::ResultBackend;
21use crate::broker::{
22    broker_builder_from_url, build_and_connect, configure_task_routes, Broker, BrokerBuilder,
23    Delivery,
24};
25use crate::error::{BrokerError, CeleryError, TraceError};
26use crate::protocol::{Message, MessageContentType};
27use crate::routing::Rule;
28use crate::task::{AsyncResult, Signature, Task, TaskEvent, TaskOptions, TaskStatus};
29use trace::{build_tracer, TraceBuilder, TracerTrait};
30
31struct Config {
32    name: String,
33    hostname: String,
34    broker_builder: Box<dyn BrokerBuilder>,
35    broker_connection_timeout: u32,
36    broker_connection_retry: bool,
37    broker_connection_max_retries: u32,
38    broker_connection_retry_delay: u32,
39    default_queue: String,
40    task_options: TaskOptions,
41    task_routes: Vec<(String, String)>,
42    result_backend: Option<Arc<dyn ResultBackend>>,
43}
44
45/// Used to create a [`Celery`] app with a custom configuration.
46pub struct CeleryBuilder {
47    config: Config,
48}
49
50impl CeleryBuilder {
51    /// Get a [`CeleryBuilder`] for creating a [`Celery`] app with a custom configuration.
52    pub fn new(name: &str, broker_url: &str) -> Self {
53        Self {
54            config: Config {
55                name: name.into(),
56                hostname: format!(
57                    "{}@{}",
58                    name,
59                    hostname::get()
60                        .ok()
61                        .and_then(|sys_hostname| sys_hostname.into_string().ok())
62                        .unwrap_or_else(|| "unknown".into())
63                ),
64                broker_builder: broker_builder_from_url(broker_url),
65                broker_connection_timeout: 2,
66                broker_connection_retry: true,
67                broker_connection_max_retries: 5,
68                broker_connection_retry_delay: 5,
69                default_queue: "celery".into(),
70                task_options: TaskOptions::default(),
71                task_routes: vec![],
72                result_backend: None,
73            },
74        }
75    }
76
77    /// Set the node name of the app. Defaults to `"{name}@{sys hostname}"`.
78    ///
79    /// *This field should probably be named "nodename" to avoid confusion with the
80    /// system hostname, but we're trying to be consistent with Python Celery.*
81    pub fn hostname(mut self, hostname: &str) -> Self {
82        self.config.hostname = hostname.into();
83        self
84    }
85
86    /// Set the name of the default queue to something other than "celery".
87    pub fn default_queue(mut self, queue_name: &str) -> Self {
88        self.config.default_queue = queue_name.into();
89        self
90    }
91
92    /// Configure a result backend implementation for storing task results.
93    pub fn result_backend<B>(mut self, backend: B) -> Self
94    where
95        B: ResultBackend + 'static,
96    {
97        self.config.result_backend = Some(Arc::new(backend));
98        self
99    }
100
101    /// Set the prefetch count. The default value depends on the broker implementation,
102    /// but it's recommended that you always set this to a value that works best
103    /// for your application.
104    ///
105    /// This may take some tuning, as it depends on a lot of factors, such
106    /// as whether your tasks are IO bound (higher prefetch count is better) or CPU bound (lower
107    /// prefetch count is better).
108    pub fn prefetch_count(mut self, prefetch_count: u16) -> Self {
109        self.config.broker_builder = self.config.broker_builder.prefetch_count(prefetch_count);
110        self
111    }
112
113    /// Set the broker heartbeat. The default value depends on the broker implementation.
114    pub fn heartbeat(mut self, heartbeat: Option<u16>) -> Self {
115        self.config.broker_builder = self.config.broker_builder.heartbeat(heartbeat);
116        self
117    }
118
119    /// Set an app-level time limit for tasks (see [`TaskOptions::time_limit`]).
120    pub fn task_time_limit(mut self, task_time_limit: u32) -> Self {
121        self.config.task_options.time_limit = Some(task_time_limit);
122        self
123    }
124
125    /// Set an app-level hard time limit for tasks (see [`TaskOptions::hard_time_limit`]).
126    ///
127    /// *Note that this is really only for compatability with Python workers*.
128    /// `time_limit` and `hard_time_limit` are treated the same by Rust workers, and if both
129    /// are set, the minimum of the two will be used.
130    pub fn task_hard_time_limit(mut self, task_hard_time_limit: u32) -> Self {
131        self.config.task_options.hard_time_limit = Some(task_hard_time_limit);
132        self
133    }
134
135    /// Set an app-level maximum number of retries for tasks (see [`TaskOptions::max_retries`]).
136    pub fn task_max_retries(mut self, task_max_retries: u32) -> Self {
137        self.config.task_options.max_retries = Some(task_max_retries);
138        self
139    }
140
141    /// Set an app-level minimum retry delay for tasks (see [`TaskOptions::min_retry_delay`]).
142    pub fn task_min_retry_delay(mut self, task_min_retry_delay: u32) -> Self {
143        self.config.task_options.min_retry_delay = Some(task_min_retry_delay);
144        self
145    }
146
147    /// Set an app-level maximum retry delay for tasks (see [`TaskOptions::max_retry_delay`]).
148    pub fn task_max_retry_delay(mut self, task_max_retry_delay: u32) -> Self {
149        self.config.task_options.max_retry_delay = Some(task_max_retry_delay);
150        self
151    }
152
153    /// Set whether by default `UnexpectedError`s should be retried for (see
154    /// [`TaskOptions::retry_for_unexpected`]).
155    pub fn task_retry_for_unexpected(mut self, retry_for_unexpected: bool) -> Self {
156        self.config.task_options.retry_for_unexpected = Some(retry_for_unexpected);
157        self
158    }
159
160    /// Set whether by default a task is acknowledged before or after execution (see
161    /// [`TaskOptions::acks_late`]).
162    pub fn acks_late(mut self, acks_late: bool) -> Self {
163        self.config.task_options.acks_late = Some(acks_late);
164        self
165    }
166
167    /// Set default serialization format a task will have (see [`TaskOptions::content_type`]).
168    pub fn task_content_type(mut self, content_type: MessageContentType) -> Self {
169        self.config.task_options.content_type = Some(content_type);
170        self
171    }
172
173    /// Add a routing rule.
174    pub fn task_route(mut self, pattern: &str, queue: &str) -> Self {
175        self.config.task_routes.push((pattern.into(), queue.into()));
176        self
177    }
178
179    /// Set a timeout in seconds before giving up establishing a connection to a broker.
180    pub fn broker_connection_timeout(mut self, timeout: u32) -> Self {
181        self.config.broker_connection_timeout = timeout;
182        self
183    }
184
185    /// Set whether or not to automatically try to re-establish connection to the AMQP broker.
186    pub fn broker_connection_retry(mut self, retry: bool) -> Self {
187        self.config.broker_connection_retry = retry;
188        self
189    }
190
191    /// Set the maximum number of retries before we give up trying to re-establish connection
192    /// to the AMQP broker.
193    pub fn broker_connection_max_retries(mut self, max_retries: u32) -> Self {
194        self.config.broker_connection_max_retries = max_retries;
195        self
196    }
197
198    /// Set the number of seconds to wait before re-trying the connection with the broker.
199    pub fn broker_connection_retry_delay(mut self, retry_delay: u32) -> Self {
200        self.config.broker_connection_retry_delay = retry_delay;
201        self
202    }
203
204    /// Construct a [`Celery`] app with the current configuration.
205    pub async fn build(self) -> Result<Celery, CeleryError> {
206        // Declare default queue to broker.
207        let broker_builder = self
208            .config
209            .broker_builder
210            .declare_queue(&self.config.default_queue);
211
212        let (broker_builder, task_routes) =
213            configure_task_routes(broker_builder, &self.config.task_routes)?;
214
215        let broker = build_and_connect(
216            broker_builder,
217            self.config.broker_connection_timeout,
218            if self.config.broker_connection_retry {
219                self.config.broker_connection_max_retries
220            } else {
221                0
222            },
223            self.config.broker_connection_retry_delay,
224        )
225        .await?;
226
227        Ok(Celery {
228            name: self.config.name,
229            hostname: self.config.hostname,
230            broker,
231            default_queue: self.config.default_queue,
232            task_options: self.config.task_options,
233            task_routes,
234            task_trace_builders: RwLock::new(HashMap::new()),
235            broker_connection_timeout: self.config.broker_connection_timeout,
236            broker_connection_retry: self.config.broker_connection_retry,
237            broker_connection_max_retries: self.config.broker_connection_max_retries,
238            broker_connection_retry_delay: self.config.broker_connection_retry_delay,
239            result_backend: self.config.result_backend.clone(),
240        })
241    }
242}
243
244/// A [`Celery`] app is used to produce or consume tasks asynchronously. This is the struct that is
245/// created with the [`app!`](crate::app!) macro.
246pub struct Celery {
247    /// An arbitrary, human-readable name for the app.
248    pub name: String,
249
250    /// Node name of the app.
251    pub hostname: String,
252
253    /// The app's broker.
254    pub broker: Box<dyn Broker>,
255
256    /// The default queue to send and receive from.
257    pub default_queue: String,
258
259    /// Default task options.
260    pub task_options: TaskOptions,
261
262    /// A vector of routing rules in the order of their importance.
263    task_routes: Vec<Rule>,
264
265    /// Mapping of task name to task tracer factory. Used to create a task tracer
266    /// from an incoming message.
267    task_trace_builders: RwLock<HashMap<String, TraceBuilder>>,
268
269    broker_connection_timeout: u32,
270    broker_connection_retry: bool,
271    broker_connection_max_retries: u32,
272    broker_connection_retry_delay: u32,
273    result_backend: Option<Arc<dyn ResultBackend>>,
274}
275
276impl Celery {
277    /// Returns a clone of the configured result backend, if any.
278    pub fn result_backend(&self) -> Option<Arc<dyn ResultBackend>> {
279        self.result_backend.clone()
280    }
281
282    /// Print a pretty ASCII art logo and configuration settings.
283    ///
284    /// This is useful and fun to print from a worker application right after
285    /// the [`Celery`] app is initialized.
286    pub async fn display_pretty(&self) {
287        // Cool ASCII logo with hostname.
288        let banner = format!(
289            r#"
290  _________________          >_<
291 /  ______________ \         | |
292/  /              \_\  ,---. | | ,---. ,--.--.,--. ,--.
293| /   .<      >.      | .-. :| || .-. :|  .--' \  '  /
294| |   (        )      \   --.| |\   --.|  |     \   /
295| |    --o--o--        `----'`-' `----'`--'   .-'  /
296| |  _/        \_   __                         `--'
297| | / \________/ \ / /
298| \    |      |   / /
299 \ \_____________/ /    {}
300  \_______________/
301"#,
302            self.hostname
303        );
304        println!("{}", banner.truecolor(255, 102, 0));
305
306        // Broker.
307        println!("{}", "[broker]".bold());
308        println!(" {}", self.broker.safe_url());
309        println!();
310
311        // Registered tasks.
312        println!("{}", "[tasks]".bold());
313        for task in self.task_trace_builders.read().await.keys() {
314            println!(" . {task}");
315        }
316        println!();
317    }
318
319    /// Send a task to a remote worker. Returns an [`AsyncResult`] with the task ID of the task
320    /// if it was successfully sent.
321    pub async fn send_task<T: Task>(
322        &self,
323        mut task_sig: Signature<T>,
324    ) -> Result<AsyncResult, CeleryError> {
325        task_sig.options.update(&self.task_options);
326        let maybe_queue = task_sig.queue.take();
327        let queue = maybe_queue.as_deref().unwrap_or_else(|| {
328            crate::routing::route(T::NAME, &self.task_routes).unwrap_or(&self.default_queue)
329        });
330        let message = Message::try_from(task_sig)?;
331        info!(
332            "Sending task {}[{}] to {}",
333            T::NAME,
334            message.task_id(),
335            queue,
336        );
337        self.broker.send(&message, queue).await?;
338        Ok(AsyncResult::with_backend(
339            message.task_id(),
340            self.result_backend(),
341        ))
342    }
343
344    /// Register a task.
345    pub async fn register_task<T: Task + 'static>(&self) -> Result<(), CeleryError> {
346        let mut task_trace_builders = self.task_trace_builders.write().await;
347        if task_trace_builders.contains_key(T::NAME) {
348            Err(CeleryError::TaskRegistrationError(T::NAME.into()))
349        } else {
350            task_trace_builders.insert(T::NAME.into(), Box::new(build_tracer::<T>));
351            debug!("Registered task {}", T::NAME);
352            Ok(())
353        }
354    }
355
356    async fn get_task_tracer(
357        self: &Arc<Self>,
358        message: Message,
359        event_tx: UnboundedSender<TaskEvent>,
360    ) -> Result<Box<dyn TracerTrait>, Box<dyn Error + Send + Sync + 'static>> {
361        let task_trace_builders = self.task_trace_builders.read().await;
362        if let Some(build_tracer) = task_trace_builders.get(&message.headers.task) {
363            Ok(build_tracer(
364                self.clone(),
365                message,
366                self.task_options,
367                event_tx,
368                self.hostname.clone(),
369            )
370            .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?)
371        } else {
372            Err(
373                Box::new(CeleryError::UnregisteredTaskError(message.headers.task))
374                    as Box<dyn Error + Send + Sync + 'static>,
375            )
376        }
377    }
378
379    /// Tries converting a delivery into a `Message`, executing the corresponding task,
380    /// and communicating with the broker.
381    async fn try_handle_delivery(
382        self: &Arc<Self>,
383        delivery: Box<dyn Delivery>,
384        event_tx: UnboundedSender<TaskEvent>,
385    ) -> Result<(), Box<dyn Error + Send + Sync + 'static>> {
386        // Coerce the delivery into a protocol message.
387        let message = match delivery.try_deserialize_message() {
388            Ok(message) => message,
389            Err(e) => {
390                // This is a naughty message that we can't handle, so we'll ack it with
391                // the broker so it gets deleted.
392                self.broker
393                    .ack(delivery.as_ref())
394                    .await
395                    .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
396                return Err(Box::new(e));
397            }
398        };
399
400        // Try deserializing the message to create a task wrapped in a task tracer.
401        // (The tracer handles all of the logic of directly interacting with the task
402        // to execute it and run the post-execution functions).
403        let mut tracer = match self.get_task_tracer(message, event_tx).await {
404            Ok(tracer) => tracer,
405            Err(e) => {
406                // Even though the message meta data was okay, we failed to deserialize
407                // the body of the message for some reason, so ack it with the broker
408                // to delete it and return an error.
409                self.broker
410                    .ack(delivery.as_ref())
411                    .await
412                    .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
413                return Err(e);
414            }
415        };
416
417        if tracer.is_delayed() {
418            // Task has an ETA, so we need to increment the prefetch count so that
419            // we can receive other tasks while we wait for the ETA.
420            if let Err(e) = self.broker.increase_prefetch_count().await {
421                // If for some reason this operation fails, we should stop tracing
422                // this task and send it back to the broker to retry.
423                // Otherwise we could reach the prefetch_count and end up blocking
424                // other deliveries if there are a high number of messages with a
425                // future ETA.
426                self.broker
427                    .retry(delivery.as_ref(), None)
428                    .await
429                    .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
430                self.broker
431                    .ack(delivery.as_ref())
432                    .await
433                    .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
434                return Err(Box::new(e));
435            };
436
437            // Then wait for the task to be ready.
438            tracer.wait().await;
439        }
440
441        // If acks_late is false, we acknowledge the message before tracing it.
442        if !tracer.acks_late() {
443            self.broker
444                .ack(delivery.as_ref())
445                .await
446                .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
447        }
448
449        // Try tracing the task now.
450        // NOTE: we don't need to log errors from the trace here since the tracer
451        // handles all errors at it's own level or the task level. In this function
452        // we only log errors at the broker and delivery level.
453        if let Err(TraceError::Retry(retry_eta)) = tracer.trace().await {
454            // If retry error -> retry the task.
455            self.broker
456                .retry(delivery.as_ref(), retry_eta)
457                .await
458                .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
459        }
460
461        // If we have not done it before, we have to acknowledge the message now.
462        if tracer.acks_late() {
463            self.broker
464                .ack(delivery.as_ref())
465                .await
466                .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
467        }
468
469        // If we had increased the prefetch count above due to a future ETA, we have
470        // to decrease it back down to restore balance to the universe.
471        if tracer.is_delayed() {
472            self.broker
473                .decrease_prefetch_count()
474                .await
475                .map_err(|e| Box::new(e) as Box<dyn Error + Send + Sync + 'static>)?;
476        }
477
478        Ok(())
479    }
480
481    /// Wraps `try_handle_delivery` to catch any and all errors that might occur.
482    async fn handle_delivery(
483        self: Arc<Self>,
484        delivery: Box<dyn Delivery>,
485        event_tx: UnboundedSender<TaskEvent>,
486    ) {
487        if let Err(e) = self.try_handle_delivery(delivery, event_tx).await {
488            error!("{}", e);
489        }
490    }
491
492    /// Close channels and connections.
493    pub async fn close(&self) -> Result<(), CeleryError> {
494        Ok(self.broker.close().await?)
495    }
496
497    /// Consume tasks from the default queue.
498    pub async fn consume(self: &Arc<Self>) -> Result<(), CeleryError> {
499        let queues = &[&self.default_queue.clone()[..]];
500        Self::consume_from(self, queues).await
501    }
502
503    /// Consume tasks from any number of queues.
504    pub async fn consume_from(self: &Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
505        loop {
506            let result = self.clone()._consume_from(queues).await;
507            if !self.broker_connection_retry {
508                return result;
509            }
510
511            if let Err(err) = result {
512                match err {
513                    CeleryError::BrokerError(broker_err) => {
514                        if broker_err.is_connection_error() {
515                            error!("Broker connection failed");
516                        } else {
517                            return Err(CeleryError::BrokerError(broker_err));
518                        }
519                    }
520                    _ => return Err(err),
521                };
522            } else {
523                return result;
524            }
525
526            let mut reconnect_successful: bool = false;
527            for _ in 0..self.broker_connection_max_retries {
528                info!("Trying to re-establish connection with broker");
529                time::sleep(Duration::from_secs(
530                    self.broker_connection_retry_delay as u64,
531                ))
532                .await;
533
534                match self.broker.reconnect(self.broker_connection_timeout).await {
535                    Err(err) => {
536                        if err.is_connection_error() {
537                            continue;
538                        }
539                        return Err(CeleryError::BrokerError(err));
540                    }
541                    Ok(_) => {
542                        info!("Successfully reconnected with broker");
543                        reconnect_successful = true;
544                        break;
545                    }
546                };
547            }
548
549            if !reconnect_successful {
550                return Err(CeleryError::BrokerError(BrokerError::NotConnected));
551            }
552        }
553    }
554
555    #[allow(clippy::cognitive_complexity)]
556    async fn _consume_from(self: Arc<Self>, queues: &[&str]) -> Result<(), CeleryError> {
557        if queues.is_empty() {
558            return Err(CeleryError::NoQueueToConsume);
559        }
560
561        info!("Consuming from {:?}", queues);
562
563        // Stream of errors from broker. The capacity here is arbitrary because a single
564        // error from the broker should trigger this method to return early.
565        let (broker_error_tx, mut broker_error_rx) = mpsc::channel::<BrokerError>(100);
566
567        // Stream of deliveries from the queue.
568        let mut stream_map = StreamMap::new();
569        let mut consumer_tags = vec![];
570        for queue in queues {
571            let broker_error_tx = broker_error_tx.clone();
572
573            let (consumer_tag, consumer) = self
574                .broker
575                .consume(
576                    queue,
577                    Box::new(move |e| {
578                        broker_error_tx.clone().try_send(e).ok();
579                    }),
580                )
581                .await?;
582            stream_map.insert(queue, consumer);
583            consumer_tags.push(consumer_tag);
584        }
585
586        // Stream of OS signals.
587        let mut ender = Ender::new()?;
588
589        // A sender and receiver for task related events.
590        // NOTE: we can use an unbounded channel since we already have backpressure
591        // from the `prefetch_count` setting.
592        let (task_event_tx, mut task_event_rx) = mpsc::unbounded_channel::<TaskEvent>();
593        let mut pending_tasks = 0;
594
595        // This is the main loop where we receive deliveries and pass them off
596        // to be handled by spawning `self.handle_delivery`.
597        // At the same time we are also listening for a SIGINT (Ctrl+C) or SIGTERM interruption.
598        // If that occurs we break from this loop and move to the warm shutdown loop
599        // if there are still any pending tasks (tasks being executed, not including
600        // tasks being delayed due to a future ETA).
601        loop {
602            select! {
603                maybe_delivery_result = stream_map.next() => {
604                    if let Some((queue, delivery_result)) = maybe_delivery_result {
605                        match delivery_result {
606                            Ok(delivery) => {
607                                let task_event_tx = task_event_tx.clone();
608                                debug!("Received delivery from {}: {:?}", queue, delivery);
609                                tokio::spawn(self.clone().handle_delivery(delivery, task_event_tx));
610                            }
611                            Err(e) => {
612                                error!("Deliver failed: {}", e);
613                            }
614                        }
615                    }
616                },
617                ending = ender.wait() => {
618                    if let Ok(SigType::Interrupt) = ending {
619                        warn!("Ope! Hitting Ctrl+C again will terminate all running tasks!");
620                    }
621                    info!("Warm shutdown...");
622                    break;
623                },
624                maybe_task_event = task_event_rx.recv() => {
625                    if let Some(event) = maybe_task_event {
626                        debug!("Received task event {:?}", event);
627                        match event {
628                            TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
629                            TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
630                        };
631                    }
632                },
633                maybe_broker_error = broker_error_rx.recv() => {
634                    if let Some(broker_error) = maybe_broker_error {
635                        error!("{}", broker_error);
636                        return Err(broker_error.into());
637                    }
638                }
639            };
640        }
641
642        // Cancel consumers.
643        for consumer_tag in consumer_tags {
644            debug!("Cancelling consumer {}", consumer_tag);
645            self.broker.cancel(&consumer_tag).await?;
646        }
647
648        if pending_tasks > 0 {
649            // Warm shutdown loop. When there are still pending tasks we wait for them
650            // to finish. We get updates about pending tasks through the `task_event_rx` channel.
651            // We also watch for a second SIGINT or SIGTERM, in which case we immediately shutdown.
652            info!("Waiting on {} pending tasks...", pending_tasks);
653            loop {
654                select! {
655                    ending = ender.wait() => {
656                        if let Ok(SigType::Interrupt) = ending {
657                            warn!("Okay fine, shutting down now. See ya!");
658                            return Err(CeleryError::ForcedShutdown);
659                        }
660                    },
661                    maybe_event = task_event_rx.recv() => {
662                        if let Some(event) = maybe_event {
663                            debug!("Received task event {:?}", event);
664                            match event {
665                                TaskEvent::StatusChange(TaskStatus::Pending) => pending_tasks += 1,
666                                TaskEvent::StatusChange(TaskStatus::Finished) => pending_tasks -= 1,
667                            };
668                            if pending_tasks <= 0 {
669                                break;
670                            }
671                        }
672                    },
673                };
674            }
675        }
676
677        info!("No more pending tasks. See ya!");
678
679        Ok(())
680    }
681}
682
683#[allow(unused)]
684enum SigType {
685    /// Equivalent to SIGINT on unix systems.
686    Interrupt,
687    /// Equivalent to SIGTERM on unix systems.
688    Terminate,
689}
690
691/// The ender listens for signals.
692#[cfg(unix)]
693struct Ender {
694    sigint: Signal,
695    sigterm: Signal,
696}
697
698#[cfg(unix)]
699impl Ender {
700    fn new() -> Result<Self, std::io::Error> {
701        let sigint = signal(SignalKind::interrupt())?;
702        let sigterm = signal(SignalKind::terminate())?;
703
704        Ok(Ender { sigint, sigterm })
705    }
706
707    /// Waits for either an interrupt or terminate.
708    async fn wait(&mut self) -> Result<SigType, std::io::Error> {
709        let sigtype;
710
711        select! {
712            _ = self.sigint.recv() => {
713                sigtype = SigType::Interrupt
714            },
715            _ = self.sigterm.recv() => {
716                sigtype = SigType::Terminate
717            }
718        }
719
720        Ok(sigtype)
721    }
722}
723
724#[cfg(windows)]
725struct Ender;
726
727#[cfg(windows)]
728impl Ender {
729    fn new() -> Result<Self, std::io::Error> {
730        Ok(Ender)
731    }
732
733    async fn wait(&mut self) -> Result<SigType, std::io::Error> {
734        tokio::signal::ctrl_c().await?;
735
736        Ok(SigType::Interrupt)
737    }
738}
739
740#[cfg(test)]
741mod tests;