use std::collections::HashMap;
use std::ops::Deref;
use std::sync::Arc;
use tokio::sync::{OwnedSemaphorePermit, RwLock, Semaphore};
use crate::{RuntimeAdapter, RuntimeKind};
type RuntimeFactory = Arc<dyn Fn() -> Arc<dyn RuntimeAdapter> + Send + Sync>;
#[derive(Debug, Clone)]
pub struct PoolConfig {
pub max_size: usize,
pub warm_on_create: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
PoolConfig {
max_size: 10,
warm_on_create: false,
}
}
}
struct PoolItem {
runtime: Arc<dyn RuntimeAdapter>,
borrow_count: usize,
}
struct PoolState {
items: Vec<PoolItem>,
factory: RuntimeFactory,
}
pub struct RuntimePool {
config: PoolConfig,
pools: RwLock<HashMap<RuntimeKind, PoolState>>,
semaphores: RwLock<HashMap<RuntimeKind, Arc<Semaphore>>>,
}
impl RuntimePool {
pub fn new(config: PoolConfig) -> Self {
RuntimePool {
config,
pools: RwLock::new(HashMap::new()),
semaphores: RwLock::new(HashMap::new()),
}
}
pub fn default_config() -> Self {
Self::new(PoolConfig::default())
}
pub async fn register_factory<F>(&self, kind: RuntimeKind, factory: F)
where
F: Fn() -> Arc<dyn RuntimeAdapter> + Send + Sync + 'static,
{
let factory = Arc::new(factory);
let semaphore = Arc::new(Semaphore::new(self.config.max_size));
{
let mut semaphores = self.semaphores.write().await;
semaphores.insert(kind, semaphore);
}
let items = if self.config.warm_on_create {
(0..self.config.max_size)
.map(|_| PoolItem {
runtime: factory(),
borrow_count: 0,
})
.collect()
} else {
Vec::new()
};
let state = PoolState { items, factory };
{
let mut pools = self.pools.write().await;
pools.insert(kind, state);
}
}
pub async fn acquire(self: &Arc<Self>, kind: RuntimeKind) -> Option<PoolGuard> {
let semaphore = {
let semaphores = self.semaphores.read().await;
semaphores.get(&kind).cloned()?
};
let permit = semaphore.acquire_owned().await.ok()?;
let mut pools = self.pools.write().await;
let state = pools.get_mut(&kind)?;
let mut found_idx = None;
for (idx, item) in state.items.iter_mut().enumerate() {
if item.borrow_count == 0 {
item.borrow_count += 1;
found_idx = Some(idx);
break;
}
}
let idx = match found_idx {
Some(idx) => idx,
None => {
let runtime = (state.factory)();
state.items.push(PoolItem {
runtime,
borrow_count: 1,
});
state.items.len() - 1
}
};
let runtime = state.items[idx].runtime.clone();
Some(PoolGuard {
pool: self.clone(),
kind,
idx,
runtime,
_permit: permit,
})
}
pub async fn try_acquire(self: &Arc<Self>, kind: RuntimeKind) -> Option<PoolGuard> {
let semaphore = {
let semaphores = self.semaphores.read().await;
semaphores.get(&kind).cloned()?
};
let permit = semaphore.try_acquire_owned().ok()?;
let mut pools = self.pools.write().await;
let state = pools.get_mut(&kind)?;
for (idx, item) in state.items.iter_mut().enumerate() {
if item.borrow_count == 0 {
item.borrow_count += 1;
let runtime = item.runtime.clone();
return Some(PoolGuard {
pool: self.clone(),
kind,
idx,
runtime,
_permit: permit,
});
}
}
None
}
fn return_runtime(&self, kind: RuntimeKind, idx: usize) {
if let Ok(mut pools) = self.pools.try_write() {
if let Some(state) = pools.get_mut(&kind) {
if let Some(item) = state.items.get_mut(idx) {
item.borrow_count = item.borrow_count.saturating_sub(1);
}
}
}
}
pub async fn stats(&self, kind: RuntimeKind) -> PoolStats {
let pools = self.pools.read().await;
if let Some(state) = pools.get(&kind) {
let total = state.items.len();
let in_use = state.items.iter().map(|i| i.borrow_count).sum();
PoolStats {
total,
available: total.saturating_sub(in_use),
in_use,
}
} else {
PoolStats::default()
}
}
pub async fn clear(&self, kind: RuntimeKind) {
let mut pools = self.pools.write().await;
pools.remove(&kind);
}
pub async fn clear_all(&self) {
let mut pools = self.pools.write().await;
pools.clear();
}
}
impl Default for RuntimePool {
fn default() -> Self {
Self::default_config()
}
}
pub struct PoolGuard {
pool: Arc<RuntimePool>,
kind: RuntimeKind,
idx: usize,
runtime: Arc<dyn RuntimeAdapter>,
_permit: OwnedSemaphorePermit,
}
impl PoolGuard {
pub fn runtime(&self) -> Arc<dyn RuntimeAdapter> {
self.runtime.clone()
}
}
impl Deref for PoolGuard {
type Target = Arc<dyn RuntimeAdapter>;
fn deref(&self) -> &Self::Target {
&self.runtime
}
}
impl Drop for PoolGuard {
fn drop(&mut self) {
self.pool.return_runtime(self.kind, self.idx);
}
}
#[derive(Debug, Clone, Default)]
pub struct PoolStats {
pub total: usize,
pub available: usize,
pub in_use: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::create_runtime;
#[test]
fn test_pool_config_default() {
let config = PoolConfig::default();
assert_eq!(config.max_size, 10);
assert!(!config.warm_on_create);
}
#[tokio::test]
async fn test_pool_creation() {
let pool = RuntimePool::default_config();
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.total, 0);
}
#[tokio::test]
async fn test_pool_stats_default() {
let stats = PoolStats::default();
assert_eq!(stats.total, 0);
assert_eq!(stats.available, 0);
assert_eq!(stats.in_use, 0);
}
#[tokio::test]
async fn test_pool_register_and_acquire() {
let pool = Arc::new(RuntimePool::new(PoolConfig {
max_size: 2,
warm_on_create: false,
}));
pool.register_factory(RuntimeKind::Local, || {
create_runtime(RuntimeKind::Local).unwrap()
})
.await;
let guard = pool.acquire(RuntimeKind::Local).await;
assert!(guard.is_some());
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.in_use, 1);
}
#[tokio::test]
async fn test_pool_release() {
let pool = Arc::new(RuntimePool::new(PoolConfig {
max_size: 2,
warm_on_create: false,
}));
pool.register_factory(RuntimeKind::Local, || {
create_runtime(RuntimeKind::Local).unwrap()
})
.await;
{
let _guard = pool.acquire(RuntimeKind::Local).await.unwrap();
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.in_use, 1);
}
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.available, 1);
}
#[tokio::test]
async fn test_pool_warm_on_create() {
let pool = Arc::new(RuntimePool::new(PoolConfig {
max_size: 3,
warm_on_create: true,
}));
pool.register_factory(RuntimeKind::Local, || {
create_runtime(RuntimeKind::Local).unwrap()
})
.await;
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.total, 3);
assert_eq!(stats.available, 3);
}
#[tokio::test]
async fn test_pool_multiple_acquire() {
let pool = Arc::new(RuntimePool::new(PoolConfig {
max_size: 2,
warm_on_create: false,
}));
pool.register_factory(RuntimeKind::Local, || {
create_runtime(RuntimeKind::Local).unwrap()
})
.await;
let g1 = pool.acquire(RuntimeKind::Local).await;
let g2 = pool.acquire(RuntimeKind::Local).await;
assert!(g1.is_some());
assert!(g2.is_some());
let stats = pool.stats(RuntimeKind::Local).await;
assert_eq!(stats.in_use, 2);
assert_eq!(stats.available, 0);
}
}