use std::collections::HashMap;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum SslMode {
#[default]
Disable,
Prefer,
Require,
VerifyCa,
VerifyFull,
}
impl SslMode {
pub const fn should_try_ssl(self) -> bool {
!matches!(self, SslMode::Disable)
}
pub const fn is_required(self) -> bool {
matches!(
self,
SslMode::Require | SslMode::VerifyCa | SslMode::VerifyFull
)
}
}
#[derive(Debug, Clone)]
pub struct PgConfig {
pub host: String,
pub port: u16,
pub user: String,
pub password: Option<String>,
pub database: String,
pub application_name: Option<String>,
pub connect_timeout: Duration,
pub ssl_mode: SslMode,
pub options: HashMap<String, String>,
}
impl Default for PgConfig {
fn default() -> Self {
Self {
host: "localhost".to_string(),
port: 5432,
user: String::new(),
password: None,
database: String::new(),
application_name: None,
connect_timeout: Duration::from_secs(30),
ssl_mode: SslMode::default(),
options: HashMap::new(),
}
}
}
impl PgConfig {
pub fn new(
host: impl Into<String>,
user: impl Into<String>,
database: impl Into<String>,
) -> Self {
Self {
host: host.into(),
user: user.into(),
database: database.into(),
..Default::default()
}
}
pub fn port(mut self, port: u16) -> Self {
self.port = port;
self
}
pub fn password(mut self, password: impl Into<String>) -> Self {
self.password = Some(password.into());
self
}
pub fn application_name(mut self, name: impl Into<String>) -> Self {
self.application_name = Some(name.into());
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.connect_timeout = timeout;
self
}
pub fn ssl_mode(mut self, mode: SslMode) -> Self {
self.ssl_mode = mode;
self
}
pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.options.insert(key.into(), value.into());
self
}
pub fn startup_params(&self) -> Vec<(String, String)> {
let mut params = vec![
("user".to_string(), self.user.clone()),
("database".to_string(), self.database.clone()),
("client_encoding".to_string(), "UTF8".to_string()),
];
if let Some(app_name) = &self.application_name {
params.push(("application_name".to_string(), app_name.clone()));
}
for (k, v) in &self.options {
params.push((k.clone(), v.clone()));
}
params
}
pub fn socket_addr(&self) -> String {
format!("{}:{}", self.host, self.port)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_config_builder() {
let config = PgConfig::new("localhost", "postgres", "testdb")
.port(5433)
.password("secret")
.application_name("myapp")
.connect_timeout(Duration::from_secs(10))
.ssl_mode(SslMode::Prefer)
.option("timezone", "UTC");
assert_eq!(config.host, "localhost");
assert_eq!(config.port, 5433);
assert_eq!(config.user, "postgres");
assert_eq!(config.database, "testdb");
assert_eq!(config.password, Some("secret".to_string()));
assert_eq!(config.application_name, Some("myapp".to_string()));
assert_eq!(config.connect_timeout, Duration::from_secs(10));
assert_eq!(config.ssl_mode, SslMode::Prefer);
assert_eq!(config.options.get("timezone"), Some(&"UTC".to_string()));
}
#[test]
fn test_startup_params() {
let config = PgConfig::new("localhost", "postgres", "testdb")
.application_name("myapp")
.option("timezone", "UTC");
let params = config.startup_params();
assert!(params.iter().any(|(k, v)| k == "user" && v == "postgres"));
assert!(params.iter().any(|(k, v)| k == "database" && v == "testdb"));
assert!(
params
.iter()
.any(|(k, v)| k == "client_encoding" && v == "UTF8")
);
assert!(
params
.iter()
.any(|(k, v)| k == "application_name" && v == "myapp")
);
assert!(params.iter().any(|(k, v)| k == "timezone" && v == "UTC"));
}
#[test]
fn test_socket_addr() {
let config = PgConfig::new("db.example.com", "user", "db").port(5433);
assert_eq!(config.socket_addr(), "db.example.com:5433");
}
#[test]
fn test_ssl_mode_properties() {
assert!(!SslMode::Disable.should_try_ssl());
assert!(!SslMode::Disable.is_required());
assert!(SslMode::Prefer.should_try_ssl());
assert!(!SslMode::Prefer.is_required());
assert!(SslMode::Require.should_try_ssl());
assert!(SslMode::Require.is_required());
assert!(SslMode::VerifyCa.should_try_ssl());
assert!(SslMode::VerifyCa.is_required());
assert!(SslMode::VerifyFull.should_try_ssl());
assert!(SslMode::VerifyFull.is_required());
}
}