use crate::{
config::Config,
connection::{Builder, Connection},
enums::Mode,
error::Error,
};
use std::{
collections::VecDeque,
ops::{Deref, DerefMut},
sync::{Arc, Mutex},
};
#[derive(Debug)]
pub struct PooledConnection<T: Pool = Arc<dyn Pool>> {
pool: T,
conn: Option<Connection>,
}
impl<T: Pool> AsRef<Connection> for PooledConnection<T> {
fn as_ref(&self) -> &Connection {
self.conn.as_ref().unwrap()
}
}
impl<T: Pool> AsMut<Connection> for PooledConnection<T> {
fn as_mut(&mut self) -> &mut Connection {
self.conn.as_mut().unwrap()
}
}
impl<T: Pool> Drop for PooledConnection<T> {
fn drop(&mut self) {
if let Some(conn) = self.conn.take() {
self.pool.give(conn);
}
}
}
impl<T: Pool> Deref for PooledConnection<T> {
type Target = Connection;
fn deref(&self) -> &Self::Target {
self.conn.as_ref().unwrap()
}
}
impl<T: Pool> DerefMut for PooledConnection<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.conn.as_mut().unwrap()
}
}
impl<T: Pool + Clone> PooledConnection<T> {
pub fn new(pool: &T) -> Result<PooledConnection<T>, Error> {
pool.take().map(|conn| {
let conn = Some(conn);
let pool = pool.clone();
Self { pool, conn }
})
}
}
pub trait Pool {
fn mode(&self) -> Mode;
fn take(&self) -> Result<Connection, Error>;
fn give(&self, conn: Connection);
}
impl Pool for Arc<dyn Pool> {
fn mode(&self) -> Mode {
self.as_ref().mode()
}
fn take(&self) -> Result<Connection, Error> {
self.as_ref().take()
}
fn give(&self, conn: Connection) {
self.as_ref().give(conn)
}
}
impl<T: Pool> Pool for Arc<T> {
fn mode(&self) -> Mode {
self.as_ref().mode()
}
fn take(&self) -> Result<Connection, Error> {
self.as_ref().take()
}
fn give(&self, conn: Connection) {
self.as_ref().give(conn)
}
}
#[derive(Debug)]
pub struct ConfigPool {
mode: Mode,
config: Config,
pool: Mutex<VecDeque<Connection>>,
max_pool_size: usize,
}
pub type ConfigPoolRef = Arc<ConfigPool>;
pub struct ConfigPoolBuilder(ConfigPool);
impl ConfigPoolBuilder {
pub fn new(mode: Mode, config: Config) -> Self {
Self(ConfigPool {
mode,
config,
pool: Mutex::new(VecDeque::new()),
max_pool_size: usize::MAX,
})
}
pub fn set_pool(&mut self, pool: VecDeque<Connection>) -> &mut Self {
self.0.pool = Mutex::new(pool);
self
}
pub fn set_max_pool_size(&mut self, max_pool_size: usize) -> &mut Self {
self.0.max_pool_size = max_pool_size;
self
}
pub fn build(self) -> Arc<ConfigPool> {
Arc::new(self.0)
}
}
impl ConfigPool {
pub fn pool_size(&self) -> usize {
self.pool.lock().map(|pool| pool.len()).unwrap_or(0)
}
pub fn is_poisoned(&self) -> bool {
self.pool.is_poisoned()
}
}
impl Pool for ConfigPool {
fn mode(&self) -> Mode {
self.mode
}
fn take(&self) -> Result<Connection, Error> {
let from_pool = match self.pool.lock() {
Ok(mut pool) => pool.pop_front(),
Err(_) => None,
};
let conn = match from_pool {
Some(mut conn) => {
conn.set_config(self.config.clone())?;
conn
}
None => self.config.build_connection(self.mode)?,
};
Ok(conn)
}
fn give(&self, mut conn: Connection) {
let wiped = conn.wipe().is_ok();
if let Ok(mut pool) = self.pool.lock() {
if pool.len() < self.max_pool_size && wiped {
pool.push_back(conn);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
callbacks::ClientHelloCallback,
config::{self, Config},
connection, error,
testing::TestPair,
};
#[test]
fn config_pool_single_connection() -> Result<(), Box<dyn std::error::Error>> {
let pool = ConfigPoolBuilder::new(Mode::Server, Config::default()).build();
assert_eq!(pool.pool_size(), 0);
for _ in 0..5 {
let _conn = PooledConnection::new(&pool)?;
assert_eq!(pool.pool_size(), 0);
}
assert_eq!(pool.pool_size(), 1);
Ok(())
}
#[test]
fn config_pool_multiple_connections() -> Result<(), Box<dyn std::error::Error>> {
let pool = ConfigPoolBuilder::new(Mode::Server, Config::default()).build();
let mut conns = VecDeque::new();
const COUNT: usize = 25;
for _ in 0..COUNT {
conns.push_back(PooledConnection::new(&pool)?);
assert_eq!(pool.pool_size(), 0);
}
assert_eq!(conns.len(), COUNT);
conns.clear();
assert_eq!(pool.pool_size(), COUNT);
const SUBSET_COUNT: usize = COUNT / 2;
for i in 1..=SUBSET_COUNT {
conns.push_back(PooledConnection::new(&pool)?);
assert_eq!(pool.pool_size(), COUNT - i);
}
assert_eq!(conns.len(), SUBSET_COUNT);
assert_eq!(pool.pool_size(), COUNT - SUBSET_COUNT);
conns.clear();
assert_eq!(pool.pool_size(), COUNT);
Ok(())
}
#[test]
fn config_pool_with_max_size() -> Result<(), Box<dyn std::error::Error>> {
const POOL_MAX_SIZE: usize = 11;
let mut pool = ConfigPoolBuilder::new(Mode::Server, Config::default());
pool.set_max_pool_size(POOL_MAX_SIZE);
let pool = pool.build();
let mut conns = VecDeque::new();
const COUNT: usize = 25;
for _ in 0..COUNT {
conns.push_back(PooledConnection::new(&pool)?);
}
assert_eq!(conns.len(), COUNT);
conns.clear();
assert_eq!(pool.pool_size(), POOL_MAX_SIZE);
Ok(())
}
#[test]
fn non_generic_pool() -> Result<(), Box<dyn std::error::Error>> {
let config_pool = ConfigPoolBuilder::new(Mode::Server, Config::default()).build();
let _: PooledConnection<ConfigPoolRef> = PooledConnection::new(&config_pool)?;
let pool: Arc<dyn Pool> = config_pool;
let _: PooledConnection = PooledConnection::new(&pool)?;
Ok(())
}
#[test]
fn dereferencing_pooled_connection() -> Result<(), Box<dyn std::error::Error>> {
let config_pool = ConfigPoolBuilder::new(Mode::Server, Config::default()).build();
let pooled_conn: PooledConnection<ConfigPoolRef> = PooledConnection::new(&config_pool)?;
let conn = pooled_conn.deref();
assert_eq!(pooled_conn.config(), conn.config());
let mut mut_pooled_conn: PooledConnection<ConfigPoolRef> =
PooledConnection::new(&config_pool)?;
let waker = futures_test::task::new_count_waker().0;
mut_pooled_conn.set_waker(Some(&waker))?;
assert!(mut_pooled_conn.waker().unwrap().will_wake(&waker));
Ok(())
}
#[test]
fn yielded_connection_associated_config() -> Result<(), error::Error> {
fn associated_config_has_ch_callback(conn: &connection::Connection) -> bool {
conn.config()
.unwrap()
.context()
.client_hello_callback
.is_some()
}
struct ConfigSwapCallback(Config);
impl ClientHelloCallback for ConfigSwapCallback {
fn on_client_hello(
&self,
connection: &mut Connection,
) -> crate::callbacks::ConnectionFutureResult {
connection.set_config(self.0.clone())?;
Ok(None)
}
}
let empty_config = config::Builder::new().build()?;
let mut config_with_callback = config::Builder::new();
let dead_end_callback = ConfigSwapCallback(empty_config);
config_with_callback.set_client_hello_callback(dead_end_callback)?;
let config_with_callback = config_with_callback.build()?;
let config_with_pooled_connections =
ConfigPoolBuilder::new(Mode::Server, config_with_callback).build();
let server_from_pool = config_with_pooled_connections.take()?;
let client = Connection::new_client();
let mut pair = TestPair::from_connections(client, server_from_pool);
assert!(associated_config_has_ch_callback(&pair.server));
assert!(pair.handshake().is_err());
assert!(!associated_config_has_ch_callback(&pair.server));
config_with_pooled_connections.give(pair.server);
let server_from_pool = config_with_pooled_connections.take()?;
assert!(associated_config_has_ch_callback(&server_from_pool));
config_with_pooled_connections.give(server_from_pool);
assert_eq!(config_with_pooled_connections.pool_size(), 1);
Ok(())
}
}