use dashmap::DashMap;
use std::hash::Hash;
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::{debug, info};
const DEFAULT_POOL_CAPACITY: usize = usize::MAX;
const DEFAULT_POOL_IDLE_TIMEOUT: Duration = Duration::from_secs(600);
pub struct RequestGuard {
active_requests: Arc<AtomicUsize>,
}
impl RequestGuard {
fn new(active_requests: Arc<AtomicUsize>) -> Self {
active_requests.fetch_add(1, Ordering::SeqCst);
Self { active_requests }
}
}
impl Drop for RequestGuard {
fn drop(&mut self) {
self.active_requests.fetch_sub(1, Ordering::SeqCst);
}
}
#[derive(Clone)]
pub struct Entry<T> {
pub client: T,
active_requests: Arc<AtomicUsize>,
actived_at: Arc<std::sync::Mutex<Instant>>,
}
impl<T> Entry<T> {
fn new(client: T) -> Self {
Self {
client,
active_requests: Arc::new(AtomicUsize::new(0)),
actived_at: Arc::new(std::sync::Mutex::new(Instant::now())),
}
}
pub fn request_guard(&self) -> RequestGuard {
RequestGuard::new(self.active_requests.clone())
}
fn set_actived_at(&self, actived_at: Instant) {
*self.actived_at.lock().unwrap() = actived_at;
}
fn has_active_requests(&self) -> bool {
self.active_requests.load(Ordering::SeqCst) > 0
}
fn idle_duration(&self) -> Duration {
let actived_at = self.actived_at.lock().unwrap();
Instant::now().duration_since(*actived_at)
}
}
#[tonic::async_trait]
pub trait Factory<A, T> {
type Error;
async fn make_client(&self, addr: &A) -> Result<T, Self::Error>;
}
pub struct Pool<K, A, T, F> {
factory: F,
clients: Arc<DashMap<K, Entry<T>>>,
capacity: usize,
idle_timeout: Duration,
cleanup_at: Arc<Mutex<Instant>>,
_phantom: PhantomData<A>,
}
pub struct Builder<K, A, T, F> {
factory: F,
capacity: usize,
idle_timeout: Duration,
_phantom: PhantomData<(K, A, T)>,
}
impl<K, A, T, F> Builder<K, A, T, F>
where
K: Clone + Eq + Hash + std::fmt::Display,
T: Clone,
F: Factory<A, T>,
{
pub fn new(factory: F) -> Self {
Self {
factory,
capacity: DEFAULT_POOL_CAPACITY,
idle_timeout: DEFAULT_POOL_IDLE_TIMEOUT,
_phantom: PhantomData,
}
}
pub fn capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
pub fn idle_timeout(mut self, idle_timeout: Duration) -> Self {
self.idle_timeout = idle_timeout;
self
}
pub fn build(self) -> Pool<K, A, T, F> {
Pool {
factory: self.factory,
clients: Arc::new(DashMap::new()),
capacity: self.capacity,
idle_timeout: self.idle_timeout,
cleanup_at: Arc::new(Mutex::new(Instant::now())),
_phantom: PhantomData,
}
}
}
impl<K, A, T, F> Pool<K, A, T, F>
where
K: Clone + Eq + Hash + std::fmt::Display,
A: Clone + Eq + std::fmt::Display,
T: Clone,
F: Factory<A, T>,
{
pub async fn entry(&self, key: &K, addr: &A) -> Result<Entry<T>, F::Error> {
self.cleanup_idle_entries().await;
if let Some(entry) = self.clients.get(key) {
debug!("reusing client: {}", key);
entry.set_actived_at(Instant::now());
return Ok(entry.value().clone());
}
debug!("creating client: {}", key);
let client = self.factory.make_client(addr).await?;
let entry = self
.clients
.entry(key.clone())
.or_insert(Entry::new(client));
entry.set_actived_at(Instant::now());
Ok(entry.clone())
}
pub async fn remove_entry(&self, key: &K) {
self.clients
.remove_if(key, |_, entry| !entry.has_active_requests());
}
async fn cleanup_idle_entries(&self) {
let now = Instant::now();
{
let cleanup_at = self.cleanup_at.lock().await;
let interval = self.idle_timeout / 2;
if now.duration_since(*cleanup_at) < interval {
debug!("avoid hot cleanup");
return;
}
}
let exceeds_capacity = self.clients.len() > self.capacity;
self.clients.retain(|key, entry| {
let has_active_requests = entry.has_active_requests();
let idle_duration = entry.idle_duration();
let is_recent = idle_duration <= self.idle_timeout;
let should_retain = has_active_requests || (!exceeds_capacity && is_recent);
if !should_retain {
info!(
"removing idle client: {}, exceeds_capacity: {}, idle_duration: {}s",
key,
exceeds_capacity,
idle_duration.as_secs(),
);
}
should_retain
});
*self.cleanup_at.lock().await = now;
}
pub async fn size(&self) -> usize {
self.clients.len()
}
pub async fn clear(&self) {
self.clients.clear();
}
}