use std::pin::Pin;
use std::sync::Arc;
use deadpool::managed::{self, Manager, Metrics, RecycleError, RecycleResult};
use tokio::sync::Mutex as AsyncMutex;
use crate::async_connection::AsyncConnection;
use crate::error::{Error, Result};
use crate::CreateMode;
pub type HookFuture<'a> = Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>>;
pub type AfterConnectHook = Arc<dyn Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static>;
pub type BeforeAcquireHook =
Arc<dyn Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static>;
#[derive(Clone)]
pub struct PoolConfig {
pub endpoint: String,
pub database: String,
pub create_mode: CreateMode,
pub user: Option<String>,
pub password: Option<String>,
pub max_size: usize,
pub health_check: bool,
pub after_connect: Option<AfterConnectHook>,
pub before_acquire: Option<BeforeAcquireHook>,
}
impl std::fmt::Debug for PoolConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("PoolConfig")
.field("endpoint", &self.endpoint)
.field("database", &self.database)
.field("create_mode", &self.create_mode)
.field("user", &self.user)
.field("password", &self.password.as_ref().map(|_| "<redacted>"))
.field("max_size", &self.max_size)
.field("health_check", &self.health_check)
.field(
"after_connect",
&self.after_connect.as_ref().map(|_| "<fn>"),
)
.field(
"before_acquire",
&self.before_acquire.as_ref().map(|_| "<fn>"),
)
.finish()
}
}
impl PoolConfig {
pub fn new(endpoint: impl Into<String>, database: impl Into<String>) -> Self {
Self {
endpoint: endpoint.into(),
database: database.into(),
create_mode: CreateMode::DoNotCreate,
user: None,
password: None,
max_size: 16,
health_check: true,
after_connect: None,
before_acquire: None,
}
}
#[must_use]
pub fn create_mode(mut self, mode: CreateMode) -> Self {
self.create_mode = mode;
self
}
#[must_use]
pub fn auth(mut self, user: impl Into<String>, password: impl Into<String>) -> Self {
self.user = Some(user.into());
self.password = Some(password.into());
self
}
#[must_use]
pub fn max_size(mut self, size: usize) -> Self {
self.max_size = size;
self
}
#[must_use]
pub fn health_check(mut self, enabled: bool) -> Self {
self.health_check = enabled;
self
}
#[must_use]
pub fn after_connect<F>(mut self, hook: F) -> Self
where
F: Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static,
{
self.after_connect = Some(Arc::new(hook));
self
}
#[must_use]
pub fn before_acquire<F>(mut self, hook: F) -> Self
where
F: Fn(&AsyncConnection) -> HookFuture<'_> + Send + Sync + 'static,
{
self.before_acquire = Some(Arc::new(hook));
self
}
}
#[derive(Debug)]
pub struct ConnectionManager {
config: Arc<PoolConfig>,
init_lock: Arc<AsyncMutex<bool>>,
}
impl ConnectionManager {
#[must_use]
pub fn new(config: PoolConfig) -> Self {
Self {
config: Arc::new(config),
init_lock: Arc::new(AsyncMutex::new(false)),
}
}
async fn open(&self, mode: CreateMode) -> Result<AsyncConnection> {
if let (Some(user), Some(password)) = (&self.config.user, &self.config.password) {
AsyncConnection::connect_with_auth(
&self.config.endpoint,
&self.config.database,
mode,
user,
password,
)
.await
} else {
AsyncConnection::connect(&self.config.endpoint, &self.config.database, mode).await
}
}
}
impl Manager for ConnectionManager {
type Type = AsyncConnection;
type Error = Error;
async fn create(&self) -> Result<AsyncConnection> {
let conn = {
let initialized = self.init_lock.lock().await;
if *initialized {
drop(initialized);
self.open(CreateMode::DoNotCreate).await?
} else {
drop(initialized);
let mut initialized = self.init_lock.lock().await;
if *initialized {
drop(initialized);
self.open(CreateMode::DoNotCreate).await?
} else {
let result = self.open(self.config.create_mode).await;
if result.is_ok() {
*initialized = true;
}
result?
}
}
};
if let Some(hook) = self.config.after_connect.as_ref() {
hook(&conn).await?;
}
Ok(conn)
}
async fn recycle(
&self,
conn: &mut AsyncConnection,
_metrics: &Metrics,
) -> RecycleResult<Self::Error> {
if self.config.health_check {
conn.execute_command("SELECT 1")
.await
.map_err(RecycleError::Backend)?;
}
if let Some(hook) = self.config.before_acquire.as_ref() {
hook(conn).await.map_err(RecycleError::Backend)?;
}
Ok(())
}
}
pub type Pool = managed::Pool<ConnectionManager>;
pub type PooledConnection = managed::Object<ConnectionManager>;
pub fn create_pool(config: PoolConfig) -> Result<Pool> {
let max_size = config.max_size;
let manager = ConnectionManager::new(config);
Pool::builder(manager)
.max_size(max_size)
.build()
.map_err(|e| Error::new(format!("Failed to create pool: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_config_builder() {
let config = PoolConfig::new("localhost:7483", "test.hyper")
.create_mode(CreateMode::CreateIfNotExists)
.auth("user", "pass")
.max_size(32);
assert_eq!(config.endpoint, "localhost:7483");
assert_eq!(config.database, "test.hyper");
assert_eq!(config.create_mode, CreateMode::CreateIfNotExists);
assert_eq!(config.user, Some("user".to_string()));
assert_eq!(config.password, Some("pass".to_string()));
assert_eq!(config.max_size, 32);
}
}