use crate::{Client, RUNTIME};
use futures::{executor, FutureExt};
use log::error;
use std::fmt;
use std::future::Future;
use std::path::Path;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::{mpsc, Arc};
use std::time::Duration;
#[doc(inline)]
pub use tokio_postgres::config::{ChannelBinding, SslMode, TargetSessionAttrs};
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket};
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
spawner: Option<Arc<dyn Fn(Pin<Box<dyn Future<Output = ()> + Send>>) + Sync + Send>>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
pub fn new() -> Config {
Config {
config: tokio_postgres::Config::new(),
spawner: None,
}
}
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
#[cfg(unix)]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
pub fn spawner<F>(&mut self, spawn: F) -> &mut Config
where
F: Fn(Pin<Box<dyn Future<Output = ()> + Send>>) + 'static + Sync + Send,
{
self.spawner = Some(Arc::new(spawn));
self
}
pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
where
T: MakeTlsConnect<Socket> + 'static + Send,
T::TlsConnect: Send,
T::Stream: Send,
<T::TlsConnect as TlsConnect<Socket>>::Future: Send,
{
let (client, connection) = match &self.spawner {
Some(spawn) => {
let (tx, rx) = mpsc::channel();
let config = self.config.clone();
let connect = async move {
let r = config.connect(tls).await;
let _ = tx.send(r);
};
spawn(Box::pin(connect));
rx.recv().unwrap()?
}
None => {
let connect = self.config.connect(tls);
RUNTIME.handle().enter(|| executor::block_on(connect))?
}
};
let connection = connection.map(|r| {
if let Err(e) = r {
error!("postgres connection error: {}", e)
}
});
match &self.spawner {
Some(spawn) => {
spawn(Box::pin(connection));
}
None => {
RUNTIME.spawn(connection);
}
}
Ok(Client::from(client))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
spawner: None,
}
}
}