dev_async/cancellation_safety.rs
1//! Cancellation-safety verification.
2//!
3//! A future is *cancellation-safe* (per the tokio definition) if
4//! dropping it mid-poll leaves observable state unchanged compared
5//! to never having polled it. Cancellation-unsafety is a common
6//! source of data corruption: a `select!` arm that wins drops the
7//! losing arms, and if those losing arms had already partially
8//! completed visible work (writing bytes, advancing a cursor,
9//! holding a lock), the system is now in an inconsistent state.
10//!
11//! [`check_cancel_safe`] runs a future to a fixed deadline,
12//! captures its in-flight state, then drops it and asks the caller
13//! to verify state. The verdict reports whether the post-cancel
14//! state matches a "safe" predicate.
15//!
16//! ## What this catches
17//!
18//! - Futures that buffer writes but flush mid-await.
19//! - Futures that consume from a stream and acknowledge before yielding.
20//! - State machines that advance internally then await on the next stage.
21//!
22//! ## What this does NOT catch
23//!
24//! - Cancellation issues that only manifest under specific schedules.
25//! - Issues in nested futures the test doesn't directly inspect.
26//!
27//! Treat the verdict as a strong signal, not a proof.
28
29use std::future::Future;
30use std::time::{Duration, Instant};
31
32use dev_report::{CheckResult, Evidence, Severity};
33
34/// Drive `fut` for at most `cancel_at` duration, then drop it. After
35/// the drop, run `assert_safe` to ask the caller whether the
36/// observable state is still consistent. Emit a [`CheckResult`].
37///
38/// Verdicts:
39/// - `assert_safe()` returns `true` -> `Pass` with `cancellation_safe`.
40/// - `assert_safe()` returns `false` -> `Fail (Critical)` with
41/// `cancellation_unsafe` + `regression` tags. State has been
42/// corrupted by the cancellation.
43/// - `fut` completes before `cancel_at` -> `Skip` with detail
44/// ("future completed before cancellation"). The check did not
45/// exercise cancellation; tighten `cancel_at` or run a slower
46/// future.
47///
48/// `assert_safe` is invoked synchronously after the future has been
49/// dropped. It must not panic.
50///
51/// # Example
52///
53/// ```no_run
54/// use dev_async::cancellation_safety::check_cancel_safe;
55/// use std::sync::atomic::{AtomicUsize, Ordering};
56/// use std::sync::Arc;
57/// use std::time::Duration;
58///
59/// # async fn ex() {
60/// let counter = Arc::new(AtomicUsize::new(0));
61/// let c2 = counter.clone();
62///
63/// let check = check_cancel_safe(
64/// "buffered_write",
65/// Duration::from_millis(20),
66/// async move {
67/// c2.fetch_add(1, Ordering::SeqCst);
68/// tokio::time::sleep(Duration::from_secs(1)).await;
69/// c2.fetch_add(1, Ordering::SeqCst); // never reached if cancelled
70/// },
71/// || counter.load(Ordering::SeqCst) <= 1,
72/// ).await;
73///
74/// assert!(check.has_tag("async"));
75/// # }
76/// ```
77pub async fn check_cancel_safe<F, Fut, AssertFn>(
78 name: impl Into<String>,
79 cancel_at: Duration,
80 fut: Fut,
81 assert_safe: AssertFn,
82) -> CheckResult
83where
84 Fut: Future<Output = F>,
85 AssertFn: FnOnce() -> bool,
86{
87 let name = name.into();
88 let started = Instant::now();
89 let result = tokio::time::timeout(cancel_at, fut).await;
90 let elapsed = started.elapsed();
91
92 let evidence_base = vec![
93 Evidence::numeric("cancel_at_ms", cancel_at.as_millis() as f64),
94 Evidence::numeric("elapsed_ms", elapsed.as_millis() as f64),
95 ];
96
97 match result {
98 Ok(_completed) => {
99 let mut c = CheckResult::skip(format!("async::{name}")).with_detail(
100 "future completed before cancellation; check did not exercise drop path",
101 );
102 c.tags = vec!["async".to_string(), "cancellation_check".to_string()];
103 c.evidence = evidence_base;
104 c
105 }
106 Err(_elapsed) => {
107 // The future was dropped. Now assess state.
108 let safe = assert_safe();
109 if safe {
110 let mut c = CheckResult::pass(format!("async::{name}"))
111 .with_duration_ms(elapsed.as_millis() as u64)
112 .with_detail("future cancelled at deadline; state predicate held");
113 c.tags = vec!["async".to_string(), "cancellation_safe".to_string()];
114 c.evidence = evidence_base;
115 c
116 } else {
117 let mut c = CheckResult::fail(format!("async::{name}"), Severity::Critical)
118 .with_detail("state predicate failed after future was cancelled mid-poll");
119 c.tags = vec![
120 "async".to_string(),
121 "cancellation_unsafe".to_string(),
122 "regression".to_string(),
123 ];
124 c.evidence = evidence_base;
125 c
126 }
127 }
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use dev_report::Verdict;
135 use std::sync::atomic::{AtomicUsize, Ordering};
136 use std::sync::Arc;
137
138 #[tokio::test]
139 async fn future_that_completes_yields_skip() {
140 let check = check_cancel_safe("fast", Duration::from_secs(1), async {}, || true).await;
141 assert_eq!(check.verdict, Verdict::Skip);
142 assert!(check.has_tag("cancellation_check"));
143 }
144
145 #[tokio::test]
146 async fn cancellation_with_safe_state_passes() {
147 let counter = Arc::new(AtomicUsize::new(0));
148 let c2 = counter.clone();
149 let check = check_cancel_safe(
150 "buffered_write",
151 Duration::from_millis(20),
152 async move {
153 // Buffered work increments counter only on completion.
154 tokio::time::sleep(Duration::from_secs(1)).await;
155 c2.fetch_add(1, Ordering::SeqCst);
156 },
157 || counter.load(Ordering::SeqCst) == 0,
158 )
159 .await;
160 assert_eq!(check.verdict, Verdict::Pass);
161 assert!(check.has_tag("cancellation_safe"));
162 }
163
164 #[tokio::test]
165 async fn cancellation_with_unsafe_state_fails() {
166 let counter = Arc::new(AtomicUsize::new(0));
167 let c2 = counter.clone();
168 let check = check_cancel_safe(
169 "early_commit",
170 Duration::from_millis(20),
171 async move {
172 // BAD: increments counter before the await — visible
173 // even if the future is cancelled.
174 c2.fetch_add(1, Ordering::SeqCst);
175 tokio::time::sleep(Duration::from_secs(1)).await;
176 },
177 || counter.load(Ordering::SeqCst) == 0,
178 )
179 .await;
180 assert_eq!(check.verdict, Verdict::Fail);
181 assert_eq!(check.severity, Some(Severity::Critical));
182 assert!(check.has_tag("cancellation_unsafe"));
183 assert!(check.has_tag("regression"));
184 }
185
186 #[tokio::test]
187 async fn evidence_includes_cancel_at_and_elapsed() {
188 let check = check_cancel_safe(
189 "x",
190 Duration::from_millis(50),
191 async {
192 tokio::time::sleep(Duration::from_secs(1)).await;
193 },
194 || true,
195 )
196 .await;
197 let labels: Vec<&str> = check.evidence.iter().map(|e| e.label.as_str()).collect();
198 assert!(labels.contains(&"cancel_at_ms"));
199 assert!(labels.contains(&"elapsed_ms"));
200 }
201}