Skip to main content

dev_async/
cancellation_safety.rs

1//! Cancellation-safety verification.
2//!
3//! A future is *cancellation-safe* (per the tokio definition) if
4//! dropping it mid-poll leaves observable state unchanged compared
5//! to never having polled it. Cancellation-unsafety is a common
6//! source of data corruption: a `select!` arm that wins drops the
7//! losing arms, and if those losing arms had already partially
8//! completed visible work (writing bytes, advancing a cursor,
9//! holding a lock), the system is now in an inconsistent state.
10//!
11//! [`check_cancel_safe`] runs a future to a fixed deadline,
12//! captures its in-flight state, then drops it and asks the caller
13//! to verify state. The verdict reports whether the post-cancel
14//! state matches a "safe" predicate.
15//!
16//! ## What this catches
17//!
18//! - Futures that buffer writes but flush mid-await.
19//! - Futures that consume from a stream and acknowledge before yielding.
20//! - State machines that advance internally then await on the next stage.
21//!
22//! ## What this does NOT catch
23//!
24//! - Cancellation issues that only manifest under specific schedules.
25//! - Issues in nested futures the test doesn't directly inspect.
26//!
27//! Treat the verdict as a strong signal, not a proof.
28
29use std::future::Future;
30use std::time::{Duration, Instant};
31
32use dev_report::{CheckResult, Evidence, Severity};
33
34/// Drive `fut` for at most `cancel_at` duration, then drop it. After
35/// the drop, run `assert_safe` to ask the caller whether the
36/// observable state is still consistent. Emit a [`CheckResult`].
37///
38/// Verdicts:
39/// - `assert_safe()` returns `true` -> `Pass` with `cancellation_safe`.
40/// - `assert_safe()` returns `false` -> `Fail (Critical)` with
41///   `cancellation_unsafe` + `regression` tags. State has been
42///   corrupted by the cancellation.
43/// - `fut` completes before `cancel_at` -> `Skip` with detail
44///   ("future completed before cancellation"). The check did not
45///   exercise cancellation; tighten `cancel_at` or run a slower
46///   future.
47///
48/// `assert_safe` is invoked synchronously after the future has been
49/// dropped. It must not panic.
50///
51/// # Example
52///
53/// ```no_run
54/// use dev_async::cancellation_safety::check_cancel_safe;
55/// use std::sync::atomic::{AtomicUsize, Ordering};
56/// use std::sync::Arc;
57/// use std::time::Duration;
58///
59/// # async fn ex() {
60/// let counter = Arc::new(AtomicUsize::new(0));
61/// let c2 = counter.clone();
62///
63/// let check = check_cancel_safe(
64///     "buffered_write",
65///     Duration::from_millis(20),
66///     async move {
67///         c2.fetch_add(1, Ordering::SeqCst);
68///         tokio::time::sleep(Duration::from_secs(1)).await;
69///         c2.fetch_add(1, Ordering::SeqCst); // never reached if cancelled
70///     },
71///     || counter.load(Ordering::SeqCst) <= 1,
72/// ).await;
73///
74/// assert!(check.has_tag("async"));
75/// # }
76/// ```
77pub async fn check_cancel_safe<F, Fut, AssertFn>(
78    name: impl Into<String>,
79    cancel_at: Duration,
80    fut: Fut,
81    assert_safe: AssertFn,
82) -> CheckResult
83where
84    Fut: Future<Output = F>,
85    AssertFn: FnOnce() -> bool,
86{
87    let name = name.into();
88    let started = Instant::now();
89    let result = tokio::time::timeout(cancel_at, fut).await;
90    let elapsed = started.elapsed();
91
92    let evidence_base = vec![
93        Evidence::numeric("cancel_at_ms", cancel_at.as_millis() as f64),
94        Evidence::numeric("elapsed_ms", elapsed.as_millis() as f64),
95    ];
96
97    match result {
98        Ok(_completed) => {
99            let mut c = CheckResult::skip(format!("async::{name}")).with_detail(
100                "future completed before cancellation; check did not exercise drop path",
101            );
102            c.tags = vec!["async".to_string(), "cancellation_check".to_string()];
103            c.evidence = evidence_base;
104            c
105        }
106        Err(_elapsed) => {
107            // The future was dropped. Now assess state.
108            let safe = assert_safe();
109            if safe {
110                let mut c = CheckResult::pass(format!("async::{name}"))
111                    .with_duration_ms(elapsed.as_millis() as u64)
112                    .with_detail("future cancelled at deadline; state predicate held");
113                c.tags = vec!["async".to_string(), "cancellation_safe".to_string()];
114                c.evidence = evidence_base;
115                c
116            } else {
117                let mut c = CheckResult::fail(format!("async::{name}"), Severity::Critical)
118                    .with_detail("state predicate failed after future was cancelled mid-poll");
119                c.tags = vec![
120                    "async".to_string(),
121                    "cancellation_unsafe".to_string(),
122                    "regression".to_string(),
123                ];
124                c.evidence = evidence_base;
125                c
126            }
127        }
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use dev_report::Verdict;
135    use std::sync::atomic::{AtomicUsize, Ordering};
136    use std::sync::Arc;
137
138    #[tokio::test]
139    async fn future_that_completes_yields_skip() {
140        let check = check_cancel_safe("fast", Duration::from_secs(1), async {}, || true).await;
141        assert_eq!(check.verdict, Verdict::Skip);
142        assert!(check.has_tag("cancellation_check"));
143    }
144
145    #[tokio::test]
146    async fn cancellation_with_safe_state_passes() {
147        let counter = Arc::new(AtomicUsize::new(0));
148        let c2 = counter.clone();
149        let check = check_cancel_safe(
150            "buffered_write",
151            Duration::from_millis(20),
152            async move {
153                // Buffered work increments counter only on completion.
154                tokio::time::sleep(Duration::from_secs(1)).await;
155                c2.fetch_add(1, Ordering::SeqCst);
156            },
157            || counter.load(Ordering::SeqCst) == 0,
158        )
159        .await;
160        assert_eq!(check.verdict, Verdict::Pass);
161        assert!(check.has_tag("cancellation_safe"));
162    }
163
164    #[tokio::test]
165    async fn cancellation_with_unsafe_state_fails() {
166        let counter = Arc::new(AtomicUsize::new(0));
167        let c2 = counter.clone();
168        let check = check_cancel_safe(
169            "early_commit",
170            Duration::from_millis(20),
171            async move {
172                // BAD: increments counter before the await — visible
173                // even if the future is cancelled.
174                c2.fetch_add(1, Ordering::SeqCst);
175                tokio::time::sleep(Duration::from_secs(1)).await;
176            },
177            || counter.load(Ordering::SeqCst) == 0,
178        )
179        .await;
180        assert_eq!(check.verdict, Verdict::Fail);
181        assert_eq!(check.severity, Some(Severity::Critical));
182        assert!(check.has_tag("cancellation_unsafe"));
183        assert!(check.has_tag("regression"));
184    }
185
186    #[tokio::test]
187    async fn evidence_includes_cancel_at_and_elapsed() {
188        let check = check_cancel_safe(
189            "x",
190            Duration::from_millis(50),
191            async {
192                tokio::time::sleep(Duration::from_secs(1)).await;
193            },
194            || true,
195        )
196        .await;
197        let labels: Vec<&str> = check.evidence.iter().map(|e| e.label.as_str()).collect();
198        assert!(labels.contains(&"cancel_at_ms"));
199        assert!(labels.contains(&"elapsed_ms"));
200    }
201}