#![deny(missing_docs)]
mod config;
mod conn;
mod error;
mod inner;
mod manage_connection;
mod queue;
use futures::channel::oneshot;
use futures::future::{self};
use futures::prelude::*;
use futures::stream;
use std::fmt;
use std::sync::Arc;
use std::time::Duration;
use tokio::time;
use tracing::{debug, debug_span, error, Instrument};
use crate::error::InternalError;
pub use crate::config::Config;
pub use conn::Conn;
pub use error::Error;
pub use manage_connection::ManageConnection;
use inner::ConnectionPool;
use queue::{Live, Queue};
pub struct Pool<C: ManageConnection + Send> {
conn_pool: Arc<ConnectionPool<C>>,
config: Arc<Config>,
}
impl<C: ManageConnection + Send + fmt::Debug> fmt::Debug for Pool<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Pool")
.field("conn_pool", &self.conn_pool)
.finish()
}
}
impl<C: ManageConnection> Pool<C> {
pub fn min_conns(&self) -> usize {
self.config.min_size
}
pub fn max_conns(&self) -> usize {
self.config.max_size
}
}
impl<C> Clone for Pool<C>
where
C: ManageConnection,
{
fn clone(&self) -> Pool<C> {
Pool {
conn_pool: self.conn_pool.clone(),
config: self.config.clone(),
}
}
}
impl<C: ManageConnection + Send> Pool<C> {
pub async fn new(manager: C, config: Config) -> Result<Pool<C>, Error<C::Error>> {
assert!(
config.max_size >= config.min_size,
"max_size of pool must be greater than or equal to the min_size"
);
let conns: stream::FuturesUnordered<_> = std::iter::repeat(&manager)
.take(config.min_size)
.map(|c| c.connect())
.collect();
let conns = conns
.try_fold(Queue::new(config.idle_queue_size), |conns, conn| {
conns.new_conn(Live::new(conn));
future::ok(conns)
})
.await?;
let config = Arc::new(config);
let conn_pool = Arc::new(ConnectionPool::new(conns, manager, Arc::clone(&config)));
Ok(Pool { conn_pool, config })
}
pub async fn connection(&self) -> Result<Conn<C>, Error<C::Error>> {
async {
if let Some(timeout) = self.config.connect_timeout {
tokio::time::timeout(timeout, self.connect_no_timeout()).await?
} else {
self.connect_no_timeout().await
}
}
.instrument(debug_span!(
"l337::Pool::connection",
pool_type = std::any::type_name::<C>(),
))
.await
}
async fn connect_no_timeout(&self) -> Result<Conn<C>, Error<C::Error>> {
if !self.config.test_on_check_out {
return self.try_get_connection().await;
}
for _ in 0..self.conn_pool.max_size() {
let mut connection = self.try_get_connection().await?;
match self.conn_pool.is_valid(&mut connection).await {
Ok(()) => return Ok(connection),
Err(error) => {
debug!(
%error,
"connection: found connection in pool that is no longer valid - removing from pool",
);
connection.forget();
self.conn_pool.conns.decrement();
debug!(
"connection count is now: {:?}",
self.conn_pool.conns.total()
);
self.spawn_new_future_loop();
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
}
Err(Error::Internal(InternalError::AllConnectionsInvalid))
}
async fn try_get_connection(&self) -> Result<Conn<C>, Error<C::Error>> {
{
if let Some(conn) = self.conn_pool.conns.get() {
debug!("connection: connection already in pool and ready to go");
return Ok(Conn::new(conn, self.clone()));
} else {
debug!("connection: try spawn connection");
if let Some(conn) = self.try_spawn_connection().await {
let conn = conn?;
let this = self.clone();
debug!("connection: spawned connection");
return Ok(Conn::new(conn, this));
}
}
}
let (tx, rx) = oneshot::channel();
debug!("connection: pushing to notify of connection");
self.conn_pool.notify_of_connection(tx);
let this = self.clone();
debug!("connection: waiting for connection");
let conn = rx.await.map_err(|_| {
Error::Internal(InternalError::Other(
"Connection channel was closed unexpectedly".into(),
))
})?;
debug!("connection: got connection after waiting");
Ok(Conn::new(conn, this))
}
async fn try_spawn_connection(&self) -> Option<Result<Live<C::Connection>, Error<C::Error>>> {
match self
.conn_pool
.conns
.safe_increment(self.conn_pool.max_size())
{
Some(_) => {
debug!("try_spawn_connection: starting connection");
let result = match self.conn_pool.connect().await {
Ok(conn) => Ok(Live::new(conn)),
Err(err) => {
self.conn_pool.conns.decrement();
Err(err)
}
};
Some(result)
}
None => None,
}
}
pub fn put_back(&self, mut conn: Live<C::Connection>) {
debug!("put_back: start put back");
let broken = self.conn_pool.has_broken(&mut conn);
if broken {
self.conn_pool.conns.decrement();
debug!(
"connection count is now: {:?}",
self.conn_pool.conns.total()
);
self.spawn_new_future_loop();
return;
}
let mut conn = conn;
while let Some(waiting) = self.conn_pool.try_waiting() {
debug!("put_back: got a waiting connection, sending");
conn = match waiting.send(conn) {
Ok(_) => return,
Err(conn) => {
debug!("put_back: unable to send connection");
conn
}
};
}
debug!("put_back: no waiting connection, storing");
if self.conn_pool.conns.store(conn).is_err() {
debug!("put_back: hit the idle connection queue limit");
}
}
fn spawn_new_future_loop(&self) {
let this1 = self.clone();
tokio::spawn(async move {
loop {
let this = this1.clone();
let res = this.conn_pool.connect().await;
match res {
Ok(conn) => {
debug!("creating new connection from spawn loop");
this.conn_pool.conns.increment();
this.put_back(Live::new(conn));
break;
}
Err(err) => {
error!(
"unable to establish new connection, trying again: {:?}",
err
);
time::sleep(Duration::from_secs(1)).await;
}
}
}
});
}
pub fn total_conns(&self) -> usize {
self.conn_pool.conns.total()
}
pub fn idle_conns(&self) -> usize {
self.conn_pool.conns.idle()
}
pub fn idle_conns_push_error(&self) -> usize {
self.conn_pool.conns.idle_push_error_count()
}
pub fn waiters(&self) -> usize {
self.conn_pool.waiting.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::{atomic::*, Arc};
use std::time::Duration;
use thiserror::Error;
use tokio::time::timeout;
#[derive(Debug)]
pub struct DummyManager {
last_id: AtomicUsize,
}
#[derive(Debug)]
pub struct DummyValue {
id: usize,
valid: Arc<AtomicBool>,
broken: bool,
}
#[derive(Debug, Error)]
#[error("DummyError")]
pub struct DummyError;
impl DummyManager {
pub fn new() -> Self {
Self {
last_id: AtomicUsize::new(0),
}
}
}
#[async_trait]
impl ManageConnection for DummyManager {
type Connection = DummyValue;
type Error = DummyError;
async fn connect(&self) -> Result<Self::Connection, Error<Self::Error>> {
Ok(DummyValue {
id: self.last_id.fetch_add(1, Ordering::SeqCst),
valid: Arc::new(AtomicBool::new(true)),
broken: false,
})
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Error<Self::Error>> {
if conn.valid.load(Ordering::SeqCst) {
Ok(())
} else {
Err(Error::External(DummyError))
}
}
fn has_broken(&self, conn: &mut Self::Connection) -> bool {
conn.broken
}
fn timed_out(&self) -> Error<Self::Error> {
unimplemented!()
}
}
#[tokio::test]
async fn test_min_max_conns() {
let mngr = DummyManager::new();
let config = Config::new().min_size(1).max_size(2);
let pool = Arc::new(Pool::new(mngr, config).await.unwrap());
assert_eq!(pool.min_conns(), 1);
assert_eq!(pool.max_conns(), 2);
}
#[tokio::test]
async fn simple_pool_creation_and_connection() {
let mngr = DummyManager::new();
let config: Config = Default::default();
let pool = Pool::new(mngr, config).await.unwrap();
let conn = pool.connection().await.unwrap();
assert!(conn.conn.is_some(), "connection is not correct type");
}
#[tokio::test]
async fn it_returns_a_non_resolved_future_when_over_pool_limit() {
let mngr = DummyManager::new();
let config: Config = Config::new().min_size(1).max_size(1);
let pool = Pool::new(mngr, config).await.unwrap();
::std::mem::forget(pool.connection().await.unwrap());
let result = tokio::time::timeout(Duration::from_millis(10), pool.connection()).await;
assert!(result.is_err(), "didn't timeout");
}
#[tokio::test]
async fn it_times_out_when_no_connections_available() {
let mngr = DummyManager::new();
let config = Config::new()
.max_size(1)
.connection_timeout(Duration::from_millis(100));
let pool = Pool::new(mngr, config).await.unwrap();
let conn1 = pool.connection().await.unwrap();
let result = pool.connection().await;
match result {
Err(Error::Internal(InternalError::TimedOut)) => {}
_ => panic!("connection should timeout"),
}
drop(conn1);
}
#[tokio::test]
async fn it_allocates_new_connections_up_to_max_size() {
let mngr = DummyManager::new();
let config: Config = Config::new().min_size(1).max_size(2);
let pool = Pool::new(mngr, config).await.unwrap();
let connection = pool.connection().await.unwrap();
::std::mem::forget(connection);
let f1 = async {
let conn = tokio::time::timeout(Duration::from_millis(10), pool.connection())
.await
.expect("second connection timed out");
::std::mem::forget(conn);
};
let f2 = async {
let result = tokio::time::timeout(Duration::from_millis(10), pool.connection()).await;
assert!(result.is_err(), "third didn't timeout");
};
futures::join!(f1, f2);
}
#[tokio::test]
async fn it_does_not_return_connections_that_are_invalid() {
let mngr = DummyManager::new();
let config: Config = Config::new().max_size(2).min_size(1);
let pool = Pool::new(mngr, config).await.unwrap();
let conn1 = pool.connection().await.unwrap();
let conn1_id = conn1.id;
let conn1_valid = Arc::clone(&conn1.valid);
drop(conn1);
let conn1 = pool.connection().await.unwrap();
assert_eq!(conn1.id, conn1_id);
drop(conn1);
conn1_valid.store(false, Ordering::SeqCst);
let conn2 = pool.connection().await.unwrap();
assert_ne!(
conn2.id, conn1_id,
"Conn1 was returned from the pool even though it is marked as invalid"
);
}
#[tokio::test]
async fn it_does_return_connections_that_are_invalid_if_so_configured() {
let mngr = DummyManager::new();
let config: Config = Config::new()
.max_size(2)
.min_size(1)
.test_on_check_out(false);
let pool = Pool::new(mngr, config).await.unwrap();
let conn1 = pool.connection().await.unwrap();
let conn1_id = conn1.id;
let conn1_valid = Arc::clone(&conn1.valid);
drop(conn1);
let conn1 = pool.connection().await.unwrap();
assert_eq!(conn1.id, conn1_id);
drop(conn1);
conn1_valid.store(false, Ordering::SeqCst);
let conn1 = pool.connection().await.unwrap();
assert_eq!(
conn1.id, conn1_id,
"Conn1 was not returned from the pool even though it is marked as invalid, and the pool should not check validity"
);
}
#[tokio::test]
async fn it_does_not_return_connections_that_are_broken() {
let mngr = DummyManager::new();
let config: Config = Config::new().max_size(2).min_size(1);
let pool = Pool::new(mngr, config).await.unwrap();
let conn1 = pool.connection().await.unwrap();
let conn1_id = conn1.id;
drop(conn1);
let mut conn1 = pool.connection().await.unwrap();
assert_eq!(conn1.id, conn1_id);
conn1.broken = true;
drop(conn1);
let conn2 = pool.connection().await.unwrap();
assert_ne!(
conn2.id, conn1_id,
"Conn1 was returned from the pool even though it is marked as broken"
);
}
#[tokio::test]
async fn test_can_be_accessed_by_mutliple_futures_concurrently() {
let mngr = DummyManager::new();
let config = Config::new().min_size(2).max_size(2);
let pool = Arc::new(Pool::new(mngr, config).await.unwrap());
let count = Arc::new(AtomicUsize::new(0));
futures::join!(
loop_run(Arc::clone(&count), Arc::clone(&pool)),
loop_run(Arc::clone(&count), Arc::clone(&pool))
);
assert_eq!(pool.total_conns(), 2);
assert_eq!(pool.idle_conns(), 2);
}
async fn loop_run(count: Arc<AtomicUsize>, pool: Arc<Pool<DummyManager>>) {
tokio::spawn(async move {
loop {
timeout(Duration::from_secs(5), pool.connection())
.await
.expect("connection timed out")
.expect("error getting connection");
let old_count = count.fetch_add(1, Ordering::SeqCst);
if old_count + 1 >= 100 {
break;
}
}
})
.await
.unwrap();
}
}