1use std::error::Error;
4use std::fmt;
5use std::future::Future;
6use std::time::Duration;
7
8use tokio::sync::watch;
9use tokio::time::Instant;
10
11#[derive(Debug, Clone, Eq, PartialEq)]
13#[non_exhaustive]
14pub enum AsyncControlError {
15 TimedOut,
17 Cancelled,
19}
20
21impl fmt::Display for AsyncControlError {
22 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 Self::TimedOut => formatter.write_str("operation timed out"),
25 Self::Cancelled => formatter.write_str("operation was cancelled"),
26 }
27 }
28}
29
30impl Error for AsyncControlError {}
31
32#[derive(Debug, Clone)]
39pub struct CancellationSource {
40 sender: watch::Sender<bool>,
41}
42
43impl CancellationSource {
44 #[must_use]
57 pub fn new() -> (Self, CancellationToken) {
58 let (sender, receiver) = watch::channel(false);
59 (Self { sender }, CancellationToken { receiver })
60 }
61
62 #[must_use]
67 pub fn token(&self) -> CancellationToken {
68 CancellationToken {
69 receiver: self.sender.subscribe(),
70 }
71 }
72
73 pub fn cancel(&self) {
78 let _ = self.sender.send(true);
79 }
80
81 #[must_use]
83 pub fn is_cancelled(&self) -> bool {
84 *self.sender.borrow()
85 }
86}
87
88#[derive(Debug, Clone)]
93pub struct CancellationToken {
94 receiver: watch::Receiver<bool>,
95}
96
97impl CancellationToken {
98 #[must_use]
100 pub fn is_cancelled(&self) -> bool {
101 *self.receiver.borrow()
102 }
103
104 pub async fn cancelled(&mut self) {
111 if *self.receiver.borrow() {
112 return;
113 }
114
115 loop {
116 if self.receiver.changed().await.is_err() {
117 return;
118 }
119 if *self.receiver.borrow_and_update() {
120 return;
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
131pub struct ShutdownTrigger {
132 source: CancellationSource,
133}
134
135impl ShutdownTrigger {
136 pub fn shutdown(&self) {
138 self.source.cancel();
139 }
140
141 #[must_use]
143 pub fn signal(&self) -> ShutdownSignal {
144 ShutdownSignal {
145 token: self.source.token(),
146 }
147 }
148
149 #[must_use]
151 pub fn is_shutdown_requested(&self) -> bool {
152 self.source.is_cancelled()
153 }
154}
155
156#[derive(Debug, Clone)]
161pub struct ShutdownSignal {
162 token: CancellationToken,
163}
164
165impl ShutdownSignal {
166 #[must_use]
168 pub fn is_shutdown_requested(&self) -> bool {
169 self.token.is_cancelled()
170 }
171
172 pub async fn wait(&mut self) {
174 self.token.cancelled().await;
175 }
176}
177
178#[must_use]
180pub fn shutdown_signal() -> (ShutdownTrigger, ShutdownSignal) {
181 let (source, token) = CancellationSource::new();
182 (ShutdownTrigger { source }, ShutdownSignal { token })
183}
184
185pub async fn with_timeout<F, T>(duration: Duration, future: F) -> Result<T, AsyncControlError>
195where
196 F: Future<Output = T>,
197{
198 tokio::time::timeout(duration, future)
199 .await
200 .map_err(|_| AsyncControlError::TimedOut)
201}
202
203pub async fn with_deadline<F, T>(deadline: Instant, future: F) -> Result<T, AsyncControlError>
213where
214 F: Future<Output = T>,
215{
216 tokio::time::timeout_at(deadline, future)
217 .await
218 .map_err(|_| AsyncControlError::TimedOut)
219}
220
221pub async fn run_until_cancelled<F, T>(
231 mut token: CancellationToken,
232 future: F,
233) -> Result<T, AsyncControlError>
234where
235 F: Future<Output = T>,
236{
237 tokio::select! {
238 biased;
239 _ = token.cancelled() => Err(AsyncControlError::Cancelled),
240 value = future => Ok(value),
241 }
242}
243
244pub async fn with_timeout_or_cancel<F, T>(
254 duration: Duration,
255 mut token: CancellationToken,
256 future: F,
257) -> Result<T, AsyncControlError>
258where
259 F: Future<Output = T>,
260{
261 tokio::select! {
262 biased;
263 _ = token.cancelled() => Err(AsyncControlError::Cancelled),
264 result = tokio::time::timeout(duration, future) => {
265 result.map_err(|_| AsyncControlError::TimedOut)
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use std::future::pending;
273 use std::sync::Arc;
274 use std::sync::atomic::{AtomicUsize, Ordering};
275
276 use tokio::sync::Notify;
277 use tokio::time::{Duration, Instant, sleep};
278
279 use super::*;
280
281 struct DropCounter {
282 counter: Arc<AtomicUsize>,
283 }
284
285 impl Drop for DropCounter {
286 fn drop(&mut self) {
287 self.counter.fetch_add(1, Ordering::SeqCst);
288 }
289 }
290
291 #[test]
292 fn async_control_error_formats_public_error_messages() {
293 assert_eq!(
294 AsyncControlError::TimedOut.to_string(),
295 "operation timed out"
296 );
297 assert_eq!(
298 AsyncControlError::Cancelled.to_string(),
299 "operation was cancelled"
300 );
301 assert!(AsyncControlError::TimedOut.source().is_none());
302 assert!(AsyncControlError::Cancelled.source().is_none());
303 }
304
305 #[tokio::test]
306 async fn with_timeout_returns_value_before_deadline() {
307 let actual = with_timeout(Duration::from_secs(1), async { 7 }).await;
308
309 assert_eq!(actual, Ok(7));
310 }
311
312 #[tokio::test(start_paused = true)]
313 async fn with_timeout_reports_elapsed_operation() {
314 let actual = with_timeout(Duration::from_millis(10), async {
315 sleep(Duration::from_secs(1)).await;
316 7
317 })
318 .await;
319
320 assert_eq!(actual, Err(AsyncControlError::TimedOut));
321 }
322
323 #[tokio::test(start_paused = true)]
324 async fn with_deadline_reports_elapsed_operation() {
325 let deadline = Instant::now() + Duration::from_millis(10);
326 let actual = with_deadline(deadline, async {
327 sleep(Duration::from_secs(1)).await;
328 7
329 })
330 .await;
331
332 assert_eq!(actual, Err(AsyncControlError::TimedOut));
333 }
334
335 #[tokio::test(start_paused = true)]
336 async fn with_timeout_or_cancel_reports_timeout_when_token_is_idle() {
337 let (_source, token) = CancellationSource::new();
338
339 let actual = with_timeout_or_cancel(Duration::from_millis(10), token, async {
340 sleep(Duration::from_secs(1)).await;
341 7
342 })
343 .await;
344
345 assert_eq!(actual, Err(AsyncControlError::TimedOut));
346 }
347
348 #[tokio::test]
349 async fn run_until_cancelled_returns_value_before_cancellation() {
350 let (_source, token) = CancellationSource::new();
351
352 let actual = run_until_cancelled(token, async { 7 }).await;
353
354 assert_eq!(actual, Ok(7));
355 }
356
357 #[tokio::test]
358 async fn cancellation_token_completes_when_all_sources_are_dropped() {
359 let (source, mut token) = CancellationSource::new();
360
361 drop(source);
362 token.cancelled().await;
363
364 assert!(!token.is_cancelled());
365 }
366
367 #[tokio::test]
368 async fn run_until_cancelled_reports_cancelled_when_source_is_dropped() {
369 let (source, token) = CancellationSource::new();
370
371 drop(source);
372 let actual = run_until_cancelled(token, pending::<()>()).await;
373
374 assert_eq!(actual, Err(AsyncControlError::Cancelled));
375 }
376
377 #[tokio::test]
378 async fn run_until_cancelled_reports_cancellation_and_drops_future() {
379 let (source, token) = CancellationSource::new();
380 let dropped = Arc::new(AtomicUsize::new(0));
381 let started = Arc::new(Notify::new());
382
383 let task = tokio::spawn({
384 let dropped = Arc::clone(&dropped);
385 let started = Arc::clone(&started);
386 async move {
387 run_until_cancelled(token, async move {
388 let _guard = DropCounter { counter: dropped };
389 started.notify_one();
390 pending::<()>().await;
391 7
392 })
393 .await
394 }
395 });
396
397 started.notified().await;
398 source.cancel();
399 let actual = task.await.unwrap();
400
401 assert_eq!(actual, Err(AsyncControlError::Cancelled));
402 assert_eq!(dropped.load(Ordering::SeqCst), 1);
403 }
404
405 #[tokio::test(start_paused = true)]
406 async fn with_timeout_or_cancel_prefers_cancellation() {
407 let (source, token) = CancellationSource::new();
408 source.cancel();
409
410 let actual = with_timeout_or_cancel(Duration::from_millis(10), token, async {
411 sleep(Duration::from_secs(1)).await;
412 7
413 })
414 .await;
415
416 assert_eq!(actual, Err(AsyncControlError::Cancelled));
417 }
418
419 #[tokio::test]
420 async fn with_timeout_or_cancel_reports_cancelled_when_source_is_dropped() {
421 let (source, token) = CancellationSource::new();
422
423 drop(source);
424 let actual = with_timeout_or_cancel(Duration::from_secs(1), token, pending::<()>()).await;
425
426 assert_eq!(actual, Err(AsyncControlError::Cancelled));
427 }
428
429 #[tokio::test]
430 async fn shutdown_signal_notifies_all_listeners() {
431 let (trigger, mut signal) = shutdown_signal();
432 let mut second = trigger.signal();
433
434 let first_task = tokio::spawn(async move {
435 signal.wait().await;
436 signal.is_shutdown_requested()
437 });
438 let second_task = tokio::spawn(async move {
439 second.wait().await;
440 second.is_shutdown_requested()
441 });
442
443 trigger.shutdown();
444
445 assert!(first_task.await.unwrap());
446 assert!(second_task.await.unwrap());
447 assert!(trigger.is_shutdown_requested());
448 }
449
450 #[tokio::test]
451 async fn shutdown_signal_waits_until_trigger_is_dropped() {
452 let (trigger, mut signal) = shutdown_signal();
453
454 drop(trigger);
455 signal.wait().await;
456
457 assert!(!signal.is_shutdown_requested());
458 }
459}