async_inspect/runtime/
tokio.rs

1//! Tokio runtime integration
2//!
3//! This module provides automatic tracking for Tokio tasks.
4
5use crate::inspector::Inspector;
6use crate::instrument::{clear_current_task_id, set_current_task_id};
7use crate::task::TaskId;
8use std::future::Future;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use std::time::Instant;
12
13/// Spawn a task with automatic tracking
14///
15/// This is a drop-in replacement for `tokio::spawn()` that automatically
16/// tracks the spawned task.
17///
18/// # Examples
19///
20/// ```rust,ignore
21/// use async_inspect::runtime::tokio::spawn_tracked;
22///
23/// spawn_tracked("background_task", async {
24///     // Your code here - automatically tracked!
25///     println!("Task running");
26/// });
27/// ```
28pub fn spawn_tracked<F, T>(name: T, future: F) -> tokio::task::JoinHandle<F::Output>
29where
30    F: Future + Send + 'static,
31    F::Output: Send + 'static,
32    T: Into<String>,
33{
34    use crate::instrument::current_task_id;
35
36    let task_name = name.into();
37
38    // Check if there's a parent task
39    let task_id = if let Some(parent_id) = current_task_id() {
40        Inspector::global().register_child_task(task_name, parent_id)
41    } else {
42        Inspector::global().register_task(task_name)
43    };
44
45    tokio::spawn(async move {
46        // Set task context for this task
47        set_current_task_id(task_id);
48
49        // Wrap execution to track completion
50        let result = future.await;
51
52        // Mark as completed
53        Inspector::global().task_completed(task_id);
54
55        // Clear context
56        clear_current_task_id();
57
58        result
59    })
60}
61
62/// A future wrapper that automatically tracks execution
63///
64/// This wrapper tracks polls, completion, and can be used with any future.
65pub struct TrackedFuture<F> {
66    future: F,
67    task_id: TaskId,
68    started: bool,
69    poll_start: Option<Instant>,
70}
71
72impl<F> TrackedFuture<F> {
73    /// Create a new tracked future
74    pub fn new(future: F, name: String) -> Self {
75        let task_id = Inspector::global().register_task(name);
76
77        Self {
78            future,
79            task_id,
80            started: false,
81            poll_start: None,
82        }
83    }
84
85    /// Get the task ID
86    pub fn task_id(&self) -> TaskId {
87        self.task_id
88    }
89}
90
91impl<F: Future> Future for TrackedFuture<F> {
92    type Output = F::Output;
93
94    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
95        // SAFETY: We don't move the future
96        let this = unsafe { self.get_unchecked_mut() };
97
98        // Set task context
99        set_current_task_id(this.task_id);
100
101        // Record poll start
102        if !this.started {
103            this.started = true;
104        }
105
106        let poll_start = Instant::now();
107        this.poll_start = Some(poll_start);
108
109        Inspector::global().poll_started(this.task_id);
110
111        // Poll the inner future
112        // SAFETY: We're pinning the projection
113        let result = unsafe { Pin::new_unchecked(&mut this.future).poll(cx) };
114
115        // Record poll end
116        let poll_duration = poll_start.elapsed();
117        Inspector::global().poll_ended(this.task_id, poll_duration);
118
119        match result {
120            Poll::Ready(output) => {
121                // Task completed
122                Inspector::global().task_completed(this.task_id);
123                clear_current_task_id();
124                Poll::Ready(output)
125            }
126            Poll::Pending => {
127                // Still pending
128                Poll::Pending
129            }
130        }
131    }
132}
133
134/// Extension trait for futures to enable `.inspect()` syntax
135///
136/// # Examples
137///
138/// ```rust,ignore
139/// use async_inspect::runtime::tokio::InspectExt;
140///
141/// let result = fetch_data()
142///     .inspect("fetch_data")
143///     .await;
144/// ```
145pub trait InspectExt: Future + Sized {
146    /// Wrap this future with automatic tracking
147    fn inspect(self, name: impl Into<String>) -> TrackedFuture<Self> {
148        TrackedFuture::new(self, name.into())
149    }
150
151    /// Spawn this future on Tokio with tracking
152    fn spawn_tracked(self, name: impl Into<String>) -> tokio::task::JoinHandle<Self::Output>
153    where
154        Self: Send + 'static,
155        Self::Output: Send + 'static,
156    {
157        spawn_tracked(name, self)
158    }
159}
160
161// Implement for all futures
162impl<F: Future> InspectExt for F {}
163
164/// Spawn a local task with automatic tracking (for !Send futures)
165///
166/// This is similar to `spawn_tracked` but for `!Send` futures on a `LocalSet`.
167///
168/// # Examples
169///
170/// ```rust,ignore
171/// use async_inspect::runtime::tokio::spawn_local_tracked;
172///
173/// tokio::task::LocalSet::new().run_until(async {
174///     spawn_local_tracked("local_task", async {
175///         // !Send future
176///     });
177/// }).await;
178/// ```
179#[cfg(feature = "tokio")]
180pub fn spawn_local_tracked<F, T>(name: T, future: F) -> tokio::task::JoinHandle<F::Output>
181where
182    F: Future + 'static,
183    F::Output: 'static,
184    T: Into<String>,
185{
186    let task_name = name.into();
187    let task_id = Inspector::global().register_task(task_name);
188
189    tokio::task::spawn_local(async move {
190        set_current_task_id(task_id);
191
192        let result = future.await;
193
194        Inspector::global().task_completed(task_id);
195        clear_current_task_id();
196
197        result
198    })
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[tokio::test]
206    async fn test_spawn_tracked() {
207        let handle = spawn_tracked("test_spawn_tracked_task", async { 42 });
208
209        let result = handle.await.unwrap();
210        assert_eq!(result, 42);
211
212        // Verify task was tracked
213        let tasks = Inspector::global().get_all_tasks();
214        assert!(tasks.iter().any(|t| t.name == "test_spawn_tracked_task"));
215    }
216
217    #[tokio::test]
218    async fn test_inspect_ext() {
219        async fn example_task() -> i32 {
220            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
221            123
222        }
223
224        let result = example_task().inspect("test_inspect_ext_task").await;
225
226        assert_eq!(result, 123);
227
228        // Verify task was tracked
229        let tasks = Inspector::global().get_all_tasks();
230        assert!(tasks.iter().any(|t| t.name == "test_inspect_ext_task"));
231    }
232
233    #[tokio::test]
234    async fn test_tracked_future() {
235        let future = async {
236            tokio::time::sleep(std::time::Duration::from_millis(5)).await;
237            "done"
238        };
239
240        let tracked = TrackedFuture::new(future, "test_tracked_future_task".to_string());
241        let task_id = tracked.task_id();
242
243        let result = tracked.await;
244        assert_eq!(result, "done");
245
246        let task = Inspector::global().get_task(task_id).unwrap();
247        assert!(task.poll_count > 0);
248    }
249
250    #[tokio::test]
251    async fn test_spawn_tracked_multiple() {
252        let handles: Vec<_> = (0..5)
253            .map(|i| spawn_tracked(format!("test_multi_task_{i}"), async move { i * 2 }))
254            .collect();
255
256        for (i, handle) in handles.into_iter().enumerate() {
257            let result = handle.await.unwrap();
258            assert_eq!(result, i * 2);
259        }
260
261        // Verify all tasks were tracked
262        let tasks = Inspector::global().get_all_tasks();
263        for i in 0..5 {
264            assert!(tasks
265                .iter()
266                .any(|t| t.name == format!("test_multi_task_{i}")));
267        }
268    }
269}