tokio_task_pool/
lib.rs

1#![ doc = include_str!( concat!( env!( "CARGO_MANIFEST_DIR" ), "/", "README.md" ) ) ]
2#[cfg(feature = "log")]
3use log::error;
4use std::fmt;
5use std::future::Future;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Semaphore;
9use tokio::task::JoinHandle;
10
11pub type SpawnResult<T> = Result<JoinHandle<Result<<T as Future>::Output, Error>>, Error>;
12
13/// Task ID, can be created from &'static str or String
14#[derive(Debug, Clone, Eq, PartialEq)]
15pub enum TaskId {
16    Static(&'static str),
17    Owned(String),
18}
19
20impl From<&'static str> for TaskId {
21    #[inline]
22    fn from(s: &'static str) -> Self {
23        Self::Static(s)
24    }
25}
26
27impl From<String> for TaskId {
28    #[inline]
29    fn from(s: String) -> Self {
30        Self::Owned(s)
31    }
32}
33
34impl TaskId {
35    #[inline]
36    fn as_str(&self) -> &str {
37        match self {
38            TaskId::Static(v) => v,
39            TaskId::Owned(s) => s.as_str(),
40        }
41    }
42}
43
44/// Task
45///
46/// Contains Future, can contain custom ID and timeout
47pub struct Task<T>
48where
49    T: Future + Send + 'static,
50    T::Output: Send + 'static,
51{
52    id: Option<TaskId>,
53    timeout: Option<Duration>,
54    future: T,
55}
56
57impl<T> Task<T>
58where
59    T: Future + Send + 'static,
60    T::Output: Send + 'static,
61{
62    #[inline]
63    pub fn new(future: T) -> Self {
64        Self {
65            id: None,
66            timeout: None,
67            future,
68        }
69    }
70    #[inline]
71    pub fn with_id<I: Into<TaskId>>(mut self, id: I) -> Self {
72        self.id = Some(id.into());
73        self
74    }
75    #[inline]
76    pub fn with_timeout(mut self, timeout: Duration) -> Self {
77        self.timeout = Some(timeout);
78        self
79    }
80}
81
82#[derive(Debug, Clone, Eq, PartialEq)]
83pub enum Error {
84    SpawnTimeout,
85    RunTimeout(Option<TaskId>),
86    SpawnSemaphoneAcquireError,
87}
88
89impl fmt::Display for Error {
90    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
91        match self {
92            Error::SpawnTimeout => write!(f, "task spawn timeout"),
93            Error::RunTimeout(id) => {
94                if let Some(i) = id {
95                    write!(f, "task {} run timeout", i.as_str())
96                } else {
97                    write!(f, "task run timeout")
98                }
99            }
100            Error::SpawnSemaphoneAcquireError => write!(f, "task spawn semaphore error"),
101        }
102    }
103}
104
105impl std::error::Error for Error {}
106
107impl From<tokio::sync::AcquireError> for Error {
108    fn from(_: tokio::sync::AcquireError) -> Self {
109        Self::SpawnSemaphoneAcquireError
110    }
111}
112
113/// Task pool
114#[derive(Debug)]
115pub struct Pool {
116    id: Option<Arc<String>>,
117    spawn_timeout: Option<Duration>,
118    run_timeout: Option<Duration>,
119    limiter: Option<Arc<Semaphore>>,
120    capacity: Option<usize>,
121    #[cfg(feature = "log")]
122    logging_enabled: bool,
123}
124
125impl Default for Pool {
126    fn default() -> Self {
127        Self::unbounded()
128    }
129}
130
131impl Pool {
132    /// Creates a bounded pool (recommended)
133    pub fn bounded(capacity: usize) -> Self {
134        Self {
135            id: None,
136            spawn_timeout: None,
137            run_timeout: None,
138            limiter: Some(Arc::new(Semaphore::new(capacity))),
139            capacity: Some(capacity),
140            #[cfg(feature = "log")]
141            logging_enabled: true,
142        }
143    }
144    /// Creates an unbounded pool
145    pub fn unbounded() -> Self {
146        Self {
147            id: None,
148            spawn_timeout: None,
149            run_timeout: None,
150            limiter: None,
151            capacity: None,
152            #[cfg(feature = "log")]
153            logging_enabled: true,
154        }
155    }
156    pub fn with_id<I: Into<String>>(mut self, id: I) -> Self {
157        self.id.replace(Arc::new(id.into()));
158        self
159    }
160    pub fn id(&self) -> Option<&str> {
161        self.id.as_deref().map(String::as_str)
162    }
163    /// Sets spawn timeout
164    ///
165    /// (ignored for unbounded)
166    #[inline]
167    pub fn with_spawn_timeout(mut self, timeout: Duration) -> Self {
168        self.spawn_timeout = Some(timeout);
169        self
170    }
171    /// Sets the default task run timeout
172    #[inline]
173    pub fn with_run_timeout(mut self, timeout: Duration) -> Self {
174        self.run_timeout = Some(timeout);
175        self
176    }
177    /// Sets both spawn and run timeouts
178    #[inline]
179    pub fn with_timeout(self, timeout: Duration) -> Self {
180        self.with_spawn_timeout(timeout).with_run_timeout(timeout)
181    }
182    #[cfg(feature = "log")]
183    /// Disables internal error logging_enabled
184    #[inline]
185    pub fn with_no_logging_enabled(mut self) -> Self {
186        self.logging_enabled = false;
187        self
188    }
189    /// Returns pool capacity
190    #[inline]
191    pub fn capacity(&self) -> Option<usize> {
192        self.capacity
193    }
194    /// Returns pool available task permits
195    #[inline]
196    pub fn available_permits(&self) -> Option<usize> {
197        self.limiter.as_ref().map(|v| v.available_permits())
198    }
199    /// Returns pool busy task permits
200    #[inline]
201    pub fn busy_permits(&self) -> Option<usize> {
202        self.limiter
203            .as_ref()
204            .map(|v| self.capacity.unwrap_or_default() - v.available_permits())
205    }
206    /// Spawns a future
207    #[inline]
208    pub fn spawn<T>(&self, future: T) -> impl Future<Output = SpawnResult<T>> + '_
209    where
210        T: Future + Send + 'static,
211        T::Output: Send + 'static,
212    {
213        self.spawn_task(Task::new(future))
214    }
215    /// Spawns a future with a custom timeout
216    #[inline]
217    pub fn spawn_with_timeout<T>(
218        &self,
219        future: T,
220        timeout: Duration,
221    ) -> impl Future<Output = SpawnResult<T>> + '_
222    where
223        T: Future + Send + 'static,
224        T::Output: Send + 'static,
225    {
226        self.spawn_task(Task::new(future).with_timeout(timeout))
227    }
228    /// Spawns a task (a future which can have a custom ID and timeout)
229    pub async fn spawn_task<T>(&self, task: Task<T>) -> SpawnResult<T>
230    where
231        T: Future + Send + 'static,
232        T::Output: Send + 'static,
233    {
234        #[cfg(feature = "log")]
235        let id = self.id.as_ref().cloned();
236        let perm = if let Some(ref limiter) = self.limiter {
237            if let Some(spawn_timeout) = self.spawn_timeout {
238                Some(
239                    tokio::time::timeout(spawn_timeout, limiter.clone().acquire_owned())
240                        .await
241                        .map_err(|_| Error::SpawnTimeout)??,
242                )
243            } else {
244                Some(limiter.clone().acquire_owned().await?)
245            }
246        } else {
247            None
248        };
249        if let Some(rtimeout) = task.timeout.or(self.run_timeout) {
250            #[cfg(feature = "log")]
251            let logging_enabled = self.logging_enabled;
252            Ok(tokio::spawn(async move {
253                let _p = perm;
254                if let Ok(v) = tokio::time::timeout(rtimeout, task.future).await {
255                    Ok(v)
256                } else {
257                    let e = Error::RunTimeout(task.id);
258                    #[cfg(feature = "log")]
259                    if logging_enabled {
260                        error!("{}: {}", id.as_deref().map_or("", |v| v.as_str()), e);
261                    }
262                    Err(e)
263                }
264            }))
265        } else {
266            Ok(tokio::spawn(async move {
267                let _p = perm;
268                Ok(task.future.await)
269            }))
270        }
271    }
272}
273
274#[cfg(test)]
275mod test {
276    use super::Pool;
277    use std::sync::atomic::{AtomicUsize, Ordering};
278    use std::sync::Arc;
279    use std::time::Duration;
280    use tokio::sync::mpsc::channel;
281    use tokio::time::sleep;
282
283    #[tokio::test]
284    async fn test_spawn() {
285        let pool = Pool::bounded(5);
286        let counter = Arc::new(AtomicUsize::new(0));
287        for _ in 1..=5 {
288            let counter_c = counter.clone();
289            pool.spawn(async move {
290                sleep(Duration::from_secs(2)).await;
291                counter_c.fetch_add(1, Ordering::SeqCst);
292            })
293            .await
294            .unwrap();
295        }
296        sleep(Duration::from_secs(3)).await;
297        assert_eq!(counter.load(Ordering::SeqCst), 5);
298    }
299
300    #[tokio::test]
301    async fn test_spawn_timeout() {
302        let pool = Pool::bounded(5).with_spawn_timeout(Duration::from_secs(1));
303        for _ in 1..=5 {
304            let (tx, mut rx) = channel(1);
305            pool.spawn(async move {
306                tx.send(()).await.unwrap();
307                sleep(Duration::from_secs(2)).await;
308            })
309            .await
310            .unwrap();
311            rx.recv().await;
312        }
313        dbg!(pool.available_permits(), pool.busy_permits());
314        assert!(pool
315            .spawn(async move {
316                sleep(Duration::from_secs(2)).await;
317            })
318            .await
319            .is_err());
320    }
321
322    #[tokio::test]
323    async fn test_run_timeout() {
324        let pool = Pool::bounded(5).with_run_timeout(Duration::from_secs(2));
325        let counter = Arc::new(AtomicUsize::new(0));
326        for i in 1..=5 {
327            let counter_c = counter.clone();
328            pool.spawn(async move {
329                sleep(Duration::from_secs(if i == 5 { 3 } else { 1 })).await;
330                counter_c.fetch_add(1, Ordering::SeqCst);
331            })
332            .await
333            .unwrap();
334        }
335        sleep(Duration::from_secs(5)).await;
336        assert_eq!(counter.load(Ordering::SeqCst), 4);
337    }
338}