Skip to main content

async_reify/
labeled.rs

1//! [`LabeledFuture`]: attach source-location labels to await points.
2
3use crate::traced::{PollResult, Trace};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8
9/// A future that records poll events with a source label into a shared
10/// [`Trace`].
11///
12/// Unlike [`TracedFuture`](crate::TracedFuture), `LabeledFuture` pushes
13/// events into a caller-supplied [`Trace`] (wrapped in `Arc<Mutex<_>>`),
14/// allowing multiple labeled futures to contribute to a single trace
15/// with a shared time origin.
16///
17/// If the wrapped future is dropped before completing, a final
18/// [`PollResult::Cancelled`] event is recorded.
19///
20/// # Examples
21///
22/// ```
23/// use async_reify::{LabeledFuture, Trace};
24///
25/// # tokio_test::block_on(async {
26/// let trace = Trace::shared();
27/// let fut = LabeledFuture::new(async { 42 }, "fetch_data", trace.clone());
28/// let val = fut.await;
29/// assert_eq!(val, 42);
30/// assert_eq!(
31///     trace.lock().unwrap().events[0].label.as_deref(),
32///     Some("fetch_data"),
33/// );
34/// # });
35/// ```
36pub struct LabeledFuture<F> {
37    inner: Pin<Box<F>>,
38    label: String,
39    trace: Arc<Mutex<Trace>>,
40    completed: bool,
41}
42
43impl<F: Future> LabeledFuture<F> {
44    /// Create a labeled future that logs to the shared [`Trace`].
45    ///
46    /// # Examples
47    ///
48    /// ```
49    /// use async_reify::{LabeledFuture, Trace};
50    ///
51    /// let trace = Trace::shared();
52    /// let _fut = LabeledFuture::new(async { 1 }, "step_1", trace);
53    /// ```
54    pub fn new(inner: F, label: &str, trace: Arc<Mutex<Trace>>) -> Self {
55        Self {
56            inner: Box::pin(inner),
57            label: label.to_string(),
58            trace,
59            completed: false,
60        }
61    }
62}
63
64impl<F: Future> Future for LabeledFuture<F> {
65    type Output = F::Output;
66
67    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68        let this = self.get_mut();
69        let poll_result = this.inner.as_mut().poll(cx);
70
71        let result = match &poll_result {
72            Poll::Pending => PollResult::Pending,
73            Poll::Ready(_) => PollResult::Ready,
74        };
75
76        if matches!(result, PollResult::Ready) {
77            this.completed = true;
78        }
79
80        this.trace
81            .lock()
82            .expect("trace mutex should not be poisoned")
83            .push(result, Some(this.label.clone()));
84
85        poll_result
86    }
87}
88
89impl<F> Drop for LabeledFuture<F> {
90    fn drop(&mut self) {
91        if !self.completed {
92            if let Ok(mut trace) = self.trace.lock() {
93                let last_was_pending = trace
94                    .events
95                    .last()
96                    .is_some_and(|e| matches!(e.result, PollResult::Pending));
97                if last_was_pending {
98                    trace.push(PollResult::Cancelled, Some(self.label.clone()));
99                }
100            }
101        }
102    }
103}
104
105/// Helper macro to create a [`LabeledFuture`] with automatic source labeling.
106///
107/// The label is derived from the expression text and the file/line number.
108///
109/// # Examples
110///
111/// ```
112/// use async_reify::{labeled_await, LabeledFuture, Trace};
113///
114/// # tokio_test::block_on(async {
115/// let trace = Trace::shared();
116/// let val = labeled_await!(async { 42 }, trace).await;
117/// assert_eq!(val, 42);
118/// let label = trace.lock().unwrap().events[0].label.as_ref().unwrap().clone();
119/// assert!(label.contains("labeled.rs")); // contains source file
120/// # });
121/// ```
122#[macro_export]
123macro_rules! labeled_await {
124    ($fut:expr, $trace:expr) => {{
125        let label = format!("{} @ {}:{}", stringify!($fut), file!(), line!());
126        $crate::LabeledFuture::new($fut, &label, $trace.clone())
127    }};
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133
134    #[tokio::test]
135    async fn labeled_future_records_label() {
136        let trace = Trace::shared();
137        let fut = LabeledFuture::new(async { "hello" }, "greet_step", trace.clone());
138        let val = fut.await;
139        assert_eq!(val, "hello");
140        let trace = trace.lock().unwrap();
141        assert!(!trace.events.is_empty());
142        assert_eq!(
143            trace.events.last().unwrap().label.as_deref(),
144            Some("greet_step")
145        );
146    }
147
148    #[tokio::test]
149    async fn multiple_labeled_futures_share_log() {
150        let trace = Trace::shared();
151
152        let fut1 = LabeledFuture::new(async { 1 }, "step_a", trace.clone());
153        let _ = fut1.await;
154
155        let fut2 = LabeledFuture::new(async { 2 }, "step_b", trace.clone());
156        let _ = fut2.await;
157
158        let trace = trace.lock().unwrap();
159        assert_eq!(trace.events.len(), 2);
160        assert_eq!(trace.events[0].label.as_deref(), Some("step_a"));
161        assert_eq!(trace.events[1].label.as_deref(), Some("step_b"));
162    }
163
164    #[tokio::test]
165    async fn labeled_await_macro() {
166        let trace = Trace::shared();
167        let val = labeled_await!(async { 42 }, trace).await;
168        assert_eq!(val, 42);
169        let trace = trace.lock().unwrap();
170        assert!(!trace.events.is_empty());
171        let label = trace.events[0].label.as_ref().unwrap();
172        assert!(label.contains("labeled.rs"));
173    }
174}