use std::env;
use std::fmt;
use std::path::Path;
use std::time::Duration;
use deadpool::managed::PoolConfig;
use tokio_postgres::config::{
ChannelBinding as PgChannelBinding, SslMode as PgSslMode,
TargetSessionAttrs as PgTargetSessionAttrs,
};
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::Socket;
use crate::Pool;
#[derive(Debug)]
pub enum ConfigError {
Message(String),
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Message(message) => write!(f, "{}", message),
}
}
}
#[cfg(feature = "config")]
impl Into<::config_crate::ConfigError> for ConfigError {
fn into(self) -> ::config_crate::ConfigError {
match self {
Self::Message(message) => ::config_crate::ConfigError::Message(message),
}
}
}
impl std::error::Error for ConfigError {}
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[non_exhaustive]
pub enum TargetSessionAttrs {
Any,
ReadWrite,
}
impl Into<PgTargetSessionAttrs> for TargetSessionAttrs {
fn into(self) -> PgTargetSessionAttrs {
match self {
Self::Any => PgTargetSessionAttrs::Any,
Self::ReadWrite => PgTargetSessionAttrs::ReadWrite,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[non_exhaustive]
pub enum SslMode {
Disable,
Prefer,
Require,
}
impl Into<PgSslMode> for SslMode {
fn into(self) -> PgSslMode {
match self {
Self::Disable => PgSslMode::Disable,
Self::Prefer => PgSslMode::Prefer,
Self::Require => PgSslMode::Require,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
#[non_exhaustive]
pub enum ChannelBinding {
Disable,
Prefer,
Require,
}
impl Into<PgChannelBinding> for ChannelBinding {
fn into(self) -> PgChannelBinding {
match self {
Self::Disable => PgChannelBinding::Disable,
Self::Prefer => PgChannelBinding::Prefer,
Self::Require => PgChannelBinding::Require,
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
pub enum RecyclingMethod {
Fast,
Verified,
Clean,
Custom(String),
}
const DISCARD_SQL: &str = "
CLOSE ALL;
SET SESSION AUTHORIZATION DEFAULT;
RESET ALL;
UNLISTEN *;
SELECT pg_advisory_unlock_all();
DISCARD TEMP;
DISCARD SEQUENCES;
";
impl RecyclingMethod {
pub fn query<'a>(&'a self) -> Option<&'a str> {
match self {
Self::Fast => None,
Self::Verified => Some(""),
Self::Clean => Some(DISCARD_SQL),
Self::Custom(sql) => Some(&sql),
}
}
}
impl Default for RecyclingMethod {
fn default() -> Self {
Self::Verified
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
pub struct ManagerConfig {
pub recycling_method: RecyclingMethod,
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "config", derive(serde::Deserialize))]
pub struct Config {
pub user: Option<String>,
pub password: Option<String>,
pub dbname: Option<String>,
pub options: Option<String>,
pub application_name: Option<String>,
pub ssl_mode: Option<SslMode>,
pub host: Option<String>,
pub hosts: Option<Vec<String>>,
pub port: Option<u16>,
pub ports: Option<Vec<u16>>,
pub connect_timeout: Option<Duration>,
pub keepalives: Option<bool>,
pub keepalives_idle: Option<Duration>,
pub target_session_attrs: Option<TargetSessionAttrs>,
pub channel_binding: Option<ChannelBinding>,
pub manager: Option<ManagerConfig>,
pub pool: Option<PoolConfig>,
}
impl Config {
pub fn new() -> Self {
Self::default()
}
#[deprecated(
since = "0.5.5",
note = "Please embed this structure in your own config structure and use `config::Config` directly."
)]
#[cfg(feature = "config")]
pub fn from_env(prefix: &str) -> Result<Self, ::config_crate::ConfigError> {
use ::config_crate::Environment;
let mut cfg = ::config_crate::Config::new();
cfg.merge(Environment::with_prefix(prefix))?;
cfg.try_into()
}
pub fn create_pool<T>(&self, tls: T) -> Result<Pool, ConfigError>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let pg_config = self.get_pg_config()?;
let manager_config = self.get_manager_config();
let manager = crate::Manager::from_config(pg_config, tls, manager_config);
let pool_config = self.get_pool_config();
Ok(Pool::from_config(manager, pool_config))
}
pub fn get_pg_config(&self) -> Result<tokio_postgres::Config, ConfigError> {
let mut cfg = tokio_postgres::Config::new();
if let Some(user) = &self.user {
cfg.user(user.as_str());
} else if let Ok(user) = env::var("USER") {
cfg.user(user.as_str());
}
if let Some(password) = &self.password {
cfg.password(password);
}
match &self.dbname {
Some(dbname) => match dbname.as_str() {
"" => {
return Err(ConfigError::Message(
"configuration property \"dbname\" not found".to_string(),
))
}
dbname => cfg.dbname(dbname),
},
None => {
return Err(ConfigError::Message(
"configuration property \"dbname\" contains an empty string".to_string(),
))
}
};
if let Some(options) = &self.options {
cfg.options(options.as_str());
}
if let Some(application_name) = &self.application_name {
cfg.application_name(application_name.as_str());
}
if let Some(host) = &self.host {
cfg.host(host.as_str());
}
if let Some(hosts) = &self.hosts {
for host in hosts.iter() {
cfg.host(host.as_str());
}
} else {
#[cfg(unix)]
{
if Path::new("/run/postgresql").exists() {
cfg.host_path("/run/postgresql");
} else if Path::new("/var/run/postgresql").exists() {
cfg.host_path("/var/run/postgresql");
} else {
cfg.host_path("/tmp");
}
}
#[cfg(not(unix))]
cfg.host("127.0.0.1");
}
if let Some(port) = &self.port {
cfg.port(*port);
}
if let Some(ports) = &self.ports {
for port in ports.iter() {
cfg.port(*port);
}
}
if let Some(connect_timeout) = &self.connect_timeout {
cfg.connect_timeout(*connect_timeout);
}
if let Some(keepalives) = &self.keepalives {
cfg.keepalives(*keepalives);
}
if let Some(keepalives_idle) = &self.keepalives_idle {
cfg.keepalives_idle(*keepalives_idle);
}
Ok(cfg)
}
pub fn get_manager_config(&self) -> ManagerConfig {
self.manager.clone().unwrap_or_default()
}
pub fn get_pool_config(&self) -> PoolConfig {
self.pool.clone().unwrap_or_default()
}
}