1use std::{
5 error::Error,
6 sync::atomic::{AtomicBool, Ordering},
7 time::Duration,
8};
9
10use futures::{FutureExt, SinkExt, StreamExt, future::pending};
11use log::{Level, log};
12use tokio::sync::{
13 mpsc,
14 watch::{self, error::SendError},
15};
16
17use crate::event::{SentStatus, StatusReceiver, StatusSender, TaskStatus};
18
19#[cfg(not(target_arch = "wasm32"))]
20use tokio::time::{sleep, timeout};
21
22#[cfg(target_arch = "wasm32")]
23use wasmtimer::tokio::{sleep, timeout};
24
25const DEFAULT_SHUTDOWN_TIMER_SECS: u64 = 5;
26
27pub(crate) type SentError = Box<dyn Error + Send + Sync>;
28type ErrorSender = mpsc::UnboundedSender<SentError>;
29type ErrorReceiver = mpsc::UnboundedReceiver<SentError>;
30
31fn try_recover_name(name: &Option<String>) -> String {
32 if let Some(name) = name {
33 name.clone()
34 } else {
35 "unknown".to_string()
36 }
37}
38
39#[derive(thiserror::Error, Debug)]
40enum TaskError {
41 #[error("Task '{}' halted unexpectedly", try_recover_name(.shutdown_name))]
42 UnexpectedHalt { shutdown_name: Option<String> },
43}
44
45#[deprecated(note = "use ShutdownManager instead")]
48#[derive(Debug)]
49pub struct TaskManager {
50 name: Option<String>,
52
53 notify_tx: watch::Sender<()>,
56 notify_rx: Option<watch::Receiver<()>>,
57 #[cfg_attr(target_arch = "wasm32", allow(dead_code))]
58 shutdown_timer_secs: u64,
59
60 task_return_error_tx: ErrorSender,
62 task_return_error_rx: Option<ErrorReceiver>,
63
64 task_drop_tx: ErrorSender,
68 task_drop_rx: Option<ErrorReceiver>,
69
70 task_status_tx: StatusSender,
73 task_status_rx: Option<StatusReceiver>,
74}
75
76#[allow(deprecated)]
77impl Default for TaskManager {
78 fn default() -> Self {
79 let (notify_tx, notify_rx) = watch::channel(());
80 let (task_halt_tx, task_halt_rx) = mpsc::unbounded_channel();
81 let (task_drop_tx, task_drop_rx) = mpsc::unbounded_channel();
82 let (task_status_tx, task_status_rx) = futures::channel::mpsc::channel(128);
85 Self {
86 name: None,
87 notify_tx,
88 notify_rx: Some(notify_rx),
89 shutdown_timer_secs: DEFAULT_SHUTDOWN_TIMER_SECS,
90 task_return_error_tx: task_halt_tx,
91 task_return_error_rx: Some(task_halt_rx),
92 task_drop_tx,
93 task_drop_rx: Some(task_drop_rx),
94 task_status_tx,
95 task_status_rx: Some(task_status_rx),
96 }
97 }
98}
99
100#[allow(deprecated)]
101#[allow(clippy::expect_used)]
102impl TaskManager {
103 pub fn new(shutdown_timer_secs: u64) -> Self {
104 Self {
105 shutdown_timer_secs,
106 ..Default::default()
107 }
108 }
109
110 #[must_use]
111 pub fn named<S: Into<String>>(mut self, name: S) -> Self {
112 self.name = Some(name.into());
113 self
114 }
115
116 #[cfg(not(target_arch = "wasm32"))]
117 pub async fn catch_interrupt(&mut self) -> Result<(), SentError> {
118 let res = crate::wait_for_signal_and_error(self).await;
119
120 log::info!("Sending shutdown");
121 self.signal_shutdown().ok();
122
123 log::info!("Waiting for tasks to finish... (Press ctrl-c to force)");
124 self.wait_for_shutdown().await;
125
126 res
127 }
128
129 pub fn subscribe(&self) -> TaskClient {
130 let task_client = TaskClient::new(
131 self.notify_rx
132 .as_ref()
133 .expect("Unable to subscribe to shutdown notifier that is already shutdown")
134 .clone(),
135 self.task_return_error_tx.clone(),
136 self.task_drop_tx.clone(),
137 self.task_status_tx.clone(),
138 );
139
140 if let Some(name) = &self.name {
141 task_client.named(format!("{name}-child"))
142 } else {
143 task_client
144 }
145 }
146
147 pub fn subscribe_named<S: Into<String>>(&self, suffix: S) -> TaskClient {
148 let task_client = self.subscribe();
149 let suffix = suffix.into();
150 let child_name = if let Some(base) = &self.name {
151 format!("{base}-{suffix}")
152 } else {
153 format!("unknown-{suffix}")
154 };
155 task_client.named(child_name)
156 }
157
158 pub fn signal_shutdown(&self) -> Result<(), SendError<()>> {
159 self.notify_tx.send(())
160 }
161
162 pub async fn start_status_listener(
163 &mut self,
164 mut sender: StatusSender,
165 start_status: TaskStatus,
166 ) {
167 if let Err(msg) = sender.send(Box::new(start_status)).await {
170 log::error!("Error sending status message: {msg}");
171 };
172
173 if let Some(mut task_status_rx) = self.task_status_rx.take() {
174 log::info!("Starting status message listener");
175 crate::spawn::spawn_future(async move {
176 loop {
177 if let Some(msg) = task_status_rx.next().await {
178 log::trace!("Got msg: {msg}");
179 if let Err(msg) = sender.send(msg).await {
180 log::error!("Error sending status message: {msg}");
181 }
182 } else {
183 log::trace!("Stopping since channel closed");
184 break;
185 }
186 }
187 log::debug!("Status listener: Exiting");
188 });
189 }
190 }
191
192 #[cfg(not(target_arch = "wasm32"))]
194 pub(crate) fn task_return_error_rx(&mut self) -> ErrorReceiver {
195 self.task_return_error_rx
196 .take()
197 .expect("unable to get error channel: attempt to wait twice?")
198 }
199
200 #[cfg(not(target_arch = "wasm32"))]
201 pub(crate) fn task_drop_rx(&mut self) -> ErrorReceiver {
202 self.task_drop_rx
203 .take()
204 .expect("unable to get task drop channel: attempt to wait twice?")
205 }
206
207 pub async fn wait_for_error(&mut self) -> Option<SentError> {
208 let mut error_rx = self
209 .task_return_error_rx
210 .take()
211 .expect("Unable to wait for error: attempt to wait twice?");
212 let mut drop_rx = self
213 .task_drop_rx
214 .take()
215 .expect("Unable to wait for error: attempt to wait twice?");
216
217 let drop_rx = drop_rx.recv().then(|msg| async move {
220 sleep(Duration::from_millis(50)).await;
221 msg
222 });
223
224 tokio::select! {
225 msg = error_rx.recv() => msg,
226 msg = drop_rx => msg
227 }
228 }
229
230 pub async fn wait_for_graceful_shutdown(&mut self) {
231 if let Some(notify_rx) = self.notify_rx.take() {
232 drop(notify_rx);
233 }
234 self.notify_tx.closed().await
235 }
236
237 pub async fn wait_for_shutdown(&mut self) {
238 log::debug!("Waiting for shutdown");
239 if let Some(notify_rx) = self.notify_rx.take() {
240 drop(notify_rx);
241 }
242
243 #[cfg(not(target_arch = "wasm32"))]
244 let interrupt_future = tokio::signal::ctrl_c();
245
246 #[cfg(target_arch = "wasm32")]
247 let interrupt_future = futures::future::pending::<()>();
248
249 let wait_future = sleep(Duration::from_secs(self.shutdown_timer_secs));
250
251 tokio::select! {
252 _ = self.notify_tx.closed() => {
253 log::info!("All registered tasks succesfully shutdown");
254 },
255 _ = interrupt_future => {
256 log::info!("Forcing shutdown");
257 }
258 _ = wait_future => {
259 log::info!("Timeout reached, forcing shutdown");
260 },
261 }
262 }
263}
264
265#[derive(Debug)]
268#[deprecated(note = "use ShutdownToken instead")]
269pub struct TaskClient {
270 name: Option<String>,
272
273 shutdown: AtomicBool,
278
279 notify: watch::Receiver<()>,
282
283 return_error: ErrorSender,
285
286 drop_error: ErrorSender,
288
289 status_msg: StatusSender,
291
292 mode: ClientOperatingMode,
294}
295
296#[allow(deprecated)]
297impl Clone for TaskClient {
298 fn clone(&self) -> Self {
299 let name = if let Some(name) = &self.name {
301 if name != Self::OVERFLOW_NAME && name.len() < Self::MAX_NAME_LENGTH {
302 Some(format!("{name}-child"))
303 } else {
304 Some(Self::OVERFLOW_NAME.to_string())
305 }
306 } else {
307 None
308 };
309
310 log::debug!("Cloned task client: {name:?}");
311
312 TaskClient {
313 name,
314 shutdown: AtomicBool::new(self.shutdown.load(Ordering::Relaxed)),
315 notify: self.notify.clone(),
316 return_error: self.return_error.clone(),
317 drop_error: self.drop_error.clone(),
318 status_msg: self.status_msg.clone(),
319 mode: self.mode.clone(),
320 }
321 }
322}
323
324#[allow(deprecated)]
325impl TaskClient {
326 const MAX_NAME_LENGTH: usize = 128;
327 const OVERFLOW_NAME: &'static str = "reached maximum TaskClient children name depth";
328
329 const SHUTDOWN_TIMEOUT_WAITING_FOR_SIGNAL_ON_EXIT: Duration = Duration::from_secs(10);
330
331 fn new(
332 notify: watch::Receiver<()>,
333 return_error: ErrorSender,
334 drop_error: ErrorSender,
335 status_msg: StatusSender,
336 ) -> TaskClient {
337 TaskClient {
338 name: None,
339 shutdown: AtomicBool::new(false),
340 notify,
341 return_error,
342 drop_error,
343 status_msg,
344 mode: ClientOperatingMode::Listening,
345 }
346 }
347
348 pub fn fork<S: Into<String>>(&self, child_suffix: S) -> Self {
350 let mut child = self.clone();
351 let suffix = child_suffix.into();
352 let child_name = if let Some(base) = &self.name {
353 format!("{base}-{suffix}")
354 } else {
355 format!("unknown-{suffix}")
356 };
357
358 log::debug!("Forked task client: {child_name}");
359 child.name = Some(child_name);
360 child
361 }
362
363 fn log<S: Into<String>>(&self, level: Level, msg: S) {
367 let msg = msg.into();
368
369 let target = &if let Some(name) = &self.name {
370 format!("TaskClient-{name}")
371 } else {
372 "unnamed-TaskClient".to_string()
373 };
374
375 log!(target: target, level, "{}", format_args!("[{target}] {msg}"))
376 }
377
378 #[must_use]
379 pub fn named<S: Into<String>>(mut self, name: S) -> Self {
380 self.name = Some(name.into());
381 self
382 }
383
384 #[must_use]
385 pub fn with_suffix<S: Into<String>>(self, suffix: S) -> Self {
386 let suffix = suffix.into();
387 let name = if let Some(base) = &self.name {
388 format!("{base}-{suffix}")
389 } else {
390 format!("unknown-{suffix}")
391 };
392 log::debug!("Renamed task client: {name}");
393 self.named(name)
394 }
395
396 pub fn dummy() -> TaskClient {
398 let (_notify_tx, notify_rx) = watch::channel(());
399 let (task_halt_tx, _task_halt_rx) = mpsc::unbounded_channel();
400 let (task_drop_tx, _task_drop_rx) = mpsc::unbounded_channel();
401 let (task_status_tx, _task_status_rx) = futures::channel::mpsc::channel(128);
402 TaskClient {
403 name: None,
404 shutdown: AtomicBool::new(false),
405 notify: notify_rx,
406 return_error: task_halt_tx,
407 drop_error: task_drop_tx,
408 status_msg: task_status_tx,
409 mode: ClientOperatingMode::Dummy,
410 }
411 }
412
413 pub fn is_dummy(&self) -> bool {
414 self.mode.is_dummy()
415 }
416
417 pub fn is_shutdown(&self) -> bool {
418 if self.mode.is_dummy() {
419 false
420 } else {
421 self.shutdown.load(Ordering::Relaxed)
422 }
423 }
424
425 pub async fn recv(&mut self) {
426 if self.mode.is_dummy() {
427 return pending().await;
428 }
429 if self.shutdown.load(Ordering::Relaxed) {
430 return;
431 }
432 let _ = self.notify.changed().await;
433 self.shutdown.store(true, Ordering::Relaxed);
434 }
435
436 pub async fn recv_with_delay(&mut self) {
437 self.recv()
438 .then(|msg| async move {
439 sleep(Duration::from_secs(2)).await;
440 msg
441 })
442 .await
443 }
444
445 #[allow(clippy::panic)]
447 pub async fn recv_timeout(&mut self) {
448 if self.mode.is_dummy() {
449 return pending().await;
450 }
451
452 if let Err(timeout) = timeout(
453 Self::SHUTDOWN_TIMEOUT_WAITING_FOR_SIGNAL_ON_EXIT,
454 self.recv(),
455 )
456 .await
457 {
458 self.log(Level::Error, "Task stopped without shutdown called");
459 panic!("{:?}: {timeout}", self.name)
460 }
461 }
462
463 pub fn is_shutdown_poll(&self) -> bool {
464 if self.mode.is_dummy() {
465 return false;
466 }
467 if self.shutdown.load(Ordering::Relaxed) {
468 return true;
469 }
470 match self.notify.has_changed() {
471 Ok(has_changed) => {
472 if has_changed {
473 self.shutdown.store(true, Ordering::Relaxed);
474 }
475 has_changed
476 }
477 Err(err) => {
478 self.log(Level::Error, format!("Polling shutdown failed: {err}"));
479 self.log(Level::Error, "Assuming this means we should shutdown...");
480
481 true
482 }
483 }
484 }
485
486 pub fn disarm(&mut self) {
490 self.mode.set_should_not_signal_on_drop();
491 }
492
493 pub fn rearm(&mut self) {
494 self.mode.set_should_signal_on_drop();
495 }
496
497 pub fn send_we_stopped(&mut self, err: SentError) {
498 if self.mode.is_dummy() {
499 return;
500 }
501
502 self.log(Level::Trace, format!("Notifying we stopped: {err}"));
503
504 if self.return_error.send(err).is_err() {
505 self.log(Level::Error, "failed to send back error message");
506 }
507 }
508
509 pub fn send_status_msg(&mut self, msg: SentStatus) {
510 if self.mode.is_dummy() {
511 return;
512 }
513 self.status_msg.try_send(msg).ok();
516 }
517}
518
519#[allow(deprecated)]
520impl Drop for TaskClient {
521 fn drop(&mut self) {
522 if !self.mode.should_signal_on_drop() {
523 self.log(
524 Level::Trace,
525 "the task client is getting dropped but instructed to not signal: this is expected during client shutdown",
526 );
527 return;
528 } else {
529 self.log(
530 Level::Debug,
531 "the task client is getting dropped: this is expected during client shutdown",
532 );
533 }
534
535 if !self.is_shutdown_poll() {
536 self.log(Level::Trace, "Notifying stop on unexpected drop");
537
538 self.drop_error
540 .send(Box::new(TaskError::UnexpectedHalt {
541 shutdown_name: self.name.clone(),
542 }))
543 .ok();
544 }
545 }
546}
547
548#[derive(Clone, Debug, PartialEq, Eq)]
549enum ClientOperatingMode {
550 Listening,
552 ListeningButDontReportHalt,
554 Dummy,
556}
557
558impl ClientOperatingMode {
559 fn is_dummy(&self) -> bool {
560 self == &ClientOperatingMode::Dummy
561 }
562
563 fn should_signal_on_drop(&self) -> bool {
564 match self {
565 ClientOperatingMode::Listening => true,
566 ClientOperatingMode::ListeningButDontReportHalt | ClientOperatingMode::Dummy => false,
567 }
568 }
569
570 fn set_should_signal_on_drop(&mut self) {
571 use ClientOperatingMode::{Dummy, Listening, ListeningButDontReportHalt};
572 *self = match &self {
573 ListeningButDontReportHalt | Listening => Listening,
574 Dummy => Dummy,
575 };
576 }
577
578 fn set_should_not_signal_on_drop(&mut self) {
579 use ClientOperatingMode::{Dummy, Listening, ListeningButDontReportHalt};
580 *self = match &self {
581 ListeningButDontReportHalt | Listening => ListeningButDontReportHalt,
582 Dummy => Dummy,
583 };
584 }
585}
586
587#[deprecated]
588#[allow(deprecated)]
589#[derive(Debug)]
590pub enum TaskHandle {
591 Internal(TaskManager),
593
594 External(TaskClient),
596}
597
598#[allow(deprecated)]
599impl From<TaskManager> for TaskHandle {
600 fn from(value: TaskManager) -> Self {
601 TaskHandle::Internal(value)
602 }
603}
604
605#[allow(deprecated)]
606impl From<TaskClient> for TaskHandle {
607 fn from(value: TaskClient) -> Self {
608 TaskHandle::External(value)
609 }
610}
611
612#[allow(deprecated)]
613impl Default for TaskHandle {
614 fn default() -> Self {
615 TaskHandle::Internal(TaskManager::default())
616 }
617}
618
619#[allow(deprecated)]
620impl TaskHandle {
621 #[must_use]
622 pub fn name_if_unnamed<S: Into<String>>(self, name: S) -> Self {
623 match self {
624 TaskHandle::Internal(task_manager) => {
625 if task_manager.name.is_none() {
626 TaskHandle::Internal(task_manager.named(name))
627 } else {
628 TaskHandle::Internal(task_manager)
629 }
630 }
631 TaskHandle::External(task_client) => {
632 if task_client.name.is_none() {
633 TaskHandle::External(task_client.named(name))
634 } else {
635 TaskHandle::External(task_client)
636 }
637 }
638 }
639 }
640
641 #[must_use]
642 pub fn named<S: Into<String>>(self, name: S) -> Self {
643 match self {
644 TaskHandle::Internal(task_manager) => TaskHandle::Internal(task_manager.named(name)),
645 TaskHandle::External(task_client) => TaskHandle::External(task_client.named(name)),
646 }
647 }
648
649 pub fn fork<S: Into<String>>(&self, child_suffix: S) -> TaskClient {
650 match self {
651 TaskHandle::External(shutdown) => shutdown.fork(child_suffix),
652 TaskHandle::Internal(shutdown) => shutdown.subscribe_named(child_suffix),
653 }
654 }
655
656 pub fn get_handle(&self) -> TaskClient {
657 match self {
658 TaskHandle::External(shutdown) => shutdown.clone(),
659 TaskHandle::Internal(shutdown) => shutdown.subscribe(),
660 }
661 }
662
663 pub fn try_into_task_manager(self) -> Option<TaskManager> {
664 match self {
665 TaskHandle::External(_) => None,
666 TaskHandle::Internal(shutdown) => Some(shutdown),
667 }
668 }
669
670 #[cfg(not(target_arch = "wasm32"))]
671 pub async fn wait_for_shutdown(self) -> Result<(), SentError> {
672 match self {
673 TaskHandle::Internal(mut task_manager) => task_manager.catch_interrupt().await,
674 TaskHandle::External(mut task_client) => {
675 task_client.recv().await;
676 Ok(())
677 }
678 }
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685
686 #[tokio::test]
687 #[allow(deprecated)]
688 async fn signal_shutdown() {
689 let shutdown = TaskManager::default();
690 let mut listener = shutdown.subscribe();
691
692 let task = tokio::spawn(async move {
693 tokio::select! {
694 _ = listener.recv() => 42,
695 }
696 });
697
698 shutdown.signal_shutdown().unwrap();
699 assert_eq!(task.await.unwrap(), 42);
700 }
701}