use deadpool::managed::{Object, Timeouts};
use std::{
sync::{
Arc,
atomic::{AtomicU64, Ordering},
},
time::Duration,
};
use toasty_core::driver::{Capability, Driver};
use tokio::{sync::Notify, task::JoinHandle};
use super::connection_task::{ConnectionHandle, ConnectionOperation};
use crate::engine::Engine;
fn get_default_pool_max_size() -> usize {
deadpool::managed::PoolConfig::default().max_size
}
const DEFAULT_SWEEP_PING_TIMEOUT: Duration = Duration::from_secs(5);
#[derive(Debug, Clone)]
pub(crate) struct PoolConfig {
pub(crate) max_size: usize,
pub(crate) timeouts: Timeouts,
pub(crate) health_check_interval: Option<Duration>,
pub(crate) pre_ping: bool,
pub(crate) max_connection_lifetime: Option<Duration>,
pub(crate) max_connection_idle_time: Option<Duration>,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
max_size: get_default_pool_max_size(),
timeouts: Default::default(),
health_check_interval: Some(Duration::from_secs(60)),
pre_ping: false,
max_connection_lifetime: None,
max_connection_idle_time: None,
}
}
}
#[derive(Debug)]
pub struct Pool {
inner: deadpool::managed::Pool<Manager>,
capability: &'static Capability,
sweep_task: Option<JoinHandle<()>>,
}
impl Drop for Pool {
fn drop(&mut self) {
if let Some(handle) = self.sweep_task.take() {
handle.abort();
}
}
}
impl Pool {
pub(crate) fn new(
driver: impl Driver,
engine: Engine,
config: PoolConfig,
) -> crate::Result<Self> {
let capability = driver.capability();
let driver_cap = driver.max_connections();
let effective_max = match driver_cap {
Some(cap) if cap < config.max_size => {
tracing::warn!(
requested = config.max_size,
cap,
"driver caps max pool size below the requested value; using driver cap"
);
cap
}
_ => config.max_size,
};
let sweep_waker = Arc::new(SweepWaker::new());
let inner = deadpool::managed::Pool::builder(Manager {
driver: Box::new(driver),
engine,
sweep_waker: sweep_waker.clone(),
pre_ping: config.pre_ping,
max_connection_lifetime: config.max_connection_lifetime,
max_connection_idle_time: config.max_connection_idle_time,
})
.runtime(deadpool::Runtime::Tokio1)
.max_size(effective_max)
.timeouts(config.timeouts)
.build()
.map_err(|e| {
tracing::error!(error = %e, "failed to build connection pool");
toasty_core::Error::connection_pool(e)
})?;
let sweep_task = config.health_check_interval.map(|interval| {
let task = SweepTask {
pool: inner.clone(),
waker: sweep_waker,
interval,
last_serviced: 0,
};
tokio::spawn(task.run())
});
Ok(Self {
inner,
capability,
sweep_task,
})
}
pub(crate) async fn get(&self, shared: Arc<super::Shared>) -> crate::Result<super::Connection> {
let connection = self.inner.get().await.map_err(|e| {
tracing::error!(error = %e, "failed to acquire connection from pool");
toasty_core::Error::connection_pool(e)
})?;
Ok(super::Connection {
inner: connection,
shared,
})
}
pub fn driver(&self) -> &dyn Driver {
self.inner.manager().driver.as_ref()
}
pub fn capability(&self) -> &'static Capability {
self.capability
}
pub fn status(&self) -> PoolStatus {
let s = self.inner.status();
PoolStatus {
max_size: s.max_size,
size: s.size,
available: s.available,
waiting: s.waiting,
}
}
}
pub(super) struct Manager {
driver: Box<dyn Driver>,
engine: Engine,
sweep_waker: Arc<SweepWaker>,
pre_ping: bool,
max_connection_lifetime: Option<Duration>,
max_connection_idle_time: Option<Duration>,
}
impl std::fmt::Debug for Manager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Manager")
.field("driver", &self.driver)
.finish()
}
}
impl deadpool::managed::Manager for Manager {
type Type = ConnectionHandle;
type Error = crate::Error;
async fn create(&self) -> Result<Self::Type, Self::Error> {
tracing::debug!("creating new pooled connection");
let connection = self.driver.connect().await.inspect_err(|e| {
tracing::error!(error = %e, "failed to create database connection");
})?;
Ok(ConnectionHandle::spawn(
connection,
self.engine.clone(),
self.sweep_waker.clone(),
))
}
async fn recycle(
&self,
obj: &mut Self::Type,
metrics: &deadpool::managed::Metrics,
) -> deadpool::managed::RecycleResult<Self::Error> {
if let Some(max) = self.max_connection_lifetime
&& metrics.age() >= max
{
tracing::debug!(?max, "discarding pooled connection past max lifetime");
return Err(deadpool::managed::RecycleError::message(
"connection exceeded max lifetime",
));
}
if let Some(max) = self.max_connection_idle_time
&& metrics.last_used() >= max
{
tracing::debug!(?max, "discarding pooled connection past max idle time");
return Err(deadpool::managed::RecycleError::message(
"connection exceeded max idle time",
));
}
if obj.in_tx.is_closed() || obj.is_finished() {
tracing::debug!("discarding dead pooled connection");
return Err(deadpool::managed::RecycleError::message(
"background task is no longer running",
));
}
if self.pre_ping {
let (tx, rx) = tokio::sync::oneshot::channel();
if obj.in_tx.send(ConnectionOperation::Ping { tx }).is_err() {
tracing::debug!("pre-ping channel closed; discarding pooled connection");
return Err(deadpool::managed::RecycleError::message(
"background task is no longer running",
));
}
match rx.await {
Ok(Ok(())) => {}
Ok(Err(err)) => {
tracing::debug!(error = %err, "pre-ping failed; discarding pooled connection");
return Err(deadpool::managed::RecycleError::Backend(err));
}
Err(_) => {
tracing::debug!(
"pre-ping response channel dropped; discarding pooled connection"
);
return Err(deadpool::managed::RecycleError::message(
"background task exited during pre-ping",
));
}
}
}
tracing::trace!("recycling pooled connection");
Ok(())
}
}
pub(crate) struct SweepWaker {
requests: AtomicU64,
notify: Notify,
}
impl SweepWaker {
fn new() -> Self {
Self {
requests: AtomicU64::new(0),
notify: Notify::new(),
}
}
pub(crate) fn wake(&self) {
self.requests.fetch_add(1, Ordering::Relaxed);
self.notify.notify_one();
}
}
struct SweepTask {
pool: deadpool::managed::Pool<Manager>,
waker: Arc<SweepWaker>,
interval: Duration,
last_serviced: u64,
}
impl SweepTask {
async fn run(mut self) {
let mut ticker = tokio::time::interval(self.interval);
ticker.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
ticker.tick().await;
loop {
tokio::select! {
_ = ticker.tick() => {
self.periodic_iteration().await;
}
_ = self.waker.notify.notified() => {
if self.waker.requests.load(Ordering::Relaxed) <= self.last_serviced {
tracing::trace!("sweep notify already serviced; skipping");
continue;
}
tracing::debug!("sweep woken by observed connection_lost; escalating");
self.escalate().await;
}
}
}
}
async fn periodic_iteration(&mut self) {
if self.pool.status().available == 0 {
return;
}
let Some(conn) = self.try_get_idle().await else {
return;
};
if Self::ping_conn(&conn).await {
drop(conn);
} else {
let _ = Object::take(conn);
self.escalate().await;
}
}
async fn escalate(&mut self) {
let snap = self.waker.requests.load(Ordering::Relaxed);
let budget = self.pool.status().available;
let mut healthy = Vec::with_capacity(budget);
for _ in 0..budget {
let Some(conn) = self.try_get_idle().await else {
break;
};
if Self::ping_conn(&conn).await {
healthy.push(conn);
} else {
let _ = Object::take(conn);
}
}
drop(healthy);
self.last_serviced = snap;
}
async fn try_get_idle(&self) -> Option<Object<Manager>> {
let timeouts = Timeouts {
wait: Some(Duration::ZERO),
create: Some(Duration::ZERO),
recycle: self.pool.timeouts().recycle,
};
self.pool.timeout_get(&timeouts).await.ok()
}
async fn ping_conn(handle: &ConnectionHandle) -> bool {
let (tx, rx) = tokio::sync::oneshot::channel();
if handle.in_tx.send(ConnectionOperation::Ping { tx }).is_err() {
return false;
}
match tokio::time::timeout(DEFAULT_SWEEP_PING_TIMEOUT, rx).await {
Ok(Ok(Ok(()))) => true,
Ok(Ok(Err(err))) => {
tracing::debug!(error = %err, "sweep ping failed");
false
}
Ok(Err(_)) => false, Err(_) => {
tracing::debug!("sweep ping timed out");
false
}
}
}
}
#[derive(Clone, Copy, Debug)]
pub struct PoolStatus {
pub max_size: usize,
pub size: usize,
pub available: usize,
pub waiting: usize,
}