abort_on_drop/
lib.rs

1//! This crate provides a wrapper type of Tokio's JoinHandle: `ChildTask`, which aborts the task when it's dropped.
2//! `ChildTask` can still be awaited to join the child-task, and abort-on-drop will still trigger while it is being awaited.
3//!
4//! For example, if task A spawned task B but is doing something else, and task B is waiting for task C to join,
5//! aborting A will also abort both B and C.
6
7use std::future::Future;
8use std::ops::Deref;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::task::JoinHandle;
12
13#[derive(Debug)]
14pub struct ChildTask<T> {
15    inner: JoinHandle<T>,
16}
17
18impl<T> Drop for ChildTask<T> {
19    fn drop(&mut self) {
20        self.inner.abort()
21    }
22}
23
24impl<T> Future for ChildTask<T> {
25    type Output = <JoinHandle<T> as Future>::Output;
26    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
27        Pin::new(&mut self.inner).poll(cx)
28    }
29}
30
31impl<T> From<JoinHandle<T>> for ChildTask<T> {
32    fn from(inner: JoinHandle<T>) -> Self {
33        Self { inner }
34    }
35}
36
37impl<T> Deref for ChildTask<T> {
38    type Target = JoinHandle<T>;
39    fn deref(&self) -> &Self::Target {
40        &self.inner
41    }
42}
43
44#[cfg(test)]
45mod tests {
46    use super::ChildTask;
47    use futures_util::future::pending;
48    use std::sync::{Arc, RwLock};
49    use tokio::task::yield_now;
50
51    struct Sentry(Arc<RwLock<bool>>);
52    impl Drop for Sentry {
53        fn drop(&mut self) {
54            *self.0.write().unwrap() = true
55        }
56    }
57
58    #[tokio::test]
59    async fn drop_while_not_waiting_for_join() {
60        let dropped = Arc::new(RwLock::new(false));
61        let sentry = Sentry(dropped.clone());
62        let task = ChildTask::from(tokio::spawn(async move {
63            let _sentry = sentry;
64            pending::<()>().await
65        }));
66        yield_now().await;
67        assert!(!*dropped.read().unwrap());
68        drop(task);
69        yield_now().await;
70        assert!(*dropped.read().unwrap());
71    }
72
73    #[tokio::test]
74    async fn drop_while_waiting_for_join() {
75        let dropped = Arc::new(RwLock::new(false));
76        let sentry = Sentry(dropped.clone());
77        let handle = tokio::spawn(async move {
78            ChildTask::from(tokio::spawn(async move {
79                let _sentry = sentry;
80                pending::<()>().await
81            }))
82            .await
83            .unwrap()
84        });
85        yield_now().await;
86        assert!(!*dropped.read().unwrap());
87        handle.abort();
88        yield_now().await;
89        assert!(*dropped.read().unwrap());
90    }
91
92    #[tokio::test]
93    async fn no_drop_only_join() {
94        assert_eq!(
95            ChildTask::from(tokio::spawn(async {
96                yield_now().await;
97                5
98            }))
99            .await
100            .unwrap(),
101            5
102        )
103    }
104
105    #[tokio::test]
106    async fn manually_abort_before_drop() {
107        let dropped = Arc::new(RwLock::new(false));
108        let sentry = Sentry(dropped.clone());
109        let task = ChildTask::from(tokio::spawn(async move {
110            let _sentry = sentry;
111            pending::<()>().await
112        }));
113        yield_now().await;
114        assert!(!*dropped.read().unwrap());
115        task.abort();
116        yield_now().await;
117        assert!(*dropped.read().unwrap());
118    }
119
120    #[tokio::test]
121    async fn manually_abort_then_join() {
122        let dropped = Arc::new(RwLock::new(false));
123        let sentry = Sentry(dropped.clone());
124        let task = ChildTask::from(tokio::spawn(async move {
125            let _sentry = sentry;
126            pending::<()>().await
127        }));
128        yield_now().await;
129        assert!(!*dropped.read().unwrap());
130        task.abort();
131        yield_now().await;
132        assert!(task.await.is_err());
133    }
134}