1use std::future::Future;
8use std::pin::Pin;
9use std::time::{Duration, Instant};
10
11use dev_report::{CheckResult, Evidence, Severity};
12
13pub type DrainCheck = Box<dyn Fn() -> Pin<Box<dyn Future<Output = bool> + Send>> + Send + Sync>;
20
21pub struct ShutdownComponent {
23 name: String,
24 drain_check: DrainCheck,
25}
26
27impl ShutdownComponent {
28 pub fn new<F, Fut>(name: impl Into<String>, drain_check: F) -> Self
33 where
34 F: Fn() -> Fut + Send + Sync + 'static,
35 Fut: Future<Output = bool> + Send + 'static,
36 {
37 let drain_check: DrainCheck = Box::new(move || Box::pin(drain_check()));
38 Self {
39 name: name.into(),
40 drain_check,
41 }
42 }
43}
44
45pub struct ShutdownProbe {
76 name: String,
77 components: Vec<ShutdownComponent>,
78 deadline: Duration,
79 poll_interval: Duration,
80}
81
82impl ShutdownProbe {
83 pub fn new(name: impl Into<String>) -> Self {
85 Self {
86 name: name.into(),
87 components: Vec::new(),
88 deadline: Duration::from_secs(5),
89 poll_interval: Duration::from_millis(50),
90 }
91 }
92
93 pub fn deadline(mut self, d: Duration) -> Self {
95 self.deadline = d;
96 self
97 }
98
99 pub fn poll_interval(mut self, d: Duration) -> Self {
101 self.poll_interval = d;
102 self
103 }
104
105 pub fn with_component(mut self, component: ShutdownComponent) -> Self {
107 self.components.push(component);
108 self
109 }
110
111 pub async fn run(self) -> Vec<CheckResult> {
122 let group = self.name;
123 let deadline = self.deadline;
124 let interval = self.poll_interval;
125 let started = Instant::now();
126 let mut results = Vec::with_capacity(self.components.len() + 1);
127 let mut failed_any = false;
128
129 for comp in self.components {
130 let comp_name = format!("async::shutdown::{group}::{}", comp.name);
131 let comp_started = Instant::now();
132 let mut drained = false;
133 loop {
134 let elapsed_total = started.elapsed();
135 if elapsed_total >= deadline {
136 break;
137 }
138 if (comp.drain_check)().await {
139 drained = true;
140 break;
141 }
142 tokio::time::sleep(interval).await;
143 }
144 let elapsed = comp_started.elapsed();
145 let evidence = vec![
146 Evidence::numeric("elapsed_ms", elapsed.as_millis() as f64),
147 Evidence::numeric("deadline_ms", deadline.as_millis() as f64),
148 Evidence::numeric("poll_interval_ms", interval.as_millis() as f64),
149 ];
150 let mut check = if drained {
151 let mut c = CheckResult::pass(comp_name)
152 .with_duration_ms(elapsed.as_millis() as u64)
153 .with_detail(format!("drained in {elapsed:?}"));
154 c.tags = vec!["async".to_string(), "shutdown".to_string()];
155 c
156 } else {
157 failed_any = true;
158 let mut c = CheckResult::fail(comp_name, Severity::Error)
159 .with_detail(format!("did not drain within {deadline:?}"));
160 c.tags = vec![
161 "async".to_string(),
162 "shutdown".to_string(),
163 "not_drained".to_string(),
164 "regression".to_string(),
165 ];
166 c
167 };
168 check.evidence = evidence;
169 results.push(check);
170 }
171
172 let total_elapsed = started.elapsed();
173 let aggregate_name = format!("async::shutdown::{group}");
174 let evidence = vec![
175 Evidence::numeric("components", results.len() as f64),
176 Evidence::numeric("elapsed_ms", total_elapsed.as_millis() as f64),
177 Evidence::numeric("deadline_ms", deadline.as_millis() as f64),
178 ];
179 let mut aggregate = if failed_any {
180 let mut c = CheckResult::fail(aggregate_name, Severity::Error)
181 .with_detail("one or more components did not drain");
182 c.tags = vec![
183 "async".to_string(),
184 "shutdown".to_string(),
185 "aggregate".to_string(),
186 "regression".to_string(),
187 ];
188 c
189 } else {
190 let mut c = CheckResult::pass(aggregate_name)
191 .with_duration_ms(total_elapsed.as_millis() as u64)
192 .with_detail("all components drained");
193 c.tags = vec![
194 "async".to_string(),
195 "shutdown".to_string(),
196 "aggregate".to_string(),
197 ];
198 c
199 };
200 aggregate.evidence = evidence;
201 results.push(aggregate);
202 results
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use dev_report::Verdict;
210 use std::sync::atomic::{AtomicBool, Ordering};
211 use std::sync::Arc;
212
213 #[tokio::test]
214 async fn already_drained_component_passes() {
215 let comp = ShutdownComponent::new("done", || async { true });
216 let probe = ShutdownProbe::new("sys")
217 .deadline(Duration::from_millis(100))
218 .poll_interval(Duration::from_millis(5))
219 .with_component(comp);
220 let results = probe.run().await;
221 assert_eq!(results.len(), 2);
222 assert_eq!(results[0].verdict, Verdict::Pass);
223 assert_eq!(results[1].verdict, Verdict::Pass);
224 assert!(results[1].has_tag("aggregate"));
225 }
226
227 #[tokio::test]
228 async fn never_draining_component_fails() {
229 let comp = ShutdownComponent::new("hung", || async { false });
230 let probe = ShutdownProbe::new("sys")
231 .deadline(Duration::from_millis(50))
232 .poll_interval(Duration::from_millis(5))
233 .with_component(comp);
234 let results = probe.run().await;
235 assert_eq!(results[0].verdict, Verdict::Fail);
236 assert!(results[0].has_tag("not_drained"));
237 assert_eq!(results[1].verdict, Verdict::Fail);
238 }
239
240 #[tokio::test]
241 async fn component_drains_eventually() {
242 let flag = Arc::new(AtomicBool::new(false));
243 let f2 = flag.clone();
245 tokio::spawn(async move {
246 tokio::time::sleep(Duration::from_millis(30)).await;
247 f2.store(true, Ordering::Relaxed);
248 });
249 let f3 = flag.clone();
250 let comp = ShutdownComponent::new("delayed", move || {
251 let f = f3.clone();
252 async move { f.load(Ordering::Relaxed) }
253 });
254 let probe = ShutdownProbe::new("sys")
255 .deadline(Duration::from_millis(200))
256 .poll_interval(Duration::from_millis(5))
257 .with_component(comp);
258 let results = probe.run().await;
259 assert_eq!(results[0].verdict, Verdict::Pass);
260 }
261
262 #[tokio::test]
263 async fn aggregate_includes_component_evidence_count() {
264 let probe = ShutdownProbe::new("multi")
265 .deadline(Duration::from_millis(50))
266 .poll_interval(Duration::from_millis(5))
267 .with_component(ShutdownComponent::new("a", || async { true }))
268 .with_component(ShutdownComponent::new("b", || async { true }));
269 let results = probe.run().await;
270 assert_eq!(results.len(), 3); let agg = results.last().unwrap();
272 let comps = agg
273 .evidence
274 .iter()
275 .find(|e| e.label == "components")
276 .unwrap();
277 if let dev_report::EvidenceData::Numeric(n) = comps.data {
278 assert_eq!(n, 2.0);
279 } else {
280 panic!("expected numeric");
281 }
282 }
283}