1use std::future::Future;
19use std::pin::Pin;
20use std::task::{Context, Poll};
21use std::time::{Duration, Instant};
22
23use dev_report::{CheckResult, Evidence, Severity};
24
25pub async fn detect_blocking<F, T>(
52 name: impl Into<String>,
53 max_no_yield: Duration,
54 fut: F,
55) -> (CheckResult, T)
56where
57 F: Future<Output = T>,
58{
59 let name = name.into();
60 let started = Instant::now();
61 let monitor = BlockingMonitor::new(fut, max_no_yield);
62 tokio::pin!(monitor);
63 let value = monitor.as_mut().await;
64 let elapsed = started.elapsed();
65 let max_observed = monitor.max_observed_no_yield();
66
67 let evidence = vec![
68 Evidence::numeric("elapsed_ms", elapsed.as_millis() as f64),
69 Evidence::numeric("max_no_yield_ms", max_observed.as_millis() as f64),
70 Evidence::numeric("threshold_ms", max_no_yield.as_millis() as f64),
71 ];
72
73 let check = if max_observed > max_no_yield {
74 let mut c =
75 CheckResult::warn(format!("async::{name}"), Severity::Warning).with_detail(format!(
76 "longest non-yielding poll was {:?}, exceeds threshold {:?}",
77 max_observed, max_no_yield
78 ));
79 c.tags = vec!["async".to_string(), "blocking_suspected".to_string()];
80 c.evidence = evidence;
81 c
82 } else {
83 let mut c = CheckResult::pass(format!("async::{name}"))
84 .with_duration_ms(elapsed.as_millis() as u64);
85 c.tags = vec!["async".to_string()];
86 c.evidence = evidence;
87 c
88 };
89 (check, value)
90}
91
92pin_project_lite::pin_project! {
93 struct BlockingMonitor<F: Future> {
94 #[pin]
95 inner: F,
96 threshold: Duration,
97 max_observed: Duration,
98 }
99}
100
101impl<F: Future> BlockingMonitor<F> {
102 fn new(inner: F, threshold: Duration) -> Self {
103 Self {
104 inner,
105 threshold,
106 max_observed: Duration::ZERO,
107 }
108 }
109
110 fn max_observed_no_yield(self: Pin<&mut Self>) -> Duration {
111 *self.project().max_observed
112 }
113}
114
115impl<F: Future> Future for BlockingMonitor<F> {
116 type Output = F::Output;
117
118 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<F::Output> {
119 let this = self.project();
120 let started = Instant::now();
121 let result = this.inner.poll(cx);
122 let elapsed = started.elapsed();
123 if elapsed > *this.max_observed {
124 *this.max_observed = elapsed;
125 }
126 result
127 }
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use dev_report::Verdict;
134
135 #[tokio::test]
136 async fn fast_future_passes() {
137 let (check, v) = detect_blocking("op", Duration::from_millis(50), async { 42 }).await;
138 assert_eq!(check.verdict, Verdict::Pass);
139 assert_eq!(v, 42);
140 }
141
142 #[tokio::test]
143 async fn long_blocking_section_warns() {
144 let (check, _) = detect_blocking("op", Duration::from_millis(5), async {
145 std::thread::sleep(Duration::from_millis(20));
147 })
148 .await;
149 assert_eq!(check.verdict, Verdict::Warn);
150 assert!(check.has_tag("blocking_suspected"));
151 }
152
153 #[tokio::test]
154 async fn evidence_includes_max_no_yield() {
155 let (check, _) = detect_blocking("op", Duration::from_millis(50), async {}).await;
156 let labels: Vec<&str> = check.evidence.iter().map(|e| e.label.as_str()).collect();
157 assert!(labels.contains(&"max_no_yield_ms"));
158 assert!(labels.contains(&"threshold_ms"));
159 }
160}