use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::time::{Duration, Instant};
use parking_lot::Mutex;
use tokio::sync::Notify;
use super::auto_eject::{AutoEject, AutoEjectState};
use super::NetError;
#[derive(Debug, Clone)]
pub struct ConnPoolConfig {
pub max_connections: usize,
pub server_failure_limit: u32,
pub server_retry_timeout_ms: u64,
pub auto_eject: bool,
}
impl Default for ConnPoolConfig {
fn default() -> Self {
Self {
max_connections: 1,
server_failure_limit: 3,
server_retry_timeout_ms: 30_000,
auto_eject: true,
}
}
}
pub type ConnFuture<C> = Pin<Box<dyn Future<Output = Result<C, NetError>> + Send + 'static>>;
pub trait ConnFactory<C>: Send + Sync + 'static {
fn connect(&self) -> ConnFuture<C>;
}
impl<C, F, Fut> ConnFactory<C> for F
where
F: Fn() -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<C, NetError>> + Send + 'static,
{
fn connect(&self) -> ConnFuture<C> {
Box::pin(self())
}
}
struct PoolInner<C> {
cfg: ConnPoolConfig,
idle: VecDeque<C>,
in_flight: usize,
auto_eject: AutoEject,
backoff: Backoff,
shutdown: bool,
}
#[derive(Debug, Clone)]
struct Backoff {
current: Duration,
max: Duration,
}
impl Backoff {
fn new(max: Duration) -> Self {
Self {
current: Duration::ZERO,
max,
}
}
fn record_failure(&mut self) -> Duration {
if self.current.is_zero() {
self.current = Duration::from_secs(1);
} else {
self.current = self.current.saturating_mul(2);
if self.current > self.max {
self.current = self.max;
}
}
self.current
}
fn record_success(&mut self) {
self.current = Duration::ZERO;
}
}
pub struct ConnPool<C> {
factory: Option<Arc<dyn ConnFactory<C>>>,
state: Arc<Mutex<PoolInner<C>>>,
notify: Arc<Notify>,
}
impl<C> Clone for ConnPool<C> {
fn clone(&self) -> Self {
Self {
factory: self.factory.clone(),
state: Arc::clone(&self.state),
notify: Arc::clone(&self.notify),
}
}
}
impl<C: Send + 'static> ConnPool<C> {
#[must_use]
pub fn new(cfg: ConnPoolConfig) -> Self {
let auto_eject = AutoEject::new(
cfg.auto_eject,
cfg.server_failure_limit.max(1),
Duration::from_millis(cfg.server_retry_timeout_ms),
);
let max_backoff = Duration::from_millis(cfg.server_retry_timeout_ms.max(1_000));
Self {
factory: None,
state: Arc::new(Mutex::new(PoolInner {
cfg,
idle: VecDeque::new(),
in_flight: 0,
auto_eject,
backoff: Backoff::new(max_backoff),
shutdown: false,
})),
notify: Arc::new(Notify::new()),
}
}
pub fn with_factory<F>(cfg: ConnPoolConfig, factory: F) -> Self
where
F: ConnFactory<C>,
{
let mut pool = Self::new(cfg);
pool.factory = Some(Arc::new(factory));
pool
}
pub fn set_factory<F: ConnFactory<C>>(&mut self, factory: F) {
self.factory = Some(Arc::new(factory));
}
#[must_use]
pub fn config(&self) -> ConnPoolConfig {
self.state.lock().cfg.clone()
}
#[must_use]
pub fn idle_count(&self) -> usize {
self.state.lock().idle.len()
}
#[must_use]
pub fn in_flight(&self) -> usize {
self.state.lock().in_flight
}
#[must_use]
pub fn is_ejected(&self, now: Instant) -> bool {
let mut g = self.state.lock();
g.auto_eject.record_attempt(now) == AutoEjectState::Ejected
}
#[must_use]
pub fn auto_eject(&self) -> AutoEject {
self.state.lock().auto_eject.clone()
}
pub fn shutdown(&self) {
{
let mut g = self.state.lock();
g.shutdown = true;
g.idle.clear();
}
self.notify.notify_waiters();
}
pub async fn get(&self) -> Result<ConnHandle<C>, NetError> {
loop {
let waiter = {
let mut g = self.state.lock();
if g.shutdown {
return Err(NetError::PoolShutdown);
}
if let Some(conn) = g.idle.pop_front() {
g.in_flight += 1;
return Ok(ConnHandle {
pool: self.clone(),
inner: Some(conn),
});
}
if g.in_flight + g.idle.len() >= g.cfg.max_connections {
true
} else {
let now = Instant::now();
if g.auto_eject.record_attempt(now) == AutoEjectState::Ejected {
return Err(NetError::Ejected);
}
false
}
};
if waiter {
self.notify.notified().await;
continue;
}
let factory = self
.factory
.as_ref()
.ok_or(NetError::PoolExhausted)?
.clone();
match factory.connect().await {
Ok(conn) => {
let mut g = self.state.lock();
g.in_flight += 1;
g.auto_eject.record_success(Instant::now());
g.backoff.record_success();
return Ok(ConnHandle {
pool: self.clone(),
inner: Some(conn),
});
}
Err(err) => {
let ejected;
{
let mut g = self.state.lock();
let now = Instant::now();
ejected = g.auto_eject.record_failure(now) == AutoEjectState::Ejected;
let _ = g.backoff.record_failure();
}
if ejected {
return Err(NetError::Ejected);
}
return Err(err);
}
}
}
}
fn return_conn(&self, conn: C) {
let mut g = self.state.lock();
if g.in_flight > 0 {
g.in_flight -= 1;
}
if !g.shutdown && g.idle.len() + g.in_flight < g.cfg.max_connections {
g.idle.push_back(conn);
}
drop(g);
self.notify.notify_one();
}
fn drop_conn(&self) {
let mut g = self.state.lock();
if g.in_flight > 0 {
g.in_flight -= 1;
}
drop(g);
self.notify.notify_one();
}
}
impl<C: std::fmt::Debug> std::fmt::Debug for ConnPool<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let g = self.state.lock();
let factory_present = self.factory.is_some();
f.debug_struct("ConnPool")
.field("cfg", &g.cfg)
.field("idle", &g.idle.len())
.field("in_flight", &g.in_flight)
.field("auto_eject_failures", &g.auto_eject.failure_count())
.field("factory_installed", &factory_present)
.field("notify", &"<tokio::sync::Notify>")
.finish()
}
}
pub struct ConnHandle<C: Send + 'static> {
pool: ConnPool<C>,
inner: Option<C>,
}
impl<C: Send + 'static> std::fmt::Debug for ConnHandle<C> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let _ = (&self.pool, &self.inner);
f.debug_struct("ConnHandle")
.field("alive", &self.inner.is_some())
.finish()
}
}
impl<C: Send + 'static> ConnHandle<C> {
pub fn get(&self) -> &C {
self.inner.as_ref().expect("invariant: handle is alive")
}
pub fn get_mut(&mut self) -> &mut C {
self.inner.as_mut().expect("invariant: handle is alive")
}
pub fn release(mut self) {
if let Some(conn) = self.inner.take() {
self.pool.return_conn(conn);
}
}
pub fn discard(mut self) {
self.inner.take();
self.pool.drop_conn();
}
}
impl<C: Send + 'static> Drop for ConnHandle<C> {
fn drop(&mut self) {
if let Some(conn) = self.inner.take() {
self.pool.return_conn(conn);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[tokio::test]
async fn round_trip_basic() {
let counter = Arc::new(AtomicUsize::new(0));
let c2 = Arc::clone(&counter);
let pool: ConnPool<usize> = ConnPool::with_factory(
ConnPoolConfig {
max_connections: 2,
..ConnPoolConfig::default()
},
move || {
let c = Arc::clone(&c2);
async move {
let id = c.fetch_add(1, Ordering::Relaxed);
Ok::<usize, NetError>(id)
}
},
);
let h1 = pool.get().await.unwrap();
let h2 = pool.get().await.unwrap();
assert_ne!(h1.get(), h2.get());
h1.release();
let h3 = pool.get().await.unwrap();
assert_eq!(*h3.get(), 0);
h3.release();
h2.release();
}
#[tokio::test]
async fn max_connections_blocks_until_release() {
let pool: ConnPool<u32> = ConnPool::with_factory(
ConnPoolConfig {
max_connections: 1,
..ConnPoolConfig::default()
},
|| async { Ok::<u32, NetError>(7) },
);
let h = pool.get().await.unwrap();
let pool2 = pool.clone();
let waiter = tokio::spawn(async move {
let h2 = pool2.get().await.unwrap();
assert_eq!(*h2.get(), 7);
});
tokio::task::yield_now().await;
assert!(!waiter.is_finished());
drop(h);
waiter.await.unwrap();
}
#[tokio::test]
async fn auto_eject_after_consecutive_failures() {
let pool: ConnPool<u8> = ConnPool::with_factory(
ConnPoolConfig {
max_connections: 1,
server_failure_limit: 2,
server_retry_timeout_ms: 50,
auto_eject: true,
},
|| async {
Err::<u8, NetError>(NetError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
"test",
)))
},
);
match pool.get().await {
Err(NetError::Io(_)) => {}
other => panic!("expected io error, got {other:?}"),
}
match pool.get().await {
Err(NetError::Ejected) => {}
other => panic!("expected eject, got {other:?}"),
}
match pool.get().await {
Err(NetError::Ejected) => {}
other => panic!("expected eject, got {other:?}"),
}
}
#[tokio::test]
async fn shutdown_unblocks_waiters() {
let pool: ConnPool<u32> = ConnPool::with_factory(
ConnPoolConfig {
max_connections: 1,
..ConnPoolConfig::default()
},
|| async { Ok::<u32, NetError>(1) },
);
let _h = pool.get().await.unwrap();
let pool2 = pool.clone();
let w = tokio::spawn(async move { pool2.get().await });
tokio::task::yield_now().await;
pool.shutdown();
assert!(matches!(w.await.unwrap(), Err(NetError::PoolShutdown)));
}
}