Skip to main content

heliosdb_proxy/
config.rs

1//! Proxy Configuration
2//!
3//! Configuration management for HeliosDB Proxy.
4
5use crate::{ProxyError, Result};
6use serde::{Deserialize, Serialize};
7use std::path::Path;
8use std::time::Duration;
9
10// =============================================================================
11// POOL MODE TYPES
12// =============================================================================
13
14/// Connection pooling mode
15///
16/// Determines when connections are returned to the pool.
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
18#[serde(rename_all = "lowercase")]
19pub enum PoolingMode {
20    /// Session mode: 1:1 client-to-backend mapping
21    #[default]
22    Session,
23    /// Transaction mode: Return after COMMIT/ROLLBACK
24    Transaction,
25    /// Statement mode: Return after each statement
26    Statement,
27}
28
29/// Prepared statement handling mode
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
31#[serde(rename_all = "lowercase")]
32pub enum PreparedStatementMode {
33    /// Disable prepared statements
34    #[default]
35    Disable,
36    /// Track and recreate on new connections
37    Track,
38    /// Use protocol-level named statements
39    Named,
40}
41
42/// Pool mode configuration
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct PoolModeConfig {
45    /// Default pooling mode
46    #[serde(default)]
47    pub mode: PoolingMode,
48    /// Maximum connections per node
49    #[serde(default = "default_pool_mode_max_size")]
50    pub max_pool_size: u32,
51    /// Minimum idle connections
52    #[serde(default = "default_pool_mode_min_idle")]
53    pub min_idle: u32,
54    /// Idle timeout (seconds)
55    #[serde(default = "default_pool_mode_idle_timeout")]
56    pub idle_timeout_secs: u64,
57    /// Max connection lifetime (seconds)
58    #[serde(default = "default_pool_mode_max_lifetime")]
59    pub max_lifetime_secs: u64,
60    /// Acquire timeout (seconds)
61    #[serde(default = "default_pool_mode_acquire_timeout")]
62    pub acquire_timeout_secs: u64,
63    /// Reset query to run when returning connection to pool
64    #[serde(default = "default_reset_query")]
65    pub reset_query: String,
66    /// Prepared statement mode
67    #[serde(default)]
68    pub prepared_statement_mode: PreparedStatementMode,
69}
70
71fn default_pool_mode_max_size() -> u32 {
72    100
73}
74
75fn default_pool_mode_min_idle() -> u32 {
76    10
77}
78
79fn default_pool_mode_idle_timeout() -> u64 {
80    600
81}
82
83fn default_pool_mode_max_lifetime() -> u64 {
84    3600
85}
86
87fn default_pool_mode_acquire_timeout() -> u64 {
88    5
89}
90
91fn default_reset_query() -> String {
92    "DISCARD ALL".to_string()
93}
94
95impl Default for PoolModeConfig {
96    fn default() -> Self {
97        Self {
98            mode: PoolingMode::default(),
99            max_pool_size: default_pool_mode_max_size(),
100            min_idle: default_pool_mode_min_idle(),
101            idle_timeout_secs: default_pool_mode_idle_timeout(),
102            max_lifetime_secs: default_pool_mode_max_lifetime(),
103            acquire_timeout_secs: default_pool_mode_acquire_timeout(),
104            reset_query: default_reset_query(),
105            prepared_statement_mode: PreparedStatementMode::default(),
106        }
107    }
108}
109
110impl PoolModeConfig {
111    /// Create config for session mode
112    pub fn session_mode() -> Self {
113        Self {
114            mode: PoolingMode::Session,
115            prepared_statement_mode: PreparedStatementMode::Named,
116            ..Default::default()
117        }
118    }
119
120    /// Create config for transaction mode
121    pub fn transaction_mode() -> Self {
122        Self {
123            mode: PoolingMode::Transaction,
124            prepared_statement_mode: PreparedStatementMode::Track,
125            ..Default::default()
126        }
127    }
128
129    /// Create config for statement mode
130    pub fn statement_mode() -> Self {
131        Self {
132            mode: PoolingMode::Statement,
133            prepared_statement_mode: PreparedStatementMode::Disable,
134            ..Default::default()
135        }
136    }
137
138    /// Get idle timeout as Duration
139    pub fn idle_timeout(&self) -> Duration {
140        Duration::from_secs(self.idle_timeout_secs)
141    }
142
143    /// Get max lifetime as Duration
144    pub fn max_lifetime(&self) -> Duration {
145        Duration::from_secs(self.max_lifetime_secs)
146    }
147
148    /// Get acquire timeout as Duration
149    pub fn acquire_timeout(&self) -> Duration {
150        Duration::from_secs(self.acquire_timeout_secs)
151    }
152}
153
154// =============================================================================
155// MAIN PROXY CONFIG
156// =============================================================================
157
158/// Proxy configuration
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct ProxyConfig {
161    /// Listen address for client connections
162    pub listen_address: String,
163    /// Admin API address
164    pub admin_address: String,
165    /// Bearer token required on admin API requests. When set, every admin
166    /// endpoint except liveness probes (`/health*`, `/livez`, `/readyz`)
167    /// requires `Authorization: Bearer <token>`. Absent (default) = open
168    /// (current behaviour) — set this for any non-loopback deployment.
169    #[serde(default)]
170    pub admin_token: Option<String>,
171    /// Enable TR (Transaction Replay)
172    pub tr_enabled: bool,
173    /// TR mode
174    pub tr_mode: TrMode,
175    /// Connection pool configuration
176    pub pool: PoolConfig,
177    /// Pool mode configuration (Session/Transaction/Statement)
178    #[serde(default)]
179    pub pool_mode: PoolModeConfig,
180    /// Load balancer configuration
181    pub load_balancer: LoadBalancerConfig,
182    /// Health check configuration
183    pub health: HealthConfig,
184    /// Backend nodes
185    pub nodes: Vec<NodeConfig>,
186    /// TLS configuration
187    pub tls: Option<TlsConfig>,
188    /// Write timeout during failover (seconds)
189    /// When primary is unavailable, wait this long for a new primary before returning error
190    #[serde(default = "default_write_timeout_secs")]
191    pub write_timeout_secs: u64,
192    /// Plugin system configuration. Only consumed when the `wasm-plugins`
193    /// feature is enabled; on a feature-off build, values are parsed and
194    /// ignored so existing configs don't break.
195    #[serde(default)]
196    pub plugins: PluginToml,
197    /// pg_hba-style connection admission rules, evaluated in order before any
198    /// backend connection is opened. Empty (the default) means admit all
199    /// (current behaviour preserved).
200    #[serde(default)]
201    pub hba: Vec<HbaRule>,
202    /// Client authentication mode. Absent/default = pass-through (the proxy
203    /// relays the client's auth to the backend, current behaviour).
204    #[serde(default)]
205    pub auth: AuthConfig,
206    /// MCP (Model Context Protocol) agent gateway. Disabled by default.
207    #[serde(default)]
208    pub mcp: McpConfig,
209    /// Per-agent SQL contracts (scoped grants). Referenced by id from the
210    /// MCP gateway (`[mcp] contract`). Empty by default.
211    #[serde(default)]
212    pub agent_contracts: Vec<crate::agent_contract::AgentContract>,
213    /// HTTP SQL gateway (Neon-serverless-driver compatible). Disabled by
214    /// default — lets edge/serverless clients run SQL over HTTP.
215    #[serde(default)]
216    pub http_gateway: HttpGatewayConfig,
217    /// Continuous traffic mirroring to a secondary backend. Disabled by
218    /// default — the on-ramp to a PG->Nano migration mirror.
219    #[serde(default)]
220    pub mirror: MirrorConfig,
221    /// Instant branch databases. Disabled by default — provisions
222    /// CREATE DATABASE ... TEMPLATE clones through the proxy.
223    #[serde(default)]
224    pub branch: BranchConfig,
225    /// Proxy-side unnamed-`Parse` promotion (Batch H). When a client re-sends an
226    /// identical unnamed extended `Parse` (the dominant pgbench/ORM pattern),
227    /// the proxy skips forwarding it to a backend that already holds that exact
228    /// unnamed statement and synthesizes the `ParseComplete` locally — cutting
229    /// the per-cycle re-`Parse` overhead. Default on; a kill-switch for drivers
230    /// that somehow depend on the redundant round trip.
231    #[serde(default = "default_true")]
232    pub optimize_unnamed_parse: bool,
233    /// How long a graceful binary-handoff drain (SIGUSR2) keeps serving
234    /// in-flight connections before the old process exits (Batch H). After this
235    /// many seconds, any still-open connections are dropped so the handoff
236    /// completes in bounded time. Overridable at runtime via the
237    /// `HELIOS_DRAIN_TIMEOUT_SECS` env var.
238    #[serde(default = "default_drain_timeout_secs")]
239    pub shutdown_drain_timeout_secs: u64,
240}
241
242fn default_drain_timeout_secs() -> u64 {
243    60
244}
245
246/// Branch-database configuration: the maintenance connection the proxy uses
247/// to provision `CREATE DATABASE <branch> TEMPLATE <base>` clones.
248#[derive(Debug, Clone, Serialize, Deserialize)]
249pub struct BranchConfig {
250    #[serde(default)]
251    pub enabled: bool,
252    #[serde(default = "default_localhost")]
253    pub backend_host: String,
254    #[serde(default = "default_pg_port")]
255    pub backend_port: u16,
256    /// A role with CREATEDB privilege.
257    #[serde(default = "default_pg_user")]
258    pub admin_user: String,
259    pub admin_password: Option<String>,
260    /// Maintenance database to issue CREATE/DROP DATABASE against (not the
261    /// branch itself). Defaults to "postgres".
262    #[serde(default = "default_admin_db")]
263    pub admin_database: String,
264    /// Default template database to branch from when a request omits `base`.
265    #[serde(default = "default_admin_db")]
266    pub base_database: String,
267}
268
269impl Default for BranchConfig {
270    fn default() -> Self {
271        Self {
272            enabled: false,
273            backend_host: default_localhost(),
274            backend_port: default_pg_port(),
275            admin_user: default_pg_user(),
276            admin_password: None,
277            admin_database: default_admin_db(),
278            base_database: default_admin_db(),
279        }
280    }
281}
282
283fn default_admin_db() -> String {
284    "postgres".to_string()
285}
286
287/// Traffic-mirror configuration: replay a sampled share of live (simple-query)
288/// writes to a secondary backend, asynchronously and off the client hot path.
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct MirrorConfig {
291    #[serde(default)]
292    pub enabled: bool,
293    /// Fraction of eligible statements to mirror, 0.0..=1.0.
294    #[serde(default = "default_sample_rate")]
295    pub sample_rate: f64,
296    /// Mirror only write/DDL statements (default). When false, all simple
297    /// queries are mirrored.
298    #[serde(default = "default_true_bool")]
299    pub writes_only: bool,
300    /// Bounded queue depth; when full, statements are dropped (and counted)
301    /// rather than blocking the client path.
302    #[serde(default = "default_mirror_queue")]
303    pub queue_size: usize,
304    #[serde(default = "default_localhost")]
305    pub backend_host: String,
306    #[serde(default = "default_pg_port")]
307    pub backend_port: u16,
308    #[serde(default = "default_pg_user")]
309    pub backend_user: String,
310    pub backend_password: Option<String>,
311    pub backend_database: Option<String>,
312    /// Source (primary) connection used by `POST /api/migration/snapshot` to
313    /// read existing data when bootstrapping the secondary. Defaults mirror
314    /// the listener-side backend; set explicitly for a snapshot.
315    #[serde(default = "default_localhost")]
316    pub source_host: String,
317    #[serde(default = "default_pg_port")]
318    pub source_port: u16,
319    #[serde(default = "default_pg_user")]
320    pub source_user: String,
321    pub source_password: Option<String>,
322    pub source_database: Option<String>,
323}
324
325impl Default for MirrorConfig {
326    fn default() -> Self {
327        Self {
328            enabled: false,
329            sample_rate: 1.0,
330            writes_only: true,
331            queue_size: 10_000,
332            backend_host: default_localhost(),
333            backend_port: default_pg_port(),
334            backend_user: default_pg_user(),
335            backend_password: None,
336            backend_database: None,
337            source_host: default_localhost(),
338            source_port: default_pg_port(),
339            source_user: default_pg_user(),
340            source_password: None,
341            source_database: None,
342        }
343    }
344}
345
346fn default_sample_rate() -> f64 {
347    1.0
348}
349fn default_mirror_queue() -> usize {
350    10_000
351}
352
353/// HTTP SQL gateway configuration. A Neon-`@neondatabase/serverless`-style
354/// `POST /sql` endpoint that runs one statement over the backend PG-wire
355/// client and returns `{ command, rowCount, rows, fields }`.
356#[derive(Debug, Clone, Serialize, Deserialize)]
357pub struct HttpGatewayConfig {
358    #[serde(default)]
359    pub enabled: bool,
360    #[serde(default = "default_http_gw_listen")]
361    pub listen_address: String,
362    #[serde(default = "default_localhost")]
363    pub backend_host: String,
364    #[serde(default = "default_pg_port")]
365    pub backend_port: u16,
366    #[serde(default = "default_pg_user")]
367    pub backend_user: String,
368    pub backend_password: Option<String>,
369    pub backend_database: Option<String>,
370    /// Optional Bearer token required on requests.
371    #[serde(default)]
372    pub auth_token: Option<String>,
373}
374
375impl Default for HttpGatewayConfig {
376    fn default() -> Self {
377        Self {
378            enabled: false,
379            listen_address: default_http_gw_listen(),
380            backend_host: default_localhost(),
381            backend_port: default_pg_port(),
382            backend_user: default_pg_user(),
383            backend_password: None,
384            backend_database: None,
385            auth_token: None,
386        }
387    }
388}
389
390fn default_http_gw_listen() -> String {
391    "127.0.0.1:9093".to_string()
392}
393
394/// MCP agent-gateway configuration. When enabled, the proxy exposes a native
395/// MCP server so AI agents call `query`/`list_tables`/`explain` tools instead
396/// of opening raw SQL connections — each call gated by the gateway's policy
397/// (read-only by default) and logged.
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct McpConfig {
400    #[serde(default)]
401    pub enabled: bool,
402    /// HTTP listen address for the MCP JSON-RPC endpoint.
403    #[serde(default = "default_mcp_listen")]
404    pub listen_address: String,
405    /// Backend the gateway runs tool SQL against.
406    #[serde(default = "default_localhost")]
407    pub backend_host: String,
408    #[serde(default = "default_pg_port")]
409    pub backend_port: u16,
410    #[serde(default = "default_pg_user")]
411    pub backend_user: String,
412    pub backend_password: Option<String>,
413    pub backend_database: Option<String>,
414    /// When true (default), the gateway refuses write/DDL statements — agents
415    /// get a read-only database surface.
416    #[serde(default = "default_true_bool")]
417    pub read_only: bool,
418    /// Name of an `[[agent_contracts]]` entry to enforce on every tool call
419    /// (scoped grants + repair hints). None = only the `read_only` guardrail.
420    #[serde(default)]
421    pub contract: Option<String>,
422}
423
424impl Default for McpConfig {
425    fn default() -> Self {
426        Self {
427            enabled: false,
428            listen_address: default_mcp_listen(),
429            backend_host: default_localhost(),
430            backend_port: default_pg_port(),
431            backend_user: default_pg_user(),
432            backend_password: None,
433            backend_database: None,
434            read_only: true,
435            contract: None,
436        }
437    }
438}
439
440fn default_mcp_listen() -> String {
441    "127.0.0.1:9092".to_string()
442}
443fn default_localhost() -> String {
444    "127.0.0.1".to_string()
445}
446fn default_pg_port() -> u16 {
447    5432
448}
449fn default_pg_user() -> String {
450    "postgres".to_string()
451}
452fn default_true_bool() -> bool {
453    true
454}
455
456/// Client-side authentication configuration.
457#[derive(Debug, Clone, Serialize, Deserialize, Default)]
458pub struct AuthConfig {
459    /// `passthrough` (default) relays client auth to the backend.
460    /// `scram` makes the proxy terminate SCRAM-SHA-256 itself against
461    /// `auth_file`, becoming the auth boundary (foundation for pooling).
462    #[serde(default)]
463    pub mode: AuthMode,
464    /// Path to a pgbouncer-style user list (`user:secret`, secret = plaintext
465    /// or a `SCRAM-SHA-256$...` verifier). Required when `mode = "scram"`.
466    #[serde(default)]
467    pub auth_file: Option<String>,
468}
469
470/// Proxy client-authentication mode.
471#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
472#[serde(rename_all = "lowercase")]
473pub enum AuthMode {
474    /// Relay the client's auth exchange straight to the backend.
475    #[default]
476    Passthrough,
477    /// Terminate SCRAM-SHA-256 at the proxy against `auth_file`.
478    Scram,
479}
480
481/// A single pg_hba-style admission rule. The first rule whose `user`,
482/// `database`, and `address` all match the incoming connection decides the
483/// outcome (`allow`/`reject`). If no rule matches, the connection is
484/// admitted (rules are an explicit deny/allow list, not default-deny — add a
485/// trailing `{ action = "reject", user = "all", database = "all", address =
486/// "all" }` for default-deny).
487#[derive(Debug, Clone, Serialize, Deserialize)]
488pub struct HbaRule {
489    /// "allow" or "reject".
490    pub action: HbaAction,
491    /// Matching PostgreSQL user, or "all".
492    #[serde(default = "hba_all")]
493    pub user: String,
494    /// Matching database, or "all".
495    #[serde(default = "hba_all")]
496    pub database: String,
497    /// Matching client address: "all", a bare IP, or a CIDR (e.g.
498    /// "10.0.0.0/8", "::1/128").
499    #[serde(default = "hba_all")]
500    pub address: String,
501}
502
503fn hba_all() -> String {
504    "all".to_string()
505}
506
507/// Admission action for an [`HbaRule`].
508#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
509#[serde(rename_all = "lowercase")]
510pub enum HbaAction {
511    Allow,
512    Reject,
513}
514
515fn default_write_timeout_secs() -> u64 {
516    30 // 30 seconds default write timeout during failover
517}
518
519impl Default for ProxyConfig {
520    fn default() -> Self {
521        Self {
522            listen_address: "0.0.0.0:5432".to_string(),
523            admin_address: "0.0.0.0:9090".to_string(),
524            admin_token: None,
525            tr_enabled: true,
526            tr_mode: TrMode::Session,
527            pool: PoolConfig::default(),
528            pool_mode: PoolModeConfig::default(),
529            load_balancer: LoadBalancerConfig::default(),
530            health: HealthConfig::default(),
531            nodes: Vec::new(),
532            tls: None,
533            write_timeout_secs: default_write_timeout_secs(),
534            plugins: PluginToml::default(),
535            hba: Vec::new(),
536            auth: AuthConfig::default(),
537            mcp: McpConfig::default(),
538            agent_contracts: Vec::new(),
539            http_gateway: HttpGatewayConfig::default(),
540            mirror: MirrorConfig::default(),
541            branch: BranchConfig::default(),
542            optimize_unnamed_parse: true,
543            shutdown_drain_timeout_secs: default_drain_timeout_secs(),
544        }
545    }
546}
547
548// =============================================================================
549// PLUGIN SYSTEM CONFIG (TOML-friendly shape)
550// =============================================================================
551
552/// Plugin-system configuration, in a TOML-friendly shape.
553///
554/// Always present on `ProxyConfig` so existing configs round-trip, but only
555/// consumed when the `wasm-plugins` feature is enabled. When
556/// `plugins.enabled` is `false` (the default), plugin loading is skipped
557/// entirely and every plugin-hook call site becomes a zero-cost no-op.
558///
559/// Converted to `crate::plugins::PluginRuntimeConfig` at startup via a
560/// feature-gated `From` impl in `src/plugins/config.rs`.
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct PluginToml {
563    /// Enable the plugin subsystem. Defaults to `false` — plugins are
564    /// strictly opt-in.
565    #[serde(default)]
566    pub enabled: bool,
567    /// Directory to scan at startup for `.wasm` plugin files.
568    #[serde(default = "default_plugin_dir")]
569    pub plugin_dir: String,
570    /// Watch `plugin_dir` for file changes and reload plugins hot.
571    #[serde(default)]
572    pub hot_reload: bool,
573    /// Memory limit per plugin instance, in megabytes.
574    #[serde(default = "default_plugin_memory_mb")]
575    pub memory_limit_mb: usize,
576    /// Execution timeout per hook call, in milliseconds.
577    #[serde(default = "default_plugin_timeout_ms")]
578    pub timeout_ms: u64,
579    /// Maximum number of concurrently-loaded plugins.
580    #[serde(default = "default_plugin_max")]
581    pub max_plugins: usize,
582    /// Enable per-call CPU-cycle (fuel) metering to bound plugin runtime.
583    #[serde(default = "default_true")]
584    pub fuel_metering: bool,
585    /// Fuel units allowed per hook call when `fuel_metering = true`.
586    #[serde(default = "default_plugin_fuel")]
587    pub fuel_limit: u64,
588    /// Optional Ed25519 trust-root directory. When set, every loaded
589    /// .wasm requires a sidecar .sig that verifies against one of
590    /// the *.pub files in this directory. When omitted, signatures
591    /// are not checked (preserves the dev-loop ergonomic of dropping
592    /// unsigned .wasm files in the plugin dir).
593    #[serde(default)]
594    pub trust_root: Option<String>,
595}
596
597fn default_plugin_dir() -> String {
598    "/etc/heliosproxy/plugins".to_string()
599}
600fn default_plugin_memory_mb() -> usize {
601    64
602}
603fn default_plugin_timeout_ms() -> u64 {
604    100
605}
606fn default_plugin_max() -> usize {
607    20
608}
609fn default_true() -> bool {
610    true
611}
612fn default_plugin_fuel() -> u64 {
613    1_000_000
614}
615
616impl Default for PluginToml {
617    fn default() -> Self {
618        Self {
619            enabled: false,
620            plugin_dir: default_plugin_dir(),
621            hot_reload: false,
622            memory_limit_mb: default_plugin_memory_mb(),
623            timeout_ms: default_plugin_timeout_ms(),
624            max_plugins: default_plugin_max(),
625            fuel_metering: true,
626            fuel_limit: default_plugin_fuel(),
627            trust_root: None,
628        }
629    }
630}
631
632impl ProxyConfig {
633    /// Get write timeout as Duration
634    pub fn write_timeout(&self) -> Duration {
635        Duration::from_secs(self.write_timeout_secs)
636    }
637
638    /// Load configuration from file
639    pub fn from_file(path: &str) -> Result<Self> {
640        let path = Path::new(path);
641
642        if !path.exists() {
643            return Err(ProxyError::Config(format!(
644                "Configuration file not found: {}",
645                path.display()
646            )));
647        }
648
649        let contents = std::fs::read_to_string(path)
650            .map_err(|e| ProxyError::Config(format!("Failed to read config: {}", e)))?;
651
652        let config: Self = toml::from_str(&contents)
653            .map_err(|e| ProxyError::Config(format!("Failed to parse config: {}", e)))?;
654
655        config.validate()?;
656
657        Ok(config)
658    }
659
660    /// Add a node from host:port string
661    pub fn add_node(&mut self, host_port: &str, role: &str) -> Result<()> {
662        let parts: Vec<&str> = host_port.rsplitn(2, ':').collect();
663        if parts.len() != 2 {
664            return Err(ProxyError::Config(format!(
665                "Invalid host:port format: {}",
666                host_port
667            )));
668        }
669
670        let port: u16 = parts[0].parse()
671            .map_err(|_| ProxyError::Config(format!("Invalid port: {}", parts[0])))?;
672
673        let host = parts[1].to_string();
674
675        let role = match role {
676            "primary" => NodeRole::Primary,
677            "standby" => NodeRole::Standby,
678            "replica" => NodeRole::ReadReplica,
679            _ => return Err(ProxyError::Config(format!("Unknown role: {}", role))),
680        };
681
682        self.nodes.push(NodeConfig {
683            host,
684            port,
685            http_port: default_http_port(),
686            role,
687            weight: 100,
688            enabled: true,
689            name: None,
690        });
691
692        Ok(())
693    }
694
695    /// Validate configuration
696    pub fn validate(&self) -> Result<()> {
697        // Must have at least one node
698        if self.nodes.is_empty() {
699            return Err(ProxyError::Config("No backend nodes configured".to_string()));
700        }
701
702        // Must have a primary node
703        let has_primary = self.nodes.iter().any(|n| n.role == NodeRole::Primary);
704        if !has_primary {
705            return Err(ProxyError::Config("No primary node configured".to_string()));
706        }
707
708        // Validate pool config
709        if self.pool.max_connections < self.pool.min_connections {
710            return Err(ProxyError::Config(
711                "max_connections must be >= min_connections".to_string(),
712            ));
713        }
714
715        Ok(())
716    }
717
718    /// Get primary node
719    pub fn primary_node(&self) -> Option<&NodeConfig> {
720        self.nodes.iter().find(|n| n.role == NodeRole::Primary && n.enabled)
721    }
722
723    /// Get standby nodes
724    pub fn standby_nodes(&self) -> Vec<&NodeConfig> {
725        self.nodes.iter()
726            .filter(|n| n.role == NodeRole::Standby && n.enabled)
727            .collect()
728    }
729
730    /// Get all enabled nodes
731    pub fn enabled_nodes(&self) -> Vec<&NodeConfig> {
732        self.nodes.iter().filter(|n| n.enabled).collect()
733    }
734}
735
736/// TR (Transaction Replay) mode
737#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
738#[serde(rename_all = "lowercase")]
739pub enum TrMode {
740    /// No transaction replay
741    None,
742    /// Re-establish session only
743    Session,
744    /// Re-execute SELECT queries
745    Select,
746    /// Full transaction replay
747    Transaction,
748}
749
750impl Default for TrMode {
751    fn default() -> Self {
752        TrMode::Session
753    }
754}
755
756/// Connection pool configuration
757#[derive(Debug, Clone, Serialize, Deserialize)]
758pub struct PoolConfig {
759    /// Minimum connections per node
760    pub min_connections: usize,
761    /// Maximum connections per node
762    pub max_connections: usize,
763    /// Connection idle timeout (seconds)
764    pub idle_timeout_secs: u64,
765    /// Maximum connection lifetime (seconds)
766    pub max_lifetime_secs: u64,
767    /// Connection acquire timeout (seconds)
768    pub acquire_timeout_secs: u64,
769    /// Test connection before use
770    pub test_on_acquire: bool,
771}
772
773impl Default for PoolConfig {
774    fn default() -> Self {
775        Self {
776            min_connections: 2,
777            max_connections: 100,
778            idle_timeout_secs: 300,
779            max_lifetime_secs: 1800,
780            acquire_timeout_secs: 30,
781            test_on_acquire: true,
782        }
783    }
784}
785
786impl PoolConfig {
787    /// Get idle timeout as Duration
788    pub fn idle_timeout(&self) -> Duration {
789        Duration::from_secs(self.idle_timeout_secs)
790    }
791
792    /// Get max lifetime as Duration
793    pub fn max_lifetime(&self) -> Duration {
794        Duration::from_secs(self.max_lifetime_secs)
795    }
796
797    /// Get acquire timeout as Duration
798    pub fn acquire_timeout(&self) -> Duration {
799        Duration::from_secs(self.acquire_timeout_secs)
800    }
801}
802
803/// Load balancer configuration
804#[derive(Debug, Clone, Serialize, Deserialize)]
805pub struct LoadBalancerConfig {
806    /// Routing strategy for read queries
807    pub read_strategy: Strategy,
808    /// Enable read/write splitting
809    pub read_write_split: bool,
810    /// Latency threshold for unhealthy marking (ms)
811    pub latency_threshold_ms: u64,
812}
813
814impl Default for LoadBalancerConfig {
815    fn default() -> Self {
816        Self {
817            read_strategy: Strategy::RoundRobin,
818            read_write_split: true,
819            latency_threshold_ms: 100,
820        }
821    }
822}
823
824/// Load balancing strategy
825#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
826#[serde(rename_all = "snake_case")]
827pub enum Strategy {
828    /// Round-robin across nodes
829    RoundRobin,
830    /// Weighted round-robin
831    WeightedRoundRobin,
832    /// Route to least loaded node
833    LeastConnections,
834    /// Route to lowest latency node
835    LatencyBased,
836    /// Random selection
837    Random,
838}
839
840/// Health check configuration
841#[derive(Debug, Clone, Serialize, Deserialize)]
842pub struct HealthConfig {
843    /// Check interval (seconds)
844    pub check_interval_secs: u64,
845    /// Check timeout (seconds)
846    pub check_timeout_secs: u64,
847    /// Failures before marking unhealthy
848    pub failure_threshold: u32,
849    /// Successes before marking healthy
850    pub success_threshold: u32,
851    /// Health check query
852    pub check_query: String,
853}
854
855impl Default for HealthConfig {
856    fn default() -> Self {
857        Self {
858            check_interval_secs: 5,
859            check_timeout_secs: 3,
860            failure_threshold: 3,
861            success_threshold: 2,
862            check_query: "SELECT 1".to_string(),
863        }
864    }
865}
866
867impl HealthConfig {
868    /// Get check interval as Duration
869    pub fn check_interval(&self) -> Duration {
870        Duration::from_secs(self.check_interval_secs)
871    }
872
873    /// Get check timeout as Duration
874    pub fn check_timeout(&self) -> Duration {
875        Duration::from_secs(self.check_timeout_secs)
876    }
877}
878
879/// Backend node configuration
880#[derive(Debug, Clone, Serialize, Deserialize)]
881pub struct NodeConfig {
882    /// Node host
883    pub host: String,
884    /// Node port (PostgreSQL protocol)
885    pub port: u16,
886    /// Node HTTP API port (for SQL API forwarding)
887    /// Defaults to 8080 if not specified
888    #[serde(default = "default_http_port")]
889    pub http_port: u16,
890    /// Node role
891    pub role: NodeRole,
892    /// Weight for load balancing
893    pub weight: u32,
894    /// Whether node is enabled
895    pub enabled: bool,
896    /// Optional node name for logging
897    pub name: Option<String>,
898}
899
900fn default_http_port() -> u16 {
901    8080
902}
903
904impl NodeConfig {
905    /// Get address string
906    pub fn address(&self) -> String {
907        format!("{}:{}", self.host, self.port)
908    }
909
910    /// Get display name
911    pub fn display_name(&self) -> &str {
912        self.name.as_deref().unwrap_or(&self.host)
913    }
914}
915
916/// Node role
917#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
918#[serde(rename_all = "lowercase")]
919pub enum NodeRole {
920    /// Primary node (accepts writes)
921    Primary,
922    /// Standby node (can be promoted)
923    Standby,
924    /// Read replica (read-only, cannot be promoted)
925    #[serde(rename = "replica")]
926    ReadReplica,
927}
928
929/// TLS configuration
930#[derive(Debug, Clone, Serialize, Deserialize)]
931pub struct TlsConfig {
932    /// Enable TLS for client connections
933    pub enabled: bool,
934    /// Path to certificate file
935    pub cert_path: String,
936    /// Path to private key file
937    pub key_path: String,
938    /// Path to CA certificate (for client verification)
939    pub ca_path: Option<String>,
940    /// Require client certificates
941    pub require_client_cert: bool,
942}
943
944#[cfg(test)]
945mod tests {
946    use super::*;
947
948    #[test]
949    fn test_default_config() {
950        let config = ProxyConfig::default();
951        assert_eq!(config.listen_address, "0.0.0.0:5432");
952        assert!(config.tr_enabled);
953    }
954
955    #[test]
956    fn test_add_node() {
957        let mut config = ProxyConfig::default();
958        config.add_node("localhost:5432", "primary").unwrap();
959        config.add_node("localhost:5433", "standby").unwrap();
960
961        assert_eq!(config.nodes.len(), 2);
962        assert!(config.primary_node().is_some());
963        assert_eq!(config.standby_nodes().len(), 1);
964    }
965
966    #[test]
967    fn test_validate_no_nodes() {
968        let config = ProxyConfig::default();
969        assert!(config.validate().is_err());
970    }
971
972    #[test]
973    fn test_validate_no_primary() {
974        let mut config = ProxyConfig::default();
975        config.add_node("localhost:5432", "standby").unwrap();
976        assert!(config.validate().is_err());
977    }
978
979    #[test]
980    fn test_validate_success() {
981        let mut config = ProxyConfig::default();
982        config.add_node("localhost:5432", "primary").unwrap();
983        assert!(config.validate().is_ok());
984    }
985
986    #[test]
987    fn test_pool_config_durations() {
988        let config = PoolConfig::default();
989        assert_eq!(config.idle_timeout(), Duration::from_secs(300));
990        assert_eq!(config.max_lifetime(), Duration::from_secs(1800));
991    }
992
993    #[test]
994    fn test_pool_mode_default() {
995        let config = PoolModeConfig::default();
996        assert_eq!(config.mode, PoolingMode::Session);
997        assert_eq!(config.max_pool_size, 100);
998        assert_eq!(config.min_idle, 10);
999        assert_eq!(config.reset_query, "DISCARD ALL");
1000    }
1001
1002    #[test]
1003    fn test_pool_mode_session() {
1004        let config = PoolModeConfig::session_mode();
1005        assert_eq!(config.mode, PoolingMode::Session);
1006        assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Named);
1007    }
1008
1009    #[test]
1010    fn test_pool_mode_transaction() {
1011        let config = PoolModeConfig::transaction_mode();
1012        assert_eq!(config.mode, PoolingMode::Transaction);
1013        assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Track);
1014    }
1015
1016    #[test]
1017    fn test_pool_mode_statement() {
1018        let config = PoolModeConfig::statement_mode();
1019        assert_eq!(config.mode, PoolingMode::Statement);
1020        assert_eq!(config.prepared_statement_mode, PreparedStatementMode::Disable);
1021    }
1022
1023    #[test]
1024    fn test_pool_mode_durations() {
1025        let config = PoolModeConfig::default();
1026        assert_eq!(config.idle_timeout(), Duration::from_secs(600));
1027        assert_eq!(config.max_lifetime(), Duration::from_secs(3600));
1028        assert_eq!(config.acquire_timeout(), Duration::from_secs(5));
1029    }
1030
1031    #[test]
1032    fn test_proxy_config_has_pool_mode() {
1033        let config = ProxyConfig::default();
1034        assert_eq!(config.pool_mode.mode, PoolingMode::Session);
1035    }
1036
1037    /// `plugins` defaults to `enabled = false` so adding the field to
1038    /// `ProxyConfig` doesn't spontaneously turn on the plugin subsystem
1039    /// for existing deployments.
1040    #[test]
1041    fn test_plugin_toml_default_is_disabled() {
1042        let config = ProxyConfig::default();
1043        assert!(!config.plugins.enabled);
1044        assert_eq!(config.plugins.plugin_dir, "/etc/heliosproxy/plugins");
1045        assert_eq!(config.plugins.memory_limit_mb, 64);
1046        assert_eq!(config.plugins.timeout_ms, 100);
1047    }
1048
1049    /// Existing TOML configs (written before this field existed) must
1050    /// round-trip through `Deserialize` without failing. The `plugins`
1051    /// section is `#[serde(default)]`, so omitting it yields the default.
1052    #[test]
1053    fn test_proxy_config_toml_without_plugins_section_still_parses() {
1054        let toml_text = r#"
1055            listen_address = "0.0.0.0:5432"
1056            admin_address = "0.0.0.0:9090"
1057            tr_enabled = true
1058            tr_mode = "session"
1059            nodes = []
1060
1061            [pool]
1062            min_connections = 2
1063            max_connections = 10
1064            idle_timeout_secs = 300
1065            max_lifetime_secs = 1800
1066            acquire_timeout_secs = 30
1067            test_on_acquire = true
1068
1069            [load_balancer]
1070            read_strategy = "round_robin"
1071            read_write_split = true
1072            latency_threshold_ms = 100
1073
1074            [health]
1075            check_interval_secs = 5
1076            check_timeout_secs = 3
1077            failure_threshold = 3
1078            success_threshold = 2
1079            check_query = "SELECT 1"
1080        "#;
1081        let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
1082        assert!(!config.plugins.enabled);
1083    }
1084
1085    /// A `[plugins]` section with overrides round-trips and populates the
1086    /// struct correctly.
1087    #[test]
1088    fn test_plugin_toml_overrides_parse() {
1089        let toml_text = r#"
1090            listen_address = "0.0.0.0:5432"
1091            admin_address = "0.0.0.0:9090"
1092            tr_enabled = true
1093            tr_mode = "session"
1094            nodes = []
1095
1096            [pool]
1097            min_connections = 2
1098            max_connections = 10
1099            idle_timeout_secs = 300
1100            max_lifetime_secs = 1800
1101            acquire_timeout_secs = 30
1102            test_on_acquire = true
1103
1104            [load_balancer]
1105            read_strategy = "round_robin"
1106            read_write_split = true
1107            latency_threshold_ms = 100
1108
1109            [health]
1110            check_interval_secs = 5
1111            check_timeout_secs = 3
1112            failure_threshold = 3
1113            success_threshold = 2
1114            check_query = "SELECT 1"
1115
1116            [plugins]
1117            enabled = true
1118            plugin_dir = "/tmp/helios-plugins"
1119            hot_reload = true
1120            memory_limit_mb = 128
1121            timeout_ms = 250
1122        "#;
1123        let config: ProxyConfig = toml::from_str(toml_text).expect("parse");
1124        assert!(config.plugins.enabled);
1125        assert_eq!(config.plugins.plugin_dir, "/tmp/helios-plugins");
1126        assert!(config.plugins.hot_reload);
1127        assert_eq!(config.plugins.memory_limit_mb, 128);
1128        assert_eq!(config.plugins.timeout_ms, 250);
1129        // Un-specified fields retain their defaults.
1130        assert_eq!(config.plugins.max_plugins, 20);
1131        assert!(config.plugins.fuel_metering);
1132    }
1133}