bounded_taskpool/taskpool.rs
1use std::{future::Future, pin::Pin, sync::Arc};
2
3use std::num::NonZeroUsize;
4use tokio::{
5 sync::{Semaphore, mpsc::Sender},
6 task::JoinHandle,
7};
8
9use super::TaskPoolError;
10
11/// Inner task wraps a function and a special state
12enum InnerTask {
13 /// Task to execute
14 Task(Pin<Box<dyn Future<Output = ()> + Send>>),
15 /// Stop the task pool
16 Stop,
17}
18
19/// Inner handle wraps a task handle and a special state
20enum InnerHandle {
21 /// Handle which need a completion check
22 Handle(JoinHandle<()>),
23 /// No more handles to wait
24 Stop,
25}
26
27/// Bounded pool of tasks
28#[derive(Clone, Debug)]
29pub struct TaskPool {
30 /// Ordering queue to schedule tasks while keeping the order
31 ordering_queue: Sender<InnerTask>,
32 /// Number of allowed parallel tasks
33 concurrency: usize,
34 /// Size of the backpressure queue
35 queue_size: usize,
36}
37
38impl TaskPool {
39 /// Create a new taskpool with a given concurrency and queue size (for
40 /// backpressure).
41 ///
42 /// ```rust
43 /// use std::num::NonZeroUsize;
44 /// use std::time::Duration;
45 /// use taskpool::TaskPool;
46 ///
47 /// #[tokio::main]
48 /// async fn main() {
49 /// let (pool, drained) = TaskPool::new(
50 /// NonZeroUsize::new(4).expect("non-zero concurrency"),
51 /// NonZeroUsize::new(64).expect("non-zero queue size"),
52 /// );
53 ///
54 /// for i in 0..10 {
55 /// pool.spawn_with_timeout(
56 /// async move {
57 /// println!("job {i}");
58 /// tokio::time::sleep(Duration::from_millis(50)).await;
59 /// },
60 /// Duration::from_millis(250),
61 /// )
62 /// .await
63 /// .expect("schedule task");
64 /// }
65 ///
66 /// pool.trigger_stop().await.expect("stop accepted");
67 /// drained.await.expect("pool drained");
68 /// }
69 /// ```
70 #[must_use]
71 pub fn new(
72 concurrency: NonZeroUsize,
73 queue_size: NonZeroUsize,
74 ) -> (Self, tokio::sync::oneshot::Receiver<()>) {
75 let concurrency = concurrency.get();
76 let queue_size = queue_size.get();
77 let sem = Arc::new(Semaphore::new(concurrency));
78
79 // Allocate a huge queue to schedule tasks in order
80 let (queue_tx, mut queue_rx) = tokio::sync::mpsc::channel::<InnerTask>(queue_size);
81 let (stop_tx, stop_rx) = tokio::sync::oneshot::channel::<()>();
82
83 // Depop the queue and schedule tasks
84 let sem_cpy = Arc::clone(&sem);
85 let (ack_tx, mut ack_rx) = tokio::sync::mpsc::unbounded_channel::<InnerHandle>();
86
87 // Main queue
88 tokio::spawn(async move {
89 while let Some(inner_task) = queue_rx.recv().await {
90 let task = match inner_task {
91 InnerTask::Task(task) => task,
92 InnerTask::Stop => break,
93 };
94
95 let guard_res = Arc::clone(&sem_cpy).acquire_owned().await;
96 let guard = match guard_res {
97 Ok(guard) => guard,
98 Err(err) => {
99 tracing::error!("failed to acquire semaphore, skipping task: {err}");
100 continue;
101 }
102 };
103
104 let job = tokio::spawn(async move {
105 let _guard = guard;
106 task.await;
107 });
108 if let Err(err) = ack_tx.send(InnerHandle::Handle(job)) {
109 tracing::warn!("issue occured while send task handle: {err}");
110 }
111 }
112
113 // Just so we know when to stop consuming handles
114 if let Err(err) = ack_tx.send(InnerHandle::Stop) {
115 tracing::warn!("issue occured while send task handle: {err}");
116 }
117 });
118
119 // Acknowledgment queue
120 tokio::spawn(async move {
121 // Try to consume all remaining jobs, to ensure all of them has been
122 // executed.
123 while let Some(inner_handle) = ack_rx.recv().await {
124 match inner_handle {
125 InnerHandle::Handle(ack_handle) => {
126 if let Err(err) = ack_handle.await {
127 tracing::warn!("issue occured while waiting for task: {err}");
128 }
129 }
130 InnerHandle::Stop => break,
131 }
132 }
133
134 // The queue has been fully drained, and all tasks has been executed
135 if let Err(()) = stop_tx.send(()) {
136 tracing::warn!(
137 "issue occured while trying to trigger the end of the task pool drain"
138 );
139 }
140 });
141
142 (
143 Self {
144 ordering_queue: queue_tx,
145 concurrency,
146 queue_size,
147 },
148 stop_rx,
149 )
150 }
151
152 /// Get the concurrency of the pool
153 #[must_use]
154 pub const fn concurrency(&self) -> usize {
155 self.concurrency
156 }
157
158 /// Get the queue size of the pool
159 #[must_use]
160 pub const fn queue_size(&self) -> usize {
161 self.queue_size
162 }
163
164 /// Spawn a task in the pool with the default timeout.
165 /// If a task can't be inserted, it will wait until it can.
166 /// It's usually better to use `spawn_with_timeout`, to avoid locking.
167 ///
168 /// # Errors
169 ///
170 /// Returns an error if fails to schedule (e.g., timeout or channel closed).
171 pub async fn spawn(
172 &self,
173 cb: impl Future<Output = ()> + Send + 'static,
174 ) -> Result<(), TaskPoolError> {
175 let pinned_cb = Box::pin(cb);
176
177 match self.ordering_queue.send(InnerTask::Task(pinned_cb)).await {
178 Ok(()) => Ok(()),
179 Err(_) => Err(TaskPoolError::FailedToSend),
180 }
181 }
182
183 /// Spawn a task in the pool with the given timeout.
184 /// Timeout is not for the task, but for channel queue insertion.
185 ///
186 /// # Errors
187 ///
188 /// Returns an error if fails to schedule (e.g., timeout or channel closed).
189 pub async fn spawn_with_timeout(
190 &self,
191 cb: impl Future<Output = ()> + Send + 'static,
192 timeout: std::time::Duration,
193 ) -> Result<(), TaskPoolError> {
194 let pinned_cb = Box::pin(cb);
195
196 match self
197 .ordering_queue
198 .send_timeout(InnerTask::Task(pinned_cb), timeout)
199 .await
200 {
201 Ok(()) => Ok(()),
202 Err(tokio::sync::mpsc::error::SendTimeoutError::Timeout(_)) => {
203 // Because the return type of the timeout include the given
204 // callback type, it forces us to be constraint by "Sync",
205 // without that, it's possible to directly send anonymous
206 // callback (like async{}).
207 Err(TaskPoolError::SendTimeout)
208 }
209 Err(tokio::sync::mpsc::error::SendTimeoutError::Closed(_)) => {
210 // _task_future is the Pin<Box<dyn Future + Send>>.
211 // It's not Sync and not Debug, so we can't easily include it in the error.
212 // We simply acknowledge the channel was closed.
213 Err(TaskPoolError::ChannelClosed)
214 }
215 }
216 }
217
218 /// Trigger a stop by adding a special event in the pool.
219 /// Everything after this marker will not be consumed and the consuming of
220 /// the queue will be stopped.
221 ///
222 /// # Errors
223 ///
224 /// Returns an error if fails to schedule.
225 pub async fn trigger_stop(&self) -> Result<(), TaskPoolError> {
226 // The only inner error possible is "channel close".
227 // So let's not depend on the return type and write the same thing
228 // manually. By doing that, we're not relying on the inner callback to
229 // be constraint by Sync.
230 match self.ordering_queue.send(InnerTask::Stop).await {
231 Ok(()) => Ok(()),
232 Err(_) => Err(TaskPoolError::ChannelClosed),
233 }
234 }
235}