use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PgAuthMethod {
Password { password: String },
Keychain { account: String },
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
pub enum PgTlsMode {
Disable,
#[default]
Prefer,
Require,
VerifyFull,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshTunnelRef {
pub ssh_connection_id: String,
pub remote_host: String,
pub remote_port: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PgConfig {
pub host: String,
pub port: u16,
pub database: String,
pub user: String,
pub auth: PgAuthMethod,
#[serde(default)]
pub tls: PgTlsMode,
#[serde(default)]
pub application_name: Option<String>,
#[serde(default)]
pub ssh_tunnel: Option<SshTunnelRef>,
#[serde(default)]
pub connect_timeout_secs: Option<u64>,
#[serde(default)]
pub max_pool_size: Option<u32>,
#[serde(default)]
pub idle_timeout_secs: Option<u64>,
#[serde(default)]
pub min_idle_connections: Option<u32>,
}
impl PgConfig {
pub fn local(database: impl Into<String>, user: impl Into<String>) -> Self {
Self {
host: "127.0.0.1".to_string(),
port: 5432,
database: database.into(),
user: user.into(),
auth: PgAuthMethod::Password {
password: String::new(),
},
tls: PgTlsMode::Disable,
application_name: Some("r-shell".to_string()),
ssh_tunnel: None,
connect_timeout_secs: Some(10),
max_pool_size: None,
idle_timeout_secs: None,
min_idle_connections: None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn local_defaults_disable_tls_and_set_app_name() {
let cfg = PgConfig::local("mydb", "alice");
assert_eq!(cfg.host, "127.0.0.1");
assert_eq!(cfg.port, 5432);
assert_eq!(cfg.database, "mydb");
assert_eq!(cfg.user, "alice");
assert_eq!(cfg.tls, PgTlsMode::Disable);
assert_eq!(cfg.application_name.as_deref(), Some("r-shell"));
assert!(cfg.ssh_tunnel.is_none());
}
#[test]
fn tls_mode_default_is_prefer() {
assert_eq!(PgTlsMode::default(), PgTlsMode::Prefer);
}
#[test]
fn config_round_trips_through_serde() {
let cfg = PgConfig {
host: "db.example.com".to_string(),
port: 5433,
database: "app".to_string(),
user: "svc".to_string(),
auth: PgAuthMethod::Keychain {
account: "postgres:profile-1".to_string(),
},
tls: PgTlsMode::VerifyFull,
application_name: Some("r-shell".to_string()),
ssh_tunnel: Some(SshTunnelRef {
ssh_connection_id: "ssh-1".to_string(),
remote_host: "db.internal".to_string(),
remote_port: 5432,
}),
connect_timeout_secs: Some(15),
max_pool_size: Some(10),
idle_timeout_secs: Some(120),
min_idle_connections: Some(0),
};
let json = serde_json::to_string(&cfg).expect("serialize");
let back: PgConfig = serde_json::from_str(&json).expect("deserialize");
assert_eq!(back.host, cfg.host);
assert_eq!(back.tls, cfg.tls);
assert!(back.ssh_tunnel.is_some());
assert_eq!(back.max_pool_size, Some(10));
assert_eq!(back.idle_timeout_secs, Some(120));
assert_eq!(back.min_idle_connections, Some(0));
}
#[test]
fn local_defaults_pool_settings_to_none() {
let cfg = PgConfig::local("db", "u");
assert_eq!(cfg.max_pool_size, None);
assert_eq!(cfg.idle_timeout_secs, None);
assert_eq!(cfg.min_idle_connections, None);
}
}