use std::ops::{Deref, DerefMut};
use std::sync::{Arc, Mutex};
struct PoolState<T> {
idle: Vec<T>,
factory: Box<dyn Fn() -> T + Send + Sync>,
reset: Box<dyn Fn(&mut T) + Send + Sync>,
max_size: usize,
total_acquired: usize,
in_use: usize,
}
impl<T> PoolState<T> {
fn new<F, R>(factory: F, reset: R, max_size: usize) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
R: Fn(&mut T) + Send + Sync + 'static,
{
PoolState {
idle: Vec::new(),
factory: Box::new(factory),
reset: Box::new(reset),
max_size,
total_acquired: 0,
in_use: 0,
}
}
fn acquire_obj(&mut self) -> T {
let obj = self.idle.pop().unwrap_or_else(|| (self.factory)());
self.total_acquired += 1;
self.in_use += 1;
obj
}
fn return_obj(&mut self, mut obj: T) {
self.in_use = self.in_use.saturating_sub(1);
if self.idle.len() < self.max_size {
(self.reset)(&mut obj);
self.idle.push(obj);
}
}
}
#[derive(Clone)]
pub struct ObjectPool<T: 'static> {
state: Arc<Mutex<PoolState<T>>>,
}
impl<T: 'static> ObjectPool<T> {
pub fn builder() -> ObjectPoolBuilder<T> {
ObjectPoolBuilder::new()
}
pub fn new<F, R>(factory: F, reset: R, max_size: usize) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
R: Fn(&mut T) + Send + Sync + 'static,
{
ObjectPool {
state: Arc::new(Mutex::new(PoolState::new(factory, reset, max_size))),
}
}
fn lock_state(&self) -> std::sync::MutexGuard<'_, PoolState<T>> {
self.state
.lock()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub fn acquire(&self) -> PoolGuard<T> {
let obj = self.lock_state().acquire_obj();
PoolGuard {
obj: Some(obj),
pool: Arc::clone(&self.state),
}
}
pub fn warm_up(&self, n: usize) {
let mut state = self.lock_state();
let to_add = n.min(state.max_size.saturating_sub(state.idle.len()));
for _ in 0..to_add {
let obj = (state.factory)();
state.idle.push(obj);
}
}
#[inline]
pub fn pool_size(&self) -> usize {
self.lock_state().idle.len()
}
#[inline]
pub fn in_use_count(&self) -> usize {
self.lock_state().in_use
}
#[inline]
pub fn total_acquired(&self) -> usize {
self.lock_state().total_acquired
}
#[inline]
pub fn max_size(&self) -> usize {
self.lock_state().max_size
}
}
pub struct PoolGuard<T: 'static> {
obj: Option<T>,
pool: Arc<Mutex<PoolState<T>>>,
}
impl<T: 'static> Deref for PoolGuard<T> {
type Target = T;
#[inline]
fn deref(&self) -> &T {
self.obj.as_ref().expect("PoolGuard already dropped")
}
}
impl<T: 'static> DerefMut for PoolGuard<T> {
#[inline]
fn deref_mut(&mut self) -> &mut T {
self.obj.as_mut().expect("PoolGuard already dropped")
}
}
impl<T: 'static> Drop for PoolGuard<T> {
fn drop(&mut self) {
if let Some(obj) = self.obj.take() {
if let Ok(mut state) = self.pool.lock() {
state.return_obj(obj);
}
}
}
}
pub struct ObjectPoolBuilder<T> {
factory: Option<Box<dyn Fn() -> T + Send + Sync>>,
reset: Option<Box<dyn Fn(&mut T) + Send + Sync>>,
max_size: usize,
}
impl<T: Default + 'static> ObjectPoolBuilder<T> {
pub fn default_factory(self) -> Self {
self.factory(T::default)
}
}
impl<T: 'static> ObjectPoolBuilder<T> {
fn new() -> Self {
ObjectPoolBuilder {
factory: None,
reset: None,
max_size: 32,
}
}
pub fn factory<F: Fn() -> T + Send + Sync + 'static>(mut self, f: F) -> Self {
self.factory = Some(Box::new(f));
self
}
pub fn reset<R: Fn(&mut T) + Send + Sync + 'static>(mut self, r: R) -> Self {
self.reset = Some(Box::new(r));
self
}
pub fn max_size(mut self, n: usize) -> Self {
self.max_size = n;
self
}
pub fn build(self) -> ObjectPool<T> {
let factory = self
.factory
.expect("ObjectPoolBuilder: factory is required");
let reset = self.reset.unwrap_or_else(|| Box::new(|_| {}));
ObjectPool::new(factory, reset, self.max_size)
}
}
pub type SyncObjectPool<T> = ObjectPool<T>;
pub type SyncPoolGuard<T> = PoolGuard<T>;
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
fn vec_pool(max: usize) -> ObjectPool<Vec<u8>> {
ObjectPool::builder()
.factory(|| Vec::with_capacity(64))
.reset(|v| v.clear())
.max_size(max)
.build()
}
#[test]
fn pool_basic_acquire_and_return() {
let pool = vec_pool(4);
assert_eq!(pool.pool_size(), 0);
{
let mut guard = pool.acquire();
guard.extend_from_slice(b"hello");
assert_eq!(&*guard, b"hello");
assert_eq!(pool.in_use_count(), 1);
}
assert_eq!(pool.pool_size(), 1);
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.total_acquired(), 1);
}
#[test]
fn pool_reset_is_called() {
let pool = vec_pool(4);
{
let mut guard = pool.acquire();
guard.push(1);
guard.push(2);
}
let guard = pool.acquire();
assert!(guard.is_empty(), "reset should have cleared the vec");
}
#[test]
fn pool_respects_max_size() {
let pool = vec_pool(2);
let g1 = pool.acquire();
let g2 = pool.acquire();
let g3 = pool.acquire();
drop(g1);
drop(g2);
drop(g3); assert_eq!(pool.pool_size(), 2);
}
#[test]
fn pool_warm_up() {
let pool = vec_pool(8);
pool.warm_up(5);
assert_eq!(pool.pool_size(), 5);
}
#[test]
fn pool_statistics() {
let pool = vec_pool(8);
let _g1 = pool.acquire();
let _g2 = pool.acquire();
assert_eq!(pool.total_acquired(), 2);
assert_eq!(pool.in_use_count(), 2);
}
#[test]
fn pool_builder_default_factory() {
let pool: ObjectPool<Vec<i32>> = ObjectPool::builder()
.default_factory()
.reset(|v: &mut Vec<i32>| v.clear())
.max_size(4)
.build();
let guard = pool.acquire();
assert!(guard.is_empty());
}
#[test]
fn pool_multiple_guards_coexist() {
let pool = vec_pool(4);
let g1 = pool.acquire();
let g2 = pool.acquire();
let g3 = pool.acquire();
assert_eq!(pool.in_use_count(), 3);
assert_eq!(g1.len(), 0);
assert_eq!(g2.len(), 0);
assert_eq!(g3.len(), 0);
drop(g1);
drop(g2);
drop(g3);
assert_eq!(pool.pool_size(), 3);
}
#[test]
fn sync_pool_basic() {
let pool = SyncObjectPool::new(|| Vec::<u8>::with_capacity(32), |v| v.clear(), 8);
{
let mut guard = pool.acquire();
guard.push(99);
assert_eq!(*guard, vec![99]);
}
assert_eq!(pool.pool_size(), 1);
assert_eq!(pool.in_use_count(), 0);
assert_eq!(pool.total_acquired(), 1);
}
#[test]
fn sync_pool_multithreaded() {
let create_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&create_count);
let pool = SyncObjectPool::new(
move || {
cc.fetch_add(1, Ordering::Relaxed);
Vec::<u8>::new()
},
|v| v.clear(),
4,
);
let mut handles = Vec::new();
for _ in 0..8 {
let p = pool.clone();
handles.push(std::thread::spawn(move || {
let mut guard = p.acquire();
guard.push(1);
std::thread::sleep(std::time::Duration::from_millis(5));
}));
}
for h in handles {
h.join().expect("thread panicked");
}
assert_eq!(pool.total_acquired(), 8);
assert!(pool.pool_size() <= 4);
}
#[test]
fn sync_pool_warm_up() {
let pool = SyncObjectPool::new(|| 0_u32, |_| {}, 8);
pool.warm_up(5);
assert_eq!(pool.pool_size(), 5);
}
#[test]
fn pool_guard_deref() {
let pool: ObjectPool<String> = ObjectPool::builder()
.factory(String::new)
.reset(|s| s.clear())
.max_size(4)
.build();
let mut guard = pool.acquire();
guard.push_str("hello");
assert_eq!(guard.as_str(), "hello");
assert_eq!(guard.len(), 5);
}
}