Skip to main content

heliosdb_proxy/
server.rs

1//! Proxy Server Implementation
2//!
3//! Main server that accepts client connections and routes them to backends.
4//! Implements PostgreSQL wire protocol forwarding with TWR (Transparent Write Routing).
5
6use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::config::{NodeConfig, NodeRole, ProxyConfig, TrMode};
10use crate::protocol::{
11    ErrorResponse, Message, MessageType, ParseMessage, ProtocolCodec, QueryMessage,
12    StartupMessage, TransactionStatus,
13};
14use crate::{ProxyError, Result};
15use bytes::{BufMut, BytesMut};
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::{TcpListener, TcpStream};
23use tokio::sync::{broadcast, RwLock};
24use uuid::Uuid;
25
26// Pool-modes feature imports
27#[cfg(feature = "pool-modes")]
28use crate::pool::{
29    ConnectionPoolManager, PoolModeConfig, PoolingMode,
30};
31#[cfg(feature = "pool-modes")]
32use crate::pool::lease::ClientId;
33#[cfg(feature = "pool-modes")]
34use crate::NodeEndpoint;
35
36// WASM plugin system imports
37#[cfg(feature = "wasm-plugins")]
38use crate::plugins::{
39    AuthRequest as PluginAuthRequest, AuthResult, HookContext, Identity, PluginManager,
40    PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
41};
42
43/// Proxy server
44pub struct ProxyServer {
45    config: ProxyConfig,
46    state: Arc<ServerState>,
47    shutdown_tx: broadcast::Sender<()>,
48}
49
50/// Build the BackendConfig template the time-travel replay engine
51/// uses for its target connection. The replay handler swaps in
52/// `target_host` / `target_port` per request; everything else
53/// (auth, TLS policy, timeouts) comes from this template.
54///
55/// Auth defaults to the bare PostgreSQL `postgres` superuser without
56/// a password — sensible for local development against `trust` auth,
57/// never for production. Per-call credential overrides on
58/// ReplayRequestBody land in FU-21.
59///
60/// `_config` is kept in the signature so future iterations can pull
61/// shared TLS / timeout settings from the proxy config without
62/// changing the call site.
63#[cfg(feature = "ha-tr")]
64fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
65    BackendConfig {
66        host: "placeholder".to_string(),
67        port: 0,
68        user: "postgres".to_string(),
69        password: None,
70        database: None,
71        application_name: Some("heliosdb-proxy-replay".to_string()),
72        tls_mode: TlsMode::Disable,
73        connect_timeout: Duration::from_secs(5),
74        query_timeout: Duration::from_secs(30),
75        tls_config: default_client_config(),
76    }
77}
78
79/// Cheap query-shape fingerprint for the anomaly detector. Replaces
80/// numeric and string literals with `?` placeholders, lower-cases
81/// keywords, and collapses whitespace. Same shape regardless of
82/// literal values — `SELECT * FROM users WHERE id = 1` and
83/// `SELECT * FROM users WHERE id = 99` map to the same fingerprint.
84///
85/// Not a parser. The analytics module has the canonical normaliser
86/// when query-analytics is on; this is a lightweight standalone so
87/// the anomaly detector works even when analytics is off.
88#[cfg(feature = "anomaly-detection")]
89fn anomaly_fingerprint(sql: &str) -> String {
90    let mut out = String::with_capacity(sql.len());
91    let mut in_single = false;
92    let mut prev_space = false;
93    let mut chars = sql.chars().peekable();
94    while let Some(c) = chars.next() {
95        if c == '\'' {
96            in_single = !in_single;
97            // Replace the entire string literal (open + body +
98            // close) with a single ?.
99            if in_single {
100                out.push('?');
101                while let Some(&n) = chars.peek() {
102                    chars.next();
103                    if n == '\'' {
104                        in_single = false;
105                        break;
106                    }
107                }
108                prev_space = false;
109                continue;
110            }
111        }
112        if c.is_ascii_digit() {
113            if !out.ends_with('?') {
114                out.push('?');
115            }
116            // Skip the rest of the number.
117            while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
118                chars.next();
119            }
120            prev_space = false;
121            continue;
122        }
123        if c.is_ascii_whitespace() {
124            if !prev_space && !out.is_empty() {
125                out.push(' ');
126                prev_space = true;
127            }
128            continue;
129        }
130        out.push(c.to_ascii_lowercase());
131        prev_space = false;
132    }
133    out.trim_end().to_string()
134}
135
136/// Server runtime state
137struct ServerState {
138    /// Active client sessions
139    sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
140    /// Node health status
141    health: RwLock<HashMap<String, NodeHealth>>,
142    /// Metrics
143    metrics: ServerMetrics,
144    /// Load balancer state
145    lb_state: RwLock<LoadBalancerState>,
146    /// Pool manager for Session/Transaction/Statement modes
147    #[cfg(feature = "pool-modes")]
148    pool_manager: Option<Arc<ConnectionPoolManager>>,
149    /// WASM plugin manager. `None` means no plugins loaded — the per-query
150    /// hook path becomes a fast no-op. When `Some`, `PreQuery` / `PostQuery`
151    /// hooks fire on every simple-query message.
152    #[cfg(feature = "wasm-plugins")]
153    plugin_manager: Option<Arc<PluginManager>>,
154    /// Shared transaction journal — single sink for per-session
155    /// statement journaling. The replay engine reads windows from
156    /// this directly. Always present when the `ha-tr` feature is on;
157    /// journaling self-disables internally when not configured.
158    #[cfg(feature = "ha-tr")]
159    transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
160    /// Anomaly detector (T3.1). Records every query and every
161    /// auth outcome; surfaces detections via /api/anomalies.
162    #[cfg(feature = "anomaly-detection")]
163    anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
164    /// Edge cache + home registry (T3.2). Both always-present even
165    /// in Home mode (the cache is a no-op there); avoids an extra
166    /// Option in the hot path.
167    #[cfg(feature = "edge-proxy")]
168    edge_cache: Arc<crate::edge::EdgeCache>,
169    #[cfg(feature = "edge-proxy")]
170    edge_registry: Arc<crate::edge::EdgeRegistry>,
171}
172
173/// Node health status
174#[derive(Debug, Clone)]
175pub struct NodeHealth {
176    /// Node address
177    pub address: String,
178    /// Whether node is healthy
179    pub healthy: bool,
180    /// Last check time
181    pub last_check: chrono::DateTime<chrono::Utc>,
182    /// Consecutive failures
183    pub failure_count: u32,
184    /// Last error message
185    pub last_error: Option<String>,
186    /// Average latency (ms)
187    pub latency_ms: f64,
188    /// Replication lag (if applicable)
189    pub replication_lag_bytes: Option<u64>,
190}
191
192/// Server metrics
193#[derive(Default)]
194struct ServerMetrics {
195    /// Total connections accepted
196    connections_accepted: AtomicU64,
197    /// Total connections closed
198    connections_closed: AtomicU64,
199    /// Total queries processed
200    queries_processed: AtomicU64,
201    /// Total bytes received from clients
202    bytes_received: AtomicU64,
203    /// Total bytes sent to clients
204    bytes_sent: AtomicU64,
205    /// Failover count
206    failovers: AtomicU64,
207}
208
209/// Load balancer state
210struct LoadBalancerState {
211    /// Round-robin counter
212    rr_counter: u64,
213}
214
215/// Client session
216pub struct ClientSession {
217    /// Session ID
218    pub id: Uuid,
219    /// Client address
220    pub client_addr: SocketAddr,
221    /// Current backend node
222    pub current_node: RwLock<Option<String>>,
223    /// Transaction state
224    pub tx_state: RwLock<TransactionState>,
225    /// Session variables
226    pub variables: RwLock<HashMap<String, String>>,
227    /// Created at
228    pub created_at: chrono::DateTime<chrono::Utc>,
229    /// TR mode for this session
230    pub tr_mode: TrMode,
231    /// Client ID for pool-modes lease tracking
232    #[cfg(feature = "pool-modes")]
233    pub pool_client_id: ClientId,
234    /// Identity returned by an `Authenticate` plugin, if any. Downstream
235    /// plugins (masking, residency routing, cost governor) read this to
236    /// gate per-user policy. `None` when no plugin ran or every plugin
237    /// deferred to the default auth flow.
238    #[cfg(feature = "wasm-plugins")]
239    pub plugin_identity: RwLock<Option<Identity>>,
240}
241
242/// Transaction state
243#[derive(Debug, Clone, Default)]
244pub struct TransactionState {
245    /// Whether in a transaction
246    pub in_transaction: bool,
247    /// Transaction ID
248    pub tx_id: Option<Uuid>,
249    /// Statements executed in current transaction
250    pub statements: Vec<StatementLog>,
251    /// Read-only transaction
252    pub read_only: bool,
253    /// Savepoints
254    pub savepoints: Vec<String>,
255}
256
257/// Logged statement for TR replay
258#[derive(Debug, Clone)]
259pub struct StatementLog {
260    /// Statement SQL
261    pub sql: String,
262    /// Parameters
263    pub params: Vec<String>,
264    /// Result checksum
265    pub result_checksum: Option<u64>,
266    /// Execution time
267    pub executed_at: chrono::DateTime<chrono::Utc>,
268}
269
270/// Disposition produced by the pre-query plugin hook stage.
271///
272/// When the `wasm-plugins` feature is off, only `Forward` is ever produced —
273/// the hook dispatch is compiled out entirely and the variant list exists
274/// purely for pattern-match symmetry.
275#[derive(Debug)]
276#[allow(dead_code)] // Block/Cached only constructed under wasm-plugins
277enum PreQueryAction {
278    /// Send the message to the backend as usual.
279    Forward,
280    /// A plugin blocked the query. The caller sends an error + ReadyForQuery
281    /// to the client and skips backend forwarding.
282    Block(String),
283    /// A plugin returned a cached response. Not yet wired — response
284    /// synthesis from raw bytes requires building a full protocol reply
285    /// (RowDescription + DataRow(s) + CommandComplete + ReadyForQuery),
286    /// which is the next step of T0-a. For now the caller falls back to
287    /// `Forward` and logs a warning.
288    Cached(Vec<u8>),
289}
290
291/// Override produced by the Route plugin hook. Consumed by `route_and_forward`
292/// when deciding which backend to talk to.
293///
294/// As with `PreQueryAction`, only `None` is ever produced when the
295/// `wasm-plugins` feature is off.
296#[derive(Debug)]
297#[allow(dead_code)] // Primary/Standby/Node/Block only constructed under wasm-plugins
298enum RouteOverride {
299    /// No override — use the default SQL-verb-based routing.
300    None,
301    /// Force the write path (use `select_primary_with_timeout`).
302    Primary,
303    /// Force the read path (use `select_read_node`).
304    Standby,
305    /// Use this exact node address. Takes precedence over the is_write
306    /// heuristic; the proxy will still verify the node is healthy before
307    /// connecting (via the normal switch-vs-reuse flow).
308    Node(String),
309    /// Reject the query: write a PG ErrorResponse + ReadyForQuery to
310    /// the client and skip the forward. Carries the reason the plugin
311    /// supplied. Takes precedence over every other field — the proxy
312    /// short-circuits before any backend selection.
313    Block(String),
314}
315
316impl ProxyServer {
317    /// Build a `PluginManager` from config and preload plugins from disk.
318    ///
319    /// Returns `None` when plugins are disabled in config, when the
320    /// runtime fails to initialise, or when the plugin directory is
321    /// missing. Individual per-file load failures are logged but do not
322    /// abort startup — the remaining plugins load normally and the
323    /// proxy stays up.
324    #[cfg(feature = "wasm-plugins")]
325    fn init_plugin_manager(
326        toml_cfg: &crate::config::PluginToml,
327    ) -> Option<Arc<crate::plugins::PluginManager>> {
328        if !toml_cfg.enabled {
329            return None;
330        }
331
332        let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
333        let plugin_dir = runtime_cfg.plugin_dir.clone();
334
335        let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
336            Ok(pm) => Arc::new(pm),
337            Err(e) => {
338                tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
339                return None;
340            }
341        };
342
343        match std::fs::read_dir(&plugin_dir) {
344            Ok(entries) => {
345                let mut loaded = 0usize;
346                let mut failed = 0usize;
347                for entry in entries.flatten() {
348                    let path = entry.path();
349                    if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
350                        continue;
351                    }
352                    match pm.load_plugin(&path) {
353                        Ok(()) => loaded += 1,
354                        Err(e) => {
355                            failed += 1;
356                            tracing::warn!(
357                                path = %path.display(),
358                                error = %e,
359                                "Failed to load plugin"
360                            );
361                        }
362                    }
363                }
364                tracing::info!(
365                    dir = %plugin_dir.display(),
366                    loaded = loaded,
367                    failed = failed,
368                    "Plugin loading complete"
369                );
370            }
371            Err(e) => {
372                tracing::warn!(
373                    dir = %plugin_dir.display(),
374                    error = %e,
375                    "Plugin directory not readable; no plugins loaded"
376                );
377            }
378        }
379
380        Some(pm)
381    }
382
383    /// Create a new proxy server
384    pub fn new(config: ProxyConfig) -> Result<Self> {
385        let (shutdown_tx, _) = broadcast::channel(1);
386
387        // Initialize health status
388        let mut health = HashMap::new();
389        for node in &config.nodes {
390            health.insert(
391                node.address(),
392                NodeHealth {
393                    address: node.address(),
394                    healthy: true, // Assume healthy until proven otherwise
395                    last_check: chrono::Utc::now(),
396                    failure_count: 0,
397                    last_error: None,
398                    latency_ms: 0.0,
399                    replication_lag_bytes: None,
400                },
401            );
402        }
403
404        // Initialize pool manager if pool-modes feature is enabled
405        #[cfg(feature = "pool-modes")]
406        let pool_manager = {
407            use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
408
409            let pool_config = PoolModeConfig {
410                default_mode: match config.pool_mode.mode {
411                    crate::config::PoolingMode::Session => PoolingMode::Session,
412                    crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
413                    crate::config::PoolingMode::Statement => PoolingMode::Statement,
414                },
415                max_pool_size: config.pool_mode.max_pool_size,
416                min_idle: config.pool_mode.min_idle,
417                idle_timeout_secs: config.pool_mode.idle_timeout_secs,
418                max_lifetime_secs: config.pool_mode.max_lifetime_secs,
419                acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
420                reset_query: config.pool_mode.reset_query.clone(),
421                prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
422                    crate::config::PreparedStatementMode::Disable => {
423                        PoolPreparedStatementMode::Disable
424                    }
425                    crate::config::PreparedStatementMode::Track => {
426                        PoolPreparedStatementMode::Track
427                    }
428                    crate::config::PreparedStatementMode::Named => {
429                        PoolPreparedStatementMode::Named
430                    }
431                },
432                test_on_acquire: config.pool.test_on_acquire,
433                validation_query: "SELECT 1".to_string(),
434                queue_timeout_secs: 30,
435                max_queue_size: 0,
436            };
437            Some(Arc::new(ConnectionPoolManager::new(pool_config)))
438        };
439
440        // Initialize plugin manager if the wasm-plugins feature is enabled
441        // AND plugins are turned on in config. Scans plugin_dir for `.wasm`
442        // files and loads each; a missing directory is non-fatal and logs
443        // a warning so empty deployments don't fail startup.
444        #[cfg(feature = "wasm-plugins")]
445        let plugin_manager = Self::init_plugin_manager(&config.plugins);
446
447        let state = Arc::new(ServerState {
448            sessions: RwLock::new(HashMap::new()),
449            health: RwLock::new(health),
450            metrics: ServerMetrics::default(),
451            lb_state: RwLock::new(LoadBalancerState {
452                rr_counter: 0,
453            }),
454            #[cfg(feature = "pool-modes")]
455            pool_manager,
456            #[cfg(feature = "wasm-plugins")]
457            plugin_manager,
458            #[cfg(feature = "ha-tr")]
459            transaction_journal: Arc::new(
460                crate::transaction_journal::TransactionJournal::new(),
461            ),
462            #[cfg(feature = "anomaly-detection")]
463            anomaly_detector: Arc::new(
464                crate::anomaly::AnomalyDetector::new(
465                    crate::anomaly::AnomalyConfig::default(),
466                ),
467            ),
468            #[cfg(feature = "edge-proxy")]
469            edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
470            #[cfg(feature = "edge-proxy")]
471            edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
472                32,
473                std::time::Duration::from_secs(120),
474            )),
475        });
476
477        Ok(Self {
478            config,
479            state,
480            shutdown_tx,
481        })
482    }
483
484    /// Run the proxy server
485    pub async fn run(&self) -> Result<()> {
486        let listener = TcpListener::bind(&self.config.listen_address)
487            .await
488            .map_err(|e| ProxyError::Network(format!("Failed to bind: {}", e)))?;
489
490        tracing::info!("Proxy listening on {}", self.config.listen_address);
491
492        // Start background tasks
493        let health_task = self.spawn_health_checker();
494        let pool_task = self.spawn_pool_manager();
495
496        // Start admin server
497        let admin_task = self.spawn_admin_server();
498
499        let mut shutdown_rx = self.shutdown_tx.subscribe();
500
501        loop {
502            tokio::select! {
503                accept_result = listener.accept() => {
504                    match accept_result {
505                        Ok((stream, addr)) => {
506                            self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
507                            let state = self.state.clone();
508                            let config = self.config.clone();
509                            let shutdown_tx = self.shutdown_tx.clone();
510
511                            tokio::spawn(async move {
512                                if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
513                                    tracing::error!("Client handler error: {}", e);
514                                }
515                            });
516                        }
517                        Err(e) => {
518                            tracing::error!("Accept error: {}", e);
519                        }
520                    }
521                }
522                _ = shutdown_rx.recv() => {
523                    tracing::info!("Shutdown signal received");
524                    break;
525                }
526            }
527        }
528
529        // Wait for background tasks
530        health_task.abort();
531        pool_task.abort();
532        admin_task.abort();
533
534        Ok(())
535    }
536
537    /// Spawn admin API server
538    fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
539        let config = self.config.clone();
540        let state = self.state.clone();
541        let mut shutdown_rx = self.shutdown_tx.subscribe();
542
543        tokio::spawn(async move {
544            // Create admin state
545            let admin_state = Arc::new(AdminState::new());
546
547            // Initialize config snapshot
548            {
549                let mut snapshot = admin_state.config_snapshot.write().await;
550                *snapshot = ConfigSnapshot {
551                    listen_address: config.listen_address.clone(),
552                    admin_address: config.admin_address.clone(),
553                    tr_enabled: config.tr_enabled,
554                    tr_mode: format!("{:?}", config.tr_mode),
555                    pool_min_connections: config.pool.min_connections,
556                    pool_max_connections: config.pool.max_connections,
557                    nodes: config.nodes.iter().map(|n| NodeSnapshot {
558                        address: n.address(),
559                        role: format!("{:?}", n.role),
560                        weight: n.weight,
561                        enabled: n.enabled,
562                    }).collect(),
563                };
564            }
565
566            // Set proxy config for SQL routing
567            admin_state.set_proxy_config(config.clone()).await;
568
569            // Attach the plugin manager so /plugins + the admin UI
570            // surface real loaded modules. Cheap Arc-clone — no
571            // duplicate state, both AdminState and ServerState hold
572            // the same manager.
573            #[cfg(feature = "wasm-plugins")]
574            if let Some(ref pm) = state.plugin_manager {
575                admin_state.with_plugin_manager(pm.clone()).await;
576            }
577
578            // Attach the time-travel replay engine. The engine reads
579            // windows from the shared TransactionJournal and replays
580            // statements against a target backend supplied per-request.
581            // Per-call credential overrides land via FU-21's
582            // ReplayRequestBody.target_user / target_password /
583            // target_database fields.
584            #[cfg(feature = "ha-tr")]
585            {
586                let template = build_replay_backend_template(&config);
587                let engine = Arc::new(crate::replay::ReplayEngine::new(
588                    state.transaction_journal.clone(),
589                    template,
590                ));
591                admin_state.with_replay_engine(engine).await;
592            }
593
594            // Attach the anomaly detector — same Arc the server
595            // populates from the query path. /api/anomalies polls
596            // this for surfaced detections.
597            #[cfg(feature = "anomaly-detection")]
598            admin_state
599                .with_anomaly_detector(state.anomaly_detector.clone())
600                .await;
601
602            // Attach the edge cache + registry. Both surfaced via
603            // /api/edge/* admin routes.
604            #[cfg(feature = "edge-proxy")]
605            admin_state
606                .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
607                .await;
608
609            // Create admin server
610            let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
611
612            // Spawn state sync task
613            let admin_state_sync = admin_state.clone();
614            let server_state = state.clone();
615            let sync_task = tokio::spawn(async move {
616                let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
617                loop {
618                    interval.tick().await;
619
620                    // Sync health status
621                    {
622                        let health = server_state.health.read().await;
623                        let mut admin_health = admin_state_sync.node_health.write().await;
624                        *admin_health = health.clone();
625                    }
626
627                    // Sync metrics
628                    {
629                        let metrics = ServerMetricsSnapshot {
630                            connections_accepted: server_state.metrics.connections_accepted.load(Ordering::Relaxed),
631                            connections_closed: server_state.metrics.connections_closed.load(Ordering::Relaxed),
632                            queries_processed: server_state.metrics.queries_processed.load(Ordering::Relaxed),
633                            bytes_received: server_state.metrics.bytes_received.load(Ordering::Relaxed),
634                            bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
635                            failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
636                        };
637                        let mut admin_metrics = admin_state_sync.metrics.write().await;
638                        *admin_metrics = metrics;
639                    }
640
641                    // Sync session count
642                    {
643                        let sessions = server_state.sessions.read().await;
644                        let mut admin_sessions = admin_state_sync.active_sessions.write().await;
645                        *admin_sessions = sessions.len() as u64;
646                    }
647                }
648            });
649
650            // Run admin server
651            tokio::select! {
652                result = admin_server.run() => {
653                    if let Err(e) = result {
654                        tracing::error!("Admin server error: {}", e);
655                    }
656                }
657                _ = shutdown_rx.recv() => {
658                    tracing::info!("Admin server shutting down");
659                }
660            }
661
662            sync_task.abort();
663        })
664    }
665
666    /// Handle a client connection
667    async fn handle_client(
668        mut stream: TcpStream,
669        addr: SocketAddr,
670        state: Arc<ServerState>,
671        config: ProxyConfig,
672        _shutdown_tx: broadcast::Sender<()>,
673    ) -> Result<()> {
674        tracing::debug!("New client connection from {}", addr);
675
676        // Create session
677        let session = Arc::new(ClientSession {
678            id: Uuid::new_v4(),
679            client_addr: addr,
680            current_node: RwLock::new(None),
681            tx_state: RwLock::new(TransactionState::default()),
682            variables: RwLock::new(HashMap::new()),
683            created_at: chrono::Utc::now(),
684            tr_mode: config.tr_mode,
685            #[cfg(feature = "pool-modes")]
686            pool_client_id: ClientId::new(),
687            #[cfg(feature = "wasm-plugins")]
688            plugin_identity: RwLock::new(None),
689        });
690
691        // Register session
692        {
693            let mut sessions = state.sessions.write().await;
694            sessions.insert(session.id, session.clone());
695        }
696
697        // Main client loop
698        let result = Self::client_loop(&mut stream, &session, &state, &config).await;
699
700        // Cleanup session
701        {
702            let mut sessions = state.sessions.write().await;
703            sessions.remove(&session.id);
704        }
705
706        // Release any active pool lease if pool-modes is enabled
707        #[cfg(feature = "pool-modes")]
708        if let Some(ref pool_manager) = state.pool_manager {
709            // Check if there's an active lease for this client and release it
710            if pool_manager.has_active_lease(&session.pool_client_id) {
711                tracing::debug!(
712                    "Releasing pool lease for disconnecting client {:?}",
713                    session.pool_client_id
714                );
715                // Note: The lease is released implicitly when the connection closes
716                // The pool manager will clean up any orphaned leases
717            }
718        }
719
720        state
721            .metrics
722            .connections_closed
723            .fetch_add(1, Ordering::Relaxed);
724
725        result
726    }
727
728    /// Main client processing loop with full PostgreSQL protocol handling
729    async fn client_loop(
730        stream: &mut TcpStream,
731        session: &Arc<ClientSession>,
732        state: &Arc<ServerState>,
733        config: &ProxyConfig,
734    ) -> Result<()> {
735        let codec = ProtocolCodec::new();
736        let mut buffer = BytesMut::with_capacity(8192);
737
738        // Handle startup phase
739        let (mut backend_stream, mut backend_node): (Option<TcpStream>, Option<String>) =
740            match Self::handle_startup(stream, &mut buffer, &codec, session, state, config).await {
741                Ok((Some(stream_conn), node_addr)) => (Some(stream_conn), Some(node_addr)),
742                Ok((None, _)) => {
743                    // SSL rejected or cancel request, connection should close
744                    return Ok(());
745                }
746                Err(e) => {
747                    tracing::error!("Startup failed: {}", e);
748                    // Send error to client
749                    let err_msg = Self::create_error_response("08006", &format!("Startup failed: {}", e));
750                    let _ = stream.write_all(&err_msg).await;
751                    return Err(e);
752                }
753            };
754
755        // Main query loop
756        loop {
757            // Read from client
758            let mut read_buf = vec![0u8; 8192];
759            let n = stream
760                .read(&mut read_buf)
761                .await
762                .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
763
764            if n == 0 {
765                // Client disconnected
766                break;
767            }
768
769            buffer.extend_from_slice(&read_buf[..n]);
770            state.metrics.bytes_received.fetch_add(n as u64, Ordering::Relaxed);
771
772            // Process all complete messages in buffer
773            while let Some(msg) = codec.decode_message(&mut buffer)? {
774                // Handle Terminate message
775                if msg.msg_type == MessageType::Terminate {
776                    return Ok(());
777                }
778
779                // Anomaly detector — record every Query message
780                // (rate window, novel-fingerprint detector, SQLi
781                // pattern scan). Fires before the plugin hook so a
782                // detection lands in the audit trail even if a
783                // plugin later blocks.
784                #[cfg(feature = "anomaly-detection")]
785                Self::record_anomaly_observation(&msg, state, session);
786
787                // Plugin pre-query hook — may rewrite the SQL, block the
788                // query with an error, or (future) return a cached response.
789                let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
790
791                if let PreQueryAction::Block(reason) = &action {
792                    tracing::info!(reason = %reason, "pre-query plugin blocked query");
793                    Self::send_block_response(stream, reason, state).await?;
794                    state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
795                    continue;
796                }
797
798                // Plugin returned a fully-formed cached response. Synthesise
799                // the PG wire reply (RowDescription + DataRows +
800                // CommandComplete + ReadyForQuery) and send it directly —
801                // the backend is never touched. On malformed payloads we
802                // log + fall through to normal forwarding so a buggy plugin
803                // degrades gracefully instead of taking the proxy down.
804                #[cfg(feature = "wasm-plugins")]
805                if let PreQueryAction::Cached(bytes) = &action {
806                    match Self::synthesise_cached_response(bytes) {
807                        Ok(reply) => {
808                            stream
809                                .write_all(&reply)
810                                .await
811                                .map_err(|e| {
812                                    ProxyError::Network(format!("Write error: {}", e))
813                                })?;
814                            state
815                                .metrics
816                                .bytes_sent
817                                .fetch_add(reply.len() as u64, Ordering::Relaxed);
818                            state
819                                .metrics
820                                .queries_processed
821                                .fetch_add(1, Ordering::Relaxed);
822                            continue;
823                        }
824                        Err(e) => {
825                            tracing::warn!(
826                                error = %e,
827                                "failed to synthesise cached response; falling back to backend"
828                            );
829                            // fall through to normal forwarding
830                        }
831                    }
832                }
833
834                // Route and process the message
835                #[cfg(feature = "wasm-plugins")]
836                let forward_start = std::time::Instant::now();
837                let forward_result = Self::route_and_forward(
838                    &msg,
839                    backend_stream.take(),
840                    backend_node.take(),
841                    session,
842                    state,
843                    config,
844                )
845                .await;
846                #[cfg(feature = "wasm-plugins")]
847                Self::fire_post_query_hook(
848                    &msg,
849                    session,
850                    state,
851                    &forward_result,
852                    forward_start.elapsed(),
853                );
854                let (response, new_backend, new_node) = forward_result?;
855
856                backend_stream = new_backend;
857                backend_node = new_node;
858
859                // Send response to client
860                if !response.is_empty() {
861                    stream
862                        .write_all(&response)
863                        .await
864                        .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
865
866                    state
867                        .metrics
868                        .bytes_sent
869                        .fetch_add(response.len() as u64, Ordering::Relaxed);
870                }
871
872                state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
873            }
874        }
875
876        Ok(())
877    }
878
879    /// Handle PostgreSQL startup phase (SSL, authentication)
880    async fn handle_startup(
881        client_stream: &mut TcpStream,
882        buffer: &mut BytesMut,
883        codec: &ProtocolCodec,
884        session: &Arc<ClientSession>,
885        state: &Arc<ServerState>,
886        config: &ProxyConfig,
887    ) -> Result<(Option<TcpStream>, String)> {
888        // Read startup message
889        let mut read_buf = vec![0u8; 1024];
890        let n = client_stream
891            .read(&mut read_buf)
892            .await
893            .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
894
895        if n == 0 {
896            return Ok((None, String::new()));
897        }
898
899        buffer.extend_from_slice(&read_buf[..n]);
900
901        // Parse startup message
902        let startup_msg = codec.decode_startup(buffer)?;
903
904        match startup_msg {
905            Some(StartupMessage::SSLRequest) => {
906                // Reject SSL (send 'N')
907                client_stream
908                    .write_all(&[b'N'])
909                    .await
910                    .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
911
912                // Read actual startup message
913                buffer.clear();
914                let n = client_stream
915                    .read(&mut read_buf)
916                    .await
917                    .map_err(|e| ProxyError::Network(format!("Post-SSL read error: {}", e)))?;
918
919                if n == 0 {
920                    return Ok((None, String::new()));
921                }
922
923                buffer.extend_from_slice(&read_buf[..n]);
924
925                // Parse the real startup message
926                return Self::process_startup(
927                    client_stream,
928                    buffer,
929                    codec,
930                    session,
931                    state,
932                    config,
933                )
934                .await;
935            }
936            Some(StartupMessage::CancelRequest { .. }) => {
937                // Cancel requests are handled separately, just close connection
938                return Ok((None, String::new()));
939            }
940            Some(StartupMessage::Startup { params, .. }) => {
941                // Connect to backend and forward startup
942                return Self::connect_and_authenticate(
943                    client_stream,
944                    &params,
945                    session,
946                    state,
947                    config,
948                )
949                .await;
950            }
951            None => {
952                return Err(ProxyError::Protocol("Incomplete startup message".to_string()));
953            }
954        }
955    }
956
957    /// Process startup message after SSL negotiation
958    async fn process_startup(
959        client_stream: &mut TcpStream,
960        buffer: &mut BytesMut,
961        codec: &ProtocolCodec,
962        session: &Arc<ClientSession>,
963        state: &Arc<ServerState>,
964        config: &ProxyConfig,
965    ) -> Result<(Option<TcpStream>, String)> {
966        let startup_msg = codec.decode_startup(buffer)?;
967
968        match startup_msg {
969            Some(StartupMessage::Startup { params, .. }) => {
970                Self::connect_and_authenticate(client_stream, &params, session, state, config).await
971            }
972            _ => Err(ProxyError::Protocol("Expected startup message".to_string())),
973        }
974    }
975
976    /// Connect to backend and handle authentication
977    async fn connect_and_authenticate(
978        client_stream: &mut TcpStream,
979        params: &HashMap<String, String>,
980        session: &Arc<ClientSession>,
981        state: &Arc<ServerState>,
982        config: &ProxyConfig,
983    ) -> Result<(Option<TcpStream>, String)> {
984        // Plugin Authenticate hook — may deny the connection outright or
985        // attach a richer identity (roles, tenant_id, claims) onto the
986        // session for downstream plugins to consume. Happens before any
987        // backend connection is opened so denials cost nothing on the
988        // backend side.
989        Self::apply_authenticate_hook(params, session, state).await?;
990
991        // Select initial backend node (primary for now)
992        let node_addr = Self::select_node(session, state, config).await?;
993
994        // Connect to backend
995        let mut backend = tokio::time::timeout(
996            config.pool.acquire_timeout(),
997            TcpStream::connect(&node_addr),
998        )
999        .await
1000        .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", node_addr)))?
1001        .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", node_addr, e)))?;
1002
1003        // Build and send startup message to backend
1004        let startup_bytes = Self::build_startup_message(params);
1005        backend
1006            .write_all(&startup_bytes)
1007            .await
1008            .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
1009
1010        // Forward authentication messages between client and backend
1011        Self::proxy_authentication(client_stream, &mut backend).await?;
1012
1013        // Store session variables
1014        {
1015            let mut vars = session.variables.write().await;
1016            for (k, v) in params {
1017                vars.insert(k.clone(), v.clone());
1018            }
1019        }
1020
1021        Ok((Some(backend), node_addr))
1022    }
1023
1024    /// Build PostgreSQL startup message
1025    fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
1026        let mut payload = BytesMut::new();
1027
1028        // Protocol version 3.0
1029        payload.put_u32(196608);
1030
1031        // Parameters
1032        for (key, value) in params {
1033            payload.extend_from_slice(key.as_bytes());
1034            payload.put_u8(0);
1035            payload.extend_from_slice(value.as_bytes());
1036            payload.put_u8(0);
1037        }
1038        payload.put_u8(0); // Terminator
1039
1040        // Build complete message with length prefix
1041        let mut msg = BytesMut::new();
1042        msg.put_u32((payload.len() + 4) as u32);
1043        msg.extend_from_slice(&payload);
1044
1045        msg.to_vec()
1046    }
1047
1048    /// Proxy authentication messages between client and backend
1049    async fn proxy_authentication(
1050        client_stream: &mut TcpStream,
1051        backend_stream: &mut TcpStream,
1052    ) -> Result<()> {
1053        let codec = ProtocolCodec::new();
1054        let mut backend_buffer = BytesMut::with_capacity(4096);
1055        let mut client_buffer = BytesMut::with_capacity(4096);
1056
1057        loop {
1058            // Read from backend
1059            let mut read_buf = vec![0u8; 4096];
1060            let n = backend_stream
1061                .read(&mut read_buf)
1062                .await
1063                .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1064
1065            if n == 0 {
1066                return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1067            }
1068
1069            backend_buffer.extend_from_slice(&read_buf[..n]);
1070
1071            // Forward all data to client
1072            client_stream
1073                .write_all(&read_buf[..n])
1074                .await
1075                .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
1076
1077            // Check for authentication complete or error
1078            while let Some(msg) = codec.decode_message(&mut backend_buffer.clone())? {
1079                match msg.msg_type {
1080                    MessageType::AuthRequest => {
1081                        // Check if auth OK
1082                        if msg.payload.len() >= 4 {
1083                            let auth_type =
1084                                i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
1085                            if auth_type == 0 {
1086                                // AuthenticationOk - continue to read ReadyForQuery
1087                            }
1088                        }
1089                    }
1090                    MessageType::ReadyForQuery => {
1091                        // Authentication complete
1092                        return Ok(());
1093                    }
1094                    MessageType::ErrorResponse => {
1095                        // Authentication failed - error already sent to client
1096                        return Err(ProxyError::Auth("Authentication failed".to_string()));
1097                    }
1098                    _ => {
1099                        // Continue forwarding
1100                    }
1101                }
1102                // Advance the actual buffer
1103                let _ = codec.decode_message(&mut backend_buffer)?;
1104            }
1105
1106            // If backend requires password, forward client's response
1107            // Read password from client if needed
1108            let n = tokio::time::timeout(Duration::from_millis(100), client_stream.read(&mut read_buf))
1109                .await;
1110
1111            if let Ok(Ok(n)) = n {
1112                if n > 0 {
1113                    client_buffer.extend_from_slice(&read_buf[..n]);
1114                    backend_stream
1115                        .write_all(&read_buf[..n])
1116                        .await
1117                        .map_err(|e| ProxyError::Network(format!("Backend password write error: {}", e)))?;
1118                }
1119            }
1120        }
1121    }
1122
1123    /// Route message and forward to appropriate backend
1124    async fn route_and_forward(
1125        msg: &Message,
1126        backend_stream: Option<TcpStream>,
1127        current_node: Option<String>,
1128        session: &Arc<ClientSession>,
1129        state: &Arc<ServerState>,
1130        config: &ProxyConfig,
1131    ) -> Result<(Vec<u8>, Option<TcpStream>, Option<String>)> {
1132        // Determine if this is a write operation (from SQL verb).
1133        let default_is_write = Self::is_write_message(msg);
1134
1135        // Plugin Route hook may override the routing decision — force
1136        // primary/standby, or pin the query to a specific node.
1137        let route_override = Self::apply_route_hook(msg, state, session);
1138
1139        // Block short-circuits before any backend selection: synthesise
1140        // a PG ErrorResponse + ReadyForQuery, hand the existing backend
1141        // stream and current node back unchanged so the caller can
1142        // continue the session normally with the next message.
1143        if let RouteOverride::Block(reason) = route_override {
1144            let mut response = Vec::with_capacity(64 + reason.len());
1145            response.extend_from_slice(&Self::create_error_response(
1146                "42000",
1147                &format!("Query blocked by route plugin: {}", reason),
1148            ));
1149            response.extend_from_slice(&Self::create_ready_for_query(b'I'));
1150            state
1151                .metrics
1152                .bytes_sent
1153                .fetch_add(response.len() as u64, Ordering::Relaxed);
1154            return Ok((response, backend_stream, current_node));
1155        }
1156
1157        // Derive effective (is_write, forced_target) after override.
1158        let (is_write, forced_target) = match route_override {
1159            RouteOverride::None => (default_is_write, None),
1160            RouteOverride::Primary => (true, None),
1161            RouteOverride::Standby => (false, None),
1162            RouteOverride::Node(name) => (default_is_write, Some(name)),
1163            RouteOverride::Block(_) => unreachable!("handled above"),
1164        };
1165
1166        // Sticky session mode: stay on the same backend if healthy and
1167        // compatible with the routing decision. A forced target shortcuts
1168        // the usual write-needs-primary check: the only question is whether
1169        // the current connection already points at the forced node.
1170        let need_switch = if let Some(ref forced) = forced_target {
1171            let health = state.health.read().await;
1172            let reuse = current_node
1173                .as_ref()
1174                .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
1175                .unwrap_or(false);
1176            !reuse
1177        } else if let Some(ref current) = current_node {
1178            let health = state.health.read().await;
1179            let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
1180
1181            if !current_healthy {
1182                true
1183            } else if is_write {
1184                // Check if current is primary
1185                let is_primary = config.nodes.iter()
1186                    .find(|n| n.address() == *current)
1187                    .map(|n| n.role == NodeRole::Primary)
1188                    .unwrap_or(false);
1189                !is_primary
1190            } else {
1191                false
1192            }
1193        } else {
1194            true
1195        };
1196
1197        let target_node = if let Some(forced) = forced_target {
1198            forced
1199        } else if need_switch {
1200            if is_write {
1201                Self::select_primary_with_timeout(session, state, config).await?
1202            } else {
1203                Self::select_read_node(session, state, config).await?
1204            }
1205        } else {
1206            current_node.clone().unwrap()
1207        };
1208
1209        let mut backend = if need_switch {
1210            // Close old connection if any
1211            drop(backend_stream);
1212
1213            // Connect to new backend
1214            let new_backend = tokio::time::timeout(
1215                config.pool.acquire_timeout(),
1216                TcpStream::connect(&target_node),
1217            )
1218            .await
1219            .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target_node)))?
1220            .map_err(|e| {
1221                ProxyError::Connection(format!("Failed to connect to {}: {}", target_node, e))
1222            })?;
1223
1224            // Re-authenticate to new backend (silently, without forwarding to client)
1225            let params = session.variables.read().await.clone();
1226            let startup = Self::build_startup_message(&params);
1227            let mut backend = new_backend;
1228            backend
1229                .write_all(&startup)
1230                .await
1231                .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
1232
1233            // Complete authentication by reading until ReadyForQuery
1234            Self::complete_backend_auth(&mut backend).await?;
1235
1236            tracing::debug!(
1237                "Switched backend from {:?} to {} for {} query",
1238                current_node,
1239                target_node,
1240                if is_write { "write" } else { "read" }
1241            );
1242
1243            backend
1244        } else {
1245            backend_stream.unwrap()
1246        };
1247
1248        // Forward the message to backend
1249        let encoded = msg.encode();
1250        backend
1251            .write_all(&encoded)
1252            .await
1253            .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))?;
1254
1255        // Read response from backend
1256        let mut response = Vec::new();
1257        let mut response_buffer = BytesMut::with_capacity(8192);
1258        let codec = ProtocolCodec::new();
1259
1260        loop {
1261            let mut read_buf = vec![0u8; 8192];
1262            let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
1263                .await
1264                .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
1265                .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
1266
1267            if n == 0 {
1268                break;
1269            }
1270
1271            response.extend_from_slice(&read_buf[..n]);
1272            response_buffer.extend_from_slice(&read_buf[..n]);
1273
1274            // Check if we've received ReadyForQuery (end of response)
1275            while let Some(resp_msg) = codec.decode_message(&mut response_buffer.clone())? {
1276                if resp_msg.msg_type == MessageType::ReadyForQuery {
1277                    // Update transaction state
1278                    if !resp_msg.payload.is_empty() {
1279                        let status = TransactionStatus::from_byte(resp_msg.payload[0]);
1280                        let mut tx_state = session.tx_state.write().await;
1281                        tx_state.in_transaction = status != TransactionStatus::Idle;
1282                    }
1283                    return Ok((response, Some(backend), Some(target_node)));
1284                }
1285                let _ = codec.decode_message(&mut response_buffer)?;
1286            }
1287        }
1288
1289        Ok((response, Some(backend), Some(target_node)))
1290    }
1291
1292    /// Check if a message is a write operation
1293    fn is_write_message(msg: &Message) -> bool {
1294        match msg.msg_type {
1295            MessageType::Query => {
1296                // Parse query and check if it's a write
1297                if let Ok(query_msg) = QueryMessage::parse(msg.payload.clone()) {
1298                    Self::is_write_query(&query_msg.query)
1299                } else {
1300                    false
1301                }
1302            }
1303            MessageType::Parse => {
1304                // Parse prepared statement
1305                if let Ok(parse_msg) = ParseMessage::parse(msg.payload.clone()) {
1306                    Self::is_write_query(&parse_msg.query)
1307                } else {
1308                    false
1309                }
1310            }
1311            // Execute, Bind, etc. maintain the current connection
1312            _ => false,
1313        }
1314    }
1315
1316    /// Check if SQL query is a write operation
1317    fn is_write_query(sql: &str) -> bool {
1318        let upper = sql.trim().to_uppercase();
1319
1320        // Write operations
1321        if upper.starts_with("INSERT")
1322            || upper.starts_with("UPDATE")
1323            || upper.starts_with("DELETE")
1324            || upper.starts_with("CREATE")
1325            || upper.starts_with("DROP")
1326            || upper.starts_with("ALTER")
1327            || upper.starts_with("TRUNCATE")
1328            || upper.starts_with("GRANT")
1329            || upper.starts_with("REVOKE")
1330            || upper.starts_with("VACUUM")
1331            || upper.starts_with("REINDEX")
1332            || upper.starts_with("CLUSTER")
1333        {
1334            return true;
1335        }
1336
1337        // Transaction control goes to current node
1338        if upper.starts_with("BEGIN")
1339            || upper.starts_with("START")
1340            || upper.starts_with("COMMIT")
1341            || upper.starts_with("ROLLBACK")
1342            || upper.starts_with("SAVEPOINT")
1343            || upper.starts_with("RELEASE")
1344        {
1345            return true;
1346        }
1347
1348        // SET commands go to primary to maintain session state
1349        if upper.starts_with("SET") && !upper.starts_with("SET TRANSACTION READ ONLY") {
1350            return true;
1351        }
1352
1353        false
1354    }
1355
1356    /// Select primary node with write timeout during failover
1357    async fn select_primary_with_timeout(
1358        session: &Arc<ClientSession>,
1359        state: &Arc<ServerState>,
1360        config: &ProxyConfig,
1361    ) -> Result<String> {
1362        let timeout = config.write_timeout();
1363        let start = std::time::Instant::now();
1364        let check_interval = Duration::from_millis(500);
1365
1366        loop {
1367            // Try to find healthy primary
1368            let health = state.health.read().await;
1369            let primary = config
1370                .nodes
1371                .iter()
1372                .find(|n| n.role == NodeRole::Primary && n.enabled);
1373
1374            if let Some(primary_node) = primary {
1375                if let Some(node_health) = health.get(&primary_node.address()) {
1376                    if node_health.healthy {
1377                        // Update session's current node
1378                        let mut current = session.current_node.write().await;
1379                        *current = Some(primary_node.address());
1380                        return Ok(primary_node.address());
1381                    }
1382                }
1383            }
1384            drop(health);
1385
1386            // Check if timeout exceeded
1387            if start.elapsed() >= timeout {
1388                state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
1389                return Err(ProxyError::NoHealthyNodes);
1390            }
1391
1392            tracing::warn!(
1393                "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
1394                start.elapsed().as_secs_f64(),
1395                timeout.as_secs_f64()
1396            );
1397
1398            // Wait before retry
1399            tokio::time::sleep(check_interval).await;
1400        }
1401    }
1402
1403    /// Select node for read operations with load balancing
1404    async fn select_read_node(
1405        session: &Arc<ClientSession>,
1406        state: &Arc<ServerState>,
1407        config: &ProxyConfig,
1408    ) -> Result<String> {
1409        // If in transaction, stick to current node
1410        {
1411            let tx_state = session.tx_state.read().await;
1412            if tx_state.in_transaction {
1413                if let Some(node) = session.current_node.read().await.clone() {
1414                    return Ok(node);
1415                }
1416            }
1417        }
1418
1419        // Get healthy nodes (prefer standbys for reads)
1420        let health = state.health.read().await;
1421        let healthy_standbys: Vec<&NodeConfig> = config
1422            .nodes
1423            .iter()
1424            .filter(|n| {
1425                n.enabled
1426                    && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
1427                    && health
1428                        .get(&n.address())
1429                        .map(|h| h.healthy)
1430                        .unwrap_or(false)
1431            })
1432            .collect();
1433
1434        if !healthy_standbys.is_empty() {
1435            // Round-robin across healthy standbys
1436            let mut lb_state = state.lb_state.write().await;
1437            let index = lb_state.rr_counter as usize % healthy_standbys.len();
1438            lb_state.rr_counter = lb_state.rr_counter.wrapping_add(1);
1439            let node_addr = healthy_standbys[index].address();
1440
1441            let mut current = session.current_node.write().await;
1442            *current = Some(node_addr.clone());
1443            return Ok(node_addr);
1444        }
1445
1446        // Fall back to primary if no healthy standbys
1447        Self::select_node(session, state, config).await
1448    }
1449
1450    /// Complete backend authentication by reading until ReadyForQuery
1451    /// This is used when switching backends - we don't forward auth to client
1452    async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
1453        let codec = ProtocolCodec::new();
1454        let mut buffer = BytesMut::with_capacity(4096);
1455        let timeout = Duration::from_secs(10);
1456        let start = std::time::Instant::now();
1457
1458        loop {
1459            if start.elapsed() > timeout {
1460                return Err(ProxyError::Auth("Backend authentication timeout".to_string()));
1461            }
1462
1463            let mut read_buf = vec![0u8; 4096];
1464            let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
1465                .await
1466                .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
1467                .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1468
1469            if n == 0 {
1470                return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1471            }
1472
1473            buffer.extend_from_slice(&read_buf[..n]);
1474
1475            // Check for complete messages
1476            loop {
1477                if buffer.len() < 5 {
1478                    break;
1479                }
1480
1481                // Parse message
1482                let mut temp_buffer = buffer.clone();
1483                match codec.decode_message(&mut temp_buffer)? {
1484                    Some(msg) => {
1485                        match msg.msg_type {
1486                            MessageType::ReadyForQuery => {
1487                                // Authentication complete
1488                                return Ok(());
1489                            }
1490                            MessageType::ErrorResponse => {
1491                                let err = ErrorResponse::parse(msg.payload)
1492                                    .map(|e| e.message().unwrap_or("Unknown error").to_string())
1493                                    .unwrap_or_else(|_| "Parse error".to_string());
1494                                return Err(ProxyError::Auth(err));
1495                            }
1496                            _ => {
1497                                // Continue reading (AuthRequest, ParameterStatus, BackendKeyData, etc.)
1498                            }
1499                        }
1500                        // Consume the message from actual buffer
1501                        let _ = codec.decode_message(&mut buffer)?;
1502                    }
1503                    None => {
1504                        // Need more data
1505                        break;
1506                    }
1507                }
1508            }
1509        }
1510    }
1511
1512    /// Create PostgreSQL error response message
1513    fn create_error_response(code: &str, message: &str) -> Vec<u8> {
1514        let mut fields = HashMap::new();
1515        fields.insert('S', "ERROR".to_string());
1516        fields.insert('V', "ERROR".to_string());
1517        fields.insert('C', code.to_string());
1518        fields.insert('M', message.to_string());
1519
1520        let err = ErrorResponse { fields };
1521        err.encode().encode().to_vec()
1522    }
1523
1524    /// Create a `ReadyForQuery` frame with the given transaction-status byte
1525    /// (`b'I'` = idle, `b'T'` = in transaction, `b'E'` = failed transaction).
1526    fn create_ready_for_query(status: u8) -> Vec<u8> {
1527        let mut payload = BytesMut::with_capacity(1);
1528        payload.put_u8(status);
1529        Message::new(MessageType::ReadyForQuery, payload)
1530            .encode()
1531            .to_vec()
1532    }
1533
1534    /// Synthesise a full PostgreSQL simple-query response from a cached
1535    /// payload produced by a plugin's `PreQueryResult::Cached`.
1536    ///
1537    /// # Payload format
1538    ///
1539    /// The plugin is expected to serialise a JSON document of the form:
1540    ///
1541    /// ```json
1542    /// {
1543    ///   "columns": [
1544    ///     {"name": "id",    "oid": 23},
1545    ///     {"name": "email", "oid": 25}
1546    ///   ],
1547    ///   "rows": [
1548    ///     ["1", "alice@example.com"],
1549    ///     ["2", null]
1550    ///   ]
1551    /// }
1552    /// ```
1553    ///
1554    /// `oid` is the PostgreSQL type OID (`23` = int4, `25` = text,
1555    /// `20` = int8, `16` = bool, `1184` = timestamptz, etc.). Row values
1556    /// are strings in text format; `null` encodes a SQL NULL. The type
1557    /// OID is advisory — pgwire clients accept `25` (text) universally
1558    /// and cast as needed.
1559    ///
1560    /// # Returned bytes
1561    ///
1562    /// One concatenated PostgreSQL wire response:
1563    ///
1564    /// ```text
1565    /// RowDescription (T) + DataRow (D) × N + CommandComplete (C: "SELECT N")
1566    ///                    + ReadyForQuery (Z: idle)
1567    /// ```
1568    ///
1569    /// Returns an error on malformed JSON; the caller falls back to
1570    /// backend forwarding.
1571    #[cfg(feature = "wasm-plugins")]
1572    fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
1573        use serde::Deserialize;
1574
1575        #[derive(Deserialize)]
1576        struct CachedPayload {
1577            columns: Vec<ColumnDef>,
1578            rows: Vec<Vec<Option<String>>>,
1579        }
1580
1581        #[derive(Deserialize)]
1582        struct ColumnDef {
1583            name: String,
1584            #[serde(default = "default_text_oid")]
1585            oid: u32,
1586        }
1587
1588        fn default_text_oid() -> u32 {
1589            25 // text
1590        }
1591
1592        let payload: CachedPayload = serde_json::from_slice(bytes).map_err(|e| {
1593            ProxyError::Protocol(format!("invalid cached payload JSON: {}", e))
1594        })?;
1595
1596        if payload.columns.is_empty() {
1597            return Err(ProxyError::Protocol(
1598                "cached payload must declare at least one column".to_string(),
1599            ));
1600        }
1601
1602        let mut reply = Vec::new();
1603
1604        // RowDescription (tag 'T')
1605        let mut rd = BytesMut::new();
1606        rd.put_u16(payload.columns.len() as u16);
1607        for col in &payload.columns {
1608            rd.extend_from_slice(col.name.as_bytes());
1609            rd.put_u8(0); // cstring terminator
1610            rd.put_i32(0); // tableOID (unknown)
1611            rd.put_i16(0); // columnNumber (unknown)
1612            rd.put_u32(col.oid);
1613            rd.put_i16(-1); // typeLen (unspecified)
1614            rd.put_i32(-1); // typeMod (unspecified)
1615            rd.put_i16(0); // format code: text
1616        }
1617        reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
1618
1619        // DataRow (tag 'D') per row
1620        let column_count = payload.columns.len();
1621        for row in &payload.rows {
1622            if row.len() != column_count {
1623                return Err(ProxyError::Protocol(format!(
1624                    "cached row has {} values but {} columns are declared",
1625                    row.len(),
1626                    column_count
1627                )));
1628            }
1629            let mut dr = BytesMut::new();
1630            dr.put_u16(row.len() as u16);
1631            for value in row {
1632                match value {
1633                    Some(s) => {
1634                        dr.put_i32(s.len() as i32);
1635                        dr.extend_from_slice(s.as_bytes());
1636                    }
1637                    None => {
1638                        dr.put_i32(-1); // NULL sentinel
1639                    }
1640                }
1641            }
1642            reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
1643        }
1644
1645        // CommandComplete (tag 'C')
1646        let tag = format!("SELECT {}", payload.rows.len());
1647        let mut cc = BytesMut::new();
1648        cc.extend_from_slice(tag.as_bytes());
1649        cc.put_u8(0);
1650        reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
1651
1652        // ReadyForQuery (tag 'Z', status 'I' idle)
1653        reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
1654
1655        Ok(reply)
1656    }
1657
1658    /// Run the pre-query plugin hook on a client message.
1659    ///
1660    /// When the `wasm-plugins` feature is off, or the plugin manager has no
1661    /// loaded plugins, this is a zero-cost passthrough that returns the
1662    /// message untouched with `PreQueryAction::Forward`.
1663    ///
1664    /// Only simple-query (`MessageType::Query`) messages are inspected today.
1665    /// Extended-protocol messages (`Parse`/`Bind`/`Execute`) are passed
1666    /// through unchanged — a future task wires them in.
1667    fn apply_pre_query_hook(
1668        msg: Message,
1669        state: &Arc<ServerState>,
1670        session: &Arc<ClientSession>,
1671    ) -> (Message, PreQueryAction) {
1672        #[cfg(feature = "wasm-plugins")]
1673        {
1674            let pm = match state.plugin_manager.as_ref() {
1675                Some(pm) => pm,
1676                None => return (msg, PreQueryAction::Forward),
1677            };
1678
1679            if msg.msg_type != MessageType::Query {
1680                return (msg, PreQueryAction::Forward);
1681            }
1682
1683            let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1684                Ok(q) => q,
1685                Err(_) => return (msg, PreQueryAction::Forward),
1686            };
1687
1688            let ctx = Self::build_query_context(&query_msg.query, session);
1689
1690            match pm.execute_pre_query(&ctx) {
1691                PreQueryResult::Continue => (msg, PreQueryAction::Forward),
1692                PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
1693                PreQueryResult::Rewrite(new_sql) => {
1694                    let rewritten = QueryMessage { query: new_sql }.encode();
1695                    (rewritten, PreQueryAction::Forward)
1696                }
1697                PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
1698            }
1699        }
1700        #[cfg(not(feature = "wasm-plugins"))]
1701        {
1702            let _ = (state, session);
1703            (msg, PreQueryAction::Forward)
1704        }
1705    }
1706
1707    /// Feed the anomaly detector a per-query observation. Cheap —
1708    /// only the SQL-injection scan and the novel-fingerprint check
1709    /// are non-trivial, both well under a microsecond on
1710    /// representative queries. Returns nothing; detections land in
1711    /// the detector's ring buffer and are surfaced via /api/anomalies.
1712    #[cfg(feature = "anomaly-detection")]
1713    fn record_anomaly_observation(
1714        msg: &Message,
1715        state: &Arc<ServerState>,
1716        session: &Arc<ClientSession>,
1717    ) {
1718        if msg.msg_type != MessageType::Query {
1719            return;
1720        }
1721        let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1722            Ok(q) => q,
1723            Err(_) => return,
1724        };
1725        // Tenant identifier is the most-specific known per-session
1726        // attribute the proxy can attribute traffic to. Multi-tenancy
1727        // sets `tenant_id` in `variables`; otherwise we fall back to
1728        // the client address (string-shaped per-client rate window).
1729        // session.variables is a tokio RwLock — but record_anomaly is
1730        // a sync helper. Use try_read so we don't add an await; on
1731        // contention we fall back to the client IP, which is still a
1732        // valid per-source identifier.
1733        let tenant = match session.variables.try_read() {
1734            Ok(vars) => vars
1735                .get("tenant_id")
1736                .or_else(|| vars.get("user"))
1737                .cloned()
1738                .unwrap_or_else(|| session.client_addr.ip().to_string()),
1739            Err(_) => session.client_addr.ip().to_string(),
1740        };
1741        let fingerprint = anomaly_fingerprint(&query_msg.query);
1742        let now = std::time::Instant::now();
1743        let iso = chrono::Utc::now().to_rfc3339();
1744        let obs = crate::anomaly::QueryObservation {
1745            tenant,
1746            fingerprint,
1747            sql: query_msg.query,
1748            timestamp: now,
1749            iso_timestamp: iso,
1750        };
1751        for ev in state.anomaly_detector.record_query(&obs) {
1752            tracing::warn!(
1753                anomaly = ?ev,
1754                "anomaly detected"
1755            );
1756        }
1757    }
1758
1759    /// Send the client a `Block`-outcome response: an error frame plus
1760    /// `ReadyForQuery` so the client's state machine returns to idle and
1761    /// the next query can be accepted.
1762    async fn send_block_response(
1763        stream: &mut TcpStream,
1764        reason: &str,
1765        state: &Arc<ServerState>,
1766    ) -> Result<()> {
1767        let err = Self::create_error_response(
1768            "42000",
1769            &format!("Query blocked by plugin: {}", reason),
1770        );
1771        stream
1772            .write_all(&err)
1773            .await
1774            .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
1775        let rfq = Self::create_ready_for_query(b'I');
1776        stream
1777            .write_all(&rfq)
1778            .await
1779            .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
1780        state
1781            .metrics
1782            .bytes_sent
1783            .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
1784        Ok(())
1785    }
1786
1787    /// Build a `QueryContext` for the plugin hook. Populated fields: `query`
1788    /// (verbatim), `is_read_only` (derived from SQL verb), and `hook_context`
1789    /// with the session id as `client_id`. `normalized` and `tables` are
1790    /// left as cheap stand-ins until the analytics normaliser is wired in
1791    /// (T0-d, unified context).
1792    #[cfg(feature = "wasm-plugins")]
1793    fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
1794        let is_read_only = !Self::is_write_query(query);
1795        let mut hook_context = HookContext::default();
1796        hook_context.client_id = Some(session.id.to_string());
1797        QueryContext {
1798            query: query.to_string(),
1799            normalized: query.to_string(),
1800            tables: Vec::new(),
1801            is_read_only,
1802            hook_context,
1803        }
1804    }
1805
1806    /// Run the Authenticate plugin hook at startup. Called from
1807    /// `connect_and_authenticate` before any backend connection.
1808    ///
1809    /// Behaviour by `AuthResult`:
1810    /// * `Defer` — no plugin opinion; proceed with the default
1811    ///   PostgreSQL auth flow unchanged.
1812    /// * `Success(identity)` — store the identity on the session so
1813    ///   downstream plugins (masking, residency) can gate on roles /
1814    ///   tenant_id / claims. PostgreSQL backend auth still runs
1815    ///   normally afterwards (the plugin does not replace PG auth in
1816    ///   this iteration; that's a follow-up).
1817    /// * `Denied(reason)` — surfaces as `ProxyError::Auth`, which the
1818    ///   caller already handles by writing an ErrorResponse to the
1819    ///   client and closing the connection.
1820    ///
1821    /// The `AuthRequest` populated here carries username, database,
1822    /// and client IP from the PostgreSQL startup parameters. Password
1823    /// is deliberately `None` — PG protocol sends the password in
1824    /// response to the backend's challenge, not at startup, so
1825    /// password-aware plugin auth is a separate future task.
1826    async fn apply_authenticate_hook(
1827        _params: &HashMap<String, String>,
1828        _session: &Arc<ClientSession>,
1829        _state: &Arc<ServerState>,
1830    ) -> Result<()> {
1831        #[cfg(feature = "wasm-plugins")]
1832        {
1833            let pm = match _state.plugin_manager.as_ref() {
1834                Some(pm) => pm,
1835                None => return Ok(()),
1836            };
1837
1838            let request = PluginAuthRequest {
1839                headers: HashMap::new(),
1840                username: _params.get("user").cloned(),
1841                password: None,
1842                client_ip: _session.client_addr.ip().to_string(),
1843                database: _params.get("database").cloned(),
1844            };
1845
1846            match pm.execute_authenticate(&request) {
1847                AuthResult::Defer => Ok(()),
1848                AuthResult::Success(identity) => {
1849                    tracing::debug!(
1850                        user = %identity.username,
1851                        roles = ?identity.roles,
1852                        "plugin authenticated user"
1853                    );
1854                    *_session.plugin_identity.write().await = Some(identity);
1855                    Ok(())
1856                }
1857                AuthResult::Denied(reason) => {
1858                    tracing::info!(
1859                        reason = %reason,
1860                        client = %_session.client_addr,
1861                        user = ?_params.get("user"),
1862                        "plugin denied authentication"
1863                    );
1864                    Err(ProxyError::Auth(format!(
1865                        "authentication denied by plugin: {}",
1866                        reason
1867                    )))
1868                }
1869            }
1870        }
1871        #[cfg(not(feature = "wasm-plugins"))]
1872        {
1873            Ok(())
1874        }
1875    }
1876
1877    /// Run the Route plugin hook on a message. Only simple-query messages
1878    /// are inspected; other message types always return `None`.
1879    fn apply_route_hook(
1880        msg: &Message,
1881        state: &Arc<ServerState>,
1882        session: &Arc<ClientSession>,
1883    ) -> RouteOverride {
1884        #[cfg(feature = "wasm-plugins")]
1885        {
1886            let pm = match state.plugin_manager.as_ref() {
1887                Some(pm) => pm,
1888                None => return RouteOverride::None,
1889            };
1890            if msg.msg_type != MessageType::Query {
1891                return RouteOverride::None;
1892            }
1893            let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1894                Ok(q) => q,
1895                Err(_) => return RouteOverride::None,
1896            };
1897            let ctx = Self::build_query_context(&query_msg.query, session);
1898            match pm.execute_route(&ctx) {
1899                RouteResult::Default => RouteOverride::None,
1900                RouteResult::Primary => RouteOverride::Primary,
1901                RouteResult::Standby => RouteOverride::Standby,
1902                RouteResult::Node(name) => RouteOverride::Node(name),
1903                RouteResult::Block(reason) => RouteOverride::Block(reason),
1904                RouteResult::Branch(name) => {
1905                    tracing::warn!(
1906                        branch = %name,
1907                        "Route hook returned Branch but branch routing is not yet wired — using default"
1908                    );
1909                    RouteOverride::None
1910                }
1911            }
1912        }
1913        #[cfg(not(feature = "wasm-plugins"))]
1914        {
1915            let _ = (msg, state, session);
1916            RouteOverride::None
1917        }
1918    }
1919
1920    /// Fire post-query hooks after a message has been forwarded (or failed
1921    /// to forward). Best-effort; errors from individual plugins are logged
1922    /// by the plugin manager and never surface here.
1923    #[cfg(feature = "wasm-plugins")]
1924    fn fire_post_query_hook(
1925        msg: &Message,
1926        session: &Arc<ClientSession>,
1927        state: &Arc<ServerState>,
1928        result: &Result<(Vec<u8>, Option<TcpStream>, Option<String>)>,
1929        elapsed: Duration,
1930    ) {
1931        let pm = match state.plugin_manager.as_ref() {
1932            Some(pm) => pm,
1933            None => return,
1934        };
1935        if msg.msg_type != MessageType::Query {
1936            return;
1937        }
1938        let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1939            Ok(q) => q,
1940            Err(_) => return,
1941        };
1942        let ctx = Self::build_query_context(&query_msg.query, session);
1943        let outcome = match result {
1944            Ok((resp, _, node)) => PostQueryOutcome {
1945                success: true,
1946                target_node: node.clone(),
1947                elapsed_us: elapsed.as_micros() as u64,
1948                response_bytes: resp.len() as u64,
1949                error: None,
1950            },
1951            Err(e) => PostQueryOutcome {
1952                success: false,
1953                target_node: None,
1954                elapsed_us: elapsed.as_micros() as u64,
1955                response_bytes: 0,
1956                error: Some(e.to_string()),
1957            },
1958        };
1959        pm.execute_post_query(&ctx, &outcome);
1960    }
1961
1962    /// Select a backend node for the request
1963    /// Select a backend node for initial connection
1964    /// Prefers primary but falls back to standbys for read connections
1965    async fn select_node(
1966        session: &Arc<ClientSession>,
1967        state: &Arc<ServerState>,
1968        config: &ProxyConfig,
1969    ) -> Result<String> {
1970        // If in a transaction, stick to the current node
1971        {
1972            let tx_state = session.tx_state.read().await;
1973            if tx_state.in_transaction {
1974                if let Some(node) = session.current_node.read().await.clone() {
1975                    return Ok(node);
1976                }
1977            }
1978        }
1979
1980        // Get healthy nodes
1981        let health = state.health.read().await;
1982        let healthy_nodes: Vec<&NodeConfig> = config
1983            .nodes
1984            .iter()
1985            .filter(|n| {
1986                n.enabled
1987                    && health
1988                        .get(&n.address())
1989                        .map(|h| h.healthy)
1990                        .unwrap_or(false)
1991            })
1992            .collect();
1993
1994        if healthy_nodes.is_empty() {
1995            return Err(ProxyError::NoHealthyNodes);
1996        }
1997
1998        // Try to find healthy primary first
1999        if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
2000            let node_addr = primary.address();
2001            let mut current = session.current_node.write().await;
2002            *current = Some(node_addr.clone());
2003            return Ok(node_addr);
2004        }
2005
2006        // Fall back to standby if primary is unavailable
2007        // (Initial connection will work, writes will use write timeout to wait for primary)
2008        if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
2009            tracing::warn!("Primary unavailable, connecting to standby for initial session");
2010            let node_addr = standby.address();
2011            let mut current = session.current_node.write().await;
2012            *current = Some(node_addr.clone());
2013            return Ok(node_addr);
2014        }
2015
2016        // No nodes available
2017        Err(ProxyError::NoHealthyNodes)
2018    }
2019
2020    /// Spawn health checker background task
2021    fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
2022        let state = self.state.clone();
2023        let config = self.config.clone();
2024        let mut shutdown_rx = self.shutdown_tx.subscribe();
2025
2026        tokio::spawn(async move {
2027            let mut interval =
2028                tokio::time::interval(std::time::Duration::from_secs(config.health.check_interval_secs));
2029
2030            loop {
2031                tokio::select! {
2032                    _ = interval.tick() => {
2033                        Self::check_all_nodes(&state, &config).await;
2034                    }
2035                    _ = shutdown_rx.recv() => {
2036                        break;
2037                    }
2038                }
2039            }
2040        })
2041    }
2042
2043    /// Check health of all nodes
2044    async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
2045        for node in &config.nodes {
2046            let result = Self::check_node_health(node, config).await;
2047            let mut health = state.health.write().await;
2048
2049            if let Some(node_health) = health.get_mut(&node.address()) {
2050                match result {
2051                    Ok(latency) => {
2052                        node_health.healthy = true;
2053                        node_health.failure_count = 0;
2054                        node_health.latency_ms = latency;
2055                        node_health.last_error = None;
2056                    }
2057                    Err(e) => {
2058                        node_health.failure_count += 1;
2059                        node_health.last_error = Some(e.to_string());
2060
2061                        if node_health.failure_count >= config.health.failure_threshold {
2062                            node_health.healthy = false;
2063                            tracing::warn!(
2064                                "Node {} marked unhealthy after {} failures",
2065                                node.address(),
2066                                node_health.failure_count
2067                            );
2068                        }
2069                    }
2070                }
2071                node_health.last_check = chrono::Utc::now();
2072            }
2073        }
2074    }
2075
2076    /// Check health of a single node
2077    async fn check_node_health(node: &NodeConfig, config: &ProxyConfig) -> Result<f64> {
2078        let start = std::time::Instant::now();
2079
2080        let timeout = std::time::Duration::from_secs(config.health.check_timeout_secs);
2081        let _stream = tokio::time::timeout(timeout, TcpStream::connect(node.address()))
2082            .await
2083            .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", node.address())))?
2084            .map_err(|e| {
2085                ProxyError::HealthCheck(format!("Failed to connect to {}: {}", node.address(), e))
2086            })?;
2087
2088        // In a real implementation, we would execute the health check query here
2089        let latency = start.elapsed().as_secs_f64() * 1000.0;
2090        Ok(latency)
2091    }
2092
2093    /// Spawn pool manager background task
2094    fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
2095        let state = self.state.clone();
2096        let mut shutdown_rx = self.shutdown_tx.subscribe();
2097
2098        tokio::spawn(async move {
2099            let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
2100
2101            loop {
2102                tokio::select! {
2103                    _ = interval.tick() => {
2104                        // Evict idle connections from pool-modes manager
2105                        #[cfg(feature = "pool-modes")]
2106                        if let Some(ref pool_manager) = state.pool_manager {
2107                            pool_manager.evict_idle().await;
2108                            tracing::trace!("Pool-modes idle eviction completed");
2109                        }
2110                    }
2111                    _ = shutdown_rx.recv() => {
2112                        // Cleanup on shutdown
2113                        #[cfg(feature = "pool-modes")]
2114                        if let Some(ref pool_manager) = state.pool_manager {
2115                            pool_manager.close_all().await;
2116                            tracing::info!("Pool-modes manager closed all connections");
2117                        }
2118                        break;
2119                    }
2120                }
2121            }
2122        })
2123    }
2124
2125    /// Shutdown the server
2126    pub fn shutdown(&self) {
2127        let _ = self.shutdown_tx.send(());
2128    }
2129
2130    /// Get pool mode statistics (if pool-modes feature enabled)
2131    #[cfg(feature = "pool-modes")]
2132    pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
2133        if let Some(ref pool_manager) = self.state.pool_manager {
2134            let stats = pool_manager.get_stats().await;
2135            let metrics = pool_manager.metrics().snapshot();
2136            let default_mode = pool_manager.default_mode();
2137
2138            // Calculate average lease duration across all modes
2139            let avg_lease_duration_ms = metrics
2140                .mode_stats
2141                .get(&default_mode)
2142                .map(|s| s.avg_lease_duration_ms as u64)
2143                .unwrap_or(0);
2144
2145            Some(PoolModeStatsSnapshot {
2146                mode: format!("{:?}", default_mode),
2147                total_connections: stats.total_connections,
2148                active_leases: stats.active_connections,
2149                idle_connections: stats.idle_connections,
2150                node_count: stats.node_count,
2151                acquires: metrics.acquires,
2152                releases: metrics.releases,
2153                acquire_failures: metrics.acquire_failures,
2154                acquire_timeouts: metrics.acquire_timeouts,
2155                transactions_completed: metrics.transactions_completed,
2156                statements_executed: metrics.statements_executed,
2157                avg_lease_duration_ms,
2158            })
2159        } else {
2160            None
2161        }
2162    }
2163
2164    /// Add a node to the pool manager (if pool-modes feature enabled)
2165    #[cfg(feature = "pool-modes")]
2166    pub async fn add_node_to_pool(&self, node: &NodeConfig) {
2167        if let Some(ref pool_manager) = self.state.pool_manager {
2168            let endpoint = NodeEndpoint::new(&node.host, node.port)
2169                .with_role(match node.role {
2170                    NodeRole::Primary => crate::NodeRole::Primary,
2171                    NodeRole::Standby => crate::NodeRole::Standby,
2172                    NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
2173                })
2174                .with_weight(node.weight);
2175            pool_manager.add_node(&endpoint).await;
2176            tracing::info!("Added node {} to pool manager", node.address());
2177        }
2178    }
2179
2180    /// Get server metrics
2181    pub fn metrics(&self) -> ServerMetricsSnapshot {
2182        ServerMetricsSnapshot {
2183            connections_accepted: self.state.metrics.connections_accepted.load(Ordering::Relaxed),
2184            connections_closed: self.state.metrics.connections_closed.load(Ordering::Relaxed),
2185            queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
2186            bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
2187            bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
2188            failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
2189        }
2190    }
2191}
2192
2193/// Metrics snapshot for external consumption
2194#[derive(Debug, Clone)]
2195pub struct ServerMetricsSnapshot {
2196    pub connections_accepted: u64,
2197    pub connections_closed: u64,
2198    pub queries_processed: u64,
2199    pub bytes_received: u64,
2200    pub bytes_sent: u64,
2201    pub failovers: u64,
2202}
2203
2204/// Pool mode statistics snapshot (when pool-modes feature is enabled)
2205#[cfg(feature = "pool-modes")]
2206#[derive(Debug, Clone)]
2207pub struct PoolModeStatsSnapshot {
2208    /// Current pooling mode
2209    pub mode: String,
2210    /// Total connections across all pools
2211    pub total_connections: usize,
2212    /// Active (leased) connections
2213    pub active_leases: usize,
2214    /// Idle connections
2215    pub idle_connections: usize,
2216    /// Number of nodes in the pool
2217    pub node_count: usize,
2218    /// Total connection acquires
2219    pub acquires: u64,
2220    /// Total connection releases
2221    pub releases: u64,
2222    /// Failed acquire attempts
2223    pub acquire_failures: u64,
2224    /// Acquire timeouts
2225    pub acquire_timeouts: u64,
2226    /// Completed transactions (Transaction mode)
2227    pub transactions_completed: u64,
2228    /// Total statements executed
2229    pub statements_executed: u64,
2230    /// Average lease duration in milliseconds
2231    pub avg_lease_duration_ms: u64,
2232}
2233
2234#[cfg(test)]
2235mod tests {
2236    use super::*;
2237    use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
2238
2239    fn test_config() -> ProxyConfig {
2240        let mut config = ProxyConfig::default();
2241        config.listen_address = "127.0.0.1:0".to_string();
2242        config
2243            .add_node("127.0.0.1:5432", "primary")
2244            .unwrap();
2245        config
2246    }
2247
2248    #[test]
2249    fn test_server_creation() {
2250        let config = test_config();
2251        let server = ProxyServer::new(config);
2252        assert!(server.is_ok());
2253    }
2254
2255    #[test]
2256    fn test_initial_metrics() {
2257        let config = test_config();
2258        let server = ProxyServer::new(config).unwrap();
2259        let metrics = server.metrics();
2260        assert_eq!(metrics.connections_accepted, 0);
2261        assert_eq!(metrics.queries_processed, 0);
2262    }
2263
2264    #[tokio::test]
2265    async fn test_session_creation() {
2266        let config = test_config();
2267        let server = ProxyServer::new(config).unwrap();
2268
2269        let sessions = server.state.sessions.read().await;
2270        assert!(sessions.is_empty());
2271    }
2272
2273    #[tokio::test]
2274    async fn test_node_health_initialization() {
2275        let config = test_config();
2276        let server = ProxyServer::new(config).unwrap();
2277
2278        let health = server.state.health.read().await;
2279        assert!(!health.is_empty());
2280
2281        for node_health in health.values() {
2282            assert!(node_health.healthy);
2283            assert_eq!(node_health.failure_count, 0);
2284        }
2285    }
2286
2287    /// Build a minimal `ClientSession` for plugin-hook unit tests.
2288    fn make_test_session() -> Arc<ClientSession> {
2289        Arc::new(ClientSession {
2290            id: Uuid::new_v4(),
2291            client_addr: "127.0.0.1:0".parse().unwrap(),
2292            current_node: RwLock::new(None),
2293            tx_state: RwLock::new(TransactionState::default()),
2294            variables: RwLock::new(HashMap::new()),
2295            created_at: chrono::Utc::now(),
2296            tr_mode: crate::config::TrMode::default(),
2297            #[cfg(feature = "pool-modes")]
2298            pool_client_id: crate::pool::lease::ClientId::default(),
2299            #[cfg(feature = "wasm-plugins")]
2300            plugin_identity: RwLock::new(None),
2301        })
2302    }
2303
2304    /// With no plugin manager attached, `apply_route_hook` must be a
2305    /// zero-cost `None` return so the default SQL-verb routing applies.
2306    /// Verifies the feature-gated early-return path.
2307    #[tokio::test]
2308    async fn test_apply_route_hook_no_plugin_manager_returns_none() {
2309        let config = test_config();
2310        let server = ProxyServer::new(config).unwrap();
2311        let session = make_test_session();
2312
2313        let msg = QueryMessage {
2314            query: "SELECT * FROM users".to_string(),
2315        }
2316        .encode();
2317
2318        let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
2319        assert!(matches!(decision, RouteOverride::None));
2320    }
2321
2322    /// Same invariant for the pre-query hook: without a plugin manager,
2323    /// `apply_pre_query_hook` must return the message unchanged with
2324    /// `PreQueryAction::Forward`.
2325    #[tokio::test]
2326    async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
2327        let config = test_config();
2328        let server = ProxyServer::new(config).unwrap();
2329        let session = make_test_session();
2330
2331        let original = QueryMessage {
2332            query: "SELECT 1".to_string(),
2333        }
2334        .encode();
2335        let original_bytes = original.encode().to_vec();
2336
2337        let (msg_out, action) =
2338            ProxyServer::apply_pre_query_hook(original, &server.state, &session);
2339
2340        assert!(matches!(action, PreQueryAction::Forward));
2341        // The message must survive the hook byte-for-byte when no plugins run.
2342        assert_eq!(msg_out.encode().to_vec(), original_bytes);
2343    }
2344
2345    /// Non-Query message types (e.g., extended-protocol Parse/Execute) must
2346    /// bypass the Route hook entirely regardless of plugin state, because
2347    /// we haven't wired SQL extraction for those variants yet.
2348    #[tokio::test]
2349    async fn test_apply_route_hook_skips_non_query_messages() {
2350        let config = test_config();
2351        let server = ProxyServer::new(config).unwrap();
2352        let session = make_test_session();
2353
2354        let sync_msg = Message::empty(MessageType::Sync);
2355        let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
2356        assert!(matches!(decision, RouteOverride::None));
2357    }
2358
2359    /// By default, `[plugins].enabled = false`, so `init_plugin_manager`
2360    /// short-circuits without touching the filesystem or wasmtime and
2361    /// returns `None`. The proxy starts normally whether or not a plugin
2362    /// directory exists on the host.
2363    #[cfg(feature = "wasm-plugins")]
2364    #[test]
2365    fn test_init_plugin_manager_disabled_by_default_returns_none() {
2366        let config = test_config();
2367        assert!(!config.plugins.enabled);
2368        let pm = ProxyServer::init_plugin_manager(&config.plugins);
2369        assert!(pm.is_none());
2370    }
2371
2372    /// Plugins enabled but pointing at a directory that doesn't exist
2373    /// must still initialise the manager (so new plugins can be hot-
2374    /// loaded later) and log a warning — it must NOT fail startup.
2375    #[cfg(feature = "wasm-plugins")]
2376    #[test]
2377    fn test_init_plugin_manager_missing_dir_logs_warning() {
2378        let mut config = test_config();
2379        config.plugins.enabled = true;
2380        config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
2381
2382        // Manager is created; no panic; Some(pm) returned even with empty dir.
2383        let pm = ProxyServer::init_plugin_manager(&config.plugins);
2384        assert!(pm.is_some());
2385    }
2386
2387    /// With no plugin manager attached, `apply_authenticate_hook` is a
2388    /// zero-cost `Ok(())` that leaves session identity unset — the
2389    /// default PG auth flow applies.
2390    #[tokio::test]
2391    async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
2392        let config = test_config();
2393        let server = ProxyServer::new(config).unwrap();
2394        let session = make_test_session();
2395
2396        let mut params = HashMap::new();
2397        params.insert("user".to_string(), "alice".to_string());
2398        params.insert("database".to_string(), "app".to_string());
2399
2400        let result =
2401            ProxyServer::apply_authenticate_hook(&params, &session, &server.state).await;
2402        assert!(result.is_ok());
2403
2404        // No plugin → no identity stored.
2405        #[cfg(feature = "wasm-plugins")]
2406        {
2407            let ident = session.plugin_identity.read().await;
2408            assert!(ident.is_none());
2409        }
2410    }
2411
2412    /// Cached-response synthesis round-trip: a well-formed plugin
2413    /// payload must produce concatenated wire frames in the order
2414    /// `T D D C Z`. We inspect the raw tag bytes directly because
2415    /// `MessageType::from_tag` conflates server→client DataRow (`'D'`)
2416    /// with client→server Describe (same byte) — a known quirk of the
2417    /// shared `MessageType` enum that the real proxy side-steps by
2418    /// knowing the direction at the call site.
2419    #[cfg(feature = "wasm-plugins")]
2420    #[test]
2421    fn test_synthesise_cached_response_roundtrip() {
2422        let payload = br#"{
2423            "columns": [
2424                {"name": "id",    "oid": 23},
2425                {"name": "email", "oid": 25}
2426            ],
2427            "rows": [
2428                ["1", "alice@example.com"],
2429                ["2", null]
2430            ]
2431        }"#;
2432        let reply =
2433            ProxyServer::synthesise_cached_response(payload).expect("synthesis");
2434
2435        // Walk the concatenation frame-by-frame via length prefixes.
2436        // Each PG message: tag(1) + length(4, big-endian, includes self) + payload.
2437        let mut tags = Vec::new();
2438        let mut i = 0;
2439        while i < reply.len() {
2440            let tag = reply[i];
2441            let len = u32::from_be_bytes([
2442                reply[i + 1],
2443                reply[i + 2],
2444                reply[i + 3],
2445                reply[i + 4],
2446            ]) as usize;
2447            tags.push(tag);
2448            i += 1 + len;
2449        }
2450        assert_eq!(i, reply.len(), "no trailing bytes");
2451        assert_eq!(
2452            tags,
2453            vec![b'T', b'D', b'D', b'C', b'Z'],
2454            "wire frame order"
2455        );
2456
2457        // Spot-check the final ReadyForQuery payload is 'I' (idle).
2458        assert_eq!(*reply.last().unwrap(), b'I');
2459    }
2460
2461    /// Row width mismatch between columns and row data is rejected so
2462    /// the plugin author can't produce ambiguous wire frames.
2463    #[cfg(feature = "wasm-plugins")]
2464    #[test]
2465    fn test_synthesise_cached_response_rejects_row_width_mismatch() {
2466        let payload = br#"{
2467            "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
2468            "rows": [["1", "alice", "extra"]]
2469        }"#;
2470        let result = ProxyServer::synthesise_cached_response(payload);
2471        assert!(matches!(result, Err(ProxyError::Protocol(_))));
2472    }
2473
2474    /// Empty payload (no columns) is rejected — a RowDescription with
2475    /// zero columns is technically valid PG but useless and likely a
2476    /// plugin bug.
2477    #[cfg(feature = "wasm-plugins")]
2478    #[test]
2479    fn test_synthesise_cached_response_rejects_empty_columns() {
2480        let payload = br#"{ "columns": [], "rows": [] }"#;
2481        let result = ProxyServer::synthesise_cached_response(payload);
2482        assert!(matches!(result, Err(ProxyError::Protocol(_))));
2483    }
2484
2485    /// Malformed JSON must return a Protocol error, not panic. The
2486    /// caller treats this as "fall back to backend."
2487    #[cfg(feature = "wasm-plugins")]
2488    #[test]
2489    fn test_synthesise_cached_response_rejects_bad_json() {
2490        let payload = b"not json at all";
2491        let result = ProxyServer::synthesise_cached_response(payload);
2492        assert!(matches!(result, Err(ProxyError::Protocol(_))));
2493    }
2494
2495    /// Denied by plugin surfaces as `ProxyError::Auth` so the existing
2496    /// error-response path in `handle_client` writes an ErrorResponse
2497    /// and closes the connection. Here we prove the error variant
2498    /// when the plugin manager is present but denies. We build a
2499    /// PluginManager with no plugins loaded — so it defers — and
2500    /// verify the Ok path. (Denial path requires an actual
2501    /// auth-plugin `.wasm`; covered by the plugin unit tests in
2502    /// `plugins::tests`.)
2503    #[cfg(feature = "wasm-plugins")]
2504    #[tokio::test]
2505    async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
2506        use crate::plugins::{PluginManager, PluginRuntimeConfig};
2507
2508        let config = test_config();
2509        let server = ProxyServer::new(config).unwrap();
2510        let session = make_test_session();
2511
2512        // Synthesise a state with a real PluginManager but zero
2513        // registered plugins — every hook must defer.
2514        let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
2515        let augmented_state = Arc::new(ServerState {
2516            sessions: RwLock::new(HashMap::new()),
2517            health: RwLock::new(HashMap::new()),
2518            metrics: ServerMetrics::default(),
2519            lb_state: RwLock::new(LoadBalancerState {
2520                rr_counter: 0,
2521            }),
2522            #[cfg(feature = "pool-modes")]
2523            pool_manager: None,
2524            plugin_manager: Some(pm),
2525            #[cfg(feature = "ha-tr")]
2526            transaction_journal: Arc::new(
2527                crate::transaction_journal::TransactionJournal::new(),
2528            ),
2529            #[cfg(feature = "anomaly-detection")]
2530            anomaly_detector: Arc::new(
2531                crate::anomaly::AnomalyDetector::new(
2532                    crate::anomaly::AnomalyConfig::default(),
2533                ),
2534            ),
2535            #[cfg(feature = "edge-proxy")]
2536            edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
2537            #[cfg(feature = "edge-proxy")]
2538            edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
2539                32,
2540                std::time::Duration::from_secs(120),
2541            )),
2542        });
2543
2544        let mut params = HashMap::new();
2545        params.insert("user".to_string(), "alice".to_string());
2546
2547        let result =
2548            ProxyServer::apply_authenticate_hook(&params, &session, &augmented_state).await;
2549        assert!(result.is_ok());
2550        let ident = session.plugin_identity.read().await;
2551        assert!(ident.is_none());
2552        // Unused bindings for the sync-state build path.
2553        let _ = server;
2554    }
2555}