use std::{env, fmt, time::Duration};
#[cfg(feature = "serde")]
use serde_1 as serde;
use tokio_postgres::{
config::{
ChannelBinding as PgChannelBinding, SslMode as PgSslMode,
TargetSessionAttrs as PgTargetSessionAttrs,
},
tls::{MakeTlsConnect, TlsConnect},
Socket,
};
use crate::{CreatePoolError, PoolBuilder, Runtime};
use super::{Pool, PoolConfig};
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
pub struct Config {
pub user: Option<String>,
pub password: Option<String>,
pub dbname: Option<String>,
pub options: Option<String>,
pub application_name: Option<String>,
pub ssl_mode: Option<SslMode>,
pub host: Option<String>,
pub hosts: Option<Vec<String>>,
pub port: Option<u16>,
pub ports: Option<Vec<u16>>,
pub connect_timeout: Option<Duration>,
pub keepalives: Option<bool>,
pub keepalives_idle: Option<Duration>,
pub target_session_attrs: Option<TargetSessionAttrs>,
pub channel_binding: Option<ChannelBinding>,
pub manager: Option<ManagerConfig>,
pub pool: Option<PoolConfig>,
}
#[derive(Copy, Clone, Debug)]
pub enum ConfigError {
DbnameMissing,
DbnameEmpty,
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::DbnameMissing => write!(f, "configuration property \"dbname\" not found"),
Self::DbnameEmpty => write!(
f,
"configuration property \"dbname\" contains an empty string",
),
}
}
}
impl std::error::Error for ConfigError {}
impl Config {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn create_pool<T>(&self, runtime: Option<Runtime>, tls: T) -> Result<Pool, CreatePoolError>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let mut builder = self.builder(tls).map_err(CreatePoolError::Config)?;
if let Some(runtime) = runtime {
builder = builder.runtime(runtime);
}
builder.build().map_err(CreatePoolError::Build)
}
pub fn builder<T>(&self, tls: T) -> Result<PoolBuilder, ConfigError>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let pg_config = self.get_pg_config()?;
let manager_config = self.get_manager_config();
let manager = crate::Manager::from_config(pg_config, tls, manager_config);
let pool_config = self.get_pool_config();
Ok(Pool::builder(manager).config(pool_config))
}
#[allow(unused_results)]
pub fn get_pg_config(&self) -> Result<tokio_postgres::Config, ConfigError> {
let mut cfg = tokio_postgres::Config::new();
if let Some(user) = &self.user {
cfg.user(user.as_str());
} else if let Ok(user) = env::var("USER") {
cfg.user(user.as_str());
}
if let Some(password) = &self.password {
cfg.password(password);
}
match &self.dbname {
Some(dbname) => match dbname.as_str() {
"" => return Err(ConfigError::DbnameMissing),
dbname => cfg.dbname(dbname),
},
None => return Err(ConfigError::DbnameEmpty),
};
if let Some(options) = &self.options {
cfg.options(options.as_str());
}
if let Some(application_name) = &self.application_name {
cfg.application_name(application_name.as_str());
}
if let Some(host) = &self.host {
cfg.host(host.as_str());
}
if let Some(hosts) = &self.hosts {
for host in hosts.iter() {
cfg.host(host.as_str());
}
}
if self.host.is_none() && self.hosts.is_none() {
#[cfg(unix)]
{
cfg.host_path("/run/postgresql");
cfg.host_path("/var/run/postgresql");
cfg.host_path("/tmp");
}
#[cfg(not(unix))]
cfg.host("127.0.0.1");
}
if let Some(port) = self.port {
cfg.port(port);
}
if let Some(ports) = &self.ports {
for port in ports.iter() {
cfg.port(*port);
}
}
if let Some(connect_timeout) = self.connect_timeout {
cfg.connect_timeout(connect_timeout);
}
if let Some(keepalives) = self.keepalives {
cfg.keepalives(keepalives);
}
if let Some(keepalives_idle) = self.keepalives_idle {
cfg.keepalives_idle(keepalives_idle);
}
if let Some(mode) = self.ssl_mode {
cfg.ssl_mode(mode.into());
}
Ok(cfg)
}
#[must_use]
pub fn get_manager_config(&self) -> ManagerConfig {
self.manager.clone().unwrap_or_default()
}
#[must_use]
pub fn get_pool_config(&self) -> PoolConfig {
self.pool.unwrap_or_default()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
pub enum RecyclingMethod {
Fast,
Verified,
Clean,
Custom(String),
}
impl Default for RecyclingMethod {
fn default() -> Self {
Self::Fast
}
}
impl RecyclingMethod {
const DISCARD_SQL: &'static str = "\
CLOSE ALL; \
SET SESSION AUTHORIZATION DEFAULT; \
RESET ALL; \
UNLISTEN *; \
SELECT pg_advisory_unlock_all(); \
DISCARD TEMP; \
DISCARD SEQUENCES;\
";
pub fn query(&self) -> Option<&str> {
match self {
Self::Fast => None,
Self::Verified => Some(""),
Self::Clean => Some(Self::DISCARD_SQL),
Self::Custom(sql) => Some(sql),
}
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
pub struct ManagerConfig {
pub recycling_method: RecyclingMethod,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
#[non_exhaustive]
pub enum TargetSessionAttrs {
Any,
ReadWrite,
}
impl From<TargetSessionAttrs> for PgTargetSessionAttrs {
fn from(attrs: TargetSessionAttrs) -> Self {
match attrs {
TargetSessionAttrs::Any => Self::Any,
TargetSessionAttrs::ReadWrite => Self::ReadWrite,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
#[non_exhaustive]
pub enum SslMode {
Disable,
Prefer,
Require,
}
impl From<SslMode> for PgSslMode {
fn from(mode: SslMode) -> Self {
match mode {
SslMode::Disable => Self::Disable,
SslMode::Prefer => Self::Prefer,
SslMode::Require => Self::Require,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[cfg_attr(feature = "serde", serde(crate = "serde_1"))]
#[non_exhaustive]
pub enum ChannelBinding {
Disable,
Prefer,
Require,
}
impl From<ChannelBinding> for PgChannelBinding {
fn from(cb: ChannelBinding) -> Self {
match cb {
ChannelBinding::Disable => Self::Disable,
ChannelBinding::Prefer => Self::Prefer,
ChannelBinding::Require => Self::Require,
}
}
}