use std::fmt::Debug;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex};
use std::time::{Duration, Instant};
#[cfg(feature = "prometheus")]
use cdk_prometheus::metrics::METRICS;
use crate::database::DatabaseConnector;
#[derive(Debug, thiserror::Error)]
pub enum Error<E>
where
E: std::error::Error + Send + Sync + 'static,
{
#[error("Internal: PoisonError")]
Poison,
#[error("Timed out waiting for a resource")]
Timeout,
#[error(transparent)]
Resource(#[from] E),
}
pub trait DatabaseConfig: Clone + Debug + Send + Sync {
fn max_size(&self) -> usize;
fn default_timeout(&self) -> Duration;
}
pub trait DatabasePool: Debug {
type Connection: DatabaseConnector;
type Config: DatabaseConfig;
type Error: Debug + std::error::Error + Send + Sync + 'static;
fn new_resource(
config: &Self::Config,
stale: Arc<AtomicBool>,
timeout: Duration,
) -> Result<Self::Connection, Error<Self::Error>>;
fn drop(_resource: Self::Connection) {}
}
#[derive(Debug)]
pub struct Pool<RM>
where
RM: DatabasePool,
{
config: RM::Config,
queue: Mutex<Vec<(Arc<AtomicBool>, RM::Connection)>>,
in_use: AtomicUsize,
max_size: usize,
default_timeout: Duration,
waiter: Condvar,
}
pub struct PooledResource<RM>
where
RM: DatabasePool,
{
resource: Option<(Arc<AtomicBool>, RM::Connection)>,
pool: Arc<Pool<RM>>,
#[cfg(feature = "prometheus")]
start_time: Instant,
}
impl<RM> Debug for PooledResource<RM>
where
RM: DatabasePool,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Resource: {:?}", self.resource)
}
}
impl<RM> Drop for PooledResource<RM>
where
RM: DatabasePool,
{
fn drop(&mut self) {
if let Some(resource) = self.resource.take() {
let mut active_resource = self.pool.queue.lock().expect("active_resource");
active_resource.push(resource);
let _in_use = self.pool.in_use.fetch_sub(1, Ordering::AcqRel);
#[cfg(feature = "prometheus")]
{
METRICS.set_db_connections_active(_in_use as i64);
let duration = self.start_time.elapsed().as_secs_f64();
METRICS.record_db_operation(duration, "drop");
}
self.pool.waiter.notify_one();
}
}
}
impl<RM> Deref for PooledResource<RM>
where
RM: DatabasePool,
{
type Target = RM::Connection;
fn deref(&self) -> &Self::Target {
&self.resource.as_ref().expect("resource already dropped").1
}
}
impl<RM> DerefMut for PooledResource<RM>
where
RM: DatabasePool,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.resource.as_mut().expect("resource already dropped").1
}
}
impl<RM> Pool<RM>
where
RM: DatabasePool,
{
pub fn new(config: RM::Config) -> Arc<Self> {
Arc::new(Self {
default_timeout: config.default_timeout(),
max_size: config.max_size(),
config,
queue: Default::default(),
in_use: Default::default(),
waiter: Default::default(),
})
}
#[inline(always)]
pub fn get(self: &Arc<Self>) -> Result<PooledResource<RM>, Error<RM::Error>> {
self.get_timeout(self.default_timeout)
}
fn increment_connection_counter(&self) -> usize {
let in_use = self.in_use.fetch_add(1, Ordering::AcqRel);
#[cfg(feature = "prometheus")]
{
METRICS.set_db_connections_active(in_use as i64);
}
in_use
}
#[inline(always)]
pub fn get_timeout(
self: &Arc<Self>,
timeout: Duration,
) -> Result<PooledResource<RM>, Error<RM::Error>> {
let mut resources = self.queue.lock().map_err(|_| Error::Poison)?;
let time = Instant::now();
loop {
while let Some((stale, resource)) = resources.pop() {
if !stale.load(Ordering::SeqCst) {
self.increment_connection_counter();
return Ok(PooledResource {
resource: Some((stale, resource)),
pool: self.clone(),
#[cfg(feature = "prometheus")]
start_time: Instant::now(),
});
}
}
if self.in_use.load(Ordering::Relaxed) < self.max_size {
self.increment_connection_counter();
let stale: Arc<AtomicBool> = Arc::new(false.into());
match RM::new_resource(&self.config, stale.clone(), timeout) {
Ok(new_resource) => {
return Ok(PooledResource {
resource: Some((stale, new_resource)),
pool: self.clone(),
#[cfg(feature = "prometheus")]
start_time: Instant::now(),
});
}
Err(e) => {
self.in_use.fetch_sub(1, Ordering::AcqRel);
return Err(e);
}
}
}
resources = self
.waiter
.wait_timeout(resources, timeout)
.map_err(|_| Error::Poison)
.and_then(|(lock, timeout_result)| {
if timeout_result.timed_out() {
tracing::warn!(
"Timeout waiting for the resource (pool size: {}). Waited {} ms",
self.max_size,
time.elapsed().as_millis()
);
Err(Error::Timeout)
} else {
Ok(lock)
}
})?;
}
}
}
impl<RM> Drop for Pool<RM>
where
RM: DatabasePool,
{
fn drop(&mut self) {
if let Ok(mut resources) = self.queue.lock() {
loop {
while let Some(resource) = resources.pop() {
RM::drop(resource.1);
}
if self.in_use.load(Ordering::Relaxed) == 0 {
break;
}
resources = if let Ok(resources) = self.waiter.wait(resources) {
resources
} else {
break;
};
}
}
}
}