bounded_taskpool/
taskpool.rs

1use std::{future::Future, pin::Pin, sync::Arc};
2
3use std::num::NonZeroUsize;
4use tokio::{
5    sync::{Semaphore, mpsc::Sender},
6    task::JoinHandle,
7};
8
9use super::TaskPoolError;
10
11/// Inner task wraps a function and a special state
12enum InnerTask {
13    /// Task to execute
14    Task(Pin<Box<dyn Future<Output = ()> + Send>>),
15    /// Stop the task pool
16    Stop,
17}
18
19/// Inner handle wraps a task handle and a special state
20enum InnerHandle {
21    /// Handle which need a completion check
22    Handle(JoinHandle<()>),
23    /// No more handles to wait
24    Stop,
25}
26
27/// Bounded pool of tasks
28#[derive(Clone, Debug)]
29pub struct TaskPool {
30    /// Ordering queue to schedule tasks while keeping the order
31    ordering_queue: Sender<InnerTask>,
32    /// Number of allowed parallel tasks
33    concurrency: usize,
34    /// Size of the backpressure queue
35    queue_size: usize,
36}
37
38impl TaskPool {
39    /// Create a new taskpool with a given concurrency and queue size (for
40    /// backpressure).
41    ///
42    /// ```rust
43    /// use std::num::NonZeroUsize;
44    /// use std::time::Duration;
45    /// use taskpool::TaskPool;
46    ///
47    /// #[tokio::main]
48    /// async fn main() {
49    ///     let (pool, drained) = TaskPool::new(
50    ///         NonZeroUsize::new(4).expect("non-zero concurrency"),
51    ///         NonZeroUsize::new(64).expect("non-zero queue size"),
52    ///     );
53    ///
54    ///     for i in 0..10 {
55    ///         pool.spawn_with_timeout(
56    ///             async move {
57    ///                 println!("job {i}");
58    ///                 tokio::time::sleep(Duration::from_millis(50)).await;
59    ///             },
60    ///             Duration::from_millis(250),
61    ///         )
62    ///         .await
63    ///         .expect("schedule task");
64    ///     }
65    ///
66    ///     pool.trigger_stop().await.expect("stop accepted");
67    ///     drained.await.expect("pool drained");
68    /// }
69    /// ```
70    #[must_use]
71    pub fn new(
72        concurrency: NonZeroUsize,
73        queue_size: NonZeroUsize,
74    ) -> (Self, tokio::sync::oneshot::Receiver<()>) {
75        let concurrency = concurrency.get();
76        let queue_size = queue_size.get();
77        let sem = Arc::new(Semaphore::new(concurrency));
78
79        // Allocate a huge queue to schedule tasks in order
80        let (queue_tx, mut queue_rx) = tokio::sync::mpsc::channel::<InnerTask>(queue_size);
81        let (stop_tx, stop_rx) = tokio::sync::oneshot::channel::<()>();
82
83        // Depop the queue and schedule tasks
84        let sem_cpy = Arc::clone(&sem);
85        let (ack_tx, mut ack_rx) = tokio::sync::mpsc::unbounded_channel::<InnerHandle>();
86
87        // Main queue
88        tokio::spawn(async move {
89            while let Some(inner_task) = queue_rx.recv().await {
90                let task = match inner_task {
91                    InnerTask::Task(task) => task,
92                    InnerTask::Stop => break,
93                };
94
95                let guard_res = Arc::clone(&sem_cpy).acquire_owned().await;
96                let guard = match guard_res {
97                    Ok(guard) => guard,
98                    Err(err) => {
99                        tracing::error!("failed to acquire semaphore, skipping task: {err}");
100                        continue;
101                    }
102                };
103
104                let job = tokio::spawn(async move {
105                    let _guard = guard;
106                    task.await;
107                });
108                if let Err(err) = ack_tx.send(InnerHandle::Handle(job)) {
109                    tracing::warn!("issue occured while send task handle: {err}");
110                }
111            }
112
113            // Just so we know when to stop consuming handles
114            if let Err(err) = ack_tx.send(InnerHandle::Stop) {
115                tracing::warn!("issue occured while send task handle: {err}");
116            }
117        });
118
119        // Acknowledgment queue
120        tokio::spawn(async move {
121            // Try to consume all remaining jobs, to ensure all of them has been
122            // executed.
123            while let Some(inner_handle) = ack_rx.recv().await {
124                match inner_handle {
125                    InnerHandle::Handle(ack_handle) => {
126                        if let Err(err) = ack_handle.await {
127                            tracing::warn!("issue occured while waiting for task: {err}");
128                        }
129                    }
130                    InnerHandle::Stop => break,
131                }
132            }
133
134            // The queue has been fully drained, and all tasks has been executed
135            if let Err(()) = stop_tx.send(()) {
136                tracing::warn!(
137                    "issue occured while trying to trigger the end of the task pool drain"
138                );
139            }
140        });
141
142        (
143            Self {
144                ordering_queue: queue_tx,
145                concurrency,
146                queue_size,
147            },
148            stop_rx,
149        )
150    }
151
152    /// Get the concurrency of the pool
153    #[must_use]
154    pub const fn concurrency(&self) -> usize {
155        self.concurrency
156    }
157
158    /// Get the queue size of the pool
159    #[must_use]
160    pub const fn queue_size(&self) -> usize {
161        self.queue_size
162    }
163
164    /// Spawn a task in the pool with the default timeout.
165    /// If a task can't be inserted, it will wait until it can.
166    /// It's usually better to use `spawn_with_timeout`, to avoid locking.
167    ///
168    /// # Errors
169    ///
170    /// Returns an error if fails to schedule (e.g., timeout or channel closed).
171    pub async fn spawn(
172        &self,
173        cb: impl Future<Output = ()> + Send + 'static,
174    ) -> Result<(), TaskPoolError> {
175        let pinned_cb = Box::pin(cb);
176
177        match self.ordering_queue.send(InnerTask::Task(pinned_cb)).await {
178            Ok(()) => Ok(()),
179            Err(_) => Err(TaskPoolError::FailedToSend),
180        }
181    }
182
183    /// Spawn a task in the pool with the given timeout.
184    /// Timeout is not for the task, but for channel queue insertion.
185    ///
186    /// # Errors
187    ///
188    /// Returns an error if fails to schedule (e.g., timeout or channel closed).
189    pub async fn spawn_with_timeout(
190        &self,
191        cb: impl Future<Output = ()> + Send + 'static,
192        timeout: std::time::Duration,
193    ) -> Result<(), TaskPoolError> {
194        let pinned_cb = Box::pin(cb);
195
196        match self
197            .ordering_queue
198            .send_timeout(InnerTask::Task(pinned_cb), timeout)
199            .await
200        {
201            Ok(()) => Ok(()),
202            Err(tokio::sync::mpsc::error::SendTimeoutError::Timeout(_)) => {
203                // Because the return type of the timeout include the given
204                // callback type, it forces us to be constraint by "Sync",
205                // without that, it's possible to directly send anonymous
206                // callback (like async{}).
207                Err(TaskPoolError::SendTimeout)
208            }
209            Err(tokio::sync::mpsc::error::SendTimeoutError::Closed(_)) => {
210                // _task_future is the Pin<Box<dyn Future + Send>>.
211                // It's not Sync and not Debug, so we can't easily include it in the error.
212                // We simply acknowledge the channel was closed.
213                Err(TaskPoolError::ChannelClosed)
214            }
215        }
216    }
217
218    /// Trigger a stop by adding a special event in the pool.
219    /// Everything after this marker will not be consumed and the consuming of
220    /// the queue will be stopped.
221    ///
222    /// # Errors
223    ///
224    /// Returns an error if fails to schedule.
225    pub async fn trigger_stop(&self) -> Result<(), TaskPoolError> {
226        // The only inner error possible is "channel close".
227        // So let's not depend on the return type and write the same thing
228        // manually. By doing that, we're not relying on the inner callback to
229        // be constraint by Sync.
230        match self.ordering_queue.send(InnerTask::Stop).await {
231            Ok(()) => Ok(()),
232            Err(_) => Err(TaskPoolError::ChannelClosed),
233        }
234    }
235}