pub const DEFAULT_POOL_SIZE: usize = 10;
cfg_sync! {
pub use r2d2::Error as r2d2Error;
pub use self::sync_impls::Pool;
pub fn get(host: impl ToString, port: u16, max_size: u32) -> Result<Pool, r2d2Error> {
Pool::builder()
.max_size(max_size)
.build(ConnectionManager::new_notls(host.to_string(), port))
}
}
cfg_sync_ssl_any! {
pub use self::sync_impls::TlsPool;
pub fn get_tls(host: impl ToString, port: u16, cert: impl ToString, max_size: u32) -> Result<TlsPool, r2d2Error> {
TlsPool::builder()
.max_size(max_size)
.build(
ConnectionManager::new_tls(host.to_string(), port, cert)
)
}
}
cfg_async! {
pub use bb8::RunError as bb8Error;
pub use self::async_impls::Pool as AsyncPool;
use crate::error::Error;
pub async fn get_async(host: impl ToString, port: u16, max_size: u32) -> Result<AsyncPool, Error> {
AsyncPool::builder()
.max_size(max_size)
.build(ConnectionManager::new_notls(host.to_string(), port)).await
}
}
cfg_async_ssl_any! {
pub use self::async_impls::TlsPool as AsyncTlsPool;
pub async fn get_tls_async(host: impl ToString, port: u16, cert: impl ToString, max_size: u32) -> Result<AsyncTlsPool, Error> {
AsyncTlsPool::builder()
.max_size(max_size)
.build(ConnectionManager::new_tls(host.to_string(), port, cert)).await
}
}
use core::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct ConnectionManager<C> {
host: String,
port: u16,
cert: Option<String>,
_m: PhantomData<C>,
}
impl<C> ConnectionManager<C> {
fn _new(host: String, port: u16, cert: Option<String>) -> Self {
Self {
host,
port,
cert,
_m: PhantomData,
}
}
}
impl<C> ConnectionManager<C> {
pub fn new_notls(host: impl ToString, port: u16) -> ConnectionManager<C> {
Self::_new(host.to_string(), port, None)
}
pub fn new_tls(host: impl ToString, port: u16, cert: impl ToString) -> ConnectionManager<C> {
Self::_new(host.to_string(), port, Some(cert.to_string()))
}
}
#[cfg(feature = "sync")]
#[cfg_attr(docsrs, doc(cfg(feature = "sync")))]
mod sync_impls {
use super::ConnectionManager;
use crate::sync::Connection as SyncConnection;
cfg_sync_ssl_any! {
use crate::sync::TlsConnection as SyncTlsConnection;
}
use crate::{
error::{Error, SkyhashError},
Element, Query, SkyQueryResult, SkyResult,
};
use r2d2::ManageConnection;
pub type Pool = r2d2::Pool<ConnectionManager<SyncConnection>>;
cfg_sync_ssl_any! {
pub type TlsPool = r2d2::Pool<ConnectionManager<SyncTlsConnection>>;
}
pub trait PoolableConnection: Send + Sync + Sized {
fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self>;
fn run_query(&mut self, q: Query) -> SkyQueryResult;
}
impl PoolableConnection for SyncConnection {
fn get_connection(host: &str, port: u16, _tls_cert: Option<&String>) -> SkyResult<Self> {
let c = Self::new(host, port)?;
Ok(c)
}
fn run_query(&mut self, q: Query) -> SkyQueryResult {
self.run_query_raw(&q)
}
}
cfg_sync_ssl_any! {
impl PoolableConnection for SyncTlsConnection {
fn get_connection(host: &str, port: u16, tls_cert: Option<&String>) -> SkyResult<Self> {
let c = Self::new(
host,
port,
tls_cert.ok_or(Error::ConfigurationError(
"Expected TLS certificate in `ConnectionManager`",
))?,
)?;
Ok(c)
}
fn run_query(&mut self, q: Query) -> SkyQueryResult {
self.run_query_raw(&q)
}
}
}
impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> {
type Error = Error;
type Connection = C;
fn connect(&self) -> Result<Self::Connection, Self::Error> {
C::get_connection(self.host.as_ref(), self.port, self.cert.as_ref())
}
fn is_valid(&self, con: &mut Self::Connection) -> Result<(), Self::Error> {
let q = crate::query!("HEYA");
match con.run_query(q)? {
Element::String(st) if st.eq("HEY!") => Ok(()),
_ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)),
}
}
fn has_broken(&self, _: &mut Self::Connection) -> bool {
false
}
}
}
#[cfg(feature = "aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "aio")))]
mod async_impls {
use super::ConnectionManager;
cfg_async_ssl_any! {
use crate::aio::TlsConnection as AsyncTlsConnection;
}
use crate::{
aio::Connection as AsyncConnection,
error::{Error, SkyhashError},
Element, Query, SkyQueryResult, SkyResult,
};
use async_trait::async_trait;
use bb8::ManageConnection;
pub type Pool = bb8::Pool<ConnectionManager<AsyncConnection>>;
cfg_async_ssl_any! {
pub type TlsPool = bb8::Pool<ConnectionManager<AsyncTlsConnection>>;
}
#[async_trait]
pub trait PoolableConnection: Send + Sync + Sized {
async fn get_connection(
host: &str,
port: u16,
tls_cert: Option<&String>,
) -> SkyResult<Self>;
async fn run_query(&mut self, q: Query) -> SkyQueryResult;
}
#[async_trait]
impl PoolableConnection for AsyncConnection {
async fn get_connection(
host: &str,
port: u16,
_tls_cert: Option<&String>,
) -> SkyResult<Self> {
let con = AsyncConnection::new(host, port).await?;
Ok(con)
}
async fn run_query(&mut self, q: Query) -> SkyQueryResult {
self.run_query_raw(&q).await
}
}
cfg_async_ssl_any! {
#[async_trait]
impl PoolableConnection for AsyncTlsConnection {
async fn get_connection(
host: &str,
port: u16,
tls_cert: Option<&String>,
) -> SkyResult<Self> {
let con = AsyncTlsConnection::new(
host,
port,
tls_cert.ok_or(Error::ConfigurationError(
"Expected TLS certificate in `ConnectionManager`",
))?,
)
.await?;
Ok(con)
}
async fn run_query(&mut self, q: Query) -> SkyQueryResult {
self.run_query_raw(&q).await
}
}
}
#[async_trait]
impl<C: PoolableConnection + 'static> ManageConnection for ConnectionManager<C> {
type Connection = C;
type Error = Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
C::get_connection(&self.host, self.port, self.cert.as_ref()).await
}
async fn is_valid(&self, con: &mut Self::Connection) -> Result<(), Self::Error> {
match con.run_query(crate::query!("HEYA")).await? {
Element::String(st) if st.eq("HEY!") => Ok(()),
_ => Err(Error::SkyError(SkyhashError::UnexpectedResponse)),
}
}
fn has_broken(&self, _: &mut Self::Connection) -> bool {
false
}
}
}