1use crate::traced::{PollResult, Trace};
4use std::future::Future;
5use std::pin::Pin;
6use std::sync::{Arc, Mutex};
7use std::task::{Context, Poll};
8
9pub struct LabeledFuture<F> {
37 inner: Pin<Box<F>>,
38 label: String,
39 trace: Arc<Mutex<Trace>>,
40 completed: bool,
41}
42
43impl<F: Future> LabeledFuture<F> {
44 pub fn new(inner: F, label: &str, trace: Arc<Mutex<Trace>>) -> Self {
55 Self {
56 inner: Box::pin(inner),
57 label: label.to_string(),
58 trace,
59 completed: false,
60 }
61 }
62}
63
64impl<F: Future> Future for LabeledFuture<F> {
65 type Output = F::Output;
66
67 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
68 let this = self.get_mut();
69 let poll_result = this.inner.as_mut().poll(cx);
70
71 let result = match &poll_result {
72 Poll::Pending => PollResult::Pending,
73 Poll::Ready(_) => PollResult::Ready,
74 };
75
76 if matches!(result, PollResult::Ready) {
77 this.completed = true;
78 }
79
80 this.trace
81 .lock()
82 .expect("trace mutex should not be poisoned")
83 .push(result, Some(this.label.clone()));
84
85 poll_result
86 }
87}
88
89impl<F> Drop for LabeledFuture<F> {
90 fn drop(&mut self) {
91 if !self.completed {
92 if let Ok(mut trace) = self.trace.lock() {
93 let last_was_pending = trace
94 .events
95 .last()
96 .is_some_and(|e| matches!(e.result, PollResult::Pending));
97 if last_was_pending {
98 trace.push(PollResult::Cancelled, Some(self.label.clone()));
99 }
100 }
101 }
102 }
103}
104
105#[macro_export]
123macro_rules! labeled_await {
124 ($fut:expr, $trace:expr) => {{
125 let label = format!("{} @ {}:{}", stringify!($fut), file!(), line!());
126 $crate::LabeledFuture::new($fut, &label, $trace.clone())
127 }};
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133
134 #[tokio::test]
135 async fn labeled_future_records_label() {
136 let trace = Trace::shared();
137 let fut = LabeledFuture::new(async { "hello" }, "greet_step", trace.clone());
138 let val = fut.await;
139 assert_eq!(val, "hello");
140 let trace = trace.lock().unwrap();
141 assert!(!trace.events.is_empty());
142 assert_eq!(
143 trace.events.last().unwrap().label.as_deref(),
144 Some("greet_step")
145 );
146 }
147
148 #[tokio::test]
149 async fn multiple_labeled_futures_share_log() {
150 let trace = Trace::shared();
151
152 let fut1 = LabeledFuture::new(async { 1 }, "step_a", trace.clone());
153 let _ = fut1.await;
154
155 let fut2 = LabeledFuture::new(async { 2 }, "step_b", trace.clone());
156 let _ = fut2.await;
157
158 let trace = trace.lock().unwrap();
159 assert_eq!(trace.events.len(), 2);
160 assert_eq!(trace.events[0].label.as_deref(), Some("step_a"));
161 assert_eq!(trace.events[1].label.as_deref(), Some("step_b"));
162 }
163
164 #[tokio::test]
165 async fn labeled_await_macro() {
166 let trace = Trace::shared();
167 let val = labeled_await!(async { 42 }, trace).await;
168 assert_eq!(val, 42);
169 let trace = trace.lock().unwrap();
170 assert!(!trace.events.is_empty());
171 let label = trace.events[0].label.as_ref().unwrap();
172 assert!(label.contains("labeled.rs"));
173 }
174}