1use crate::error::PseudoPoolError;
2use crate::Result;
3use std::collections::HashMap;
4use std::marker::PhantomData;
5use std::sync::atomic::{AtomicBool, Ordering};
6use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
7use std::time::Duration;
8use uuid::Uuid;
9
10const POOL_POLLING_TIMEOUT: Duration = Duration::from_secs(5);
11
12type PoolEntryId = Uuid;
13
14struct PoolEntry<T> {
15 pool_entry_id: PoolEntryId,
16 payload: Arc<RwLock<T>>,
17}
18
19impl<T> PoolEntry<T> {
20 fn new(payload: T) -> Self {
21 let pool_entry_id = Uuid::new_v4();
22
23 Self {
24 pool_entry_id,
25 payload: Arc::new(RwLock::new(payload)),
26 }
27 }
28}
29
30impl<T> Clone for PoolEntry<T> {
31 fn clone(&self) -> Self {
32 Self {
33 pool_entry_id: self.pool_entry_id,
34 payload: self.payload.clone(),
35 }
36 }
37}
38
39pub struct ExternalPoolEntry<T> {
42 pool_entry: PoolEntry<T>,
43 notifier: crossbeam_channel::Sender<PoolEntryId>,
44 phantom: PhantomData<()>,
46}
47
48impl<T> ExternalPoolEntry<T> {
49 fn new(pool_entry: PoolEntry<T>, notifier: crossbeam_channel::Sender<PoolEntryId>) -> Self {
50 ExternalPoolEntry {
51 pool_entry,
52 notifier,
53 phantom: PhantomData,
54 }
55 }
56
57 pub fn get_payload(&self) -> RwLockReadGuard<T> {
59 self.pool_entry.payload.read().unwrap()
60 }
61
62 pub fn get_payload_mut(&mut self) -> RwLockWriteGuard<T> {
66 self.pool_entry.payload.write().unwrap()
67 }
68}
69
70impl<T> Drop for ExternalPoolEntry<T> {
72 fn drop(&mut self) {
73 let id = self.pool_entry.pool_entry_id;
74 self.notifier.send(id).unwrap()
75 }
76}
77
78struct InternalPoolEntry<T> {
79 pool_entry: PoolEntry<T>,
80 in_use: AtomicBool,
81}
82
83impl<T> InternalPoolEntry<T> {
84 fn new(payload: T) -> Self {
85 Self {
86 pool_entry: PoolEntry::new(payload),
87 in_use: AtomicBool::new(false),
88 }
89 }
90}
91
92pub struct Pool<T> {
96 map: HashMap<PoolEntryId, InternalPoolEntry<T>>,
97 notification_sender: crossbeam_channel::Sender<PoolEntryId>,
98 notification_receiver: crossbeam_channel::Receiver<PoolEntryId>,
99}
100
101impl<T> Pool<T> {
102 pub fn new() -> Self {
104 let (notification_sender, notification_receiver) = crossbeam_channel::unbounded();
105 Self {
106 map: HashMap::new(),
107 notification_sender,
108 notification_receiver,
109 }
110 }
111
112 pub fn new_from_iterable<V: IntoIterator<Item = T>>(vec: V) -> Self {
114 let mut pool = Self::new();
115 pool.extend_entries(vec);
116 pool
117 }
118
119 pub fn add_entry(&mut self, payload: T) {
121 let entry = InternalPoolEntry::new(payload);
122 self.map.insert(entry.pool_entry.pool_entry_id, entry);
123 }
124
125 pub fn extend_entries<V: IntoIterator<Item = T>>(&mut self, vec: V) {
127 for payload in vec {
128 self.add_entry(payload);
129 }
130 }
131
132 fn get_external_entry(&mut self, entry: PoolEntry<T>) -> ExternalPoolEntry<T> {
133 ExternalPoolEntry::new(entry, self.notification_sender.clone())
134 }
135
136 pub fn checkout_blocking(&mut self) -> Result<ExternalPoolEntry<T>> {
139 loop {
140 if let Some(entry) = self.try_checkout() {
141 return Ok(entry);
142 }
143
144 let entry_id = self
145 .notification_receiver
146 .recv_timeout(POOL_POLLING_TIMEOUT);
147
148 if let Ok(entry_id) = entry_id {
149 self.checkin(entry_id)?;
150 } else {
151 }
153 }
154 }
155
156 pub fn try_checkout(&mut self) -> Option<ExternalPoolEntry<T>> {
159 self.process_checkins();
160 for (_, entry) in self.map.iter_mut() {
161 if let Ok(in_use) =
162 entry
163 .in_use
164 .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
165 {
166 assert!(!in_use);
167 let pool_entry = entry.pool_entry.clone();
168 return Some(self.get_external_entry(pool_entry));
169 }
170 }
171 None
172 }
173
174 fn process_checkins(&mut self) {
175 if self.notification_receiver.is_empty() {
176 return;
177 }
178 loop {
179 let entry_id = self.notification_receiver.try_recv();
180 if let Ok(entry_id) = entry_id {
181 self.checkin(entry_id).unwrap()
182 } else {
183 return;
184 }
185 }
186 }
187
188 fn checkin(&mut self, entry_id: PoolEntryId) -> Result<()> {
189 let entry = self.map.get(&entry_id);
190 if let Some(entry) = entry {
191 entry.in_use.store(false, Ordering::Release);
192 Ok(())
193 } else {
194 Err(PseudoPoolError::InvalidCheckin(entry_id))
195 }
196 }
197
198 pub fn update_leases(&mut self) -> usize {
201 self.process_checkins();
202 self.leases()
203 }
204
205 pub fn leases(&self) -> usize {
208 self.map
209 .iter()
210 .filter(|(_, entry)| !entry.in_use.load(Ordering::Acquire))
211 .count()
212 }
213}
214
215impl<T> Default for Pool<T> {
216 fn default() -> Self {
217 Self::new()
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_non_blocking() {
227 let mut pool = Pool::new();
228 pool.add_entry(String::from("test"));
229 pool.add_entry(String::from("test2"));
230 pool.add_entry(String::from("test3"));
231 pool.add_entry(String::from("test4"));
232 pool.add_entry(String::from("test5"));
233
234 assert_eq!(5, pool.leases());
235
236 let l1 = pool.try_checkout().unwrap();
237 assert_eq!(pool.leases(), 4);
238 let l2 = pool.try_checkout().unwrap();
239 assert_eq!(pool.update_leases(), 3);
240 assert_ne!(*l1.get_payload(), *l2.get_payload());
241 drop(l1);
242 assert_eq!(pool.leases(), 3);
243 assert_eq!(pool.update_leases(), 4);
244 assert_eq!(pool.leases(), 4);
245 let l1a = pool.try_checkout().unwrap();
246 assert_eq!(pool.leases(), 3);
247 let l2_value = (*l2.get_payload()).clone();
248 drop(l2);
249 assert_eq!(pool.update_leases(), 4);
250 let l2a = pool.try_checkout().unwrap();
251 assert_eq!(*l2a.get_payload(), l2_value);
252 assert_eq!(pool.leases(), 3);
253 let l3 = pool.try_checkout().unwrap();
254 assert_ne!(*l3.get_payload(), l2_value);
255 assert_eq!(pool.leases(), 2);
256 let l4 = pool.try_checkout().unwrap();
257 assert_eq!(pool.leases(), 1);
258 let l5 = pool.try_checkout().unwrap();
259 assert_ne!(*l5.get_payload(), *l4.get_payload());
260 assert_eq!(pool.leases(), 0);
261 let l0 = pool.try_checkout();
262 assert!(l0.is_none());
263 let l1a_value = (*l1a.get_payload()).clone();
264 drop(l1a);
265 assert_eq!(pool.leases(), 0);
266 let l1_returns = pool.try_checkout().unwrap();
267 assert_eq!(pool.leases(), 0);
268 assert_eq!(*l1_returns.get_payload(), l1a_value);
269 }
270
271 #[test]
272 fn test_blocking() {
273 let mut pool = Pool::new_from_iterable(vec![String::from("test1"), String::from("test2")]);
274 pool.extend_entries(vec![String::from("test3"), String::from("test4")]);
275 assert_eq!(pool.leases(), 4);
276 let l1 = pool.checkout_blocking().unwrap();
277 assert_eq!(pool.update_leases(), 3);
278 let l2 = pool.checkout_blocking().unwrap();
279 assert_eq!(pool.update_leases(), 2);
280 let _l3 = pool.checkout_blocking().unwrap();
281 assert_eq!(pool.update_leases(), 1);
282 assert_ne!(*l1.get_payload(), *l2.get_payload());
283 drop(l1);
284 assert_eq!(pool.update_leases(), 2);
285 let _l1a = pool.checkout_blocking().unwrap();
286 assert_eq!(pool.update_leases(), 1);
287 let _l4 = pool.checkout_blocking().unwrap();
288 assert_eq!(pool.update_leases(), 0);
289 }
291}