Skip to main content

dev_async/
shutdown.rs

1//! Graceful-shutdown verification.
2//!
3//! [`ShutdownProbe`] watches a set of components and emits one
4//! [`CheckResult`] per component plus an aggregate. A component is
5//! considered "drained" when its predicate returns `true`.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::time::{Duration, Instant};
10
11use dev_report::{CheckResult, Evidence, Severity};
12
13/// A predicate that returns `true` once the named component has
14/// reached a clean stopped state.
15///
16/// The closure may be async (return a future). The probe polls it on
17/// a configurable interval until the predicate returns `true` or the
18/// deadline elapses.
19pub type DrainCheck = Box<dyn Fn() -> Pin<Box<dyn Future<Output = bool> + Send>> + Send + Sync>;
20
21/// A component to drain.
22pub struct ShutdownComponent {
23    name: String,
24    drain_check: DrainCheck,
25}
26
27impl ShutdownComponent {
28    /// Build a component with the given name and drain predicate.
29    ///
30    /// The predicate MUST return `true` once the component is fully
31    /// drained.
32    pub fn new<F, Fut>(name: impl Into<String>, drain_check: F) -> Self
33    where
34        F: Fn() -> Fut + Send + Sync + 'static,
35        Fut: Future<Output = bool> + Send + 'static,
36    {
37        let drain_check: DrainCheck = Box::new(move || Box::pin(drain_check()));
38        Self {
39            name: name.into(),
40            drain_check,
41        }
42    }
43}
44
45/// Polls a set of [`ShutdownComponent`]s until they all drain or the
46/// deadline elapses.
47///
48/// # Example
49///
50/// ```no_run
51/// use dev_async::shutdown::{ShutdownComponent, ShutdownProbe};
52/// use std::sync::atomic::{AtomicBool, Ordering};
53/// use std::sync::Arc;
54/// use std::time::Duration;
55///
56/// # async fn ex() {
57/// let drained = Arc::new(AtomicBool::new(true));
58/// let comp = {
59///     let drained = drained.clone();
60///     ShutdownComponent::new("worker", move || {
61///         let d = drained.clone();
62///         async move { d.load(Ordering::Relaxed) }
63///     })
64/// };
65///
66/// let probe = ShutdownProbe::new("system")
67///     .deadline(Duration::from_millis(200))
68///     .poll_interval(Duration::from_millis(10))
69///     .with_component(comp);
70///
71/// let checks = probe.run().await;
72/// assert!(!checks.is_empty());
73/// # }
74/// ```
75pub struct ShutdownProbe {
76    name: String,
77    components: Vec<ShutdownComponent>,
78    deadline: Duration,
79    poll_interval: Duration,
80}
81
82impl ShutdownProbe {
83    /// Begin building a probe with a stable name.
84    pub fn new(name: impl Into<String>) -> Self {
85        Self {
86            name: name.into(),
87            components: Vec::new(),
88            deadline: Duration::from_secs(5),
89            poll_interval: Duration::from_millis(50),
90        }
91    }
92
93    /// Maximum time to wait for the system to drain.
94    pub fn deadline(mut self, d: Duration) -> Self {
95        self.deadline = d;
96        self
97    }
98
99    /// How often to re-evaluate each component's drain predicate.
100    pub fn poll_interval(mut self, d: Duration) -> Self {
101        self.poll_interval = d;
102        self
103    }
104
105    /// Add a component to the probe.
106    pub fn with_component(mut self, component: ShutdownComponent) -> Self {
107        self.components.push(component);
108        self
109    }
110
111    /// Run the probe and return one [`CheckResult`] per component plus
112    /// an aggregate.
113    ///
114    /// Per-component verdicts:
115    /// - Drained before deadline -> `Pass` with `elapsed_ms` evidence.
116    /// - Did not drain in time -> `Fail (Error)` with `not_drained` tag.
117    ///
118    /// The aggregate verdict is `Fail` if any component failed,
119    /// otherwise `Pass`. It is the last entry in the returned vector
120    /// and tagged `aggregate`.
121    pub async fn run(self) -> Vec<CheckResult> {
122        let group = self.name;
123        let deadline = self.deadline;
124        let interval = self.poll_interval;
125        let started = Instant::now();
126        let mut results = Vec::with_capacity(self.components.len() + 1);
127        let mut failed_any = false;
128
129        for comp in self.components {
130            let comp_name = format!("async::shutdown::{group}::{}", comp.name);
131            let comp_started = Instant::now();
132            let mut drained = false;
133            loop {
134                let elapsed_total = started.elapsed();
135                if elapsed_total >= deadline {
136                    break;
137                }
138                if (comp.drain_check)().await {
139                    drained = true;
140                    break;
141                }
142                tokio::time::sleep(interval).await;
143            }
144            let elapsed = comp_started.elapsed();
145            let evidence = vec![
146                Evidence::numeric("elapsed_ms", elapsed.as_millis() as f64),
147                Evidence::numeric("deadline_ms", deadline.as_millis() as f64),
148                Evidence::numeric("poll_interval_ms", interval.as_millis() as f64),
149            ];
150            let mut check = if drained {
151                let mut c = CheckResult::pass(comp_name)
152                    .with_duration_ms(elapsed.as_millis() as u64)
153                    .with_detail(format!("drained in {elapsed:?}"));
154                c.tags = vec!["async".to_string(), "shutdown".to_string()];
155                c
156            } else {
157                failed_any = true;
158                let mut c = CheckResult::fail(comp_name, Severity::Error)
159                    .with_detail(format!("did not drain within {deadline:?}"));
160                c.tags = vec![
161                    "async".to_string(),
162                    "shutdown".to_string(),
163                    "not_drained".to_string(),
164                    "regression".to_string(),
165                ];
166                c
167            };
168            check.evidence = evidence;
169            results.push(check);
170        }
171
172        let total_elapsed = started.elapsed();
173        let aggregate_name = format!("async::shutdown::{group}");
174        let evidence = vec![
175            Evidence::numeric("components", results.len() as f64),
176            Evidence::numeric("elapsed_ms", total_elapsed.as_millis() as f64),
177            Evidence::numeric("deadline_ms", deadline.as_millis() as f64),
178        ];
179        let mut aggregate = if failed_any {
180            let mut c = CheckResult::fail(aggregate_name, Severity::Error)
181                .with_detail("one or more components did not drain");
182            c.tags = vec![
183                "async".to_string(),
184                "shutdown".to_string(),
185                "aggregate".to_string(),
186                "regression".to_string(),
187            ];
188            c
189        } else {
190            let mut c = CheckResult::pass(aggregate_name)
191                .with_duration_ms(total_elapsed.as_millis() as u64)
192                .with_detail("all components drained");
193            c.tags = vec![
194                "async".to_string(),
195                "shutdown".to_string(),
196                "aggregate".to_string(),
197            ];
198            c
199        };
200        aggregate.evidence = evidence;
201        results.push(aggregate);
202        results
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use dev_report::Verdict;
210    use std::sync::atomic::{AtomicBool, Ordering};
211    use std::sync::Arc;
212
213    #[tokio::test]
214    async fn already_drained_component_passes() {
215        let comp = ShutdownComponent::new("done", || async { true });
216        let probe = ShutdownProbe::new("sys")
217            .deadline(Duration::from_millis(100))
218            .poll_interval(Duration::from_millis(5))
219            .with_component(comp);
220        let results = probe.run().await;
221        assert_eq!(results.len(), 2);
222        assert_eq!(results[0].verdict, Verdict::Pass);
223        assert_eq!(results[1].verdict, Verdict::Pass);
224        assert!(results[1].has_tag("aggregate"));
225    }
226
227    #[tokio::test]
228    async fn never_draining_component_fails() {
229        let comp = ShutdownComponent::new("hung", || async { false });
230        let probe = ShutdownProbe::new("sys")
231            .deadline(Duration::from_millis(50))
232            .poll_interval(Duration::from_millis(5))
233            .with_component(comp);
234        let results = probe.run().await;
235        assert_eq!(results[0].verdict, Verdict::Fail);
236        assert!(results[0].has_tag("not_drained"));
237        assert_eq!(results[1].verdict, Verdict::Fail);
238    }
239
240    #[tokio::test]
241    async fn component_drains_eventually() {
242        let flag = Arc::new(AtomicBool::new(false));
243        // Trigger the flag after 30ms.
244        let f2 = flag.clone();
245        tokio::spawn(async move {
246            tokio::time::sleep(Duration::from_millis(30)).await;
247            f2.store(true, Ordering::Relaxed);
248        });
249        let f3 = flag.clone();
250        let comp = ShutdownComponent::new("delayed", move || {
251            let f = f3.clone();
252            async move { f.load(Ordering::Relaxed) }
253        });
254        let probe = ShutdownProbe::new("sys")
255            .deadline(Duration::from_millis(200))
256            .poll_interval(Duration::from_millis(5))
257            .with_component(comp);
258        let results = probe.run().await;
259        assert_eq!(results[0].verdict, Verdict::Pass);
260    }
261
262    #[tokio::test]
263    async fn aggregate_includes_component_evidence_count() {
264        let probe = ShutdownProbe::new("multi")
265            .deadline(Duration::from_millis(50))
266            .poll_interval(Duration::from_millis(5))
267            .with_component(ShutdownComponent::new("a", || async { true }))
268            .with_component(ShutdownComponent::new("b", || async { true }));
269        let results = probe.run().await;
270        assert_eq!(results.len(), 3); // 2 components + aggregate
271        let agg = results.last().unwrap();
272        let comps = agg
273            .evidence
274            .iter()
275            .find(|e| e.label == "components")
276            .unwrap();
277        if let dev_report::EvidenceData::Numeric(n) = comps.data {
278            assert_eq!(n, 2.0);
279        } else {
280            panic!("expected numeric");
281        }
282    }
283}