async_inspect/runtime/
tokio.rs1use crate::inspector::Inspector;
6use crate::instrument::{clear_current_task_id, set_current_task_id};
7use crate::task::TaskId;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13pub fn spawn_tracked<F, T>(name: T, future: F) -> tokio::task::JoinHandle<F::Output>
29where
30 F: Future + Send + 'static,
31 F::Output: Send + 'static,
32 T: Into<String>,
33{
34 use crate::instrument::current_task_id;
35
36 let task_name = name.into();
37
38 let task_id = if let Some(parent_id) = current_task_id() {
40 Inspector::global().register_child_task(task_name, parent_id)
41 } else {
42 Inspector::global().register_task(task_name)
43 };
44
45 tokio::spawn(async move {
46 set_current_task_id(task_id);
48
49 let result = future.await;
51
52 Inspector::global().task_completed(task_id);
54
55 clear_current_task_id();
57
58 result
59 })
60}
61
62pub struct TrackedFuture<F> {
66 future: F,
67 task_id: TaskId,
68 started: bool,
69 poll_start: Option<Instant>,
70}
71
72impl<F> TrackedFuture<F> {
73 pub fn new(future: F, name: String) -> Self {
75 let task_id = Inspector::global().register_task(name);
76
77 Self {
78 future,
79 task_id,
80 started: false,
81 poll_start: None,
82 }
83 }
84
85 pub fn task_id(&self) -> TaskId {
87 self.task_id
88 }
89}
90
91impl<F: Future> Future for TrackedFuture<F> {
92 type Output = F::Output;
93
94 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95 let this = unsafe { self.get_unchecked_mut() };
97
98 set_current_task_id(this.task_id);
100
101 if !this.started {
103 this.started = true;
104 }
105
106 let poll_start = Instant::now();
107 this.poll_start = Some(poll_start);
108
109 Inspector::global().poll_started(this.task_id);
110
111 let result = unsafe { Pin::new_unchecked(&mut this.future).poll(cx) };
114
115 let poll_duration = poll_start.elapsed();
117 Inspector::global().poll_ended(this.task_id, poll_duration);
118
119 match result {
120 Poll::Ready(output) => {
121 Inspector::global().task_completed(this.task_id);
123 clear_current_task_id();
124 Poll::Ready(output)
125 }
126 Poll::Pending => {
127 Poll::Pending
129 }
130 }
131 }
132}
133
134pub trait InspectExt: Future + Sized {
146 fn inspect(self, name: impl Into<String>) -> TrackedFuture<Self> {
148 TrackedFuture::new(self, name.into())
149 }
150
151 fn spawn_tracked(self, name: impl Into<String>) -> tokio::task::JoinHandle<Self::Output>
153 where
154 Self: Send + 'static,
155 Self::Output: Send + 'static,
156 {
157 spawn_tracked(name, self)
158 }
159}
160
161impl<F: Future> InspectExt for F {}
163
164#[cfg(feature = "tokio")]
180pub fn spawn_local_tracked<F, T>(name: T, future: F) -> tokio::task::JoinHandle<F::Output>
181where
182 F: Future + 'static,
183 F::Output: 'static,
184 T: Into<String>,
185{
186 let task_name = name.into();
187 let task_id = Inspector::global().register_task(task_name);
188
189 tokio::task::spawn_local(async move {
190 set_current_task_id(task_id);
191
192 let result = future.await;
193
194 Inspector::global().task_completed(task_id);
195 clear_current_task_id();
196
197 result
198 })
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[tokio::test]
206 async fn test_spawn_tracked() {
207 let handle = spawn_tracked("test_spawn_tracked_task", async { 42 });
208
209 let result = handle.await.unwrap();
210 assert_eq!(result, 42);
211
212 let tasks = Inspector::global().get_all_tasks();
214 assert!(tasks.iter().any(|t| t.name == "test_spawn_tracked_task"));
215 }
216
217 #[tokio::test]
218 async fn test_inspect_ext() {
219 async fn example_task() -> i32 {
220 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
221 123
222 }
223
224 let result = example_task().inspect("test_inspect_ext_task").await;
225
226 assert_eq!(result, 123);
227
228 let tasks = Inspector::global().get_all_tasks();
230 assert!(tasks.iter().any(|t| t.name == "test_inspect_ext_task"));
231 }
232
233 #[tokio::test]
234 async fn test_tracked_future() {
235 let future = async {
236 tokio::time::sleep(std::time::Duration::from_millis(5)).await;
237 "done"
238 };
239
240 let tracked = TrackedFuture::new(future, "test_tracked_future_task".to_string());
241 let task_id = tracked.task_id();
242
243 let result = tracked.await;
244 assert_eq!(result, "done");
245
246 let task = Inspector::global().get_task(task_id).unwrap();
247 assert!(task.poll_count > 0);
248 }
249
250 #[tokio::test]
251 async fn test_spawn_tracked_multiple() {
252 let handles: Vec<_> = (0..5)
253 .map(|i| spawn_tracked(format!("test_multi_task_{i}"), async move { i * 2 }))
254 .collect();
255
256 for (i, handle) in handles.into_iter().enumerate() {
257 let result = handle.await.unwrap();
258 assert_eq!(result, i * 2);
259 }
260
261 let tasks = Inspector::global().get_all_tasks();
263 for i in 0..5 {
264 assert!(tasks
265 .iter()
266 .any(|t| t.name == format!("test_multi_task_{i}")));
267 }
268 }
269}