use std::borrow::Cow;
#[cfg(unix)]
use std::ffi::OsStr;
use std::net::IpAddr;
use std::ops::Deref;
#[cfg(unix)]
use std::os::unix::ffi::OsStrExt;
#[cfg(unix)]
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::time::Duration;
use std::{fmt, iter, mem, str};
use crate::error::PgWireClientError;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum TargetSessionAttrs {
Any,
ReadWrite,
ReadOnly,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslMode {
Disable,
Prefer,
Require,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum SslNegotiation {
Postgres,
Direct,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum ChannelBinding {
Disable,
Prefer,
Require,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum LoadBalanceHosts {
Disable,
Random,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Host {
Tcp(String),
#[cfg(unix)]
Unix(PathBuf),
}
impl Host {
pub(crate) fn get_hostname(&self) -> Option<String> {
match self {
Host::Tcp(host) => Some(host.clone()),
#[cfg(unix)]
Host::Unix(_) => None,
}
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Config {
pub(crate) user: Option<String>,
pub(crate) password: Option<Vec<u8>>,
pub(crate) dbname: Option<String>,
pub(crate) options: Option<String>,
pub(crate) application_name: Option<String>,
pub(crate) ssl_mode: SslMode,
pub(crate) ssl_negotiation: SslNegotiation,
pub(crate) host: Vec<Host>,
pub(crate) hostaddr: Vec<IpAddr>,
pub(crate) port: Vec<u16>,
pub(crate) connect_timeout: Option<Duration>,
pub(crate) tcp_user_timeout: Option<Duration>,
pub(crate) keepalives: bool,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) keepalive_config: KeepaliveConfig,
pub(crate) target_session_attrs: TargetSessionAttrs,
pub(crate) channel_binding: ChannelBinding,
pub(crate) load_balance_hosts: LoadBalanceHosts,
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
pub fn new() -> Config {
Config {
user: None,
password: None,
dbname: None,
options: None,
application_name: None,
ssl_mode: SslMode::Prefer,
ssl_negotiation: SslNegotiation::Postgres,
host: vec![],
hostaddr: vec![],
port: vec![],
connect_timeout: None,
tcp_user_timeout: None,
keepalives: true,
#[cfg(not(target_arch = "wasm32"))]
keepalive_config: KeepaliveConfig {
idle: Duration::from_secs(2 * 60 * 60),
interval: None,
retries: None,
},
target_session_attrs: TargetSessionAttrs::Any,
channel_binding: ChannelBinding::Prefer,
load_balance_hosts: LoadBalanceHosts::Disable,
}
}
pub fn user(&mut self, user: impl Into<String>) -> &mut Config {
self.user = Some(user.into());
self
}
pub fn get_user(&self) -> Option<&str> {
self.user.as_deref()
}
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.password = Some(password.as_ref().to_vec());
self
}
pub fn get_password(&self) -> Option<&[u8]> {
self.password.as_deref()
}
pub fn dbname(&mut self, dbname: impl Into<String>) -> &mut Config {
self.dbname = Some(dbname.into());
self
}
pub fn get_dbname(&self) -> Option<&str> {
self.dbname.as_deref()
}
pub fn options(&mut self, options: impl Into<String>) -> &mut Config {
self.options = Some(options.into());
self
}
pub fn get_options(&self) -> Option<&str> {
self.options.as_deref()
}
pub fn application_name(&mut self, application_name: impl Into<String>) -> &mut Config {
self.application_name = Some(application_name.into());
self
}
pub fn get_application_name(&self) -> Option<&str> {
self.application_name.as_deref()
}
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.ssl_mode = ssl_mode;
self
}
pub fn get_ssl_mode(&self) -> SslMode {
self.ssl_mode
}
pub fn ssl_negotiation(&mut self, ssl_negotiation: SslNegotiation) -> &mut Config {
self.ssl_negotiation = ssl_negotiation;
self
}
pub fn get_ssl_negotiation(&self) -> SslNegotiation {
self.ssl_negotiation
}
pub fn host(&mut self, host: impl Into<String>) -> &mut Config {
let host = host.into();
#[cfg(unix)]
{
if host.starts_with('/') {
return self.host_path(host);
}
}
self.host.push(Host::Tcp(host));
self
}
pub(crate) fn get_hosts(&self) -> &[Host] {
&self.host
}
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.hostaddr.deref()
}
#[cfg(unix)]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.host.push(Host::Unix(host.as_ref().to_path_buf()));
self
}
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.hostaddr.push(hostaddr);
self
}
pub fn port(&mut self, port: u16) -> &mut Config {
self.port.push(port);
self
}
pub fn get_ports(&self) -> &[u16] {
&self.port
}
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.connect_timeout = Some(connect_timeout);
self
}
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.connect_timeout.as_ref()
}
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.tcp_user_timeout = Some(tcp_user_timeout);
self
}
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.tcp_user_timeout.as_ref()
}
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.keepalives = keepalives;
self
}
pub fn get_keepalives(&self) -> bool {
self.keepalives
}
#[cfg(not(target_arch = "wasm32"))]
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.keepalive_config.idle = keepalives_idle;
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_keepalives_idle(&self) -> Duration {
self.keepalive_config.idle
}
#[cfg(not(target_arch = "wasm32"))]
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.keepalive_config.interval = Some(keepalives_interval);
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.keepalive_config.interval
}
#[cfg(not(target_arch = "wasm32"))]
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.keepalive_config.retries = Some(keepalives_retries);
self
}
#[cfg(not(target_arch = "wasm32"))]
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.keepalive_config.retries
}
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.target_session_attrs = target_session_attrs;
self
}
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.target_session_attrs
}
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.channel_binding = channel_binding;
self
}
pub fn get_channel_binding(&self) -> ChannelBinding {
self.channel_binding
}
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.load_balance_hosts = load_balance_hosts;
self
}
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.load_balance_hosts
}
fn param(&mut self, key: &str, value: &str) -> Result<(), PgWireClientError> {
match key {
"user" => {
self.user(value);
}
"password" => {
self.password(value);
}
"dbname" => {
self.dbname(value);
}
"options" => {
self.options(value);
}
"application_name" => {
self.application_name(value);
}
"sslmode" => {
let mode = match value {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
_ => return Err(PgWireClientError::InvalidConfig("sslmode".into())),
};
self.ssl_mode(mode);
}
"sslnegotiation" => {
let mode = match value {
"postgres" => SslNegotiation::Postgres,
"direct" => SslNegotiation::Direct,
_ => return Err(PgWireClientError::InvalidConfig("sslnegotiation".into())),
};
self.ssl_negotiation(mode);
}
"host" => {
for host in value.split(',') {
self.host(host);
}
}
"hostaddr" => {
for hostaddr in value.split(',') {
let addr = hostaddr
.parse()
.map_err(|_| PgWireClientError::InvalidConfig("hostaddr".into()))?;
self.hostaddr(addr);
}
}
"port" => {
for port in value.split(',') {
let port = if port.is_empty() {
5432
} else {
port.parse()
.map_err(|_| PgWireClientError::InvalidConfig("port".into()))?
};
self.port(port);
}
}
"connect_timeout" => {
let timeout = value
.parse::<i64>()
.map_err(|_| PgWireClientError::InvalidConfig("connect_timeout".into()))?;
if timeout > 0 {
self.connect_timeout(Duration::from_secs(timeout as u64));
}
}
"tcp_user_timeout" => {
let timeout = value
.parse::<i64>()
.map_err(|_| PgWireClientError::InvalidConfig("tcp_user_timeout".into()))?;
if timeout > 0 {
self.tcp_user_timeout(Duration::from_secs(timeout as u64));
}
}
#[cfg(not(target_arch = "wasm32"))]
"keepalives" => {
let keepalives = value
.parse::<u64>()
.map_err(|_| PgWireClientError::InvalidConfig("keepalives".into()))?;
self.keepalives(keepalives != 0);
}
#[cfg(not(target_arch = "wasm32"))]
"keepalives_idle" => {
let keepalives_idle = value
.parse::<i64>()
.map_err(|_| PgWireClientError::InvalidConfig("keepalives_idle".into()))?;
if keepalives_idle > 0 {
self.keepalives_idle(Duration::from_secs(keepalives_idle as u64));
}
}
#[cfg(not(target_arch = "wasm32"))]
"keepalives_interval" => {
let keepalives_interval = value
.parse::<i64>()
.map_err(|_| PgWireClientError::InvalidConfig("keepalives_interval".into()))?;
if keepalives_interval > 0 {
self.keepalives_interval(Duration::from_secs(keepalives_interval as u64));
}
}
#[cfg(not(target_arch = "wasm32"))]
"keepalives_retries" => {
let keepalives_retries = value
.parse::<u32>()
.map_err(|_| PgWireClientError::InvalidConfig("keepalives_retries".into()))?;
self.keepalives_retries(keepalives_retries);
}
"target_session_attrs" => {
let target_session_attrs = match value {
"any" => TargetSessionAttrs::Any,
"read-write" => TargetSessionAttrs::ReadWrite,
"read-only" => TargetSessionAttrs::ReadOnly,
_ => {
return Err(PgWireClientError::InvalidConfig(
"target_session_attrs".into(),
));
}
};
self.target_session_attrs(target_session_attrs);
}
"channel_binding" => {
let channel_binding = match value {
"disable" => ChannelBinding::Disable,
"prefer" => ChannelBinding::Prefer,
"require" => ChannelBinding::Require,
_ => {
return Err(PgWireClientError::InvalidConfig("channel_binding".into()));
}
};
self.channel_binding(channel_binding);
}
"load_balance_hosts" => {
let load_balance_hosts = match value {
"disable" => LoadBalanceHosts::Disable,
"random" => LoadBalanceHosts::Random,
_ => {
return Err(PgWireClientError::InvalidConfig(
"load_balance_hosts".into(),
));
}
};
self.load_balance_hosts(load_balance_hosts);
}
key => {
return Err(PgWireClientError::UnknownConfig(key.to_string()));
}
}
Ok(())
}
}
impl FromStr for Config {
type Err = PgWireClientError;
fn from_str(s: &str) -> Result<Config, PgWireClientError> {
match UrlParser::parse(s)? {
Some(config) => Ok(config),
None => Parser::parse(s),
}
}
}
impl fmt::Debug for Config {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Redaction {}
impl fmt::Debug for Redaction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "_")
}
}
let mut config_dbg = &mut f.debug_struct("Config");
config_dbg = config_dbg
.field("user", &self.user)
.field("password", &self.password.as_ref().map(|_| Redaction {}))
.field("dbname", &self.dbname)
.field("options", &self.options)
.field("application_name", &self.application_name)
.field("ssl_mode", &self.ssl_mode)
.field("host", &self.host)
.field("hostaddr", &self.hostaddr)
.field("port", &self.port)
.field("connect_timeout", &self.connect_timeout)
.field("tcp_user_timeout", &self.tcp_user_timeout)
.field("keepalives", &self.keepalives);
#[cfg(not(target_arch = "wasm32"))]
{
config_dbg = config_dbg
.field("keepalives_idle", &self.keepalive_config.idle)
.field("keepalives_interval", &self.keepalive_config.interval)
.field("keepalives_retries", &self.keepalive_config.retries);
}
config_dbg
.field("target_session_attrs", &self.target_session_attrs)
.field("channel_binding", &self.channel_binding)
.finish()
}
}
struct Parser<'a> {
s: &'a str,
it: iter::Peekable<str::CharIndices<'a>>,
}
impl<'a> Parser<'a> {
fn parse(s: &'a str) -> Result<Config, PgWireClientError> {
let mut parser = Parser {
s,
it: s.char_indices().peekable(),
};
let mut config = Config::new();
while let Some((key, value)) = parser.parameter()? {
config.param(key, &value)?;
}
Ok(config)
}
fn skip_ws(&mut self) {
self.take_while(char::is_whitespace);
}
fn take_while<F>(&mut self, f: F) -> &'a str
where
F: Fn(char) -> bool,
{
let start = match self.it.peek() {
Some(&(i, _)) => i,
None => return "",
};
loop {
match self.it.peek() {
Some(&(_, c)) if f(c) => {
self.it.next();
}
Some(&(i, _)) => return &self.s[start..i],
None => return &self.s[start..],
}
}
}
fn eat(&mut self, target: char) -> Result<(), PgWireClientError> {
match self.it.next() {
Some((_, c)) if c == target => Ok(()),
Some((i, c)) => {
let m =
format!("unexpected character at byte {i}: expected `{target}` but got `{c}`",);
Err(PgWireClientError::InvalidConfig(m))
}
None => Err(PgWireClientError::InvalidConfig("unexpected EOF".into())),
}
}
fn eat_if(&mut self, target: char) -> bool {
match self.it.peek() {
Some(&(_, c)) if c == target => {
self.it.next();
true
}
_ => false,
}
}
fn keyword(&mut self) -> Option<&'a str> {
let s = self.take_while(|c| match c {
c if c.is_whitespace() => false,
'=' => false,
_ => true,
});
if s.is_empty() { None } else { Some(s) }
}
fn value(&mut self) -> Result<String, PgWireClientError> {
let value = if self.eat_if('\'') {
let value = self.quoted_value()?;
self.eat('\'')?;
value
} else {
self.simple_value()?
};
Ok(value)
}
fn simple_value(&mut self) -> Result<String, PgWireClientError> {
let mut value = String::new();
while let Some(&(_, c)) = self.it.peek() {
if c.is_whitespace() {
break;
}
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
value.push(c);
}
}
if value.is_empty() {
return Err(PgWireClientError::InvalidConfig("unexpected EOF".into()));
}
Ok(value)
}
fn quoted_value(&mut self) -> Result<String, PgWireClientError> {
let mut value = String::new();
while let Some(&(_, c)) = self.it.peek() {
if c == '\'' {
return Ok(value);
}
self.it.next();
if c == '\\' {
if let Some((_, c2)) = self.it.next() {
value.push(c2);
}
} else {
value.push(c);
}
}
Err(PgWireClientError::InvalidConfig(
"unterminated quoted connection parameter value".into(),
))
}
fn parameter(&mut self) -> Result<Option<(&'a str, String)>, PgWireClientError> {
self.skip_ws();
let keyword = match self.keyword() {
Some(keyword) => keyword,
None => return Ok(None),
};
self.skip_ws();
self.eat('=')?;
self.skip_ws();
let value = self.value()?;
Ok(Some((keyword, value)))
}
}
struct UrlParser<'a> {
s: &'a str,
config: Config,
}
impl<'a> UrlParser<'a> {
fn parse(s: &'a str) -> Result<Option<Config>, PgWireClientError> {
let s = match Self::remove_url_prefix(s) {
Some(s) => s,
None => return Ok(None),
};
let mut parser = UrlParser {
s,
config: Config::new(),
};
parser.parse_credentials()?;
parser.parse_host()?;
parser.parse_path()?;
parser.parse_params()?;
Ok(Some(parser.config))
}
fn remove_url_prefix(s: &str) -> Option<&str> {
for prefix in &["postgres://", "postgresql://"] {
if let Some(stripped) = s.strip_prefix(prefix) {
return Some(stripped);
}
}
None
}
fn take_until(&mut self, end: &[char]) -> Option<&'a str> {
match self.s.find(end) {
Some(pos) => {
let (head, tail) = self.s.split_at(pos);
self.s = tail;
Some(head)
}
None => None,
}
}
fn take_all(&mut self) -> &'a str {
mem::take(&mut self.s)
}
fn eat_byte(&mut self) {
self.s = &self.s[1..];
}
fn parse_credentials(&mut self) -> Result<(), PgWireClientError> {
let creds = match self.take_until(&['@']) {
Some(creds) => creds,
None => return Ok(()),
};
self.eat_byte();
let mut it = creds.splitn(2, ':');
let user = self.decode(it.next().unwrap())?;
self.config.user(user);
if let Some(password) = it.next() {
let password = Cow::from(percent_encoding::percent_decode(password.as_bytes()));
self.config.password(password);
}
Ok(())
}
fn parse_host(&mut self) -> Result<(), PgWireClientError> {
let host = match self.take_until(&['/', '?']) {
Some(host) => host,
None => self.take_all(),
};
if host.is_empty() {
return Ok(());
}
for chunk in host.split(',') {
let (host, port) = if chunk.starts_with('[') {
let idx = match chunk.find(']') {
Some(idx) => idx,
None => return Err(PgWireClientError::InvalidConfig("host".into())),
};
let host = &chunk[1..idx];
let remaining = &chunk[idx + 1..];
let port = if let Some(port) = remaining.strip_prefix(':') {
Some(port)
} else if remaining.is_empty() {
None
} else {
return Err(PgWireClientError::InvalidConfig("host".into()));
};
(host, port)
} else {
let mut it = chunk.splitn(2, ':');
(it.next().unwrap(), it.next())
};
self.host_param(host)?;
let port = self.decode(port.unwrap_or("5432"))?;
self.config.param("port", &port)?;
}
Ok(())
}
fn parse_path(&mut self) -> Result<(), PgWireClientError> {
if !self.s.starts_with('/') {
return Ok(());
}
self.eat_byte();
let dbname = match self.take_until(&['?']) {
Some(dbname) => dbname,
None => self.take_all(),
};
if !dbname.is_empty() {
self.config.dbname(self.decode(dbname)?);
}
Ok(())
}
fn parse_params(&mut self) -> Result<(), PgWireClientError> {
if !self.s.starts_with('?') {
return Ok(());
}
self.eat_byte();
while !self.s.is_empty() {
let key = match self.take_until(&['=']) {
Some(key) => self.decode(key)?,
None => {
return Err(PgWireClientError::InvalidConfig(
"unterminated parameter".into(),
));
}
};
self.eat_byte();
let value = match self.take_until(&['&']) {
Some(value) => {
self.eat_byte();
value
}
None => self.take_all(),
};
if key == "host" {
self.host_param(value)?;
} else {
let value = self.decode(value)?;
self.config.param(&key, &value)?;
}
}
Ok(())
}
#[cfg(unix)]
fn host_param(&mut self, s: &str) -> Result<(), PgWireClientError> {
let decoded = Cow::from(percent_encoding::percent_decode(s.as_bytes()));
if decoded.first() == Some(&b'/') {
self.config.host_path(OsStr::from_bytes(&decoded));
} else {
let decoded =
str::from_utf8(&decoded).map_err(PgWireClientError::InvalidUtf8ConfigValue)?;
self.config.host(decoded);
}
Ok(())
}
#[cfg(not(unix))]
fn host_param(&mut self, s: &str) -> Result<(), PgWireClientError> {
let s = self.decode(s)?;
self.config.param("host", &s)
}
fn decode(&self, s: &'a str) -> Result<Cow<'a, str>, PgWireClientError> {
percent_encoding::percent_decode(s.as_bytes())
.decode_utf8()
.map_err(PgWireClientError::InvalidUtf8ConfigValue)
}
}
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct KeepaliveConfig {
pub idle: Duration,
pub interval: Option<Duration>,
pub retries: Option<u32>,
}
#[cfg(test)]
mod tests {
use std::net::IpAddr;
use super::*;
#[test]
fn test_simple_parsing() {
let s = "user=pass_user dbname=postgres host=host1,host2 hostaddr=127.0.0.1,127.0.0.2 port=26257";
let config = s.parse::<Config>().unwrap();
assert_eq!(Some("pass_user"), config.get_user());
assert_eq!(Some("postgres"), config.get_dbname());
assert_eq!(
[
Host::Tcp("host1".to_string()),
Host::Tcp("host2".to_string())
],
config.get_hosts(),
);
assert_eq!(
[
"127.0.0.1".parse::<IpAddr>().unwrap(),
"127.0.0.2".parse::<IpAddr>().unwrap()
],
config.get_hostaddrs(),
);
assert_eq!(1, 1);
}
#[test]
fn test_invalid_hostaddr_parsing() {
let s = "user=pass_user dbname=postgres host=host1 hostaddr=127.0.0 port=26257";
s.parse::<Config>().err().unwrap();
}
}