use super::connection::{Floating, Idle, Live};
use crate::connection::ConnectOptions;
use crate::connection::Connection;
use crate::database::Database;
use crate::error::Error;
use crate::pool::{deadline_as_timeout, CloseEvent, PoolOptions};
use crossbeam_queue::ArrayQueue;
use futures_intrusive::sync::{Semaphore, SemaphoreReleaser};
use std::cmp;
use std::future::Future;
use std::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
use std::sync::Arc;
use crate::pool::options::PoolConnectionMetadata;
use std::time::{Duration, Instant};
const WAKE_ALL_PERMITS: usize = usize::MAX / 2;
pub(crate) struct PoolInner<DB: Database> {
pub(super) connect_options: <DB::Connection as Connection>::Options,
pub(super) idle_conns: ArrayQueue<Idle<DB>>,
pub(super) semaphore: Semaphore,
pub(super) size: AtomicU32,
pub(super) num_idle: AtomicUsize,
is_closed: AtomicBool,
pub(super) on_closed: event_listener::Event,
pub(super) options: PoolOptions<DB>,
}
impl<DB: Database> PoolInner<DB> {
pub(super) fn new_arc(
options: PoolOptions<DB>,
connect_options: <DB::Connection as Connection>::Options,
) -> Arc<Self> {
let capacity = options.max_connections as usize;
let _ = capacity
.checked_add(WAKE_ALL_PERMITS)
.expect("max_connections exceeds max capacity of the pool");
let pool = Self {
connect_options,
idle_conns: ArrayQueue::new(capacity),
semaphore: Semaphore::new(options.fair, capacity),
size: AtomicU32::new(0),
num_idle: AtomicUsize::new(0),
is_closed: AtomicBool::new(false),
on_closed: event_listener::Event::new(),
options,
};
let pool = Arc::new(pool);
spawn_maintenance_tasks(&pool);
pool
}
pub(super) fn size(&self) -> u32 {
self.size.load(Ordering::Acquire)
}
pub(super) fn num_idle(&self) -> usize {
self.num_idle.load(Ordering::Acquire)
}
pub(super) fn is_closed(&self) -> bool {
self.is_closed.load(Ordering::Acquire)
}
pub(super) fn close<'a>(self: &'a Arc<Self>) -> impl Future<Output = ()> + 'a {
let already_closed = self.is_closed.swap(true, Ordering::AcqRel);
if !already_closed {
self.semaphore.release(WAKE_ALL_PERMITS);
self.on_closed.notify(usize::MAX);
}
async move {
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
}
let _permits = self
.semaphore
.acquire(WAKE_ALL_PERMITS + (self.options.max_connections as usize))
.await;
while let Some(idle) = self.idle_conns.pop() {
let _ = idle.live.float((*self).clone()).close().await;
}
}
}
pub(crate) fn close_event(&self) -> CloseEvent {
CloseEvent {
listener: (!self.is_closed()).then(|| self.on_closed.listen()),
}
}
#[inline]
pub(super) fn try_acquire(self: &Arc<Self>) -> Option<Floating<DB, Idle<DB>>> {
if self.is_closed() {
return None;
}
let permit = self.semaphore.try_acquire(1)?;
self.pop_idle(permit).ok()
}
fn pop_idle<'a>(
self: &'a Arc<Self>,
permit: SemaphoreReleaser<'a>,
) -> Result<Floating<DB, Idle<DB>>, SemaphoreReleaser<'a>> {
if let Some(idle) = self.idle_conns.pop() {
self.num_idle.fetch_sub(1, Ordering::AcqRel);
Ok(Floating::from_idle(idle, (*self).clone(), permit))
} else {
Err(permit)
}
}
pub(super) fn release(&self, floating: Floating<DB, Live<DB>>) {
let Floating { inner: idle, guard } = floating.into_idle();
if !self.idle_conns.push(idle).is_ok() {
panic!("BUG: connection queue overflow in release()");
}
guard.release_permit();
self.num_idle.fetch_add(1, Ordering::AcqRel);
}
pub(super) fn try_increment_size<'a>(
self: &'a Arc<Self>,
permit: SemaphoreReleaser<'a>,
) -> Result<DecrementSizeGuard<DB>, SemaphoreReleaser<'a>> {
match self
.size
.fetch_update(Ordering::AcqRel, Ordering::Acquire, |size| {
size.checked_add(1)
.filter(|size| size <= &self.options.max_connections)
}) {
Ok(_) => Ok(DecrementSizeGuard::from_permit((*self).clone(), permit)),
Err(_) => Err(permit),
}
}
pub(super) async fn acquire(self: &Arc<Self>) -> Result<Floating<DB, Live<DB>>, Error> {
if self.is_closed() {
return Err(Error::PoolClosed);
}
let deadline = Instant::now() + self.options.acquire_timeout;
sqlx_rt::timeout(
self.options.acquire_timeout,
async {
loop {
let permit = self.semaphore.acquire(1).await;
if self.is_closed() {
return Err(Error::PoolClosed);
}
let guard = match self.pop_idle(permit) {
Ok(conn) => match check_idle_conn(conn, &self.options).await {
Ok(live) => return Ok(live),
Err(guard) => guard,
},
Err(permit) => if let Ok(guard) = self.try_increment_size(permit) {
guard
} else {
log::debug!("woke but was unable to acquire idle connection or open new one; retrying");
continue;
}
};
return self.connect(deadline, guard).await;
}
}
)
.await
.map_err(|_| Error::PoolTimedOut)?
}
pub(super) async fn connect(
self: &Arc<Self>,
deadline: Instant,
guard: DecrementSizeGuard<DB>,
) -> Result<Floating<DB, Live<DB>>, Error> {
if self.is_closed() {
return Err(Error::PoolClosed);
}
let mut backoff = Duration::from_millis(10);
let max_backoff = deadline_as_timeout::<DB>(deadline)? / 5;
loop {
let timeout = deadline_as_timeout::<DB>(deadline)?;
match sqlx_rt::timeout(timeout, self.connect_options.connect()).await {
Ok(Ok(mut raw)) => {
let meta = PoolConnectionMetadata {
age: Duration::ZERO,
idle_for: Duration::ZERO,
};
let res = if let Some(callback) = &self.options.after_connect {
callback(&mut raw, meta).await
} else {
Ok(())
};
match res {
Ok(()) => return Ok(Floating::new_live(raw, guard)),
Err(e) => {
log::error!("error returned from after_connect: {:?}", e);
let _ = raw.close_hard().await;
}
}
}
Ok(Err(Error::Io(e))) if e.kind() == std::io::ErrorKind::ConnectionRefused => (),
Ok(Err(Error::Database(error))) if error.is_transient_in_connect_phase() => (),
Ok(Err(e)) => return Err(e),
Err(_) => return Err(Error::PoolTimedOut),
}
sqlx_rt::sleep(backoff).await;
backoff = cmp::min(backoff * 2, max_backoff);
}
}
pub async fn try_min_connections(self: &Arc<Self>, deadline: Instant) -> Result<(), Error> {
macro_rules! unwrap_or_return {
($expr:expr) => {
match $expr {
Some(val) => val,
None => return Ok(()),
}
};
}
while self.size() < self.options.min_connections {
let permit = unwrap_or_return!(self.semaphore.try_acquire(1));
let guard = unwrap_or_return!(self.try_increment_size(permit).ok());
self.release(self.connect(deadline, guard).await?);
}
Ok(())
}
pub async fn min_connections_maintenance(self: &Arc<Self>, deadline: Option<Instant>) {
let deadline = deadline.unwrap_or_else(|| {
Instant::now() + Duration::from_secs(300)
});
match self.try_min_connections(deadline).await {
Ok(()) => (),
Err(Error::PoolClosed) => (),
Err(Error::PoolTimedOut) => {
log::debug!("unable to complete `min_connections` maintenance before deadline")
}
Err(e) => log::debug!("error while maintaining min_connections: {:?}", e),
}
}
}
fn is_beyond_max_lifetime<DB: Database>(live: &Live<DB>, options: &PoolOptions<DB>) -> bool {
options
.max_lifetime
.map_or(false, |max| live.created_at.elapsed() > max)
}
fn is_beyond_idle_timeout<DB: Database>(idle: &Idle<DB>, options: &PoolOptions<DB>) -> bool {
options
.idle_timeout
.map_or(false, |timeout| idle.idle_since.elapsed() > timeout)
}
async fn check_idle_conn<DB: Database>(
mut conn: Floating<DB, Idle<DB>>,
options: &PoolOptions<DB>,
) -> Result<Floating<DB, Live<DB>>, DecrementSizeGuard<DB>> {
if is_beyond_max_lifetime(&conn, options) {
return Err(conn.close().await);
}
if options.test_before_acquire {
if let Err(e) = conn.ping().await {
log::info!("ping on idle connection returned error: {}", e);
return Err(conn.close_hard().await);
}
}
if let Some(test) = &options.before_acquire {
let meta = conn.metadata();
match test(&mut conn.live.raw, meta).await {
Ok(false) => {
return Err(conn.close().await);
}
Err(error) => {
log::warn!("error from `before_acquire`: {}", error);
return Err(conn.close_hard().await);
}
Ok(true) => {}
}
}
Ok(conn.into_live())
}
fn spawn_maintenance_tasks<DB: Database>(pool: &Arc<PoolInner<DB>>) {
let pool = Arc::clone(&pool);
let period = match (pool.options.max_lifetime, pool.options.idle_timeout) {
(Some(it), None) | (None, Some(it)) => it,
(Some(a), Some(b)) => cmp::min(a, b),
(None, None) => {
if pool.options.min_connections > 0 {
sqlx_rt::spawn(async move {
pool.min_connections_maintenance(None).await;
});
}
return;
}
};
sqlx_rt::spawn(async move {
let _ = pool
.close_event()
.do_until(async {
while !pool.is_closed() {
let next_run = Instant::now() + period;
pool.min_connections_maintenance(Some(next_run)).await;
if let Some(duration) = next_run.checked_duration_since(Instant::now()) {
sqlx_rt::sleep(duration).await;
} else {
sqlx_rt::yield_now().await;
}
if !pool.idle_conns.is_empty() {
do_reap(&pool).await;
}
}
})
.await;
});
}
async fn do_reap<DB: Database>(pool: &Arc<PoolInner<DB>>) {
let max_reaped = pool.size().saturating_sub(pool.options.min_connections);
let (reap, keep) = (0..max_reaped)
.filter_map(|_| pool.try_acquire())
.partition::<Vec<_>, _>(|conn| {
is_beyond_idle_timeout(conn, &pool.options)
|| is_beyond_max_lifetime(conn, &pool.options)
});
for conn in keep {
pool.release(conn.into_live());
}
for conn in reap {
let _ = conn.close().await;
}
}
pub(in crate::pool) struct DecrementSizeGuard<DB: Database> {
pub(crate) pool: Arc<PoolInner<DB>>,
cancelled: bool,
}
impl<DB: Database> DecrementSizeGuard<DB> {
pub fn new_permit(pool: Arc<PoolInner<DB>>) -> Self {
Self {
pool,
cancelled: false,
}
}
pub fn from_permit(pool: Arc<PoolInner<DB>>, mut permit: SemaphoreReleaser<'_>) -> Self {
permit.disarm();
Self::new_permit(pool)
}
fn release_permit(self) {
self.pool.semaphore.release(1);
self.cancel();
}
pub fn cancel(mut self) {
self.cancelled = true;
}
}
impl<DB: Database> Drop for DecrementSizeGuard<DB> {
fn drop(&mut self) {
if !self.cancelled {
self.pool.size.fetch_sub(1, Ordering::AcqRel);
self.pool.semaphore.release(1);
}
}
}