Skip to main content

async_rt/
tracker.rs

1use crate::{Executor, ExecutorBlocking, JoinHandle};
2use std::fmt::Debug;
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::atomic::AtomicUsize;
6use std::sync::Arc;
7use std::task::{Context, Poll};
8
9/// Track running tasks.
10///
11/// Note that there is no guarantee that the runtime would drop the future after it is done; therefore,
12/// this should only be used for purely approx statistics and not actual numbers. Additionally,
13/// it does not track any tasks spawned directly by the runtime but only by [`Executor::spawn`] through
14/// this struct.
15pub struct TrackerExecutor<E> {
16    executor: E,
17    counter: Arc<AtomicUsize>,
18}
19
20impl<E> Debug for TrackerExecutor<E> {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        f.debug_struct("TrackerExecutor").finish()
23    }
24}
25
26impl<E: Executor> TrackerExecutor<E> {
27    pub fn new(executor: E) -> Self {
28        Self {
29            executor,
30            counter: Arc::default(),
31        }
32    }
33
34    /// Number of active tasks.
35    pub fn count(&self) -> usize {
36        self.counter.load(std::sync::atomic::Ordering::Relaxed)
37    }
38}
39
40struct FutureCounter<F> {
41    future: F,
42    counter: Arc<AtomicUsize>,
43}
44
45impl<F> FutureCounter<F> {
46    pub fn new(future: F, counter: Arc<AtomicUsize>) -> Self {
47        counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
48        Self { future, counter }
49    }
50}
51
52impl<F> Future for FutureCounter<F>
53where
54    F: Future + 'static + Unpin,
55{
56    type Output = F::Output;
57    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
58        Pin::new(&mut self.future).poll(cx)
59    }
60}
61
62impl<F> Drop for FutureCounter<F> {
63    fn drop(&mut self) {
64        self.counter
65            .fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
66    }
67}
68
69impl<E: Executor> Executor for TrackerExecutor<E> {
70    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
71    where
72        F: Future + Send + 'static,
73        F::Output: Send + 'static,
74    {
75        let counter = self.counter.clone();
76        let future = Box::pin(future);
77        let future = FutureCounter::new(future, counter);
78        self.executor.spawn(future)
79    }
80}
81
82impl<E: ExecutorBlocking> ExecutorBlocking for TrackerExecutor<E> {
83    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
84    where
85        F: FnOnce() -> R + Send + 'static,
86        R: Send + 'static,
87    {
88
89        struct AtomicCounterDrop(Arc<AtomicUsize>);
90
91        impl AtomicCounterDrop {
92            pub fn new(counter: Arc<AtomicUsize>) -> Self {
93                counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
94                Self(counter)
95            }
96        }
97
98        impl Drop for AtomicCounterDrop {
99            fn drop(&mut self) {
100                self.0.fetch_sub(1, std::sync::atomic::Ordering::SeqCst);
101            }
102        }
103
104        let counter = AtomicCounterDrop::new(self.counter.clone());
105
106        self.executor.spawn_blocking(move || {
107            let _counter = counter;
108            f()
109        })
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::TrackerExecutor;
116    use crate::rt::tokio::TokioExecutor;
117    use crate::Executor;
118
119    #[tokio::test]
120    async fn test_tracker_executor() {
121        let executor = TrackerExecutor::new(TokioExecutor);
122        let handle = executor.spawn(futures::future::pending::<()>());
123        assert_eq!(executor.count(), 1);
124        handle.abort();
125        // We yield back to the runtime to allow progress to be made after aborting the task.
126        crate::task::yield_now().await;
127        assert_eq!(executor.count(), 0);
128    }
129}