async_reify/traced.rs
1//! [`TracedFuture`]: a future wrapper that records poll events.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9/// The outcome of a single poll.
10///
11/// `Cancelled` is recorded when a future is dropped before completing
12/// (its last poll returned `Pending` and no `Ready` event was ever emitted).
13///
14/// # Examples
15///
16/// ```
17/// use async_reify::PollResult;
18///
19/// let ready = PollResult::Ready;
20/// let pending = PollResult::Pending;
21/// assert_ne!(ready, pending);
22/// ```
23#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub enum PollResult {
26 /// The future returned `Poll::Pending`.
27 Pending,
28 /// The future returned `Poll::Ready`.
29 Ready,
30 /// The future was dropped before completing.
31 Cancelled,
32}
33
34/// A recorded poll event.
35///
36/// `offset` is the elapsed time since the start of the [`Trace`] this
37/// event belongs to, expressed as a [`Duration`]. This makes events
38/// serializable and replayable across processes.
39///
40/// # Examples
41///
42/// ```
43/// use async_reify::{PollEvent, PollResult};
44/// use std::time::Duration;
45///
46/// let event = PollEvent {
47/// step: 0,
48/// offset: Duration::from_micros(150),
49/// result: PollResult::Ready,
50/// label: None,
51/// };
52/// assert_eq!(event.step, 0);
53/// ```
54#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct PollEvent {
57 /// Sequential poll index (0-based).
58 pub step: usize,
59 /// Time elapsed since the start of the parent [`Trace`].
60 pub offset: Duration,
61 /// Whether the poll returned Ready, Pending, or was Cancelled by drop.
62 pub result: PollResult,
63 /// Optional label for this await point.
64 pub label: Option<String>,
65}
66
67/// Collected trace from a [`TracedFuture`].
68///
69/// A `Trace` holds the recorded events plus a reference [`Instant`] used
70/// to compute event offsets. The reference instant is not serialized;
71/// when a `Trace` is deserialized the `start` field is reset to the
72/// deserialization time and only the per-event `offset` values are
73/// authoritative.
74///
75/// # Examples
76///
77/// ```
78/// use async_reify::Trace;
79///
80/// let trace = Trace::new();
81/// assert!(trace.events.is_empty());
82/// ```
83#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85pub struct Trace {
86 /// All poll events in order.
87 pub events: Vec<PollEvent>,
88 /// Reference instant against which each event's `offset` was measured.
89 /// Not serialized: only the offsets persist across (de)serialization.
90 #[cfg_attr(feature = "serde", serde(skip, default = "Instant::now"))]
91 pub start: Instant,
92}
93
94impl Trace {
95 /// Construct an empty trace anchored at the current instant.
96 ///
97 /// # Examples
98 ///
99 /// ```
100 /// use async_reify::Trace;
101 ///
102 /// let t = Trace::new();
103 /// assert!(t.events.is_empty());
104 /// ```
105 pub fn new() -> Self {
106 Self {
107 events: Vec::new(),
108 start: Instant::now(),
109 }
110 }
111
112 /// Return a fresh shared, mutex-protected trace suitable for use
113 /// with [`LabeledFuture`](crate::LabeledFuture).
114 ///
115 /// # Examples
116 ///
117 /// ```
118 /// use async_reify::Trace;
119 ///
120 /// let shared = Trace::shared();
121 /// assert!(shared.lock().unwrap().events.is_empty());
122 /// ```
123 pub fn shared() -> Arc<Mutex<Trace>> {
124 Arc::new(Mutex::new(Trace::new()))
125 }
126
127 /// Append an event to the trace, computing its `offset` from the
128 /// trace's reference instant.
129 pub(crate) fn push(&mut self, result: PollResult, label: Option<String>) {
130 let step = self.events.len();
131 let offset = Instant::now().saturating_duration_since(self.start);
132 self.events.push(PollEvent {
133 step,
134 offset,
135 result,
136 label,
137 });
138 }
139}
140
141impl Default for Trace {
142 fn default() -> Self {
143 Trace::new()
144 }
145}
146
147/// A future wrapper that records each poll as a [`PollEvent`].
148///
149/// Use [`TracedFuture::run`] for a convenient way to execute a future
150/// and collect its trace.
151///
152/// If the wrapped future is dropped before it completes, a final event
153/// with [`PollResult::Cancelled`] is appended to the trace.
154///
155/// # Examples
156///
157/// ```
158/// use async_reify::TracedFuture;
159///
160/// # tokio_test::block_on(async {
161/// let (val, trace) = TracedFuture::run(async { 1 + 1 }).await;
162/// assert_eq!(val, 2);
163/// assert!(!trace.events.is_empty());
164/// # });
165/// ```
166pub struct TracedFuture<F> {
167 inner: Pin<Box<F>>,
168 trace: Arc<Mutex<Trace>>,
169 label: Option<String>,
170 completed: bool,
171}
172
173impl<F: Future> TracedFuture<F> {
174 /// Create a new traced future wrapping `inner`.
175 ///
176 /// # Examples
177 ///
178 /// ```
179 /// use async_reify::TracedFuture;
180 ///
181 /// let traced = TracedFuture::new(async { 42 });
182 /// ```
183 pub fn new(inner: F) -> Self {
184 Self {
185 inner: Box::pin(inner),
186 trace: Trace::shared(),
187 label: None,
188 completed: false,
189 }
190 }
191
192 /// Create a new traced future with a label.
193 ///
194 /// # Examples
195 ///
196 /// ```
197 /// use async_reify::TracedFuture;
198 ///
199 /// let traced = TracedFuture::with_label(async { 42 }, "my_step");
200 /// ```
201 pub fn with_label(inner: F, label: &str) -> Self {
202 Self {
203 inner: Box::pin(inner),
204 trace: Trace::shared(),
205 label: Some(label.to_string()),
206 completed: false,
207 }
208 }
209
210 /// Run the future to completion, returning the result and the trace.
211 ///
212 /// This is a convenience wrapper that polls the future through a
213 /// [`TracedFuture`] and collects all events.
214 ///
215 /// # Examples
216 ///
217 /// ```
218 /// use async_reify::{TracedFuture, PollResult};
219 ///
220 /// # tokio_test::block_on(async {
221 /// let (result, trace) = TracedFuture::run(async { "hello" }).await;
222 /// assert_eq!(result, "hello");
223 /// assert!(matches!(trace.events.last().unwrap().result, PollResult::Ready));
224 /// # });
225 /// ```
226 pub async fn run(inner: F) -> (F::Output, Trace) {
227 let trace = Trace::shared();
228 let traced = TracedFuture {
229 inner: Box::pin(inner),
230 trace: trace.clone(),
231 label: None,
232 completed: false,
233 };
234 let result = traced.await;
235 let trace = Arc::try_unwrap(trace)
236 .expect("trace arc should have single owner")
237 .into_inner()
238 .expect("trace mutex should not be poisoned");
239 (result, trace)
240 }
241}
242
243impl<F: Future> Future for TracedFuture<F> {
244 type Output = F::Output;
245
246 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
247 let this = self.get_mut();
248 let poll_result = this.inner.as_mut().poll(cx);
249
250 let result = match &poll_result {
251 Poll::Pending => PollResult::Pending,
252 Poll::Ready(_) => PollResult::Ready,
253 };
254
255 if matches!(result, PollResult::Ready) {
256 this.completed = true;
257 }
258
259 this.trace
260 .lock()
261 .expect("trace mutex should not be poisoned")
262 .push(result, this.label.clone());
263
264 poll_result
265 }
266}
267
268impl<F> Drop for TracedFuture<F> {
269 fn drop(&mut self) {
270 if !self.completed {
271 if let Ok(mut trace) = self.trace.lock() {
272 let last_was_pending = trace
273 .events
274 .last()
275 .is_some_and(|e| matches!(e.result, PollResult::Pending));
276 if last_was_pending {
277 trace.push(PollResult::Cancelled, self.label.clone());
278 }
279 }
280 }
281 }
282}
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[tokio::test]
289 async fn trace_immediate_future() {
290 let (val, trace) = TracedFuture::run(async { 42 }).await;
291 assert_eq!(val, 42);
292 assert_eq!(trace.events.len(), 1);
293 assert_eq!(trace.events[0].result, PollResult::Ready);
294 assert_eq!(trace.events[0].step, 0);
295 }
296
297 #[tokio::test]
298 async fn trace_multi_step() {
299 let (val, trace) = TracedFuture::run(async {
300 tokio::task::yield_now().await;
301 tokio::task::yield_now().await;
302 99
303 })
304 .await;
305 assert_eq!(val, 99);
306 assert!(trace.events.len() >= 3);
307 assert_eq!(trace.events.last().unwrap().result, PollResult::Ready);
308 }
309
310 #[tokio::test]
311 async fn with_label() {
312 let traced = TracedFuture::with_label(async { 1 }, "test_step");
313 let trace = traced.trace.clone();
314 let _ = traced.await;
315 let trace = trace.lock().unwrap();
316 assert_eq!(trace.events[0].label.as_deref(), Some("test_step"));
317 }
318
319 #[tokio::test]
320 async fn dropped_pending_future_is_cancelled() {
321 let trace_shared = Trace::shared();
322
323 // Build a future that yields once (recording Pending) then would return.
324 // Drop it after the first poll to simulate cancellation.
325 struct YieldOnce {
326 yielded: bool,
327 }
328 impl Future for YieldOnce {
329 type Output = ();
330 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
331 if self.yielded {
332 Poll::Ready(())
333 } else {
334 self.yielded = true;
335 cx.waker().wake_by_ref();
336 Poll::Pending
337 }
338 }
339 }
340
341 let mut traced = TracedFuture {
342 inner: Box::pin(YieldOnce { yielded: false }),
343 trace: trace_shared.clone(),
344 label: Some("drop_me".into()),
345 completed: false,
346 };
347
348 // Manually poll once (record Pending) without driving to completion.
349 let waker = futures_task::noop_waker();
350 let mut cx = Context::from_waker(&waker);
351 let _ = Pin::new(&mut traced).poll(&mut cx);
352 drop(traced);
353
354 let trace = trace_shared.lock().unwrap();
355 assert!(trace.events.iter().any(|e| e.result == PollResult::Pending));
356 assert!(
357 trace
358 .events
359 .iter()
360 .any(|e| e.result == PollResult::Cancelled),
361 "expected a Cancelled event after drop, got {:?}",
362 trace.events
363 );
364 }
365
366 #[cfg(feature = "serde")]
367 #[tokio::test]
368 async fn trace_round_trip_serde() {
369 let (_, trace) = TracedFuture::run(async {
370 tokio::task::yield_now().await;
371 7
372 })
373 .await;
374 let json = serde_json::to_string(&trace).expect("serialize");
375 let restored: Trace = serde_json::from_str(&json).expect("deserialize");
376 assert_eq!(restored.events.len(), trace.events.len());
377 for (a, b) in trace.events.iter().zip(restored.events.iter()) {
378 assert_eq!(a.step, b.step);
379 assert_eq!(a.offset, b.offset);
380 assert_eq!(a.result, b.result);
381 assert_eq!(a.label, b.label);
382 }
383 }
384}