use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub proxy: ProxyConfig,
pub targets: HashMap<String, TargetConfig>,
#[serde(default)]
pub connection_management: ConnectionManagementConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyConfig {
pub listen_port: u16,
pub listen_host: String,
#[serde(default = "default_max_connections")]
pub max_connections: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TargetConfig {
pub host: String,
pub port: u16,
pub ssh: Option<SshConfig>,
#[serde(default)]
pub connection_pool: ConnectionPoolConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SshConfig {
pub enabled: bool,
pub host: Option<String>,
pub user: Option<String>,
pub key_file: Option<PathBuf>,
pub port: Option<u16>,
#[serde(default = "default_ssh_timeout")]
pub timeout_seconds: u64,
#[serde(default = "default_auto_reconnect")]
pub auto_reconnect: bool,
#[serde(default = "default_reconnect_interval")]
pub reconnect_interval_seconds: u64,
#[serde(default = "default_max_reconnect_attempts")]
pub max_reconnect_attempts: u32,
#[serde(default = "default_reconnect_backoff_multiplier")]
pub reconnect_backoff_multiplier: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionPoolConfig {
#[serde(default = "default_pool_size")]
pub max_size: usize,
#[serde(default = "default_pool_timeout")]
pub timeout_seconds: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConnectionManagementConfig {
#[serde(default = "default_health_check_interval")]
pub health_check_interval_seconds: u64,
#[serde(default = "default_health_check_timeout")]
pub health_check_timeout_seconds: u64,
#[serde(default = "default_max_failures")]
pub max_consecutive_failures: u32,
#[serde(default = "default_retry_delay")]
pub retry_delay_seconds: u64,
}
fn default_max_connections() -> usize {
1000
}
fn default_pool_size() -> usize {
10
}
fn default_pool_timeout() -> u64 {
30
}
fn default_ssh_timeout() -> u64 {
30
}
fn default_auto_reconnect() -> bool {
true
}
fn default_reconnect_interval() -> u64 {
30
}
fn default_max_reconnect_attempts() -> u32 {
5
}
fn default_reconnect_backoff_multiplier() -> f64 {
2.0
}
fn default_health_check_interval() -> u64 {
30
}
fn default_health_check_timeout() -> u64 {
5
}
fn default_max_failures() -> u32 {
3
}
fn default_retry_delay() -> u64 {
5
}
impl Default for ConnectionPoolConfig {
fn default() -> Self {
Self {
max_size: default_pool_size(),
timeout_seconds: default_pool_timeout(),
}
}
}
impl Default for ConnectionManagementConfig {
fn default() -> Self {
Self {
health_check_interval_seconds: default_health_check_interval(),
health_check_timeout_seconds: default_health_check_timeout(),
max_consecutive_failures: default_max_failures(),
retry_delay_seconds: default_retry_delay(),
}
}
}
impl Default for Config {
fn default() -> Self {
let mut targets = HashMap::new();
targets.insert(
"local".to_string(),
TargetConfig {
host: "localhost".to_string(),
port: 5432,
ssh: None,
connection_pool: ConnectionPoolConfig::default(),
},
);
targets.insert(
"production".to_string(),
TargetConfig {
host: "prod-db.example.com".to_string(),
port: 5432,
ssh: Some(SshConfig {
enabled: true,
host: Some("bastion.example.com".to_string()),
user: Some("dbuser".to_string()),
key_file: Some(PathBuf::from("/path/to/cert.pem")),
port: Some(22),
timeout_seconds: default_ssh_timeout(),
auto_reconnect: default_auto_reconnect(),
reconnect_interval_seconds: default_reconnect_interval(),
max_reconnect_attempts: default_max_reconnect_attempts(),
reconnect_backoff_multiplier: default_reconnect_backoff_multiplier(),
}),
connection_pool: ConnectionPoolConfig::default(),
},
);
targets.insert(
"development".to_string(),
TargetConfig {
host: "dev-db.example.com".to_string(),
port: 5432,
ssh: Some(SshConfig {
enabled: false,
host: Some("dev-bastion.example.com".to_string()),
user: Some("devuser".to_string()),
key_file: Some(PathBuf::from("/path/to/dev-cert.pem")),
port: Some(22),
timeout_seconds: default_ssh_timeout(),
auto_reconnect: default_auto_reconnect(),
reconnect_interval_seconds: default_reconnect_interval(),
max_reconnect_attempts: default_max_reconnect_attempts(),
reconnect_backoff_multiplier: default_reconnect_backoff_multiplier(),
}),
connection_pool: ConnectionPoolConfig::default(),
},
);
Self {
proxy: ProxyConfig {
listen_port: 5433,
listen_host: "127.0.0.1".to_string(),
max_connections: default_max_connections(),
},
targets,
connection_management: ConnectionManagementConfig::default(),
}
}
}
impl Config {
pub fn from_file(path: &PathBuf) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = serde_yaml::from_str(&content)?;
Ok(config)
}
pub fn to_file(&self, path: &PathBuf) -> anyhow::Result<()> {
let content = serde_yaml::to_string(self)?;
std::fs::write(path, content)?;
Ok(())
}
pub fn validate(&self) -> anyhow::Result<()> {
if self.proxy.listen_port == 0 {
anyhow::bail!("Proxy listen port cannot be 0");
}
if self.targets.is_empty() {
anyhow::bail!("At least one target must be configured");
}
for (name, target) in &self.targets {
if target.port == 0 {
anyhow::bail!("Target '{}' port cannot be 0", name);
}
if let Some(ssh) = &target.ssh {
if ssh.enabled {
if ssh.host.is_none() {
anyhow::bail!(
"SSH host is required when SSH is enabled for target '{}'",
name
);
}
if ssh.user.is_none() {
anyhow::bail!(
"SSH user is required when SSH is enabled for target '{}'",
name
);
}
}
}
}
Ok(())
}
pub fn get_target(&self, target_name: &str) -> anyhow::Result<&TargetConfig> {
self.targets.get(target_name).ok_or_else(|| {
anyhow::anyhow!(
"Target '{}' not found. Available targets: {}",
target_name,
self.list_targets().join(", ")
)
})
}
pub fn list_targets(&self) -> Vec<String> {
self.targets.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_config_default() {
let config = Config::default();
assert_eq!(config.proxy.listen_port, 5433);
assert_eq!(config.proxy.listen_host, "127.0.0.1");
assert_eq!(config.proxy.max_connections, 1000);
assert!(config.targets.contains_key("local"));
assert!(config.targets.contains_key("production"));
assert!(config.targets.contains_key("development"));
let local_target = config.get_target("local").unwrap();
assert_eq!(local_target.host, "localhost");
assert_eq!(local_target.port, 5432);
assert!(local_target.ssh.is_none());
}
#[test]
fn test_config_validation_valid() {
let config = Config::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_config_validation_invalid_port() {
let mut config = Config::default();
config.proxy.listen_port = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_config_validation_empty_targets() {
let mut config = Config::default();
config.targets.clear();
assert!(config.validate().is_err());
}
#[test]
fn test_config_from_yaml() {
let yaml_content = r#"
proxy:
listen_port: 5434
listen_host: "0.0.0.0"
max_connections: 500
targets:
test:
host: "test.example.com"
port: 5432
ssh:
enabled: true
host: "bastion.test.com"
user: "testuser"
key_file: "/tmp/test.pem"
port: 22
connection_management:
health_check_interval_seconds: 60
health_check_timeout_seconds: 10
"#;
let mut temp_file = NamedTempFile::new().unwrap();
temp_file.write_all(yaml_content.as_bytes()).unwrap();
let config = Config::from_file(&temp_file.path().to_path_buf()).unwrap();
assert_eq!(config.proxy.listen_port, 5434);
assert_eq!(config.proxy.listen_host, "0.0.0.0");
assert_eq!(config.proxy.max_connections, 500);
let test_target = config.get_target("test").unwrap();
assert_eq!(test_target.host, "test.example.com");
assert_eq!(test_target.port, 5432);
let ssh_config = test_target.ssh.as_ref().unwrap();
assert!(ssh_config.enabled);
assert_eq!(ssh_config.host.as_ref().unwrap(), "bastion.test.com");
assert_eq!(ssh_config.user.as_ref().unwrap(), "testuser");
assert_eq!(
config.connection_management.health_check_interval_seconds,
60
);
assert_eq!(
config.connection_management.health_check_timeout_seconds,
10
);
}
#[test]
fn test_get_target_not_found() {
let config = Config::default();
assert!(config.get_target("nonexistent").is_err());
}
#[test]
fn test_list_targets() {
let config = Config::default();
let targets = config.list_targets();
assert_eq!(targets.len(), 3);
assert!(targets.contains(&"local".to_string()));
assert!(targets.contains(&"production".to_string()));
assert!(targets.contains(&"development".to_string()));
}
#[test]
fn test_ssh_config_defaults() {
let ssh_config = SshConfig {
enabled: true,
host: Some("test.com".to_string()),
user: Some("user".to_string()),
key_file: Some(PathBuf::from("/test.pem")),
port: Some(22),
timeout_seconds: default_ssh_timeout(),
auto_reconnect: default_auto_reconnect(),
reconnect_interval_seconds: default_reconnect_interval(),
max_reconnect_attempts: default_max_reconnect_attempts(),
reconnect_backoff_multiplier: default_reconnect_backoff_multiplier(),
};
assert_eq!(ssh_config.timeout_seconds, 30);
assert!(ssh_config.auto_reconnect);
assert_eq!(ssh_config.reconnect_interval_seconds, 30);
assert_eq!(ssh_config.max_reconnect_attempts, 5);
assert_eq!(ssh_config.reconnect_backoff_multiplier, 2.0);
}
#[test]
fn test_connection_pool_defaults() {
let pool_config = ConnectionPoolConfig::default();
assert_eq!(pool_config.max_size, 10);
assert_eq!(pool_config.timeout_seconds, 30);
}
#[test]
fn test_connection_management_defaults() {
let mgmt_config = ConnectionManagementConfig::default();
assert_eq!(mgmt_config.health_check_interval_seconds, 30);
assert_eq!(mgmt_config.health_check_timeout_seconds, 5);
assert_eq!(mgmt_config.max_consecutive_failures, 3);
assert_eq!(mgmt_config.retry_delay_seconds, 5);
}
}