use std::{env, fmt, net::IpAddr, str::FromStr, time::Duration};
use tokio_postgres::config::{
ChannelBinding as PgChannelBinding, LoadBalanceHosts as PgLoadBalanceHosts,
SslMode as PgSslMode, TargetSessionAttrs as PgTargetSessionAttrs,
};
#[cfg(not(target_arch = "wasm32"))]
use super::Pool;
#[cfg(not(target_arch = "wasm32"))]
use crate::{CreatePoolError, PoolBuilder, Runtime};
#[cfg(not(target_arch = "wasm32"))]
use tokio_postgres::{
tls::{MakeTlsConnect, TlsConnect},
Socket,
};
use super::PoolConfig;
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct Config {
pub url: Option<String>,
pub user: Option<String>,
pub password: Option<String>,
pub dbname: Option<String>,
pub options: Option<String>,
pub application_name: Option<String>,
pub ssl_mode: Option<SslMode>,
pub host: Option<String>,
pub hosts: Option<Vec<String>>,
pub hostaddr: Option<IpAddr>,
pub hostaddrs: Option<Vec<IpAddr>>,
pub port: Option<u16>,
pub ports: Option<Vec<u16>>,
pub connect_timeout: Option<Duration>,
pub keepalives: Option<bool>,
#[cfg(not(target_arch = "wasm32"))]
pub keepalives_idle: Option<Duration>,
pub target_session_attrs: Option<TargetSessionAttrs>,
pub channel_binding: Option<ChannelBinding>,
pub load_balance_hosts: Option<LoadBalanceHosts>,
pub manager: Option<ManagerConfig>,
pub pool: Option<PoolConfig>,
}
#[derive(Debug)]
pub enum ConfigError {
InvalidUrl(tokio_postgres::Error),
DbnameMissing,
DbnameEmpty,
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidUrl(e) => write!(f, "configuration property \"url\" is invalid: {}", e),
Self::DbnameMissing => write!(f, "configuration property \"dbname\" not found"),
Self::DbnameEmpty => write!(
f,
"configuration property \"dbname\" contains an empty string",
),
}
}
}
impl std::error::Error for ConfigError {}
impl Config {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[cfg(not(target_arch = "wasm32"))]
pub fn create_pool<T>(&self, runtime: Option<Runtime>, tls: T) -> Result<Pool, CreatePoolError>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let mut builder = self.builder(tls).map_err(CreatePoolError::Config)?;
if let Some(runtime) = runtime {
builder = builder.runtime(runtime);
}
builder.build().map_err(CreatePoolError::Build)
}
#[cfg(not(target_arch = "wasm32"))]
pub fn builder<T>(&self, tls: T) -> Result<PoolBuilder, ConfigError>
where
T: MakeTlsConnect<Socket> + Clone + Sync + Send + 'static,
T::Stream: Sync + Send,
T::TlsConnect: Sync + Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let pg_config = self.get_pg_config()?;
let manager_config = self.get_manager_config();
let manager = crate::Manager::from_config(pg_config, tls, manager_config);
let pool_config = self.get_pool_config();
Ok(Pool::builder(manager).config(pool_config))
}
#[allow(unused_results)]
pub fn get_pg_config(&self) -> Result<tokio_postgres::Config, ConfigError> {
let mut cfg = if let Some(url) = &self.url {
tokio_postgres::Config::from_str(url).map_err(ConfigError::InvalidUrl)?
} else {
tokio_postgres::Config::new()
};
if let Some(user) = self.user.as_ref().filter(|s| !s.is_empty()) {
cfg.user(user.as_str());
}
if !cfg.get_user().map_or(false, |u| !u.is_empty()) {
if let Ok(user) = env::var("USER") {
cfg.user(&user);
}
}
if let Some(password) = &self.password {
cfg.password(password);
}
if let Some(dbname) = self.dbname.as_ref().filter(|s| !s.is_empty()) {
cfg.dbname(dbname);
}
match cfg.get_dbname() {
None => {
return Err(ConfigError::DbnameMissing);
}
Some("") => {
return Err(ConfigError::DbnameEmpty);
}
_ => {}
}
if let Some(options) = &self.options {
cfg.options(options.as_str());
}
if let Some(application_name) = &self.application_name {
cfg.application_name(application_name.as_str());
}
if let Some(host) = &self.host {
cfg.host(host.as_str());
}
if let Some(hosts) = &self.hosts {
for host in hosts.iter() {
cfg.host(host.as_str());
}
}
if cfg.get_hosts().is_empty() {
#[cfg(unix)]
{
cfg.host_path("/run/postgresql");
cfg.host_path("/var/run/postgresql");
cfg.host_path("/tmp");
}
#[cfg(not(unix))]
cfg.host("127.0.0.1");
}
if let Some(hostaddr) = self.hostaddr {
cfg.hostaddr(hostaddr);
}
if let Some(hostaddrs) = &self.hostaddrs {
for hostaddr in hostaddrs {
cfg.hostaddr(*hostaddr);
}
}
if let Some(port) = self.port {
cfg.port(port);
}
if let Some(ports) = &self.ports {
for port in ports.iter() {
cfg.port(*port);
}
}
if let Some(connect_timeout) = self.connect_timeout {
cfg.connect_timeout(connect_timeout);
}
if let Some(keepalives) = self.keepalives {
cfg.keepalives(keepalives);
}
#[cfg(not(target_arch = "wasm32"))]
if let Some(keepalives_idle) = self.keepalives_idle {
cfg.keepalives_idle(keepalives_idle);
}
if let Some(mode) = self.ssl_mode {
cfg.ssl_mode(mode.into());
}
Ok(cfg)
}
#[must_use]
pub fn get_manager_config(&self) -> ManagerConfig {
self.manager.clone().unwrap_or_default()
}
#[must_use]
pub fn get_pool_config(&self) -> PoolConfig {
self.pool.unwrap_or_default()
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum RecyclingMethod {
Fast,
Verified,
Clean,
Custom(String),
}
impl Default for RecyclingMethod {
fn default() -> Self {
Self::Fast
}
}
impl RecyclingMethod {
const DISCARD_SQL: &'static str = "\
CLOSE ALL; \
SET SESSION AUTHORIZATION DEFAULT; \
RESET ALL; \
UNLISTEN *; \
SELECT pg_advisory_unlock_all(); \
DISCARD TEMP; \
DISCARD SEQUENCES;\
";
pub fn query(&self) -> Option<&str> {
match self {
Self::Fast => None,
Self::Verified => Some(""),
Self::Clean => Some(Self::DISCARD_SQL),
Self::Custom(sql) => Some(sql),
}
}
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub struct ManagerConfig {
pub recycling_method: RecyclingMethod,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[non_exhaustive]
pub enum TargetSessionAttrs {
Any,
ReadWrite,
}
impl From<TargetSessionAttrs> for PgTargetSessionAttrs {
fn from(attrs: TargetSessionAttrs) -> Self {
match attrs {
TargetSessionAttrs::Any => Self::Any,
TargetSessionAttrs::ReadWrite => Self::ReadWrite,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[non_exhaustive]
pub enum SslMode {
Disable,
Prefer,
Require,
}
impl From<SslMode> for PgSslMode {
fn from(mode: SslMode) -> Self {
match mode {
SslMode::Disable => Self::Disable,
SslMode::Prefer => Self::Prefer,
SslMode::Require => Self::Require,
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[non_exhaustive]
pub enum ChannelBinding {
Disable,
Prefer,
Require,
}
impl From<ChannelBinding> for PgChannelBinding {
fn from(cb: ChannelBinding) -> Self {
match cb {
ChannelBinding::Disable => Self::Disable,
ChannelBinding::Prefer => Self::Prefer,
ChannelBinding::Require => Self::Require,
}
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
#[non_exhaustive]
pub enum LoadBalanceHosts {
Disable,
Random,
}
impl From<LoadBalanceHosts> for PgLoadBalanceHosts {
fn from(cb: LoadBalanceHosts) -> Self {
match cb {
LoadBalanceHosts::Disable => Self::Disable,
LoadBalanceHosts::Random => Self::Random,
}
}
}