use crate::{ProxyError, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::time::Duration;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PoolingMode {
#[default]
Session,
Transaction,
Statement,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum PreparedStatementMode {
#[default]
Disable,
Track,
Named,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolModeConfig {
#[serde(default)]
pub mode: PoolingMode,
#[serde(default = "default_pool_mode_max_size")]
pub max_pool_size: u32,
#[serde(default = "default_pool_mode_min_idle")]
pub min_idle: u32,
#[serde(default = "default_pool_mode_idle_timeout")]
pub idle_timeout_secs: u64,
#[serde(default = "default_pool_mode_max_lifetime")]
pub max_lifetime_secs: u64,
#[serde(default = "default_pool_mode_acquire_timeout")]
pub acquire_timeout_secs: u64,
#[serde(default = "default_reset_query")]
pub reset_query: String,
#[serde(default)]
pub prepared_statement_mode: PreparedStatementMode,
}
fn default_pool_mode_max_size() -> u32 {
100
}
fn default_pool_mode_min_idle() -> u32 {
10
}
fn default_pool_mode_idle_timeout() -> u64 {
600
}
fn default_pool_mode_max_lifetime() -> u64 {
3600
}
fn default_pool_mode_acquire_timeout() -> u64 {
5
}
fn default_reset_query() -> String {
"DISCARD ALL".to_string()
}
impl Default for PoolModeConfig {
fn default() -> Self {
Self {
mode: PoolingMode::default(),
max_pool_size: default_pool_mode_max_size(),
min_idle: default_pool_mode_min_idle(),
idle_timeout_secs: default_pool_mode_idle_timeout(),
max_lifetime_secs: default_pool_mode_max_lifetime(),
acquire_timeout_secs: default_pool_mode_acquire_timeout(),
reset_query: default_reset_query(),
prepared_statement_mode: PreparedStatementMode::default(),
}
}
}
impl PoolModeConfig {
pub fn session_mode() -> Self {
Self {
mode: PoolingMode::Session,
prepared_statement_mode: PreparedStatementMode::Named,
..Default::default()
}
}
pub fn transaction_mode() -> Self {
Self {
mode: PoolingMode::Transaction,
prepared_statement_mode: PreparedStatementMode::Track,
..Default::default()
}
}
pub fn statement_mode() -> Self {
Self {
mode: PoolingMode::Statement,
prepared_statement_mode: PreparedStatementMode::Disable,
..Default::default()
}
}
pub fn idle_timeout(&self) -> Duration {
Duration::from_secs(self.idle_timeout_secs)
}
pub fn max_lifetime(&self) -> Duration {
Duration::from_secs(self.max_lifetime_secs)
}
pub fn acquire_timeout(&self) -> Duration {
Duration::from_secs(self.acquire_timeout_secs)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProxyConfig {
pub listen_address: String,
pub admin_address: String,
pub tr_enabled: bool,
pub tr_mode: TrMode,
pub pool: PoolConfig,
#[serde(default)]
pub pool_mode: PoolModeConfig,
pub load_balancer: LoadBalancerConfig,
pub health: HealthConfig,
pub nodes: Vec<NodeConfig>,
pub tls: Option<TlsConfig>,
#[serde(default = "default_write_timeout_secs")]
pub write_timeout_secs: u64,
#[serde(default)]
pub plugins: PluginToml,
}
fn default_write_timeout_secs() -> u64 {
30 }
impl Default for ProxyConfig {
fn default() -> Self {
Self {
listen_address: "0.0.0.0:5432".to_string(),
admin_address: "0.0.0.0:9090".to_string(),
tr_enabled: true,
tr_mode: TrMode::Session,
pool: PoolConfig::default(),
pool_mode: PoolModeConfig::default(),
load_balancer: LoadBalancerConfig::default(),
health: HealthConfig::default(),
nodes: Vec::new(),
tls: None,
write_timeout_secs: default_write_timeout_secs(),
plugins: PluginToml::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginToml {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_plugin_dir")]
pub plugin_dir: String,
#[serde(default)]
pub hot_reload: bool,
#[serde(default = "default_plugin_memory_mb")]
pub memory_limit_mb: usize,
#[serde(default = "default_plugin_timeout_ms")]
pub timeout_ms: u64,
#[serde(default = "default_plugin_max")]
pub max_plugins: usize,
#[serde(default = "default_true")]
pub fuel_metering: bool,
#[serde(default = "default_plugin_fuel")]
pub fuel_limit: u64,
#[serde(default)]
pub trust_root: Option<String>,
}
fn default_plugin_dir() -> String {
"/etc/heliosproxy/plugins".to_string()
}
fn default_plugin_memory_mb() -> usize {
64
}
fn default_plugin_timeout_ms() -> u64 {
100
}
fn default_plugin_max() -> usize {
20
}
fn default_true() -> bool {
true
}
fn default_plugin_fuel() -> u64 {
1_000_000
}
impl Default for PluginToml {
fn default() -> Self {
Self {
enabled: false,
plugin_dir: default_plugin_dir(),
hot_reload: false,
memory_limit_mb: default_plugin_memory_mb(),
timeout_ms: default_plugin_timeout_ms(),
max_plugins: default_plugin_max(),
fuel_metering: true,
fuel_limit: default_plugin_fuel(),
trust_root: None,
}
}
}
impl ProxyConfig {
pub fn write_timeout(&self) -> Duration {
Duration::from_secs(self.write_timeout_secs)
}
pub fn from_file(path: &str) -> Result<Self> {
let path = Path::new(path);
if !path.exists() {
return Err(ProxyError::Config(format!(
"Configuration file not found: {}",
path.display()
)));
}
let contents = std::fs::read_to_string(path)
.map_err(|e| ProxyError::Config(format!("Failed to read config: {}", e)))?;
let config: Self = toml::from_str(&contents)
.map_err(|e| ProxyError::Config(format!("Failed to parse config: {}", e)))?;
config.validate()?;
Ok(config)
}
pub fn add_node(&mut self, host_port: &str, role: &str) -> Result<()> {
let parts: Vec<&str> = host_port.rsplitn(2, ':').collect();
if parts.len() != 2 {
return Err(ProxyError::Config(format!(
"Invalid host:port format: {}",
host_port
)));
}
let port: u16 = parts[0].parse()
.map_err(|_| ProxyError::Config(format!("Invalid port: {}", parts[0])))?;
let host = parts[1].to_string();
let role = match role {
"primary" => NodeRole::Primary,
"standby" => NodeRole::Standby,
"replica" => NodeRole::ReadReplica,
_ => return Err(ProxyError::Config(format!("Unknown role: {}", role))),
};
self.nodes.push(NodeConfig {
host,
port,
http_port: default_http_port(),
role,
weight: 100,
enabled: true,
name: None,
});
Ok(())
}
pub fn validate(&self) -> Result<()> {
if self.nodes.is_empty() {
return Err(ProxyError::Config("No backend nodes configured".to_string()));
}
let has_primary = self.nodes.iter().any(|n| n.role == NodeRole::Primary);
if !has_primary {
return Err(ProxyError::Config("No primary node configured".to_string()));
}
if self.pool.max_connections < self.pool.min_connections {
return Err(ProxyError::Config(
"max_connections must be >= min_connections".to_string(),
));
}
Ok(())
}
pub fn primary_node(&self) -> Option<&NodeConfig> {
self.nodes.iter().find(|n| n.role == NodeRole::Primary && n.enabled)
}
pub fn standby_nodes(&self) -> Vec<&NodeConfig> {
self.nodes.iter()
.filter(|n| n.role == NodeRole::Standby && n.enabled)
.collect()
}
pub fn enabled_nodes(&self) -> Vec<&NodeConfig> {
self.nodes.iter().filter(|n| n.enabled).collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TrMode {
None,
Session,
Select,
Transaction,
}
impl Default for TrMode {
fn default() -> Self {
TrMode::Session
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PoolConfig {
pub min_connections: usize,
pub max_connections: usize,
pub idle_timeout_secs: u64,
pub max_lifetime_secs: u64,
pub acquire_timeout_secs: u64,
pub test_on_acquire: bool,
}
impl Default for PoolConfig {
fn default() -> Self {
Self {
min_connections: 2,
max_connections: 100,
idle_timeout_secs: 300,
max_lifetime_secs: 1800,
acquire_timeout_secs: 30,
test_on_acquire: true,
}
}
}
impl PoolConfig {
pub fn idle_timeout(&self) -> Duration {
Duration::from_secs(self.idle_timeout_secs)
}
pub fn max_lifetime(&self) -> Duration {
Duration::from_secs(self.max_lifetime_secs)
}
pub fn acquire_timeout(&self) -> Duration {
Duration::from_secs(self.acquire_timeout_secs)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoadBalancerConfig {
pub read_strategy: Strategy,
pub read_write_split: bool,
pub latency_threshold_ms: u64,
}
impl Default for LoadBalancerConfig {
fn default() -> Self {
Self {
read_strategy: Strategy::RoundRobin,
read_write_split: true,
latency_threshold_ms: 100,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Strategy {
RoundRobin,
WeightedRoundRobin,
LeastConnections,
LatencyBased,
Random,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HealthConfig {
pub check_interval_secs: u64,
pub check_timeout_secs: u64,
pub failure_threshold: u32,
pub success_threshold: u32,
pub check_query: String,
}
impl Default for HealthConfig {
fn default() -> Self {
Self {
check_interval_secs: 5,
check_timeout_secs: 3,
failure_threshold: 3,
success_threshold: 2,
check_query: "SELECT 1".to_string(),
}
}
}
impl HealthConfig {
pub fn check_interval(&self) -> Duration {
Duration::from_secs(self.check_interval_secs)
}
pub fn check_timeout(&self) -> Duration {
Duration::from_secs(self.check_timeout_secs)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeConfig {
pub host: String,
pub port: u16,
#[serde(default = "default_http_port")]
pub http_port: u16,
pub role: NodeRole,
pub weight: u32,
pub enabled: bool,
pub name: Option<String>,
}
fn default_http_port() -> u16 {
8080
}
impl NodeConfig {
pub fn address(&self) -> String {
format!("{}:{}", self.host, self.port)
}
pub fn display_name(&self) -> &str {
self.name.as_deref().unwrap_or(&self.host)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum NodeRole {
Primary,
Standby,
#[serde(rename = "replica")]
ReadReplica,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TlsConfig {
pub enabled: bool,
pub cert_path: String,
pub key_path: String,
pub ca_path: Option<String>,
pub require_client_cert: bool,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = ProxyConfig::default();
assert_eq!(config.listen_address, "0.0.0.0:5432");
assert!(config.tr_enabled);
}
#[test]
fn test_add_node() {
let mut config = ProxyConfig::default();
config.add_node("localhost:5432", "primary").unwrap();
config.add_node("localhost:5433", "standby").unwrap();
assert_eq!(config.nodes.len(), 2);
assert!(config.primary_node().is_some());
assert_eq!(config.standby_nodes().len(), 1);
}
#[test]
fn test_validate_no_nodes() {
let config = ProxyConfig::default();
assert!(config.validate().is_err());
}
#[test]
fn test_validate_no_primary() {
let mut config = ProxyConfig::default();
config.add_node("localhost:5432", "standby").unwrap();
assert!(config.validate().is_err());
}
#[test]
fn test_validate_success() {
let mut config = ProxyConfig::default();
config.add_node("localhost:5432", "primary").unwrap();
assert!(config.validate().is_ok());
}
#[test]
fn test_pool_config_durations() {
let config = PoolConfig::default();
assert_eq!(config.idle_timeout(), Duration::from_secs(300));
assert_eq!(config.max_lifetime(), Duration::from_secs(1800));
}
#[test]
fn test_pool_mode_default() {
let config = PoolModeConfig::default();
assert_eq!(config.mode, PoolingMode::Session);
assert_eq!(config.max_pool_size, 100);
assert_eq!(config.min_idle, 10);
assert_eq!(config.reset_query, "DISCARD ALL");
}
#[test]
fn test_pool_mode_session() {
let config = PoolModeConfig::session_mode();
assert_eq!(config.mode, PoolingMode::Session);
assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Named);
}
#[test]
fn test_pool_mode_transaction() {
let config = PoolModeConfig::transaction_mode();
assert_eq!(config.mode, PoolingMode::Transaction);
assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Track);
}
#[test]
fn test_pool_mode_statement() {
let config = PoolModeConfig::statement_mode();
assert_eq!(config.mode, PoolingMode::Statement);
assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Disable);
}
#[test]
fn test_pool_mode_durations() {
let config = PoolModeConfig::default();
assert_eq!(config.idle_timeout(), Duration::from_secs(600));
assert_eq!(config.max_lifetime(), Duration::from_secs(3600));
assert_eq!(config.acquire_timeout(), Duration::from_secs(5));
}
#[test]
fn test_proxy_config_has_pool_mode() {
let config = ProxyConfig::default();
assert_eq!(config.pool_mode.mode, PoolingMode::Session);
}
#[test]
fn test_plugin_toml_default_is_disabled() {
let config = ProxyConfig::default();
assert!(!config.plugins.enabled);
assert_eq!(config.plugins.plugin_dir, "/etc/heliosproxy/plugins");
assert_eq!(config.plugins.memory_limit_mb, 64);
assert_eq!(config.plugins.timeout_ms, 100);
}
#[test]
fn test_proxy_config_toml_without_plugins_section_still_parses() {
let toml_text = r#"
listen_address = "0.0.0.0:5432"
admin_address = "0.0.0.0:9090"
tr_enabled = true
tr_mode = "session"
nodes = []
[pool]
min_connections = 2
max_connections = 10
idle_timeout_secs = 300
max_lifetime_secs = 1800
acquire_timeout_secs = 30
test_on_acquire = true
[load_balancer]
read_strategy = "round_robin"
read_write_split = true
latency_threshold_ms = 100
[health]
check_interval_secs = 5
check_timeout_secs = 3
failure_threshold = 3
success_threshold = 2
check_query = "SELECT 1"
"#;
let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
assert!(!config.plugins.enabled);
}
#[test]
fn test_plugin_toml_overrides_parse() {
let toml_text = r#"
listen_address = "0.0.0.0:5432"
admin_address = "0.0.0.0:9090"
tr_enabled = true
tr_mode = "session"
nodes = []
[pool]
min_connections = 2
max_connections = 10
idle_timeout_secs = 300
max_lifetime_secs = 1800
acquire_timeout_secs = 30
test_on_acquire = true
[load_balancer]
read_strategy = "round_robin"
read_write_split = true
latency_threshold_ms = 100
[health]
check_interval_secs = 5
check_timeout_secs = 3
failure_threshold = 3
success_threshold = 2
check_query = "SELECT 1"
[plugins]
enabled = true
plugin_dir = "/tmp/helios-plugins"
hot_reload = true
memory_limit_mb = 128
timeout_ms = 250
"#;
let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
assert!(config.plugins.enabled);
assert_eq!(config.plugins.plugin_dir, "/tmp/helios-plugins");
assert!(config.plugins.hot_reload);
assert_eq!(config.plugins.memory_limit_mb, 128);
assert_eq!(config.plugins.timeout_ms, 250);
assert_eq!(config.plugins.max_plugins, 20);
assert!(config.plugins.fuel_metering);
}
}