#![warn(missing_docs, unreachable_pub)]
use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::{Deref, DerefMut};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex, RwLock, Weak};
use async_trait::async_trait;
use futures::FutureExt;
use log::{info, warn};
use tokio::spawn;
use tokio_postgres::{
tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig,
Error, IsolationLevel, Socket, Statement, Transaction as PgTransaction,
TransactionBuilder as PgTransactionBuilder,
};
pub mod config;
pub use crate::config::{Config, ManagerConfig, RecyclingMethod};
pub use deadpool::managed::PoolConfig;
pub use deadpool::Runtime;
pub type Pool = deadpool::managed::Pool<Manager>;
pub type PoolError = deadpool::managed::PoolError<tokio_postgres::Error>;
pub type Client = deadpool::managed::Object<Manager>;
type RecycleResult = deadpool::managed::RecycleResult<Error>;
type RecycleError = deadpool::managed::RecycleError<Error>;
pub use tokio_postgres;
pub struct Manager {
config: ManagerConfig,
pg_config: PgConfig,
connect: Box<dyn Connect>,
pub statement_caches: StatementCaches,
}
impl Manager {
pub fn new<T>(pg_config: tokio_postgres::Config, tls: T) -> Manager
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
Self::from_config(pg_config, tls, ManagerConfig::default())
}
pub fn from_config<T>(
pg_config: tokio_postgres::Config,
tls: T,
config: ManagerConfig,
) -> Manager
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
Manager {
config,
pg_config,
connect: Box::new(ConnectImpl { tls }),
statement_caches: StatementCaches::default(),
}
}
}
#[async_trait]
impl deadpool::managed::Manager for Manager {
type Type = ClientWrapper;
type Error = Error;
async fn create(&self) -> Result<ClientWrapper, Error> {
let client = self.connect.connect(&self.pg_config).await?;
let client_wrapper = ClientWrapper::new(client);
self.statement_caches
.attach(&client_wrapper.statement_cache);
Ok(client_wrapper)
}
async fn recycle(&self, client: &mut ClientWrapper) -> RecycleResult {
if client.is_closed() {
info!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
return Err(RecycleError::Message("Connection closed".to_string()));
}
match self.config.recycling_method.query() {
Some(sql) => match client.simple_query(sql).await {
Ok(_) => Ok(()),
Err(e) => {
info!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
Err(e.into())
}
},
None => Ok(()),
}
}
fn detach(&self, object: &mut ClientWrapper) {
self.statement_caches.detach(&object.statement_cache);
}
}
#[async_trait::async_trait]
trait Connect: Sync + Send {
async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error>;
}
struct ConnectImpl<T>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
tls: T,
}
#[async_trait::async_trait]
impl<T> Connect for ConnectImpl<T>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
async fn connect(&self, pg_config: &PgConfig) -> Result<PgClient, Error> {
let (client, connection) = pg_config.connect(self.tls.clone()).await?;
let connection = connection.map(|r| {
if let Err(e) = r {
warn!(target: "deadpool.postgres", "Connection error: {}", e);
}
});
spawn(connection);
Ok(client)
}
}
#[derive(Default)]
pub struct StatementCaches {
caches: Mutex<Vec<Weak<StatementCache>>>,
}
impl StatementCaches {
fn attach(&self, cache: &Arc<StatementCache>) {
let cache = Arc::downgrade(&cache);
self.caches.lock().unwrap().push(cache)
}
fn detach(&self, cache: &Arc<StatementCache>) {
let cache = Arc::downgrade(&cache);
self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
}
pub fn clear(&self) {
let caches = self.caches.lock().unwrap();
for cache in caches.iter() {
if let Some(cache) = cache.upgrade() {
cache.clear();
}
}
}
pub fn remove(&self, query: &str, types: &[Type]) {
let caches = self.caches.lock().unwrap();
for cache in caches.iter() {
if let Some(cache) = cache.upgrade() {
cache.remove(query, types);
}
}
}
}
#[derive(Hash, Eq, PartialEq)]
struct StatementCacheKey<'a> {
query: Cow<'a, str>,
types: Cow<'a, [Type]>,
}
pub struct StatementCache {
map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
size: AtomicUsize,
}
impl StatementCache {
fn new() -> StatementCache {
StatementCache {
map: RwLock::new(HashMap::new()),
size: AtomicUsize::new(0),
}
}
pub fn size(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
pub fn clear(&self) {
let mut map = self.map.write().unwrap();
map.clear();
self.size.store(0, Ordering::Relaxed);
}
pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
let key = StatementCacheKey {
query: Cow::Owned(query.to_owned()),
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
let removed = map.remove(&key);
if removed.is_some() {
self.size.fetch_sub(1, Ordering::Relaxed);
}
removed
}
fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
let key = StatementCacheKey {
query: Cow::Borrowed(query),
types: Cow::Borrowed(types),
};
self.map
.read()
.unwrap()
.get(&key)
.map(|stmt| stmt.to_owned())
}
fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
let key = StatementCacheKey {
query: Cow::Owned(query.to_owned()),
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
if map.insert(key, stmt).is_none() {
self.size.fetch_add(1, Ordering::Relaxed);
}
}
pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
self.prepare_typed(client, query, &[]).await
}
pub async fn prepare_typed(
&self,
client: &PgClient,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
match self.get(query, types) {
Some(statement) => Ok(statement),
None => {
let stmt = client.prepare_typed(query, types).await?;
self.insert(query, types, stmt.clone());
Ok(stmt)
}
}
}
}
pub struct ClientWrapper {
client: PgClient,
pub statement_cache: Arc<StatementCache>,
}
impl ClientWrapper {
pub fn new(client: PgClient) -> Self {
Self {
client,
statement_cache: Arc::new(StatementCache::new()),
}
}
pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
self.statement_cache.prepare(&self.client, query).await
}
pub async fn prepare_typed_cached(
&self,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
self.statement_cache
.prepare_typed(&self.client, query, types)
.await
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
Ok(Transaction {
txn: PgClient::transaction(&mut self.client).await?,
statement_cache: self.statement_cache.clone(),
})
}
pub fn build_transaction(&mut self) -> TransactionBuilder {
TransactionBuilder {
builder: self.client.build_transaction(),
statement_cache: self.statement_cache.clone(),
}
}
}
impl Deref for ClientWrapper {
type Target = PgClient;
fn deref(&self) -> &PgClient {
&self.client
}
}
impl DerefMut for ClientWrapper {
fn deref_mut(&mut self) -> &mut PgClient {
&mut self.client
}
}
pub struct Transaction<'a> {
txn: PgTransaction<'a>,
statement_cache: Arc<StatementCache>,
}
impl<'a> Transaction<'a> {
pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
self.statement_cache.prepare(&self.client(), query).await
}
pub async fn prepare_typed_cached(
&self,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
self.statement_cache
.prepare_typed(&self.client(), query, types)
.await
}
pub async fn commit(self) -> Result<(), Error> {
self.txn.commit().await
}
pub async fn rollback(self) -> Result<(), Error> {
self.txn.rollback().await
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
Ok(Transaction {
txn: PgTransaction::transaction(&mut self.txn).await?,
statement_cache: self.statement_cache.clone(),
})
}
pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
where
I: Into<String>,
{
Ok(Transaction {
txn: PgTransaction::savepoint(&mut self.txn, name).await?,
statement_cache: self.statement_cache.clone(),
})
}
}
impl<'a> Deref for Transaction<'a> {
type Target = PgTransaction<'a>;
fn deref(&self) -> &PgTransaction<'a> {
&self.txn
}
}
impl<'a> DerefMut for Transaction<'a> {
fn deref_mut(&mut self) -> &mut PgTransaction<'a> {
&mut self.txn
}
}
pub struct TransactionBuilder<'a> {
builder: PgTransactionBuilder<'a>,
statement_cache: Arc<StatementCache>,
}
impl<'a> TransactionBuilder<'a> {
pub fn isolation_level(self, isolation_level: IsolationLevel) -> Self {
Self {
builder: self.builder.isolation_level(isolation_level),
statement_cache: self.statement_cache,
}
}
pub fn read_only(self, read_only: bool) -> Self {
Self {
builder: self.builder.read_only(read_only),
statement_cache: self.statement_cache,
}
}
pub fn deferrable(self, deferrable: bool) -> Self {
Self {
builder: self.builder.deferrable(deferrable),
statement_cache: self.statement_cache,
}
}
pub async fn start(self) -> Result<Transaction<'a>, Error> {
Ok(Transaction {
txn: self.builder.start().await?,
statement_cache: self.statement_cache,
})
}
}
impl<'a> Deref for TransactionBuilder<'a> {
type Target = PgTransactionBuilder<'a>;
fn deref(&self) -> &Self::Target {
&self.builder
}
}
impl<'a> DerefMut for TransactionBuilder<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.builder
}
}