use std::borrow::Cow;
use std::env::var;
use std::fmt::{Display, Write};
use std::path::{Path, PathBuf};
mod connect;
mod parse;
mod pgpass;
mod ssl_mode;
use crate::{connection::LogSettings, net::CertificateInput};
pub use ssl_mode::PgSslMode;
#[derive(Debug, Clone)]
pub struct PgConnectOptions {
pub(crate) host: String,
pub(crate) port: u16,
pub(crate) socket: Option<PathBuf>,
pub(crate) username: String,
pub(crate) password: Option<String>,
pub(crate) database: Option<String>,
pub(crate) ssl_mode: PgSslMode,
pub(crate) ssl_root_cert: Option<CertificateInput>,
pub(crate) statement_cache_capacity: usize,
pub(crate) application_name: Option<String>,
pub(crate) log_settings: LogSettings,
pub(crate) extra_float_digits: Option<Cow<'static, str>>,
pub(crate) options: Option<String>,
}
impl Default for PgConnectOptions {
fn default() -> Self {
Self::new_without_pgpass().apply_pgpass()
}
}
impl PgConnectOptions {
pub fn new() -> Self {
Self::new_without_pgpass().apply_pgpass()
}
pub fn new_without_pgpass() -> Self {
let port = var("PGPORT")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(5432);
let host = var("PGHOST").ok().unwrap_or_else(|| default_host(port));
let username = var("PGUSER").ok().unwrap_or_else(|| "postgres".into());
let database = var("PGDATABASE").ok();
PgConnectOptions {
port,
host,
socket: None,
username,
password: var("PGPASSWORD").ok(),
database,
ssl_root_cert: var("PGSSLROOTCERT").ok().map(CertificateInput::from),
ssl_mode: var("PGSSLMODE")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_default(),
statement_cache_capacity: 100,
application_name: var("PGAPPNAME").ok(),
extra_float_digits: Some("3".into()),
log_settings: Default::default(),
options: var("PGOPTIONS").ok(),
}
}
pub(crate) fn apply_pgpass(mut self) -> Self {
if self.password.is_none() {
self.password = pgpass::load_password(
&self.host,
self.port,
&self.username,
self.database.as_deref(),
);
}
self
}
pub fn host(mut self, host: &str) -> Self {
self.host = host.to_owned();
self
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn socket(mut self, path: impl AsRef<Path>) -> Self {
self.socket = Some(path.as_ref().to_path_buf());
self
}
pub fn username(mut self, username: &str) -> Self {
self.username = username.to_owned();
self
}
pub fn password(mut self, password: &str) -> Self {
self.password = Some(password.to_owned());
self
}
pub fn database(mut self, database: &str) -> Self {
self.database = Some(database.to_owned());
self
}
pub fn get_database(&self) -> Option<&str> {
self.database.as_deref()
}
pub fn ssl_mode(mut self, mode: PgSslMode) -> Self {
self.ssl_mode = mode;
self
}
pub fn ssl_root_cert(mut self, cert: impl AsRef<Path>) -> Self {
self.ssl_root_cert = Some(CertificateInput::File(cert.as_ref().to_path_buf()));
self
}
pub fn ssl_root_cert_from_pem(mut self, pem_certificate: Vec<u8>) -> Self {
self.ssl_root_cert = Some(CertificateInput::Inline(pem_certificate));
self
}
pub fn statement_cache_capacity(mut self, capacity: usize) -> Self {
self.statement_cache_capacity = capacity;
self
}
pub fn application_name(mut self, application_name: &str) -> Self {
self.application_name = Some(application_name.to_owned());
self
}
pub fn extra_float_digits(mut self, extra_float_digits: impl Into<Option<i8>>) -> Self {
self.extra_float_digits = extra_float_digits.into().map(|it| it.to_string().into());
self
}
pub fn options<K, V, I>(mut self, options: I) -> Self
where
K: Display,
V: Display,
I: IntoIterator<Item = (K, V)>,
{
let options_str = self.options.get_or_insert_with(String::new);
for (k, v) in options {
if !options_str.is_empty() {
options_str.push(' ');
}
write!(options_str, "-c {}={}", k, v).expect("failed to write an option to the string");
}
self
}
pub(crate) fn fetch_socket(&self) -> Option<String> {
match self.socket {
Some(ref socket) => {
let full_path = format!("{}/.s.PGSQL.{}", socket.display(), self.port);
Some(full_path)
}
None if self.host.starts_with('/') => {
let full_path = format!("{}/.s.PGSQL.{}", self.host, self.port);
Some(full_path)
}
_ => None,
}
}
}
fn default_host(port: u16) -> String {
let socket = format!(".s.PGSQL.{}", port);
let candidates = [
"/var/run/postgresql", "/private/tmp", "/tmp", ];
for candidate in &candidates {
if Path::new(candidate).join(&socket).exists() {
return candidate.to_string();
}
}
"localhost".to_owned()
}
#[test]
fn test_options_formatting() {
let options = PgConnectOptions::new().options([("geqo", "off")]);
assert_eq!(options.options, Some("-c geqo=off".to_string()));
let options = options.options([("search_path", "sqlx")]);
assert_eq!(
options.options,
Some("-c geqo=off -c search_path=sqlx".to_string())
);
let options = PgConnectOptions::new().options([("geqo", "off"), ("statement_timeout", "5min")]);
assert_eq!(
options.options,
Some("-c geqo=off -c statement_timeout=5min".to_string())
);
let options = PgConnectOptions::new();
assert_eq!(options.options, None);
}