graphos_common/memory/
pool.rs1use std::ops::{Deref, DerefMut};
7
8use parking_lot::Mutex;
9
10pub struct ObjectPool<T> {
16 pool: Mutex<Vec<T>>,
18 factory: Box<dyn Fn() -> T + Send + Sync>,
20 reset: Option<Box<dyn Fn(&mut T) + Send + Sync>>,
22 max_size: usize,
24}
25
26impl<T> ObjectPool<T> {
27 pub fn new<F>(factory: F) -> Self
29 where
30 F: Fn() -> T + Send + Sync + 'static,
31 {
32 Self {
33 pool: Mutex::new(Vec::new()),
34 factory: Box::new(factory),
35 reset: None,
36 max_size: 1024,
37 }
38 }
39
40 pub fn with_reset<F, R>(factory: F, reset: R) -> Self
45 where
46 F: Fn() -> T + Send + Sync + 'static,
47 R: Fn(&mut T) + Send + Sync + 'static,
48 {
49 Self {
50 pool: Mutex::new(Vec::new()),
51 factory: Box::new(factory),
52 reset: Some(Box::new(reset)),
53 max_size: 1024,
54 }
55 }
56
57 #[must_use]
61 pub fn with_max_size(mut self, max_size: usize) -> Self {
62 self.max_size = max_size;
63 self
64 }
65
66 pub fn prefill(&self, count: usize) {
68 let mut pool = self.pool.lock();
69 let to_add = count
70 .saturating_sub(pool.len())
71 .min(self.max_size - pool.len());
72 for _ in 0..to_add {
73 pool.push((self.factory)());
74 }
75 }
76
77 pub fn get(&self) -> Pooled<'_, T> {
82 let value = self.pool.lock().pop().unwrap_or_else(|| (self.factory)());
83 Pooled {
84 pool: self,
85 value: Some(value),
86 }
87 }
88
89 pub fn take(&self) -> T {
93 self.pool.lock().pop().unwrap_or_else(|| (self.factory)())
94 }
95
96 pub fn put(&self, mut value: T) {
100 if let Some(ref reset) = self.reset {
101 reset(&mut value);
102 }
103
104 let mut pool = self.pool.lock();
105 if pool.len() < self.max_size {
106 pool.push(value);
107 }
108 }
110
111 #[must_use]
113 pub fn available(&self) -> usize {
114 self.pool.lock().len()
115 }
116
117 #[must_use]
119 pub fn max_size(&self) -> usize {
120 self.max_size
121 }
122
123 pub fn clear(&self) {
125 self.pool.lock().clear();
126 }
127}
128
129pub struct Pooled<'a, T> {
131 pool: &'a ObjectPool<T>,
132 value: Option<T>,
133}
134
135impl<T> Pooled<'_, T> {
136 pub fn take(mut self) -> T {
138 self.value.take().expect("Value already taken")
139 }
140}
141
142impl<T> Deref for Pooled<'_, T> {
143 type Target = T;
144
145 fn deref(&self) -> &Self::Target {
146 self.value.as_ref().expect("Value already taken")
147 }
148}
149
150impl<T> DerefMut for Pooled<'_, T> {
151 fn deref_mut(&mut self) -> &mut Self::Target {
152 self.value.as_mut().expect("Value already taken")
153 }
154}
155
156impl<T> Drop for Pooled<'_, T> {
157 fn drop(&mut self) {
158 if let Some(value) = self.value.take() {
159 self.pool.put(value);
160 }
161 }
162}
163
164pub type VecPool<T> = ObjectPool<Vec<T>>;
166
167impl<T: 'static> VecPool<T> {
168 pub fn new_vec_pool() -> Self {
170 ObjectPool::with_reset(Vec::new, |v| v.clear())
171 }
172
173 pub fn new_vec_pool_with_capacity(capacity: usize) -> Self {
175 ObjectPool::with_reset(move || Vec::with_capacity(capacity), |v| v.clear())
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_pool_basic() {
185 let pool: ObjectPool<Vec<u8>> = ObjectPool::new(Vec::new);
186
187 let mut obj = pool.get();
189 obj.push(1);
190 obj.push(2);
191 assert_eq!(&*obj, &[1, 2]);
192
193 drop(obj);
195 assert_eq!(pool.available(), 1);
196
197 let obj2 = pool.get();
199 assert_eq!(pool.available(), 0);
200
201 assert_eq!(&*obj2, &[1, 2]);
203 }
204
205 #[test]
206 fn test_pool_with_reset() {
207 let pool: ObjectPool<Vec<u8>> = ObjectPool::with_reset(Vec::new, Vec::clear);
208
209 let mut obj = pool.get();
210 obj.push(1);
211 obj.push(2);
212
213 drop(obj);
214
215 let obj2 = pool.get();
217 assert!(obj2.is_empty());
218 }
219
220 #[test]
221 fn test_pool_prefill() {
222 let pool: ObjectPool<String> = ObjectPool::new(String::new);
223
224 pool.prefill(10);
225 assert_eq!(pool.available(), 10);
226
227 let _obj = pool.get();
230 assert_eq!(pool.available(), 9);
231 }
232
233 #[test]
234 fn test_pool_max_size() {
235 let pool: ObjectPool<u64> = ObjectPool::new(|| 0).with_max_size(3);
236
237 pool.prefill(10);
238 assert_eq!(pool.available(), 3);
240
241 let o1 = pool.take();
243 let o2 = pool.take();
244 let o3 = pool.take();
245
246 assert_eq!(pool.available(), 0);
247
248 pool.put(o1);
249 pool.put(o2);
250 pool.put(o3);
251 pool.put(99); assert_eq!(pool.available(), 3);
254 }
255
256 #[test]
257 fn test_pool_take_ownership() {
258 let pool: ObjectPool<String> = ObjectPool::new(String::new);
259
260 let mut obj = pool.get();
261 obj.push_str("hello");
262
263 let owned = obj.take();
265 assert_eq!(owned, "hello");
266 assert_eq!(pool.available(), 0);
267 }
268
269 #[test]
270 fn test_pool_clear() {
271 let pool: ObjectPool<u64> = ObjectPool::new(|| 0);
272
273 pool.prefill(10);
274 assert_eq!(pool.available(), 10);
275
276 pool.clear();
277 assert_eq!(pool.available(), 0);
278 }
279
280 #[test]
281 fn test_vec_pool() {
282 let pool: VecPool<u8> = VecPool::new_vec_pool();
283
284 let mut v = pool.get();
285 v.extend_from_slice(&[1, 2, 3]);
286
287 drop(v);
288
289 let v2 = pool.get();
290 assert!(v2.is_empty()); }
292
293 #[test]
294 fn test_vec_pool_with_capacity() {
295 let pool: VecPool<u8> = VecPool::new_vec_pool_with_capacity(100);
296
297 let v = pool.get();
298 assert!(v.capacity() >= 100);
299 }
300
301 #[test]
302 fn test_pool_thread_safety() {
303 use std::sync::Arc;
304 use std::thread;
305
306 let pool: Arc<ObjectPool<Vec<u8>>> = Arc::new(ObjectPool::with_reset(Vec::new, Vec::clear));
307
308 let handles: Vec<_> = (0..4)
309 .map(|_| {
310 let pool = Arc::clone(&pool);
311 thread::spawn(move || {
312 for _ in 0..100 {
313 let mut v = pool.get();
314 v.push(42);
315 }
317 })
318 })
319 .collect();
320
321 for h in handles {
322 h.join().unwrap();
323 }
324
325 assert!(pool.available() > 0);
327 }
328}