1use crate::{ProxyError, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
18#[serde(rename_all = "lowercase")]
19pub enum PoolingMode {
20 #[default]
22 Session,
23 Transaction,
25 Statement,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31#[serde(rename_all = "lowercase")]
32pub enum PreparedStatementMode {
33 #[default]
35 Disable,
36 Track,
38 Named,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PoolModeConfig {
45 #[serde(default)]
47 pub mode: PoolingMode,
48 #[serde(default = "default_pool_mode_max_size")]
50 pub max_pool_size: u32,
51 #[serde(default = "default_pool_mode_min_idle")]
53 pub min_idle: u32,
54 #[serde(default = "default_pool_mode_idle_timeout")]
56 pub idle_timeout_secs: u64,
57 #[serde(default = "default_pool_mode_max_lifetime")]
59 pub max_lifetime_secs: u64,
60 #[serde(default = "default_pool_mode_acquire_timeout")]
62 pub acquire_timeout_secs: u64,
63 #[serde(default = "default_reset_query")]
65 pub reset_query: String,
66 #[serde(default)]
68 pub prepared_statement_mode: PreparedStatementMode,
69}
70
71fn default_pool_mode_max_size() -> u32 {
72 100
73}
74
75fn default_pool_mode_min_idle() -> u32 {
76 10
77}
78
79fn default_pool_mode_idle_timeout() -> u64 {
80 600
81}
82
83fn default_pool_mode_max_lifetime() -> u64 {
84 3600
85}
86
87fn default_pool_mode_acquire_timeout() -> u64 {
88 5
89}
90
91fn default_reset_query() -> String {
92 "DISCARD ALL".to_string()
93}
94
95impl Default for PoolModeConfig {
96 fn default() -> Self {
97 Self {
98 mode: PoolingMode::default(),
99 max_pool_size: default_pool_mode_max_size(),
100 min_idle: default_pool_mode_min_idle(),
101 idle_timeout_secs: default_pool_mode_idle_timeout(),
102 max_lifetime_secs: default_pool_mode_max_lifetime(),
103 acquire_timeout_secs: default_pool_mode_acquire_timeout(),
104 reset_query: default_reset_query(),
105 prepared_statement_mode: PreparedStatementMode::default(),
106 }
107 }
108}
109
110impl PoolModeConfig {
111 pub fn session_mode() -> Self {
113 Self {
114 mode: PoolingMode::Session,
115 prepared_statement_mode: PreparedStatementMode::Named,
116 ..Default::default()
117 }
118 }
119
120 pub fn transaction_mode() -> Self {
122 Self {
123 mode: PoolingMode::Transaction,
124 prepared_statement_mode: PreparedStatementMode::Track,
125 ..Default::default()
126 }
127 }
128
129 pub fn statement_mode() -> Self {
131 Self {
132 mode: PoolingMode::Statement,
133 prepared_statement_mode: PreparedStatementMode::Disable,
134 ..Default::default()
135 }
136 }
137
138 pub fn idle_timeout(&self) -> Duration {
140 Duration::from_secs(self.idle_timeout_secs)
141 }
142
143 pub fn max_lifetime(&self) -> Duration {
145 Duration::from_secs(self.max_lifetime_secs)
146 }
147
148 pub fn acquire_timeout(&self) -> Duration {
150 Duration::from_secs(self.acquire_timeout_secs)
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ProxyConfig {
161 pub listen_address: String,
163 pub admin_address: String,
165 pub tr_enabled: bool,
167 pub tr_mode: TrMode,
169 pub pool: PoolConfig,
171 #[serde(default)]
173 pub pool_mode: PoolModeConfig,
174 pub load_balancer: LoadBalancerConfig,
176 pub health: HealthConfig,
178 pub nodes: Vec<NodeConfig>,
180 pub tls: Option<TlsConfig>,
182 #[serde(default = "default_write_timeout_secs")]
185 pub write_timeout_secs: u64,
186 #[serde(default)]
190 pub plugins: PluginToml,
191}
192
193fn default_write_timeout_secs() -> u64 {
194 30 }
196
197impl Default for ProxyConfig {
198 fn default() -> Self {
199 Self {
200 listen_address: "0.0.0.0:5432".to_string(),
201 admin_address: "0.0.0.0:9090".to_string(),
202 tr_enabled: true,
203 tr_mode: TrMode::Session,
204 pool: PoolConfig::default(),
205 pool_mode: PoolModeConfig::default(),
206 load_balancer: LoadBalancerConfig::default(),
207 health: HealthConfig::default(),
208 nodes: Vec::new(),
209 tls: None,
210 write_timeout_secs: default_write_timeout_secs(),
211 plugins: PluginToml::default(),
212 }
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct PluginToml {
231 #[serde(default)]
234 pub enabled: bool,
235 #[serde(default = "default_plugin_dir")]
237 pub plugin_dir: String,
238 #[serde(default)]
240 pub hot_reload: bool,
241 #[serde(default = "default_plugin_memory_mb")]
243 pub memory_limit_mb: usize,
244 #[serde(default = "default_plugin_timeout_ms")]
246 pub timeout_ms: u64,
247 #[serde(default = "default_plugin_max")]
249 pub max_plugins: usize,
250 #[serde(default = "default_true")]
252 pub fuel_metering: bool,
253 #[serde(default = "default_plugin_fuel")]
255 pub fuel_limit: u64,
256 #[serde(default)]
262 pub trust_root: Option<String>,
263}
264
265fn default_plugin_dir() -> String {
266 "/etc/heliosproxy/plugins".to_string()
267}
268fn default_plugin_memory_mb() -> usize {
269 64
270}
271fn default_plugin_timeout_ms() -> u64 {
272 100
273}
274fn default_plugin_max() -> usize {
275 20
276}
277fn default_true() -> bool {
278 true
279}
280fn default_plugin_fuel() -> u64 {
281 1_000_000
282}
283
284impl Default for PluginToml {
285 fn default() -> Self {
286 Self {
287 enabled: false,
288 plugin_dir: default_plugin_dir(),
289 hot_reload: false,
290 memory_limit_mb: default_plugin_memory_mb(),
291 timeout_ms: default_plugin_timeout_ms(),
292 max_plugins: default_plugin_max(),
293 fuel_metering: true,
294 fuel_limit: default_plugin_fuel(),
295 trust_root: None,
296 }
297 }
298}
299
300impl ProxyConfig {
301 pub fn write_timeout(&self) -> Duration {
303 Duration::from_secs(self.write_timeout_secs)
304 }
305
306 pub fn from_file(path: &str) -> Result<Self> {
308 let path = Path::new(path);
309
310 if !path.exists() {
311 return Err(ProxyError::Config(format!(
312 "Configuration file not found: {}",
313 path.display()
314 )));
315 }
316
317 let contents = std::fs::read_to_string(path)
318 .map_err(|e| ProxyError::Config(format!("Failed to read config: {}", e)))?;
319
320 let config: Self = toml::from_str(&contents)
321 .map_err(|e| ProxyError::Config(format!("Failed to parse config: {}", e)))?;
322
323 config.validate()?;
324
325 Ok(config)
326 }
327
328 pub fn add_node(&mut self, host_port: &str, role: &str) -> Result<()> {
330 let parts: Vec<&str> = host_port.rsplitn(2, ':').collect();
331 if parts.len() != 2 {
332 return Err(ProxyError::Config(format!(
333 "Invalid host:port format: {}",
334 host_port
335 )));
336 }
337
338 let port: u16 = parts[0].parse()
339 .map_err(|_| ProxyError::Config(format!("Invalid port: {}", parts[0])))?;
340
341 let host = parts[1].to_string();
342
343 let role = match role {
344 "primary" => NodeRole::Primary,
345 "standby" => NodeRole::Standby,
346 "replica" => NodeRole::ReadReplica,
347 _ => return Err(ProxyError::Config(format!("Unknown role: {}", role))),
348 };
349
350 self.nodes.push(NodeConfig {
351 host,
352 port,
353 http_port: default_http_port(),
354 role,
355 weight: 100,
356 enabled: true,
357 name: None,
358 });
359
360 Ok(())
361 }
362
363 pub fn validate(&self) -> Result<()> {
365 if self.nodes.is_empty() {
367 return Err(ProxyError::Config("No backend nodes configured".to_string()));
368 }
369
370 let has_primary = self.nodes.iter().any(|n| n.role == NodeRole::Primary);
372 if !has_primary {
373 return Err(ProxyError::Config("No primary node configured".to_string()));
374 }
375
376 if self.pool.max_connections < self.pool.min_connections {
378 return Err(ProxyError::Config(
379 "max_connections must be >= min_connections".to_string(),
380 ));
381 }
382
383 Ok(())
384 }
385
386 pub fn primary_node(&self) -> Option<&NodeConfig> {
388 self.nodes.iter().find(|n| n.role == NodeRole::Primary && n.enabled)
389 }
390
391 pub fn standby_nodes(&self) -> Vec<&NodeConfig> {
393 self.nodes.iter()
394 .filter(|n| n.role == NodeRole::Standby && n.enabled)
395 .collect()
396 }
397
398 pub fn enabled_nodes(&self) -> Vec<&NodeConfig> {
400 self.nodes.iter().filter(|n| n.enabled).collect()
401 }
402}
403
404#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
406#[serde(rename_all = "lowercase")]
407pub enum TrMode {
408 None,
410 Session,
412 Select,
414 Transaction,
416}
417
418impl Default for TrMode {
419 fn default() -> Self {
420 TrMode::Session
421 }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct PoolConfig {
427 pub min_connections: usize,
429 pub max_connections: usize,
431 pub idle_timeout_secs: u64,
433 pub max_lifetime_secs: u64,
435 pub acquire_timeout_secs: u64,
437 pub test_on_acquire: bool,
439}
440
441impl Default for PoolConfig {
442 fn default() -> Self {
443 Self {
444 min_connections: 2,
445 max_connections: 100,
446 idle_timeout_secs: 300,
447 max_lifetime_secs: 1800,
448 acquire_timeout_secs: 30,
449 test_on_acquire: true,
450 }
451 }
452}
453
454impl PoolConfig {
455 pub fn idle_timeout(&self) -> Duration {
457 Duration::from_secs(self.idle_timeout_secs)
458 }
459
460 pub fn max_lifetime(&self) -> Duration {
462 Duration::from_secs(self.max_lifetime_secs)
463 }
464
465 pub fn acquire_timeout(&self) -> Duration {
467 Duration::from_secs(self.acquire_timeout_secs)
468 }
469}
470
471#[derive(Debug, Clone, Serialize, Deserialize)]
473pub struct LoadBalancerConfig {
474 pub read_strategy: Strategy,
476 pub read_write_split: bool,
478 pub latency_threshold_ms: u64,
480}
481
482impl Default for LoadBalancerConfig {
483 fn default() -> Self {
484 Self {
485 read_strategy: Strategy::RoundRobin,
486 read_write_split: true,
487 latency_threshold_ms: 100,
488 }
489 }
490}
491
492#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
494#[serde(rename_all = "snake_case")]
495pub enum Strategy {
496 RoundRobin,
498 WeightedRoundRobin,
500 LeastConnections,
502 LatencyBased,
504 Random,
506}
507
508#[derive(Debug, Clone, Serialize, Deserialize)]
510pub struct HealthConfig {
511 pub check_interval_secs: u64,
513 pub check_timeout_secs: u64,
515 pub failure_threshold: u32,
517 pub success_threshold: u32,
519 pub check_query: String,
521}
522
523impl Default for HealthConfig {
524 fn default() -> Self {
525 Self {
526 check_interval_secs: 5,
527 check_timeout_secs: 3,
528 failure_threshold: 3,
529 success_threshold: 2,
530 check_query: "SELECT 1".to_string(),
531 }
532 }
533}
534
535impl HealthConfig {
536 pub fn check_interval(&self) -> Duration {
538 Duration::from_secs(self.check_interval_secs)
539 }
540
541 pub fn check_timeout(&self) -> Duration {
543 Duration::from_secs(self.check_timeout_secs)
544 }
545}
546
547#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct NodeConfig {
550 pub host: String,
552 pub port: u16,
554 #[serde(default = "default_http_port")]
557 pub http_port: u16,
558 pub role: NodeRole,
560 pub weight: u32,
562 pub enabled: bool,
564 pub name: Option<String>,
566}
567
568fn default_http_port() -> u16 {
569 8080
570}
571
572impl NodeConfig {
573 pub fn address(&self) -> String {
575 format!("{}:{}", self.host, self.port)
576 }
577
578 pub fn display_name(&self) -> &str {
580 self.name.as_deref().unwrap_or(&self.host)
581 }
582}
583
584#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
586#[serde(rename_all = "lowercase")]
587pub enum NodeRole {
588 Primary,
590 Standby,
592 #[serde(rename = "replica")]
594 ReadReplica,
595}
596
597#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct TlsConfig {
600 pub enabled: bool,
602 pub cert_path: String,
604 pub key_path: String,
606 pub ca_path: Option<String>,
608 pub require_client_cert: bool,
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_default_config() {
618 let config = ProxyConfig::default();
619 assert_eq!(config.listen_address, "0.0.0.0:5432");
620 assert!(config.tr_enabled);
621 }
622
623 #[test]
624 fn test_add_node() {
625 let mut config = ProxyConfig::default();
626 config.add_node("localhost:5432", "primary").unwrap();
627 config.add_node("localhost:5433", "standby").unwrap();
628
629 assert_eq!(config.nodes.len(), 2);
630 assert!(config.primary_node().is_some());
631 assert_eq!(config.standby_nodes().len(), 1);
632 }
633
634 #[test]
635 fn test_validate_no_nodes() {
636 let config = ProxyConfig::default();
637 assert!(config.validate().is_err());
638 }
639
640 #[test]
641 fn test_validate_no_primary() {
642 let mut config = ProxyConfig::default();
643 config.add_node("localhost:5432", "standby").unwrap();
644 assert!(config.validate().is_err());
645 }
646
647 #[test]
648 fn test_validate_success() {
649 let mut config = ProxyConfig::default();
650 config.add_node("localhost:5432", "primary").unwrap();
651 assert!(config.validate().is_ok());
652 }
653
654 #[test]
655 fn test_pool_config_durations() {
656 let config = PoolConfig::default();
657 assert_eq!(config.idle_timeout(), Duration::from_secs(300));
658 assert_eq!(config.max_lifetime(), Duration::from_secs(1800));
659 }
660
661 #[test]
662 fn test_pool_mode_default() {
663 let config = PoolModeConfig::default();
664 assert_eq!(config.mode, PoolingMode::Session);
665 assert_eq!(config.max_pool_size, 100);
666 assert_eq!(config.min_idle, 10);
667 assert_eq!(config.reset_query, "DISCARD ALL");
668 }
669
670 #[test]
671 fn test_pool_mode_session() {
672 let config = PoolModeConfig::session_mode();
673 assert_eq!(config.mode, PoolingMode::Session);
674 assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Named);
675 }
676
677 #[test]
678 fn test_pool_mode_transaction() {
679 let config = PoolModeConfig::transaction_mode();
680 assert_eq!(config.mode, PoolingMode::Transaction);
681 assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Track);
682 }
683
684 #[test]
685 fn test_pool_mode_statement() {
686 let config = PoolModeConfig::statement_mode();
687 assert_eq!(config.mode, PoolingMode::Statement);
688 assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Disable);
689 }
690
691 #[test]
692 fn test_pool_mode_durations() {
693 let config = PoolModeConfig::default();
694 assert_eq!(config.idle_timeout(), Duration::from_secs(600));
695 assert_eq!(config.max_lifetime(), Duration::from_secs(3600));
696 assert_eq!(config.acquire_timeout(), Duration::from_secs(5));
697 }
698
699 #[test]
700 fn test_proxy_config_has_pool_mode() {
701 let config = ProxyConfig::default();
702 assert_eq!(config.pool_mode.mode, PoolingMode::Session);
703 }
704
705 #[test]
709 fn test_plugin_toml_default_is_disabled() {
710 let config = ProxyConfig::default();
711 assert!(!config.plugins.enabled);
712 assert_eq!(config.plugins.plugin_dir, "/etc/heliosproxy/plugins");
713 assert_eq!(config.plugins.memory_limit_mb, 64);
714 assert_eq!(config.plugins.timeout_ms, 100);
715 }
716
717 #[test]
721 fn test_proxy_config_toml_without_plugins_section_still_parses() {
722 let toml_text = r#"
723 listen_address = "0.0.0.0:5432"
724 admin_address = "0.0.0.0:9090"
725 tr_enabled = true
726 tr_mode = "session"
727 nodes = []
728
729 [pool]
730 min_connections = 2
731 max_connections = 10
732 idle_timeout_secs = 300
733 max_lifetime_secs = 1800
734 acquire_timeout_secs = 30
735 test_on_acquire = true
736
737 [load_balancer]
738 read_strategy = "round_robin"
739 read_write_split = true
740 latency_threshold_ms = 100
741
742 [health]
743 check_interval_secs = 5
744 check_timeout_secs = 3
745 failure_threshold = 3
746 success_threshold = 2
747 check_query = "SELECT 1"
748 "#;
749 let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
750 assert!(!config.plugins.enabled);
751 }
752
753 #[test]
756 fn test_plugin_toml_overrides_parse() {
757 let toml_text = r#"
758 listen_address = "0.0.0.0:5432"
759 admin_address = "0.0.0.0:9090"
760 tr_enabled = true
761 tr_mode = "session"
762 nodes = []
763
764 [pool]
765 min_connections = 2
766 max_connections = 10
767 idle_timeout_secs = 300
768 max_lifetime_secs = 1800
769 acquire_timeout_secs = 30
770 test_on_acquire = true
771
772 [load_balancer]
773 read_strategy = "round_robin"
774 read_write_split = true
775 latency_threshold_ms = 100
776
777 [health]
778 check_interval_secs = 5
779 check_timeout_secs = 3
780 failure_threshold = 3
781 success_threshold = 2
782 check_query = "SELECT 1"
783
784 [plugins]
785 enabled = true
786 plugin_dir = "/tmp/helios-plugins"
787 hot_reload = true
788 memory_limit_mb = 128
789 timeout_ms = 250
790 "#;
791 let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
792 assert!(config.plugins.enabled);
793 assert_eq!(config.plugins.plugin_dir, "/tmp/helios-plugins");
794 assert!(config.plugins.hot_reload);
795 assert_eq!(config.plugins.memory_limit_mb, 128);
796 assert_eq!(config.plugins.timeout_ms, 250);
797 assert_eq!(config.plugins.max_plugins, 20);
799 assert!(config.plugins.fuel_metering);
800 }
801}