task_collection/
lib.rs

1
2
3 
4//! Types for spawning and waiting on groups of tasks.
5//!
6//! This crate provides onthe [TaskCollection](crate::TaskCollection) for the
7//! grouping of spawned tasks. The TaskCollection type is created using a
8//! [Spawner](crate::Spawner) implementation. Any tasks spawned via the
9//! TaskCollections `spawn` method are tracked with minimal overhead.
10//! `await`ing the TaskCollection will yield until all the spawned tasks for
11//! that collection have completed.
12//!
13//! The following example shows how to use a TaskCollection to wait for spawned tasks to finish:
14//!
15//! ```rust
16//! # use task_collection::TaskCollection;
17//! # use futures::channel::mpsc;
18//! # use futures::StreamExt;
19//! # use core::time::Duration;
20//!
21//! fn main() {
22//!     let runtime = tokio::runtime::Runtime::new().unwrap();
23//!     let (tx, mut rx) = mpsc::unbounded::<u64>();
24//!     
25//!     runtime.spawn(async move {
26//!         (0..10).for_each(|v| {
27//!             tx.unbounded_send(v).expect("Failed to send");
28//!         })
29//!     });
30//!     
31//!     runtime.block_on(async {
32//!         let collection = TaskCollection::new(&runtime);
33//!         while let Some(val) = rx.next().await {
34//!             collection.spawn(async move {
35//!                 // Simulate some async work
36//!                 tokio::time::sleep(Duration::from_secs(val)).await;
37//!                 println!("Value {}", val);
38//!             });
39//!         }
40//!
41//!         collection.await;
42//!         println!("All values printed");
43//!     });
44//! }
45//! ```
46
47#![no_std]
48
49#[cfg(feature = "alloc")]
50extern crate alloc;
51
52use core::task::Poll;
53use core::{future::Future, sync::atomic::AtomicUsize};
54use core::{ops::Deref, sync::atomic::Ordering::SeqCst};
55
56use futures_util::task::AtomicWaker;
57
58pub struct TaskCollection<S, T> {
59    spawner: S,
60    tracker: T,
61}
62
63impl<S> TaskCollection<S, ()>
64where
65    S: Spawner,
66{
67    pub fn with_static_tracker(
68        spawner: S,
69        tracker: &'static Tracker,
70    ) -> TaskCollection<S, &'static Tracker> {
71        TaskCollection { spawner, tracker }
72    }
73
74    #[cfg(feature = "alloc")]
75    pub fn new(spawner: S) -> TaskCollection<S, alloc::sync::Arc<Tracker>> {
76        TaskCollection {
77            spawner,
78            tracker: alloc::sync::Arc::new(Tracker::new()),
79        }
80    }
81}
82
83impl<S, T> TaskCollection<S, T>
84where
85    S: Spawner,
86    T: 'static + Deref<Target = Tracker> + Clone + Send,
87{
88    pub fn spawn<F, R>(&self, future: F)
89    where
90        F: Future<Output = R> + Send + 'static,
91    {
92        let tracker = self.create_task();
93        self.spawner.spawn(async {
94            let _ = future.await;
95            core::mem::drop(tracker);
96        });
97    }
98
99    fn create_task(&self) -> Task<T> {
100        let mut current_tasks = self.tracker.active_tasks.load(SeqCst);
101
102        loop {
103            if current_tasks == usize::MAX {
104                panic!();
105            }
106
107            let new_tasks = current_tasks + 1;
108
109            let actual_current =
110                self.tracker
111                    .active_tasks
112                    .compare_and_swap(current_tasks, new_tasks, SeqCst);
113
114            if current_tasks == actual_current {
115                return Task {
116                    inner: self.tracker.clone(),
117                };
118            }
119
120            current_tasks = actual_current;
121        }
122    }
123}
124
125impl<S, T> Future for TaskCollection<S, T>
126where
127    T: core::ops::Deref<Target = Tracker>,
128{
129    type Output = ();
130
131    fn poll(
132        self: core::pin::Pin<&mut Self>,
133        cx: &mut core::task::Context<'_>,
134    ) -> core::task::Poll<Self::Output> {
135        let active_tasks = self.tracker.active_tasks.load(SeqCst);
136
137        if active_tasks == 0 {
138            Poll::Ready(())
139        } else {
140            self.tracker.waker.register(cx.waker());
141
142            let active_tasks = self.tracker.active_tasks.load(SeqCst);
143            if active_tasks == 0 {
144                Poll::Ready(())
145            } else {
146                Poll::Pending
147            }
148        }
149    }
150}
151
152struct Task<T>
153where
154    T: Deref<Target = Tracker>,
155{
156    inner: T,
157}
158
159impl<T> Drop for Task<T>
160where
161    T: Deref<Target = Tracker>,
162{
163    fn drop(&mut self) {
164        let previous = self.inner.active_tasks.fetch_sub(1, SeqCst);
165
166        if previous == 1 {
167            self.inner.waker.wake();
168        }
169    }
170}
171
172pub struct Tracker {
173    waker: AtomicWaker,
174    active_tasks: AtomicUsize,
175}
176
177impl Tracker {
178    pub const fn new() -> Tracker {
179        Tracker {
180            waker: AtomicWaker::new(),
181            active_tasks: AtomicUsize::new(0),
182        }
183    }
184}
185
186pub trait Spawner {
187    fn spawn<F>(&self, future: F)
188    where
189        F: Future<Output = ()> + Send + 'static;
190}
191
192#[cfg(feature = "smol")]
193impl Spawner for &smol::Executor<'_> {
194    fn spawn<F>(&self, future: F)
195    where
196        F: core::future::Future<Output = ()> + Send + 'static,
197    {
198        smol::Executor::spawn(self, future).detach();
199    }
200}
201
202#[cfg(feature = "tokio")]
203impl Spawner for &tokio::runtime::Runtime {
204    fn spawn<F>(&self, future: F)
205    where
206        F: core::future::Future<Output = ()> + Send + 'static,
207    {
208        tokio::runtime::Runtime::spawn(self, future);
209    }
210}
211
212#[cfg(feature = "tokio")]
213pub struct GlobalTokioSpawner;
214
215#[cfg(feature = "tokio")]
216impl Spawner for GlobalTokioSpawner {
217    fn spawn<F>(&self, future: F)
218    where
219        F: core::future::Future<Output = ()> + Send + 'static,
220    {
221        tokio::spawn(future);
222    }
223}
224
225#[cfg(feature = "async-std")]
226pub struct AsyncStdSpawner;
227
228#[cfg(feature = "async-std")]
229impl Spawner for AsyncStdSpawner {
230    fn spawn<F>(&self, future: F)
231    where
232        F: core::future::Future<Output = ()> + Send + 'static,
233    {
234        async_std::task::spawn(future);
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    extern crate std;
241    use crate::{TaskCollection, Tracker};
242    use core::panic;
243    use smol::future::FutureExt;
244    use std::time::Duration;
245
246    #[test]
247    #[cfg(feature = "smol")]
248    fn test_smol() {
249        let exec = smol::Executor::new();
250
251        let f = async {
252            let collection = TaskCollection::new(&exec);
253
254            for i in &[5, 3, 1, 4, 2] {
255                collection.spawn(async move {
256                    smol::Timer::after(Duration::from_secs(*i)).await;
257                });
258            }
259
260            collection.await;
261        };
262
263        let timeout = async {
264            smol::Timer::after(Duration::from_secs(10)).await;
265            panic!();
266        };
267
268        smol::block_on(exec.run(f.or(timeout)));
269    }
270
271    #[test]
272    #[cfg(feature = "smol")]
273    fn test_smol_static() {
274        let exec = smol::Executor::new();
275        static T: Tracker = Tracker::new();
276        let f = async {
277            let collection = TaskCollection::with_static_tracker(&exec, &T);
278
279            for i in &[5, 3, 1, 4, 2] {
280                collection.spawn(async move {
281                    smol::Timer::after(Duration::from_secs(*i)).await;
282                });
283            }
284
285            collection.await;
286        };
287
288        let timeout = async {
289            smol::Timer::after(Duration::from_secs(10)).await;
290            panic!();
291        };
292
293        smol::block_on(exec.run(f.or(timeout)));
294    }
295
296    #[test]
297    #[cfg(feature = "tokio")]
298    fn test_tokio() {
299        let runtime = tokio::runtime::Runtime::new().unwrap();
300
301        let f = async {
302            let collection = TaskCollection::new(&runtime);
303
304            for i in &[5, 3, 1, 4, 2] {
305                collection.spawn(async move {
306                    tokio::time::sleep(Duration::from_secs(*i)).await;
307                });
308            }
309
310            collection.await;
311        };
312
313        runtime.block_on(async {
314            tokio::select! {
315                _ = f => (),
316                _ = tokio::time::sleep(Duration::from_secs(10)) => panic!()
317            }
318        });
319    }
320
321    #[test]
322    #[cfg(feature = "async-std")]
323    fn test_async_std() {
324        use crate::AsyncStdSpawner;
325        let f = async {
326            let collection = TaskCollection::new(AsyncStdSpawner);
327
328            for i in &[5, 3, 1, 4, 2] {
329                collection.spawn(async move {
330                    async_std::task::sleep(Duration::from_secs(*i)).await;
331                });
332            }
333
334            collection.await;
335        };
336
337        async_std::task::block_on(async_std::future::timeout(Duration::from_secs(10), f)).unwrap();
338    }
339}