deadpool_postgres/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4    nonstandard_style,
5    rust_2018_idioms,
6    rustdoc::broken_intra_doc_links,
7    rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11    deprecated_in_future,
12    missing_copy_implementations,
13    missing_debug_implementations,
14    missing_docs,
15    unreachable_pub,
16    unused_import_braces,
17    unused_labels,
18    unused_lifetimes,
19    unused_qualifications,
20    unused_results
21)]
22
23mod config;
24mod generic_client;
25
26use std::{
27    borrow::Cow,
28    collections::HashMap,
29    fmt,
30    future::Future,
31    ops::{Deref, DerefMut},
32    pin::Pin,
33    sync::{
34        atomic::{AtomicUsize, Ordering},
35        Arc, Mutex, RwLock, Weak,
36    },
37};
38
39use deadpool::managed;
40#[cfg(not(target_arch = "wasm32"))]
41use tokio::spawn;
42use tokio::task::JoinHandle;
43use tokio_postgres::{
44    types::Type, Client as PgClient, Config as PgConfig, Error, IsolationLevel, Statement,
45    Transaction as PgTransaction, TransactionBuilder as PgTransactionBuilder,
46};
47
48#[cfg(not(target_arch = "wasm32"))]
49use tokio_postgres::{
50    tls::{MakeTlsConnect, TlsConnect},
51    Socket,
52};
53
54pub use tokio_postgres;
55
56pub use self::config::{
57    ChannelBinding, Config, ConfigError, LoadBalanceHosts, ManagerConfig, RecyclingMethod, SslMode,
58    TargetSessionAttrs,
59};
60
61pub use self::generic_client::GenericClient;
62
63pub use deadpool::managed::reexports::*;
64deadpool::managed_reexports!(
65    "tokio_postgres",
66    Manager,
67    managed::Object<Manager>,
68    Error,
69    ConfigError
70);
71
72type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
73
74/// Type alias for [`Object`]
75pub type Client = Object;
76
77type RecycleResult = managed::RecycleResult<Error>;
78type RecycleError = managed::RecycleError<Error>;
79
80/// [`Manager`] for creating and recycling PostgreSQL connections.
81///
82/// [`Manager`]: managed::Manager
83pub struct Manager {
84    config: ManagerConfig,
85    pg_config: PgConfig,
86    connect: Box<dyn Connect>,
87    /// [`StatementCaches`] of [`Client`]s handed out by the [`Pool`].
88    pub statement_caches: StatementCaches,
89}
90
91impl Manager {
92    #[cfg(not(target_arch = "wasm32"))]
93    /// Creates a new [`Manager`] using the given [`tokio_postgres::Config`] and
94    /// `tls` connector.
95    pub fn new<T>(pg_config: tokio_postgres::Config, tls: T) -> Self
96    where
97        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
98        T::Stream: Sync + Send,
99        T::TlsConnect: Sync + Send,
100        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
101    {
102        Self::from_config(pg_config, tls, ManagerConfig::default())
103    }
104
105    #[cfg(not(target_arch = "wasm32"))]
106    /// Create a new [`Manager`] using the given [`tokio_postgres::Config`], and
107    /// `tls` connector and [`ManagerConfig`].
108    pub fn from_config<T>(pg_config: tokio_postgres::Config, tls: T, config: ManagerConfig) -> Self
109    where
110        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
111        T::Stream: Sync + Send,
112        T::TlsConnect: Sync + Send,
113        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
114    {
115        Self::from_connect(pg_config, ConfigConnectImpl { tls }, config)
116    }
117
118    /// Create a new [`Manager`] using the given [`tokio_postgres::Config`], and
119    /// `connect` impl and [`ManagerConfig`].
120    pub fn from_connect(
121        pg_config: tokio_postgres::Config,
122        connect: impl Connect + 'static,
123        config: ManagerConfig,
124    ) -> Self {
125        Self {
126            config,
127            pg_config,
128            connect: Box::new(connect),
129            statement_caches: StatementCaches::default(),
130        }
131    }
132}
133
134impl fmt::Debug for Manager {
135    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136        f.debug_struct("Manager")
137            .field("config", &self.config)
138            .field("pg_config", &self.pg_config)
139            //.field("connect", &self.connect)
140            .field("statement_caches", &self.statement_caches)
141            .finish()
142    }
143}
144
145impl managed::Manager for Manager {
146    type Type = ClientWrapper;
147    type Error = Error;
148
149    async fn create(&self) -> Result<ClientWrapper, Error> {
150        let (client, conn_task) = self.connect.connect(&self.pg_config).await?;
151        let client_wrapper = ClientWrapper::new(client, conn_task);
152        self.statement_caches
153            .attach(&client_wrapper.statement_cache);
154        Ok(client_wrapper)
155    }
156
157    async fn recycle(&self, client: &mut ClientWrapper, _: &Metrics) -> RecycleResult {
158        if client.is_closed() {
159            tracing::warn!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
160            return Err(RecycleError::message("Connection closed"));
161        }
162        match self.config.recycling_method.query() {
163            Some(sql) => match client.simple_query(sql).await {
164                Ok(_) => Ok(()),
165                Err(e) => {
166                    tracing::warn!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
167                    Err(e.into())
168                }
169            },
170            None => Ok(()),
171        }
172    }
173
174    fn detach(&self, object: &mut ClientWrapper) {
175        self.statement_caches.detach(&object.statement_cache);
176    }
177}
178
179/// Describes a mechanism for establishing a connection to a PostgreSQL
180/// server via `tokio_postgres`.
181pub trait Connect: Sync + Send {
182    /// Establishes a new `tokio_postgres` connection, returning
183    /// the associated `Client` and a `JoinHandle` to a tokio task
184    /// for processing the connection.
185    fn connect(
186        &self,
187        pg_config: &PgConfig,
188    ) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>>;
189}
190
191#[cfg(not(target_arch = "wasm32"))]
192/// Provides an implementation of [`Connect`] that establishes the connection
193/// using the `tokio_postgres` configuration itself.
194#[derive(Debug)]
195pub struct ConfigConnectImpl<T>
196where
197    T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
198    T::Stream: Sync + Send,
199    T::TlsConnect: Sync + Send,
200    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
201{
202    /// The TLS connector to use for the connection.
203    pub tls: T,
204}
205
206#[cfg(not(target_arch = "wasm32"))]
207impl<T> Connect for ConfigConnectImpl<T>
208where
209    T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
210    T::Stream: Sync + Send,
211    T::TlsConnect: Sync + Send,
212    <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
213{
214    fn connect(
215        &self,
216        pg_config: &PgConfig,
217    ) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>> {
218        let tls = self.tls.clone();
219        let pg_config = pg_config.clone();
220        Box::pin(async move {
221            let fut = pg_config.connect(tls);
222            let (client, connection) = fut.await?;
223            let conn_task = spawn(async move {
224                if let Err(e) = connection.await {
225                    tracing::warn!(target: "deadpool.postgres", "Connection error: {}", e);
226                }
227            });
228            Ok((client, conn_task))
229        })
230    }
231}
232
233/// Structure holding a reference to all [`StatementCache`]s and providing
234/// access for clearing all caches and removing single statements from them.
235#[derive(Default, Debug)]
236pub struct StatementCaches {
237    caches: Mutex<Vec<Weak<StatementCache>>>,
238}
239
240impl StatementCaches {
241    fn attach(&self, cache: &Arc<StatementCache>) {
242        let cache = Arc::downgrade(cache);
243        self.caches.lock().unwrap().push(cache);
244    }
245
246    fn detach(&self, cache: &Arc<StatementCache>) {
247        let cache = Arc::downgrade(cache);
248        self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
249    }
250
251    /// Clears [`StatementCache`] of all connections which were handed out by a
252    /// [`Manager`].
253    pub fn clear(&self) {
254        let caches = self.caches.lock().unwrap();
255        for cache in caches.iter() {
256            if let Some(cache) = cache.upgrade() {
257                cache.clear();
258            }
259        }
260    }
261
262    /// Removes statement from all caches which were handed out by a
263    /// [`Manager`].
264    pub fn remove(&self, query: &str, types: &[Type]) {
265        let caches = self.caches.lock().unwrap();
266        for cache in caches.iter() {
267            if let Some(cache) = cache.upgrade() {
268                drop(cache.remove(query, types));
269            }
270        }
271    }
272}
273
274impl fmt::Debug for StatementCache {
275    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276        f.debug_struct("ClientWrapper")
277            //.field("map", &self.map)
278            .field("size", &self.size)
279            .finish()
280    }
281}
282
283// Allows us to use owned keys in a `HashMap`, but still be able to call `get`
284// with borrowed keys instead of allocating them each time.
285#[derive(Debug, Eq, Hash, PartialEq)]
286struct StatementCacheKey<'a> {
287    query: Cow<'a, str>,
288    types: Cow<'a, [Type]>,
289}
290
291/// Representation of a cache of [`Statement`]s.
292///
293/// [`StatementCache`] is bound to one [`Client`], and [`Statement`]s generated
294/// by that [`Client`] must not be used with other [`Client`]s.
295///
296/// It can be used like that:
297/// ```rust,ignore
298/// let client = pool.get().await?;
299/// let stmt = client
300///     .statement_cache
301///     .prepare(&client, "SELECT 1")
302///     .await;
303/// let rows = client.query(stmt, &[]).await?;
304/// ...
305/// ```
306///
307/// Normally, you probably want to use the [`ClientWrapper::prepare_cached()`]
308/// and [`ClientWrapper::prepare_typed_cached()`] methods instead (or the
309/// similar ones on [`Transaction`]).
310pub struct StatementCache {
311    map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
312    size: AtomicUsize,
313}
314
315impl StatementCache {
316    fn new() -> Self {
317        Self {
318            map: RwLock::new(HashMap::new()),
319            size: AtomicUsize::new(0),
320        }
321    }
322
323    /// Returns current size of this [`StatementCache`].
324    pub fn size(&self) -> usize {
325        self.size.load(Ordering::Relaxed)
326    }
327
328    /// Clears this [`StatementCache`].
329    ///
330    /// **Important:** This only clears the [`StatementCache`] of one [`Client`]
331    /// instance. If you want to clear the [`StatementCache`] of all [`Client`]s
332    /// you should be calling `pool.manager().statement_caches.clear()` instead.
333    pub fn clear(&self) {
334        let mut map = self.map.write().unwrap();
335        map.clear();
336        self.size.store(0, Ordering::Relaxed);
337    }
338
339    /// Removes a [`Statement`] from this [`StatementCache`].
340    ///
341    /// **Important:** This only removes a [`Statement`] from one [`Client`]
342    /// cache. If you want to remove a [`Statement`] from all
343    /// [`StatementCaches`] you should be calling
344    /// `pool.manager().statement_caches.remove()` instead.
345    pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
346        let key = StatementCacheKey {
347            query: Cow::Owned(query.to_owned()),
348            types: Cow::Owned(types.to_owned()),
349        };
350        let mut map = self.map.write().unwrap();
351        let removed = map.remove(&key);
352        if removed.is_some() {
353            let _ = self.size.fetch_sub(1, Ordering::Relaxed);
354        }
355        removed
356    }
357
358    /// Returns a [`Statement`] from this [`StatementCache`].
359    fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
360        let key = StatementCacheKey {
361            query: Cow::Borrowed(query),
362            types: Cow::Borrowed(types),
363        };
364        self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
365    }
366
367    /// Inserts a [`Statement`] into this [`StatementCache`].
368    fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
369        let key = StatementCacheKey {
370            query: Cow::Owned(query.to_owned()),
371            types: Cow::Owned(types.to_owned()),
372        };
373        let mut map = self.map.write().unwrap();
374        if map.insert(key, stmt).is_none() {
375            let _ = self.size.fetch_add(1, Ordering::Relaxed);
376        }
377    }
378
379    /// Creates a new prepared [`Statement`] using this [`StatementCache`], if
380    /// possible.
381    ///
382    /// See [`tokio_postgres::Client::prepare()`].
383    pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
384        self.prepare_typed(client, query, &[]).await
385    }
386
387    /// Creates a new prepared [`Statement`] with specifying its [`Type`]s
388    /// explicitly using this [`StatementCache`], if possible.
389    ///
390    /// See [`tokio_postgres::Client::prepare_typed()`].
391    pub async fn prepare_typed(
392        &self,
393        client: &PgClient,
394        query: &str,
395        types: &[Type],
396    ) -> Result<Statement, Error> {
397        match self.get(query, types) {
398            Some(statement) => Ok(statement),
399            None => {
400                let stmt = client.prepare_typed(query, types).await?;
401                self.insert(query, types, stmt.clone());
402                Ok(stmt)
403            }
404        }
405    }
406}
407
408/// Wrapper around [`tokio_postgres::Client`] with a [`StatementCache`].
409#[derive(Debug)]
410pub struct ClientWrapper {
411    /// Original [`PgClient`].
412    client: PgClient,
413
414    /// A handle to the connection task that should be aborted when the client
415    /// wrapper is dropped.
416    conn_task: JoinHandle<()>,
417
418    /// [`StatementCache`] of this client.
419    pub statement_cache: Arc<StatementCache>,
420}
421
422impl ClientWrapper {
423    /// Create a new [`ClientWrapper`] instance using the given
424    /// [`tokio_postgres::Client`] and handle to the connection task.
425    #[must_use]
426    pub fn new(client: PgClient, conn_task: JoinHandle<()>) -> Self {
427        Self {
428            client,
429            conn_task,
430            statement_cache: Arc::new(StatementCache::new()),
431        }
432    }
433
434    /// Like [`tokio_postgres::Client::prepare()`], but uses an existing
435    /// [`Statement`] from the [`StatementCache`] if possible.
436    pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
437        self.statement_cache.prepare(&self.client, query).await
438    }
439
440    /// Like [`tokio_postgres::Client::prepare_typed()`], but uses an
441    /// existing [`Statement`] from the [`StatementCache`] if possible.
442    pub async fn prepare_typed_cached(
443        &self,
444        query: &str,
445        types: &[Type],
446    ) -> Result<Statement, Error> {
447        self.statement_cache
448            .prepare_typed(&self.client, query, types)
449            .await
450    }
451
452    /// Like [`tokio_postgres::Client::transaction()`], but returns a wrapped
453    /// [`Transaction`] with a [`StatementCache`].
454    #[allow(unused_lifetimes)] // false positive
455    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
456        Ok(Transaction {
457            txn: PgClient::transaction(&mut self.client).await?,
458            statement_cache: self.statement_cache.clone(),
459        })
460    }
461
462    /// Like [`tokio_postgres::Client::build_transaction()`], but creates a
463    /// wrapped [`Transaction`] with a [`StatementCache`].
464    pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
465        TransactionBuilder {
466            builder: self.client.build_transaction(),
467            statement_cache: self.statement_cache.clone(),
468        }
469    }
470}
471
472impl Deref for ClientWrapper {
473    type Target = PgClient;
474
475    fn deref(&self) -> &PgClient {
476        &self.client
477    }
478}
479
480impl DerefMut for ClientWrapper {
481    fn deref_mut(&mut self) -> &mut PgClient {
482        &mut self.client
483    }
484}
485
486impl Drop for ClientWrapper {
487    fn drop(&mut self) {
488        self.conn_task.abort()
489    }
490}
491
492/// Wrapper around [`tokio_postgres::Transaction`] with a [`StatementCache`]
493/// from the [`Client`] object it was created by.
494pub struct Transaction<'a> {
495    /// Original [`PgTransaction`].
496    txn: PgTransaction<'a>,
497
498    /// [`StatementCache`] of this [`Transaction`].
499    pub statement_cache: Arc<StatementCache>,
500}
501
502impl<'a> Transaction<'a> {
503    /// Like [`tokio_postgres::Transaction::prepare()`], but uses an existing
504    /// [`Statement`] from the [`StatementCache`] if possible.
505    pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
506        self.statement_cache.prepare(self.client(), query).await
507    }
508
509    /// Like [`tokio_postgres::Transaction::prepare_typed()`], but uses an
510    /// existing [`Statement`] from the [`StatementCache`] if possible.
511    pub async fn prepare_typed_cached(
512        &self,
513        query: &str,
514        types: &[Type],
515    ) -> Result<Statement, Error> {
516        self.statement_cache
517            .prepare_typed(self.client(), query, types)
518            .await
519    }
520
521    /// Like [`tokio_postgres::Transaction::commit()`].
522    pub async fn commit(self) -> Result<(), Error> {
523        self.txn.commit().await
524    }
525
526    /// Like [`tokio_postgres::Transaction::rollback()`].
527    pub async fn rollback(self) -> Result<(), Error> {
528        self.txn.rollback().await
529    }
530
531    /// Like [`tokio_postgres::Transaction::transaction()`], but returns a
532    /// wrapped [`Transaction`] with a [`StatementCache`].
533    #[allow(unused_lifetimes)] // false positive
534    pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
535        Ok(Transaction {
536            txn: PgTransaction::transaction(&mut self.txn).await?,
537            statement_cache: self.statement_cache.clone(),
538        })
539    }
540
541    /// Like [`tokio_postgres::Transaction::savepoint()`], but returns a wrapped
542    /// [`Transaction`] with a [`StatementCache`].
543    #[allow(unused_lifetimes)] // false positive
544    pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
545    where
546        I: Into<String>,
547    {
548        Ok(Transaction {
549            txn: PgTransaction::savepoint(&mut self.txn, name).await?,
550            statement_cache: self.statement_cache.clone(),
551        })
552    }
553}
554
555impl<'a> fmt::Debug for Transaction<'a> {
556    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
557        f.debug_struct("Transaction")
558            //.field("txn", &self.txn)
559            .field("statement_cache", &self.statement_cache)
560            .finish()
561    }
562}
563
564impl<'a> Deref for Transaction<'a> {
565    type Target = PgTransaction<'a>;
566
567    fn deref(&self) -> &PgTransaction<'a> {
568        &self.txn
569    }
570}
571
572impl<'a> DerefMut for Transaction<'a> {
573    fn deref_mut(&mut self) -> &mut PgTransaction<'a> {
574        &mut self.txn
575    }
576}
577
578/// Wrapper around [`tokio_postgres::TransactionBuilder`] with a
579/// [`StatementCache`] from the [`Client`] object it was created by.
580#[must_use = "builder does nothing itself, use `.start()` to use it"]
581pub struct TransactionBuilder<'a> {
582    /// Original [`PgTransactionBuilder`].
583    builder: PgTransactionBuilder<'a>,
584
585    /// [`StatementCache`] of this [`TransactionBuilder`].
586    statement_cache: Arc<StatementCache>,
587}
588
589impl<'a> TransactionBuilder<'a> {
590    /// Sets the isolation level of the transaction.
591    ///
592    /// Like [`tokio_postgres::TransactionBuilder::isolation_level()`].
593    pub fn isolation_level(self, isolation_level: IsolationLevel) -> Self {
594        Self {
595            builder: self.builder.isolation_level(isolation_level),
596            statement_cache: self.statement_cache,
597        }
598    }
599
600    /// Sets the access mode of the transaction.
601    ///
602    /// Like [`tokio_postgres::TransactionBuilder::read_only()`].
603    pub fn read_only(self, read_only: bool) -> Self {
604        Self {
605            builder: self.builder.read_only(read_only),
606            statement_cache: self.statement_cache,
607        }
608    }
609
610    /// Sets the deferrability of the transaction.
611    ///
612    /// If the transaction is also serializable and read only, creation
613    /// of the transaction may block, but when it completes the transaction
614    /// is able to run with less overhead and a guarantee that it will not
615    /// be aborted due to serialization failure.
616    ///
617    /// Like [`tokio_postgres::TransactionBuilder::deferrable()`].
618    pub fn deferrable(self, deferrable: bool) -> Self {
619        Self {
620            builder: self.builder.deferrable(deferrable),
621            statement_cache: self.statement_cache,
622        }
623    }
624
625    /// Begins the [`Transaction`].
626    ///
627    /// The transaction will roll back by default - use the commit method
628    /// to commit it.
629    ///
630    /// Like [`tokio_postgres::TransactionBuilder::start()`].
631    pub async fn start(self) -> Result<Transaction<'a>, Error> {
632        Ok(Transaction {
633            txn: self.builder.start().await?,
634            statement_cache: self.statement_cache,
635        })
636    }
637}
638
639impl<'a> fmt::Debug for TransactionBuilder<'a> {
640    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
641        f.debug_struct("TransactionBuilder")
642            //.field("builder", &self.builder)
643            .field("statement_cache", &self.statement_cache)
644            .finish()
645    }
646}
647
648impl<'a> Deref for TransactionBuilder<'a> {
649    type Target = PgTransactionBuilder<'a>;
650
651    fn deref(&self) -> &Self::Target {
652        &self.builder
653    }
654}
655
656impl<'a> DerefMut for TransactionBuilder<'a> {
657    fn deref_mut(&mut self) -> &mut Self::Target {
658        &mut self.builder
659    }
660}