#![doc = include_str!("../README.md")]
#![cfg_attr(docsrs, feature(doc_cfg))]
#![deny(
nonstandard_style,
rustdoc::broken_intra_doc_links,
rustdoc::private_intra_doc_links
)]
#![warn(
deprecated_in_future,
missing_copy_implementations,
missing_debug_implementations,
missing_docs,
unreachable_pub,
unused_import_braces,
unused_labels,
unused_lifetimes,
unused_qualifications,
unused_results
)]
mod config;
mod deadpool;
mod generic_client;
pub use deadpool::managed::Metrics;
pub use deadpool::managed::PoolConfig;
pub use deadpool::managed::Timeouts;
use std::future::Future;
use std::{
borrow::Cow,
collections::HashMap,
fmt,
ops::{Deref, DerefMut},
pin::Pin,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex, RwLock, Weak,
},
};
use simple_pg_client::{
tls::MakeTlsConnect, tls::TlsConnect, types::Type, Client as PgClient, Config as PgConfig,
Error, IsolationLevel, Socket, Statement,
};
use tokio::{spawn, task::JoinHandle};
pub use simple_pg_client;
use crate::deadpool::managed::RecycleError;
pub use self::config::{
ChannelBinding, Config, ConfigError, ManagerConfig, RecyclingMethod, SslMode,
TargetSessionAttrs,
};
pub use deadpool::managed::{BuildError, Hook, HookError, Pool, PoolBuilder, PoolConn, PoolError};
pub type CreatePoolError = deadpool::managed::CreatePoolError<ConfigError>;
pub type Client = PoolConn;
type RecycleResult = deadpool::managed::RecycleResult;
pub struct Manager {
config: ManagerConfig,
pg_config: PgConfig,
connect: Box<dyn Connect>,
pub statement_caches: StatementCaches,
}
impl Manager {
pub fn new<T>(pg_config: simple_pg_client::Config, tls: T) -> Self
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
Self::from_config(pg_config, tls, ManagerConfig::default())
}
pub fn from_config<T>(
pg_config: simple_pg_client::Config,
tls: T,
config: ManagerConfig,
) -> Self
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
Self {
config,
pg_config,
connect: Box::new(ConnectImpl { tls }),
statement_caches: StatementCaches::default(),
}
}
}
impl fmt::Debug for Manager {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Manager")
.field("config", &self.config)
.field("pg_config", &self.pg_config)
.field("statement_caches", &self.statement_caches)
.finish()
}
}
impl Manager {
async fn create(&self) -> Result<Conn, Error> {
let (client, conn_task) = self.connect.connect(&self.pg_config).await?;
let client_wrapper = Conn::new(client, conn_task);
self.statement_caches
.attach(&client_wrapper.statement_cache);
Ok(client_wrapper)
}
async fn recycle(&self, client: &mut Conn, _: &Metrics) -> RecycleResult {
if client.is_closed() {
log::warn!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
return Err(RecycleError::StaticMessage("Connection closed"));
}
match self.config.recycling_method.query() {
Some(sql) => match client.simple_query(sql).await {
Ok(_) => Ok(()),
Err(e) => {
log::warn!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
Err(e.into())
}
},
None => Ok(()),
}
}
fn detach(&self, object: &mut Conn) {
self.statement_caches.detach(&object.statement_cache);
}
}
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
trait Connect: Sync + Send {
fn connect(
&self,
pg_config: &PgConfig,
) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>>;
}
struct ConnectImpl<T>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
tls: T,
}
impl<T> Connect for ConnectImpl<T>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
fn connect(
&self,
pg_config: &PgConfig,
) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>> {
let tls = self.tls.clone();
let pg_config = pg_config.clone();
Box::pin(async move {
let fut = pg_config.connect(tls);
let (client, connection) = fut.await?;
let conn_task = spawn(async move {
if let Err(e) = connection.await {
log::warn!(target: "deadpool.postgres", "Connection error: {}", e);
}
});
Ok((client, conn_task))
})
}
}
#[derive(Default, Debug)]
pub struct StatementCaches {
caches: Mutex<Vec<Weak<StatementCache>>>,
}
impl StatementCaches {
fn attach(&self, cache: &Arc<StatementCache>) {
let cache = Arc::downgrade(cache);
self.caches.lock().unwrap().push(cache);
}
fn detach(&self, cache: &Arc<StatementCache>) {
let cache = Arc::downgrade(cache);
self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
}
pub fn clear(&self) {
let caches = self.caches.lock().unwrap();
for cache in caches.iter() {
if let Some(cache) = cache.upgrade() {
cache.clear();
}
}
}
pub fn remove(&self, query: &str, types: &[Type]) {
let caches = self.caches.lock().unwrap();
for cache in caches.iter() {
if let Some(cache) = cache.upgrade() {
drop(cache.remove(query, types));
}
}
}
}
impl fmt::Debug for StatementCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ClientWrapper")
.field("size", &self.size)
.finish()
}
}
#[derive(Debug, Eq, Hash, PartialEq)]
struct StatementCacheKey<'a> {
query: Cow<'a, str>,
types: Cow<'a, [Type]>,
}
pub struct StatementCache {
map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
size: AtomicUsize,
}
impl StatementCache {
fn new() -> Self {
Self {
map: RwLock::new(HashMap::new()),
size: AtomicUsize::new(0),
}
}
pub fn size(&self) -> usize {
self.size.load(Ordering::Relaxed)
}
pub fn clear(&self) {
let mut map = self.map.write().unwrap();
map.clear();
self.size.store(0, Ordering::Relaxed);
}
pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
let key = StatementCacheKey {
query: Cow::Owned(query.to_owned()),
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
let removed = map.remove(&key);
if removed.is_some() {
let _ = self.size.fetch_sub(1, Ordering::Relaxed);
}
removed
}
fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
let key = StatementCacheKey {
query: Cow::Borrowed(query),
types: Cow::Borrowed(types),
};
self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
}
fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
let key = StatementCacheKey {
query: Cow::Owned(query.to_owned()),
types: Cow::Owned(types.to_owned()),
};
let mut map = self.map.write().unwrap();
if map.insert(key, stmt).is_none() {
let _ = self.size.fetch_add(1, Ordering::Relaxed);
}
}
pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
self.prepare_typed(client, query, &[]).await
}
pub async fn prepare_typed(
&self,
client: &PgClient,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
match self.get(query, types) {
Some(statement) => Ok(statement),
None => {
let stmt = client.prepare_typed(query, types).await?;
self.insert(query, types, stmt.clone());
Ok(stmt)
}
}
}
}
#[derive(Debug)]
pub struct Conn {
client: PgClient,
conn_task: JoinHandle<()>,
pub statement_cache: Arc<StatementCache>,
}
pub struct Transaction<'a> {
client: &'a mut Conn,
returning_transaction_depth: u16,
done: bool,
}
impl<'a> fmt::Debug for Transaction<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Transaction")
.field("client", &self.client)
.field("done", &self.done)
.finish()
}
}
impl<'a> Deref for Transaction<'a> {
type Target = Conn;
fn deref(&self) -> &Self::Target {
self.client
}
}
impl<'a> DerefMut for Transaction<'a> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.client
}
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
drop(simple_pg_client::Transaction {
client: self.client,
returning_transaction_depth: self.returning_transaction_depth,
done: self.done,
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct TransactionConfig {
pub isolation_level: IsolationLevel,
pub read_only: bool,
pub deferrable: bool,
}
impl Default for TransactionConfig {
fn default() -> Self {
Self {
isolation_level: IsolationLevel::ReadCommitted,
read_only: false,
deferrable: false,
}
}
}
impl<'a> Transaction<'a> {
pub async fn commit(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.returning_transaction_depth > 0 {
format!("RELEASE sp_{}", self.returning_transaction_depth)
} else {
"COMMIT".to_string()
};
self.client.batch_execute(&query).await
}
pub async fn rollback(mut self) -> Result<(), Error> {
self.done = true;
let query = if self.returning_transaction_depth > 0 {
format!("ROLLBACK TO sp_{}", self.returning_transaction_depth)
} else {
"ROLLBACK".to_string()
};
self.client.batch_execute(&query).await?;
self.client.transaction_depth = self.returning_transaction_depth;
Ok(())
}
async fn new_with_config(
client: &mut Conn,
TransactionConfig {
isolation_level,
read_only,
deferrable,
}: TransactionConfig,
) -> Result<Transaction<'_>, Error> {
debug_assert!(client.transaction_depth == 0);
let tx = (**client)
.build_transaction()
.isolation_level(isolation_level)
.deferrable(deferrable)
.read_only(read_only)
.start()
.await?;
std::mem::forget(tx);
Ok(Transaction {
returning_transaction_depth: client.transaction_depth.saturating_sub(1),
client,
done: false,
})
}
async fn from(client: &mut Conn) -> Result<Transaction<'_>, Error> {
let tx = (**client).transaction().await?;
std::mem::forget(tx);
debug_assert!(client.transaction_depth > 0);
Ok(Transaction {
returning_transaction_depth: client.transaction_depth.saturating_sub(1),
client,
done: false,
})
}
}
impl Conn {
#[must_use]
pub fn new(client: PgClient, conn_task: JoinHandle<()>) -> Self {
Self {
client,
conn_task,
statement_cache: Arc::new(StatementCache::new()),
}
}
pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
match self.statement_cache.get(query, &[]) {
Some(statement) => Ok(statement),
None => {
let stmt = Box::pin(self.client.prepare_typed(query, &[])).await?;
self.statement_cache.insert(query, &[], stmt.clone());
Ok(stmt)
}
}
}
pub async fn prepare_typed_cached(
&self,
query: &str,
types: &[Type],
) -> Result<Statement, Error> {
self.statement_cache
.prepare_typed(&self.client, query, types)
.await
}
pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
Transaction::from(self).await
}
pub async fn transaction_configured(
&mut self,
config: TransactionConfig,
) -> Result<Transaction<'_>, Error> {
Transaction::new_with_config(self, config).await
}
}
impl Deref for Conn {
type Target = PgClient;
fn deref(&self) -> &PgClient {
&self.client
}
}
impl DerefMut for Conn {
fn deref_mut(&mut self) -> &mut PgClient {
&mut self.client
}
}
impl Drop for Conn {
fn drop(&mut self) {
self.conn_task.abort()
}
}