deadpool_postgres/
config.rs

1//! Configuration used for [`Pool`] creation.
2
3use std::{env, fmt, net::IpAddr, str::FromStr, time::Duration};
4
5use tokio_postgres::config::{
6    ChannelBinding as PgChannelBinding, LoadBalanceHosts as PgLoadBalanceHosts,
7    SslMode as PgSslMode, TargetSessionAttrs as PgTargetSessionAttrs,
8};
9
10#[cfg(not(target_arch = "wasm32"))]
11use super::Pool;
12#[cfg(not(target_arch = "wasm32"))]
13use crate::{CreatePoolError, PoolBuilder, Runtime};
14#[cfg(not(target_arch = "wasm32"))]
15use tokio_postgres::{
16    tls::{MakeTlsConnect, TlsConnect},
17    Socket,
18};
19
20use super::PoolConfig;
21
22/// Configuration object.
23///
24/// # Example (from environment)
25///
26/// By enabling the `serde` feature you can read the configuration using the
27/// [`config`](https://crates.io/crates/config) crate as following:
28/// ```env
29/// PG__HOST=pg.example.com
30/// PG__USER=john_doe
31/// PG__PASSWORD=topsecret
32/// PG__DBNAME=example
33/// PG__POOL__MAX_SIZE=16
34/// PG__POOL__TIMEOUTS__WAIT__SECS=5
35/// PG__POOL__TIMEOUTS__WAIT__NANOS=0
36/// ```
37/// ```rust
38/// #[derive(serde::Deserialize, serde::Serialize)]
39/// struct Config {
40///     pg: deadpool_postgres::Config,
41/// }
42/// impl Config {
43///     pub fn from_env() -> Result<Self, config::ConfigError> {
44///         let mut cfg = config::Config::builder()
45///            .add_source(config::Environment::default().separator("__"))
46///            .build()?;
47///            cfg.try_deserialize()
48///     }
49/// }
50/// ```
51#[derive(Clone, Debug, Default)]
52#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
53pub struct Config {
54    /// Initialize the configuration by parsing the URL first.
55    /// **Note**: All the other options override settings defined
56    /// by the URL except for the `host` and `hosts` options which
57    /// are additive!
58    pub url: Option<String>,
59    /// See [`tokio_postgres::Config::user`].
60    pub user: Option<String>,
61    /// See [`tokio_postgres::Config::password`].
62    pub password: Option<String>,
63    /// See [`tokio_postgres::Config::dbname`].
64    pub dbname: Option<String>,
65    /// See [`tokio_postgres::Config::options`].
66    pub options: Option<String>,
67    /// See [`tokio_postgres::Config::application_name`].
68    pub application_name: Option<String>,
69    /// See [`tokio_postgres::Config::ssl_mode`].
70    pub ssl_mode: Option<SslMode>,
71    /// This is similar to [`Config::hosts`] but only allows one host to be
72    /// specified.
73    ///
74    /// Unlike [`tokio_postgres::Config`] this structure differentiates between
75    /// one host and more than one host. This makes it possible to store this
76    /// configuration in an environment variable.
77    ///
78    /// See [`tokio_postgres::Config::host`].
79    pub host: Option<String>,
80    /// See [`tokio_postgres::Config::host`].
81    pub hosts: Option<Vec<String>>,
82    /// See [`tokio_postgres::Config::hostaddr`].
83    pub hostaddr: Option<IpAddr>,
84    /// See [`tokio_postgres::Config::hostaddr`].
85    pub hostaddrs: Option<Vec<IpAddr>>,
86    /// This is similar to [`Config::ports`] but only allows one port to be
87    /// specified.
88    ///
89    /// Unlike [`tokio_postgres::Config`] this structure differentiates between
90    /// one port and more than one port. This makes it possible to store this
91    /// configuration in an environment variable.
92    ///
93    /// See [`tokio_postgres::Config::port`].
94    pub port: Option<u16>,
95    /// See [`tokio_postgres::Config::port`].
96    pub ports: Option<Vec<u16>>,
97    /// See [`tokio_postgres::Config::connect_timeout`].
98    pub connect_timeout: Option<Duration>,
99    /// See [`tokio_postgres::Config::keepalives`].
100    pub keepalives: Option<bool>,
101    #[cfg(not(target_arch = "wasm32"))]
102    /// See [`tokio_postgres::Config::keepalives_idle`].
103    pub keepalives_idle: Option<Duration>,
104    /// See [`tokio_postgres::Config::target_session_attrs`].
105    pub target_session_attrs: Option<TargetSessionAttrs>,
106    /// See [`tokio_postgres::Config::channel_binding`].
107    pub channel_binding: Option<ChannelBinding>,
108    /// See [`tokio_postgres::Config::load_balance_hosts`].
109    pub load_balance_hosts: Option<LoadBalanceHosts>,
110
111    /// [`Manager`] configuration.
112    ///
113    /// [`Manager`]: super::Manager
114    pub manager: Option<ManagerConfig>,
115
116    /// [`Pool`] configuration.
117    pub pool: Option<PoolConfig>,
118}
119
120/// This error is returned if there is something wrong with the configuration
121#[derive(Debug)]
122pub enum ConfigError {
123    /// This variant is returned if the `url` is invalid
124    InvalidUrl(tokio_postgres::Error),
125    /// This variant is returned if the `dbname` is missing from the config
126    DbnameMissing,
127    /// This variant is returned if the `dbname` contains an empty string
128    DbnameEmpty,
129}
130
131impl fmt::Display for ConfigError {
132    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
133        match self {
134            Self::InvalidUrl(e) => write!(f, "configuration property \"url\" is invalid: {}", e),
135            Self::DbnameMissing => write!(f, "configuration property \"dbname\" not found"),
136            Self::DbnameEmpty => write!(
137                f,
138                "configuration property \"dbname\" contains an empty string",
139            ),
140        }
141    }
142}
143
144impl std::error::Error for ConfigError {}
145
146impl Config {
147    /// Create a new [`Config`] instance with default values. This function is
148    /// identical to [`Config::default()`].
149    #[must_use]
150    pub fn new() -> Self {
151        Self::default()
152    }
153
154    #[cfg(not(target_arch = "wasm32"))]
155    /// Creates a new [`Pool`] using this [`Config`].
156    ///
157    /// # Errors
158    ///
159    /// See [`CreatePoolError`] for details.
160    pub fn create_pool<T>(&self, runtime: Option<Runtime>, tls: T) -> Result<Pool, CreatePoolError>
161    where
162        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
163        T::Stream: Sync + Send,
164        T::TlsConnect: Sync + Send,
165        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
166    {
167        let mut builder = self.builder(tls).map_err(CreatePoolError::Config)?;
168        if let Some(runtime) = runtime {
169            builder = builder.runtime(runtime);
170        }
171        builder.build().map_err(CreatePoolError::Build)
172    }
173
174    #[cfg(not(target_arch = "wasm32"))]
175    /// Creates a new [`PoolBuilder`] using this [`Config`].
176    ///
177    /// # Errors
178    ///
179    /// See [`ConfigError`] and [`tokio_postgres::Error`] for details.
180    pub fn builder<T>(&self, tls: T) -> Result<PoolBuilder, ConfigError>
181    where
182        T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
183        T::Stream: Sync + Send,
184        T::TlsConnect: Sync + Send,
185        <T::TlsConnect as TlsConnect<Socket>>::Future: Send,
186    {
187        let pg_config = self.get_pg_config()?;
188        let manager_config = self.get_manager_config();
189        let manager = crate::Manager::from_config(pg_config, tls, manager_config);
190        let pool_config = self.get_pool_config();
191        Ok(Pool::builder(manager).config(pool_config))
192    }
193
194    /// Returns [`tokio_postgres::Config`] which can be used to connect to
195    /// the database.
196    #[allow(unused_results)]
197    pub fn get_pg_config(&self) -> Result<tokio_postgres::Config, ConfigError> {
198        let mut cfg = if let Some(url) = &self.url {
199            tokio_postgres::Config::from_str(url).map_err(ConfigError::InvalidUrl)?
200        } else {
201            tokio_postgres::Config::new()
202        };
203        if let Some(user) = self.user.as_ref().filter(|s| !s.is_empty()) {
204            cfg.user(user.as_str());
205        }
206        if !cfg.get_user().map_or(false, |u| !u.is_empty()) {
207            if let Ok(user) = env::var("USER") {
208                cfg.user(&user);
209            }
210        }
211        if let Some(password) = &self.password {
212            cfg.password(password);
213        }
214        if let Some(dbname) = self.dbname.as_ref().filter(|s| !s.is_empty()) {
215            cfg.dbname(dbname);
216        }
217        match cfg.get_dbname() {
218            None => {
219                return Err(ConfigError::DbnameMissing);
220            }
221            Some("") => {
222                return Err(ConfigError::DbnameEmpty);
223            }
224            _ => {}
225        }
226        if let Some(options) = &self.options {
227            cfg.options(options.as_str());
228        }
229        if let Some(application_name) = &self.application_name {
230            cfg.application_name(application_name.as_str());
231        }
232        if let Some(host) = &self.host {
233            cfg.host(host.as_str());
234        }
235        if let Some(hosts) = &self.hosts {
236            for host in hosts.iter() {
237                cfg.host(host.as_str());
238            }
239        }
240        if cfg.get_hosts().is_empty() {
241            // Systems that support it default to unix domain sockets.
242            #[cfg(unix)]
243            {
244                cfg.host_path("/run/postgresql");
245                cfg.host_path("/var/run/postgresql");
246                cfg.host_path("/tmp");
247            }
248            // Windows and other systems use 127.0.0.1 instead.
249            #[cfg(not(unix))]
250            cfg.host("127.0.0.1");
251        }
252        if let Some(hostaddr) = self.hostaddr {
253            cfg.hostaddr(hostaddr);
254        }
255        if let Some(hostaddrs) = &self.hostaddrs {
256            for hostaddr in hostaddrs {
257                cfg.hostaddr(*hostaddr);
258            }
259        }
260        if let Some(port) = self.port {
261            cfg.port(port);
262        }
263        if let Some(ports) = &self.ports {
264            for port in ports.iter() {
265                cfg.port(*port);
266            }
267        }
268        if let Some(connect_timeout) = self.connect_timeout {
269            cfg.connect_timeout(connect_timeout);
270        }
271        if let Some(keepalives) = self.keepalives {
272            cfg.keepalives(keepalives);
273        }
274        #[cfg(not(target_arch = "wasm32"))]
275        if let Some(keepalives_idle) = self.keepalives_idle {
276            cfg.keepalives_idle(keepalives_idle);
277        }
278        if let Some(mode) = self.ssl_mode {
279            cfg.ssl_mode(mode.into());
280        }
281        Ok(cfg)
282    }
283
284    /// Returns [`ManagerConfig`] which can be used to construct a
285    /// [`deadpool::managed::Pool`] instance.
286    #[must_use]
287    pub fn get_manager_config(&self) -> ManagerConfig {
288        self.manager.clone().unwrap_or_default()
289    }
290
291    /// Returns [`deadpool::managed::PoolConfig`] which can be used to construct
292    /// a [`deadpool::managed::Pool`] instance.
293    #[must_use]
294    pub fn get_pool_config(&self) -> PoolConfig {
295        self.pool.unwrap_or_default()
296    }
297}
298
299/// Possible methods of how a connection is recycled.
300///
301/// The default is [`Fast`] which does not check the connection health or
302/// perform any clean-up queries.
303///
304/// [`Fast`]: RecyclingMethod::Fast
305/// [`Verified`]: RecyclingMethod::Verified
306#[derive(Clone, Debug, Eq, PartialEq)]
307#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
308pub enum RecyclingMethod {
309    /// Only run [`Client::is_closed()`][1] when recycling existing connections.
310    ///
311    /// Unless you have special needs this is a safe choice.
312    ///
313    /// [1]: tokio_postgres::Client::is_closed
314    Fast,
315
316    /// Run [`Client::is_closed()`][1] and execute a test query.
317    ///
318    /// This is slower, but guarantees that the database connection is ready to
319    /// be used. Normally, [`Client::is_closed()`][1] should be enough to filter
320    /// out bad connections, but under some circumstances (i.e. hard-closed
321    /// network connections) it's possible that [`Client::is_closed()`][1]
322    /// returns `false` while the connection is dead. You will receive an error
323    /// on your first query then.
324    ///
325    /// [1]: tokio_postgres::Client::is_closed
326    Verified,
327
328    /// Like [`Verified`] query method, but instead use the following sequence
329    /// of statements which guarantees a pristine connection:
330    /// ```sql
331    /// CLOSE ALL;
332    /// SET SESSION AUTHORIZATION DEFAULT;
333    /// RESET ALL;
334    /// UNLISTEN *;
335    /// SELECT pg_advisory_unlock_all();
336    /// DISCARD TEMP;
337    /// DISCARD SEQUENCES;
338    /// ```
339    ///
340    /// This is similar to calling `DISCARD ALL`. but doesn't call
341    /// `DEALLOCATE ALL` and `DISCARD PLAN`, so that the statement cache is not
342    /// rendered ineffective.
343    ///
344    /// [`Verified`]: RecyclingMethod::Verified
345    Clean,
346
347    /// Like [`Verified`] but allows to specify a custom SQL to be executed.
348    ///
349    /// [`Verified`]: RecyclingMethod::Verified
350    Custom(String),
351}
352
353impl Default for RecyclingMethod {
354    fn default() -> Self {
355        Self::Fast
356    }
357}
358
359impl RecyclingMethod {
360    const DISCARD_SQL: &'static str = "\
361        CLOSE ALL; \
362        SET SESSION AUTHORIZATION DEFAULT; \
363        RESET ALL; \
364        UNLISTEN *; \
365        SELECT pg_advisory_unlock_all(); \
366        DISCARD TEMP; \
367        DISCARD SEQUENCES;\
368    ";
369
370    /// Returns SQL query to be executed when recycling a connection.
371    pub fn query(&self) -> Option<&str> {
372        match self {
373            Self::Fast => None,
374            Self::Verified => Some(""),
375            Self::Clean => Some(Self::DISCARD_SQL),
376            Self::Custom(sql) => Some(sql),
377        }
378    }
379}
380
381/// Configuration object for a [`Manager`].
382///
383/// This currently only makes it possible to specify which [`RecyclingMethod`]
384/// should be used when retrieving existing objects from the [`Pool`].
385///
386/// [`Manager`]: super::Manager
387#[derive(Clone, Debug, Default)]
388#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
389pub struct ManagerConfig {
390    /// Method of how a connection is recycled. See [`RecyclingMethod`].
391    pub recycling_method: RecyclingMethod,
392}
393
394/// Properties required of a session.
395///
396/// This is a 1:1 copy of the [`PgTargetSessionAttrs`] enumeration.
397/// This is duplicated here in order to add support for the
398/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
399#[derive(Clone, Copy, Debug, Eq, PartialEq)]
400#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
401#[non_exhaustive]
402pub enum TargetSessionAttrs {
403    /// No special properties are required.
404    Any,
405
406    /// The session must allow writes.
407    ReadWrite,
408}
409
410impl From<TargetSessionAttrs> for PgTargetSessionAttrs {
411    fn from(attrs: TargetSessionAttrs) -> Self {
412        match attrs {
413            TargetSessionAttrs::Any => Self::Any,
414            TargetSessionAttrs::ReadWrite => Self::ReadWrite,
415        }
416    }
417}
418
419/// TLS configuration.
420///
421/// This is a 1:1 copy of the [`PgSslMode`] enumeration.
422/// This is duplicated here in order to add support for the
423/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
424#[derive(Clone, Copy, Debug, Eq, PartialEq)]
425#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
426#[non_exhaustive]
427pub enum SslMode {
428    /// Do not use TLS.
429    Disable,
430
431    /// Attempt to connect with TLS but allow sessions without.
432    Prefer,
433
434    /// Require the use of TLS.
435    Require,
436}
437
438impl From<SslMode> for PgSslMode {
439    fn from(mode: SslMode) -> Self {
440        match mode {
441            SslMode::Disable => Self::Disable,
442            SslMode::Prefer => Self::Prefer,
443            SslMode::Require => Self::Require,
444        }
445    }
446}
447
448/// Channel binding configuration.
449///
450/// This is a 1:1 copy of the [`PgChannelBinding`] enumeration.
451/// This is duplicated here in order to add support for the
452/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
453#[derive(Clone, Copy, Debug, Eq, PartialEq)]
454#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
455#[non_exhaustive]
456pub enum ChannelBinding {
457    /// Do not use channel binding.
458    Disable,
459
460    /// Attempt to use channel binding but allow sessions without.
461    Prefer,
462
463    /// Require the use of channel binding.
464    Require,
465}
466
467impl From<ChannelBinding> for PgChannelBinding {
468    fn from(cb: ChannelBinding) -> Self {
469        match cb {
470            ChannelBinding::Disable => Self::Disable,
471            ChannelBinding::Prefer => Self::Prefer,
472            ChannelBinding::Require => Self::Require,
473        }
474    }
475}
476
477/// Load balancing configuration.
478///
479/// This is a 1:1 copy of the [`PgLoadBalanceHosts`] enumeration.
480/// This is duplicated here in order to add support for the
481/// [`serde::Deserialize`] trait which is required for the [`serde`] support.
482#[derive(Debug, Copy, Clone, PartialEq, Eq)]
483#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
484#[non_exhaustive]
485pub enum LoadBalanceHosts {
486    /// Make connection attempts to hosts in the order provided.
487    Disable,
488    /// Make connection attempts to hosts in a random order.
489    Random,
490}
491
492impl From<LoadBalanceHosts> for PgLoadBalanceHosts {
493    fn from(cb: LoadBalanceHosts) -> Self {
494        match cb {
495            LoadBalanceHosts::Disable => Self::Disable,
496            LoadBalanceHosts::Random => Self::Random,
497        }
498    }
499}