use std::sync::{Arc, Mutex};
use crate::thread_state::ensure_thread_shard_key;
#[doc(hidden)]
pub trait Resettable {
fn reset(&mut self);
}
pub(crate) struct Pool<T: Resettable + Default + Send + 'static> {
items: Mutex<Vec<Box<T>>>,
capacity: usize,
}
impl<T: Resettable + Default + Send + 'static> Pool<T> {
fn new(capacity: usize) -> Self {
Self {
items: Mutex::new(Vec::new()),
capacity,
}
}
fn try_take(self: &Arc<Self>) -> Box<T> {
if let Ok(mut guard) = self.items.try_lock()
&& let Some(b) = guard.pop()
{
return b;
}
Box::new(T::default())
}
fn try_return(&self, value: Box<T>) {
match self.items.try_lock() {
Ok(mut guard) => {
if guard.len() < self.capacity {
guard.push(value);
}
}
Err(_) => {
}
}
}
}
pub(crate) struct ObjectPool<T: Resettable + Default + Send + 'static> {
shards: Vec<Arc<Pool<T>>>,
shard_mask: u64,
}
impl<T: Resettable + Default + Send + 'static> ObjectPool<T> {
pub fn new(shard_count: usize, per_shard_capacity: usize) -> Arc<Self> {
let n = shard_count.max(1).next_power_of_two();
let shards = (0..n)
.map(|_| Arc::new(Pool::new(per_shard_capacity)))
.collect();
Arc::new(Self {
shards,
shard_mask: (n as u64) - 1,
})
}
pub fn acquire(&self) -> ReuseRef<T> {
let key = ensure_thread_shard_key();
let shard = &self.shards[(key & self.shard_mask) as usize];
let boxed = shard.try_take();
ReuseRef {
value: Some(boxed),
pool: Arc::clone(shard),
}
}
}
pub struct ReuseRef<T: Resettable + Default + Send + 'static> {
value: Option<Box<T>>,
pool: Arc<Pool<T>>,
}
impl<T: Resettable + Default + Send + 'static> std::ops::Deref for ReuseRef<T> {
type Target = T;
#[inline]
#[allow(clippy::expect_used)]
fn deref(&self) -> &T {
self.value
.as_deref()
.expect("ReuseRef value taken before drop")
}
}
impl<T: Resettable + Default + Send + 'static> std::ops::DerefMut for ReuseRef<T> {
#[inline]
#[allow(clippy::expect_used)]
fn deref_mut(&mut self) -> &mut T {
self.value
.as_deref_mut()
.expect("ReuseRef value taken before drop")
}
}
impl<T: Resettable + Default + Send + 'static + Clone> Clone for ReuseRef<T> {
fn clone(&self) -> Self {
ReuseRef {
value: Some(Box::new((**self).clone())),
pool: Arc::clone(&self.pool),
}
}
}
impl<T: Resettable + Default + Send + 'static> Drop for ReuseRef<T> {
fn drop(&mut self) {
let Some(mut boxed) = self.value.take() else {
return;
};
boxed.reset();
self.pool.try_return(boxed);
}
}
impl<T: Resettable + Default + Send + 'static + std::fmt::Debug> std::fmt::Debug for ReuseRef<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(&**self, f)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug, Default, Clone)]
struct Buf {
bytes: Vec<u8>,
drops: Option<Arc<AtomicUsize>>,
}
impl Resettable for Buf {
fn reset(&mut self) {
self.bytes.clear();
}
}
impl Drop for Buf {
fn drop(&mut self) {
if let Some(d) = &self.drops {
d.fetch_add(1, Ordering::Relaxed);
}
}
}
fn install_counter(r: &mut ReuseRef<Buf>, counter: &Arc<AtomicUsize>) {
r.drops = Some(Arc::clone(counter));
}
#[test]
fn pool_starts_empty_and_grows() {
let pool = ObjectPool::<Buf>::new(1, 4);
let shard = &pool.shards[0];
assert_eq!(shard.items.lock().unwrap().len(), 0);
{
let mut r = pool.acquire();
r.bytes.extend_from_slice(b"x");
}
assert_eq!(shard.items.lock().unwrap().len(), 1);
}
#[test]
fn acquire_then_drop_returns_to_shard() {
let pool = ObjectPool::<Buf>::new(1, 4);
let counter = Arc::new(AtomicUsize::new(0));
{
let mut r = pool.acquire();
install_counter(&mut r, &counter);
r.bytes.extend_from_slice(b"hello");
assert_eq!(r.bytes, b"hello");
}
assert_eq!(counter.load(Ordering::Relaxed), 0);
let r = pool.acquire();
assert_eq!(r.bytes, b"");
}
#[test]
fn full_shard_drops_overflow() {
let pool = ObjectPool::<Buf>::new(1, 2);
let counter = Arc::new(AtomicUsize::new(0));
let mut refs: Vec<_> = (0..4).map(|_| pool.acquire()).collect();
for r in &mut refs {
install_counter(r, &counter);
}
drop(refs);
assert_eq!(counter.load(Ordering::Relaxed), 2);
}
#[test]
fn contended_acquire_still_returns_on_drop() {
let pool = ObjectPool::<Buf>::new(1, 4);
let shard = Arc::clone(&pool.shards[0]);
let r;
{
let _guard = shard.items.lock().unwrap();
r = pool.acquire();
}
drop(r);
assert_eq!(shard.items.lock().unwrap().len(), 1);
}
}