1#![doc = include_str!("../README.md")]
2
3mod sync;
4#[cfg(test)]
5mod tests;
6
7use smallvec::SmallVec;
8use std::collections::VecDeque;
9use std::future::Future;
10use std::sync::Arc;
11use sync::*;
12use std::ops::{Deref, DerefMut};
13use std::pin::Pin;
14
15pub struct Pool<T: 'static + Sync + Send, S: Sync> {
16 state: Mutex<(usize, S)>,
17 pool: Arc<Mutex<VecDeque<PoolEntry<T>>>>,
18 used: Arc<Semaphore>,
19 max: usize,
20 transformer: Box<dyn Fn(&mut S, &mut PoolTransformer<T>) + Send + Sync + 'static>,
21}
22
23pub struct PoolGuard<'a, T: 'static + Sync + Send, S: Sync> {
24 pool: &'a Pool<T, S>,
25 id: usize,
26 inner: OwnedMutexGuard<T>,
27 permit: Option<OwnedSemaphorePermit>,
28}
29
30pub struct PoolTransformer<'a, T: 'static + Sync + Send> {
31 spawn: SmallVec<[Pin<Box<dyn Future<Output = T> + Send +'a>>; 4]>,
32}
33
34struct PoolEntry<T: 'static + Sync + Send> {
35 id: usize,
36 mutex: Arc<Mutex<T>>,
37}
38
39impl<T: 'static + Sync + Send, S: Sync> Pool<T, S> {
40 pub fn new<'a>(
56 max: usize,
57 state: S,
58 transform: impl Fn(&mut S, &mut PoolTransformer<T>) + Sync + Send +'static,
59 ) -> Self {
60 if max == 0 {
61 panic!("max pool size is not allowed to be 0");
62 }
63 Self {
64 state: Mutex::new((0, state)),
65 pool: Arc::new(Mutex::new(VecDeque::with_capacity(max))),
66 used: Arc::new(Semaphore::new(0)),
67 max,
68 transformer: Box::new(transform),
69 }
70 }
71
72 pub async fn get(&self) -> PoolGuard<'_, T, S> {
73 let permit = match self.used.clone().try_acquire_owned() {
74 Ok(permit) => permit,
75 Err(_) => {
76 self.try_spawn_new().await;
77 self.used.clone()
78 .acquire_owned()
79 .await
80 .expect("Semaphore should not be closed")
81 }
82 };
83 let mut llock = self.pool.lock().await;
84 let entry = match llock.pop_back() {
85 Some(entry) => entry,
86 None => panic!("Obtaining a pool entry should not fail"),
87 };
88
89 let PoolEntry { id, mutex } = entry.clone();
90 llock.push_front(entry);
91
92 PoolGuard {
93 pool: self,
94 inner: match mutex.try_lock_owned() {
95 Ok(lock) => lock,
96 Err(_) => panic!("Invalid pool list order"),
97 },
98 id,
99 permit: Some(permit),
100 }
101 }
102
103 async fn try_spawn_new(&self) {
104 if { self.pool.lock().await.len() } >= self.max {
105 return;
106 }
107 if let Ok(mut guard) = self.state.try_lock() {
108 let (id, state) = &mut *guard;
109 let mut transformer = PoolTransformer {
110 spawn: SmallVec::new(),
111 };
112 (self.transformer)(state, &mut transformer);
113 let mut llock = self.pool.lock().await;
114 let len = transformer.spawn.len();
115 for item in transformer.spawn {
116 let item = item.await;
117 *id += 1;
118 llock.push_back(PoolEntry {
119 id: *id,
120 mutex: Arc::new(Mutex::new(item)),
121 });
122 }
123 drop(llock);
124 self.used.add_permits(len);
125 }
126 }
127}
128
129impl<'a, T: 'static + Sync + Send> PoolTransformer<'a, T> {
130 pub fn spawn(&mut self, future: impl Future<Output = T> + Send + 'a) {
131 self.spawn.push(Box::pin(future));
132 }
133}
134
135impl<'a, T: 'static + Sync + Send, S: Sync> Deref for PoolGuard<'a, T, S> {
136 type Target = T;
137
138 fn deref(&self) -> &Self::Target {
139 &*self.inner
140 }
141}
142
143impl<'a, T: 'static + Sync + Send, S: Sync> DerefMut for PoolGuard<'a, T, S> {
144 fn deref_mut(&mut self) -> &mut Self::Target {
145 &mut *self.inner
146 }
147}
148
149impl<'a, T: 'static + Sync + Send, S: Sync> Drop for PoolGuard<'a, T, S> {
150 fn drop(&mut self) {
151 let permit = self.permit.take().unwrap();
152 let mutex = self.pool.pool.clone();
153 let id = self.id;
154
155 tokio::runtime::Handle::current().spawn(async move {
156 let mut llock = mutex.lock().await;
157 let index = (0usize..)
158 .zip(llock.iter())
159 .find(|item| item.1.id == id)
160 .map(|item| item.0)
161 .unwrap();
162 let item = llock.remove(index).unwrap();
163 llock.push_back(item);
164 drop(permit);
165 });
166 }
167}
168
169impl<T: 'static + Sync + Send> Clone for PoolEntry<T> {
170 fn clone(&self) -> Self {
171 Self {
172 id: self.id,
173 mutex: self.mutex.clone(),
174 }
175 }
176}