apool/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod sync;
4#[cfg(test)]
5mod tests;
6
7use smallvec::SmallVec;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::sync::Arc;
11use sync::*;
12use std::ops::{Deref, DerefMut};
13use std::pin::Pin;
14
15pub struct Pool<T: 'static + Sync + Send, S: Sync> {
16    state: Mutex<(usize, S)>,
17    pool: Arc<Mutex<VecDeque<PoolEntry<T>>>>,
18    used: Arc<Semaphore>,
19    max: usize,
20    transformer: Box<dyn Fn(&mut S, &mut PoolTransformer<T>) + Send + Sync + 'static>,
21}
22
23pub struct PoolGuard<'a, T: 'static + Sync + Send, S: Sync> {
24    pool: &'a Pool<T, S>,
25    id: usize,
26    inner: OwnedMutexGuard<T>,
27    permit: Option<OwnedSemaphorePermit>,
28}
29
30pub struct PoolTransformer<'a, T: 'static + Sync + Send> {
31    spawn: SmallVec<[Pin<Box<dyn Future<Output = T> + Send +'a>>; 4]>,
32}
33
34struct PoolEntry<T: 'static + Sync + Send> {
35    id: usize,
36    mutex: Arc<Mutex<T>>,
37}
38
39impl<T: 'static + Sync + Send, S: Sync> Pool<T, S> {
40    ///```
41    ///use apool::Pool;
42    ///
43    ///let pool = Pool::<usize, usize>::new(
44    ///    4,
45    ///    0,
46    ///    |state, transformer| {
47    ///        let i = *state;
48    ///        *state += 1;
49    ///        transformer.spawn(async move {
50    ///            i
51    ///        });
52    ///    },
53    ///);
54    /// ```
55    pub fn new<'a>(
56        max: usize,
57        state: S,
58        transform: impl Fn(&mut S, &mut PoolTransformer<T>) + Sync + Send +'static,
59    ) -> Self {
60        if max == 0 {
61            panic!("max pool size is not allowed to be 0");
62        }
63        Self {
64            state: Mutex::new((0, state)),
65            pool: Arc::new(Mutex::new(VecDeque::with_capacity(max))),
66            used: Arc::new(Semaphore::new(0)),
67            max,
68            transformer: Box::new(transform),
69        }
70    }
71
72    pub async fn get(&self) -> PoolGuard<'_, T, S> {
73        let permit = match self.used.clone().try_acquire_owned() {
74            Ok(permit) => permit,
75            Err(_) => {
76                self.try_spawn_new().await;
77                self.used.clone()
78                    .acquire_owned()
79                    .await
80                    .expect("Semaphore should not be closed")
81            }
82        };
83        let mut llock = self.pool.lock().await;
84        let entry = match llock.pop_back() {
85            Some(entry) => entry,
86            None => panic!("Obtaining a pool entry should not fail"),
87        };
88
89        let PoolEntry { id, mutex } = entry.clone();
90        llock.push_front(entry);
91
92        PoolGuard {
93            pool: self,
94            inner: match mutex.try_lock_owned() {
95                Ok(lock) => lock,
96                Err(_) => panic!("Invalid pool list order"),
97            },
98            id,
99            permit: Some(permit),
100        }
101    }
102
103    async fn try_spawn_new(&self) {
104        if { self.pool.lock().await.len() } >= self.max {
105            return;
106        }
107        if let Ok(mut guard) = self.state.try_lock() {
108            let (id, state) = &mut *guard;
109            let mut transformer = PoolTransformer {
110                spawn: SmallVec::new(),
111            };
112            (self.transformer)(state, &mut transformer);
113            let mut llock = self.pool.lock().await;
114            let len = transformer.spawn.len();
115            for item in transformer.spawn {
116                let item = item.await;
117                *id += 1;
118                llock.push_back(PoolEntry {
119                    id: *id,
120                    mutex: Arc::new(Mutex::new(item)),
121                });
122            }
123            drop(llock);
124            self.used.add_permits(len);
125        }
126    }
127}
128
129impl<'a, T: 'static + Sync + Send> PoolTransformer<'a, T> {
130    pub fn spawn(&mut self, future: impl Future<Output = T> + Send + 'a) {
131        self.spawn.push(Box::pin(future));
132    }
133}
134
135impl<'a, T: 'static + Sync + Send, S: Sync> Deref for PoolGuard<'a, T, S> {
136    type Target = T;
137
138    fn deref(&self) -> &Self::Target {
139        &*self.inner
140    }
141}
142
143impl<'a, T: 'static + Sync + Send, S: Sync> DerefMut for PoolGuard<'a, T, S> {
144    fn deref_mut(&mut self) -> &mut Self::Target {
145        &mut *self.inner
146    }
147}
148
149impl<'a, T: 'static + Sync + Send, S: Sync> Drop for PoolGuard<'a, T, S> {
150    fn drop(&mut self) {
151        let permit = self.permit.take().unwrap();
152        let mutex = self.pool.pool.clone();
153        let id = self.id;
154
155        tokio::runtime::Handle::current().spawn(async move {
156            let mut llock = mutex.lock().await;
157            let index = (0usize..)
158                .zip(llock.iter())
159                .find(|item| item.1.id == id)
160                .map(|item| item.0)
161                .unwrap();
162            let item = llock.remove(index).unwrap();
163            llock.push_back(item);
164            drop(permit);
165        });
166    }
167}
168
169impl<T: 'static + Sync + Send> Clone for PoolEntry<T> {
170    fn clone(&self) -> Self {
171        Self {
172            id: self.id,
173            mutex: self.mutex.clone(),
174        }
175    }
176}