1use std::{future::Future, time::Duration};
22
23use futures::future::FutureExt;
24use tokio::task::JoinHandle;
25use tokio_util::{sync::CancellationToken, task::TaskTracker};
26use zenoh_core::{ResolveFuture, Wait};
27use zenoh_runtime::ZRuntime;
28
29#[derive(Clone)]
30pub struct TaskController {
31 tracker: TaskTracker,
32 token: CancellationToken,
33}
34
35impl Default for TaskController {
36 fn default() -> Self {
37 TaskController {
38 tracker: TaskTracker::new(),
39 token: CancellationToken::new(),
40 }
41 }
42}
43
44impl TaskController {
45 pub fn spawn_abortable<F, T>(&self, future: F) -> JoinHandle<()>
48 where
49 F: Future<Output = T> + Send + 'static,
50 T: Send + 'static,
51 {
52 #[cfg(feature = "tracing-instrument")]
53 let future = tracing::Instrument::instrument(future, tracing::Span::current());
54
55 let token = self.token.child_token();
56 let task = async move {
57 tokio::select! {
58 _ = token.cancelled() => {},
59 _ = future => {}
60 }
61 };
62
63 self.tracker.spawn(task)
64 }
65
66 pub fn spawn_abortable_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
69 where
70 F: Future<Output = T> + Send + 'static,
71 T: Send + 'static,
72 {
73 #[cfg(feature = "tracing-instrument")]
74 let future = tracing::Instrument::instrument(future, tracing::Span::current());
75
76 let token = self.token.child_token();
77 let task = async move {
78 tokio::select! {
79 _ = token.cancelled() => {},
80 _ = future => {}
81 }
82 };
83 self.tracker.spawn_on(task, &rt)
84 }
85
86 pub fn get_cancellation_token(&self) -> CancellationToken {
87 self.token.child_token()
88 }
89
90 pub fn spawn<F, T>(&self, future: F) -> JoinHandle<()>
94 where
95 F: Future<Output = T> + Send + 'static,
96 T: Send + 'static,
97 {
98 #[cfg(feature = "tracing-instrument")]
99 let future = tracing::Instrument::instrument(future, tracing::Span::current());
100
101 self.tracker.spawn(future.map(|_f| ()))
102 }
103
104 pub fn spawn_with_rt<F, T>(&self, rt: ZRuntime, future: F) -> JoinHandle<()>
108 where
109 F: Future<Output = T> + Send + 'static,
110 T: Send + 'static,
111 {
112 #[cfg(feature = "tracing-instrument")]
113 let future = tracing::Instrument::instrument(future, tracing::Span::current());
114
115 self.tracker.spawn_on(future.map(|_f| ()), &rt)
116 }
117
118 pub fn terminate_all(&self, timeout: Duration) -> usize {
126 ResolveFuture::new(async move {
127 if tokio::time::timeout(timeout, self.terminate_all_async())
128 .await
129 .is_err()
130 {
131 tracing::error!("Failed to terminate {} tasks", self.tracker.len());
132 }
133 self.tracker.len()
134 })
135 .wait()
136 }
137
138 pub async fn terminate_all_async(&self) {
140 self.tracker.close();
141 self.token.cancel();
142 self.tracker.wait().await
143 }
144}
145
146pub struct TerminatableTask {
147 handle: Option<JoinHandle<()>>,
148 token: CancellationToken,
149}
150
151impl Drop for TerminatableTask {
152 fn drop(&mut self) {
153 self.terminate(std::time::Duration::from_secs(10));
154 }
155}
156
157impl TerminatableTask {
158 pub fn create_cancellation_token() -> CancellationToken {
159 CancellationToken::new()
160 }
161
162 pub fn spawn<F, T>(rt: ZRuntime, future: F, token: CancellationToken) -> TerminatableTask
165 where
166 F: Future<Output = T> + Send + 'static,
167 T: Send + 'static,
168 {
169 TerminatableTask {
170 handle: Some(rt.spawn(future.map(|_f| ()))),
171 token,
172 }
173 }
174
175 pub fn spawn_abortable<F, T>(rt: ZRuntime, future: F) -> TerminatableTask
177 where
178 F: Future<Output = T> + Send + 'static,
179 T: Send + 'static,
180 {
181 let token = CancellationToken::new();
182 let token2 = token.clone();
183 let task = async move {
184 tokio::select! {
185 _ = token2.cancelled() => {},
186 _ = future => {}
187 }
188 };
189
190 TerminatableTask {
191 handle: Some(rt.spawn(task)),
192 token,
193 }
194 }
195
196 pub fn terminate(&mut self, timeout: Duration) -> bool {
199 ResolveFuture::new(async move {
200 if tokio::time::timeout(timeout, self.terminate_async())
201 .await
202 .is_err()
203 {
204 tracing::error!("Failed to terminate the task");
205 return false;
206 };
207 true
208 })
209 .wait()
210 }
211
212 pub async fn terminate_async(&mut self) {
214 self.token.cancel();
215 if let Some(handle) = self.handle.take() {
216 let _ = handle.await;
217 }
218 }
219}