1use std::future::Future;
8use std::ops::Deref;
9use std::pin::Pin;
10use std::task::{Context, Poll};
11use tokio::task::JoinHandle;
12
13#[derive(Debug)]
14pub struct ChildTask<T> {
15 inner: JoinHandle<T>,
16}
17
18impl<T> Drop for ChildTask<T> {
19 fn drop(&mut self) {
20 self.inner.abort()
21 }
22}
23
24impl<T> Future for ChildTask<T> {
25 type Output = <JoinHandle<T> as Future>::Output;
26 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
27 Pin::new(&mut self.inner).poll(cx)
28 }
29}
30
31impl<T> From<JoinHandle<T>> for ChildTask<T> {
32 fn from(inner: JoinHandle<T>) -> Self {
33 Self { inner }
34 }
35}
36
37impl<T> Deref for ChildTask<T> {
38 type Target = JoinHandle<T>;
39 fn deref(&self) -> &Self::Target {
40 &self.inner
41 }
42}
43
44#[cfg(test)]
45mod tests {
46 use super::ChildTask;
47 use futures_util::future::pending;
48 use std::sync::{Arc, RwLock};
49 use tokio::task::yield_now;
50
51 struct Sentry(Arc<RwLock<bool>>);
52 impl Drop for Sentry {
53 fn drop(&mut self) {
54 *self.0.write().unwrap() = true
55 }
56 }
57
58 #[tokio::test]
59 async fn drop_while_not_waiting_for_join() {
60 let dropped = Arc::new(RwLock::new(false));
61 let sentry = Sentry(dropped.clone());
62 let task = ChildTask::from(tokio::spawn(async move {
63 let _sentry = sentry;
64 pending::<()>().await
65 }));
66 yield_now().await;
67 assert!(!*dropped.read().unwrap());
68 drop(task);
69 yield_now().await;
70 assert!(*dropped.read().unwrap());
71 }
72
73 #[tokio::test]
74 async fn drop_while_waiting_for_join() {
75 let dropped = Arc::new(RwLock::new(false));
76 let sentry = Sentry(dropped.clone());
77 let handle = tokio::spawn(async move {
78 ChildTask::from(tokio::spawn(async move {
79 let _sentry = sentry;
80 pending::<()>().await
81 }))
82 .await
83 .unwrap()
84 });
85 yield_now().await;
86 assert!(!*dropped.read().unwrap());
87 handle.abort();
88 yield_now().await;
89 assert!(*dropped.read().unwrap());
90 }
91
92 #[tokio::test]
93 async fn no_drop_only_join() {
94 assert_eq!(
95 ChildTask::from(tokio::spawn(async {
96 yield_now().await;
97 5
98 }))
99 .await
100 .unwrap(),
101 5
102 )
103 }
104
105 #[tokio::test]
106 async fn manually_abort_before_drop() {
107 let dropped = Arc::new(RwLock::new(false));
108 let sentry = Sentry(dropped.clone());
109 let task = ChildTask::from(tokio::spawn(async move {
110 let _sentry = sentry;
111 pending::<()>().await
112 }));
113 yield_now().await;
114 assert!(!*dropped.read().unwrap());
115 task.abort();
116 yield_now().await;
117 assert!(*dropped.read().unwrap());
118 }
119
120 #[tokio::test]
121 async fn manually_abort_then_join() {
122 let dropped = Arc::new(RwLock::new(false));
123 let sentry = Sentry(dropped.clone());
124 let task = ChildTask::from(tokio::spawn(async move {
125 let _sentry = sentry;
126 pending::<()>().await
127 }));
128 yield_now().await;
129 assert!(!*dropped.read().unwrap());
130 task.abort();
131 yield_now().await;
132 assert!(task.await.is_err());
133 }
134}