use json::JsonValue;
#[derive(Clone, Debug)]
pub struct Config {
pub debug: bool,
pub username: String,
pub userpass: String,
pub database: String,
pub hostname: String,
pub hostport: i32,
pub charset: String,
pub pool_max: u32,
pub sslmode: String,
}
impl Config {
pub fn new(config: &JsonValue) -> Config {
let hostport = config["hostport"].as_i32().unwrap_or(5432);
let hostport = if (1..=65535).contains(&hostport) {
hostport
} else {
5432
};
let sanitize = |s: &str| -> String { s.replace('\0', "") };
Self {
debug: config["debug"].as_bool().unwrap_or(false),
username: sanitize(config["username"].as_str().unwrap_or("postgres")),
userpass: sanitize(config["userpass"].as_str().unwrap_or("111111")),
database: sanitize(config["database"].as_str().unwrap_or("postgres")),
hostname: sanitize(config["hostname"].as_str().unwrap_or("localhost")),
hostport,
charset: sanitize(config["charset"].as_str().unwrap_or("UTF8")),
pool_max: config["pool_max"].as_u32().unwrap_or(5),
sslmode: sanitize(config["sslmode"].as_str().unwrap_or("disable")),
}
}
pub fn url(&mut self) -> String {
format!("{}:{}", self.hostname, self.hostport)
}
pub fn pool_key(&self) -> String {
format!("{}:{}:{}", self.hostname, self.hostport, self.database)
}
pub fn from_url(url: &str) -> Result<Config, crate::error::PgsqlError> {
let s = url
.strip_prefix("postgresql://")
.or_else(|| url.strip_prefix("postgres://"))
.ok_or_else(|| {
crate::error::PgsqlError::Config(
"URL must start with postgresql:// or postgres://".into(),
)
})?;
let (main, query) = match s.find('?') {
Some(i) => (&s[..i], &s[i + 1..]),
None => (s, ""),
};
let mut qparams = std::collections::BTreeMap::new();
if !query.is_empty() {
for pair in query.split('&') {
if let Some((k, v)) = pair.split_once('=') {
qparams.insert(k.to_string(), v.to_string());
}
}
}
let (userinfo, rest) = match main.find('@') {
Some(i) => (&main[..i], &main[i + 1..]),
None => ("", main),
};
let (username, userpass) = if userinfo.is_empty() {
("postgres".to_string(), "".to_string())
} else {
match userinfo.split_once(':') {
Some((u, p)) => (u.to_string(), p.to_string()),
None => (userinfo.to_string(), "".to_string()),
}
};
let (hostport_str, database) = match rest.find('/') {
Some(i) => (&rest[..i], rest[i + 1..].to_string()),
None => (rest, "postgres".to_string()),
};
let (hostname, hostport) = match hostport_str.rfind(':') {
Some(i) => {
let port = hostport_str[i + 1..].parse::<i32>().unwrap_or(5432);
(hostport_str[..i].to_string(), port)
}
None => (hostport_str.to_string(), 5432),
};
let hostport = if (1..=65535).contains(&hostport) {
hostport
} else {
5432
};
let sanitize = |s: &str| -> String { s.replace('\0', "") };
Ok(Config {
debug: qparams.get("debug").is_some_and(|v| v == "true"),
username: sanitize(&username),
userpass: sanitize(&userpass),
database: sanitize(&database),
hostname: sanitize(&hostname),
hostport,
charset: sanitize(qparams.get("charset").map_or("UTF8", |v| v.as_str())),
pool_max: qparams
.get("pool_max")
.and_then(|v| v.parse().ok())
.unwrap_or(5),
sslmode: sanitize(qparams.get("sslmode").map_or("disable", |v| v.as_str())),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hostport_out_of_range_falls_back_to_default() {
let cfg = Config::new(&json::object! { "hostport": 99999 });
assert_eq!(cfg.hostport, 5432);
}
#[test]
fn hostport_zero_falls_back_to_default() {
let cfg = Config::new(&json::object! { "hostport": 0 });
assert_eq!(cfg.hostport, 5432);
}
#[test]
fn hostport_negative_falls_back_to_default() {
let cfg = Config::new(&json::object! { "hostport": -1 });
assert_eq!(cfg.hostport, 5432);
}
#[test]
fn hostport_valid_is_kept() {
let cfg = Config::new(&json::object! { "hostport": 5433 });
assert_eq!(cfg.hostport, 5433);
}
#[test]
fn sslmode_defaults_to_disable() {
let cfg = Config::new(&json::object! {});
assert_eq!(cfg.sslmode, "disable");
}
#[test]
fn sslmode_parsed_from_config() {
let cfg = Config::new(&json::object! { "sslmode": "require" });
assert_eq!(cfg.sslmode, "require");
}
#[test]
fn from_url_full() {
let cfg = Config::from_url(
"postgresql://admin:secret@db.example.com:5433/mydb?sslmode=require&charset=UTF8",
)
.unwrap();
assert_eq!(cfg.username, "admin");
assert_eq!(cfg.userpass, "secret");
assert_eq!(cfg.hostname, "db.example.com");
assert_eq!(cfg.hostport, 5433);
assert_eq!(cfg.database, "mydb");
assert_eq!(cfg.sslmode, "require");
assert_eq!(cfg.charset, "UTF8");
}
#[test]
fn from_url_minimal() {
let cfg = Config::from_url("postgresql://localhost/testdb").unwrap();
assert_eq!(cfg.hostname, "localhost");
assert_eq!(cfg.hostport, 5432);
assert_eq!(cfg.database, "testdb");
assert_eq!(cfg.username, "postgres");
}
#[test]
fn from_url_postgres_scheme() {
let cfg = Config::from_url("postgres://u:p@host:1234/db").unwrap();
assert_eq!(cfg.username, "u");
assert_eq!(cfg.hostname, "host");
assert_eq!(cfg.hostport, 1234);
}
#[test]
fn from_url_invalid_scheme() {
let result = Config::from_url("mysql://localhost/db");
assert!(result.is_err());
}
}