Skip to main content

grafeo_common/memory/
pool.rs

1//! Object pool for reusing frequently allocated types.
2//!
3//! If you're creating and destroying the same type of object repeatedly
4//! (like temporary buffers during query execution), a pool avoids the
5//! allocation overhead. Objects are reset and returned to the pool instead
6//! of being freed.
7
8use std::ops::{Deref, DerefMut};
9
10use parking_lot::Mutex;
11
12/// A thread-safe object pool for reusing allocations.
13///
14/// Use [`get()`](Self::get) to grab an object (created fresh if the pool is
15/// empty). When you drop the returned [`Pooled`] wrapper, the object goes
16/// back to the pool for reuse.
17///
18/// # Examples
19///
20/// ```
21/// use grafeo_common::memory::ObjectPool;
22///
23/// // Pool of vectors that get cleared on return
24/// let pool = ObjectPool::with_reset(Vec::<u8>::new, |v| v.clear());
25///
26/// let mut buf = pool.get();
27/// buf.extend_from_slice(&[1, 2, 3]);
28/// // buf is returned to pool when dropped, and cleared
29/// ```
30pub struct ObjectPool<T> {
31    /// The pool of available objects.
32    pool: Mutex<Vec<T>>,
33    /// Factory function to create new objects.
34    factory: Box<dyn Fn() -> T + Send + Sync>,
35    /// Optional reset function called when returning objects to the pool.
36    reset: Option<Box<dyn Fn(&mut T) + Send + Sync>>,
37    /// Maximum pool size.
38    max_size: usize,
39}
40
41impl<T> ObjectPool<T> {
42    /// Creates a new object pool with the given factory function.
43    pub fn new<F>(factory: F) -> Self
44    where
45        F: Fn() -> T + Send + Sync + 'static,
46    {
47        Self {
48            pool: Mutex::new(Vec::new()),
49            factory: Box::new(factory),
50            reset: None,
51            max_size: 1024,
52        }
53    }
54
55    /// Creates a new object pool with a factory and reset function.
56    ///
57    /// The reset function is called when an object is returned to the pool,
58    /// allowing you to clear or reinitialize the object for reuse.
59    pub fn with_reset<F, R>(factory: F, reset: R) -> Self
60    where
61        F: Fn() -> T + Send + Sync + 'static,
62        R: Fn(&mut T) + Send + Sync + 'static,
63    {
64        Self {
65            pool: Mutex::new(Vec::new()),
66            factory: Box::new(factory),
67            reset: Some(Box::new(reset)),
68            max_size: 1024,
69        }
70    }
71
72    /// Sets the maximum pool size.
73    ///
74    /// Objects returned when the pool is at capacity will be dropped instead.
75    #[must_use]
76    pub fn with_max_size(mut self, max_size: usize) -> Self {
77        self.max_size = max_size;
78        self
79    }
80
81    /// Pre-populates the pool with `count` objects.
82    pub fn prefill(&self, count: usize) {
83        let mut pool = self.pool.lock();
84        let to_add = count
85            .saturating_sub(pool.len())
86            .min(self.max_size - pool.len());
87        for _ in 0..to_add {
88            pool.push((self.factory)());
89        }
90    }
91
92    /// Takes an object from the pool, creating a new one if necessary.
93    ///
94    /// Returns a `Pooled` wrapper that will return the object to the pool
95    /// when dropped.
96    pub fn get(&self) -> Pooled<'_, T> {
97        let value = self.pool.lock().pop().unwrap_or_else(|| (self.factory)());
98        Pooled {
99            pool: self,
100            value: Some(value),
101        }
102    }
103
104    /// Takes an object from the pool without wrapping it.
105    ///
106    /// The caller is responsible for returning the object via `put()` if desired.
107    pub fn take(&self) -> T {
108        self.pool.lock().pop().unwrap_or_else(|| (self.factory)())
109    }
110
111    /// Returns an object to the pool.
112    ///
113    /// If the pool is at capacity, the object is dropped instead.
114    pub fn put(&self, mut value: T) {
115        if let Some(ref reset) = self.reset {
116            reset(&mut value);
117        }
118
119        let mut pool = self.pool.lock();
120        if pool.len() < self.max_size {
121            pool.push(value);
122        }
123        // Otherwise, value is dropped
124    }
125
126    /// Returns the current number of objects in the pool.
127    #[must_use]
128    pub fn available(&self) -> usize {
129        self.pool.lock().len()
130    }
131
132    /// Returns the maximum pool size.
133    #[must_use]
134    pub fn max_size(&self) -> usize {
135        self.max_size
136    }
137
138    /// Clears all objects from the pool.
139    pub fn clear(&self) {
140        self.pool.lock().clear();
141    }
142}
143
144/// A borrowed object from the pool - returns automatically when dropped.
145///
146/// Use [`take()`](Self::take) if you need to keep the object instead of
147/// returning it to the pool.
148pub struct Pooled<'a, T> {
149    pool: &'a ObjectPool<T>,
150    value: Option<T>,
151}
152
153impl<T> Pooled<'_, T> {
154    /// Takes ownership of the inner value, preventing it from being returned to the pool.
155    pub fn take(mut self) -> T {
156        self.value.take().expect("Value already taken")
157    }
158}
159
160impl<T> Deref for Pooled<'_, T> {
161    type Target = T;
162
163    fn deref(&self) -> &Self::Target {
164        self.value.as_ref().expect("Value already taken")
165    }
166}
167
168impl<T> DerefMut for Pooled<'_, T> {
169    fn deref_mut(&mut self) -> &mut Self::Target {
170        self.value.as_mut().expect("Value already taken")
171    }
172}
173
174impl<T> Drop for Pooled<'_, T> {
175    fn drop(&mut self) {
176        if let Some(value) = self.value.take() {
177            self.pool.put(value);
178        }
179    }
180}
181
182/// A specialized pool for `Vec<T>` that clears vectors on return.
183pub type VecPool<T> = ObjectPool<Vec<T>>;
184
185impl<T: 'static> VecPool<T> {
186    /// Creates a new vector pool.
187    pub fn new_vec_pool() -> Self {
188        ObjectPool::with_reset(Vec::new, |v| v.clear())
189    }
190
191    /// Creates a new vector pool with pre-allocated capacity.
192    pub fn new_vec_pool_with_capacity(capacity: usize) -> Self {
193        ObjectPool::with_reset(move || Vec::with_capacity(capacity), |v| v.clear())
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_pool_basic() {
203        let pool: ObjectPool<Vec<u8>> = ObjectPool::new(Vec::new);
204
205        // Get an object
206        let mut obj = pool.get();
207        obj.push(1);
208        obj.push(2);
209        assert_eq!(&*obj, &[1, 2]);
210
211        // Object should be returned to pool when dropped
212        drop(obj);
213        assert_eq!(pool.available(), 1);
214
215        // Get should return the pooled object
216        let obj2 = pool.get();
217        assert_eq!(pool.available(), 0);
218
219        // The returned object still has data (no reset function)
220        assert_eq!(&*obj2, &[1, 2]);
221    }
222
223    #[test]
224    fn test_pool_with_reset() {
225        let pool: ObjectPool<Vec<u8>> = ObjectPool::with_reset(Vec::new, Vec::clear);
226
227        let mut obj = pool.get();
228        obj.push(1);
229        obj.push(2);
230
231        drop(obj);
232
233        // Get should return a cleared object
234        let obj2 = pool.get();
235        assert!(obj2.is_empty());
236    }
237
238    #[test]
239    fn test_pool_prefill() {
240        let pool: ObjectPool<String> = ObjectPool::new(String::new);
241
242        pool.prefill(10);
243        assert_eq!(pool.available(), 10);
244
245        // Getting objects should reduce available count
246        // Note: we must keep the Pooled handle alive, otherwise it returns the object on drop
247        let _obj = pool.get();
248        assert_eq!(pool.available(), 9);
249    }
250
251    #[test]
252    fn test_pool_max_size() {
253        let pool: ObjectPool<u64> = ObjectPool::new(|| 0).with_max_size(3);
254
255        pool.prefill(10);
256        // Should only have 3 objects
257        assert_eq!(pool.available(), 3);
258
259        // Return more than max - extras should be dropped
260        let o1 = pool.take();
261        let o2 = pool.take();
262        let o3 = pool.take();
263
264        assert_eq!(pool.available(), 0);
265
266        pool.put(o1);
267        pool.put(o2);
268        pool.put(o3);
269        pool.put(99); // This one should be dropped
270
271        assert_eq!(pool.available(), 3);
272    }
273
274    #[test]
275    fn test_pool_take_ownership() {
276        let pool: ObjectPool<String> = ObjectPool::new(String::new);
277
278        let mut obj = pool.get();
279        obj.push_str("hello");
280
281        // Take ownership - should NOT return to pool
282        let owned = obj.take();
283        assert_eq!(owned, "hello");
284        assert_eq!(pool.available(), 0);
285    }
286
287    #[test]
288    fn test_pool_clear() {
289        let pool: ObjectPool<u64> = ObjectPool::new(|| 0);
290
291        pool.prefill(10);
292        assert_eq!(pool.available(), 10);
293
294        pool.clear();
295        assert_eq!(pool.available(), 0);
296    }
297
298    #[test]
299    fn test_vec_pool() {
300        let pool: VecPool<u8> = VecPool::new_vec_pool();
301
302        let mut v = pool.get();
303        v.extend_from_slice(&[1, 2, 3]);
304
305        drop(v);
306
307        let v2 = pool.get();
308        assert!(v2.is_empty()); // Should be cleared
309    }
310
311    #[test]
312    fn test_vec_pool_with_capacity() {
313        let pool: VecPool<u8> = VecPool::new_vec_pool_with_capacity(100);
314
315        let v = pool.get();
316        assert!(v.capacity() >= 100);
317    }
318
319    #[test]
320    fn test_pool_thread_safety() {
321        use std::sync::Arc;
322        use std::thread;
323
324        let pool: Arc<ObjectPool<Vec<u8>>> = Arc::new(ObjectPool::with_reset(Vec::new, Vec::clear));
325
326        let handles: Vec<_> = (0..4)
327            .map(|_| {
328                let pool = Arc::clone(&pool);
329                thread::spawn(move || {
330                    for _ in 0..100 {
331                        let mut v = pool.get();
332                        v.push(42);
333                        // v is automatically returned on drop
334                    }
335                })
336            })
337            .collect();
338
339        for h in handles {
340            h.join().unwrap();
341        }
342
343        // Pool should have some objects
344        assert!(pool.available() > 0);
345    }
346}