use std::sync::Arc;
use std::time::Duration;
pub use deadpool::Status as PoolStatus;
use deadpool::managed::Pool as DeadPool;
use crate::config::DaemonServer;
use crate::pool::builder::{ParameterLogging, PoolBuilder};
use crate::pool::manager::JobManager;
use crate::pool::routing::Registry;
pub(crate) struct ReaperGuard {
pub(crate) handle: tokio::task::JoinHandle<()>,
}
impl Drop for ReaperGuard {
fn drop(&mut self) {
self.handle.abort();
}
}
#[derive(Clone)]
pub struct Pool {
pub(crate) inner: DeadPool<JobManager>,
pub(crate) registry: Arc<Registry>,
pub(crate) acquire_timeout: Option<Duration>,
#[cfg_attr(not(feature = "tracing"), allow(dead_code))]
pub(crate) parameter_logging: ParameterLogging,
pub(crate) _reaper: Option<Arc<ReaperGuard>>,
}
const SATURATION_THRESHOLD: u32 = 32;
impl Pool {
fn pick_unsaturated(&self) -> Option<std::sync::Arc<crate::Job>> {
let limit = std::cmp::min(self.inner.status().size, 8);
let mut candidates = self.registry.least_busy(limit);
candidates.retain(|j| j.in_flight() < SATURATION_THRESHOLD);
candidates.into_iter().next()
}
pub fn builder(server: impl Into<Arc<DaemonServer>>) -> PoolBuilder {
PoolBuilder::new(server.into())
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(skip(self), fields(sql = %sql, tier = tracing::field::Empty))
)]
pub async fn execute(&self, sql: &str) -> crate::Result<crate::query::Rows> {
use crate::Job;
#[cfg(feature = "metrics")]
emit_pool_status_gauges(&self.inner.status());
if let Some(arc) = self.registry.peek_idle() {
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "try_idle");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "try_idle",
)
.increment(1);
return Job::execute(&arc, sql).await;
}
if let Some(arc) = self.pick_unsaturated() {
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "least_busy_scan");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "least_busy_scan",
)
.increment(1);
return Job::execute(&arc, sql).await;
}
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "fair_queue");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "fair_queue",
)
.increment(1);
let obj = self.get_or_timeout().await?;
Job::execute(&obj, sql).await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(
skip(self, params),
fields(
sql = %sql,
param_count = params.len(),
tier = tracing::field::Empty,
param_types = tracing::field::Empty,
params = tracing::field::Empty,
),
)
)]
pub async fn execute_with(
&self,
sql: &str,
params: &[serde_json::Value],
) -> crate::Result<crate::query::Rows> {
use crate::Job;
#[cfg(feature = "tracing")]
match self.parameter_logging {
ParameterLogging::None => {}
ParameterLogging::TypesAndCount => {
let types: Vec<&'static str> = params
.iter()
.map(|v| match v {
serde_json::Value::String(_) => "String",
serde_json::Value::Number(_) => "Number",
serde_json::Value::Bool(_) => "Bool",
serde_json::Value::Null => "Null",
serde_json::Value::Array(_) => "Array",
serde_json::Value::Object(_) => "Object",
})
.collect();
tracing::Span::current().record("param_types", tracing::field::debug(&types));
}
ParameterLogging::Full => {
tracing::Span::current().record("params", tracing::field::debug(params));
}
}
#[cfg(feature = "metrics")]
emit_pool_status_gauges(&self.inner.status());
if let Some(arc) = self.registry.peek_idle() {
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "try_idle");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "try_idle",
)
.increment(1);
return Job::execute_with(&arc, sql, params).await;
}
if let Some(arc) = self.pick_unsaturated() {
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "least_busy_scan");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "least_busy_scan",
)
.increment(1);
return Job::execute_with(&arc, sql, params).await;
}
#[cfg(feature = "tracing")]
tracing::Span::current().record("tier", "fair_queue");
#[cfg(feature = "metrics")]
metrics::counter!(
crate::observability::POOL_ROUTING_TIER_WINS_TOTAL,
"tier" => "fair_queue",
)
.increment(1);
let obj = self.get_or_timeout().await?;
Job::execute_with(&obj, sql, params).await
}
#[cfg_attr(
feature = "tracing",
tracing::instrument(skip(self), fields(acquired_in_micros = tracing::field::Empty))
)]
pub async fn acquire(&self) -> crate::Result<crate::pool::reserved::Reserved> {
#[cfg(any(feature = "tracing", feature = "metrics"))]
let start = std::time::Instant::now();
let obj = self.get_or_timeout().await?;
#[cfg(any(feature = "tracing", feature = "metrics"))]
let elapsed_micros = u64::try_from(start.elapsed().as_micros()).unwrap_or(u64::MAX);
#[cfg(feature = "tracing")]
tracing::Span::current().record("acquired_in_micros", elapsed_micros);
#[cfg(feature = "metrics")]
{
#[allow(clippy::cast_precision_loss)]
let micros_f64 = elapsed_micros as f64;
metrics::histogram!(crate::observability::POOL_ACQUIRE_LATENCY_MICROS)
.record(micros_f64);
metrics::counter!(crate::observability::POOL_RESERVED_ACQUIRED_TOTAL).increment(1);
}
Ok(crate::pool::reserved::Reserved::new(obj))
}
#[must_use]
pub fn status(&self) -> PoolStatus {
self.inner.status()
}
async fn get_or_timeout(
&self,
) -> crate::Result<deadpool::managed::Object<crate::pool::manager::JobManager>> {
use deadpool::managed::PoolError;
Box::pin(self.inner.get()).await.map_err(|e| match e {
PoolError::Timeout(_) => crate::Error::PoolExhausted {
timeout: self.acquire_timeout.unwrap_or_default(),
},
PoolError::Backend(b) => b,
other => crate::Error::Internal(format!("pool: {other}")),
})
}
}
#[cfg(feature = "metrics")]
fn emit_pool_status_gauges(s: &PoolStatus) {
#[allow(clippy::cast_precision_loss)]
{
metrics::gauge!(crate::observability::POOL_SIZE).set(s.size as f64);
metrics::gauge!(crate::observability::POOL_AVAILABLE).set(s.available as f64);
metrics::gauge!(crate::observability::POOL_WAITING).set(s.waiting as f64);
}
}
pub(crate) fn reaper_period(idle_timeout: Duration) -> Option<Duration> {
if idle_timeout.is_zero() {
return None;
}
let quarter = idle_timeout / 4;
let clamped = quarter.clamp(Duration::from_secs(1), Duration::from_secs(60));
Some(clamped)
}
pub(crate) fn spawn_idle_reaper(
inner: &DeadPool<JobManager>,
idle_timeout: Duration,
period: Duration,
) -> tokio::task::JoinHandle<()> {
let weak = inner.weak();
tokio::spawn(async move {
let mut interval = tokio::time::interval(period);
interval.tick().await;
loop {
interval.tick().await;
let Some(pool) = weak.upgrade() else {
return;
};
let result = pool.retain(|_, metrics| metrics.last_used() < idle_timeout);
#[cfg(feature = "metrics")]
if !result.removed.is_empty() {
metrics::counter!(crate::observability::POOL_IDLE_REAPED_TOTAL)
.increment(result.removed.len() as u64);
}
#[cfg(not(feature = "metrics"))]
let _ = result;
drop(pool);
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reaper_period_quarters_short_timeouts_clamped_to_one_second() {
assert_eq!(
reaper_period(Duration::from_secs(2)),
Some(Duration::from_secs(1))
);
}
#[test]
fn reaper_period_quarters_medium_timeouts_unclamped() {
assert_eq!(
reaper_period(Duration::from_secs(60)),
Some(Duration::from_secs(15))
);
}
#[test]
fn reaper_period_clamps_long_timeouts_to_one_minute() {
assert_eq!(
reaper_period(Duration::from_secs(3600)),
Some(Duration::from_secs(60))
);
}
#[test]
fn reaper_period_returns_none_for_zero() {
assert_eq!(reaper_period(Duration::ZERO), None);
}
}