use async_trait::async_trait;
use dashmap::DashMap;
use pravega_client_shared::PravegaNodeUri;
use snafu::Snafu;
use std::fmt;
use std::ops::{Deref, DerefMut};
use uuid::Uuid;
#[derive(Debug, Snafu)]
#[snafu(visibility = "pub(crate)")]
pub enum ConnectionPoolError {
#[snafu(display("Could not establish connection to endpoint: {}", endpoint))]
EstablishConnection { endpoint: String, error_msg: String },
#[snafu(display("No available connection in the internal pool"))]
NoAvailableConnection {},
}
#[async_trait]
pub trait Manager {
type Conn: Send + Sized;
async fn establish_connection(&self, endpoint: PravegaNodeUri)
-> Result<Self::Conn, ConnectionPoolError>;
fn is_valid(&self, conn: &Self::Conn) -> bool;
fn get_max_connections(&self) -> u32;
fn name(&self) -> String;
}
pub struct ConnectionPool<M>
where
M: Manager,
{
manager: M,
managed_pool: ManagedPool<M::Conn>,
}
impl<M> ConnectionPool<M>
where
M: Manager,
{
pub fn new(manager: M) -> Self {
let managed_pool = ManagedPool::new(manager.get_max_connections());
ConnectionPool {
manager,
managed_pool,
}
}
pub async fn get_connection(
&self,
endpoint: PravegaNodeUri,
) -> Result<PooledConnection<'_, M::Conn>, ConnectionPoolError> {
loop {
match self.managed_pool.get_connection(endpoint.clone()) {
Ok(internal_conn) => {
let conn = internal_conn.conn;
if self.manager.is_valid(&conn) {
return Ok(PooledConnection {
uuid: internal_conn.uuid,
inner: Some(conn),
endpoint,
pool: &self.managed_pool,
});
}
}
Err(_e) => {
let conn = self.manager.establish_connection(endpoint.clone()).await?;
return Ok(PooledConnection {
uuid: Uuid::new_v4(),
inner: Some(conn),
endpoint,
pool: &self.managed_pool,
});
}
}
}
}
pub fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
self.managed_pool.pool_len(endpoint)
}
}
impl<M: Manager> fmt::Debug for ConnectionPool<M> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ConnectionPool")
.field("managed pool name", &self.manager.name())
.field("managed pool", &self.managed_pool)
.finish()
}
}
struct ManagedPool<T: Sized + Send> {
map: DashMap<PravegaNodeUri, InternalPool<T>>,
max_connections: u32,
}
impl<T: Sized + Send> ManagedPool<T> {
pub fn new(max_connections: u32) -> Self {
let map = DashMap::new();
ManagedPool { map, max_connections }
}
fn add_connection(&self, endpoint: PravegaNodeUri, connection: InternalConn<T>) {
let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
if self.max_connections > internal.conns.len() as u32 {
internal.conns.push(connection);
}
}
fn get_connection(&self, endpoint: PravegaNodeUri) -> Result<InternalConn<T>, ConnectionPoolError> {
let mut internal = self.map.entry(endpoint).or_insert_with(InternalPool::new);
if internal.conns.is_empty() {
Err(ConnectionPoolError::NoAvailableConnection {})
} else {
let conn = internal.conns.pop().expect("pop connection from vec");
Ok(conn)
}
}
fn pool_len(&self, endpoint: &PravegaNodeUri) -> usize {
self.map.get(endpoint).map_or(0, |pool| pool.conns.len())
}
}
impl<T: Sized + Send> fmt::Debug for ManagedPool<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ManagedPool")
.field("internal map", &self.map)
.field("max connection", &self.max_connections)
.finish()
}
}
struct InternalConn<T> {
uuid: Uuid,
conn: T,
}
struct InternalPool<T> {
conns: Vec<InternalConn<T>>,
}
impl<T: Send + Sized> InternalPool<T> {
fn new() -> Self {
InternalPool { conns: vec![] }
}
}
impl<T> fmt::Debug for InternalPool<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("InternalPool")
.field("pool size", &self.conns.len())
.finish()
}
}
pub struct PooledConnection<'a, T: Send + Sized> {
uuid: Uuid,
endpoint: PravegaNodeUri,
inner: Option<T>,
pool: &'a ManagedPool<T>,
}
impl<T: Send + Sized> fmt::Debug for PooledConnection<'_, T> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.uuid, fmt)
}
}
impl<T: Send + Sized> Drop for PooledConnection<'_, T> {
fn drop(&mut self) {
let conn = self.inner.take().expect("get inner connection");
self.pool.add_connection(
self.endpoint.clone(),
InternalConn {
uuid: self.uuid,
conn,
},
)
}
}
impl<T: Send + Sized> Deref for PooledConnection<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.as_ref().expect("borrow inner connection")
}
}
impl<T: Send + Sized> DerefMut for PooledConnection<'_, T> {
fn deref_mut(&mut self) -> &mut T {
self.inner.as_mut().expect("mutably borrow inner connection")
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tokio::time::Duration;
struct FooConnection {}
struct FooManager {
max_connections_in_pool: u32,
}
#[async_trait]
impl Manager for FooManager {
type Conn = FooConnection;
async fn establish_connection(
&self,
_endpoint: PravegaNodeUri,
) -> Result<Self::Conn, ConnectionPoolError> {
Ok(FooConnection {})
}
fn is_valid(&self, _conn: &Self::Conn) -> bool {
true
}
fn get_max_connections(&self) -> u32 {
self.max_connections_in_pool
}
fn name(&self) -> String {
"foo".to_string()
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_connection_pool_basic() {
let manager = FooManager {
max_connections_in_pool: 2,
};
let pool = ConnectionPool::new(manager);
let endpoint = PravegaNodeUri::from("127.0.0.1:1000".to_string());
assert_eq!(pool.pool_len(&endpoint), 0);
let connection = pool
.get_connection(endpoint.clone())
.await
.expect("get connection");
assert_eq!(pool.pool_len(&endpoint), 0);
drop(connection);
assert_eq!(pool.pool_len(&endpoint), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_connection_pool_size() {
const MAX_CONNECTION: u32 = 2;
let manager = FooManager {
max_connections_in_pool: MAX_CONNECTION,
};
let pool = Arc::new(ConnectionPool::new(manager));
let endpoint = PravegaNodeUri::from("127.0.0.1:1234".to_string());
let mut handles = vec![];
for _ in 0..10 {
let cloned_pool = pool.clone();
let endpoint_clone = endpoint.clone();
let handle = tokio::spawn(async move {
let _connection = cloned_pool
.get_connection(endpoint_clone)
.await
.expect("get connection");
tokio::time::sleep(Duration::from_millis(500)).await;
});
handles.push(handle);
}
while !handles.is_empty() {
let handle = handles.pop().expect("get handle");
handle.await.expect("handle should work");
}
assert_eq!(pool.pool_len(&endpoint) as u32, MAX_CONNECTION);
}
}