1#![doc = include_str!("../README.md")]
2#![cfg_attr(docsrs, feature(doc_cfg))]
3#![deny(
4 nonstandard_style,
5 rust_2018_idioms,
6 rustdoc::broken_intra_doc_links,
7 rustdoc::private_intra_doc_links
8)]
9#![forbid(non_ascii_idents, unsafe_code)]
10#![warn(
11 deprecated_in_future,
12 missing_copy_implementations,
13 missing_debug_implementations,
14 missing_docs,
15 unreachable_pub,
16 unused_import_braces,
17 unused_labels,
18 unused_lifetimes,
19 unused_qualifications,
20 unused_results
21)]
22
23mod config;
24mod generic_client;
25
26use std::{
27 borrow::Cow,
28 collections::HashMap,
29 fmt,
30 future::Future,
31 ops::{Deref, DerefMut},
32 pin::Pin,
33 sync::{
34 atomic::{AtomicUsize, Ordering},
35 Arc, Mutex, RwLock, Weak,
36 },
37};
38
39use deadpool::managed;
40#[cfg(not(target_arch = "wasm32"))]
41use tokio::spawn;
42use tokio::task::JoinHandle;
43use tokio_postgres::{
44 types::Type, Client as PgClient, Config as PgConfig, Error, IsolationLevel, Statement,
45 Transaction as PgTransaction, TransactionBuilder as PgTransactionBuilder,
46};
47
48#[cfg(not(target_arch = "wasm32"))]
49use tokio_postgres::{
50 tls::{MakeTlsConnect, TlsConnect},
51 Socket,
52};
53
54pub use tokio_postgres;
55
56pub use self::config::{
57 ChannelBinding, Config, ConfigError, LoadBalanceHosts, ManagerConfig, RecyclingMethod, SslMode,
58 TargetSessionAttrs,
59};
60
61pub use self::generic_client::GenericClient;
62
63pub use deadpool::managed::reexports::*;
64deadpool::managed_reexports!(
65 "tokio_postgres",
66 Manager,
67 managed::Object<Manager>,
68 Error,
69 ConfigError
70);
71
72type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
73
74pub type Client = Object;
76
77type RecycleResult = managed::RecycleResult<Error>;
78type RecycleError = managed::RecycleError<Error>;
79
80pub struct Manager {
84 config: ManagerConfig,
85 pg_config: PgConfig,
86 connect: Box<dyn Connect>,
87 pub statement_caches: StatementCaches,
89}
90
91impl Manager {
92 #[cfg(not(target_arch = "wasm32"))]
93 pub fn new<T>(pg_config: tokio_postgres::Config, tls: T) -> Self
96 where
97 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
98 T::Stream: Sync + Send,
99 T::TlsConnect: Sync + Send,
100 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
101 {
102 Self::from_config(pg_config, tls, ManagerConfig::default())
103 }
104
105 #[cfg(not(target_arch = "wasm32"))]
106 pub fn from_config<T>(pg_config: tokio_postgres::Config, tls: T, config: ManagerConfig) -> Self
109 where
110 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
111 T::Stream: Sync + Send,
112 T::TlsConnect: Sync + Send,
113 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
114 {
115 Self::from_connect(pg_config, ConfigConnectImpl { tls }, config)
116 }
117
118 pub fn from_connect(
121 pg_config: tokio_postgres::Config,
122 connect: impl Connect + 'static,
123 config: ManagerConfig,
124 ) -> Self {
125 Self {
126 config,
127 pg_config,
128 connect: Box::new(connect),
129 statement_caches: StatementCaches::default(),
130 }
131 }
132}
133
134impl fmt::Debug for Manager {
135 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
136 f.debug_struct("Manager")
137 .field("config", &self.config)
138 .field("pg_config", &self.pg_config)
139 .field("statement_caches", &self.statement_caches)
141 .finish()
142 }
143}
144
145impl managed::Manager for Manager {
146 type Type = ClientWrapper;
147 type Error = Error;
148
149 async fn create(&self) -> Result<ClientWrapper, Error> {
150 let (client, conn_task) = self.connect.connect(&self.pg_config).await?;
151 let client_wrapper = ClientWrapper::new(client, conn_task);
152 self.statement_caches
153 .attach(&client_wrapper.statement_cache);
154 Ok(client_wrapper)
155 }
156
157 async fn recycle(&self, client: &mut ClientWrapper, _: &Metrics) -> RecycleResult {
158 if client.is_closed() {
159 tracing::warn!(target: "deadpool.postgres", "Connection could not be recycled: Connection closed");
160 return Err(RecycleError::message("Connection closed"));
161 }
162 match self.config.recycling_method.query() {
163 Some(sql) => match client.simple_query(sql).await {
164 Ok(_) => Ok(()),
165 Err(e) => {
166 tracing::warn!(target: "deadpool.postgres", "Connection could not be recycled: {}", e);
167 Err(e.into())
168 }
169 },
170 None => Ok(()),
171 }
172 }
173
174 fn detach(&self, object: &mut ClientWrapper) {
175 self.statement_caches.detach(&object.statement_cache);
176 }
177}
178
179pub trait Connect: Sync + Send {
182 fn connect(
186 &self,
187 pg_config: &PgConfig,
188 ) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>>;
189}
190
191#[cfg(not(target_arch = "wasm32"))]
192#[derive(Debug)]
195pub struct ConfigConnectImpl<T>
196where
197 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
198 T::Stream: Sync + Send,
199 T::TlsConnect: Sync + Send,
200 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
201{
202 pub tls: T,
204}
205
206#[cfg(not(target_arch = "wasm32"))]
207impl<T> Connect for ConfigConnectImpl<T>
208where
209 T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
210 T::Stream: Sync + Send,
211 T::TlsConnect: Sync + Send,
212 <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
213{
214 fn connect(
215 &self,
216 pg_config: &PgConfig,
217 ) -> BoxFuture<'_, Result<(PgClient, JoinHandle<()>), Error>> {
218 let tls = self.tls.clone();
219 let pg_config = pg_config.clone();
220 Box::pin(async move {
221 let fut = pg_config.connect(tls);
222 let (client, connection) = fut.await?;
223 let conn_task = spawn(async move {
224 if let Err(e) = connection.await {
225 tracing::warn!(target: "deadpool.postgres", "Connection error: {}", e);
226 }
227 });
228 Ok((client, conn_task))
229 })
230 }
231}
232
233#[derive(Default, Debug)]
236pub struct StatementCaches {
237 caches: Mutex<Vec<Weak<StatementCache>>>,
238}
239
240impl StatementCaches {
241 fn attach(&self, cache: &Arc<StatementCache>) {
242 let cache = Arc::downgrade(cache);
243 self.caches.lock().unwrap().push(cache);
244 }
245
246 fn detach(&self, cache: &Arc<StatementCache>) {
247 let cache = Arc::downgrade(cache);
248 self.caches.lock().unwrap().retain(|sc| !sc.ptr_eq(&cache));
249 }
250
251 pub fn clear(&self) {
254 let caches = self.caches.lock().unwrap();
255 for cache in caches.iter() {
256 if let Some(cache) = cache.upgrade() {
257 cache.clear();
258 }
259 }
260 }
261
262 pub fn remove(&self, query: &str, types: &[Type]) {
265 let caches = self.caches.lock().unwrap();
266 for cache in caches.iter() {
267 if let Some(cache) = cache.upgrade() {
268 drop(cache.remove(query, types));
269 }
270 }
271 }
272}
273
274impl fmt::Debug for StatementCache {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 f.debug_struct("ClientWrapper")
277 .field("size", &self.size)
279 .finish()
280 }
281}
282
283#[derive(Debug, Eq, Hash, PartialEq)]
286struct StatementCacheKey<'a> {
287 query: Cow<'a, str>,
288 types: Cow<'a, [Type]>,
289}
290
291pub struct StatementCache {
311 map: RwLock<HashMap<StatementCacheKey<'static>, Statement>>,
312 size: AtomicUsize,
313}
314
315impl StatementCache {
316 fn new() -> Self {
317 Self {
318 map: RwLock::new(HashMap::new()),
319 size: AtomicUsize::new(0),
320 }
321 }
322
323 pub fn size(&self) -> usize {
325 self.size.load(Ordering::Relaxed)
326 }
327
328 pub fn clear(&self) {
334 let mut map = self.map.write().unwrap();
335 map.clear();
336 self.size.store(0, Ordering::Relaxed);
337 }
338
339 pub fn remove(&self, query: &str, types: &[Type]) -> Option<Statement> {
346 let key = StatementCacheKey {
347 query: Cow::Owned(query.to_owned()),
348 types: Cow::Owned(types.to_owned()),
349 };
350 let mut map = self.map.write().unwrap();
351 let removed = map.remove(&key);
352 if removed.is_some() {
353 let _ = self.size.fetch_sub(1, Ordering::Relaxed);
354 }
355 removed
356 }
357
358 fn get(&self, query: &str, types: &[Type]) -> Option<Statement> {
360 let key = StatementCacheKey {
361 query: Cow::Borrowed(query),
362 types: Cow::Borrowed(types),
363 };
364 self.map.read().unwrap().get(&key).map(ToOwned::to_owned)
365 }
366
367 fn insert(&self, query: &str, types: &[Type], stmt: Statement) {
369 let key = StatementCacheKey {
370 query: Cow::Owned(query.to_owned()),
371 types: Cow::Owned(types.to_owned()),
372 };
373 let mut map = self.map.write().unwrap();
374 if map.insert(key, stmt).is_none() {
375 let _ = self.size.fetch_add(1, Ordering::Relaxed);
376 }
377 }
378
379 pub async fn prepare(&self, client: &PgClient, query: &str) -> Result<Statement, Error> {
384 self.prepare_typed(client, query, &[]).await
385 }
386
387 pub async fn prepare_typed(
392 &self,
393 client: &PgClient,
394 query: &str,
395 types: &[Type],
396 ) -> Result<Statement, Error> {
397 match self.get(query, types) {
398 Some(statement) => Ok(statement),
399 None => {
400 let stmt = client.prepare_typed(query, types).await?;
401 self.insert(query, types, stmt.clone());
402 Ok(stmt)
403 }
404 }
405 }
406}
407
408#[derive(Debug)]
410pub struct ClientWrapper {
411 client: PgClient,
413
414 conn_task: JoinHandle<()>,
417
418 pub statement_cache: Arc<StatementCache>,
420}
421
422impl ClientWrapper {
423 #[must_use]
426 pub fn new(client: PgClient, conn_task: JoinHandle<()>) -> Self {
427 Self {
428 client,
429 conn_task,
430 statement_cache: Arc::new(StatementCache::new()),
431 }
432 }
433
434 pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
437 self.statement_cache.prepare(&self.client, query).await
438 }
439
440 pub async fn prepare_typed_cached(
443 &self,
444 query: &str,
445 types: &[Type],
446 ) -> Result<Statement, Error> {
447 self.statement_cache
448 .prepare_typed(&self.client, query, types)
449 .await
450 }
451
452 #[allow(unused_lifetimes)] pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
456 Ok(Transaction {
457 txn: PgClient::transaction(&mut self.client).await?,
458 statement_cache: self.statement_cache.clone(),
459 })
460 }
461
462 pub fn build_transaction(&mut self) -> TransactionBuilder<'_> {
465 TransactionBuilder {
466 builder: self.client.build_transaction(),
467 statement_cache: self.statement_cache.clone(),
468 }
469 }
470}
471
472impl Deref for ClientWrapper {
473 type Target = PgClient;
474
475 fn deref(&self) -> &PgClient {
476 &self.client
477 }
478}
479
480impl DerefMut for ClientWrapper {
481 fn deref_mut(&mut self) -> &mut PgClient {
482 &mut self.client
483 }
484}
485
486impl Drop for ClientWrapper {
487 fn drop(&mut self) {
488 self.conn_task.abort()
489 }
490}
491
492pub struct Transaction<'a> {
495 txn: PgTransaction<'a>,
497
498 pub statement_cache: Arc<StatementCache>,
500}
501
502impl<'a> Transaction<'a> {
503 pub async fn prepare_cached(&self, query: &str) -> Result<Statement, Error> {
506 self.statement_cache.prepare(self.client(), query).await
507 }
508
509 pub async fn prepare_typed_cached(
512 &self,
513 query: &str,
514 types: &[Type],
515 ) -> Result<Statement, Error> {
516 self.statement_cache
517 .prepare_typed(self.client(), query, types)
518 .await
519 }
520
521 pub async fn commit(self) -> Result<(), Error> {
523 self.txn.commit().await
524 }
525
526 pub async fn rollback(self) -> Result<(), Error> {
528 self.txn.rollback().await
529 }
530
531 #[allow(unused_lifetimes)] pub async fn transaction(&mut self) -> Result<Transaction<'_>, Error> {
535 Ok(Transaction {
536 txn: PgTransaction::transaction(&mut self.txn).await?,
537 statement_cache: self.statement_cache.clone(),
538 })
539 }
540
541 #[allow(unused_lifetimes)] pub async fn savepoint<I>(&mut self, name: I) -> Result<Transaction<'_>, Error>
545 where
546 I: Into<String>,
547 {
548 Ok(Transaction {
549 txn: PgTransaction::savepoint(&mut self.txn, name).await?,
550 statement_cache: self.statement_cache.clone(),
551 })
552 }
553}
554
555impl<'a> fmt::Debug for Transaction<'a> {
556 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
557 f.debug_struct("Transaction")
558 .field("statement_cache", &self.statement_cache)
560 .finish()
561 }
562}
563
564impl<'a> Deref for Transaction<'a> {
565 type Target = PgTransaction<'a>;
566
567 fn deref(&self) -> &PgTransaction<'a> {
568 &self.txn
569 }
570}
571
572impl<'a> DerefMut for Transaction<'a> {
573 fn deref_mut(&mut self) -> &mut PgTransaction<'a> {
574 &mut self.txn
575 }
576}
577
578#[must_use = "builder does nothing itself, use `.start()` to use it"]
581pub struct TransactionBuilder<'a> {
582 builder: PgTransactionBuilder<'a>,
584
585 statement_cache: Arc<StatementCache>,
587}
588
589impl<'a> TransactionBuilder<'a> {
590 pub fn isolation_level(self, isolation_level: IsolationLevel) -> Self {
594 Self {
595 builder: self.builder.isolation_level(isolation_level),
596 statement_cache: self.statement_cache,
597 }
598 }
599
600 pub fn read_only(self, read_only: bool) -> Self {
604 Self {
605 builder: self.builder.read_only(read_only),
606 statement_cache: self.statement_cache,
607 }
608 }
609
610 pub fn deferrable(self, deferrable: bool) -> Self {
619 Self {
620 builder: self.builder.deferrable(deferrable),
621 statement_cache: self.statement_cache,
622 }
623 }
624
625 pub async fn start(self) -> Result<Transaction<'a>, Error> {
632 Ok(Transaction {
633 txn: self.builder.start().await?,
634 statement_cache: self.statement_cache,
635 })
636 }
637}
638
639impl<'a> fmt::Debug for TransactionBuilder<'a> {
640 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
641 f.debug_struct("TransactionBuilder")
642 .field("statement_cache", &self.statement_cache)
644 .finish()
645 }
646}
647
648impl<'a> Deref for TransactionBuilder<'a> {
649 type Target = PgTransactionBuilder<'a>;
650
651 fn deref(&self) -> &Self::Target {
652 &self.builder
653 }
654}
655
656impl<'a> DerefMut for TransactionBuilder<'a> {
657 fn deref_mut(&mut self) -> &mut Self::Target {
658 &mut self.builder
659 }
660}