use std::path::PathBuf;
use std::time::Duration;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SslMode {
Disable,
#[default]
Prefer,
Require,
VerifyCa,
VerifyFull,
}
#[derive(Debug, Clone)]
pub struct Config {
pub(crate) hosts: Vec<(String, u16)>,
pub(crate) database: String,
pub(crate) user: String,
pub(crate) password: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) application_name: Option<String>,
pub(crate) connect_timeout: Duration,
pub(crate) statement_timeout: Option<Duration>,
pub(crate) _keepalive: Option<Duration>,
pub(crate) _keepalive_idle: Option<Duration>,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) _extra_float_digits: Option<i32>,
pub(crate) load_balance_hosts: LoadBalanceHosts,
pub(crate) ssl_client_cert: Option<std::path::PathBuf>,
pub(crate) ssl_client_key: Option<std::path::PathBuf>,
pub(crate) ssl_direct: bool,
pub(crate) channel_binding: ChannelBinding,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ChannelBinding {
#[default]
Prefer,
Require,
Disable,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TargetSessionAttrs {
#[default]
Any,
ReadWrite,
ReadOnly,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LoadBalanceHosts {
#[default]
Disable,
Random,
}
impl Config {
pub fn parse(s: &str) -> Result<Self> {
let s = s.trim();
let without_scheme = s
.strip_prefix("postgres://")
.or_else(|| s.strip_prefix("postgresql://"))
.ok_or_else(|| {
Error::Config(
"connection string must start with postgres:// or postgresql://".into(),
)
})?;
let (userinfo, rest) = match without_scheme.split_once('@') {
Some((ui, rest)) => (Some(ui), rest),
None => (None, without_scheme),
};
let (user, password) = match userinfo {
Some(ui) => match ui.split_once(':') {
Some((u, p)) => (percent_decode(u)?, Some(percent_decode(p)?)),
None => (percent_decode(ui)?, None),
},
None => (String::new(), None),
};
let (hostport, db_and_params) = match rest.split_once('/') {
Some((hp, rest)) => (hp, Some(rest)),
None => (rest, None),
};
let mut hosts: Vec<(String, u16)> = Vec::new();
if hostport.is_empty() {
} else {
for entry in hostport.split(',') {
let (h, p) = match entry.rsplit_once(':') {
Some((h, p)) => {
let port: u16 = p
.parse()
.map_err(|_| Error::Config(format!("invalid port: {p}")))?;
(h.to_string(), port)
}
None => (entry.to_string(), 5432),
};
hosts.push((h, p));
}
}
let (database, params_str) = match db_and_params {
Some(dp) => match dp.split_once('?') {
Some((db, params)) => (percent_decode(db)?, Some(params.to_string())),
None => (percent_decode(dp)?, None),
},
None => (String::new(), None),
};
let mut config = ConfigBuilder::new();
for (h, p) in &hosts {
config = config.host_port(h.clone(), *p);
}
config = config.database(database).user(user);
if let Some(pw) = password {
config = config.password(pw);
}
if let Some(params) = params_str {
for param in params.split('&') {
let (key, value) = param
.split_once('=')
.ok_or_else(|| Error::Config(format!("invalid parameter: {param}")))?;
let value = percent_decode(value)?;
match key {
"sslmode" => {
config = config.ssl_mode(match value.as_str() {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
"verify-ca" => SslMode::VerifyCa,
"verify-full" => SslMode::VerifyFull,
_ => return Err(Error::Config(format!("invalid sslmode: {value}"))),
});
}
"application_name" => {
config = config.application_name(value);
}
"connect_timeout" => {
let secs: u64 = value.parse().map_err(|_| {
Error::Config(format!("invalid connect_timeout: {value}"))
})?;
config = config.connect_timeout(Duration::from_secs(secs));
}
"statement_timeout" => {
let secs: u64 = value.parse().map_err(|_| {
Error::Config(format!("invalid statement_timeout: {value}"))
})?;
config = config.statement_timeout(Duration::from_secs(secs));
}
"target_session_attrs" => {
config = config.target_session_attrs(match value.as_str() {
"any" => TargetSessionAttrs::Any,
"read-write" => TargetSessionAttrs::ReadWrite,
"read-only" => TargetSessionAttrs::ReadOnly,
_ => {
return Err(Error::Config(format!(
"invalid target_session_attrs: {value}"
)))
}
});
}
"sslcert" => {
config = config.ssl_client_cert(PathBuf::from(value));
}
"sslkey" => {
config = config.ssl_client_key(PathBuf::from(value));
}
"ssldirect" | "sslnegotiation" => {
let direct = match value.as_str() {
"true" | "direct" => true,
"false" | "postgres" => false,
_ => return Err(Error::Config(format!("invalid {key}: {value}"))),
};
config = config.ssl_direct(direct);
}
"channel_binding" => {
config = config.channel_binding(match value.as_str() {
"prefer" => ChannelBinding::Prefer,
"require" => ChannelBinding::Require,
"disable" => ChannelBinding::Disable,
_ => {
return Err(Error::Config(format!(
"invalid channel_binding: {value}"
)))
}
});
}
"load_balance_hosts" => {
config = config.load_balance_hosts(match value.as_str() {
"disable" => LoadBalanceHosts::Disable,
"random" => LoadBalanceHosts::Random,
_ => {
return Err(Error::Config(format!(
"invalid load_balance_hosts: {value}"
)))
}
});
}
"host" => {
config = config.host_port(value, 5432);
}
_ => {
}
}
}
}
Ok(config.build())
}
pub fn builder() -> ConfigBuilder {
ConfigBuilder::new()
}
pub fn host(&self) -> &str {
self.hosts.first().map_or("localhost", |(h, _)| h.as_str())
}
pub fn port(&self) -> u16 {
self.hosts.first().map_or(5432, |(_, p)| *p)
}
pub fn hosts(&self) -> &[(String, u16)] {
&self.hosts
}
pub fn load_balance_hosts(&self) -> LoadBalanceHosts {
self.load_balance_hosts
}
pub fn target_session_attrs(&self) -> TargetSessionAttrs {
self.target_session_attrs
}
pub fn database(&self) -> &str {
&self.database
}
pub fn user(&self) -> &str {
&self.user
}
pub fn password(&self) -> Option<&str> {
self.password.as_deref()
}
pub fn ssl_mode(&self) -> SslMode {
self.ssl_mode
}
pub fn application_name(&self) -> Option<&str> {
self.application_name.as_deref()
}
pub fn connect_timeout(&self) -> Duration {
self.connect_timeout
}
pub fn statement_timeout(&self) -> Option<Duration> {
self.statement_timeout
}
pub fn ssl_client_cert(&self) -> Option<&std::path::Path> {
self.ssl_client_cert.as_deref()
}
pub fn ssl_client_key(&self) -> Option<&std::path::Path> {
self.ssl_client_key.as_deref()
}
pub fn ssl_direct(&self) -> bool {
self.ssl_direct
}
pub fn channel_binding(&self) -> ChannelBinding {
self.channel_binding
}
}
#[derive(Debug, Clone)]
pub struct ConfigBuilder {
hosts: Vec<(String, u16)>,
default_port: u16,
database: String,
user: String,
password: Option<String>,
ssl_mode: SslMode,
application_name: Option<String>,
connect_timeout: Duration,
statement_timeout: Option<Duration>,
keepalive: Option<Duration>,
keepalive_idle: Option<Duration>,
target_session_attrs: TargetSessionAttrs,
extra_float_digits: Option<i32>,
load_balance_hosts: LoadBalanceHosts,
ssl_client_cert: Option<PathBuf>,
ssl_client_key: Option<PathBuf>,
ssl_direct: bool,
channel_binding: ChannelBinding,
}
impl ConfigBuilder {
fn new() -> Self {
Self {
hosts: Vec::new(),
default_port: 5432,
database: String::new(),
user: String::new(),
password: None,
ssl_mode: SslMode::default(),
application_name: None,
connect_timeout: Duration::from_secs(10),
statement_timeout: None,
keepalive: Some(Duration::from_secs(60)),
keepalive_idle: None,
target_session_attrs: TargetSessionAttrs::default(),
extra_float_digits: Some(3),
load_balance_hosts: LoadBalanceHosts::default(),
ssl_client_cert: None,
ssl_client_key: None,
ssl_direct: false,
channel_binding: ChannelBinding::default(),
}
}
pub fn host(mut self, host: impl Into<String>) -> Self {
self.hosts.push((host.into(), self.default_port));
self
}
pub fn host_port(mut self, host: impl Into<String>, port: u16) -> Self {
self.hosts.push((host.into(), port));
self
}
pub fn port(mut self, port: u16) -> Self {
let old_default = self.default_port;
self.default_port = port;
for (_, p) in &mut self.hosts {
if *p == old_default {
*p = port;
}
}
self
}
pub fn load_balance_hosts(mut self, strategy: LoadBalanceHosts) -> Self {
self.load_balance_hosts = strategy;
self
}
pub fn database(mut self, database: impl Into<String>) -> Self {
self.database = database.into();
self
}
pub fn user(mut self, user: impl Into<String>) -> Self {
self.user = user.into();
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(password.into());
self
}
pub fn ssl_mode(mut self, ssl_mode: SslMode) -> Self {
self.ssl_mode = ssl_mode;
self
}
pub fn application_name(mut self, name: impl Into<String>) -> Self {
self.application_name = Some(name.into());
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn statement_timeout(mut self, timeout: Duration) -> Self {
self.statement_timeout = Some(timeout);
self
}
pub fn keepalive(mut self, interval: Duration) -> Self {
self.keepalive = Some(interval);
self
}
pub fn target_session_attrs(mut self, attrs: TargetSessionAttrs) -> Self {
self.target_session_attrs = attrs;
self
}
pub fn ssl_client_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.ssl_client_cert = Some(path.into());
self
}
pub fn ssl_client_key(mut self, path: impl Into<PathBuf>) -> Self {
self.ssl_client_key = Some(path.into());
self
}
pub fn ssl_direct(mut self, direct: bool) -> Self {
self.ssl_direct = direct;
self
}
pub fn channel_binding(mut self, binding: ChannelBinding) -> Self {
self.channel_binding = binding;
self
}
pub fn build(self) -> Config {
let hosts = if self.hosts.is_empty() {
vec![("localhost".to_string(), self.default_port)]
} else {
self.hosts
};
Config {
hosts,
database: self.database,
user: self.user,
password: self.password,
ssl_mode: self.ssl_mode,
application_name: self.application_name,
connect_timeout: self.connect_timeout,
statement_timeout: self.statement_timeout,
_keepalive: self.keepalive,
_keepalive_idle: self.keepalive_idle,
target_session_attrs: self.target_session_attrs,
_extra_float_digits: self.extra_float_digits,
load_balance_hosts: self.load_balance_hosts,
ssl_client_cert: self.ssl_client_cert,
ssl_client_key: self.ssl_client_key,
ssl_direct: self.ssl_direct,
channel_binding: self.channel_binding,
}
}
}
fn percent_decode(s: &str) -> Result<String> {
let mut result = String::with_capacity(s.len());
let mut chars = s.as_bytes().iter();
while let Some(&b) = chars.next() {
if b == b'%' {
let hi = chars
.next()
.ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
let lo = chars
.next()
.ok_or_else(|| Error::Config("incomplete percent encoding".into()))?;
let byte = hex_digit(*hi)? << 4 | hex_digit(*lo)?;
result.push(byte as char);
} else {
result.push(b as char);
}
}
Ok(result)
}
fn hex_digit(b: u8) -> Result<u8> {
match b {
b'0'..=b'9' => Ok(b - b'0'),
b'a'..=b'f' => Ok(b - b'a' + 10),
b'A'..=b'F' => Ok(b - b'A' + 10),
_ => Err(Error::Config(format!("invalid hex digit: {}", b as char))),
}
}