use std::ops::{Deref, DerefMut};
use parking_lot::Mutex;
pub struct ObjectPool<T> {
pool: Mutex<Vec<T>>,
factory: Box<dyn Fn() -> T + Send + Sync>,
reset: Option<Box<dyn Fn(&mut T) + Send + Sync>>,
max_size: usize,
}
impl<T> ObjectPool<T> {
pub fn new<F>(factory: F) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
Self {
pool: Mutex::new(Vec::new()),
factory: Box::new(factory),
reset: None,
max_size: 1024,
}
}
pub fn with_reset<F, R>(factory: F, reset: R) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
R: Fn(&mut T) + Send + Sync + 'static,
{
Self {
pool: Mutex::new(Vec::new()),
factory: Box::new(factory),
reset: Some(Box::new(reset)),
max_size: 1024,
}
}
#[must_use]
pub fn with_max_size(mut self, max_size: usize) -> Self {
self.max_size = max_size;
self
}
pub fn prefill(&self, count: usize) {
let mut pool = self.pool.lock();
let to_add = count
.saturating_sub(pool.len())
.min(self.max_size - pool.len());
for _ in 0..to_add {
pool.push((self.factory)());
}
}
pub fn get(&self) -> Pooled<'_, T> {
let value = self.pool.lock().pop().unwrap_or_else(|| (self.factory)());
Pooled {
pool: self,
value: Some(value),
}
}
pub fn take(&self) -> T {
self.pool.lock().pop().unwrap_or_else(|| (self.factory)())
}
pub fn put(&self, mut value: T) {
if let Some(ref reset) = self.reset {
reset(&mut value);
}
let mut pool = self.pool.lock();
if pool.len() < self.max_size {
pool.push(value);
}
}
#[must_use]
pub fn available(&self) -> usize {
self.pool.lock().len()
}
#[must_use]
pub fn max_size(&self) -> usize {
self.max_size
}
pub fn clear(&self) {
self.pool.lock().clear();
}
}
pub struct Pooled<'a, T> {
pool: &'a ObjectPool<T>,
value: Option<T>,
}
impl<T> Pooled<'_, T> {
pub fn take(mut self) -> T {
self.value.take().expect("Value already taken")
}
}
impl<T> Deref for Pooled<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.value.as_ref().expect("Value already taken")
}
}
impl<T> DerefMut for Pooled<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.value.as_mut().expect("Value already taken")
}
}
impl<T> Drop for Pooled<'_, T> {
fn drop(&mut self) {
if let Some(value) = self.value.take() {
self.pool.put(value);
}
}
}
pub type VecPool<T> = ObjectPool<Vec<T>>;
impl<T: 'static> VecPool<T> {
pub fn new_vec_pool() -> Self {
ObjectPool::with_reset(Vec::new, |v| v.clear())
}
pub fn new_vec_pool_with_capacity(capacity: usize) -> Self {
ObjectPool::with_reset(move || Vec::with_capacity(capacity), |v| v.clear())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_basic() {
let pool: ObjectPool<Vec<u8>> = ObjectPool::new(Vec::new);
let mut obj = pool.get();
obj.push(1);
obj.push(2);
assert_eq!(&*obj, &[1, 2]);
drop(obj);
assert_eq!(pool.available(), 1);
let obj2 = pool.get();
assert_eq!(pool.available(), 0);
assert_eq!(&*obj2, &[1, 2]);
}
#[test]
fn test_pool_with_reset() {
let pool: ObjectPool<Vec<u8>> = ObjectPool::with_reset(Vec::new, Vec::clear);
let mut obj = pool.get();
obj.push(1);
obj.push(2);
drop(obj);
let obj2 = pool.get();
assert!(obj2.is_empty());
}
#[test]
fn test_pool_prefill() {
let pool: ObjectPool<String> = ObjectPool::new(String::new);
pool.prefill(10);
assert_eq!(pool.available(), 10);
let _obj = pool.get();
assert_eq!(pool.available(), 9);
}
#[test]
fn test_pool_max_size() {
let pool: ObjectPool<u64> = ObjectPool::new(|| 0).with_max_size(3);
pool.prefill(10);
assert_eq!(pool.available(), 3);
let o1 = pool.take();
let o2 = pool.take();
let o3 = pool.take();
assert_eq!(pool.available(), 0);
pool.put(o1);
pool.put(o2);
pool.put(o3);
pool.put(99);
assert_eq!(pool.available(), 3);
}
#[test]
fn test_pool_take_ownership() {
let pool: ObjectPool<String> = ObjectPool::new(String::new);
let mut obj = pool.get();
obj.push_str("hello");
let owned = obj.take();
assert_eq!(owned, "hello");
assert_eq!(pool.available(), 0);
}
#[test]
fn test_pool_clear() {
let pool: ObjectPool<u64> = ObjectPool::new(|| 0);
pool.prefill(10);
assert_eq!(pool.available(), 10);
pool.clear();
assert_eq!(pool.available(), 0);
}
#[test]
fn test_vec_pool() {
let pool: VecPool<u8> = VecPool::new_vec_pool();
let mut v = pool.get();
v.extend_from_slice(&[1, 2, 3]);
drop(v);
let v2 = pool.get();
assert!(v2.is_empty()); }
#[test]
fn test_vec_pool_with_capacity() {
let pool: VecPool<u8> = VecPool::new_vec_pool_with_capacity(100);
let v = pool.get();
assert!(v.capacity() >= 100);
}
#[test]
#[cfg(not(miri))] fn test_pool_thread_safety() {
use std::sync::Arc;
use std::thread;
let pool: Arc<ObjectPool<Vec<u8>>> = Arc::new(ObjectPool::with_reset(Vec::new, Vec::clear));
let handles: Vec<_> = (0..4)
.map(|_| {
let pool = Arc::clone(&pool);
thread::spawn(move || {
for _ in 0..100 {
let mut v = pool.get();
v.push(42);
}
})
})
.collect();
for h in handles {
h.join().unwrap();
}
assert!(pool.available() > 0);
}
}