Skip to main content

heliosdb_proxy/
server.rs

1//! Proxy Server Implementation
2//!
3//! Main server that accepts client connections and routes them to backends.
4//! Implements PostgreSQL wire protocol forwarding with TWR (Transparent Write Routing).
5
6use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::client_tls::{build_tls_acceptor, ClientStream};
10use crate::config::{HbaAction, HbaRule, NodeConfig, NodeRole, ProxyConfig, TrMode};
11use crate::protocol::{
12    ErrorResponse, Message, MessageType, ProtocolCodec, QueryMessage,
13    StartupMessage, TransactionStatus,
14};
15use crate::{ProxyError, Result};
16use arc_swap::ArcSwap;
17use bytes::{BufMut, BytesMut};
18use dashmap::DashMap;
19use std::collections::{HashMap, HashSet};
20use std::net::SocketAddr;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::{AsyncReadExt, AsyncWriteExt};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{broadcast, RwLock};
27use uuid::Uuid;
28
29// Pool-modes feature imports
30#[cfg(feature = "pool-modes")]
31use crate::pool::{
32    ConnectionPoolManager, PoolModeConfig, PoolingMode,
33};
34#[cfg(feature = "pool-modes")]
35use crate::pool::lease::ClientId;
36#[cfg(feature = "pool-modes")]
37use crate::NodeEndpoint;
38
39// WASM plugin system imports
40#[cfg(feature = "wasm-plugins")]
41use crate::plugins::{
42    AuthRequest as PluginAuthRequest, AuthResult, HookContext, HookType, Identity, PluginManager,
43    PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
44};
45
46/// Proxy server
47pub struct ProxyServer {
48    config: ProxyConfig,
49    state: Arc<ServerState>,
50    shutdown_tx: broadcast::Sender<()>,
51    /// Path the config was loaded from, retained so `SIGHUP` can re-read it
52    /// for a zero-downtime reload (Batch H). `None` when the config was built
53    /// from CLI flags/defaults rather than a file.
54    config_path: Option<String>,
55}
56
57/// Stand-in "signal stream" on platforms without Unix signals: its `recv()`
58/// never resolves, so the `SIGHUP` select arm is simply inert there.
59#[cfg(not(unix))]
60struct HangupNever;
61#[cfg(not(unix))]
62impl HangupNever {
63    async fn recv(&mut self) -> Option<()> {
64        std::future::pending().await
65    }
66}
67
68/// Build the BackendConfig template the time-travel replay engine
69/// uses for its target connection. The replay handler swaps in
70/// `target_host` / `target_port` per request; everything else
71/// (auth, TLS policy, timeouts) comes from this template.
72///
73/// Auth defaults to the bare PostgreSQL `postgres` superuser without
74/// a password — sensible for local development against `trust` auth,
75/// never for production. Per-call credential overrides on
76/// ReplayRequestBody land in FU-21.
77///
78/// `_config` is kept in the signature so future iterations can pull
79/// shared TLS / timeout settings from the proxy config without
80/// changing the call site.
81#[cfg(feature = "ha-tr")]
82fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
83    BackendConfig {
84        host: "placeholder".to_string(),
85        port: 0,
86        user: "postgres".to_string(),
87        password: None,
88        database: None,
89        application_name: Some("heliosdb-proxy-replay".to_string()),
90        tls_mode: TlsMode::Disable,
91        connect_timeout: Duration::from_secs(5),
92        query_timeout: Duration::from_secs(30),
93        tls_config: default_client_config(),
94    }
95}
96
97/// Cheap query-shape fingerprint for the anomaly detector. Replaces
98/// numeric and string literals with `?` placeholders, lower-cases
99/// keywords, and collapses whitespace. Same shape regardless of
100/// literal values — `SELECT * FROM users WHERE id = 1` and
101/// `SELECT * FROM users WHERE id = 99` map to the same fingerprint.
102///
103/// Not a parser. The analytics module has the canonical normaliser
104/// when query-analytics is on; this is a lightweight standalone so
105/// the anomaly detector works even when analytics is off.
106#[cfg(feature = "anomaly-detection")]
107fn anomaly_fingerprint(sql: &str) -> String {
108    let mut out = String::with_capacity(sql.len());
109    let mut in_single = false;
110    let mut prev_space = false;
111    let mut chars = sql.chars().peekable();
112    while let Some(c) = chars.next() {
113        if c == '\'' {
114            in_single = !in_single;
115            // Replace the entire string literal (open + body +
116            // close) with a single ?.
117            if in_single {
118                out.push('?');
119                while let Some(&n) = chars.peek() {
120                    chars.next();
121                    if n == '\'' {
122                        in_single = false;
123                        break;
124                    }
125                }
126                prev_space = false;
127                continue;
128            }
129        }
130        if c.is_ascii_digit() {
131            if !out.ends_with('?') {
132                out.push('?');
133            }
134            // Skip the rest of the number.
135            while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
136                chars.next();
137            }
138            prev_space = false;
139            continue;
140        }
141        if c.is_ascii_whitespace() {
142            if !prev_space && !out.is_empty() {
143                out.push(' ');
144                prev_space = true;
145            }
146            continue;
147        }
148        out.push(c.to_ascii_lowercase());
149        prev_space = false;
150    }
151    out.trim_end().to_string()
152}
153
154/// Server runtime state
155struct ServerState {
156    /// Active client sessions
157    sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
158    /// Node health status
159    // Read-mostly: only the periodic health checker writes (a full-map
160    // swap), every query reads. ArcSwap makes the per-query read a single
161    // lock-free atomic load with no await, no semaphore, no guard held
162    // across the routing awaits.
163    health: ArcSwap<HashMap<String, NodeHealth>>,
164    /// Live, reloadable proxy configuration (Batch H). The accept loop snapshots
165    /// this per new connection and the health checker reads it each tick, so a
166    /// SIGHUP that swaps it takes effect for new connections and node health
167    /// without dropping any in-flight session. The fields that can only be
168    /// applied at startup (listen/admin socket addresses) are ignored on reload
169    /// with a warning. Existing connections keep the snapshot they started with.
170    live_config: ArcSwap<ProxyConfig>,
171    /// Metrics
172    metrics: ServerMetrics,
173    /// Query-cancellation routing. Maps the BackendKeyData (pid, secret)
174    /// the backend handed to the client onto the backend address that
175    /// issued it, so a later out-of-band CancelRequest (which arrives on a
176    /// fresh connection) can be forwarded to the right backend instead of
177    /// being dropped. Bounded; best-effort.
178    cancel_map: Arc<DashMap<(u32, u32), String>>,
179    /// Client-facing TLS acceptor, built from `[tls]` config when enabled.
180    /// `None` => the proxy rejects SSLRequests with `N` (plaintext only).
181    tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
182    /// Proxy-terminated SCRAM auth state. `Some` when `[auth] mode = "scram"`:
183    /// the proxy authenticates clients itself against this user list instead
184    /// of relaying their credentials to the backend.
185    auth_file: Option<Arc<crate::auth_scram::AuthFile>>,
186    /// Traffic-mirror handle. `Some` when `[mirror] enabled`: the data path
187    /// offers write statements to a background mirror worker.
188    mirror: Option<crate::mirror::MirrorHandle>,
189    /// Migration cutover switch. When `Some`, NEW client connections are
190    /// transparently redirected to the promoted target (the former mirror)
191    /// instead of the configured primary. Set via POST /api/migration/cutover.
192    cutover: Arc<ArcSwap<Option<Arc<crate::mirror::CutoverTarget>>>>,
193    /// Load balancer state
194    lb_state: LoadBalancerState,
195    /// Pool manager for Session/Transaction/Statement modes
196    #[cfg(feature = "pool-modes")]
197    pool_manager: Option<Arc<ConnectionPoolManager>>,
198    /// WASM plugin manager. `None` means no plugins loaded — the per-query
199    /// hook path becomes a fast no-op. When `Some`, `PreQuery` / `PostQuery`
200    /// hooks fire on every simple-query message.
201    #[cfg(feature = "wasm-plugins")]
202    plugin_manager: Option<Arc<PluginManager>>,
203    /// Shared transaction journal — single sink for per-session
204    /// statement journaling. The replay engine reads windows from
205    /// this directly. Always present when the `ha-tr` feature is on;
206    /// journaling self-disables internally when not configured.
207    #[cfg(feature = "ha-tr")]
208    transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
209    /// Anomaly detector (T3.1). Records every query and every
210    /// auth outcome; surfaces detections via /api/anomalies.
211    #[cfg(feature = "anomaly-detection")]
212    anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
213    /// Edge cache + home registry (T3.2). Both always-present even
214    /// in Home mode (the cache is a no-op there); avoids an extra
215    /// Option in the hot path.
216    #[cfg(feature = "edge-proxy")]
217    edge_cache: Arc<crate::edge::EdgeCache>,
218    #[cfg(feature = "edge-proxy")]
219    edge_registry: Arc<crate::edge::EdgeRegistry>,
220}
221
222/// Node health status
223#[derive(Debug, Clone)]
224pub struct NodeHealth {
225    /// Node address
226    pub address: String,
227    /// Whether node is healthy
228    pub healthy: bool,
229    /// Last check time
230    pub last_check: chrono::DateTime<chrono::Utc>,
231    /// Consecutive failures
232    pub failure_count: u32,
233    /// Last error message
234    pub last_error: Option<String>,
235    /// Average latency (ms)
236    pub latency_ms: f64,
237    /// Replication lag (if applicable)
238    pub replication_lag_bytes: Option<u64>,
239}
240
241/// Server metrics
242#[derive(Default)]
243struct ServerMetrics {
244    /// Total connections accepted
245    connections_accepted: AtomicU64,
246    /// Total connections closed
247    connections_closed: AtomicU64,
248    /// Total queries processed
249    queries_processed: AtomicU64,
250    /// Total bytes received from clients
251    bytes_received: AtomicU64,
252    /// Total bytes sent to clients
253    bytes_sent: AtomicU64,
254    /// Failover count
255    failovers: AtomicU64,
256}
257
258/// Load balancer state
259struct LoadBalancerState {
260    /// Round-robin counter. Atomic so the read-routing path never
261    /// takes a write lock just to advance the rotation.
262    rr_counter: AtomicU64,
263}
264
265/// Client session
266pub struct ClientSession {
267    /// Session ID
268    pub id: Uuid,
269    /// Client address
270    pub client_addr: SocketAddr,
271    /// Current backend node
272    pub current_node: RwLock<Option<String>>,
273    /// Transaction state
274    pub tx_state: RwLock<TransactionState>,
275    /// Session variables
276    pub variables: RwLock<HashMap<String, String>>,
277    /// Created at
278    pub created_at: chrono::DateTime<chrono::Utc>,
279    /// TR mode for this session
280    pub tr_mode: TrMode,
281    /// Client ID for pool-modes lease tracking
282    #[cfg(feature = "pool-modes")]
283    pub pool_client_id: ClientId,
284    /// Identity returned by an `Authenticate` plugin, if any. Downstream
285    /// plugins (masking, residency routing, cost governor) read this to
286    /// gate per-user policy. `None` when no plugin ran or every plugin
287    /// deferred to the default auth flow.
288    #[cfg(feature = "wasm-plugins")]
289    pub plugin_identity: RwLock<Option<Identity>>,
290}
291
292/// Transaction state
293#[derive(Debug, Clone, Default)]
294pub struct TransactionState {
295    /// Whether in a transaction
296    pub in_transaction: bool,
297    /// Transaction ID
298    pub tx_id: Option<Uuid>,
299    /// Statements executed in current transaction
300    pub statements: Vec<StatementLog>,
301    /// Read-only transaction
302    pub read_only: bool,
303    /// Savepoints
304    pub savepoints: Vec<String>,
305}
306
307/// Logged statement for TR replay
308#[derive(Debug, Clone)]
309pub struct StatementLog {
310    /// Statement SQL
311    pub sql: String,
312    /// Parameters
313    pub params: Vec<String>,
314    /// Result checksum
315    pub result_checksum: Option<u64>,
316    /// Execution time
317    pub executed_at: chrono::DateTime<chrono::Utc>,
318}
319
320/// A cached per-session backend connection plus the set of *named* prepared
321/// statements known to be live on **this** socket.
322///
323/// Tying the prepared-statement set to the socket (rather than to the node
324/// address) is what makes prepared statements survive a backend switch: when a
325/// connection is dropped and redialed, or when a session is routed to a
326/// different node, the fresh `BackendConn` starts with an empty set, so the
327/// proxy transparently re-issues the original `Parse` for any named statement
328/// the target connection is missing before forwarding a `Bind`/`Describe` that
329/// references it (Batch F.4). The session keeps the canonical `Parse` bytes in
330/// a separate registry; this set is just "what does *this* socket already
331/// know".
332struct BackendConn {
333    stream: TcpStream,
334    prepared: HashSet<String>,
335    /// Signature (query text + parameter-type OIDs) of the *unnamed* prepared
336    /// statement currently established on this socket, if any. When the client
337    /// re-sends an identical unnamed `Parse`, the proxy can skip forwarding it
338    /// (the backend's unnamed statement already holds that SQL) and synthesize
339    /// the `ParseComplete` locally — the unnamed-Parse promotion (Batch H).
340    unnamed_sig: Option<bytes::Bytes>,
341}
342
343impl BackendConn {
344    fn new(stream: TcpStream) -> Self {
345        Self { stream, prepared: HashSet::new(), unnamed_sig: None }
346    }
347}
348
349/// Bind a TCP listener with `SO_REUSEADDR` + `SO_REUSEPORT` so a second process
350/// can bind the same address concurrently (the kernel then load-balances new
351/// connections across both). This is what lets a new binary take over new
352/// connections while the old one drains — used for both the client and admin
353/// listeners so a binary handoff can re-bind every address (Batch H).
354pub(crate) fn bind_reuseport(addr: &str) -> Result<TcpListener> {
355    use socket2::{Domain, Protocol, Socket, Type};
356    let sockaddr: SocketAddr = addr
357        .parse()
358        .map_err(|e| ProxyError::Config(format!("invalid listen address '{}': {}", addr, e)))?;
359    let domain = if sockaddr.is_ipv6() { Domain::IPV6 } else { Domain::IPV4 };
360    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
361        .map_err(|e| ProxyError::Network(format!("socket(): {}", e)))?;
362    socket
363        .set_reuse_address(true)
364        .map_err(|e| ProxyError::Network(format!("SO_REUSEADDR: {}", e)))?;
365    #[cfg(all(unix, not(target_os = "solaris")))]
366    socket
367        .set_reuse_port(true)
368        .map_err(|e| ProxyError::Network(format!("SO_REUSEPORT: {}", e)))?;
369    socket
370        .set_nonblocking(true)
371        .map_err(|e| ProxyError::Network(format!("set_nonblocking: {}", e)))?;
372    socket
373        .bind(&sockaddr.into())
374        .map_err(|e| ProxyError::Network(format!("Failed to bind {}: {}", addr, e)))?;
375    socket
376        .listen(1024)
377        .map_err(|e| ProxyError::Network(format!("listen(): {}", e)))?;
378    let std_listener: std::net::TcpListener = socket.into();
379    TcpListener::from_std(std_listener)
380        .map_err(|e| ProxyError::Network(format!("from_std listener: {}", e)))
381}
382
383/// Disposition produced by the pre-query plugin hook stage.
384///
385/// When the `wasm-plugins` feature is off, only `Forward` is ever produced —
386/// the hook dispatch is compiled out entirely and the variant list exists
387/// purely for pattern-match symmetry.
388#[derive(Debug)]
389#[allow(dead_code)] // Block/Cached only constructed under wasm-plugins
390enum PreQueryAction {
391    /// Send the message to the backend as usual.
392    Forward,
393    /// A plugin blocked the query. The caller sends an error + ReadyForQuery
394    /// to the client and skips backend forwarding.
395    Block(String),
396    /// A plugin returned a cached response. Not yet wired — response
397    /// synthesis from raw bytes requires building a full protocol reply
398    /// (RowDescription + DataRow(s) + CommandComplete + ReadyForQuery),
399    /// which is the next step of T0-a. For now the caller falls back to
400    /// `Forward` and logs a warning.
401    Cached(Vec<u8>),
402}
403
404/// Override produced by the Route plugin hook. Consumed by `route_and_forward`
405/// when deciding which backend to talk to.
406///
407/// As with `PreQueryAction`, only `None` is ever produced when the
408/// `wasm-plugins` feature is off.
409#[derive(Debug)]
410#[allow(dead_code)] // Primary/Standby/Node/Block only constructed under wasm-plugins
411enum RouteOverride {
412    /// No override — use the default SQL-verb-based routing.
413    None,
414    /// Force the write path (use `select_primary_with_timeout`).
415    Primary,
416    /// Force the read path (use `select_read_node`).
417    Standby,
418    /// Use this exact node address. Takes precedence over the is_write
419    /// heuristic; the proxy will still verify the node is healthy before
420    /// connecting (via the normal switch-vs-reuse flow).
421    Node(String),
422    /// Reject the query: write a PG ErrorResponse + ReadyForQuery to
423    /// the client and skip the forward. Carries the reason the plugin
424    /// supplied. Takes precedence over every other field — the proxy
425    /// short-circuits before any backend selection.
426    Block(String),
427}
428
429impl ProxyServer {
430    /// Build a `PluginManager` from config and preload plugins from disk.
431    ///
432    /// Returns `None` when plugins are disabled in config, when the
433    /// runtime fails to initialise, or when the plugin directory is
434    /// missing. Individual per-file load failures are logged but do not
435    /// abort startup — the remaining plugins load normally and the
436    /// proxy stays up.
437    #[cfg(feature = "wasm-plugins")]
438    fn init_plugin_manager(
439        toml_cfg: &crate::config::PluginToml,
440    ) -> Option<Arc<crate::plugins::PluginManager>> {
441        if !toml_cfg.enabled {
442            return None;
443        }
444
445        let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
446        let plugin_dir = runtime_cfg.plugin_dir.clone();
447
448        let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
449            Ok(pm) => Arc::new(pm),
450            Err(e) => {
451                tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
452                return None;
453            }
454        };
455
456        match std::fs::read_dir(&plugin_dir) {
457            Ok(entries) => {
458                let mut loaded = 0usize;
459                let mut failed = 0usize;
460                for entry in entries.flatten() {
461                    let path = entry.path();
462                    if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
463                        continue;
464                    }
465                    match pm.load_plugin(&path) {
466                        Ok(()) => loaded += 1,
467                        Err(e) => {
468                            failed += 1;
469                            tracing::warn!(
470                                path = %path.display(),
471                                error = %e,
472                                "Failed to load plugin"
473                            );
474                        }
475                    }
476                }
477                tracing::info!(
478                    dir = %plugin_dir.display(),
479                    loaded = loaded,
480                    failed = failed,
481                    "Plugin loading complete"
482                );
483            }
484            Err(e) => {
485                tracing::warn!(
486                    dir = %plugin_dir.display(),
487                    error = %e,
488                    "Plugin directory not readable; no plugins loaded"
489                );
490            }
491        }
492
493        Some(pm)
494    }
495
496    /// Create a new proxy server
497    pub fn new(config: ProxyConfig) -> Result<Self> {
498        let (shutdown_tx, _) = broadcast::channel(1);
499
500        // Initialize health status
501        let mut health = HashMap::new();
502        for node in &config.nodes {
503            health.insert(
504                node.address(),
505                NodeHealth {
506                    address: node.address(),
507                    healthy: true, // Assume healthy until proven otherwise
508                    last_check: chrono::Utc::now(),
509                    failure_count: 0,
510                    last_error: None,
511                    latency_ms: 0.0,
512                    replication_lag_bytes: None,
513                },
514            );
515        }
516
517        // Initialize pool manager if pool-modes feature is enabled
518        #[cfg(feature = "pool-modes")]
519        let pool_manager = {
520            use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
521
522            let pool_config = PoolModeConfig {
523                default_mode: match config.pool_mode.mode {
524                    crate::config::PoolingMode::Session => PoolingMode::Session,
525                    crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
526                    crate::config::PoolingMode::Statement => PoolingMode::Statement,
527                },
528                max_pool_size: config.pool_mode.max_pool_size,
529                min_idle: config.pool_mode.min_idle,
530                idle_timeout_secs: config.pool_mode.idle_timeout_secs,
531                max_lifetime_secs: config.pool_mode.max_lifetime_secs,
532                acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
533                reset_query: config.pool_mode.reset_query.clone(),
534                prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
535                    crate::config::PreparedStatementMode::Disable => {
536                        PoolPreparedStatementMode::Disable
537                    }
538                    crate::config::PreparedStatementMode::Track => {
539                        PoolPreparedStatementMode::Track
540                    }
541                    crate::config::PreparedStatementMode::Named => {
542                        PoolPreparedStatementMode::Named
543                    }
544                },
545                test_on_acquire: config.pool.test_on_acquire,
546                validation_query: "SELECT 1".to_string(),
547                queue_timeout_secs: 30,
548                max_queue_size: 0,
549            };
550            Some(Arc::new(ConnectionPoolManager::new(pool_config)))
551        };
552
553        // Initialize plugin manager if the wasm-plugins feature is enabled
554        // AND plugins are turned on in config. Scans plugin_dir for `.wasm`
555        // files and loads each; a missing directory is non-fatal and logs
556        // a warning so empty deployments don't fail startup.
557        #[cfg(feature = "wasm-plugins")]
558        let plugin_manager = Self::init_plugin_manager(&config.plugins);
559
560        // Build the client TLS acceptor if [tls] is configured + enabled.
561        // A bad cert/key is fatal at startup (fail fast, don't silently
562        // fall back to plaintext for a deployment that asked for TLS).
563        let tls_acceptor = match config.tls.as_ref() {
564            Some(tls) if tls.enabled => match build_tls_acceptor(tls) {
565                Ok(acc) => {
566                    tracing::info!(
567                        mtls = tls.require_client_cert,
568                        "client TLS termination enabled"
569                    );
570                    Some(acc)
571                }
572                Err(e) => {
573                    return Err(ProxyError::Config(format!("TLS init failed: {}", e)));
574                }
575            },
576            _ => None,
577        };
578
579        // Load the SCRAM auth_file when proxy-terminated auth is requested.
580        // Misconfiguration is fatal at startup (fail fast).
581        let auth_file = if config.auth.mode == crate::config::AuthMode::Scram {
582            let path = config.auth.auth_file.as_ref().ok_or_else(|| {
583                ProxyError::Config("auth mode 'scram' requires auth_file".to_string())
584            })?;
585            let af = crate::auth_scram::AuthFile::load(path)
586                .map_err(|e| ProxyError::Config(format!("auth_file: {}", e)))?;
587            tracing::info!(users = %(!af.is_empty()), "proxy SCRAM auth enabled");
588            Some(Arc::new(af))
589        } else {
590            None
591        };
592
593        // Spawn the traffic-mirror worker when enabled (we are inside the
594        // tokio runtime here — main is #[tokio::main]).
595        let mirror = if config.mirror.enabled {
596            tracing::info!(target = %format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
597                writes_only = config.mirror.writes_only, "traffic mirroring enabled");
598            Some(crate::mirror::spawn(config.mirror.clone()))
599        } else {
600            None
601        };
602
603        let state = Arc::new(ServerState {
604            sessions: RwLock::new(HashMap::new()),
605            health: ArcSwap::from_pointee(health),
606            live_config: ArcSwap::from_pointee(config.clone()),
607            metrics: ServerMetrics::default(),
608            cancel_map: Arc::new(DashMap::new()),
609            tls_acceptor,
610            auth_file,
611            mirror,
612            cutover: Arc::new(ArcSwap::from_pointee(None)),
613            lb_state: LoadBalancerState {
614                rr_counter: AtomicU64::new(0),
615            },
616            #[cfg(feature = "pool-modes")]
617            pool_manager,
618            #[cfg(feature = "wasm-plugins")]
619            plugin_manager,
620            #[cfg(feature = "ha-tr")]
621            transaction_journal: Arc::new(
622                crate::transaction_journal::TransactionJournal::new(),
623            ),
624            #[cfg(feature = "anomaly-detection")]
625            anomaly_detector: Arc::new(
626                crate::anomaly::AnomalyDetector::new(
627                    crate::anomaly::AnomalyConfig::default(),
628                ),
629            ),
630            #[cfg(feature = "edge-proxy")]
631            edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
632            #[cfg(feature = "edge-proxy")]
633            edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
634                32,
635                std::time::Duration::from_secs(120),
636            )),
637        });
638
639        Ok(Self {
640            config,
641            state,
642            shutdown_tx,
643            config_path: None,
644        })
645    }
646
647    /// Record the config file path so `SIGHUP` can re-read it for a live
648    /// reload (Batch H). Without a path (config built from CLI flags/defaults)
649    /// a `SIGHUP` is logged and ignored — there is nothing to re-read.
650    pub fn with_config_path(mut self, path: Option<String>) -> Self {
651        self.config_path = path;
652        self
653    }
654
655    /// A stream that yields once per `SIGHUP`. On non-Unix platforms it never
656    /// yields (config reload is Unix-signal driven).
657    #[cfg(unix)]
658    fn hangup_stream() -> tokio::signal::unix::Signal {
659        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
660            .expect("failed to install SIGHUP handler")
661    }
662    #[cfg(not(unix))]
663    fn hangup_stream() -> HangupNever {
664        HangupNever
665    }
666
667    /// A stream that yields once per `SIGUSR2` — the graceful binary-handoff
668    /// drain trigger. Never yields on non-Unix platforms.
669    #[cfg(unix)]
670    fn usr2_stream() -> tokio::signal::unix::Signal {
671        tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())
672            .expect("failed to install SIGUSR2 handler")
673    }
674    #[cfg(not(unix))]
675    fn usr2_stream() -> HangupNever {
676        HangupNever
677    }
678
679    /// Wait for in-flight client connections to finish, up to `timeout`. Used by
680    /// the graceful drain after the listener is closed — the session map is the
681    /// live active-connection gauge (one entry per accepted connection).
682    async fn drain_connections(state: &Arc<ServerState>, timeout: Duration) {
683        let deadline = tokio::time::Instant::now() + timeout;
684        loop {
685            let active = state.sessions.read().await.len();
686            if active == 0 {
687                tracing::info!("drain complete — all in-flight connections finished");
688                return;
689            }
690            if tokio::time::Instant::now() >= deadline {
691                tracing::warn!(
692                    active,
693                    "drain timeout reached — exiting with connections still open"
694                );
695                return;
696            }
697            tokio::time::sleep(Duration::from_millis(200)).await;
698        }
699    }
700
701    /// Graceful-drain timeout: how long to keep serving in-flight connections
702    /// after SIGUSR2 before exiting. Sourced from `shutdown_drain_timeout_secs`
703    /// in the live config, with the `HELIOS_DRAIN_TIMEOUT_SECS` env var as a
704    /// runtime override.
705    fn drain_timeout(config_secs: u64) -> Duration {
706        let secs = std::env::var("HELIOS_DRAIN_TIMEOUT_SECS")
707            .ok()
708            .and_then(|s| s.parse::<u64>().ok())
709            .unwrap_or(config_secs);
710        Duration::from_secs(secs)
711    }
712
713    /// Re-read the config file and hot-swap the live config (Batch H).
714    ///
715    /// New connections immediately use the reloaded config; in-flight sessions
716    /// keep the snapshot they began with, so nothing is dropped. A parse error
717    /// keeps the running config untouched. Socket-bound fields (listen/admin
718    /// address) cannot change on an already-bound listener and are reported but
719    /// not applied. The node set is reconciled into the health map so routing
720    /// sees additions/removals at once.
721    async fn reload_config(&self) {
722        let Some(path) = self.config_path.as_deref() else {
723            tracing::warn!(
724                "SIGHUP received but config was not loaded from a file — nothing to reload"
725            );
726            return;
727        };
728        tracing::info!(path, "SIGHUP: reloading configuration");
729        let new_config = match ProxyConfig::from_file(path) {
730            Ok(c) => c,
731            Err(e) => {
732                tracing::error!(path, error = %e, "SIGHUP reload failed to parse — keeping current config");
733                return;
734            }
735        };
736        let old = self.state.live_config.load_full();
737        if new_config.listen_address != old.listen_address {
738            tracing::warn!(old = %old.listen_address, new = %new_config.listen_address,
739                "listen_address change needs a restart/handoff; the bound socket is kept");
740        }
741        if new_config.admin_address != old.admin_address {
742            tracing::warn!(old = %old.admin_address, new = %new_config.admin_address,
743                "admin_address change needs a restart; the bound socket is kept");
744        }
745        // Reconcile node health to the new node set before publishing the
746        // config, so the first connection on the new config can route to it.
747        Self::reconcile_health(&self.state, &new_config);
748        let nodes = new_config.nodes.len();
749        let hba_rules = new_config.hba.len();
750        let pool_max = new_config.pool.max_connections;
751        self.state.live_config.store(Arc::new(new_config));
752        tracing::info!(
753            nodes,
754            hba_rules,
755            pool_max,
756            "SIGHUP: configuration reloaded — applies to new connections"
757        );
758    }
759
760    /// Rebuild the health map for `config`'s node set: surviving nodes keep
761    /// their current health; new nodes are seeded healthy (immediately
762    /// routable, the next check confirms); removed nodes are dropped.
763    fn reconcile_health(state: &Arc<ServerState>, config: &ProxyConfig) {
764        let current = state.health.load_full();
765        let mut next: HashMap<String, NodeHealth> = HashMap::new();
766        for node in &config.nodes {
767            let addr = node.address();
768            match current.get(&addr) {
769                Some(existing) => {
770                    next.insert(addr, existing.clone());
771                }
772                None => {
773                    tracing::info!(node = %addr, "SIGHUP: new node added — seeding healthy");
774                    next.insert(
775                        addr.clone(),
776                        NodeHealth {
777                            address: addr,
778                            healthy: true,
779                            last_check: chrono::Utc::now(),
780                            failure_count: 0,
781                            last_error: None,
782                            latency_ms: 0.0,
783                            replication_lag_bytes: None,
784                        },
785                    );
786                }
787            }
788        }
789        for gone in current.keys().filter(|k| !next.contains_key(*k)) {
790            tracing::info!(node = %gone, "SIGHUP: node removed from config");
791        }
792        state.health.store(Arc::new(next));
793    }
794
795    /// Run the proxy server
796    pub async fn run(&self) -> Result<()> {
797        // Bind with SO_REUSEPORT so a freshly-started binary can bind the SAME
798        // listen address concurrently — the kernel load-balances new
799        // connections across both processes. That is the mechanism behind the
800        // zero-downtime binary handoff: start the new binary, then SIGUSR2 the
801        // old one to close its listener and drain (Batch H, item 84).
802        let listener = bind_reuseport(&self.config.listen_address)?;
803
804        tracing::info!("Proxy listening on {} (SO_REUSEPORT)", self.config.listen_address);
805
806        // Start background tasks
807        let health_task = self.spawn_health_checker();
808        let pool_task = self.spawn_pool_manager();
809
810        // Start admin server
811        let admin_task = self.spawn_admin_server();
812
813        // Start the MCP agent gateway when enabled.
814        let mcp_task = if self.config.mcp.enabled {
815            let mcp_cfg = self.config.mcp.clone();
816            // Resolve the configured agent contract (scoped grants) by id.
817            let contract = mcp_cfg.contract.as_ref().and_then(|id| {
818                let found = self.config.agent_contracts.iter().find(|c| &c.id == id).cloned();
819                if found.is_none() {
820                    tracing::warn!(%id, "mcp.contract names an unknown agent_contract; gateway runs with only the read-only guardrail");
821                }
822                found
823            });
824            Some(tokio::spawn(async move {
825                if let Err(e) = crate::mcp::McpServer::new(mcp_cfg, contract).run().await {
826                    tracing::error!("MCP gateway error: {}", e);
827                }
828            }))
829        } else {
830            None
831        };
832
833        // Start the HTTP SQL gateway (Neon-serverless compatible) when enabled.
834        let http_gw_task = if self.config.http_gateway.enabled {
835            let gw_cfg = self.config.http_gateway.clone();
836            Some(tokio::spawn(async move {
837                if let Err(e) = crate::http_gateway::HttpGateway::new(gw_cfg).run().await {
838                    tracing::error!("HTTP gateway error: {}", e);
839                }
840            }))
841        } else {
842            None
843        };
844
845        let mut shutdown_rx = self.shutdown_tx.subscribe();
846
847        // SIGHUP -> zero-downtime config reload; SIGUSR2 -> graceful drain for
848        // binary handoff (Batch H). On platforms without Unix signals these are
849        // simply never readable.
850        let mut sighup = Self::hangup_stream();
851        let mut sigusr2 = Self::usr2_stream();
852        let mut graceful = false;
853
854        loop {
855            tokio::select! {
856                _ = sighup.recv() => {
857                    self.reload_config().await;
858                }
859                _ = sigusr2.recv() => {
860                    tracing::info!(
861                        "SIGUSR2: graceful binary-handoff drain — closing the listener so new \
862                         connections route to the sibling process; finishing in-flight connections"
863                    );
864                    graceful = true;
865                    break;
866                }
867                accept_result = listener.accept() => {
868                    match accept_result {
869                        Ok((stream, addr)) => {
870                            // PG wire traffic is small request/response
871                            // frames; Nagle + delayed-ACK costs tens of
872                            // ms per round-trip if left on.
873                            let _ = stream.set_nodelay(true);
874                            self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
875                            let state = self.state.clone();
876                            // Snapshot the *live* config so a SIGHUP reload
877                            // applies to new connections; in-flight sessions
878                            // keep the snapshot they began with (Batch H).
879                            let config = (*self.state.live_config.load_full()).clone();
880                            let shutdown_tx = self.shutdown_tx.clone();
881
882                            tokio::spawn(async move {
883                                if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
884                                    tracing::error!("Client handler error: {}", e);
885                                }
886                            });
887                        }
888                        Err(e) => {
889                            tracing::error!("Accept error: {}", e);
890                        }
891                    }
892                }
893                _ = shutdown_rx.recv() => {
894                    tracing::info!("Shutdown signal received");
895                    break;
896                }
897            }
898        }
899
900        // Close the listening socket so the kernel stops routing new connections
901        // to this process's accept queue (with SO_REUSEPORT they would otherwise
902        // sit unaccepted) — all new connections now go to the sibling listener.
903        drop(listener);
904
905        // On a graceful handoff, keep serving in-flight connections until they
906        // finish (or the drain deadline), so nothing in flight is dropped.
907        if graceful {
908            let timeout =
909                Self::drain_timeout(self.state.live_config.load().shutdown_drain_timeout_secs);
910            tracing::info!(timeout_secs = timeout.as_secs(), "draining in-flight connections");
911            Self::drain_connections(&self.state, timeout).await;
912        }
913
914        // Wait for background tasks
915        health_task.abort();
916        pool_task.abort();
917        admin_task.abort();
918        if let Some(t) = mcp_task {
919            t.abort();
920        }
921        if let Some(t) = http_gw_task {
922            t.abort();
923        }
924
925        Ok(())
926    }
927
928    /// Spawn admin API server
929    fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
930        let config = self.config.clone();
931        let state = self.state.clone();
932        let mut shutdown_rx = self.shutdown_tx.subscribe();
933
934        tokio::spawn(async move {
935            // Create admin state
936            let admin_state = Arc::new(AdminState::new());
937
938            // Initialize config snapshot
939            {
940                let mut snapshot = admin_state.config_snapshot.write().await;
941                *snapshot = ConfigSnapshot {
942                    listen_address: config.listen_address.clone(),
943                    admin_address: config.admin_address.clone(),
944                    tr_enabled: config.tr_enabled,
945                    tr_mode: format!("{:?}", config.tr_mode),
946                    pool_min_connections: config.pool.min_connections,
947                    pool_max_connections: config.pool.max_connections,
948                    nodes: config.nodes.iter().map(|n| NodeSnapshot {
949                        address: n.address(),
950                        role: format!("{:?}", n.role),
951                        weight: n.weight,
952                        enabled: n.enabled,
953                    }).collect(),
954                };
955            }
956
957            // Set proxy config for SQL routing
958            admin_state.set_proxy_config(config.clone()).await;
959
960            // Require a Bearer token on admin requests when configured.
961            admin_state.with_auth_token(config.admin_token.clone()).await;
962
963            // Branch-database provisioning surface.
964            if config.branch.enabled {
965                admin_state.with_branch(config.branch.clone()).await;
966            }
967
968            // Surface traffic-mirror / migration status when mirroring is on.
969            if let Some(ref mirror) = state.mirror {
970                admin_state
971                    .with_migration(crate::admin::MigrationInfo {
972                        target: mirror.target().to_string(),
973                        writes_only: mirror.writes_only(),
974                        metrics: mirror.metrics.clone(),
975                        config: config.mirror.clone(),
976                        cutover: state.cutover.clone(),
977                        cutover_target: crate::mirror::CutoverTarget {
978                            addr: format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
979                            user: config.mirror.backend_user.clone(),
980                            password: config.mirror.backend_password.clone(),
981                            database: config.mirror.backend_database.clone(),
982                        },
983                    })
984                    .await;
985            }
986
987            // Attach the plugin manager so /plugins + the admin UI
988            // surface real loaded modules. Cheap Arc-clone — no
989            // duplicate state, both AdminState and ServerState hold
990            // the same manager.
991            #[cfg(feature = "wasm-plugins")]
992            if let Some(ref pm) = state.plugin_manager {
993                admin_state.with_plugin_manager(pm.clone()).await;
994            }
995
996            // Attach the time-travel replay engine. The engine reads
997            // windows from the shared TransactionJournal and replays
998            // statements against a target backend supplied per-request.
999            // Per-call credential overrides land via FU-21's
1000            // ReplayRequestBody.target_user / target_password /
1001            // target_database fields.
1002            #[cfg(feature = "ha-tr")]
1003            {
1004                let template = build_replay_backend_template(&config);
1005                let engine = Arc::new(crate::replay::ReplayEngine::new(
1006                    state.transaction_journal.clone(),
1007                    template,
1008                ));
1009                admin_state.with_replay_engine(engine).await;
1010            }
1011
1012            // Attach the anomaly detector — same Arc the server
1013            // populates from the query path. /api/anomalies polls
1014            // this for surfaced detections.
1015            #[cfg(feature = "anomaly-detection")]
1016            admin_state
1017                .with_anomaly_detector(state.anomaly_detector.clone())
1018                .await;
1019
1020            // Attach the edge cache + registry. Both surfaced via
1021            // /api/edge/* admin routes.
1022            #[cfg(feature = "edge-proxy")]
1023            admin_state
1024                .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
1025                .await;
1026
1027            // Create admin server
1028            let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
1029
1030            // Spawn state sync task
1031            let admin_state_sync = admin_state.clone();
1032            let server_state = state.clone();
1033            let sync_task = tokio::spawn(async move {
1034                let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
1035                loop {
1036                    interval.tick().await;
1037
1038                    // Sync health status
1039                    {
1040                        let health = server_state.health.load_full();
1041                        let mut admin_health = admin_state_sync.node_health.write().await;
1042                        *admin_health = (*health).clone();
1043                    }
1044
1045                    // Sync metrics
1046                    {
1047                        let metrics = ServerMetricsSnapshot {
1048                            connections_accepted: server_state.metrics.connections_accepted.load(Ordering::Relaxed),
1049                            connections_closed: server_state.metrics.connections_closed.load(Ordering::Relaxed),
1050                            queries_processed: server_state.metrics.queries_processed.load(Ordering::Relaxed),
1051                            bytes_received: server_state.metrics.bytes_received.load(Ordering::Relaxed),
1052                            bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
1053                            failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
1054                        };
1055                        let mut admin_metrics = admin_state_sync.metrics.write().await;
1056                        *admin_metrics = metrics;
1057                    }
1058
1059                    // Sync session count
1060                    {
1061                        let sessions = server_state.sessions.read().await;
1062                        let mut admin_sessions = admin_state_sync.active_sessions.write().await;
1063                        *admin_sessions = sessions.len() as u64;
1064                    }
1065                }
1066            });
1067
1068            // Run admin server
1069            tokio::select! {
1070                result = admin_server.run() => {
1071                    if let Err(e) = result {
1072                        tracing::error!("Admin server error: {}", e);
1073                    }
1074                }
1075                _ = shutdown_rx.recv() => {
1076                    tracing::info!("Admin server shutting down");
1077                }
1078            }
1079
1080            sync_task.abort();
1081        })
1082    }
1083
1084    /// Handle a client connection
1085    async fn handle_client(
1086        stream: TcpStream,
1087        addr: SocketAddr,
1088        state: Arc<ServerState>,
1089        config: ProxyConfig,
1090        _shutdown_tx: broadcast::Sender<()>,
1091    ) -> Result<()> {
1092        tracing::debug!("New client connection from {}", addr);
1093
1094        // Create session
1095        let session = Arc::new(ClientSession {
1096            id: Uuid::new_v4(),
1097            client_addr: addr,
1098            current_node: RwLock::new(None),
1099            tx_state: RwLock::new(TransactionState::default()),
1100            variables: RwLock::new(HashMap::new()),
1101            created_at: chrono::Utc::now(),
1102            tr_mode: config.tr_mode,
1103            #[cfg(feature = "pool-modes")]
1104            pool_client_id: ClientId::new(),
1105            #[cfg(feature = "wasm-plugins")]
1106            plugin_identity: RwLock::new(None),
1107        });
1108
1109        // Register session
1110        {
1111            let mut sessions = state.sessions.write().await;
1112            sessions.insert(session.id, session.clone());
1113        }
1114
1115        // Negotiate client TLS (if the client sent SSLRequest). Produces a
1116        // ClientStream that is plaintext or TLS-wrapped; the rest of the
1117        // session is written against that single stream type. `pre` carries
1118        // a first startup/cancel message already read while peeking.
1119        let result = match Self::negotiate_client_tls(stream, &state).await {
1120            Ok((mut client_stream, pre)) => {
1121                Self::client_loop(&mut client_stream, pre, &session, &state, &config).await
1122            }
1123            Err(e) => Err(e),
1124        };
1125
1126        // Cleanup session
1127        {
1128            let mut sessions = state.sessions.write().await;
1129            sessions.remove(&session.id);
1130        }
1131
1132        // Release any active pool lease if pool-modes is enabled
1133        #[cfg(feature = "pool-modes")]
1134        if let Some(ref pool_manager) = state.pool_manager {
1135            // Check if there's an active lease for this client and release it
1136            if pool_manager.has_active_lease(&session.pool_client_id) {
1137                tracing::debug!(
1138                    "Releasing pool lease for disconnecting client {:?}",
1139                    session.pool_client_id
1140                );
1141                // Note: The lease is released implicitly when the connection closes
1142                // The pool manager will clean up any orphaned leases
1143            }
1144        }
1145
1146        state
1147            .metrics
1148            .connections_closed
1149            .fetch_add(1, Ordering::Relaxed);
1150
1151        result
1152    }
1153
1154    /// Main client processing loop with full PostgreSQL protocol handling
1155    async fn client_loop(
1156        stream: &mut ClientStream,
1157        pre: Option<StartupMessage>,
1158        session: &Arc<ClientSession>,
1159        state: &Arc<ServerState>,
1160        config: &ProxyConfig,
1161    ) -> Result<()> {
1162        let codec = ProtocolCodec::new();
1163        let mut buffer = BytesMut::with_capacity(8192);
1164
1165        // Handle startup phase. The session keeps a per-node cache of
1166        // authenticated backend connections (`conns`) instead of a single
1167        // stream: when read/write routing moves a session between primary
1168        // and standby it now reuses the already-authenticated connection to
1169        // each node rather than dropping the socket and paying a fresh TCP
1170        // connect + startup + SCRAM handshake on every switch (Batch C).
1171        // Connections are authenticated with the client's own credentials
1172        // (auth is pass-through), so they are private to this session —
1173        // cross-client transaction pooling additionally needs proxy-side
1174        // backend auth and is deferred to the auth batch.
1175        let mut conns: HashMap<String, BackendConn> = HashMap::new();
1176        let mut current_node: Option<String> =
1177            match Self::handle_startup(stream, &mut buffer, &codec, pre, session, state, config).await {
1178                Ok((Some(stream_conn), node_addr)) => {
1179                    conns.insert(node_addr.clone(), BackendConn::new(stream_conn));
1180                    Some(node_addr)
1181                }
1182                Ok((None, _)) => {
1183                    // SSL rejected or cancel request, connection should close
1184                    return Ok(());
1185                }
1186                Err(e) => {
1187                    tracing::error!("Startup failed: {}", e);
1188                    // Send error to client
1189                    let err_msg = Self::create_error_response("08006", &format!("Startup failed: {}", e));
1190                    let _ = stream.write_all(&err_msg).await;
1191                    return Err(e);
1192                }
1193            };
1194
1195        // Main query loop.
1196        //
1197        // Two wire shapes are handled. Simple-query (`Query`) messages are
1198        // self-contained: route, forward, and stream the response back
1199        // frame-by-frame until ReadyForQuery. Extended-protocol messages
1200        // (`Parse`/`Bind`/`Describe`/`Execute`/`Close`) carry no response of
1201        // their own until the client sends `Sync` (or `Flush`), so they are
1202        // accumulated into `pending` and forwarded as one batch at that
1203        // boundary — this is what stops the per-message 30s backend-read
1204        // timeout that made every prepared-statement driver unusable. The
1205        // routing decision for an extended batch is taken from the SQL in its
1206        // first `Parse`; a batch with no `Parse` (a re-`Bind`/`Execute` of a
1207        // named prepared statement) stays on the connection the statement was
1208        // prepared on.
1209        let mut read_buf = vec![0u8; 16384];
1210        let mut pending = BytesMut::new();
1211        let mut pending_route_sql: Option<String> = None;
1212        // Prepared-statement tracking (Batch F.4). `stmt_registry` is the
1213        // session's canonical record of every *named* `Parse` the client has
1214        // issued (name -> full Parse message bytes) so the proxy can re-prepare
1215        // a statement on any backend connection that is missing it. `batch_*`
1216        // accumulate, for the in-flight extended batch, which named statements
1217        // it defines (Parse), references (Bind/Describe-S), and closes
1218        // (Close-S) — resolved at the Sync/Flush boundary.
1219        let mut stmt_registry: HashMap<String, bytes::Bytes> = HashMap::new();
1220        let mut batch_defines: Vec<String> = Vec::new();
1221        let mut batch_refs: Vec<String> = Vec::new();
1222        let mut batch_closes: Vec<String> = Vec::new();
1223        // Unnamed-`Parse` promotion (Batch H). `held_unnamed` parks an unnamed
1224        // Parse that is the FIRST message of a batch (so the batch stays the
1225        // clean Parse→Bind→…→Sync shape) — it is NOT appended to `pending`; the
1226        // decision to forward or skip it is taken at the batch boundary once the
1227        // target connection is known. Holds (full Parse message, signature).
1228        let promote_unnamed = config.optimize_unnamed_parse;
1229        let mut held_unnamed: Option<(bytes::Bytes, bytes::Bytes)> = None;
1230        loop {
1231            // Read from client
1232            let n = stream
1233                .read(&mut read_buf)
1234                .await
1235                .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
1236
1237            if n == 0 {
1238                // Client disconnected
1239                break;
1240            }
1241
1242            buffer.extend_from_slice(&read_buf[..n]);
1243            state.metrics.bytes_received.fetch_add(n as u64, Ordering::Relaxed);
1244
1245            // Process all complete messages in buffer
1246            while let Some(msg) = codec.decode_message(&mut buffer)? {
1247                match msg.msg_type {
1248                    MessageType::Terminate => return Ok(()),
1249
1250                    // ---- Simple query protocol ----
1251                    MessageType::Query => {
1252                        // Anomaly detector — record every Query message before
1253                        // the plugin hook so a detection lands in the audit
1254                        // trail even if a plugin later blocks.
1255                        #[cfg(feature = "anomaly-detection")]
1256                        Self::record_anomaly_observation(&msg, state, session);
1257
1258                        // Plugin pre-query hook — may rewrite the SQL, block,
1259                        // or return a cached response.
1260                        let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
1261
1262                        if let PreQueryAction::Block(reason) = &action {
1263                            tracing::info!(reason = %reason, "pre-query plugin blocked query");
1264                            Self::send_block_response(stream, reason, state).await?;
1265                            state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1266                            continue;
1267                        }
1268
1269                        #[cfg(feature = "wasm-plugins")]
1270                        if let PreQueryAction::Cached(bytes) = &action {
1271                            match Self::synthesise_cached_response(bytes) {
1272                                Ok(reply) => {
1273                                    stream.write_all(&reply).await.map_err(|e| {
1274                                        ProxyError::Network(format!("Write error: {}", e))
1275                                    })?;
1276                                    state.metrics.bytes_sent.fetch_add(reply.len() as u64, Ordering::Relaxed);
1277                                    state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1278                                    continue;
1279                                }
1280                                Err(e) => {
1281                                    tracing::warn!(error = %e, "failed to synthesise cached response; falling back to backend");
1282                                }
1283                            }
1284                        }
1285
1286                        // Traffic mirror: offer the (final, post-rewrite)
1287                        // statement to the secondary backend. Non-blocking —
1288                        // never delays the client path.
1289                        if let Some(ref mirror) = state.mirror {
1290                            if let Some(sql) = crate::protocol::query_text(&msg.payload) {
1291                                mirror.offer(sql, Self::is_write_query(sql));
1292                            }
1293                        }
1294
1295                        #[cfg(feature = "wasm-plugins")]
1296                        let forward_start = std::time::Instant::now();
1297                        let fr = Self::forward_simple_query(
1298                            stream,
1299                            &msg,
1300                            &mut conns,
1301                            current_node.as_deref(),
1302                            session,
1303                            state,
1304                            config,
1305                        )
1306                        .await;
1307                        #[cfg(feature = "wasm-plugins")]
1308                        Self::fire_post_query_hook(&msg, session, state, &fr, forward_start.elapsed());
1309                        let (used_node, sent) = fr?;
1310                        if let Some(n) = used_node {
1311                            current_node = Some(n);
1312                        }
1313                        state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1314                        state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1315                    }
1316
1317                    // ---- Extended query protocol: accumulate until Sync/Flush ----
1318                    MessageType::Parse
1319                    | MessageType::Bind
1320                    | MessageType::Describe
1321                    | MessageType::Execute
1322                    | MessageType::Close => {
1323                        // Whether this message is appended to `pending`. An
1324                        // unnamed Parse held aside for promotion is the lone
1325                        // exception (resolved at the batch boundary).
1326                        let mut add_to_pending = true;
1327                        match msg.msg_type {
1328                            MessageType::Parse => {
1329                                // Register named statements so they can be
1330                                // re-prepared on a different backend later, and
1331                                // borrow the query (2nd cstring) for routing.
1332                                let name = Self::parse_stmt_name(&msg.payload);
1333                                let unnamed = name.is_empty();
1334                                if !unnamed {
1335                                    let name = name.to_string();
1336                                    stmt_registry.insert(name.clone(), msg.encode().freeze());
1337                                    batch_defines.push(name);
1338                                }
1339                                if pending_route_sql.is_none() {
1340                                    if let Some(end) = msg.payload.iter().position(|&b| b == 0) {
1341                                        if let Some(q) =
1342                                            crate::protocol::query_text(&msg.payload[end + 1..])
1343                                        {
1344                                            if !q.is_empty() {
1345                                                pending_route_sql = Some(q.to_string());
1346                                                #[cfg(feature = "anomaly-detection")]
1347                                                Self::record_anomaly_sql(q, state, session);
1348                                            }
1349                                        }
1350                                    }
1351                                }
1352                                // Promotion: park an unnamed Parse that opens a
1353                                // fresh batch. Its signature is the payload after
1354                                // the empty statement-name NUL (query + param
1355                                // types). Anything that breaks the clean shape
1356                                // (a second Parse, a non-empty `pending`) un-parks
1357                                // it back into `pending` to preserve wire order.
1358                                if promote_unnamed
1359                                    && unnamed
1360                                    && pending.is_empty()
1361                                    && held_unnamed.is_none()
1362                                {
1363                                    let sig = bytes::Bytes::copy_from_slice(&msg.payload[1..]);
1364                                    held_unnamed = Some((msg.encode().freeze(), sig));
1365                                    add_to_pending = false;
1366                                } else if let Some((held_msg, _)) = held_unnamed.take() {
1367                                    let mut combined = BytesMut::with_capacity(held_msg.len() + pending.len());
1368                                    combined.extend_from_slice(&held_msg);
1369                                    combined.extend_from_slice(&pending);
1370                                    pending = combined;
1371                                }
1372                            }
1373                            MessageType::Bind => {
1374                                if let Some(name) = Self::bind_stmt_ref(&msg.payload) {
1375                                    batch_refs.push(name.to_string());
1376                                }
1377                            }
1378                            MessageType::Describe => {
1379                                if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1380                                    batch_refs.push(name.to_string());
1381                                }
1382                            }
1383                            MessageType::Close => {
1384                                if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1385                                    batch_closes.push(name.to_string());
1386                                }
1387                            }
1388                            _ => {}
1389                        }
1390                        if add_to_pending {
1391                            pending.extend_from_slice(&msg.encode());
1392                        }
1393                    }
1394
1395                    // ---- Extended batch boundary ----
1396                    MessageType::Sync | MessageType::Flush => {
1397                        let wait_ready = msg.msg_type == MessageType::Sync;
1398                        pending.extend_from_slice(&msg.encode());
1399                        let batch = pending.split().freeze();
1400                        // Re-prepare any named statement this batch references
1401                        // but does not itself define, in case the target
1402                        // connection (after a switch/redial) is missing it.
1403                        let reprepare: Vec<String> = batch_refs
1404                            .iter()
1405                            .filter(|r| !batch_defines.contains(r))
1406                            .cloned()
1407                            .collect();
1408                        let (used_node, sent) = Self::forward_extended_batch(
1409                            stream,
1410                            &batch,
1411                            pending_route_sql.as_deref(),
1412                            wait_ready,
1413                            &mut conns,
1414                            current_node.as_deref(),
1415                            &stmt_registry,
1416                            &reprepare,
1417                            &batch_defines,
1418                            held_unnamed.take(),
1419                            session,
1420                            state,
1421                            config,
1422                        )
1423                        .await?;
1424                        if let Some(n) = used_node {
1425                            current_node = Some(n);
1426                        }
1427                        state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1428                        // Closed statements are deallocated everywhere — forget
1429                        // their canonical Parse so they are never re-prepared.
1430                        for name in batch_closes.drain(..) {
1431                            stmt_registry.remove(&name);
1432                        }
1433                        if wait_ready {
1434                            // Sync ends the extended cycle; reset routing so the
1435                            // next Parse can re-route. Flush leaves it intact so
1436                            // the rest of the in-flight sequence stays put.
1437                            pending_route_sql = None;
1438                            batch_defines.clear();
1439                            batch_refs.clear();
1440                            state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1441                        }
1442                    }
1443
1444                    // ---- COPY sub-protocol (client -> backend) ----
1445                    MessageType::CopyData | MessageType::CopyDone | MessageType::CopyFail => {
1446                        if let Some(node) = current_node.clone() {
1447                            if let Some(b) = conns.get_mut(&node) {
1448                                b.stream.write_all(&msg.encode()).await.map_err(|e| {
1449                                    ProxyError::Network(format!("Backend copy write error: {}", e))
1450                                })?;
1451                                if matches!(msg.msg_type, MessageType::CopyDone | MessageType::CopyFail) {
1452                                    let r = Self::stream_until_ready(stream, &mut b.stream, session, state).await;
1453                                    match r {
1454                                        Ok(sent) => {
1455                                            state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1456                                        }
1457                                        Err(e) => {
1458                                            conns.remove(&node);
1459                                            return Err(e);
1460                                        }
1461                                    }
1462                                }
1463                            }
1464                        }
1465                    }
1466
1467                    // ---- Anything else: forward to current backend best-effort ----
1468                    _ => {
1469                        if let Some(ref node) = current_node {
1470                            if let Some(b) = conns.get_mut(node) {
1471                                let _ = b.stream.write_all(&msg.encode()).await;
1472                            }
1473                        }
1474                    }
1475                }
1476            }
1477        }
1478
1479        Ok(())
1480    }
1481
1482    /// Peek the first startup-phase message and negotiate client TLS.
1483    ///
1484    /// On `SSLRequest` the proxy answers `S` and runs a rustls server
1485    /// handshake when a TLS acceptor is configured, otherwise `N`
1486    /// (plaintext). A `Startup`/`CancelRequest` arriving first (no
1487    /// SSLRequest) is returned in `pre` so the caller doesn't re-read it.
1488    async fn negotiate_client_tls(
1489        mut tcp: TcpStream,
1490        state: &Arc<ServerState>,
1491    ) -> Result<(ClientStream, Option<StartupMessage>)> {
1492        let codec = ProtocolCodec::new();
1493        let mut buffer = BytesMut::with_capacity(1024);
1494        let mut read_buf = vec![0u8; 1024];
1495
1496        let first = loop {
1497            if let Some(msg) = codec.decode_startup(&mut buffer)? {
1498                break msg;
1499            }
1500            let n = tcp
1501                .read(&mut read_buf)
1502                .await
1503                .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
1504            if n == 0 {
1505                return Err(ProxyError::Connection("client closed before startup".to_string()));
1506            }
1507            buffer.extend_from_slice(&read_buf[..n]);
1508        };
1509
1510        match first {
1511            StartupMessage::SSLRequest => match state.tls_acceptor.as_ref() {
1512                Some(acceptor) => {
1513                    tcp.write_all(&[b'S'])
1514                        .await
1515                        .map_err(|e| ProxyError::Network(format!("SSL accept write: {}", e)))?;
1516                    let tls = acceptor
1517                        .accept(tcp)
1518                        .await
1519                        .map_err(|e| ProxyError::Network(format!("TLS handshake failed: {}", e)))?;
1520                    if tls.get_ref().1.peer_certificates().is_some() {
1521                        tracing::debug!("client presented a certificate (mTLS)");
1522                    }
1523                    Ok((ClientStream::Tls(Box::new(tls)), None))
1524                }
1525                None => {
1526                    tcp.write_all(&[b'N'])
1527                        .await
1528                        .map_err(|e| ProxyError::Network(format!("SSL reject write: {}", e)))?;
1529                    Ok((ClientStream::Plain(tcp), None))
1530                }
1531            },
1532            other => Ok((ClientStream::Plain(tcp), Some(other))),
1533        }
1534    }
1535
1536    /// Handle PostgreSQL startup phase (authentication). TLS/SSLRequest is
1537    /// already handled upstream in `negotiate_client_tls`; `pre` carries the
1538    /// first startup/cancel message when it was read during negotiation.
1539    async fn handle_startup(
1540        client_stream: &mut ClientStream,
1541        buffer: &mut BytesMut,
1542        codec: &ProtocolCodec,
1543        pre: Option<StartupMessage>,
1544        session: &Arc<ClientSession>,
1545        state: &Arc<ServerState>,
1546        config: &ProxyConfig,
1547    ) -> Result<(Option<TcpStream>, String)> {
1548        // Use the message already read during TLS negotiation, or read one
1549        // now (the TLS case, where the real startup follows the handshake).
1550        let startup_msg = match pre {
1551            Some(msg) => Some(msg),
1552            None => {
1553                let mut read_buf = vec![0u8; 1024];
1554                loop {
1555                    if let Some(msg) = codec.decode_startup(buffer)? {
1556                        break Some(msg);
1557                    }
1558                    let n = client_stream
1559                        .read(&mut read_buf)
1560                        .await
1561                        .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
1562                    if n == 0 {
1563                        return Ok((None, String::new()));
1564                    }
1565                    buffer.extend_from_slice(&read_buf[..n]);
1566                }
1567            }
1568        };
1569
1570        match startup_msg {
1571            Some(StartupMessage::SSLRequest) => {
1572                // SSL is negotiated upstream; a second SSLRequest here is a
1573                // protocol error — reject defensively.
1574                client_stream
1575                    .write_all(&[b'N'])
1576                    .await
1577                    .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
1578                Err(ProxyError::Protocol("unexpected SSLRequest after startup".to_string()))
1579            }
1580            Some(StartupMessage::CancelRequest { pid, key }) => {
1581                // Forward the cancel to the backend that owns this key, then
1582                // close (the client opened this connection only to cancel).
1583                Self::forward_cancel_request(state, pid, key).await;
1584                Ok((None, String::new()))
1585            }
1586            Some(StartupMessage::Startup { params, .. }) => {
1587                Self::connect_and_authenticate(client_stream, &params, session, state, config).await
1588            }
1589            None => Err(ProxyError::Protocol("Incomplete startup message".to_string())),
1590        }
1591    }
1592
1593    /// Evaluate pg_hba-style admission rules in order. The first rule whose
1594    /// user, database, and address all match decides; if none match, admit.
1595    fn hba_admits(rules: &[HbaRule], ip: std::net::IpAddr, user: &str, database: &str) -> bool {
1596        for r in rules {
1597            let user_ok = r.user == "all" || r.user == user;
1598            let db_ok = r.database == "all" || r.database == database;
1599            if user_ok && db_ok && Self::hba_addr_matches(&r.address, ip) {
1600                return r.action == HbaAction::Allow;
1601            }
1602        }
1603        true
1604    }
1605
1606    /// Match a client address against an hba `address` spec: "all", a bare
1607    /// IP, or a CIDR (`10.0.0.0/8`, `::1/128`).
1608    fn hba_addr_matches(spec: &str, ip: std::net::IpAddr) -> bool {
1609        use std::net::IpAddr;
1610        if spec == "all" {
1611            return true;
1612        }
1613        if let Some((net, bits)) = spec.split_once('/') {
1614            let bits: u32 = match bits.parse() {
1615                Ok(b) => b,
1616                Err(_) => return false,
1617            };
1618            match (net.parse::<IpAddr>(), ip) {
1619                (Ok(IpAddr::V4(n)), IpAddr::V4(i)) if bits <= 32 => {
1620                    let mask = if bits == 0 { 0 } else { u32::MAX << (32 - bits) };
1621                    (u32::from(n) & mask) == (u32::from(i) & mask)
1622                }
1623                (Ok(IpAddr::V6(n)), IpAddr::V6(i)) if bits <= 128 => {
1624                    let mask = if bits == 0 { 0 } else { u128::MAX << (128 - bits) };
1625                    (u128::from(n) & mask) == (u128::from(i) & mask)
1626                }
1627                _ => false,
1628            }
1629        } else {
1630            spec.parse::<IpAddr>().map(|s| s == ip).unwrap_or(false)
1631        }
1632    }
1633
1634    /// Run a proxy-terminated SCRAM-SHA-256 server exchange against the
1635    /// client, validating its password with the configured `auth_file`. On
1636    /// success the client is authenticated by the proxy (no AuthenticationOk
1637    /// is sent here — the backend's is forwarded later). On any failure
1638    /// returns Err; the caller emits an ErrorResponse and closes.
1639    async fn proxy_scram_auth(
1640        client: &mut ClientStream,
1641        user: &str,
1642        state: &Arc<ServerState>,
1643    ) -> std::result::Result<(), String> {
1644        use crate::auth_scram::ScramServer;
1645        let auth_file = state.auth_file.as_ref().ok_or("scram not configured")?;
1646
1647        // 1. AuthenticationSASL: advertise SCRAM-SHA-256.
1648        let mut sasl = BytesMut::new();
1649        sasl.put_i32(10); // SASL
1650        sasl.extend_from_slice(b"SCRAM-SHA-256\0");
1651        sasl.put_u8(0); // end of mechanism list
1652        Self::write_auth_frame(client, &sasl).await?;
1653
1654        // 2. Read SASLInitialResponse ('p'): mechanism cstring + i32 len + data.
1655        let init = Self::read_password_message(client).await?;
1656        let mech_end = init
1657            .iter()
1658            .position(|&b| b == 0)
1659            .ok_or("malformed SASLInitialResponse (no mechanism)")?;
1660        if init.len() < mech_end + 5 {
1661            return Err("short SASLInitialResponse".into());
1662        }
1663        let client_first = std::str::from_utf8(&init[mech_end + 5..])
1664            .map_err(|_| "client-first not UTF-8")?;
1665
1666        // 3. Look up the verifier (unknown user -> generic failure).
1667        let verifier = auth_file
1668            .get(user)
1669            .ok_or("no such user")?
1670            .clone();
1671
1672        // 4. server-first.
1673        let server_nonce = Self::random_nonce();
1674        let (server, server_first) = ScramServer::start(verifier, client_first, &server_nonce)?;
1675
1676        // 5. AuthenticationSASLContinue.
1677        let mut cont = BytesMut::new();
1678        cont.put_i32(11);
1679        cont.extend_from_slice(server_first.as_bytes());
1680        Self::write_auth_frame(client, &cont).await?;
1681
1682        // 6. Read SASLResponse ('p'): payload = client-final.
1683        let client_final_raw = Self::read_password_message(client).await?;
1684        let client_final = std::str::from_utf8(&client_final_raw)
1685            .map_err(|_| "client-final not UTF-8")?;
1686
1687        // 7. Verify -> server-final.
1688        let server_final = server.finish(client_final)?;
1689
1690        // 8. AuthenticationSASLFinal (no AuthenticationOk — backend's follows).
1691        let mut fin = BytesMut::new();
1692        fin.put_i32(12);
1693        fin.extend_from_slice(server_final.as_bytes());
1694        Self::write_auth_frame(client, &fin).await?;
1695        Ok(())
1696    }
1697
1698    /// Write an AuthenticationRequest ('R') frame with the given payload.
1699    async fn write_auth_frame(
1700        client: &mut ClientStream,
1701        payload: &[u8],
1702    ) -> std::result::Result<(), String> {
1703        let mut frame = BytesMut::with_capacity(payload.len() + 5);
1704        frame.put_u8(b'R');
1705        frame.put_u32((payload.len() + 4) as u32);
1706        frame.extend_from_slice(payload);
1707        client
1708            .write_all(&frame)
1709            .await
1710            .map_err(|e| format!("client write: {}", e))
1711    }
1712
1713    /// Read one Password/SASL ('p') message from the client, returning its
1714    /// payload. Errors on EOF or any non-'p' frame.
1715    async fn read_password_message(
1716        client: &mut ClientStream,
1717    ) -> std::result::Result<BytesMut, String> {
1718        let codec = ProtocolCodec::new();
1719        let mut buffer = BytesMut::with_capacity(1024);
1720        let mut read_buf = vec![0u8; 1024];
1721        loop {
1722            if let Some(msg) = codec
1723                .decode_message(&mut buffer)
1724                .map_err(|e| format!("decode: {}", e))?
1725            {
1726                if msg.msg_type == MessageType::Password {
1727                    return Ok(msg.payload);
1728                }
1729                return Err(format!("expected SASL response, got {:?}", msg.msg_type));
1730            }
1731            let n = client
1732                .read(&mut read_buf)
1733                .await
1734                .map_err(|e| format!("client read: {}", e))?;
1735            if n == 0 {
1736                return Err("client closed during SASL".into());
1737            }
1738            buffer.extend_from_slice(&read_buf[..n]);
1739        }
1740    }
1741
1742    /// A fresh random SCRAM server nonce (printable, no comma).
1743    fn random_nonce() -> String {
1744        use rand::Rng;
1745        const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
1746        let mut rng = rand::thread_rng();
1747        (0..24).map(|_| CHARS[rng.gen_range(0..CHARS.len())] as char).collect()
1748    }
1749
1750    /// Connect to backend and handle authentication
1751    async fn connect_and_authenticate(
1752        client_stream: &mut ClientStream,
1753        params: &HashMap<String, String>,
1754        session: &Arc<ClientSession>,
1755        state: &Arc<ServerState>,
1756        config: &ProxyConfig,
1757    ) -> Result<(Option<TcpStream>, String)> {
1758        // pg_hba-style admission: reject disallowed (user, database, client
1759        // address) combinations before opening any backend connection.
1760        let user = params.get("user").map(String::as_str).unwrap_or("");
1761        let database = params.get("database").map(String::as_str).unwrap_or(user);
1762        if !Self::hba_admits(&config.hba, session.client_addr.ip(), user, database) {
1763            tracing::info!(%user, %database, client = %session.client_addr, "connection rejected by hba rule");
1764            let err = Self::create_error_response(
1765                "28000",
1766                "connection rejected by proxy admission rules",
1767            );
1768            let _ = client_stream.write_all(&err).await;
1769            return Ok((None, String::new()));
1770        }
1771
1772        // Proxy-terminated SCRAM-SHA-256: when an auth_file is configured the
1773        // proxy authenticates the client itself (becoming the auth boundary)
1774        // instead of relaying credentials to the backend. On success it falls
1775        // through to the normal backend connect, whose AuthenticationOk +
1776        // session messages are forwarded to the already-authenticated client.
1777        if state.auth_file.is_some() {
1778            if let Err(e) = Self::proxy_scram_auth(client_stream, user, state).await {
1779                tracing::info!(%user, error = %e, "proxy SCRAM auth failed");
1780                let err = Self::create_error_response("28P01", &format!("authentication failed: {}", e));
1781                let _ = client_stream.write_all(&err).await;
1782                return Ok((None, String::new()));
1783            }
1784            tracing::debug!(%user, "client authenticated by proxy SCRAM");
1785        }
1786
1787        // Plugin Authenticate hook — may deny the connection outright or
1788        // attach a richer identity (roles, tenant_id, claims) onto the
1789        // session for downstream plugins to consume. Happens before any
1790        // backend connection is opened so denials cost nothing on the
1791        // backend side.
1792        Self::apply_authenticate_hook(params, session, state).await?;
1793
1794        // Migration cutover: when active, redirect this connection to the
1795        // promoted target, substituting the target's credentials/database for
1796        // the client's so the cutover is transparent to the application.
1797        let cutover = state.cutover.load_full();
1798        let (node_addr, effective_params) = if let Some(t) = cutover.as_ref() {
1799            let mut p = params.clone();
1800            p.insert("user".to_string(), t.user.clone());
1801            if let Some(ref db) = t.database {
1802                p.insert("database".to_string(), db.clone());
1803            } else {
1804                p.remove("database");
1805            }
1806            tracing::debug!(target = %t.addr, "routing connection to cutover target");
1807            (t.addr.clone(), p)
1808        } else {
1809            (Self::select_node(session, state, config).await?, params.clone())
1810        };
1811
1812        // Connect to backend
1813        let mut backend = tokio::time::timeout(
1814            config.pool.acquire_timeout(),
1815            TcpStream::connect(&node_addr),
1816        )
1817        .await
1818        .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", node_addr)))?
1819        .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", node_addr, e)))?;
1820        let _ = backend.set_nodelay(true);
1821
1822        // Build and send startup message to backend
1823        let params = &effective_params;
1824        let startup_bytes = Self::build_startup_message(params);
1825        backend
1826            .write_all(&startup_bytes)
1827            .await
1828            .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
1829
1830        // Forward authentication messages between client and backend.
1831        // Registers the backend's BackendKeyData so a later CancelRequest
1832        // can be routed back to this node.
1833        Self::proxy_authentication(client_stream, &mut backend, state, &node_addr).await?;
1834
1835        // Store session variables
1836        {
1837            let mut vars = session.variables.write().await;
1838            for (k, v) in params {
1839                vars.insert(k.clone(), v.clone());
1840            }
1841        }
1842
1843        Ok((Some(backend), node_addr))
1844    }
1845
1846    /// Build PostgreSQL startup message
1847    fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
1848        let mut payload = BytesMut::new();
1849
1850        // Protocol version 3.0
1851        payload.put_u32(196608);
1852
1853        // Parameters
1854        for (key, value) in params {
1855            payload.extend_from_slice(key.as_bytes());
1856            payload.put_u8(0);
1857            payload.extend_from_slice(value.as_bytes());
1858            payload.put_u8(0);
1859        }
1860        payload.put_u8(0); // Terminator
1861
1862        // Build complete message with length prefix
1863        let mut msg = BytesMut::new();
1864        msg.put_u32((payload.len() + 4) as u32);
1865        msg.extend_from_slice(&payload);
1866
1867        msg.to_vec()
1868    }
1869
1870    /// Cap on the cancel-key map; cleared on overflow (a dropped stale
1871    /// entry only means one best-effort cancel is not forwarded).
1872    const MAX_CANCEL_KEYS: usize = 100_000;
1873
1874    /// Record the backend that owns a BackendKeyData (pid, secret) pair.
1875    fn register_cancel_key(state: &Arc<ServerState>, pid: u32, key: u32, node_addr: &str) {
1876        if state.cancel_map.len() >= Self::MAX_CANCEL_KEYS {
1877            state.cancel_map.clear();
1878        }
1879        state.cancel_map.insert((pid, key), node_addr.to_string());
1880    }
1881
1882    /// Forward a client CancelRequest to the backend that issued the
1883    /// matching BackendKeyData. Best-effort: unknown keys are ignored.
1884    async fn forward_cancel_request(state: &Arc<ServerState>, pid: u32, key: u32) {
1885        let Some(addr) = state.cancel_map.get(&(pid, key)).map(|e| e.clone()) else {
1886            tracing::debug!(pid, "cancel request for unknown key; ignoring");
1887            return;
1888        };
1889        // CancelRequest: int32 len(16) + int32 code(80877102) + pid + key.
1890        let mut msg = BytesMut::with_capacity(16);
1891        msg.put_u32(16);
1892        msg.put_u32(80877102);
1893        msg.put_u32(pid);
1894        msg.put_u32(key);
1895        match tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(&addr)).await {
1896            Ok(Ok(mut conn)) => {
1897                let _ = conn.set_nodelay(true);
1898                if let Err(e) = conn.write_all(&msg).await {
1899                    tracing::warn!(node = %addr, error = %e, "failed to forward CancelRequest");
1900                }
1901                // PG closes the connection after handling a CancelRequest.
1902            }
1903            other => tracing::warn!(node = %addr, ?other, "could not connect to forward CancelRequest"),
1904        }
1905    }
1906
1907    /// Proxy authentication messages between client and backend
1908    async fn proxy_authentication(
1909        client_stream: &mut ClientStream,
1910        backend_stream: &mut TcpStream,
1911        state: &Arc<ServerState>,
1912        node_addr: &str,
1913    ) -> Result<()> {
1914        let codec = ProtocolCodec::new();
1915        let mut backend_buffer = BytesMut::with_capacity(4096);
1916        let mut client_buffer = BytesMut::with_capacity(4096);
1917        let mut read_buf = vec![0u8; 4096];
1918
1919        loop {
1920            // Read from backend
1921            let n = backend_stream
1922                .read(&mut read_buf)
1923                .await
1924                .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1925
1926            if n == 0 {
1927                return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1928            }
1929
1930            backend_buffer.extend_from_slice(&read_buf[..n]);
1931
1932            // Forward all data to client
1933            client_stream
1934                .write_all(&read_buf[..n])
1935                .await
1936                .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
1937
1938            // Check for authentication complete or error. Bytes were
1939            // already forwarded above, so frames are consumed (decoded
1940            // once) straight out of the buffer — no clone needed.
1941            while let Some(msg) = codec.decode_message(&mut backend_buffer)? {
1942                match msg.msg_type {
1943                    MessageType::BackendKeyData => {
1944                        // The backend told the client how to cancel its
1945                        // queries; remember which backend owns that key so
1946                        // an out-of-band CancelRequest can be forwarded.
1947                        if msg.payload.len() >= 8 {
1948                            let pid = u32::from_be_bytes([
1949                                msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3],
1950                            ]);
1951                            let key = u32::from_be_bytes([
1952                                msg.payload[4], msg.payload[5], msg.payload[6], msg.payload[7],
1953                            ]);
1954                            Self::register_cancel_key(state, pid, key, node_addr);
1955                        }
1956                    }
1957                    MessageType::AuthRequest => {
1958                        // Check if auth OK
1959                        if msg.payload.len() >= 4 {
1960                            let auth_type =
1961                                i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
1962                            if auth_type == 0 {
1963                                // AuthenticationOk - continue to read ReadyForQuery
1964                            }
1965                        }
1966                    }
1967                    MessageType::ReadyForQuery => {
1968                        // Authentication complete
1969                        return Ok(());
1970                    }
1971                    MessageType::ErrorResponse => {
1972                        // Authentication failed - error already sent to client
1973                        return Err(ProxyError::Auth("Authentication failed".to_string()));
1974                    }
1975                    _ => {
1976                        // Continue forwarding
1977                    }
1978                }
1979            }
1980
1981            // If backend requires password, forward client's response
1982            // Read password from client if needed
1983            let n = tokio::time::timeout(Duration::from_millis(100), client_stream.read(&mut read_buf))
1984                .await;
1985
1986            if let Ok(Ok(n)) = n {
1987                if n > 0 {
1988                    client_buffer.extend_from_slice(&read_buf[..n]);
1989                    backend_stream
1990                        .write_all(&read_buf[..n])
1991                        .await
1992                        .map_err(|e| ProxyError::Network(format!("Backend password write error: {}", e)))?;
1993                }
1994            }
1995        }
1996    }
1997
1998    /// Decide which node a request should be routed to, without doing any
1999    /// I/O. Reuses `current_node` when it is healthy and role-compatible
2000    /// (sticky session), otherwise selects a fresh primary/read node. The
2001    /// returned address is the key into the per-session connection cache.
2002    async fn choose_target_node(
2003        is_write: bool,
2004        forced_target: Option<String>,
2005        current_node: Option<&str>,
2006        session: &Arc<ClientSession>,
2007        state: &Arc<ServerState>,
2008        config: &ProxyConfig,
2009    ) -> Result<String> {
2010        // After a migration cutover, every request stays on the promoted
2011        // target — never route back to the former primary.
2012        if let Some(t) = state.cutover.load_full().as_ref() {
2013            return Ok(t.addr.clone());
2014        }
2015        let need_switch = if let Some(ref forced) = forced_target {
2016            let health = state.health.load_full();
2017            let reuse = current_node
2018                .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
2019                .unwrap_or(false);
2020            !reuse
2021        } else if let Some(current) = current_node {
2022            let health = state.health.load_full();
2023            let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
2024            if !current_healthy {
2025                true
2026            } else if is_write {
2027                let is_primary = config
2028                    .nodes
2029                    .iter()
2030                    .find(|n| n.address() == *current)
2031                    .map(|n| n.role == NodeRole::Primary)
2032                    .unwrap_or(false);
2033                !is_primary
2034            } else {
2035                false
2036            }
2037        } else {
2038            true
2039        };
2040
2041        if let Some(forced) = forced_target {
2042            Ok(forced)
2043        } else if need_switch {
2044            if is_write {
2045                Self::select_primary_with_timeout(session, state, config).await
2046            } else {
2047                Self::select_read_node(session, state, config).await
2048            }
2049        } else {
2050            Ok(current_node.unwrap().to_string())
2051        }
2052    }
2053
2054    /// Ensure the per-session cache holds an authenticated backend connection
2055    /// to `target`, dialing + silently re-authenticating one (with the
2056    /// client's pass-through credentials) only if absent. The cached
2057    /// connection is then reused across read/write route switches.
2058    async fn ensure_conn(
2059        conns: &mut HashMap<String, BackendConn>,
2060        target: &str,
2061        session: &Arc<ClientSession>,
2062        config: &ProxyConfig,
2063    ) -> Result<()> {
2064        if conns.contains_key(target) {
2065            return Ok(());
2066        }
2067        let mut backend = tokio::time::timeout(
2068            config.pool.acquire_timeout(),
2069            TcpStream::connect(target),
2070        )
2071        .await
2072        .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target)))?
2073        .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", target, e)))?;
2074        let _ = backend.set_nodelay(true);
2075
2076        let params = session.variables.read().await.clone();
2077        let startup = Self::build_startup_message(&params);
2078        backend
2079            .write_all(&startup)
2080            .await
2081            .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
2082        Self::complete_backend_auth(&mut backend).await?;
2083        tracing::debug!(node = %target, "opened backend connection");
2084        conns.insert(target.to_string(), BackendConn::new(backend));
2085        Ok(())
2086    }
2087
2088    /// Forward a simple-query (`Query`) message and stream its response back
2089    /// to the client frame-by-frame, ending at ReadyForQuery. Picks (and, if
2090    /// needed, opens) the target node's connection from the per-session
2091    /// cache. Returns `(Some(node_used), bytes)` — `None` node means the
2092    /// request was short-circuited (plugin block) without touching a backend.
2093    async fn forward_simple_query(
2094        client: &mut ClientStream,
2095        msg: &Message,
2096        conns: &mut HashMap<String, BackendConn>,
2097        current_node: Option<&str>,
2098        session: &Arc<ClientSession>,
2099        state: &Arc<ServerState>,
2100        config: &ProxyConfig,
2101    ) -> Result<(Option<String>, u64)> {
2102        let default_is_write = Self::is_write_message(msg);
2103        let route_override = Self::apply_route_hook(msg, state, session);
2104
2105        // Block short-circuits before any backend selection.
2106        if let RouteOverride::Block(reason) = route_override {
2107            let mut response = Vec::with_capacity(64 + reason.len());
2108            response.extend_from_slice(&Self::create_error_response(
2109                "42000",
2110                &format!("Query blocked by route plugin: {}", reason),
2111            ));
2112            response.extend_from_slice(&Self::create_ready_for_query(b'I'));
2113            client
2114                .write_all(&response)
2115                .await
2116                .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2117            return Ok((None, response.len() as u64));
2118        }
2119
2120        let (is_write, forced_target) = match route_override {
2121            RouteOverride::None => (default_is_write, None),
2122            RouteOverride::Primary => (true, None),
2123            RouteOverride::Standby => (false, None),
2124            RouteOverride::Node(name) => (default_is_write, Some(name)),
2125            RouteOverride::Block(_) => unreachable!("handled above"),
2126        };
2127
2128        let target =
2129            Self::choose_target_node(is_write, forced_target, current_node, session, state, config)
2130                .await?;
2131        Self::ensure_conn(conns, &target, session, config).await?;
2132        let backend = conns.get_mut(&target).expect("just ensured");
2133
2134        backend
2135            .stream
2136            .write_all(&msg.encode())
2137            .await
2138            .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))?;
2139
2140        match Self::stream_until_ready(client, &mut backend.stream, session, state).await {
2141            Ok(sent) => Ok((Some(target), sent)),
2142            Err(e) => {
2143                // Drop the broken connection so the next use redials.
2144                conns.remove(&target);
2145                Err(e)
2146            }
2147        }
2148    }
2149
2150    /// Forward an accumulated extended-protocol batch (Parse/Bind/Describe/
2151    /// Execute/Close terminated by Sync or Flush) and stream the response.
2152    /// Routing is taken from `route_sql` (the first Parse's SQL); when it is
2153    /// `None` (a re-Bind/Execute of a named prepared statement) the request
2154    /// stays on the connection the statement was prepared on — no switch.
2155    ///
2156    /// `reprepare` lists named statements this batch references but does not
2157    /// itself define; any that the chosen connection has not seen are
2158    /// re-prepared from `registry` (their original `Parse`) before the batch is
2159    /// sent, so a named statement survives a backend switch/redial (Batch F.4).
2160    /// `defines` are the named statements this batch's own `Parse`s create —
2161    /// recorded against the connection once it accepts the batch.
2162    #[allow(clippy::too_many_arguments)]
2163    async fn forward_extended_batch(
2164        client: &mut ClientStream,
2165        batch: &[u8],
2166        route_sql: Option<&str>,
2167        wait_ready: bool,
2168        conns: &mut HashMap<String, BackendConn>,
2169        current_node: Option<&str>,
2170        registry: &HashMap<String, bytes::Bytes>,
2171        reprepare: &[String],
2172        defines: &[String],
2173        unnamed: Option<(bytes::Bytes, bytes::Bytes)>,
2174        session: &Arc<ClientSession>,
2175        state: &Arc<ServerState>,
2176        config: &ProxyConfig,
2177    ) -> Result<(Option<String>, u64)> {
2178        let target = match route_sql {
2179            Some(sql) => {
2180                let is_write = Self::is_write_query(sql);
2181                Self::choose_target_node(is_write, None, current_node, session, state, config)
2182                    .await?
2183            }
2184            // No Parse in this batch: stay on the prepared-statement /
2185            // portal connection. Fall back to a read node only if the
2186            // session has no current connection yet.
2187            None => match current_node {
2188                Some(c) => c.to_string(),
2189                None => Self::select_read_node(session, state, config).await?,
2190            },
2191        };
2192
2193        Self::ensure_conn(conns, &target, session, config).await?;
2194        let backend = conns.get_mut(&target).expect("just ensured");
2195
2196        // Transparently re-prepare any referenced named statement this socket
2197        // is missing. Each is sent as its original `Parse` + `Flush`; the
2198        // resulting `ParseComplete` is consumed here so the client never sees
2199        // the extra round trip. A re-prepare failure recycles the connection.
2200        for name in reprepare {
2201            if backend.prepared.contains(name) {
2202                continue;
2203            }
2204            let Some(parse_bytes) = registry.get(name) else {
2205                continue; // unknown statement — let the batch surface the error
2206            };
2207            match Self::reprepare_statement(&mut backend.stream, parse_bytes).await {
2208                Ok(()) => {
2209                    backend.prepared.insert(name.clone());
2210                }
2211                Err(e) => {
2212                    conns.remove(&target);
2213                    return Err(e);
2214                }
2215            }
2216        }
2217
2218        // Unnamed-`Parse` promotion: if the held unnamed Parse matches what this
2219        // connection's unnamed statement already holds, skip forwarding it and
2220        // synthesize its `ParseComplete` to the client; otherwise forward it
2221        // first (re-establishing the connection's unnamed statement) and record
2222        // its signature. A fresh/redialed connection has no signature, so the
2223        // Parse is always (re)forwarded there — correctness is preserved.
2224        let mut inject_parse_complete = false;
2225        let mut new_unnamed_sig: Option<bytes::Bytes> = None;
2226        if let Some((parse_msg, sig)) = unnamed.as_ref() {
2227            if backend.unnamed_sig.as_deref() == Some(&sig[..]) {
2228                inject_parse_complete = true;
2229            } else {
2230                if let Err(e) = backend
2231                    .stream
2232                    .write_all(parse_msg)
2233                    .await
2234                    .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
2235                {
2236                    conns.remove(&target);
2237                    return Err(e);
2238                }
2239                new_unnamed_sig = Some(sig.clone());
2240            }
2241        }
2242
2243        if let Err(e) = backend
2244            .stream
2245            .write_all(batch)
2246            .await
2247            .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
2248        {
2249            conns.remove(&target);
2250            return Err(e);
2251        }
2252
2253        // The client expects `ParseComplete` first; the backend won't send one
2254        // for a skipped Parse, so emit it here before relaying the response.
2255        let mut injected: u64 = 0;
2256        if inject_parse_complete {
2257            if let Err(e) = client
2258                .write_all(&[b'1', 0, 0, 0, 4])
2259                .await
2260                .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))
2261            {
2262                conns.remove(&target);
2263                return Err(e);
2264            }
2265            injected = 5;
2266        }
2267
2268        let r = if wait_ready {
2269            Self::stream_until_ready(client, &mut backend.stream, session, state).await
2270        } else {
2271            Self::stream_flush(client, &mut backend.stream, session, state).await
2272        };
2273        match r {
2274            Ok(sent) => {
2275                // The connection now holds these named statements.
2276                for name in defines {
2277                    backend.prepared.insert(name.clone());
2278                }
2279                // ...and the (re)forwarded unnamed statement.
2280                if let Some(sig) = new_unnamed_sig {
2281                    backend.unnamed_sig = Some(sig);
2282                }
2283                Ok((Some(target), sent + injected))
2284            }
2285            Err(e) => {
2286                conns.remove(&target);
2287                Err(e)
2288            }
2289        }
2290    }
2291
2292    /// Re-issue one named `Parse` on a backend socket out-of-band: send the
2293    /// original `Parse` bytes followed by a `Flush`, then read and discard the
2294    /// single `ParseComplete` the backend emits. The statement persists on the
2295    /// connection (the implicit transaction is closed later by the real
2296    /// batch's `Sync`). An `ErrorResponse` means the re-prepare failed.
2297    async fn reprepare_statement<S: AsyncReadExt + AsyncWriteExt + Unpin>(
2298        backend: &mut S,
2299        parse_bytes: &[u8],
2300    ) -> Result<()> {
2301        backend
2302            .write_all(parse_bytes)
2303            .await
2304            .map_err(|e| ProxyError::Network(format!("re-prepare write error: {}", e)))?;
2305        // Flush: 'H' + length 4.
2306        backend
2307            .write_all(&[b'H', 0, 0, 0, 4])
2308            .await
2309            .map_err(|e| ProxyError::Network(format!("re-prepare flush error: {}", e)))?;
2310        let mtype = Self::read_one_frame_type(backend).await?;
2311        match mtype {
2312            b'1' => Ok(()), // ParseComplete
2313            b'E' => Err(ProxyError::Protocol("re-prepare rejected by backend".to_string())),
2314            other => Err(ProxyError::Protocol(format!(
2315                "unexpected re-prepare reply: {}",
2316                other as char
2317            ))),
2318        }
2319    }
2320
2321    /// Read exactly one backend message frame (5-byte header + body) and return
2322    /// its type byte, discarding the body. Used to consume the `ParseComplete`
2323    /// produced by an out-of-band re-prepare.
2324    async fn read_one_frame_type<S: AsyncReadExt + Unpin>(backend: &mut S) -> Result<u8> {
2325        let mut header = [0u8; 5];
2326        backend
2327            .read_exact(&mut header)
2328            .await
2329            .map_err(|e| ProxyError::Network(format!("re-prepare read error: {}", e)))?;
2330        let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
2331        let body_len = len.saturating_sub(4);
2332        if body_len > 0 {
2333            let mut body = vec![0u8; body_len];
2334            backend
2335                .read_exact(&mut body)
2336                .await
2337                .map_err(|e| ProxyError::Network(format!("re-prepare body read error: {}", e)))?;
2338        }
2339        Ok(header[0])
2340    }
2341
2342    /// Name a `Parse` defines: its first cstring. `""` is the unnamed
2343    /// statement, which is per-protocol transient and never tracked.
2344    fn parse_stmt_name(payload: &[u8]) -> &str {
2345        let end = payload.iter().position(|&b| b == 0).unwrap_or(0);
2346        std::str::from_utf8(&payload[..end]).unwrap_or("")
2347    }
2348
2349    /// Prepared-statement name a `Bind` references: the *second* cstring
2350    /// (portal name first, then statement name). `None` for the unnamed
2351    /// statement.
2352    fn bind_stmt_ref(payload: &[u8]) -> Option<&str> {
2353        let portal_end = payload.iter().position(|&b| b == 0)?;
2354        let rest = &payload[portal_end + 1..];
2355        let stmt_end = rest.iter().position(|&b| b == 0)?;
2356        let name = std::str::from_utf8(&rest[..stmt_end]).ok()?;
2357        (!name.is_empty()).then_some(name)
2358    }
2359
2360    /// Statement name a `Describe`/`Close` targets — only when it is
2361    /// statement-kind (`'S'`, not portal `'P'`). `None` otherwise.
2362    fn stmt_kind_name(payload: &[u8]) -> Option<&str> {
2363        if payload.first() != Some(&b'S') {
2364            return None;
2365        }
2366        let rest = &payload[1..];
2367        let end = rest.iter().position(|&b| b == 0)?;
2368        let name = std::str::from_utf8(&rest[..end]).ok()?;
2369        (!name.is_empty()).then_some(name)
2370    }
2371
2372    /// Stream backend response frames to the client until ReadyForQuery (end
2373    /// of a Sync/simple-query response). Forwards bytes verbatim, coalescing
2374    /// all currently-complete frames into one write and keeping only a
2375    /// partial-frame tail buffered, so proxy memory stays O(frame) rather
2376    /// than O(result). Also yields on CopyInResponse/CopyBothResponse so the
2377    /// client can supply COPY data. Updates `tx_state` from the RFQ status.
2378    /// Returns bytes streamed to the client.
2379    async fn stream_until_ready(
2380        client: &mut ClientStream,
2381        backend: &mut TcpStream,
2382        session: &Arc<ClientSession>,
2383        state: &Arc<ServerState>,
2384    ) -> Result<u64> {
2385        let _ = state;
2386        let mut buf = BytesMut::with_capacity(16384);
2387        let mut read_buf = vec![0u8; 16384];
2388        let mut sent: u64 = 0;
2389
2390        loop {
2391            // Walk complete frames in `buf`, stopping at a boundary frame.
2392            let mut consumed = 0usize;
2393            let mut ready_status: Option<u8> = None;
2394            let mut yield_for_copy = false;
2395            loop {
2396                let rem = &buf[consumed..];
2397                if rem.len() < 5 {
2398                    break;
2399                }
2400                let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
2401                if len < 4 || rem.len() < len + 1 {
2402                    break; // incomplete or malformed length — need more bytes
2403                }
2404                let frame_total = len + 1;
2405                let mtype = rem[0];
2406                consumed += frame_total;
2407                if mtype == b'Z' {
2408                    // ReadyForQuery: payload is one status byte at rem[5].
2409                    ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
2410                    break;
2411                }
2412                if mtype == b'G' || mtype == b'W' {
2413                    // CopyInResponse / CopyBothResponse: the backend now wants
2414                    // CopyData from the client — forward up to here and yield.
2415                    yield_for_copy = true;
2416                    break;
2417                }
2418            }
2419
2420            if consumed > 0 {
2421                client
2422                    .write_all(&buf[..consumed])
2423                    .await
2424                    .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2425                sent += consumed as u64;
2426                let _ = buf.split_to(consumed);
2427            }
2428
2429            if let Some(status) = ready_status {
2430                let st = TransactionStatus::from_byte(status);
2431                let mut tx = session.tx_state.write().await;
2432                tx.in_transaction = st != TransactionStatus::Idle;
2433                return Ok(sent);
2434            }
2435            if yield_for_copy {
2436                return Ok(sent);
2437            }
2438
2439            let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
2440                .await
2441                .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
2442                .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
2443            if n == 0 {
2444                return Err(ProxyError::Connection("Backend closed mid-response".to_string()));
2445            }
2446            buf.extend_from_slice(&read_buf[..n]);
2447        }
2448    }
2449
2450    /// Stream whatever the backend has produced in response to a `Flush`
2451    /// (which, unlike `Sync`, produces no ReadyForQuery). Relays available
2452    /// bytes and returns once the backend goes briefly idle, so the loop can
2453    /// read the client's next frames — deadlock-free. The eventual `Sync`
2454    /// drains the final ReadyForQuery via `stream_until_ready`.
2455    async fn stream_flush(
2456        client: &mut ClientStream,
2457        backend: &mut TcpStream,
2458        session: &Arc<ClientSession>,
2459        state: &Arc<ServerState>,
2460    ) -> Result<u64> {
2461        let _ = (session, state);
2462        let mut read_buf = vec![0u8; 16384];
2463        let mut sent: u64 = 0;
2464        loop {
2465            match tokio::time::timeout(Duration::from_millis(200), backend.read(&mut read_buf)).await
2466            {
2467                Ok(Ok(0)) => return Err(ProxyError::Connection("Backend closed mid-flush".to_string())),
2468                Ok(Ok(n)) => {
2469                    client
2470                        .write_all(&read_buf[..n])
2471                        .await
2472                        .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2473                    sent += n as u64;
2474                }
2475                Ok(Err(e)) => return Err(ProxyError::Network(format!("Backend read error: {}", e))),
2476                Err(_) => return Ok(sent), // idle: backend has emitted all flush output
2477            }
2478        }
2479    }
2480
2481    /// Check if a message is a write operation
2482    fn is_write_message(msg: &Message) -> bool {
2483        match msg.msg_type {
2484            MessageType::Query => {
2485                // Borrow the SQL straight out of the payload — the
2486                // message is forwarded verbatim, so no copy is needed
2487                // just to inspect the leading keyword.
2488                crate::protocol::query_text(&msg.payload)
2489                    .map(Self::is_write_query)
2490                    .unwrap_or(false)
2491            }
2492            MessageType::Parse => {
2493                // Parse payload = statement-name cstring + query
2494                // cstring; skip the name and borrow the query.
2495                msg.payload
2496                    .iter()
2497                    .position(|&b| b == 0)
2498                    .and_then(|end| crate::protocol::query_text(&msg.payload[end + 1..]))
2499                    .map(Self::is_write_query)
2500                    .unwrap_or(false)
2501            }
2502            // Execute, Bind, etc. maintain the current connection
2503            _ => false,
2504        }
2505    }
2506
2507    /// Check if SQL query is a write operation
2508    fn is_write_query(sql: &str) -> bool {
2509        use crate::protocol::starts_with_ci;
2510        let trimmed = sql.trim();
2511
2512        // Write operations
2513        if starts_with_ci(trimmed, "INSERT")
2514            || starts_with_ci(trimmed, "UPDATE")
2515            || starts_with_ci(trimmed, "DELETE")
2516            || starts_with_ci(trimmed, "CREATE")
2517            || starts_with_ci(trimmed, "DROP")
2518            || starts_with_ci(trimmed, "ALTER")
2519            || starts_with_ci(trimmed, "TRUNCATE")
2520            || starts_with_ci(trimmed, "GRANT")
2521            || starts_with_ci(trimmed, "REVOKE")
2522            || starts_with_ci(trimmed, "VACUUM")
2523            || starts_with_ci(trimmed, "REINDEX")
2524            || starts_with_ci(trimmed, "CLUSTER")
2525        {
2526            return true;
2527        }
2528
2529        // Transaction control goes to current node
2530        if starts_with_ci(trimmed, "BEGIN")
2531            || starts_with_ci(trimmed, "START")
2532            || starts_with_ci(trimmed, "COMMIT")
2533            || starts_with_ci(trimmed, "ROLLBACK")
2534            || starts_with_ci(trimmed, "SAVEPOINT")
2535            || starts_with_ci(trimmed, "RELEASE")
2536        {
2537            return true;
2538        }
2539
2540        // SET commands go to primary to maintain session state
2541        if starts_with_ci(trimmed, "SET") && !starts_with_ci(trimmed, "SET TRANSACTION READ ONLY") {
2542            return true;
2543        }
2544
2545        false
2546    }
2547
2548    /// Select primary node with write timeout during failover
2549    async fn select_primary_with_timeout(
2550        session: &Arc<ClientSession>,
2551        state: &Arc<ServerState>,
2552        config: &ProxyConfig,
2553    ) -> Result<String> {
2554        let timeout = config.write_timeout();
2555        let start = std::time::Instant::now();
2556        // Poll for the promoted primary fairly tightly so writes resume
2557        // quickly after a failover (was 500ms — a needless recovery floor).
2558        let check_interval = Duration::from_millis(100);
2559
2560        loop {
2561            // Try to find healthy primary
2562            let health = state.health.load_full();
2563            let primary = config
2564                .nodes
2565                .iter()
2566                .find(|n| n.role == NodeRole::Primary && n.enabled);
2567
2568            if let Some(primary_node) = primary {
2569                if let Some(node_health) = health.get(&primary_node.address()) {
2570                    if node_health.healthy {
2571                        // Update session's current node
2572                        let mut current = session.current_node.write().await;
2573                        *current = Some(primary_node.address());
2574                        return Ok(primary_node.address());
2575                    }
2576                }
2577            }
2578            drop(health);
2579
2580            // Check if timeout exceeded
2581            if start.elapsed() >= timeout {
2582                state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
2583                return Err(ProxyError::NoHealthyNodes);
2584            }
2585
2586            tracing::warn!(
2587                "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
2588                start.elapsed().as_secs_f64(),
2589                timeout.as_secs_f64()
2590            );
2591
2592            // Wait before retry
2593            tokio::time::sleep(check_interval).await;
2594        }
2595    }
2596
2597    /// Select node for read operations with load balancing
2598    async fn select_read_node(
2599        session: &Arc<ClientSession>,
2600        state: &Arc<ServerState>,
2601        config: &ProxyConfig,
2602    ) -> Result<String> {
2603        // If in transaction, stick to current node
2604        {
2605            let tx_state = session.tx_state.read().await;
2606            if tx_state.in_transaction {
2607                if let Some(node) = session.current_node.read().await.clone() {
2608                    return Ok(node);
2609                }
2610            }
2611        }
2612
2613        // Get healthy nodes (prefer standbys for reads)
2614        let health = state.health.load_full();
2615        let healthy_standbys: Vec<&NodeConfig> = config
2616            .nodes
2617            .iter()
2618            .filter(|n| {
2619                n.enabled
2620                    && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
2621                    && health
2622                        .get(&n.address())
2623                        .map(|h| h.healthy)
2624                        .unwrap_or(false)
2625            })
2626            .collect();
2627
2628        if !healthy_standbys.is_empty() {
2629            // Round-robin across healthy standbys
2630            let ticket = state.lb_state.rr_counter.fetch_add(1, Ordering::Relaxed);
2631            let index = ticket as usize % healthy_standbys.len();
2632            let node_addr = healthy_standbys[index].address();
2633
2634            let mut current = session.current_node.write().await;
2635            *current = Some(node_addr.clone());
2636            return Ok(node_addr);
2637        }
2638
2639        // Fall back to primary if no healthy standbys
2640        Self::select_node(session, state, config).await
2641    }
2642
2643    /// Complete backend authentication by reading until ReadyForQuery
2644    /// This is used when switching backends - we don't forward auth to client
2645    async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
2646        let codec = ProtocolCodec::new();
2647        let mut buffer = BytesMut::with_capacity(4096);
2648        let mut read_buf = vec![0u8; 4096];
2649        let timeout = Duration::from_secs(10);
2650        let start = std::time::Instant::now();
2651
2652        loop {
2653            if start.elapsed() > timeout {
2654                return Err(ProxyError::Auth("Backend authentication timeout".to_string()));
2655            }
2656
2657            let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
2658                .await
2659                .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
2660                .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
2661
2662            if n == 0 {
2663                return Err(ProxyError::Connection("Backend closed during auth".to_string()));
2664            }
2665
2666            buffer.extend_from_slice(&read_buf[..n]);
2667
2668            // Decode (and consume) complete frames directly; returns
2669            // None when more data is needed.
2670            while let Some(msg) = codec.decode_message(&mut buffer)? {
2671                match msg.msg_type {
2672                    MessageType::ReadyForQuery => {
2673                        // Authentication complete
2674                        return Ok(());
2675                    }
2676                    MessageType::ErrorResponse => {
2677                        let err = ErrorResponse::parse(msg.payload)
2678                            .map(|e| e.message().unwrap_or("Unknown error").to_string())
2679                            .unwrap_or_else(|_| "Parse error".to_string());
2680                        return Err(ProxyError::Auth(err));
2681                    }
2682                    _ => {
2683                        // Continue reading (AuthRequest, ParameterStatus, BackendKeyData, etc.)
2684                    }
2685                }
2686            }
2687        }
2688    }
2689
2690    /// Create PostgreSQL error response message
2691    fn create_error_response(code: &str, message: &str) -> Vec<u8> {
2692        let mut fields = HashMap::new();
2693        fields.insert('S', "ERROR".to_string());
2694        fields.insert('V', "ERROR".to_string());
2695        fields.insert('C', code.to_string());
2696        fields.insert('M', message.to_string());
2697
2698        let err = ErrorResponse { fields };
2699        err.encode().encode().to_vec()
2700    }
2701
2702    /// Create a `ReadyForQuery` frame with the given transaction-status byte
2703    /// (`b'I'` = idle, `b'T'` = in transaction, `b'E'` = failed transaction).
2704    fn create_ready_for_query(status: u8) -> Vec<u8> {
2705        let mut payload = BytesMut::with_capacity(1);
2706        payload.put_u8(status);
2707        Message::new(MessageType::ReadyForQuery, payload)
2708            .encode()
2709            .to_vec()
2710    }
2711
2712    /// Synthesise a full PostgreSQL simple-query response from a cached
2713    /// payload produced by a plugin's `PreQueryResult::Cached`.
2714    ///
2715    /// # Payload format
2716    ///
2717    /// The plugin is expected to serialise a JSON document of the form:
2718    ///
2719    /// ```json
2720    /// {
2721    ///   "columns": [
2722    ///     {"name": "id",    "oid": 23},
2723    ///     {"name": "email", "oid": 25}
2724    ///   ],
2725    ///   "rows": [
2726    ///     ["1", "alice@example.com"],
2727    ///     ["2", null]
2728    ///   ]
2729    /// }
2730    /// ```
2731    ///
2732    /// `oid` is the PostgreSQL type OID (`23` = int4, `25` = text,
2733    /// `20` = int8, `16` = bool, `1184` = timestamptz, etc.). Row values
2734    /// are strings in text format; `null` encodes a SQL NULL. The type
2735    /// OID is advisory — pgwire clients accept `25` (text) universally
2736    /// and cast as needed.
2737    ///
2738    /// # Returned bytes
2739    ///
2740    /// One concatenated PostgreSQL wire response:
2741    ///
2742    /// ```text
2743    /// RowDescription (T) + DataRow (D) × N + CommandComplete (C: "SELECT N")
2744    ///                    + ReadyForQuery (Z: idle)
2745    /// ```
2746    ///
2747    /// Returns an error on malformed JSON; the caller falls back to
2748    /// backend forwarding.
2749    #[cfg(feature = "wasm-plugins")]
2750    fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
2751        use serde::Deserialize;
2752
2753        #[derive(Deserialize)]
2754        struct CachedPayload {
2755            columns: Vec<ColumnDef>,
2756            rows: Vec<Vec<Option<String>>>,
2757        }
2758
2759        #[derive(Deserialize)]
2760        struct ColumnDef {
2761            name: String,
2762            #[serde(default = "default_text_oid")]
2763            oid: u32,
2764        }
2765
2766        fn default_text_oid() -> u32 {
2767            25 // text
2768        }
2769
2770        let payload: CachedPayload = serde_json::from_slice(bytes).map_err(|e| {
2771            ProxyError::Protocol(format!("invalid cached payload JSON: {}", e))
2772        })?;
2773
2774        if payload.columns.is_empty() {
2775            return Err(ProxyError::Protocol(
2776                "cached payload must declare at least one column".to_string(),
2777            ));
2778        }
2779
2780        let mut reply = Vec::new();
2781
2782        // RowDescription (tag 'T')
2783        let mut rd = BytesMut::new();
2784        rd.put_u16(payload.columns.len() as u16);
2785        for col in &payload.columns {
2786            rd.extend_from_slice(col.name.as_bytes());
2787            rd.put_u8(0); // cstring terminator
2788            rd.put_i32(0); // tableOID (unknown)
2789            rd.put_i16(0); // columnNumber (unknown)
2790            rd.put_u32(col.oid);
2791            rd.put_i16(-1); // typeLen (unspecified)
2792            rd.put_i32(-1); // typeMod (unspecified)
2793            rd.put_i16(0); // format code: text
2794        }
2795        reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
2796
2797        // DataRow (tag 'D') per row
2798        let column_count = payload.columns.len();
2799        for row in &payload.rows {
2800            if row.len() != column_count {
2801                return Err(ProxyError::Protocol(format!(
2802                    "cached row has {} values but {} columns are declared",
2803                    row.len(),
2804                    column_count
2805                )));
2806            }
2807            let mut dr = BytesMut::new();
2808            dr.put_u16(row.len() as u16);
2809            for value in row {
2810                match value {
2811                    Some(s) => {
2812                        dr.put_i32(s.len() as i32);
2813                        dr.extend_from_slice(s.as_bytes());
2814                    }
2815                    None => {
2816                        dr.put_i32(-1); // NULL sentinel
2817                    }
2818                }
2819            }
2820            reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
2821        }
2822
2823        // CommandComplete (tag 'C')
2824        let tag = format!("SELECT {}", payload.rows.len());
2825        let mut cc = BytesMut::new();
2826        cc.extend_from_slice(tag.as_bytes());
2827        cc.put_u8(0);
2828        reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
2829
2830        // ReadyForQuery (tag 'Z', status 'I' idle)
2831        reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
2832
2833        Ok(reply)
2834    }
2835
2836    /// Run the pre-query plugin hook on a client message.
2837    ///
2838    /// When the `wasm-plugins` feature is off, or the plugin manager has no
2839    /// loaded plugins, this is a zero-cost passthrough that returns the
2840    /// message untouched with `PreQueryAction::Forward`.
2841    ///
2842    /// Only simple-query (`MessageType::Query`) messages are inspected today.
2843    /// Extended-protocol messages (`Parse`/`Bind`/`Execute`) are passed
2844    /// through unchanged — a future task wires them in.
2845    fn apply_pre_query_hook(
2846        msg: Message,
2847        state: &Arc<ServerState>,
2848        session: &Arc<ClientSession>,
2849    ) -> (Message, PreQueryAction) {
2850        #[cfg(feature = "wasm-plugins")]
2851        {
2852            let pm = match state.plugin_manager.as_ref() {
2853                Some(pm) => pm,
2854                None => return (msg, PreQueryAction::Forward),
2855            };
2856
2857            if msg.msg_type != MessageType::Query {
2858                return (msg, PreQueryAction::Forward);
2859            }
2860
2861            // Zero plugins registered for this hook — skip the payload
2862            // clone, SQL parse, and context construction entirely.
2863            if !pm.has_hook(HookType::PreQuery) {
2864                return (msg, PreQueryAction::Forward);
2865            }
2866
2867            let query_msg = match QueryMessage::parse(msg.payload.clone()) {
2868                Ok(q) => q,
2869                Err(_) => return (msg, PreQueryAction::Forward),
2870            };
2871
2872            let ctx = Self::build_query_context(&query_msg.query, session);
2873
2874            match pm.execute_pre_query(&ctx) {
2875                PreQueryResult::Continue => (msg, PreQueryAction::Forward),
2876                PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
2877                PreQueryResult::Rewrite(new_sql) => {
2878                    let rewritten = QueryMessage { query: new_sql }.encode();
2879                    (rewritten, PreQueryAction::Forward)
2880                }
2881                PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
2882            }
2883        }
2884        #[cfg(not(feature = "wasm-plugins"))]
2885        {
2886            let _ = (state, session);
2887            (msg, PreQueryAction::Forward)
2888        }
2889    }
2890
2891    /// Feed the anomaly detector a per-query observation. Cheap —
2892    /// only the SQL-injection scan and the novel-fingerprint check
2893    /// are non-trivial, both well under a microsecond on
2894    /// representative queries. Returns nothing; detections land in
2895    /// the detector's ring buffer and are surfaced via /api/anomalies.
2896    #[cfg(feature = "anomaly-detection")]
2897    fn record_anomaly_observation(
2898        msg: &Message,
2899        state: &Arc<ServerState>,
2900        session: &Arc<ClientSession>,
2901    ) {
2902        if msg.msg_type != MessageType::Query {
2903            return;
2904        }
2905        // Borrow the SQL straight out of the payload — the message is
2906        // forwarded verbatim, so no deep copy of the frame is needed.
2907        if let Some(query) = crate::protocol::query_text(&msg.payload) {
2908            Self::record_anomaly_sql(query, state, session);
2909        }
2910    }
2911
2912    /// Feed one SQL statement to the anomaly detector. Shared by the
2913    /// simple-query path and the extended-protocol `Parse` path so
2914    /// prepared-statement traffic is observed too.
2915    #[cfg(feature = "anomaly-detection")]
2916    fn record_anomaly_sql(query: &str, state: &Arc<ServerState>, session: &Arc<ClientSession>) {
2917        // Tenant identifier is the most-specific known per-session
2918        // attribute the proxy can attribute traffic to. Multi-tenancy
2919        // sets `tenant_id` in `variables`; otherwise we fall back to
2920        // the client address. session.variables is a tokio RwLock but this
2921        // is a sync helper — try_read avoids an await; on contention we
2922        // fall back to the client IP, still a valid per-source identifier.
2923        let tenant = match session.variables.try_read() {
2924            Ok(vars) => vars
2925                .get("tenant_id")
2926                .or_else(|| vars.get("user"))
2927                .cloned()
2928                .unwrap_or_else(|| session.client_addr.ip().to_string()),
2929            Err(_) => session.client_addr.ip().to_string(),
2930        };
2931        let fingerprint = anomaly_fingerprint(query);
2932        let obs = crate::anomaly::QueryObservation {
2933            tenant,
2934            fingerprint,
2935            sql: query.to_string(),
2936            timestamp: std::time::Instant::now(),
2937        };
2938        for ev in state.anomaly_detector.record_query(&obs) {
2939            tracing::warn!(anomaly = ?ev, "anomaly detected");
2940        }
2941    }
2942
2943    /// Send the client a `Block`-outcome response: an error frame plus
2944    /// `ReadyForQuery` so the client's state machine returns to idle and
2945    /// the next query can be accepted.
2946    async fn send_block_response(
2947        stream: &mut ClientStream,
2948        reason: &str,
2949        state: &Arc<ServerState>,
2950    ) -> Result<()> {
2951        let err = Self::create_error_response(
2952            "42000",
2953            &format!("Query blocked by plugin: {}", reason),
2954        );
2955        stream
2956            .write_all(&err)
2957            .await
2958            .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
2959        let rfq = Self::create_ready_for_query(b'I');
2960        stream
2961            .write_all(&rfq)
2962            .await
2963            .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
2964        state
2965            .metrics
2966            .bytes_sent
2967            .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
2968        Ok(())
2969    }
2970
2971    /// Build a `QueryContext` for the plugin hook. Populated fields: `query`
2972    /// (verbatim), `is_read_only` (derived from SQL verb), and `hook_context`
2973    /// with the session id as `client_id`. `normalized` and `tables` are
2974    /// left as cheap stand-ins until the analytics normaliser is wired in
2975    /// (T0-d, unified context).
2976    #[cfg(feature = "wasm-plugins")]
2977    fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
2978        let is_read_only = !Self::is_write_query(query);
2979        let mut hook_context = HookContext::default();
2980        hook_context.client_id = Some(session.id.to_string());
2981        QueryContext {
2982            query: query.to_string(),
2983            normalized: query.to_string(),
2984            tables: Vec::new(),
2985            is_read_only,
2986            hook_context,
2987        }
2988    }
2989
2990    /// Run the Authenticate plugin hook at startup. Called from
2991    /// `connect_and_authenticate` before any backend connection.
2992    ///
2993    /// Behaviour by `AuthResult`:
2994    /// * `Defer` — no plugin opinion; proceed with the default
2995    ///   PostgreSQL auth flow unchanged.
2996    /// * `Success(identity)` — store the identity on the session so
2997    ///   downstream plugins (masking, residency) can gate on roles /
2998    ///   tenant_id / claims. PostgreSQL backend auth still runs
2999    ///   normally afterwards (the plugin does not replace PG auth in
3000    ///   this iteration; that's a follow-up).
3001    /// * `Denied(reason)` — surfaces as `ProxyError::Auth`, which the
3002    ///   caller already handles by writing an ErrorResponse to the
3003    ///   client and closing the connection.
3004    ///
3005    /// The `AuthRequest` populated here carries username, database,
3006    /// and client IP from the PostgreSQL startup parameters. Password
3007    /// is deliberately `None` — PG protocol sends the password in
3008    /// response to the backend's challenge, not at startup, so
3009    /// password-aware plugin auth is a separate future task.
3010    async fn apply_authenticate_hook(
3011        _params: &HashMap<String, String>,
3012        _session: &Arc<ClientSession>,
3013        _state: &Arc<ServerState>,
3014    ) -> Result<()> {
3015        #[cfg(feature = "wasm-plugins")]
3016        {
3017            let pm = match _state.plugin_manager.as_ref() {
3018                Some(pm) => pm,
3019                None => return Ok(()),
3020            };
3021
3022            let request = PluginAuthRequest {
3023                headers: HashMap::new(),
3024                username: _params.get("user").cloned(),
3025                password: None,
3026                client_ip: _session.client_addr.ip().to_string(),
3027                database: _params.get("database").cloned(),
3028            };
3029
3030            match pm.execute_authenticate(&request) {
3031                AuthResult::Defer => Ok(()),
3032                AuthResult::Success(identity) => {
3033                    tracing::debug!(
3034                        user = %identity.username,
3035                        roles = ?identity.roles,
3036                        "plugin authenticated user"
3037                    );
3038                    *_session.plugin_identity.write().await = Some(identity);
3039                    Ok(())
3040                }
3041                AuthResult::Denied(reason) => {
3042                    tracing::info!(
3043                        reason = %reason,
3044                        client = %_session.client_addr,
3045                        user = ?_params.get("user"),
3046                        "plugin denied authentication"
3047                    );
3048                    Err(ProxyError::Auth(format!(
3049                        "authentication denied by plugin: {}",
3050                        reason
3051                    )))
3052                }
3053            }
3054        }
3055        #[cfg(not(feature = "wasm-plugins"))]
3056        {
3057            Ok(())
3058        }
3059    }
3060
3061    /// Run the Route plugin hook on a message. Only simple-query messages
3062    /// are inspected; other message types always return `None`.
3063    fn apply_route_hook(
3064        msg: &Message,
3065        state: &Arc<ServerState>,
3066        session: &Arc<ClientSession>,
3067    ) -> RouteOverride {
3068        #[cfg(feature = "wasm-plugins")]
3069        {
3070            let pm = match state.plugin_manager.as_ref() {
3071                Some(pm) => pm,
3072                None => return RouteOverride::None,
3073            };
3074            if msg.msg_type != MessageType::Query {
3075                return RouteOverride::None;
3076            }
3077            // Zero plugins registered for this hook — skip the payload
3078            // clone, SQL parse, and context construction entirely.
3079            if !pm.has_hook(HookType::Route) {
3080                return RouteOverride::None;
3081            }
3082            let query_msg = match QueryMessage::parse(msg.payload.clone()) {
3083                Ok(q) => q,
3084                Err(_) => return RouteOverride::None,
3085            };
3086            let ctx = Self::build_query_context(&query_msg.query, session);
3087            match pm.execute_route(&ctx) {
3088                RouteResult::Default => RouteOverride::None,
3089                RouteResult::Primary => RouteOverride::Primary,
3090                RouteResult::Standby => RouteOverride::Standby,
3091                RouteResult::Node(name) => RouteOverride::Node(name),
3092                RouteResult::Block(reason) => RouteOverride::Block(reason),
3093                RouteResult::Branch(name) => {
3094                    tracing::warn!(
3095                        branch = %name,
3096                        "Route hook returned Branch but branch routing is not yet wired — using default"
3097                    );
3098                    RouteOverride::None
3099                }
3100            }
3101        }
3102        #[cfg(not(feature = "wasm-plugins"))]
3103        {
3104            let _ = (msg, state, session);
3105            RouteOverride::None
3106        }
3107    }
3108
3109    /// Fire post-query hooks after a message has been forwarded (or failed
3110    /// to forward). Best-effort; errors from individual plugins are logged
3111    /// by the plugin manager and never surface here.
3112    #[cfg(feature = "wasm-plugins")]
3113    fn fire_post_query_hook(
3114        msg: &Message,
3115        session: &Arc<ClientSession>,
3116        state: &Arc<ServerState>,
3117        result: &Result<(Option<String>, u64)>,
3118        elapsed: Duration,
3119    ) {
3120        let pm = match state.plugin_manager.as_ref() {
3121            Some(pm) => pm,
3122            None => return,
3123        };
3124        if msg.msg_type != MessageType::Query {
3125            return;
3126        }
3127        // Zero plugins registered for this hook — skip the payload
3128        // clone, SQL parse, and context construction entirely.
3129        if !pm.has_hook(HookType::PostQuery) {
3130            return;
3131        }
3132        let query_msg = match QueryMessage::parse(msg.payload.clone()) {
3133            Ok(q) => q,
3134            Err(_) => return,
3135        };
3136        let ctx = Self::build_query_context(&query_msg.query, session);
3137        let outcome = match result {
3138            Ok((node, bytes)) => PostQueryOutcome {
3139                success: true,
3140                target_node: node.clone(),
3141                elapsed_us: elapsed.as_micros() as u64,
3142                response_bytes: *bytes,
3143                error: None,
3144            },
3145            Err(e) => PostQueryOutcome {
3146                success: false,
3147                target_node: None,
3148                elapsed_us: elapsed.as_micros() as u64,
3149                response_bytes: 0,
3150                error: Some(e.to_string()),
3151            },
3152        };
3153        pm.execute_post_query(&ctx, &outcome);
3154    }
3155
3156    /// Select a backend node for the request
3157    /// Select a backend node for initial connection
3158    /// Prefers primary but falls back to standbys for read connections
3159    async fn select_node(
3160        session: &Arc<ClientSession>,
3161        state: &Arc<ServerState>,
3162        config: &ProxyConfig,
3163    ) -> Result<String> {
3164        // If in a transaction, stick to the current node
3165        {
3166            let tx_state = session.tx_state.read().await;
3167            if tx_state.in_transaction {
3168                if let Some(node) = session.current_node.read().await.clone() {
3169                    return Ok(node);
3170                }
3171            }
3172        }
3173
3174        // Get healthy nodes
3175        let health = state.health.load_full();
3176        let healthy_nodes: Vec<&NodeConfig> = config
3177            .nodes
3178            .iter()
3179            .filter(|n| {
3180                n.enabled
3181                    && health
3182                        .get(&n.address())
3183                        .map(|h| h.healthy)
3184                        .unwrap_or(false)
3185            })
3186            .collect();
3187
3188        if healthy_nodes.is_empty() {
3189            return Err(ProxyError::NoHealthyNodes);
3190        }
3191
3192        // Try to find healthy primary first
3193        if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
3194            let node_addr = primary.address();
3195            let mut current = session.current_node.write().await;
3196            *current = Some(node_addr.clone());
3197            return Ok(node_addr);
3198        }
3199
3200        // Fall back to standby if primary is unavailable
3201        // (Initial connection will work, writes will use write timeout to wait for primary)
3202        if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
3203            tracing::warn!("Primary unavailable, connecting to standby for initial session");
3204            let node_addr = standby.address();
3205            let mut current = session.current_node.write().await;
3206            *current = Some(node_addr.clone());
3207            return Ok(node_addr);
3208        }
3209
3210        // No nodes available
3211        Err(ProxyError::NoHealthyNodes)
3212    }
3213
3214    /// Spawn health checker background task
3215    fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
3216        let state = self.state.clone();
3217        let mut shutdown_rx = self.shutdown_tx.subscribe();
3218
3219        tokio::spawn(async move {
3220            let mut interval = tokio::time::interval(std::time::Duration::from_secs(
3221                state.live_config.load().health.check_interval_secs,
3222            ));
3223
3224            loop {
3225                tokio::select! {
3226                    _ = interval.tick() => {
3227                        // Read the live config each tick so a SIGHUP that
3228                        // adds/removes nodes is checked on the next sweep.
3229                        let config = state.live_config.load_full();
3230                        Self::check_all_nodes(&state, &config).await;
3231                    }
3232                    _ = shutdown_rx.recv() => {
3233                        break;
3234                    }
3235                }
3236            }
3237        })
3238    }
3239
3240    /// Check health of all nodes.
3241    ///
3242    /// Probes run concurrently (one slow/unreachable node no longer delays
3243    /// detection on the others — lowers the failover-detection latency
3244    /// floor), then a single new health snapshot is published via ArcSwap so
3245    /// readers on the query path never block.
3246    async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
3247        // Probe every node in parallel (owned address + timeout so each
3248        // probe is 'static and runs on its own task).
3249        let timeout = Duration::from_secs(config.health.check_timeout_secs);
3250        let mut set = tokio::task::JoinSet::new();
3251        for node in &config.nodes {
3252            let addr = node.address();
3253            set.spawn(async move {
3254                let r = Self::check_node_addr(&addr, timeout).await;
3255                (addr, r)
3256            });
3257        }
3258        let mut results = Vec::with_capacity(config.nodes.len());
3259        while let Some(joined) = set.join_next().await {
3260            if let Ok(pair) = joined {
3261                results.push(pair);
3262            }
3263        }
3264
3265        // Clone-and-modify the current snapshot, then atomically swap it in.
3266        let mut next = (*state.health.load_full()).clone();
3267        for (addr, result) in results {
3268            if let Some(node_health) = next.get_mut(&addr) {
3269                match result {
3270                    Ok(latency) => {
3271                        node_health.healthy = true;
3272                        node_health.failure_count = 0;
3273                        node_health.latency_ms = latency;
3274                        node_health.last_error = None;
3275                    }
3276                    Err(e) => {
3277                        node_health.failure_count += 1;
3278                        node_health.last_error = Some(e.to_string());
3279                        if node_health.failure_count >= config.health.failure_threshold {
3280                            node_health.healthy = false;
3281                            tracing::warn!(
3282                                "Node {} marked unhealthy after {} failures",
3283                                addr,
3284                                node_health.failure_count
3285                            );
3286                        }
3287                    }
3288                }
3289                node_health.last_check = chrono::Utc::now();
3290            }
3291        }
3292        state.health.store(Arc::new(next));
3293    }
3294
3295    /// Check health of a single node by TCP-connect probe. Returns the
3296    /// connect latency in milliseconds.
3297    async fn check_node_addr(addr: &str, timeout: Duration) -> Result<f64> {
3298        let start = std::time::Instant::now();
3299        let _stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
3300            .await
3301            .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", addr)))?
3302            .map_err(|e| ProxyError::HealthCheck(format!("Failed to connect to {}: {}", addr, e)))?;
3303        let latency = start.elapsed().as_secs_f64() * 1000.0;
3304        Ok(latency)
3305    }
3306
3307    /// Spawn pool manager background task
3308    fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
3309        let state = self.state.clone();
3310        let mut shutdown_rx = self.shutdown_tx.subscribe();
3311
3312        tokio::spawn(async move {
3313            let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
3314
3315            loop {
3316                tokio::select! {
3317                    _ = interval.tick() => {
3318                        // Evict idle connections from pool-modes manager
3319                        #[cfg(feature = "pool-modes")]
3320                        if let Some(ref pool_manager) = state.pool_manager {
3321                            pool_manager.evict_idle().await;
3322                            tracing::trace!("Pool-modes idle eviction completed");
3323                        }
3324                    }
3325                    _ = shutdown_rx.recv() => {
3326                        // Cleanup on shutdown
3327                        #[cfg(feature = "pool-modes")]
3328                        if let Some(ref pool_manager) = state.pool_manager {
3329                            pool_manager.close_all().await;
3330                            tracing::info!("Pool-modes manager closed all connections");
3331                        }
3332                        break;
3333                    }
3334                }
3335            }
3336        })
3337    }
3338
3339    /// Shutdown the server
3340    pub fn shutdown(&self) {
3341        let _ = self.shutdown_tx.send(());
3342    }
3343
3344    /// Get pool mode statistics (if pool-modes feature enabled)
3345    #[cfg(feature = "pool-modes")]
3346    pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
3347        if let Some(ref pool_manager) = self.state.pool_manager {
3348            let stats = pool_manager.get_stats().await;
3349            let metrics = pool_manager.metrics().snapshot();
3350            let default_mode = pool_manager.default_mode();
3351
3352            // Calculate average lease duration across all modes
3353            let avg_lease_duration_ms = metrics
3354                .mode_stats
3355                .get(&default_mode)
3356                .map(|s| s.avg_lease_duration_ms as u64)
3357                .unwrap_or(0);
3358
3359            Some(PoolModeStatsSnapshot {
3360                mode: format!("{:?}", default_mode),
3361                total_connections: stats.total_connections,
3362                active_leases: stats.active_connections,
3363                idle_connections: stats.idle_connections,
3364                node_count: stats.node_count,
3365                acquires: metrics.acquires,
3366                releases: metrics.releases,
3367                acquire_failures: metrics.acquire_failures,
3368                acquire_timeouts: metrics.acquire_timeouts,
3369                transactions_completed: metrics.transactions_completed,
3370                statements_executed: metrics.statements_executed,
3371                avg_lease_duration_ms,
3372            })
3373        } else {
3374            None
3375        }
3376    }
3377
3378    /// Add a node to the pool manager (if pool-modes feature enabled)
3379    #[cfg(feature = "pool-modes")]
3380    pub async fn add_node_to_pool(&self, node: &NodeConfig) {
3381        if let Some(ref pool_manager) = self.state.pool_manager {
3382            let endpoint = NodeEndpoint::new(&node.host, node.port)
3383                .with_role(match node.role {
3384                    NodeRole::Primary => crate::NodeRole::Primary,
3385                    NodeRole::Standby => crate::NodeRole::Standby,
3386                    NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
3387                })
3388                .with_weight(node.weight);
3389            pool_manager.add_node(&endpoint).await;
3390            tracing::info!("Added node {} to pool manager", node.address());
3391        }
3392    }
3393
3394    /// Get server metrics
3395    pub fn metrics(&self) -> ServerMetricsSnapshot {
3396        ServerMetricsSnapshot {
3397            connections_accepted: self.state.metrics.connections_accepted.load(Ordering::Relaxed),
3398            connections_closed: self.state.metrics.connections_closed.load(Ordering::Relaxed),
3399            queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
3400            bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
3401            bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
3402            failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
3403        }
3404    }
3405}
3406
3407/// Metrics snapshot for external consumption
3408#[derive(Debug, Clone)]
3409pub struct ServerMetricsSnapshot {
3410    pub connections_accepted: u64,
3411    pub connections_closed: u64,
3412    pub queries_processed: u64,
3413    pub bytes_received: u64,
3414    pub bytes_sent: u64,
3415    pub failovers: u64,
3416}
3417
3418/// Pool mode statistics snapshot (when pool-modes feature is enabled)
3419#[cfg(feature = "pool-modes")]
3420#[derive(Debug, Clone)]
3421pub struct PoolModeStatsSnapshot {
3422    /// Current pooling mode
3423    pub mode: String,
3424    /// Total connections across all pools
3425    pub total_connections: usize,
3426    /// Active (leased) connections
3427    pub active_leases: usize,
3428    /// Idle connections
3429    pub idle_connections: usize,
3430    /// Number of nodes in the pool
3431    pub node_count: usize,
3432    /// Total connection acquires
3433    pub acquires: u64,
3434    /// Total connection releases
3435    pub releases: u64,
3436    /// Failed acquire attempts
3437    pub acquire_failures: u64,
3438    /// Acquire timeouts
3439    pub acquire_timeouts: u64,
3440    /// Completed transactions (Transaction mode)
3441    pub transactions_completed: u64,
3442    /// Total statements executed
3443    pub statements_executed: u64,
3444    /// Average lease duration in milliseconds
3445    pub avg_lease_duration_ms: u64,
3446}
3447
3448#[cfg(test)]
3449mod tests {
3450    use super::*;
3451    use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
3452
3453    fn test_config() -> ProxyConfig {
3454        let mut config = ProxyConfig::default();
3455        config.listen_address = "127.0.0.1:0".to_string();
3456        config
3457            .add_node("127.0.0.1:5432", "primary")
3458            .unwrap();
3459        config
3460    }
3461
3462    #[test]
3463    fn test_server_creation() {
3464        let config = test_config();
3465        let server = ProxyServer::new(config);
3466        assert!(server.is_ok());
3467    }
3468
3469    #[test]
3470    fn test_hba_addr_matches() {
3471        use std::net::IpAddr;
3472        let v4 = |s: &str| s.parse::<IpAddr>().unwrap();
3473        // "all" matches everything
3474        assert!(ProxyServer::hba_addr_matches("all", v4("203.0.113.7")));
3475        // CIDR membership
3476        assert!(ProxyServer::hba_addr_matches("10.0.0.0/8", v4("10.1.2.3")));
3477        assert!(!ProxyServer::hba_addr_matches("10.0.0.0/8", v4("11.1.2.3")));
3478        assert!(ProxyServer::hba_addr_matches("127.0.0.1/32", v4("127.0.0.1")));
3479        assert!(!ProxyServer::hba_addr_matches("127.0.0.1/32", v4("127.0.0.2")));
3480        // bare IP exact match
3481        assert!(ProxyServer::hba_addr_matches("192.168.1.1", v4("192.168.1.1")));
3482        assert!(!ProxyServer::hba_addr_matches("192.168.1.1", v4("192.168.1.2")));
3483        // IPv6 CIDR + /0 catch-all
3484        assert!(ProxyServer::hba_addr_matches("::1/128", v4("::1")));
3485        assert!(ProxyServer::hba_addr_matches("0.0.0.0/0", v4("8.8.8.8")));
3486    }
3487
3488    #[test]
3489    fn test_hba_admits() {
3490        use crate::config::{HbaAction, HbaRule};
3491        use std::net::IpAddr;
3492        let ip: IpAddr = "10.0.0.5".parse().unwrap();
3493        // No rules -> admit all
3494        assert!(ProxyServer::hba_admits(&[], ip, "bench", "benchdb"));
3495        // Reject a specific user, allow others (default admit)
3496        let rules = vec![HbaRule {
3497            action: HbaAction::Reject,
3498            user: "bench".into(),
3499            database: "all".into(),
3500            address: "all".into(),
3501        }];
3502        assert!(!ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
3503        assert!(ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
3504        // First match wins: allow bench from 10/8, reject everything else
3505        let rules = vec![
3506            HbaRule { action: HbaAction::Allow, user: "bench".into(), database: "all".into(), address: "10.0.0.0/8".into() },
3507            HbaRule { action: HbaAction::Reject, user: "all".into(), database: "all".into(), address: "all".into() },
3508        ];
3509        assert!(ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
3510        assert!(!ProxyServer::hba_admits(&rules, "192.168.0.1".parse().unwrap(), "bench", "benchdb"));
3511        assert!(!ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
3512    }
3513
3514    #[test]
3515    fn test_initial_metrics() {
3516        let config = test_config();
3517        let server = ProxyServer::new(config).unwrap();
3518        let metrics = server.metrics();
3519        assert_eq!(metrics.connections_accepted, 0);
3520        assert_eq!(metrics.queries_processed, 0);
3521    }
3522
3523    #[tokio::test]
3524    async fn test_session_creation() {
3525        let config = test_config();
3526        let server = ProxyServer::new(config).unwrap();
3527
3528        let sessions = server.state.sessions.read().await;
3529        assert!(sessions.is_empty());
3530    }
3531
3532    #[tokio::test]
3533    async fn test_node_health_initialization() {
3534        let config = test_config();
3535        let server = ProxyServer::new(config).unwrap();
3536
3537        let health = server.state.health.load_full();
3538        assert!(!health.is_empty());
3539
3540        for node_health in health.values() {
3541            assert!(node_health.healthy);
3542            assert_eq!(node_health.failure_count, 0);
3543        }
3544    }
3545
3546    /// Build a minimal `ClientSession` for plugin-hook unit tests.
3547    fn make_test_session() -> Arc<ClientSession> {
3548        Arc::new(ClientSession {
3549            id: Uuid::new_v4(),
3550            client_addr: "127.0.0.1:0".parse().unwrap(),
3551            current_node: RwLock::new(None),
3552            tx_state: RwLock::new(TransactionState::default()),
3553            variables: RwLock::new(HashMap::new()),
3554            created_at: chrono::Utc::now(),
3555            tr_mode: crate::config::TrMode::default(),
3556            #[cfg(feature = "pool-modes")]
3557            pool_client_id: crate::pool::lease::ClientId::default(),
3558            #[cfg(feature = "wasm-plugins")]
3559            plugin_identity: RwLock::new(None),
3560        })
3561    }
3562
3563    /// With no plugin manager attached, `apply_route_hook` must be a
3564    /// zero-cost `None` return so the default SQL-verb routing applies.
3565    /// Verifies the feature-gated early-return path.
3566    #[tokio::test]
3567    async fn test_apply_route_hook_no_plugin_manager_returns_none() {
3568        let config = test_config();
3569        let server = ProxyServer::new(config).unwrap();
3570        let session = make_test_session();
3571
3572        let msg = QueryMessage {
3573            query: "SELECT * FROM users".to_string(),
3574        }
3575        .encode();
3576
3577        let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
3578        assert!(matches!(decision, RouteOverride::None));
3579    }
3580
3581    /// Same invariant for the pre-query hook: without a plugin manager,
3582    /// `apply_pre_query_hook` must return the message unchanged with
3583    /// `PreQueryAction::Forward`.
3584    #[tokio::test]
3585    async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
3586        let config = test_config();
3587        let server = ProxyServer::new(config).unwrap();
3588        let session = make_test_session();
3589
3590        let original = QueryMessage {
3591            query: "SELECT 1".to_string(),
3592        }
3593        .encode();
3594        let original_bytes = original.encode().to_vec();
3595
3596        let (msg_out, action) =
3597            ProxyServer::apply_pre_query_hook(original, &server.state, &session);
3598
3599        assert!(matches!(action, PreQueryAction::Forward));
3600        // The message must survive the hook byte-for-byte when no plugins run.
3601        assert_eq!(msg_out.encode().to_vec(), original_bytes);
3602    }
3603
3604    /// Non-Query message types (e.g., extended-protocol Parse/Execute) must
3605    /// bypass the Route hook entirely regardless of plugin state, because
3606    /// we haven't wired SQL extraction for those variants yet.
3607    #[tokio::test]
3608    async fn test_apply_route_hook_skips_non_query_messages() {
3609        let config = test_config();
3610        let server = ProxyServer::new(config).unwrap();
3611        let session = make_test_session();
3612
3613        let sync_msg = Message::empty(MessageType::Sync);
3614        let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
3615        assert!(matches!(decision, RouteOverride::None));
3616    }
3617
3618    /// By default, `[plugins].enabled = false`, so `init_plugin_manager`
3619    /// short-circuits without touching the filesystem or wasmtime and
3620    /// returns `None`. The proxy starts normally whether or not a plugin
3621    /// directory exists on the host.
3622    #[cfg(feature = "wasm-plugins")]
3623    #[test]
3624    fn test_init_plugin_manager_disabled_by_default_returns_none() {
3625        let config = test_config();
3626        assert!(!config.plugins.enabled);
3627        let pm = ProxyServer::init_plugin_manager(&config.plugins);
3628        assert!(pm.is_none());
3629    }
3630
3631    /// Plugins enabled but pointing at a directory that doesn't exist
3632    /// must still initialise the manager (so new plugins can be hot-
3633    /// loaded later) and log a warning — it must NOT fail startup.
3634    #[cfg(feature = "wasm-plugins")]
3635    #[test]
3636    fn test_init_plugin_manager_missing_dir_logs_warning() {
3637        let mut config = test_config();
3638        config.plugins.enabled = true;
3639        config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
3640
3641        // Manager is created; no panic; Some(pm) returned even with empty dir.
3642        let pm = ProxyServer::init_plugin_manager(&config.plugins);
3643        assert!(pm.is_some());
3644    }
3645
3646    /// With no plugin manager attached, `apply_authenticate_hook` is a
3647    /// zero-cost `Ok(())` that leaves session identity unset — the
3648    /// default PG auth flow applies.
3649    #[tokio::test]
3650    async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
3651        let config = test_config();
3652        let server = ProxyServer::new(config).unwrap();
3653        let session = make_test_session();
3654
3655        let mut params = HashMap::new();
3656        params.insert("user".to_string(), "alice".to_string());
3657        params.insert("database".to_string(), "app".to_string());
3658
3659        let result =
3660            ProxyServer::apply_authenticate_hook(&params, &session, &server.state).await;
3661        assert!(result.is_ok());
3662
3663        // No plugin → no identity stored.
3664        #[cfg(feature = "wasm-plugins")]
3665        {
3666            let ident = session.plugin_identity.read().await;
3667            assert!(ident.is_none());
3668        }
3669    }
3670
3671    /// Cached-response synthesis round-trip: a well-formed plugin
3672    /// payload must produce concatenated wire frames in the order
3673    /// `T D D C Z`. We inspect the raw tag bytes directly because
3674    /// `MessageType::from_tag` conflates server→client DataRow (`'D'`)
3675    /// with client→server Describe (same byte) — a known quirk of the
3676    /// shared `MessageType` enum that the real proxy side-steps by
3677    /// knowing the direction at the call site.
3678    #[cfg(feature = "wasm-plugins")]
3679    #[test]
3680    fn test_synthesise_cached_response_roundtrip() {
3681        let payload = br#"{
3682            "columns": [
3683                {"name": "id",    "oid": 23},
3684                {"name": "email", "oid": 25}
3685            ],
3686            "rows": [
3687                ["1", "alice@example.com"],
3688                ["2", null]
3689            ]
3690        }"#;
3691        let reply =
3692            ProxyServer::synthesise_cached_response(payload).expect("synthesis");
3693
3694        // Walk the concatenation frame-by-frame via length prefixes.
3695        // Each PG message: tag(1) + length(4, big-endian, includes self) + payload.
3696        let mut tags = Vec::new();
3697        let mut i = 0;
3698        while i < reply.len() {
3699            let tag = reply[i];
3700            let len = u32::from_be_bytes([
3701                reply[i + 1],
3702                reply[i + 2],
3703                reply[i + 3],
3704                reply[i + 4],
3705            ]) as usize;
3706            tags.push(tag);
3707            i += 1 + len;
3708        }
3709        assert_eq!(i, reply.len(), "no trailing bytes");
3710        assert_eq!(
3711            tags,
3712            vec![b'T', b'D', b'D', b'C', b'Z'],
3713            "wire frame order"
3714        );
3715
3716        // Spot-check the final ReadyForQuery payload is 'I' (idle).
3717        assert_eq!(*reply.last().unwrap(), b'I');
3718    }
3719
3720    /// Row width mismatch between columns and row data is rejected so
3721    /// the plugin author can't produce ambiguous wire frames.
3722    #[cfg(feature = "wasm-plugins")]
3723    #[test]
3724    fn test_synthesise_cached_response_rejects_row_width_mismatch() {
3725        let payload = br#"{
3726            "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
3727            "rows": [["1", "alice", "extra"]]
3728        }"#;
3729        let result = ProxyServer::synthesise_cached_response(payload);
3730        assert!(matches!(result, Err(ProxyError::Protocol(_))));
3731    }
3732
3733    /// Empty payload (no columns) is rejected — a RowDescription with
3734    /// zero columns is technically valid PG but useless and likely a
3735    /// plugin bug.
3736    #[cfg(feature = "wasm-plugins")]
3737    #[test]
3738    fn test_synthesise_cached_response_rejects_empty_columns() {
3739        let payload = br#"{ "columns": [], "rows": [] }"#;
3740        let result = ProxyServer::synthesise_cached_response(payload);
3741        assert!(matches!(result, Err(ProxyError::Protocol(_))));
3742    }
3743
3744    /// Malformed JSON must return a Protocol error, not panic. The
3745    /// caller treats this as "fall back to backend."
3746    #[cfg(feature = "wasm-plugins")]
3747    #[test]
3748    fn test_synthesise_cached_response_rejects_bad_json() {
3749        let payload = b"not json at all";
3750        let result = ProxyServer::synthesise_cached_response(payload);
3751        assert!(matches!(result, Err(ProxyError::Protocol(_))));
3752    }
3753
3754    /// Denied by plugin surfaces as `ProxyError::Auth` so the existing
3755    /// error-response path in `handle_client` writes an ErrorResponse
3756    /// and closes the connection. Here we prove the error variant
3757    /// when the plugin manager is present but denies. We build a
3758    /// PluginManager with no plugins loaded — so it defers — and
3759    /// verify the Ok path. (Denial path requires an actual
3760    /// auth-plugin `.wasm`; covered by the plugin unit tests in
3761    /// `plugins::tests`.)
3762    #[cfg(feature = "wasm-plugins")]
3763    #[tokio::test]
3764    async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
3765        use crate::plugins::{PluginManager, PluginRuntimeConfig};
3766
3767        let config = test_config();
3768        let server = ProxyServer::new(config).unwrap();
3769        let session = make_test_session();
3770
3771        // Synthesise a state with a real PluginManager but zero
3772        // registered plugins — every hook must defer.
3773        let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
3774        let augmented_state = Arc::new(ServerState {
3775            sessions: RwLock::new(HashMap::new()),
3776            health: ArcSwap::from_pointee(HashMap::new()),
3777            live_config: ArcSwap::from_pointee(ProxyConfig::default()),
3778            metrics: ServerMetrics::default(),
3779            cancel_map: Arc::new(DashMap::new()),
3780            tls_acceptor: None,
3781            auth_file: None,
3782            mirror: None,
3783            cutover: Arc::new(ArcSwap::from_pointee(None)),
3784            lb_state: LoadBalancerState {
3785                rr_counter: AtomicU64::new(0),
3786            },
3787            #[cfg(feature = "pool-modes")]
3788            pool_manager: None,
3789            plugin_manager: Some(pm),
3790            #[cfg(feature = "ha-tr")]
3791            transaction_journal: Arc::new(
3792                crate::transaction_journal::TransactionJournal::new(),
3793            ),
3794            #[cfg(feature = "anomaly-detection")]
3795            anomaly_detector: Arc::new(
3796                crate::anomaly::AnomalyDetector::new(
3797                    crate::anomaly::AnomalyConfig::default(),
3798                ),
3799            ),
3800            #[cfg(feature = "edge-proxy")]
3801            edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
3802            #[cfg(feature = "edge-proxy")]
3803            edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
3804                32,
3805                std::time::Duration::from_secs(120),
3806            )),
3807        });
3808
3809        let mut params = HashMap::new();
3810        params.insert("user".to_string(), "alice".to_string());
3811
3812        let result =
3813            ProxyServer::apply_authenticate_hook(&params, &session, &augmented_state).await;
3814        assert!(result.is_ok());
3815        let ident = session.plugin_identity.read().await;
3816        assert!(ident.is_none());
3817        // Unused bindings for the sync-state build path.
3818        let _ = server;
3819    }
3820
3821    // ---- Batch F.4: prepared-statement tracking across backend switches ----
3822
3823    fn cstr(s: &str) -> Vec<u8> {
3824        let mut v = s.as_bytes().to_vec();
3825        v.push(0);
3826        v
3827    }
3828
3829    #[test]
3830    fn parse_stmt_name_extracts_named_and_unnamed() {
3831        // Parse payload = stmt-name cstring + query cstring + int16 nparams.
3832        let mut named = cstr("ps1");
3833        named.extend_from_slice(&cstr("SELECT 1"));
3834        named.extend_from_slice(&[0, 0]);
3835        assert_eq!(ProxyServer::parse_stmt_name(&named), "ps1");
3836
3837        let mut unnamed = cstr("");
3838        unnamed.extend_from_slice(&cstr("SELECT 1"));
3839        unnamed.extend_from_slice(&[0, 0]);
3840        assert_eq!(ProxyServer::parse_stmt_name(&unnamed), "");
3841    }
3842
3843    #[test]
3844    fn bind_stmt_ref_reads_second_cstring() {
3845        // Bind payload = portal cstring + statement cstring + ...
3846        let mut named = cstr("portal_a");
3847        named.extend_from_slice(&cstr("ps1"));
3848        named.extend_from_slice(&[0, 0]); // 0 param-format codes, 0 params
3849        assert_eq!(ProxyServer::bind_stmt_ref(&named), Some("ps1"));
3850
3851        // Unnamed statement (empty second cstring) is not tracked.
3852        let mut unnamed = cstr("");
3853        unnamed.extend_from_slice(&cstr(""));
3854        assert_eq!(ProxyServer::bind_stmt_ref(&unnamed), None);
3855    }
3856
3857    #[test]
3858    fn stmt_kind_name_only_matches_statement_kind() {
3859        // Describe/Close 'S' (statement) carries a trackable name.
3860        let mut stmt = vec![b'S'];
3861        stmt.extend_from_slice(&cstr("ps1"));
3862        assert_eq!(ProxyServer::stmt_kind_name(&stmt), Some("ps1"));
3863
3864        // 'P' (portal) is not a statement reference.
3865        let mut portal = vec![b'P'];
3866        portal.extend_from_slice(&cstr("portal_a"));
3867        assert_eq!(ProxyServer::stmt_kind_name(&portal), None);
3868
3869        // Statement-kind but unnamed -> nothing to track.
3870        let mut empty = vec![b'S'];
3871        empty.extend_from_slice(&cstr(""));
3872        assert_eq!(ProxyServer::stmt_kind_name(&empty), None);
3873    }
3874
3875    #[tokio::test]
3876    async fn read_one_frame_type_consumes_full_frame() {
3877        // ParseComplete '1' with empty body, followed by a second frame to
3878        // prove only the first frame is consumed.
3879        let (mut a, mut b) = tokio::io::duplex(64);
3880        // frame 1: '1' + len(4) + no body; frame 2: 'Z' + len(5) + 'I'.
3881        let bytes = [b'1', 0, 0, 0, 4, b'Z', 0, 0, 0, 5, b'I'];
3882        b.write_all(&bytes).await.unwrap();
3883        let t = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
3884        assert_eq!(t, b'1');
3885        // The next frame's type byte is still readable -> we stopped cleanly.
3886        let t2 = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
3887        assert_eq!(t2, b'Z');
3888    }
3889
3890    #[tokio::test]
3891    async fn reprepare_statement_accepts_parse_complete_and_rejects_error() {
3892        // Backend answers ParseComplete -> Ok.
3893        let (mut client, mut backend) = tokio::io::duplex(64);
3894        backend.write_all(&[b'1', 0, 0, 0, 4]).await.unwrap();
3895        let parse = {
3896            let mut p = vec![b'P', 0, 0, 0, 0];
3897            p.extend_from_slice(&cstr("ps1"));
3898            p.extend_from_slice(&cstr("SELECT 1"));
3899            p.extend_from_slice(&[0, 0]);
3900            p
3901        };
3902        assert!(ProxyServer::reprepare_statement(&mut client, &parse).await.is_ok());
3903
3904        // Backend answers ErrorResponse -> Err.
3905        let (mut client2, mut backend2) = tokio::io::duplex(64);
3906        backend2.write_all(&[b'E', 0, 0, 0, 4]).await.unwrap();
3907        assert!(ProxyServer::reprepare_statement(&mut client2, &parse).await.is_err());
3908    }
3909}