#![allow(clippy::doc_overindented_list_items)]
use crate::Client;
use crate::connection::Connection;
use log::info;
use std::fmt;
use std::net::IpAddr;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, SslNegotiation, TargetSessionAttrs,
};
use tokio_postgres::error::DbError;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket};
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.config.ssl_negotiation(ssl_negotiation);
self
}
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.config.get_ssl_negotiation()
}
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
#[cfg(unix)]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
where
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let runtime = runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
let (client, connection) = runtime.block_on(self.config.connect(tls))?;
let connection = Connection::new(runtime, connection, self.notice_callback.clone());
Ok(Client::new(connection, client))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}