zenoh_task/
lib.rs

1//
2// Copyright (c) 2024 ZettaScale Technology
3//
4// This program and the accompanying materials are made available under the
5// terms of the Eclipse Public License 2.0 which is available at
6// http://www.eclipse.org/legal/epl-2.0, or the Apache License, Version 2.0
7// which is available at https://www.apache.org/licenses/LICENSE-2.0.
8//
9// SPDX-License-Identifier: EPL-2.0 OR Apache-2.0
10//
11// Contributors:
12//   ZettaScale Zenoh Team, <zenoh@zettascale.tech>
13//
14
15//! ⚠️ WARNING ⚠️
16//!
17//! This module is intended for Zenoh's internal use.
18//!
19//! [Click here for Zenoh's documentation](https://docs.rs/zenoh/latest/zenoh)
20
21use std::{future::Future, time::Duration};
22
23use futures::future::FutureExt;
24use tokio::task::JoinHandle;
25use tokio_util::{sync::CancellationToken, task::TaskTracker};
26use zenoh_core::{ResolveFuture, Wait};
27use zenoh_runtime::ZRuntime;
28
29#[derive(Clone)]
30pub struct TaskController {
31    tracker: TaskTracker,
32    token: CancellationToken,
33}
34
35impl Default for TaskController {
36    fn default() -> Self {
37        TaskController {
38            tracker: TaskTracker::new(),
39            token: CancellationToken::new(),
40        }
41    }
42}
43
44impl TaskController {
45    /// Spawns a task that can be later terminated by call to [`TaskController::terminate_all()`].
46    /// Task output is ignored.
47    pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<()>
48    where
49        F: Future<Output = T> + Send + 'static,
50        T: Send + 'static,
51    {
52        #[cfg(feature = "tracing-instrument")]
53        let future = tracing::Instrument::instrument(future, tracing::Span::current());
54
55        let token = self.token.child_token();
56        let task = async move {
57            tokio::select! {
58                _ = token.cancelled() => {},
59                _ = future => {}
60            }
61        };
62
63        self.tracker.spawn(task)
64    }
65
66    /// Spawns a task using a specified runtime that can be later terminated by call to [`TaskController::terminate_all()`].
67    /// Task output is ignored.
68    pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
69    where
70        F: Future<Output = T> + Send + 'static,
71        T: Send + 'static,
72    {
73        #[cfg(feature = "tracing-instrument")]
74        let future = tracing::Instrument::instrument(future, tracing::Span::current());
75
76        let token = self.token.child_token();
77        let task = async move {
78            tokio::select! {
79                _ = token.cancelled() => {},
80                _ = future => {}
81            }
82        };
83        self.tracker.spawn_on(task, &rt)
84    }
85
86    pub fn get_cancellation_token(&self) -> CancellationToken {
87        self.token.child_token()
88    }
89
90    /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
91    /// or that can run to completion in finite amount of time.
92    /// It can be later terminated by call to [`TaskController::terminate_all()`].
93    pub fn spawn<F, T>(&self, future: F) -> JoinHandle<()>
94    where
95        F: Future<Output = T> + Send + 'static,
96        T: Send + 'static,
97    {
98        #[cfg(feature = "tracing-instrument")]
99        let future = tracing::Instrument::instrument(future, tracing::Span::current());
100
101        self.tracker.spawn(future.map(|_f| ()))
102    }
103
104    /// Spawns a task that can be cancelled via cancellation of a token obtained by [`TaskController::get_cancellation_token()`],
105    /// or that can run to completion in finite amount of time, using a specified runtime.
106    /// It can be later aborted by call to [`TaskController::terminate_all()`].
107    pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
108    where
109        F: Future<Output = T> + Send + 'static,
110        T: Send + 'static,
111    {
112        #[cfg(feature = "tracing-instrument")]
113        let future = tracing::Instrument::instrument(future, tracing::Span::current());
114
115        self.tracker.spawn_on(future.map(|_f| ()), &rt)
116    }
117
118    /// Attempts tp terminate all previously spawned tasks
119    /// The caller must ensure that all tasks spawned with [`TaskController::spawn()`]
120    /// or [`TaskController::spawn_with_rt()`] can yield in finite amount of time either because they will run to completion
121    /// or due to cancellation of token acquired via [`TaskController::get_cancellation_token()`].
122    /// Tasks spawned with [`TaskController::spawn_abortable()`] or [`TaskController::spawn_abortable_with_rt()`] will be aborted (i.e. terminated upon next await call).
123    /// The call blocks until all tasks yield or timeout duration expires.
124    /// Returns 0 in case of success, number of non terminated tasks otherwise.
125    pub fn terminate_all(&self, timeout: Duration) -> usize {
126        ResolveFuture::new(async move {
127            if tokio::time::timeout(timeout, self.terminate_all_async())
128                .await
129                .is_err()
130            {
131                tracing::error!("Failed to terminate {} tasks", self.tracker.len());
132            }
133            self.tracker.len()
134        })
135        .wait()
136    }
137
138    /// Async version of [`TaskController::terminate_all()`].
139    pub async fn terminate_all_async(&self) {
140        self.tracker.close();
141        self.token.cancel();
142        self.tracker.wait().await
143    }
144}
145
146pub struct TerminatableTask {
147    handle: Option<JoinHandle<()>>,
148    token: CancellationToken,
149}
150
151impl Drop for TerminatableTask {
152    fn drop(&mut self) {
153        self.terminate(std::time::Duration::from_secs(10));
154    }
155}
156
157impl TerminatableTask {
158    pub fn create_cancellation_token() -> CancellationToken {
159        CancellationToken::new()
160    }
161
162    /// Spawns a task that can be later terminated by [`TerminatableTask::terminate()`].
163    /// Prior to termination attempt the specified cancellation token will be cancelled.
164    pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
165    where
166        F: Future<Output = T> + Send + 'static,
167        T: Send + 'static,
168    {
169        TerminatableTask {
170            handle: Some(rt.spawn(future.map(|_f| ()))),
171            token,
172        }
173    }
174
175    /// Spawns a task that can be later aborted by [`TerminatableTask::terminate()`].
176    pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
177    where
178        F: Future<Output = T> + Send + 'static,
179        T: Send + 'static,
180    {
181        let token = CancellationToken::new();
182        let token2 = token.clone();
183        let task = async move {
184            tokio::select! {
185                _ = token2.cancelled() => {},
186                _ = future => {}
187            }
188        };
189
190        TerminatableTask {
191            handle: Some(rt.spawn(task)),
192            token,
193        }
194    }
195
196    /// Attempts to terminate the task.
197    /// Returns true if task completed / aborted within timeout duration, false otherwise.
198    pub fn terminate(&mut self, timeout: Duration) -> bool {
199        ResolveFuture::new(async move {
200            if tokio::time::timeout(timeout, self.terminate_async())
201                .await
202                .is_err()
203            {
204                tracing::error!("Failed to terminate the task");
205                return false;
206            };
207            true
208        })
209        .wait()
210    }
211
212    /// Async version of [`TerminatableTask::terminate()`].
213    pub async fn terminate_async(&mut self) {
214        self.token.cancel();
215        if let Some(handle) = self.handle.take() {
216            let _ = handle.await;
217        }
218    }
219}