1use crate::{ProxyError, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use std::time::Duration;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
18#[serde(rename_all = "lowercase")]
19pub enum PoolingMode {
20 #[default]
22 Session,
23 Transaction,
25 Statement,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31#[serde(rename_all = "lowercase")]
32pub enum PreparedStatementMode {
33 #[default]
35 Disable,
36 Track,
38 Named,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PoolModeConfig {
45 #[serde(default)]
47 pub mode: PoolingMode,
48 #[serde(default = "default_pool_mode_max_size")]
50 pub max_pool_size: u32,
51 #[serde(default = "default_pool_mode_min_idle")]
53 pub min_idle: u32,
54 #[serde(default = "default_pool_mode_idle_timeout")]
56 pub idle_timeout_secs: u64,
57 #[serde(default = "default_pool_mode_max_lifetime")]
59 pub max_lifetime_secs: u64,
60 #[serde(default = "default_pool_mode_acquire_timeout")]
62 pub acquire_timeout_secs: u64,
63 #[serde(default = "default_reset_query")]
65 pub reset_query: String,
66 #[serde(default)]
68 pub prepared_statement_mode: PreparedStatementMode,
69}
70
71fn default_pool_mode_max_size() -> u32 {
72 100
73}
74
75fn default_pool_mode_min_idle() -> u32 {
76 10
77}
78
79fn default_pool_mode_idle_timeout() -> u64 {
80 600
81}
82
83fn default_pool_mode_max_lifetime() -> u64 {
84 3600
85}
86
87fn default_pool_mode_acquire_timeout() -> u64 {
88 5
89}
90
91fn default_reset_query() -> String {
92 "DISCARD ALL".to_string()
93}
94
95impl Default for PoolModeConfig {
96 fn default() -> Self {
97 Self {
98 mode: PoolingMode::default(),
99 max_pool_size: default_pool_mode_max_size(),
100 min_idle: default_pool_mode_min_idle(),
101 idle_timeout_secs: default_pool_mode_idle_timeout(),
102 max_lifetime_secs: default_pool_mode_max_lifetime(),
103 acquire_timeout_secs: default_pool_mode_acquire_timeout(),
104 reset_query: default_reset_query(),
105 prepared_statement_mode: PreparedStatementMode::default(),
106 }
107 }
108}
109
110impl PoolModeConfig {
111 pub fn session_mode() -> Self {
113 Self {
114 mode: PoolingMode::Session,
115 prepared_statement_mode: PreparedStatementMode::Named,
116 ..Default::default()
117 }
118 }
119
120 pub fn transaction_mode() -> Self {
122 Self {
123 mode: PoolingMode::Transaction,
124 prepared_statement_mode: PreparedStatementMode::Track,
125 ..Default::default()
126 }
127 }
128
129 pub fn statement_mode() -> Self {
131 Self {
132 mode: PoolingMode::Statement,
133 prepared_statement_mode: PreparedStatementMode::Disable,
134 ..Default::default()
135 }
136 }
137
138 pub fn idle_timeout(&self) -> Duration {
140 Duration::from_secs(self.idle_timeout_secs)
141 }
142
143 pub fn max_lifetime(&self) -> Duration {
145 Duration::from_secs(self.max_lifetime_secs)
146 }
147
148 pub fn acquire_timeout(&self) -> Duration {
150 Duration::from_secs(self.acquire_timeout_secs)
151 }
152}
153
154#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ProxyConfig {
161 pub listen_address: String,
163 pub admin_address: String,
165 #[serde(default)]
170 pub admin_token: Option<String>,
171 pub tr_enabled: bool,
173 pub tr_mode: TrMode,
175 pub pool: PoolConfig,
177 #[serde(default)]
179 pub pool_mode: PoolModeConfig,
180 pub load_balancer: LoadBalancerConfig,
182 pub health: HealthConfig,
184 pub nodes: Vec<NodeConfig>,
186 pub tls: Option<TlsConfig>,
188 #[serde(default = "default_write_timeout_secs")]
191 pub write_timeout_secs: u64,
192 #[serde(default)]
196 pub plugins: PluginToml,
197 #[serde(default)]
201 pub hba: Vec<HbaRule>,
202 #[serde(default)]
205 pub auth: AuthConfig,
206 #[serde(default)]
208 pub mcp: McpConfig,
209 #[serde(default)]
212 pub agent_contracts: Vec<crate::agent_contract::AgentContract>,
213 #[serde(default)]
216 pub http_gateway: HttpGatewayConfig,
217 #[serde(default)]
220 pub mirror: MirrorConfig,
221 #[serde(default)]
224 pub branch: BranchConfig,
225 #[serde(default)]
231 pub routing_hints: RoutingHintsConfig,
232 #[serde(default)]
236 pub rate_limit: RateLimitToml,
237 #[serde(default)]
241 pub circuit_breaker: CircuitBreakerToml,
242 #[serde(default)]
246 pub analytics: AnalyticsToml,
247 #[serde(default)]
250 pub lag_routing: LagRoutingToml,
251 #[serde(default)]
254 pub cache: CacheToml,
255 #[serde(default)]
258 pub query_rewrite: QueryRewriteToml,
259 #[serde(default)]
263 pub multi_tenancy: MultiTenancyToml,
264 #[serde(default)]
267 pub schema_routing: SchemaRoutingToml,
268 #[serde(default)]
271 pub graphql_gateway: GraphqlGatewayConfig,
272 #[serde(default = "default_true")]
279 pub optimize_unnamed_parse: bool,
280 #[serde(default = "default_drain_timeout_secs")]
286 pub shutdown_drain_timeout_secs: u64,
287}
288
289fn default_drain_timeout_secs() -> u64 {
290 60
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct BranchConfig {
297 #[serde(default)]
298 pub enabled: bool,
299 #[serde(default = "default_localhost")]
300 pub backend_host: String,
301 #[serde(default = "default_pg_port")]
302 pub backend_port: u16,
303 #[serde(default = "default_pg_user")]
305 pub admin_user: String,
306 pub admin_password: Option<String>,
307 #[serde(default = "default_admin_db")]
310 pub admin_database: String,
311 #[serde(default = "default_admin_db")]
313 pub base_database: String,
314}
315
316impl Default for BranchConfig {
317 fn default() -> Self {
318 Self {
319 enabled: false,
320 backend_host: default_localhost(),
321 backend_port: default_pg_port(),
322 admin_user: default_pg_user(),
323 admin_password: None,
324 admin_database: default_admin_db(),
325 base_database: default_admin_db(),
326 }
327 }
328}
329
330fn default_admin_db() -> String {
331 "postgres".to_string()
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
337pub struct MirrorConfig {
338 #[serde(default)]
339 pub enabled: bool,
340 #[serde(default = "default_sample_rate")]
342 pub sample_rate: f64,
343 #[serde(default = "default_true_bool")]
346 pub writes_only: bool,
347 #[serde(default = "default_mirror_queue")]
350 pub queue_size: usize,
351 #[serde(default = "default_localhost")]
352 pub backend_host: String,
353 #[serde(default = "default_pg_port")]
354 pub backend_port: u16,
355 #[serde(default = "default_pg_user")]
356 pub backend_user: String,
357 pub backend_password: Option<String>,
358 pub backend_database: Option<String>,
359 #[serde(default = "default_localhost")]
363 pub source_host: String,
364 #[serde(default = "default_pg_port")]
365 pub source_port: u16,
366 #[serde(default = "default_pg_user")]
367 pub source_user: String,
368 pub source_password: Option<String>,
369 pub source_database: Option<String>,
370}
371
372impl Default for MirrorConfig {
373 fn default() -> Self {
374 Self {
375 enabled: false,
376 sample_rate: 1.0,
377 writes_only: true,
378 queue_size: 10_000,
379 backend_host: default_localhost(),
380 backend_port: default_pg_port(),
381 backend_user: default_pg_user(),
382 backend_password: None,
383 backend_database: None,
384 source_host: default_localhost(),
385 source_port: default_pg_port(),
386 source_user: default_pg_user(),
387 source_password: None,
388 source_database: None,
389 }
390 }
391}
392
393fn default_sample_rate() -> f64 {
394 1.0
395}
396fn default_mirror_queue() -> usize {
397 10_000
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
404pub struct HttpGatewayConfig {
405 #[serde(default)]
406 pub enabled: bool,
407 #[serde(default = "default_http_gw_listen")]
408 pub listen_address: String,
409 #[serde(default = "default_localhost")]
410 pub backend_host: String,
411 #[serde(default = "default_pg_port")]
412 pub backend_port: u16,
413 #[serde(default = "default_pg_user")]
414 pub backend_user: String,
415 pub backend_password: Option<String>,
416 pub backend_database: Option<String>,
417 #[serde(default)]
419 pub auth_token: Option<String>,
420}
421
422impl Default for HttpGatewayConfig {
423 fn default() -> Self {
424 Self {
425 enabled: false,
426 listen_address: default_http_gw_listen(),
427 backend_host: default_localhost(),
428 backend_port: default_pg_port(),
429 backend_user: default_pg_user(),
430 backend_password: None,
431 backend_database: None,
432 auth_token: None,
433 }
434 }
435}
436
437fn default_http_gw_listen() -> String {
438 "127.0.0.1:9093".to_string()
439}
440
441#[derive(Debug, Clone, Serialize, Deserialize)]
446pub struct McpConfig {
447 #[serde(default)]
448 pub enabled: bool,
449 #[serde(default = "default_mcp_listen")]
451 pub listen_address: String,
452 #[serde(default = "default_localhost")]
454 pub backend_host: String,
455 #[serde(default = "default_pg_port")]
456 pub backend_port: u16,
457 #[serde(default = "default_pg_user")]
458 pub backend_user: String,
459 pub backend_password: Option<String>,
460 pub backend_database: Option<String>,
461 #[serde(default = "default_true_bool")]
464 pub read_only: bool,
465 #[serde(default)]
468 pub contract: Option<String>,
469}
470
471impl Default for McpConfig {
472 fn default() -> Self {
473 Self {
474 enabled: false,
475 listen_address: default_mcp_listen(),
476 backend_host: default_localhost(),
477 backend_port: default_pg_port(),
478 backend_user: default_pg_user(),
479 backend_password: None,
480 backend_database: None,
481 read_only: true,
482 contract: None,
483 }
484 }
485}
486
487fn default_mcp_listen() -> String {
488 "127.0.0.1:9092".to_string()
489}
490fn default_localhost() -> String {
491 "127.0.0.1".to_string()
492}
493fn default_pg_port() -> u16 {
494 5432
495}
496fn default_pg_user() -> String {
497 "postgres".to_string()
498}
499fn default_true_bool() -> bool {
500 true
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize, Default)]
505pub struct AuthConfig {
506 #[serde(default)]
510 pub mode: AuthMode,
511 #[serde(default)]
514 pub auth_file: Option<String>,
515}
516
517#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
519#[serde(rename_all = "lowercase")]
520pub enum AuthMode {
521 #[default]
523 Passthrough,
524 Scram,
526}
527
528#[derive(Debug, Clone, Serialize, Deserialize)]
535pub struct HbaRule {
536 pub action: HbaAction,
538 #[serde(default = "hba_all")]
540 pub user: String,
541 #[serde(default = "hba_all")]
543 pub database: String,
544 #[serde(default = "hba_all")]
547 pub address: String,
548}
549
550fn hba_all() -> String {
551 "all".to_string()
552}
553
554#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
556#[serde(rename_all = "lowercase")]
557pub enum HbaAction {
558 Allow,
559 Reject,
560}
561
562fn default_write_timeout_secs() -> u64 {
563 30 }
565
566#[derive(Debug, Clone, Default, Serialize, Deserialize)]
568#[serde(default)]
569pub struct GqlTableToml {
570 pub name: String,
571 pub columns: Vec<String>,
572}
573
574#[derive(Debug, Clone, Serialize, Deserialize)]
577#[serde(default)]
578pub struct GraphqlGatewayConfig {
579 pub enabled: bool,
581 pub listen_address: String,
583 pub backend_host: String,
585 pub backend_port: u16,
586 pub backend_user: String,
587 pub backend_password: Option<String>,
588 pub backend_database: Option<String>,
589 pub auth_token: Option<String>,
591 pub tables: Vec<GqlTableToml>,
593}
594
595impl Default for GraphqlGatewayConfig {
596 fn default() -> Self {
597 Self {
598 enabled: false,
599 listen_address: "0.0.0.0:9091".to_string(),
600 backend_host: "127.0.0.1".to_string(),
601 backend_port: 5432,
602 backend_user: "postgres".to_string(),
603 backend_password: None,
604 backend_database: None,
605 auth_token: None,
606 tables: Vec::new(),
607 }
608 }
609}
610
611#[derive(Debug, Clone, Default, Serialize, Deserialize)]
614#[serde(default)]
615pub struct SchemaRoutingToml {
616 pub enabled: bool,
619 pub analytics_node: String,
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
627#[serde(default)]
628pub struct MultiTenancyToml {
629 pub enabled: bool,
631 pub identify_by: String,
634 pub tenant_column: String,
636 pub tenant_tables: Vec<String>,
639 pub tenants: Vec<String>,
641}
642
643impl Default for MultiTenancyToml {
644 fn default() -> Self {
645 Self {
646 enabled: false,
647 identify_by: "application_name".to_string(),
648 tenant_column: "tenant_id".to_string(),
649 tenant_tables: Vec::new(),
650 tenants: Vec::new(),
651 }
652 }
653}
654
655#[derive(Debug, Clone, Serialize, Deserialize, Default)]
659#[serde(default)]
660pub struct RewriteRuleToml {
661 pub match_table: Option<String>,
663 pub match_regex: Option<String>,
665 pub replace_table_with: Option<String>,
667 pub append_where: Option<String>,
669 pub add_limit: Option<u32>,
671}
672
673#[derive(Debug, Clone, Default, Serialize, Deserialize)]
677#[serde(default)]
678pub struct QueryRewriteToml {
679 pub enabled: bool,
681 pub rules: Vec<RewriteRuleToml>,
683}
684
685#[derive(Debug, Clone, Serialize, Deserialize)]
689#[serde(default)]
690pub struct CacheToml {
691 pub enabled: bool,
693 pub ttl_secs: u64,
695 pub max_result_bytes: usize,
697}
698
699impl Default for CacheToml {
700 fn default() -> Self {
701 Self {
702 enabled: false,
703 ttl_secs: 300,
704 max_result_bytes: 1024 * 1024,
705 }
706 }
707}
708
709#[derive(Debug, Clone, Serialize, Deserialize)]
712#[serde(default)]
713pub struct LagRoutingToml {
714 pub enabled: bool,
716 pub ryw_window_ms: u64,
720 pub max_lag_bytes: u64,
724}
725
726impl Default for LagRoutingToml {
727 fn default() -> Self {
728 Self {
729 enabled: false,
730 ryw_window_ms: 500,
731 max_lag_bytes: 0,
732 }
733 }
734}
735
736#[derive(Debug, Clone, Serialize, Deserialize)]
740#[serde(default)]
741pub struct AnalyticsToml {
742 pub enabled: bool,
745 pub slow_query_ms: u64,
747 pub max_fingerprints: u32,
749}
750
751impl Default for AnalyticsToml {
752 fn default() -> Self {
753 Self {
754 enabled: false,
755 slow_query_ms: 1000,
756 max_fingerprints: 10000,
757 }
758 }
759}
760
761#[derive(Debug, Clone, Serialize, Deserialize)]
765#[serde(default)]
766pub struct CircuitBreakerToml {
767 pub enabled: bool,
769 pub failure_threshold: u32,
772 pub open_secs: u64,
774 pub success_threshold: u32,
776}
777
778impl Default for CircuitBreakerToml {
779 fn default() -> Self {
780 Self {
781 enabled: false,
782 failure_threshold: 5,
783 open_secs: 10,
784 success_threshold: 3,
785 }
786 }
787}
788
789#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
791#[serde(rename_all = "snake_case")]
792pub enum RateLimitKeyBy {
793 #[default]
795 User,
796 ClientIp,
798 Database,
800 Global,
802}
803
804#[derive(Debug, Clone, Serialize, Deserialize)]
809#[serde(default)]
810pub struct RateLimitToml {
811 pub enabled: bool,
813 pub default_qps: u32,
815 pub default_burst: u32,
817 pub max_concurrent: u32,
819 pub key_by: RateLimitKeyBy,
821}
822
823impl Default for RateLimitToml {
824 fn default() -> Self {
825 Self {
826 enabled: false,
827 default_qps: 1000,
828 default_burst: 2000,
829 max_concurrent: 0,
830 key_by: RateLimitKeyBy::User,
831 }
832 }
833}
834
835#[derive(Debug, Clone, Serialize, Deserialize)]
841#[serde(default)]
842pub struct RoutingHintsConfig {
843 pub enabled: bool,
846 pub strip_hints: bool,
850}
851
852impl Default for RoutingHintsConfig {
853 fn default() -> Self {
854 Self {
855 enabled: false,
856 strip_hints: true,
857 }
858 }
859}
860
861impl Default for ProxyConfig {
862 fn default() -> Self {
863 Self {
864 listen_address: "0.0.0.0:5432".to_string(),
865 admin_address: "0.0.0.0:9090".to_string(),
866 admin_token: None,
867 tr_enabled: true,
868 tr_mode: TrMode::Session,
869 pool: PoolConfig::default(),
870 pool_mode: PoolModeConfig::default(),
871 load_balancer: LoadBalancerConfig::default(),
872 health: HealthConfig::default(),
873 nodes: Vec::new(),
874 tls: None,
875 write_timeout_secs: default_write_timeout_secs(),
876 plugins: PluginToml::default(),
877 hba: Vec::new(),
878 auth: AuthConfig::default(),
879 mcp: McpConfig::default(),
880 agent_contracts: Vec::new(),
881 http_gateway: HttpGatewayConfig::default(),
882 mirror: MirrorConfig::default(),
883 branch: BranchConfig::default(),
884 routing_hints: RoutingHintsConfig::default(),
885 rate_limit: RateLimitToml::default(),
886 circuit_breaker: CircuitBreakerToml::default(),
887 analytics: AnalyticsToml::default(),
888 lag_routing: LagRoutingToml::default(),
889 cache: CacheToml::default(),
890 query_rewrite: QueryRewriteToml::default(),
891 multi_tenancy: MultiTenancyToml::default(),
892 schema_routing: SchemaRoutingToml::default(),
893 graphql_gateway: GraphqlGatewayConfig::default(),
894 optimize_unnamed_parse: true,
895 shutdown_drain_timeout_secs: default_drain_timeout_secs(),
896 }
897 }
898}
899
900#[derive(Debug, Clone, Serialize, Deserialize)]
914pub struct PluginToml {
915 #[serde(default)]
918 pub enabled: bool,
919 #[serde(default = "default_plugin_dir")]
921 pub plugin_dir: String,
922 #[serde(default)]
924 pub hot_reload: bool,
925 #[serde(default = "default_plugin_memory_mb")]
927 pub memory_limit_mb: usize,
928 #[serde(default = "default_plugin_timeout_ms")]
930 pub timeout_ms: u64,
931 #[serde(default = "default_plugin_max")]
933 pub max_plugins: usize,
934 #[serde(default = "default_true")]
936 pub fuel_metering: bool,
937 #[serde(default = "default_plugin_fuel")]
939 pub fuel_limit: u64,
940 #[serde(default)]
946 pub trust_root: Option<String>,
947}
948
949fn default_plugin_dir() -> String {
950 "/etc/heliosproxy/plugins".to_string()
951}
952fn default_plugin_memory_mb() -> usize {
953 64
954}
955fn default_plugin_timeout_ms() -> u64 {
956 100
957}
958fn default_plugin_max() -> usize {
959 20
960}
961fn default_true() -> bool {
962 true
963}
964fn default_plugin_fuel() -> u64 {
965 1_000_000
966}
967
968impl Default for PluginToml {
969 fn default() -> Self {
970 Self {
971 enabled: false,
972 plugin_dir: default_plugin_dir(),
973 hot_reload: false,
974 memory_limit_mb: default_plugin_memory_mb(),
975 timeout_ms: default_plugin_timeout_ms(),
976 max_plugins: default_plugin_max(),
977 fuel_metering: true,
978 fuel_limit: default_plugin_fuel(),
979 trust_root: None,
980 }
981 }
982}
983
984impl ProxyConfig {
985 pub fn write_timeout(&self) -> Duration {
987 Duration::from_secs(self.write_timeout_secs)
988 }
989
990 pub fn from_file(path: &str) -> Result<Self> {
992 let path = Path::new(path);
993
994 if !path.exists() {
995 return Err(ProxyError::Config(format!(
996 "Configuration file not found: {}",
997 path.display()
998 )));
999 }
1000
1001 let contents = std::fs::read_to_string(path)
1002 .map_err(|e| ProxyError::Config(format!("Failed to read config: {}", e)))?;
1003
1004 let config: Self = toml::from_str(&contents)
1005 .map_err(|e| ProxyError::Config(format!("Failed to parse config: {}", e)))?;
1006
1007 config.validate()?;
1008
1009 Ok(config)
1010 }
1011
1012 pub fn add_node(&mut self, host_port: &str, role: &str) -> Result<()> {
1014 let parts: Vec<&str> = host_port.rsplitn(2, ':').collect();
1015 if parts.len() != 2 {
1016 return Err(ProxyError::Config(format!(
1017 "Invalid host:port format: {}",
1018 host_port
1019 )));
1020 }
1021
1022 let port: u16 = parts[0]
1023 .parse()
1024 .map_err(|_| ProxyError::Config(format!("Invalid port: {}", parts[0])))?;
1025
1026 let host = parts[1].to_string();
1027
1028 let role = match role {
1029 "primary" => NodeRole::Primary,
1030 "standby" => NodeRole::Standby,
1031 "replica" => NodeRole::ReadReplica,
1032 _ => return Err(ProxyError::Config(format!("Unknown role: {}", role))),
1033 };
1034
1035 self.nodes.push(NodeConfig {
1036 host,
1037 port,
1038 http_port: default_http_port(),
1039 role,
1040 weight: 100,
1041 enabled: true,
1042 name: None,
1043 });
1044
1045 Ok(())
1046 }
1047
1048 pub fn validate(&self) -> Result<()> {
1050 if self.nodes.is_empty() {
1052 return Err(ProxyError::Config(
1053 "No backend nodes configured".to_string(),
1054 ));
1055 }
1056
1057 let has_primary = self.nodes.iter().any(|n| n.role == NodeRole::Primary);
1059 if !has_primary {
1060 return Err(ProxyError::Config("No primary node configured".to_string()));
1061 }
1062
1063 if self.pool.max_connections < self.pool.min_connections {
1065 return Err(ProxyError::Config(
1066 "max_connections must be >= min_connections".to_string(),
1067 ));
1068 }
1069
1070 Ok(())
1071 }
1072
1073 pub fn primary_node(&self) -> Option<&NodeConfig> {
1075 self.nodes
1076 .iter()
1077 .find(|n| n.role == NodeRole::Primary && n.enabled)
1078 }
1079
1080 pub fn standby_nodes(&self) -> Vec<&NodeConfig> {
1082 self.nodes
1083 .iter()
1084 .filter(|n| n.role == NodeRole::Standby && n.enabled)
1085 .collect()
1086 }
1087
1088 pub fn enabled_nodes(&self) -> Vec<&NodeConfig> {
1090 self.nodes.iter().filter(|n| n.enabled).collect()
1091 }
1092}
1093
1094#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1096#[serde(rename_all = "lowercase")]
1097#[derive(Default)]
1098pub enum TrMode {
1099 None,
1101 #[default]
1103 Session,
1104 Select,
1106 Transaction,
1108}
1109
1110#[derive(Debug, Clone, Serialize, Deserialize)]
1112pub struct PoolConfig {
1113 pub min_connections: usize,
1115 pub max_connections: usize,
1117 pub idle_timeout_secs: u64,
1119 pub max_lifetime_secs: u64,
1121 pub acquire_timeout_secs: u64,
1123 pub test_on_acquire: bool,
1125}
1126
1127impl Default for PoolConfig {
1128 fn default() -> Self {
1129 Self {
1130 min_connections: 2,
1131 max_connections: 100,
1132 idle_timeout_secs: 300,
1133 max_lifetime_secs: 1800,
1134 acquire_timeout_secs: 30,
1135 test_on_acquire: true,
1136 }
1137 }
1138}
1139
1140impl PoolConfig {
1141 pub fn idle_timeout(&self) -> Duration {
1143 Duration::from_secs(self.idle_timeout_secs)
1144 }
1145
1146 pub fn max_lifetime(&self) -> Duration {
1148 Duration::from_secs(self.max_lifetime_secs)
1149 }
1150
1151 pub fn acquire_timeout(&self) -> Duration {
1153 Duration::from_secs(self.acquire_timeout_secs)
1154 }
1155}
1156
1157#[derive(Debug, Clone, Serialize, Deserialize)]
1159pub struct LoadBalancerConfig {
1160 pub read_strategy: Strategy,
1162 pub read_write_split: bool,
1164 pub latency_threshold_ms: u64,
1166}
1167
1168impl Default for LoadBalancerConfig {
1169 fn default() -> Self {
1170 Self {
1171 read_strategy: Strategy::RoundRobin,
1172 read_write_split: true,
1173 latency_threshold_ms: 100,
1174 }
1175 }
1176}
1177
1178#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1180#[serde(rename_all = "snake_case")]
1181pub enum Strategy {
1182 RoundRobin,
1184 WeightedRoundRobin,
1186 LeastConnections,
1188 LatencyBased,
1190 Random,
1192}
1193
1194#[derive(Debug, Clone, Serialize, Deserialize)]
1196pub struct HealthConfig {
1197 pub check_interval_secs: u64,
1199 pub check_timeout_secs: u64,
1201 pub failure_threshold: u32,
1203 pub success_threshold: u32,
1205 pub check_query: String,
1207}
1208
1209impl Default for HealthConfig {
1210 fn default() -> Self {
1211 Self {
1212 check_interval_secs: 5,
1213 check_timeout_secs: 3,
1214 failure_threshold: 3,
1215 success_threshold: 2,
1216 check_query: "SELECT 1".to_string(),
1217 }
1218 }
1219}
1220
1221impl HealthConfig {
1222 pub fn check_interval(&self) -> Duration {
1224 Duration::from_secs(self.check_interval_secs)
1225 }
1226
1227 pub fn check_timeout(&self) -> Duration {
1229 Duration::from_secs(self.check_timeout_secs)
1230 }
1231}
1232
1233#[derive(Debug, Clone, Serialize, Deserialize)]
1235pub struct NodeConfig {
1236 pub host: String,
1238 pub port: u16,
1240 #[serde(default = "default_http_port")]
1243 pub http_port: u16,
1244 pub role: NodeRole,
1246 pub weight: u32,
1248 pub enabled: bool,
1250 pub name: Option<String>,
1252}
1253
1254fn default_http_port() -> u16 {
1255 8080
1256}
1257
1258impl NodeConfig {
1259 pub fn address(&self) -> String {
1261 format!("{}:{}", self.host, self.port)
1262 }
1263
1264 pub fn display_name(&self) -> &str {
1266 self.name.as_deref().unwrap_or(&self.host)
1267 }
1268}
1269
1270#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1272#[serde(rename_all = "lowercase")]
1273pub enum NodeRole {
1274 Primary,
1276 Standby,
1278 #[serde(rename = "replica")]
1280 ReadReplica,
1281}
1282
1283#[derive(Debug, Clone, Serialize, Deserialize)]
1285pub struct TlsConfig {
1286 pub enabled: bool,
1288 pub cert_path: String,
1290 pub key_path: String,
1292 pub ca_path: Option<String>,
1294 pub require_client_cert: bool,
1296}
1297
1298#[cfg(test)]
1299mod tests {
1300 use super::*;
1301
1302 #[test]
1303 fn test_default_config() {
1304 let config = ProxyConfig::default();
1305 assert_eq!(config.listen_address, "0.0.0.0:5432");
1306 assert!(config.tr_enabled);
1307 }
1308
1309 #[test]
1310 fn test_add_node() {
1311 let mut config = ProxyConfig::default();
1312 config.add_node("localhost:5432", "primary").unwrap();
1313 config.add_node("localhost:5433", "standby").unwrap();
1314
1315 assert_eq!(config.nodes.len(), 2);
1316 assert!(config.primary_node().is_some());
1317 assert_eq!(config.standby_nodes().len(), 1);
1318 }
1319
1320 #[test]
1321 fn test_validate_no_nodes() {
1322 let config = ProxyConfig::default();
1323 assert!(config.validate().is_err());
1324 }
1325
1326 #[test]
1327 fn test_validate_no_primary() {
1328 let mut config = ProxyConfig::default();
1329 config.add_node("localhost:5432", "standby").unwrap();
1330 assert!(config.validate().is_err());
1331 }
1332
1333 #[test]
1334 fn test_validate_success() {
1335 let mut config = ProxyConfig::default();
1336 config.add_node("localhost:5432", "primary").unwrap();
1337 assert!(config.validate().is_ok());
1338 }
1339
1340 #[test]
1341 fn test_pool_config_durations() {
1342 let config = PoolConfig::default();
1343 assert_eq!(config.idle_timeout(), Duration::from_secs(300));
1344 assert_eq!(config.max_lifetime(), Duration::from_secs(1800));
1345 }
1346
1347 #[test]
1348 fn test_pool_mode_default() {
1349 let config = PoolModeConfig::default();
1350 assert_eq!(config.mode, PoolingMode::Session);
1351 assert_eq!(config.max_pool_size, 100);
1352 assert_eq!(config.min_idle, 10);
1353 assert_eq!(config.reset_query, "DISCARD ALL");
1354 }
1355
1356 #[test]
1357 fn test_pool_mode_session() {
1358 let config = PoolModeConfig::session_mode();
1359 assert_eq!(config.mode, PoolingMode::Session);
1360 assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Named);
1361 }
1362
1363 #[test]
1364 fn test_pool_mode_transaction() {
1365 let config = PoolModeConfig::transaction_mode();
1366 assert_eq!(config.mode, PoolingMode::Transaction);
1367 assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Track);
1368 }
1369
1370 #[test]
1371 fn test_pool_mode_statement() {
1372 let config = PoolModeConfig::statement_mode();
1373 assert_eq!(config.mode, PoolingMode::Statement);
1374 assert_eq!(
1375 config.prepared_statement_mode,
1376 PreparedStatementMode::Disable
1377 );
1378 }
1379
1380 #[test]
1381 fn test_pool_mode_durations() {
1382 let config = PoolModeConfig::default();
1383 assert_eq!(config.idle_timeout(), Duration::from_secs(600));
1384 assert_eq!(config.max_lifetime(), Duration::from_secs(3600));
1385 assert_eq!(config.acquire_timeout(), Duration::from_secs(5));
1386 }
1387
1388 #[test]
1389 fn test_proxy_config_has_pool_mode() {
1390 let config = ProxyConfig::default();
1391 assert_eq!(config.pool_mode.mode, PoolingMode::Session);
1392 }
1393
1394 #[test]
1398 fn test_plugin_toml_default_is_disabled() {
1399 let config = ProxyConfig::default();
1400 assert!(!config.plugins.enabled);
1401 assert_eq!(config.plugins.plugin_dir, "/etc/heliosproxy/plugins");
1402 assert_eq!(config.plugins.memory_limit_mb, 64);
1403 assert_eq!(config.plugins.timeout_ms, 100);
1404 }
1405
1406 #[test]
1410 fn test_proxy_config_toml_without_plugins_section_still_parses() {
1411 let toml_text = r#"
1412 listen_address = "0.0.0.0:5432"
1413 admin_address = "0.0.0.0:9090"
1414 tr_enabled = true
1415 tr_mode = "session"
1416 nodes = []
1417
1418 [pool]
1419 min_connections = 2
1420 max_connections = 10
1421 idle_timeout_secs = 300
1422 max_lifetime_secs = 1800
1423 acquire_timeout_secs = 30
1424 test_on_acquire = true
1425
1426 [load_balancer]
1427 read_strategy = "round_robin"
1428 read_write_split = true
1429 latency_threshold_ms = 100
1430
1431 [health]
1432 check_interval_secs = 5
1433 check_timeout_secs = 3
1434 failure_threshold = 3
1435 success_threshold = 2
1436 check_query = "SELECT 1"
1437 "#;
1438 let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
1439 assert!(!config.plugins.enabled);
1440 }
1441
1442 #[test]
1445 fn test_plugin_toml_overrides_parse() {
1446 let toml_text = r#"
1447 listen_address = "0.0.0.0:5432"
1448 admin_address = "0.0.0.0:9090"
1449 tr_enabled = true
1450 tr_mode = "session"
1451 nodes = []
1452
1453 [pool]
1454 min_connections = 2
1455 max_connections = 10
1456 idle_timeout_secs = 300
1457 max_lifetime_secs = 1800
1458 acquire_timeout_secs = 30
1459 test_on_acquire = true
1460
1461 [load_balancer]
1462 read_strategy = "round_robin"
1463 read_write_split = true
1464 latency_threshold_ms = 100
1465
1466 [health]
1467 check_interval_secs = 5
1468 check_timeout_secs = 3
1469 failure_threshold = 3
1470 success_threshold = 2
1471 check_query = "SELECT 1"
1472
1473 [plugins]
1474 enabled = true
1475 plugin_dir = "/tmp/helios-plugins"
1476 hot_reload = true
1477 memory_limit_mb = 128
1478 timeout_ms = 250
1479 "#;
1480 let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
1481 assert!(config.plugins.enabled);
1482 assert_eq!(config.plugins.plugin_dir, "/tmp/helios-plugins");
1483 assert!(config.plugins.hot_reload);
1484 assert_eq!(config.plugins.memory_limit_mb, 128);
1485 assert_eq!(config.plugins.timeout_ms, 250);
1486 assert_eq!(config.plugins.max_plugins, 20);
1488 assert!(config.plugins.fuel_metering);
1489 }
1490}