mod concurrent_limited_multimap;
use std::hash::Hash;
use std::sync::{Arc, RwLock};
use slab::Slab;
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use self::concurrent_limited_multimap::ConcurrentLimitedMultimap;
pub struct Pool<K, I> {
inner: Arc<ConcurrentLimitedMultimap<K, I, ahash::RandomState>>,
local_limits: RwLock<Slab<Arc<Semaphore>>>,
semaphore: Option<Arc<Semaphore>>,
}
impl<K, I> Pool<K, I>
where
K: Eq + std::hash::Hash,
{
pub fn new(capacity: usize) -> Self {
Self {
inner: Arc::new(ConcurrentLimitedMultimap::with_hasher(
capacity,
ahash::RandomState::new(),
)),
semaphore: Some(Arc::new(Semaphore::new(capacity))),
local_limits: RwLock::new(Slab::new()),
}
}
pub fn new_unbounded() -> Self {
Self {
inner: Arc::new(ConcurrentLimitedMultimap::with_hasher_unbounded(
ahash::RandomState::new(),
)),
semaphore: None,
local_limits: RwLock::new(Slab::new()),
}
}
pub fn set_local_limit(&self, limit: usize) -> usize {
let mut local_limits = self.local_limits.write().expect("local limits lock poisoned");
local_limits.insert(Arc::new(Semaphore::new(limit)))
}
pub async fn pull(&self, key: K) -> Item<K, I> {
self.pull_with_wait_local_limit(key, None).await
}
pub async fn pull_with_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Option<Item<K, I>> {
let local_guard = if let Some(index) = local_limit_index {
let local_limits = self.local_limits.read().expect("local limits lock poisoned");
if let Some(semaphore) = local_limits.get(index) {
let semaphore = semaphore.clone();
drop(local_limits);
Some(semaphore.try_acquire_owned().ok()?)
} else {
None
}
} else {
None
};
let guard = if let Some(semaphore) = &self.semaphore {
Some(semaphore.clone().acquire_owned().await.expect("semaphore closed"))
} else {
None
};
let key = Arc::new(key);
let inner_value = self.inner.remove(key.clone());
Some(Item {
pool_inner: self.inner.clone(),
key: Some(key),
inner: inner_value,
_guard: guard,
_local_guard: local_guard,
})
}
#[allow(clippy::await_holding_lock)]
pub async fn pull_with_wait_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Item<K, I> {
let local_guard = if let Some(index) = local_limit_index {
let local_limits = self.local_limits.read().expect("local limits lock poisoned");
if let Some(semaphore) = local_limits.get(index) {
let semaphore = semaphore.clone();
drop(local_limits); Some(semaphore.acquire_owned().await.expect("semaphore closed"))
} else {
None
}
} else {
None
};
let guard = if let Some(semaphore) = &self.semaphore {
Some(semaphore.clone().acquire_owned().await.expect("semaphore closed"))
} else {
None
};
let key = Arc::new(key);
let inner_value = self.inner.remove(key.clone());
Item {
pool_inner: self.inner.clone(),
key: Some(key),
inner: inner_value,
_guard: guard,
_local_guard: local_guard,
}
}
pub fn try_pull(&self, key: K) -> Option<Item<K, I>> {
self.try_pull_with_local_limit(key, None)
}
pub fn try_pull_with_local_limit(&self, key: K, local_limit_index: Option<usize>) -> Option<Item<K, I>> {
let local_guard = if let Some(index) = local_limit_index {
let local_limits = self.local_limits.read().expect("local limits lock poisoned");
if let Some(semaphore) = local_limits.get(index) {
let semaphore = semaphore.clone();
drop(local_limits);
Some(semaphore.try_acquire_owned().ok()?)
} else {
None
}
} else {
None
};
let guard = if let Some(semaphore) = &self.semaphore {
Some(semaphore.clone().try_acquire_owned().ok()?)
} else {
None
};
let key = Arc::new(key);
let inner_value = self.inner.remove(key.clone());
Some(Item {
pool_inner: self.inner.clone(),
key: Some(key),
inner: inner_value,
_guard: guard,
_local_guard: local_guard,
})
}
}
pub struct Item<K: Eq + Hash, I> {
pool_inner: Arc<ConcurrentLimitedMultimap<K, I, ahash::RandomState>>,
key: Option<Arc<K>>,
inner: Option<I>,
_guard: Option<OwnedSemaphorePermit>,
_local_guard: Option<OwnedSemaphorePermit>,
}
impl<K: Eq + Hash, I> Item<K, I> {
pub fn take(mut self) -> Option<I> {
self.inner.take()
}
pub fn inner(&self) -> &Option<I> {
&self.inner
}
pub fn inner_mut(&mut self) -> &mut Option<I> {
&mut self.inner
}
}
impl<K: Eq + Hash, I> Drop for Item<K, I> {
fn drop(&mut self) {
if let Some(inner) = self.inner.take() {
self.pool_inner.insert(self.key.take().expect("key not set"), inner);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_pool_new() {
let pool = Pool::<String, u32>::new(10);
assert_eq!(pool.semaphore.as_ref().unwrap().available_permits(), 10);
}
#[tokio::test]
async fn test_pool_pull_and_take() {
let pool = Pool::<String, u32>::new(1);
let item = pool.pull("key1".to_string()).await;
assert!(item.take().is_none());
}
#[tokio::test]
async fn test_pool_pull_and_replace() {
let pool = Pool::<String, u32>::new(1);
let mut item = pool.pull("key1".to_string()).await;
*item.inner_mut() = Some(42);
assert_eq!(item.inner(), &Some(42));
}
#[tokio::test]
async fn test_pool_eviction_behavior() {
let pool = Pool::<String, u32>::new(2);
{
let mut item1 = pool.pull("key1".to_string()).await;
item1.inner_mut().replace(1);
}
{
let mut item2 = pool.pull("key2".to_string()).await;
item2.inner_mut().replace(2);
}
{
let _item1 = pool.pull("key1".to_string()).await;
}
{
let mut item3 = pool.pull("key3".to_string()).await;
item3.inner_mut().replace(3);
}
let mut num_entries = 0;
if pool.pull("key1".to_string()).await.inner().is_some() {
num_entries += 1;
}
if pool.pull("key2".to_string()).await.inner().is_some() {
num_entries += 1;
}
if pool.pull("key3".to_string()).await.inner().is_some() {
num_entries += 1;
}
assert_eq!(num_entries, 2);
}
#[tokio::test]
async fn test_pool_semaphore_limit() {
let pool = Pool::<String, u32>::new(1);
let item1 = pool.pull("key1".to_string()).await;
let semaphore_permits = pool.semaphore.as_ref().unwrap().available_permits();
assert_eq!(semaphore_permits, 0);
drop(item1);
assert_eq!(pool.semaphore.as_ref().unwrap().available_permits(), 1);
}
#[tokio::test]
async fn test_set_and_get_local_limit() {
let pool = Pool::<String, u32>::new(10);
let index = pool.set_local_limit(2);
let local_limits = pool.local_limits.read().expect("lock poisoned");
assert!(local_limits.get(index).is_some());
assert_eq!(local_limits[index].available_permits(), 2);
}
#[tokio::test]
async fn test_pull_with_local_limit_success() {
let pool = Pool::<String, u32>::new(10);
let index = pool.set_local_limit(2);
let item = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
assert!(item.is_some());
}
#[tokio::test]
async fn test_pull_with_local_limit_exhausted() {
let pool = Pool::<String, u32>::new(10);
let index = pool.set_local_limit(1);
let _item1 = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
let item2 = pool.pull_with_local_limit("key2".to_string(), Some(index)).await;
assert!(item2.is_none());
}
#[tokio::test]
async fn test_pull_with_local_limit_after_release() {
let pool = Pool::<String, u32>::new(10);
let index = pool.set_local_limit(1);
let item1 = pool.pull_with_local_limit("key1".to_string(), Some(index)).await;
assert!(item1.is_some());
drop(item1);
let item2 = pool.pull_with_local_limit("key2".to_string(), Some(index)).await;
assert!(item2.is_some());
}
#[tokio::test]
async fn test_pull_with_invalid_local_limit_index() {
let pool = Pool::<String, u32>::new(10);
let item = pool.pull_with_local_limit("key1".to_string(), Some(999)).await;
assert!(item.is_some()); }
}