Skip to main content

async_reify/
traced.rs

1//! [`TracedFuture`]: a future wrapper that records poll events.
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{Arc, Mutex};
6use std::task::{Context, Poll};
7use std::time::{Duration, Instant};
8
9/// The outcome of a single poll.
10///
11/// `Cancelled` is recorded when a future is dropped before completing
12/// (its last poll returned `Pending` and no `Ready` event was ever emitted).
13///
14/// # Examples
15///
16/// ```
17/// use async_reify::PollResult;
18///
19/// let ready = PollResult::Ready;
20/// let pending = PollResult::Pending;
21/// assert_ne!(ready, pending);
22/// ```
23#[derive(Debug, Clone, PartialEq, Eq)]
24#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
25pub enum PollResult {
26    /// The future returned `Poll::Pending`.
27    Pending,
28    /// The future returned `Poll::Ready`.
29    Ready,
30    /// The future was dropped before completing.
31    Cancelled,
32}
33
34/// A recorded poll event.
35///
36/// `offset` is the elapsed time since the start of the [`Trace`] this
37/// event belongs to, expressed as a [`Duration`]. This makes events
38/// serializable and replayable across processes.
39///
40/// # Examples
41///
42/// ```
43/// use async_reify::{PollEvent, PollResult};
44/// use std::time::Duration;
45///
46/// let event = PollEvent {
47///     step: 0,
48///     offset: Duration::from_micros(150),
49///     result: PollResult::Ready,
50///     label: None,
51/// };
52/// assert_eq!(event.step, 0);
53/// ```
54#[derive(Debug, Clone)]
55#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
56pub struct PollEvent {
57    /// Sequential poll index (0-based).
58    pub step: usize,
59    /// Time elapsed since the start of the parent [`Trace`].
60    pub offset: Duration,
61    /// Whether the poll returned Ready, Pending, or was Cancelled by drop.
62    pub result: PollResult,
63    /// Optional label for this await point.
64    pub label: Option<String>,
65}
66
67/// Collected trace from a [`TracedFuture`].
68///
69/// A `Trace` holds the recorded events plus a reference [`Instant`] used
70/// to compute event offsets. The reference instant is not serialized;
71/// when a `Trace` is deserialized the `start` field is reset to the
72/// deserialization time and only the per-event `offset` values are
73/// authoritative.
74///
75/// # Examples
76///
77/// ```
78/// use async_reify::Trace;
79///
80/// let trace = Trace::new();
81/// assert!(trace.events.is_empty());
82/// ```
83#[derive(Debug, Clone)]
84#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
85pub struct Trace {
86    /// All poll events in order.
87    pub events: Vec<PollEvent>,
88    /// Reference instant against which each event's `offset` was measured.
89    /// Not serialized: only the offsets persist across (de)serialization.
90    #[cfg_attr(feature = "serde", serde(skip, default = "Instant::now"))]
91    pub start: Instant,
92}
93
94impl Trace {
95    /// Construct an empty trace anchored at the current instant.
96    ///
97    /// # Examples
98    ///
99    /// ```
100    /// use async_reify::Trace;
101    ///
102    /// let t = Trace::new();
103    /// assert!(t.events.is_empty());
104    /// ```
105    pub fn new() -> Self {
106        Self {
107            events: Vec::new(),
108            start: Instant::now(),
109        }
110    }
111
112    /// Return a fresh shared, mutex-protected trace suitable for use
113    /// with [`LabeledFuture`](crate::LabeledFuture).
114    ///
115    /// # Examples
116    ///
117    /// ```
118    /// use async_reify::Trace;
119    ///
120    /// let shared = Trace::shared();
121    /// assert!(shared.lock().unwrap().events.is_empty());
122    /// ```
123    pub fn shared() -> Arc<Mutex<Trace>> {
124        Arc::new(Mutex::new(Trace::new()))
125    }
126
127    /// Append an event to the trace, computing its `offset` from the
128    /// trace's reference instant.
129    pub(crate) fn push(&mut self, result: PollResult, label: Option<String>) {
130        let step = self.events.len();
131        let offset = Instant::now().saturating_duration_since(self.start);
132        self.events.push(PollEvent {
133            step,
134            offset,
135            result,
136            label,
137        });
138    }
139}
140
141impl Default for Trace {
142    fn default() -> Self {
143        Trace::new()
144    }
145}
146
147/// A future wrapper that records each poll as a [`PollEvent`].
148///
149/// Use [`TracedFuture::run`] for a convenient way to execute a future
150/// and collect its trace.
151///
152/// If the wrapped future is dropped before it completes, a final event
153/// with [`PollResult::Cancelled`] is appended to the trace.
154///
155/// # Examples
156///
157/// ```
158/// use async_reify::TracedFuture;
159///
160/// # tokio_test::block_on(async {
161/// let (val, trace) = TracedFuture::run(async { 1 + 1 }).await;
162/// assert_eq!(val, 2);
163/// assert!(!trace.events.is_empty());
164/// # });
165/// ```
166pub struct TracedFuture<F> {
167    inner: Pin<Box<F>>,
168    trace: Arc<Mutex<Trace>>,
169    label: Option<String>,
170    completed: bool,
171}
172
173impl<F: Future> TracedFuture<F> {
174    /// Create a new traced future wrapping `inner`.
175    ///
176    /// # Examples
177    ///
178    /// ```
179    /// use async_reify::TracedFuture;
180    ///
181    /// let traced = TracedFuture::new(async { 42 });
182    /// ```
183    pub fn new(inner: F) -> Self {
184        Self {
185            inner: Box::pin(inner),
186            trace: Trace::shared(),
187            label: None,
188            completed: false,
189        }
190    }
191
192    /// Create a new traced future with a label.
193    ///
194    /// # Examples
195    ///
196    /// ```
197    /// use async_reify::TracedFuture;
198    ///
199    /// let traced = TracedFuture::with_label(async { 42 }, "my_step");
200    /// ```
201    pub fn with_label(inner: F, label: &str) -> Self {
202        Self {
203            inner: Box::pin(inner),
204            trace: Trace::shared(),
205            label: Some(label.to_string()),
206            completed: false,
207        }
208    }
209
210    /// Run the future to completion, returning the result and the trace.
211    ///
212    /// This is a convenience wrapper that polls the future through a
213    /// [`TracedFuture`] and collects all events.
214    ///
215    /// # Examples
216    ///
217    /// ```
218    /// use async_reify::{TracedFuture, PollResult};
219    ///
220    /// # tokio_test::block_on(async {
221    /// let (result, trace) = TracedFuture::run(async { "hello" }).await;
222    /// assert_eq!(result, "hello");
223    /// assert!(matches!(trace.events.last().unwrap().result, PollResult::Ready));
224    /// # });
225    /// ```
226    pub async fn run(inner: F) -> (F::Output, Trace) {
227        let trace = Trace::shared();
228        let traced = TracedFuture {
229            inner: Box::pin(inner),
230            trace: trace.clone(),
231            label: None,
232            completed: false,
233        };
234        let result = traced.await;
235        let trace = Arc::try_unwrap(trace)
236            .expect("trace arc should have single owner")
237            .into_inner()
238            .expect("trace mutex should not be poisoned");
239        (result, trace)
240    }
241}
242
243impl<F: Future> Future for TracedFuture<F> {
244    type Output = F::Output;
245
246    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
247        let this = self.get_mut();
248        let poll_result = this.inner.as_mut().poll(cx);
249
250        let result = match &poll_result {
251            Poll::Pending => PollResult::Pending,
252            Poll::Ready(_) => PollResult::Ready,
253        };
254
255        if matches!(result, PollResult::Ready) {
256            this.completed = true;
257        }
258
259        this.trace
260            .lock()
261            .expect("trace mutex should not be poisoned")
262            .push(result, this.label.clone());
263
264        poll_result
265    }
266}
267
268impl<F> Drop for TracedFuture<F> {
269    fn drop(&mut self) {
270        if !self.completed {
271            if let Ok(mut trace) = self.trace.lock() {
272                let last_was_pending = trace
273                    .events
274                    .last()
275                    .is_some_and(|e| matches!(e.result, PollResult::Pending));
276                if last_was_pending {
277                    trace.push(PollResult::Cancelled, self.label.clone());
278                }
279            }
280        }
281    }
282}
283
284#[cfg(test)]
285mod tests {
286    use super::*;
287
288    #[tokio::test]
289    async fn trace_immediate_future() {
290        let (val, trace) = TracedFuture::run(async { 42 }).await;
291        assert_eq!(val, 42);
292        assert_eq!(trace.events.len(), 1);
293        assert_eq!(trace.events[0].result, PollResult::Ready);
294        assert_eq!(trace.events[0].step, 0);
295    }
296
297    #[tokio::test]
298    async fn trace_multi_step() {
299        let (val, trace) = TracedFuture::run(async {
300            tokio::task::yield_now().await;
301            tokio::task::yield_now().await;
302            99
303        })
304        .await;
305        assert_eq!(val, 99);
306        assert!(trace.events.len() >= 3);
307        assert_eq!(trace.events.last().unwrap().result, PollResult::Ready);
308    }
309
310    #[tokio::test]
311    async fn with_label() {
312        let traced = TracedFuture::with_label(async { 1 }, "test_step");
313        let trace = traced.trace.clone();
314        let _ = traced.await;
315        let trace = trace.lock().unwrap();
316        assert_eq!(trace.events[0].label.as_deref(), Some("test_step"));
317    }
318
319    #[tokio::test]
320    async fn dropped_pending_future_is_cancelled() {
321        let trace_shared = Trace::shared();
322
323        // Build a future that yields once (recording Pending) then would return.
324        // Drop it after the first poll to simulate cancellation.
325        struct YieldOnce {
326            yielded: bool,
327        }
328        impl Future for YieldOnce {
329            type Output = ();
330            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
331                if self.yielded {
332                    Poll::Ready(())
333                } else {
334                    self.yielded = true;
335                    cx.waker().wake_by_ref();
336                    Poll::Pending
337                }
338            }
339        }
340
341        let mut traced = TracedFuture {
342            inner: Box::pin(YieldOnce { yielded: false }),
343            trace: trace_shared.clone(),
344            label: Some("drop_me".into()),
345            completed: false,
346        };
347
348        // Manually poll once (record Pending) without driving to completion.
349        let waker = futures_task::noop_waker();
350        let mut cx = Context::from_waker(&waker);
351        let _ = Pin::new(&mut traced).poll(&mut cx);
352        drop(traced);
353
354        let trace = trace_shared.lock().unwrap();
355        assert!(trace.events.iter().any(|e| e.result == PollResult::Pending));
356        assert!(
357            trace
358                .events
359                .iter()
360                .any(|e| e.result == PollResult::Cancelled),
361            "expected a Cancelled event after drop, got {:?}",
362            trace.events
363        );
364    }
365
366    #[cfg(feature = "serde")]
367    #[tokio::test]
368    async fn trace_round_trip_serde() {
369        let (_, trace) = TracedFuture::run(async {
370            tokio::task::yield_now().await;
371            7
372        })
373        .await;
374        let json = serde_json::to_string(&trace).expect("serialize");
375        let restored: Trace = serde_json::from_str(&json).expect("deserialize");
376        assert_eq!(restored.events.len(), trace.events.len());
377        for (a, b) in trace.events.iter().zip(restored.events.iter()) {
378            assert_eq!(a.step, b.step);
379            assert_eq!(a.offset, b.offset);
380            assert_eq!(a.result, b.result);
381            assert_eq!(a.label, b.label);
382        }
383    }
384}