use crate::connection::ConnectionConfig;
use crate::{Result, WireError};
use std::path::{Component, Path, PathBuf};
use zeroize::Zeroizing;
fn split_host_port(host_port: &str) -> Result<(String, u16)> {
if host_port.starts_with('[') {
let close = host_port
.find(']')
.ok_or_else(|| WireError::Config("unclosed '[' in IPv6 address".into()))?;
let host = host_port[1..close].to_string();
let rest = &host_port[close + 1..];
let port = if let Some(port_str) = rest.strip_prefix(':') {
port_str
.parse()
.map_err(|_| WireError::Config("invalid port in IPv6 address".into()))?
} else {
5432
};
Ok((host, port))
} else if let Some(pos) = host_port.find(':') {
let (host, port_str) = host_port.split_at(pos);
let port = port_str[1..]
.parse()
.map_err(|_| WireError::Config("invalid port".into()))?;
Ok((host.to_string(), port))
} else {
Ok((host_port.to_string(), 5432))
}
}
const MAX_SOCKET_DIR_BYTES: usize = 4096;
pub fn validate_socket_dir(dir: &str) -> Result<()> {
if dir.len() > MAX_SOCKET_DIR_BYTES {
return Err(WireError::Config(format!(
"Unix socket directory path is too long ({} bytes, max {MAX_SOCKET_DIR_BYTES})",
dir.len()
)));
}
let p = Path::new(dir);
if !p.is_absolute() {
return Err(WireError::Config(format!(
"Unix socket directory must be an absolute path (got {dir:?})"
)));
}
if p.components().any(|c| c == Component::ParentDir) {
return Err(WireError::Config(format!(
"Unix socket directory must not contain '..' components (got {dir:?})"
)));
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct ConnectionInfo {
pub transport: TransportType,
pub host: Option<String>,
pub port: Option<u16>,
pub unix_socket: Option<PathBuf>,
pub database: String,
pub user: String,
pub password: Option<Zeroizing<String>>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum TransportType {
Tcp,
Unix,
}
fn resolve_default_socket_dir() -> Option<String> {
for dir in &["/run/postgresql", "/var/run/postgresql", "/tmp"] {
if Path::new(dir).is_dir() {
return Some((*dir).to_string());
}
}
None
}
pub fn parse_query_param(query_string: &str, param: &str) -> Option<String> {
if query_string.is_empty() {
return None;
}
let query = query_string.trim_start_matches('?');
for pair in query.split('&') {
if let Some((key, value)) = pair.split_once('=') {
if key == param {
return Some(value.to_string());
}
}
}
None
}
pub fn construct_socket_path(socket_dir: &str, port: u16) -> PathBuf {
PathBuf::from(format!("{}/.s.PGSQL.{}", socket_dir, port))
}
impl ConnectionInfo {
pub fn parse(s: &str) -> Result<Self> {
if !s.starts_with("postgres://") && !s.starts_with("postgresql://") {
return Err(WireError::Config(
"connection string must start with postgres://".into(),
));
}
let rest = s
.strip_prefix("postgres://")
.or_else(|| s.strip_prefix("postgresql://"))
.expect("prefix check above guarantees one of these prefixes is present");
if rest.starts_with('/') || rest.starts_with("///") {
return Self::parse_unix(rest);
}
Self::parse_tcp(rest)
}
fn parse_unix(rest: &str) -> Result<Self> {
let (path, query_string) = if let Some(q_pos) = rest.find('?') {
let (p, q) = rest.split_at(q_pos);
(p, q)
} else {
(rest, "")
};
let path = path.trim_start_matches('/');
let database = if path.is_empty() {
whoami::username()
} else {
path.to_string()
};
let port = parse_query_param(query_string, "port")
.and_then(|p| p.parse::<u16>().ok())
.unwrap_or(5432);
let socket_dir = if let Some(custom_dir) = parse_query_param(query_string, "host") {
validate_socket_dir(&custom_dir)?;
custom_dir
} else {
resolve_default_socket_dir().ok_or_else(|| {
WireError::Config(
"could not locate Unix socket directory. Set host query parameter explicitly."
.into(),
)
})?
};
let unix_socket = Some(construct_socket_path(&socket_dir, port));
Ok(Self {
transport: TransportType::Unix,
host: None,
port: Some(port),
unix_socket,
database,
user: whoami::username(),
password: None,
})
}
fn parse_tcp(rest: &str) -> Result<Self> {
let (auth, rest) = if let Some(pos) = rest.find('@') {
let (auth, rest) = rest.split_at(pos);
(Some(auth), &rest[1..])
} else {
(None, rest)
};
let (user, password) = if let Some(auth) = auth {
if let Some(pos) = auth.find(':') {
let (user, pass) = auth.split_at(pos);
(
user.to_string(),
Some(Zeroizing::new(pass[1..].to_string())),
)
} else {
(auth.to_string(), None)
}
} else {
(whoami::username(), None)
};
let (host_port, database) = if let Some(pos) = rest.find('/') {
let (hp, db) = rest.split_at(pos);
(hp, db[1..].to_string())
} else {
(rest, whoami::username())
};
let (host, port) = split_host_port(host_port)?;
Ok(Self {
transport: TransportType::Tcp,
host: Some(host),
port: Some(port),
unix_socket: None,
database,
user,
password,
})
}
pub fn to_config(&self) -> ConnectionConfig {
let mut config = ConnectionConfig::new(&self.database, &self.user);
if let Some(ref password) = self.password {
config = config.password(password.as_str());
}
config
}
}
#[cfg(test)]
mod tests;