use std::{
collections::VecDeque,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
sync::{Arc, Mutex},
};
#[derive(Debug)]
pub struct ObjectPool<T> {
queue: Mutex<VecDeque<T>>,
capacity: Option<usize>,
}
impl<T> ObjectPool<T> {
pub fn new<A>(args: A, initial_size: usize, capacity: Option<usize>) -> Self
where
T: AsPooled<A>,
A: Clone,
{
let queue = (0..initial_size).map(|_| T::create(args.clone())).collect();
Self {
queue: Mutex::new(queue),
capacity,
}
}
pub fn try_new<A>(
args: A,
initial_size: usize,
capacity: Option<usize>,
) -> Result<Self, T::Error>
where
T: TryAsPooled<A>,
A: Clone,
{
let queue = (0..initial_size)
.map(|_| T::try_create(args.clone()))
.collect::<Result<_, _>>()?;
Ok(Self {
queue: Mutex::new(queue),
capacity,
})
}
pub fn try_get_ref<A>(&self, args: A) -> Result<PooledRef<'_, T>, T::Error>
where
T: TryAsPooled<A>,
{
let item = self.try_get_or_create(args)?;
Ok(PooledRef {
item: ManuallyDrop::new(item),
parent: self,
})
}
pub fn get_ref<A>(&self, args: A) -> PooledRef<'_, T>
where
T: AsPooled<A>,
{
let item = self.get_or_create(args);
PooledRef {
item: ManuallyDrop::new(item),
parent: self,
}
}
pub fn try_get<A>(self: &Arc<Self>, args: A) -> Result<PooledArc<T>, T::Error>
where
T: TryAsPooled<A>,
{
let item = self.try_get_or_create(args)?;
Ok(PooledArc {
item: ManuallyDrop::new(item),
parent: self.clone(),
})
}
pub fn get<A>(self: &Arc<Self>, args: A) -> PooledArc<T>
where
T: AsPooled<A>,
{
let item = self.get_or_create(args);
PooledArc {
item: ManuallyDrop::new(item),
parent: self.clone(),
}
}
pub fn len(&self) -> usize {
self.lock().len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
fn try_get_or_create<A>(&self, args: A) -> Result<T, T::Error>
where
T: TryAsPooled<A>,
{
let maybe = self.lock().pop_front();
if let Some(mut item) = maybe {
item.try_modify(args)?;
Ok(item)
} else {
T::try_create(args)
}
}
fn get_or_create<A>(&self, args: A) -> T
where
T: AsPooled<A>,
{
let maybe = self.lock().pop_front();
if let Some(mut item) = maybe {
item.modify(args);
item
} else {
T::create(args)
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, VecDeque<T>> {
match self.queue.lock() {
Ok(guard) => guard,
Err(poisoned) => {
self.queue.clear_poison();
poisoned.into_inner()
}
}
}
}
pub trait TryAsPooled<A>
where
Self: Sized,
{
type Error;
fn try_create(args: A) -> Result<Self, Self::Error>;
fn try_modify(&mut self, args: A) -> Result<(), Self::Error>;
}
pub trait AsPooled<A> {
fn create(args: A) -> Self;
fn modify(&mut self, args: A);
}
#[derive(Debug, Clone, Copy)]
pub struct Undef {
pub len: usize,
}
impl Undef {
pub fn new(len: usize) -> Self {
Self { len }
}
}
impl<T> AsPooled<Undef> for Vec<T>
where
T: Default + Clone,
{
fn create(undef: Undef) -> Self {
vec![T::default(); undef.len]
}
fn modify(&mut self, undef: Undef) {
self.resize(undef.len, T::default())
}
}
#[derive(Debug)]
pub struct PooledRef<'a, T> {
item: ManuallyDrop<T>,
parent: &'a ObjectPool<T>,
}
impl<T> Drop for PooledRef<'_, T> {
fn drop(&mut self) {
let mut guard = self.parent.lock();
if guard.len() < self.parent.capacity.unwrap_or(usize::MAX) {
guard.push_back(unsafe { ManuallyDrop::take(&mut self.item) });
} else {
std::mem::drop(guard);
unsafe { ManuallyDrop::drop(&mut self.item) };
}
}
}
impl<T> Deref for PooledRef<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.item
}
}
impl<T> DerefMut for PooledRef<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.item
}
}
#[derive(Debug)]
pub struct PooledArc<T> {
item: ManuallyDrop<T>,
parent: Arc<ObjectPool<T>>,
}
impl<T> Drop for PooledArc<T> {
fn drop(&mut self) {
let mut guard = self.parent.lock();
if guard.len() < self.parent.capacity.unwrap_or(usize::MAX) {
guard.push_back(unsafe { ManuallyDrop::take(&mut self.item) });
} else {
std::mem::drop(guard);
unsafe { ManuallyDrop::drop(&mut self.item) };
}
}
}
impl<T> Deref for PooledArc<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&self.item
}
}
impl<T> DerefMut for PooledArc<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.item
}
}
#[derive(Debug)]
pub enum PoolOption<T> {
NonPooled(T),
Pooled(PooledArc<T>),
}
impl<T> PoolOption<T> {
pub fn non_pooled(item: T) -> Self {
PoolOption::NonPooled(item)
}
pub fn try_non_pooled_create<A>(args: A) -> Result<Self, T::Error>
where
T: TryAsPooled<A>,
{
Ok(PoolOption::NonPooled(T::try_create(args)?))
}
pub fn non_pooled_create<A>(args: A) -> Self
where
T: AsPooled<A>,
{
PoolOption::NonPooled(T::create(args))
}
pub fn pooled<A>(pool: &Arc<ObjectPool<T>>, args: A) -> Self
where
T: AsPooled<A>,
{
PoolOption::Pooled(pool.get(args))
}
pub fn try_pooled<A>(pool: &Arc<ObjectPool<T>>, args: A) -> Result<Self, T::Error>
where
T: TryAsPooled<A>,
{
Ok(PoolOption::Pooled(pool.try_get(args)?))
}
pub fn is_pooled(&self) -> bool {
matches!(self, PoolOption::Pooled(_))
}
pub fn is_non_pooled(&self) -> bool {
matches!(self, PoolOption::NonPooled(_))
}
}
impl<T> Deref for PoolOption<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
PoolOption::NonPooled(item) => item,
PoolOption::Pooled(item) => item,
}
}
}
impl<T> DerefMut for PoolOption<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
match self {
PoolOption::NonPooled(item) => item,
PoolOption::Pooled(item) => item,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TestItem {
value: Box<u32>,
panic_on_drop: bool,
}
impl TestItem {
fn new(value: u32) -> Self {
Self {
value: Box::new(value),
panic_on_drop: false,
}
}
}
impl AsPooled<u32> for TestItem {
fn create(value: u32) -> Self {
TestItem::new(value)
}
fn modify(&mut self, value: u32) {
*self.value = value;
self.panic_on_drop = false;
}
}
impl TryAsPooled<i32> for TestItem {
type Error = ();
fn try_create(value: i32) -> Result<Self, Self::Error> {
match value.try_into() {
Ok(v) => Ok(TestItem::new(v)),
Err(_) => Err(()),
}
}
fn try_modify(&mut self, value: i32) -> Result<(), Self::Error> {
match value.try_into() {
Ok(v) => {
*self.value = v;
self.panic_on_drop = false;
Ok(())
}
Err(_) => Err(()),
}
}
}
impl Drop for TestItem {
fn drop(&mut self) {
if self.panic_on_drop {
panic!("panicking on drop");
}
}
}
struct TestPanic;
impl AsPooled<TestPanic> for TestItem {
fn create(_: TestPanic) -> Self {
panic!("panicking on create")
}
fn modify(&mut self, _: TestPanic) {
panic!("panicking on modify")
}
}
impl TryAsPooled<TestPanic> for TestItem {
type Error = ();
fn try_create(_: TestPanic) -> Result<Self, Self::Error> {
panic!("panicking on try_create")
}
fn try_modify(&mut self, _: TestPanic) -> Result<(), Self::Error> {
panic!("panicking on try_modify")
}
}
#[test]
fn test_pool_basic_tests() {
let pool = ObjectPool::<TestItem>::new(42, 2, None);
assert_eq!(pool.len(), 2);
let item1 = pool.get_ref(100);
assert_eq!(*item1.value, 100);
assert_eq!(pool.len(), 1);
let item2 = pool.get_ref(200);
assert_eq!(*item2.value, 200);
assert_eq!(pool.len(), 0);
let item = pool.get_ref(300);
assert_eq!(*item.value, 300);
assert_eq!(pool.len(), 0);
{
let item = pool.get_ref(400);
assert_eq!(*item.value, 400);
assert_eq!(pool.len(), 0);
}
assert_eq!(pool.len(), 1); {
let item_a = pool.get_ref(500);
assert_eq!(*item_a.value, 500);
assert_eq!(pool.len(), 0); let item_b = pool.get_ref(600);
assert_eq!(*item_b.value, 600);
assert_eq!(pool.len(), 0); }
assert_eq!(pool.len(), 2);
let pool = ObjectPool::<TestItem>::new(42, 1, None);
let item = pool.get_ref(100);
assert_eq!(*item.value, 100);
}
#[test]
fn test_pool_basic_tests_with_try() {
let pool_result = ObjectPool::<TestItem>::try_new(-1, 2, Some(100));
assert!(
pool_result.is_err(),
"Pool creation should fail with negative args"
);
let pool = ObjectPool::<TestItem>::try_new(42, 2, None).unwrap();
assert_eq!(pool.len(), 2);
let item1 = pool.try_get_ref(100).unwrap();
assert_eq!(*item1.value, 100);
assert_eq!(pool.len(), 1);
let item2 = pool.try_get_ref(200).unwrap();
assert_eq!(*item2.value, 200);
assert_eq!(pool.len(), 0);
let item = pool.try_get_ref(300).unwrap();
assert_eq!(*item.value, 300);
assert_eq!(pool.len(), 0);
{
let item = pool.try_get_ref(400).unwrap();
assert_eq!(*item.value, 400);
assert_eq!(pool.len(), 0);
}
assert_eq!(pool.len(), 1); {
let item_a = pool.try_get_ref(500).unwrap();
assert_eq!(*item_a.value, 500);
assert_eq!(pool.len(), 0); let item_b = pool.try_get_ref(600).unwrap();
assert_eq!(*item_b.value, 600);
assert_eq!(pool.len(), 0); }
assert_eq!(pool.len(), 2);
let pool = ObjectPool::<TestItem>::try_new(42, 1, Some(100)).unwrap();
let item = pool.try_get_ref(100).unwrap();
assert_eq!(*item.value, 100);
}
#[test]
fn test_pool_with_arc() {
let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
let item = pool.get(100);
assert_eq!(*item.value, 100);
assert_eq!(pool.len(), 0);
let item = pool.get(200);
assert_eq!(*item.value, 200);
assert_eq!(pool.len(), 0);
{
let item = pool.get(400);
assert_eq!(*item.value, 400);
assert_eq!(pool.len(), 0);
}
assert_eq!(pool.len(), 1); let item = pool.try_get_ref(400).unwrap();
assert_eq!(*item.value, 400);
assert_eq!(pool.len(), 0); let item = pool.try_get(500).unwrap();
assert_eq!(*item.value, 500);
}
#[test]
fn test_pool_max_capacity_ref() {
let pool = ObjectPool::<TestItem>::new(42, 1, Some(1));
assert_eq!(pool.len(), 1);
assert!(!pool.is_empty());
assert_eq!(pool.len(), pool.capacity.unwrap()); {
let item = pool.get_ref(100);
assert_eq!(pool.len(), 0); assert!(pool.is_empty());
assert!(pool.len() < pool.capacity.unwrap()); assert_eq!(*item.value, 100);
}
assert_eq!(pool.len(), 1); assert_eq!(pool.len(), pool.capacity.unwrap()); {
let item1 = pool.get_ref(100);
assert_eq!(pool.len(), 0); let item2 = pool.get_ref(200);
assert_eq!(pool.len(), 0); let item3 = pool.get_ref(300);
assert_eq!(pool.len(), 0); assert!(*item1.value == 100 && *item2.value == 200 && *item3.value == 300);
}
assert_eq!(pool.len(), pool.capacity.unwrap()); assert_eq!(pool.len(), 1); }
#[test]
fn test_pool_max_capacity_pooled_item() {
let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, Some(1)));
assert_eq!(pool.len(), 1);
assert_eq!(pool.len(), pool.capacity.unwrap()); {
let item = pool.get(100);
assert_eq!(pool.len(), 0); assert!(pool.len() < pool.capacity.unwrap()); assert_eq!(*item.value, 100);
}
assert_eq!(pool.len(), 1); assert_eq!(pool.len(), pool.capacity.unwrap()); {
let item1 = pool.get(100);
assert_eq!(pool.len(), 0); let item2 = pool.get(200);
assert_eq!(pool.len(), 0); let item3 = pool.get(300);
assert_eq!(pool.len(), 0); assert!(*item1.value == 100 && *item2.value == 200 && *item3.value == 300);
}
assert_eq!(pool.len(), pool.capacity.unwrap()); assert_eq!(pool.len(), 1); }
#[test]
fn test_pool_options() {
let item = PoolOption::non_pooled(TestItem::new(42));
assert_eq!(*item.value, 42);
assert!(item.is_non_pooled());
assert!(!item.is_pooled());
let item = PoolOption::<TestItem>::non_pooled_create(100);
assert_eq!(*item.value, 100);
assert!(item.is_non_pooled());
let pool = Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
let item = PoolOption::pooled(&pool, 100);
assert_eq!(*item.value, 100);
assert!(item.is_pooled());
assert!(!item.is_non_pooled());
let item = PoolOption::try_pooled(&pool, 100).unwrap();
assert_eq!(*item.value, 100);
assert!(item.is_pooled());
let item = PoolOption::<TestItem>::try_non_pooled_create(200).unwrap();
assert_eq!(*item.value, 200);
assert!(item.is_non_pooled());
assert!(!item.is_pooled());
let item_result = PoolOption::<TestItem>::try_non_pooled_create(-200);
assert!(
item_result.is_err(),
"Creating non-pooled item with negative args should fail"
);
}
#[test]
fn test_pool_ref_deref_mut() {
let pool = ObjectPool::<TestItem>::new(42, 1, None);
let item = pool.get_ref(100);
assert_eq!(*item.value, 100);
let mut item = pool.get_ref(100);
*item.value = 200;
assert_eq!(*item.value, 200);
let item_ref: &TestItem = &item;
assert_eq!(*item_ref.value, 200);
let item_ref_mut: &mut TestItem = &mut item;
assert_eq!(*item_ref_mut.value, 200);
*item_ref_mut.value = 300;
assert_eq!(*item_ref_mut.value, 300);
let pool = &Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
let mut item = pool.get_ref(100);
assert_eq!(*item.value, 100);
*item.value = 200;
assert_eq!(*item.value, 200);
let pool = Arc::new(ObjectPool::<TestItem>::new(42, 1, None));
let mut item = PoolOption::pooled(&pool, 100);
assert_eq!(*item.value, 100);
*item.value = 200;
assert_eq!(*item.value, 200);
let mut item = PoolOption::non_pooled(TestItem::new(42));
assert_eq!(*item.value, 42);
*item.value = 100;
assert_eq!(*item.value, 100);
}
fn check_error(err: &dyn std::any::Any, contains: &str) {
match err.downcast_ref::<&'static str>() {
Some(msg) => assert!(
msg.contains(contains),
"failed: message \"{}\" does not contain \"{}\"",
msg,
contains
),
None => panic!("incorrect downcast type"),
}
}
#[test]
fn test_panic_during_create() {
let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
let err = std::panic::catch_unwind(|| {
let _ = pool.get_ref(TestPanic);
})
.unwrap_err();
check_error(&*err, "panicking on create");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling trait implementations"
);
assert_eq!(pool.len(), 0);
}
#[test]
fn test_panic_during_try_create() {
let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
let err = std::panic::catch_unwind(|| {
let _ = pool.try_get_ref(TestPanic);
})
.unwrap_err();
check_error(&*err, "panicking on try_create");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling trait implementations"
);
assert_eq!(pool.len(), 0);
}
#[test]
fn test_panic_during_modify() {
let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
let _ = pool.get_ref(0u32);
assert_eq!(pool.len(), 1);
let err = std::panic::catch_unwind(|| {
let _ = pool.get_ref(TestPanic);
})
.unwrap_err();
check_error(&*err, "panicking on modify");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling trait implementations"
);
assert_eq!(
pool.len(),
0,
"we should not return a potentially torn object to the pool"
);
}
#[test]
fn test_panic_during_try_modify() {
let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
let _ = pool.get_ref(0u32);
assert_eq!(pool.len(), 1);
let err = std::panic::catch_unwind(|| {
let _ = pool.try_get_ref(TestPanic);
})
.unwrap_err();
check_error(&*err, "panicking on try_modify");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling trait implementations"
);
assert_eq!(
pool.len(),
0,
"we should not return a potentially torn object to the pool"
);
}
#[test]
fn test_panic_during_drop_ref() {
let pool = ObjectPool::<TestItem>::new(0u32, 0, Some(1));
let mut a = pool.get_ref(0u32);
let _ = pool.get_ref(1u32);
assert_eq!(pool.len(), 1);
a.panic_on_drop = true;
let err = std::panic::catch_unwind(move || std::mem::drop(a)).unwrap_err();
check_error(&*err, "panicking on drop");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling object drop"
);
assert_eq!(pool.len(), 1);
}
#[test]
fn test_panic_during_drop_arc() {
let pool = Arc::new(ObjectPool::<TestItem>::new(0u32, 0, Some(1)));
let mut a = pool.get(0u32);
let _ = pool.get(1u32);
assert_eq!(pool.len(), 1);
a.panic_on_drop = true;
let err = std::panic::catch_unwind(move || std::mem::drop(a)).unwrap_err();
check_error(&*err, "panicking on drop");
assert!(
!pool.queue.is_poisoned(),
"lock should be released while calling object drop"
);
assert_eq!(pool.len(), 1);
}
#[test]
fn test_panic_recovery() {
let pool = ObjectPool::<TestItem>::new(0u32, 1, Some(1));
let err = std::panic::catch_unwind(|| {
let _guard = pool.queue.lock();
panic!("yeet");
})
.unwrap_err();
check_error(&*err, "yeet");
assert!(pool.queue.is_poisoned());
let _ = pool.get_ref(1u32);
assert!(!pool.queue.is_poisoned(), "poison should be cleared");
}
#[test]
fn test_undef() {
let mut x: Vec<f32> = Vec::<f32>::create(Undef::new(10));
assert_eq!(x.len(), 10);
x.modify(Undef::new(0));
assert_eq!(x.len(), 0);
x.modify(Undef::new(20));
assert_eq!(x.len(), 20);
}
}