1use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::client_tls::{build_tls_acceptor, ClientStream};
10use crate::config::{HbaAction, HbaRule, NodeConfig, NodeRole, ProxyConfig, TrMode};
11#[cfg(feature = "wasm-plugins")]
12use crate::protocol::QueryMessage;
13use crate::protocol::{
14 ErrorResponse, Message, MessageType, ProtocolCodec, StartupMessage, TransactionStatus,
15};
16use crate::{ProxyError, Result};
17use arc_swap::ArcSwap;
18use bytes::{BufMut, BytesMut};
19use dashmap::DashMap;
20use std::collections::{HashMap, HashSet};
21use std::net::SocketAddr;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{broadcast, RwLock};
28use uuid::Uuid;
29
30#[cfg(feature = "pool-modes")]
32use crate::pool::lease::ClientId;
33#[cfg(feature = "pool-modes")]
34use crate::pool::{ConnectionPoolManager, PoolModeConfig, PoolingMode};
35#[cfg(feature = "pool-modes")]
36use crate::NodeEndpoint;
37
38#[cfg(feature = "wasm-plugins")]
40use crate::plugins::{
41 AuthRequest as PluginAuthRequest, AuthResult, HookContext, HookType, Identity, PluginManager,
42 PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
43};
44
45pub struct ProxyServer {
47 config: ProxyConfig,
48 state: Arc<ServerState>,
49 shutdown_tx: broadcast::Sender<()>,
50 config_path: Option<String>,
54}
55
56#[cfg(not(unix))]
59struct HangupNever;
60#[cfg(not(unix))]
61impl HangupNever {
62 async fn recv(&mut self) -> Option<()> {
63 std::future::pending().await
64 }
65}
66
67#[cfg(feature = "ha-tr")]
81fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
82 BackendConfig {
83 host: "placeholder".to_string(),
84 port: 0,
85 user: "postgres".to_string(),
86 password: None,
87 database: None,
88 application_name: Some("heliosdb-proxy-replay".to_string()),
89 tls_mode: TlsMode::Disable,
90 connect_timeout: Duration::from_secs(5),
91 query_timeout: Duration::from_secs(30),
92 tls_config: default_client_config(),
93 }
94}
95
96#[cfg(feature = "anomaly-detection")]
106fn anomaly_fingerprint(sql: &str) -> String {
107 let mut out = String::with_capacity(sql.len());
108 let mut in_single = false;
109 let mut prev_space = false;
110 let mut chars = sql.chars().peekable();
111 while let Some(c) = chars.next() {
112 if c == '\'' {
113 in_single = !in_single;
114 if in_single {
117 out.push('?');
118 while let Some(&n) = chars.peek() {
119 chars.next();
120 if n == '\'' {
121 in_single = false;
122 break;
123 }
124 }
125 prev_space = false;
126 continue;
127 }
128 }
129 if c.is_ascii_digit() {
130 if !out.ends_with('?') {
131 out.push('?');
132 }
133 while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
135 chars.next();
136 }
137 prev_space = false;
138 continue;
139 }
140 if c.is_ascii_whitespace() {
141 if !prev_space && !out.is_empty() {
142 out.push(' ');
143 prev_space = true;
144 }
145 continue;
146 }
147 out.push(c.to_ascii_lowercase());
148 prev_space = false;
149 }
150 out.trim_end().to_string()
151}
152
153struct ServerState {
155 sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
157 health: ArcSwap<HashMap<String, NodeHealth>>,
163 health_write: parking_lot::Mutex<()>,
168 live_config: ArcSwap<ProxyConfig>,
175 metrics: ServerMetrics,
177 cancel_map: Arc<DashMap<(u32, u32), String>>,
183 cancel_order: Arc<parking_lot::Mutex<std::collections::VecDeque<(u32, u32)>>>,
187 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
190 auth_file: Option<Arc<crate::auth_scram::AuthFile>>,
194 mirror: Option<crate::mirror::MirrorHandle>,
197 cutover: Arc<ArcSwap<Option<Arc<crate::mirror::CutoverTarget>>>>,
201 lb_state: LoadBalancerState,
203 #[cfg(feature = "routing-hints")]
208 hint_parser: Option<crate::routing::HintParser>,
209 #[cfg(feature = "rate-limiting")]
212 rate_limiter: Option<Arc<crate::rate_limit::RateLimiter>>,
213 #[cfg(feature = "circuit-breaker")]
217 circuit_breaker: Option<Arc<crate::circuit_breaker::CircuitBreakerManager>>,
218 #[cfg(feature = "query-analytics")]
221 analytics: Option<Arc<crate::analytics::QueryAnalytics>>,
222 #[cfg(feature = "query-cache")]
225 query_cache: Option<Arc<crate::cache::QueryCache>>,
226 #[cfg(feature = "query-rewriting")]
229 rewriter: Option<Arc<crate::rewriter::QueryRewriter>>,
230 #[cfg(feature = "multi-tenancy")]
233 tenant_manager: Option<Arc<crate::multi_tenancy::TenantManager>>,
234 #[cfg(feature = "schema-routing")]
237 schema_analyzer: Option<Arc<crate::schema_routing::QueryAnalyzer>>,
238 #[cfg(feature = "pool-modes")]
240 pool_manager: Option<Arc<ConnectionPoolManager>>,
241 #[cfg(feature = "pool-modes")]
246 backend_pool: Option<Arc<crate::pool::BackendIdlePool>>,
247 #[cfg(feature = "wasm-plugins")]
251 plugin_manager: Option<Arc<PluginManager>>,
252 #[cfg(feature = "ha-tr")]
257 transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
258 #[cfg(feature = "anomaly-detection")]
261 anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
262 #[cfg(feature = "edge-proxy")]
266 edge_cache: Arc<crate::edge::EdgeCache>,
267 #[cfg(feature = "edge-proxy")]
268 edge_registry: Arc<crate::edge::EdgeRegistry>,
269}
270
271#[derive(Debug, Clone)]
273pub struct NodeHealth {
274 pub address: String,
276 pub healthy: bool,
278 pub last_check: chrono::DateTime<chrono::Utc>,
280 pub failure_count: u32,
282 pub last_error: Option<String>,
284 pub latency_ms: f64,
286 pub replication_lag_bytes: Option<u64>,
288}
289
290#[derive(Default)]
292struct ServerMetrics {
293 connections_accepted: AtomicU64,
295 connections_closed: AtomicU64,
297 queries_processed: AtomicU64,
299 bytes_received: AtomicU64,
301 bytes_sent: AtomicU64,
303 failovers: AtomicU64,
305}
306
307struct LoadBalancerState {
309 rr_counter: AtomicU64,
312}
313
314pub struct ClientSession {
316 pub id: Uuid,
318 pub client_addr: SocketAddr,
320 pub current_node: RwLock<Option<String>>,
322 pub tx_state: RwLock<TransactionState>,
324 pub variables: RwLock<HashMap<String, String>>,
326 pub created_at: chrono::DateTime<chrono::Utc>,
328 pub tr_mode: TrMode,
330 #[cfg(feature = "lag-routing")]
335 pub last_write_at: RwLock<Option<std::time::Instant>>,
336 #[cfg(feature = "pool-modes")]
338 pub pool_client_id: ClientId,
339 #[cfg(feature = "wasm-plugins")]
344 pub plugin_identity: RwLock<Option<Identity>>,
345}
346
347#[derive(Debug, Clone, Default)]
349pub struct TransactionState {
350 pub in_transaction: bool,
352 pub tx_id: Option<Uuid>,
354 pub statements: Vec<StatementLog>,
356 pub read_only: bool,
358 pub savepoints: Vec<String>,
360}
361
362#[derive(Debug, Clone)]
364pub struct StatementLog {
365 pub sql: String,
367 pub params: Vec<String>,
369 pub result_checksum: Option<u64>,
371 pub executed_at: chrono::DateTime<chrono::Utc>,
373}
374
375struct BackendConn {
388 stream: TcpStream,
389 prepared: HashSet<String>,
390 unnamed_sig: Option<bytes::Bytes>,
396}
397
398impl BackendConn {
399 fn new(stream: TcpStream) -> Self {
400 Self {
401 stream,
402 prepared: HashSet::new(),
403 unnamed_sig: None,
404 }
405 }
406}
407
408pub(crate) fn bind_reuseport(addr: &str) -> Result<TcpListener> {
414 use socket2::{Domain, Protocol, Socket, Type};
415 let sockaddr: SocketAddr = addr
416 .parse()
417 .map_err(|e| ProxyError::Config(format!("invalid listen address '{}': {}", addr, e)))?;
418 let domain = if sockaddr.is_ipv6() {
419 Domain::IPV6
420 } else {
421 Domain::IPV4
422 };
423 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
424 .map_err(|e| ProxyError::Network(format!("socket(): {}", e)))?;
425 socket
426 .set_reuse_address(true)
427 .map_err(|e| ProxyError::Network(format!("SO_REUSEADDR: {}", e)))?;
428 #[cfg(all(unix, not(target_os = "solaris")))]
429 socket
430 .set_reuse_port(true)
431 .map_err(|e| ProxyError::Network(format!("SO_REUSEPORT: {}", e)))?;
432 socket
433 .set_nonblocking(true)
434 .map_err(|e| ProxyError::Network(format!("set_nonblocking: {}", e)))?;
435 socket
436 .bind(&sockaddr.into())
437 .map_err(|e| ProxyError::Network(format!("Failed to bind {}: {}", addr, e)))?;
438 socket
439 .listen(1024)
440 .map_err(|e| ProxyError::Network(format!("listen(): {}", e)))?;
441 let std_listener: std::net::TcpListener = socket.into();
442 TcpListener::from_std(std_listener)
443 .map_err(|e| ProxyError::Network(format!("from_std listener: {}", e)))
444}
445
446#[derive(Debug)]
452#[allow(dead_code)] enum PreQueryAction {
454 Forward,
456 Block(String),
459 Cached(Vec<u8>),
465}
466
467#[derive(Debug)]
473#[allow(dead_code)] enum RouteOverride {
475 None,
477 Primary,
479 Standby,
481 Node(String),
485 Block(String),
490}
491
492impl ProxyServer {
493 #[cfg(feature = "wasm-plugins")]
501 fn init_plugin_manager(
502 toml_cfg: &crate::config::PluginToml,
503 ) -> Option<Arc<crate::plugins::PluginManager>> {
504 if !toml_cfg.enabled {
505 return None;
506 }
507
508 let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
509 let plugin_dir = runtime_cfg.plugin_dir.clone();
510
511 let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
512 Ok(pm) => Arc::new(pm),
513 Err(e) => {
514 tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
515 return None;
516 }
517 };
518
519 match std::fs::read_dir(&plugin_dir) {
520 Ok(entries) => {
521 let mut loaded = 0usize;
522 let mut failed = 0usize;
523 for entry in entries.flatten() {
524 let path = entry.path();
525 if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
526 continue;
527 }
528 match pm.load_plugin(&path) {
529 Ok(()) => loaded += 1,
530 Err(e) => {
531 failed += 1;
532 tracing::warn!(
533 path = %path.display(),
534 error = %e,
535 "Failed to load plugin"
536 );
537 }
538 }
539 }
540 tracing::info!(
541 dir = %plugin_dir.display(),
542 loaded = loaded,
543 failed = failed,
544 "Plugin loading complete"
545 );
546 }
547 Err(e) => {
548 tracing::warn!(
549 dir = %plugin_dir.display(),
550 error = %e,
551 "Plugin directory not readable; no plugins loaded"
552 );
553 }
554 }
555
556 Some(pm)
557 }
558
559 pub fn new(config: ProxyConfig) -> Result<Self> {
561 let (shutdown_tx, _) = broadcast::channel(1);
562
563 let mut health = HashMap::new();
565 for node in &config.nodes {
566 health.insert(
567 node.address(),
568 NodeHealth {
569 address: node.address(),
570 healthy: true, last_check: chrono::Utc::now(),
572 failure_count: 0,
573 last_error: None,
574 latency_ms: 0.0,
575 replication_lag_bytes: None,
576 },
577 );
578 }
579
580 #[cfg(feature = "pool-modes")]
582 let pool_manager = {
583 use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
584
585 let pool_config = PoolModeConfig {
586 default_mode: match config.pool_mode.mode {
587 crate::config::PoolingMode::Session => PoolingMode::Session,
588 crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
589 crate::config::PoolingMode::Statement => PoolingMode::Statement,
590 },
591 max_pool_size: config.pool_mode.max_pool_size,
592 min_idle: config.pool_mode.min_idle,
593 idle_timeout_secs: config.pool_mode.idle_timeout_secs,
594 max_lifetime_secs: config.pool_mode.max_lifetime_secs,
595 acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
596 reset_query: config.pool_mode.reset_query.clone(),
597 prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
598 crate::config::PreparedStatementMode::Disable => {
599 PoolPreparedStatementMode::Disable
600 }
601 crate::config::PreparedStatementMode::Track => PoolPreparedStatementMode::Track,
602 crate::config::PreparedStatementMode::Named => PoolPreparedStatementMode::Named,
603 },
604 test_on_acquire: config.pool.test_on_acquire,
605 validation_query: "SELECT 1".to_string(),
606 queue_timeout_secs: 30,
607 max_queue_size: 0,
608 };
609 Some(Arc::new(ConnectionPoolManager::new(pool_config)))
610 };
611
612 #[cfg(feature = "pool-modes")]
616 let backend_pool = match config.pool_mode.mode {
617 crate::config::PoolingMode::Transaction | crate::config::PoolingMode::Statement => {
618 tracing::info!(
619 mode = ?config.pool_mode.mode,
620 max_idle_per_identity = config.pool_mode.max_pool_size,
621 "pool-modes: data-path connection pooling enabled"
622 );
623 Some(Arc::new(crate::pool::BackendIdlePool::new(
624 config.pool_mode.max_pool_size as usize,
625 Self::MAX_TOTAL_IDLE_BACKEND_CONNS,
626 )))
627 }
628 crate::config::PoolingMode::Session => None,
629 };
630
631 #[cfg(feature = "wasm-plugins")]
636 let plugin_manager = Self::init_plugin_manager(&config.plugins);
637
638 let tls_acceptor = match config.tls.as_ref() {
642 Some(tls) if tls.enabled => match build_tls_acceptor(tls) {
643 Ok(acc) => {
644 tracing::info!(
645 mtls = tls.require_client_cert,
646 "client TLS termination enabled"
647 );
648 Some(acc)
649 }
650 Err(e) => {
651 return Err(ProxyError::Config(format!("TLS init failed: {}", e)));
652 }
653 },
654 _ => None,
655 };
656
657 let auth_file = if config.auth.mode == crate::config::AuthMode::Scram {
660 let path = config.auth.auth_file.as_ref().ok_or_else(|| {
661 ProxyError::Config("auth mode 'scram' requires auth_file".to_string())
662 })?;
663 let af = crate::auth_scram::AuthFile::load(path)
664 .map_err(|e| ProxyError::Config(format!("auth_file: {}", e)))?;
665 tracing::info!(users = %(!af.is_empty()), "proxy SCRAM auth enabled");
666 Some(Arc::new(af))
667 } else {
668 None
669 };
670
671 let mirror = if config.mirror.enabled {
674 tracing::info!(target = %format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
675 writes_only = config.mirror.writes_only, "traffic mirroring enabled");
676 Some(crate::mirror::spawn(config.mirror.clone()))
677 } else {
678 None
679 };
680
681 #[cfg(feature = "rate-limiting")]
683 let rate_limiter = if config.rate_limit.enabled {
684 let rl = &config.rate_limit;
685 tracing::info!(
686 qps = rl.default_qps,
687 burst = rl.default_burst,
688 key_by = ?rl.key_by,
689 "rate limiting enabled"
690 );
691 let rlc = crate::rate_limit::RateLimitConfig {
692 enabled: true,
693 default_qps: rl.default_qps,
694 default_burst: rl.default_burst,
695 default_concurrency: if rl.max_concurrent > 0 {
696 rl.max_concurrent
697 } else {
698 crate::rate_limit::RateLimitConfig::default().default_concurrency
699 },
700 ..Default::default()
701 };
702 Some(Arc::new(crate::rate_limit::RateLimiter::new(rlc)))
703 } else {
704 None
705 };
706
707 #[cfg(feature = "circuit-breaker")]
709 let circuit_breaker = if config.circuit_breaker.enabled {
710 let cb = &config.circuit_breaker;
711 tracing::info!(
712 failure_threshold = cb.failure_threshold,
713 open_secs = cb.open_secs,
714 "circuit breaker enabled"
715 );
716 let cbc = crate::circuit_breaker::CircuitBreakerConfig {
717 failure_threshold: cb.failure_threshold,
718 cooldown: Duration::from_secs(cb.open_secs),
719 half_open_success_threshold: cb.success_threshold,
720 ..Default::default()
721 };
722 let mgr = crate::circuit_breaker::CircuitBreakerManager::new(
723 crate::circuit_breaker::ManagerConfig::new(cbc),
724 );
725 Some(Arc::new(mgr))
726 } else {
727 None
728 };
729
730 #[cfg(feature = "query-analytics")]
732 let analytics = if config.analytics.enabled {
733 let a = &config.analytics;
734 tracing::info!(
735 slow_query_ms = a.slow_query_ms,
736 max_fingerprints = a.max_fingerprints,
737 "query analytics enabled"
738 );
739 let ac = crate::analytics::AnalyticsConfig {
740 enabled: true,
741 max_fingerprints: a.max_fingerprints as usize,
742 slow_query: crate::analytics::SlowQueryConfig {
743 threshold: Duration::from_millis(a.slow_query_ms),
744 ..Default::default()
745 },
746 ..Default::default()
747 };
748 Some(Arc::new(crate::analytics::QueryAnalytics::new(ac)))
749 } else {
750 None
751 };
752
753 #[cfg(feature = "query-cache")]
755 let query_cache = if config.cache.enabled {
756 let c = &config.cache;
757 tracing::info!(
758 ttl_secs = c.ttl_secs,
759 max_result_bytes = c.max_result_bytes,
760 "query cache enabled (L1 hot + L2 warm)"
761 );
762 let ttl = Duration::from_secs(c.ttl_secs);
763 let cc = crate::cache::CacheConfig {
764 enabled: true,
765 default_ttl: ttl,
766 max_result_size: c.max_result_bytes,
767 l1: crate::cache::L1Config {
768 ttl,
769 ..Default::default()
770 },
771 l2: crate::cache::L2Config {
772 ttl,
773 ..Default::default()
774 },
775 ..Default::default()
776 };
777 Some(Arc::new(crate::cache::QueryCache::new(cc)))
778 } else {
779 None
780 };
781
782 #[cfg(feature = "query-rewriting")]
784 let rewriter = if config.query_rewrite.enabled && !config.query_rewrite.rules.is_empty() {
785 use crate::rewriter::{
786 QueryPattern, QueryRewriter, RewriteRule, RewriterConfig, Transformation,
787 };
788 let rw = QueryRewriter::new(RewriterConfig {
789 enabled: true,
790 ..Default::default()
791 });
792 let mut n = 0usize;
793 for (i, r) in config.query_rewrite.rules.iter().enumerate() {
794 let transformation =
795 if let (Some(from), Some(to)) = (&r.match_table, &r.replace_table_with) {
796 Transformation::ReplaceTable {
797 from: from.clone(),
798 to: to.clone(),
799 }
800 } else if let Some(w) = &r.append_where {
801 Transformation::AppendWhereAnd(w.clone())
802 } else if let Some(limit) = r.add_limit {
803 Transformation::AddLimit(limit)
804 } else {
805 continue; };
807 let pattern = if let Some(t) = &r.match_table {
808 QueryPattern::Table(t.clone())
809 } else if let Some(re) = &r.match_regex {
810 QueryPattern::regex(re.clone())
811 } else {
812 QueryPattern::All
813 };
814 rw.add_rule(
815 RewriteRule::build(format!("rule-{i}"))
816 .pattern(pattern)
817 .transform(transformation)
818 .build(),
819 );
820 n += 1;
821 }
822 tracing::info!(rules = n, "query rewriting enabled");
823 Some(Arc::new(rw))
824 } else {
825 None
826 };
827
828 #[cfg(feature = "multi-tenancy")]
830 let tenant_manager =
831 if config.multi_tenancy.enabled && !config.multi_tenancy.tenants.is_empty() {
832 use crate::multi_tenancy::{
833 IdentificationMethod, IsolationStrategy, MultiTenancyConfig, TenantConfig,
834 TenantId, TenantManagerBuilder, TenantQueryTransformer,
835 };
836 let mt = &config.multi_tenancy;
837 let identification = match mt.identify_by.as_str() {
838 "database" => IdentificationMethod::DatabaseName,
839 param => IdentificationMethod::Header {
840 header_name: param.to_string(),
841 },
842 };
843 let mtc = MultiTenancyConfig {
844 enabled: true,
845 identification,
846 ..Default::default()
847 };
848 let table_refs: Vec<&str> = mt.tenant_tables.iter().map(|s| s.as_str()).collect();
850 let transformer = TenantQueryTransformer::new()
851 .register_tables(&table_refs, mt.tenant_column.clone());
852 let tm = TenantManagerBuilder::new()
853 .config(mtc)
854 .query_transformer(transformer)
855 .build();
856 for id in &mt.tenants {
857 tm.register_tenant(TenantConfig::new(
858 TenantId::new(id.clone()),
859 IsolationStrategy::row("public", mt.tenant_column.clone()),
860 ));
861 }
862 tracing::info!(
863 tenants = mt.tenants.len(),
864 identify_by = %mt.identify_by,
865 "multi-tenancy enabled"
866 );
867 Some(Arc::new(tm))
868 } else {
869 None
870 };
871
872 #[cfg(feature = "schema-routing")]
874 let schema_analyzer =
875 if config.schema_routing.enabled && !config.schema_routing.analytics_node.is_empty() {
876 tracing::info!(
877 analytics_node = %config.schema_routing.analytics_node,
878 "schema/workload routing enabled (OLAP -> analytics node)"
879 );
880 let registry = Arc::new(crate::schema_routing::SchemaRegistry::new());
881 Some(Arc::new(crate::schema_routing::QueryAnalyzer::new(
882 registry,
883 )))
884 } else {
885 None
886 };
887
888 let state = Arc::new(ServerState {
889 sessions: RwLock::new(HashMap::new()),
890 health: ArcSwap::from_pointee(health),
891 health_write: parking_lot::Mutex::new(()),
892 live_config: ArcSwap::from_pointee(config.clone()),
893 metrics: ServerMetrics::default(),
894 cancel_map: Arc::new(DashMap::new()),
895 cancel_order: Arc::new(parking_lot::Mutex::new(std::collections::VecDeque::new())),
896 tls_acceptor,
897 auth_file,
898 mirror,
899 cutover: Arc::new(ArcSwap::from_pointee(None)),
900 lb_state: LoadBalancerState {
901 rr_counter: AtomicU64::new(0),
902 },
903 #[cfg(feature = "routing-hints")]
904 hint_parser: if config.routing_hints.enabled {
905 tracing::info!(
906 strip = config.routing_hints.strip_hints,
907 "SQL-comment routing hints enabled"
908 );
909 Some(if config.routing_hints.strip_hints {
910 crate::routing::HintParser::new()
911 } else {
912 crate::routing::HintParser::without_stripping()
913 })
914 } else {
915 None
916 },
917 #[cfg(feature = "rate-limiting")]
918 rate_limiter,
919 #[cfg(feature = "circuit-breaker")]
920 circuit_breaker,
921 #[cfg(feature = "query-analytics")]
922 analytics,
923 #[cfg(feature = "query-cache")]
924 query_cache,
925 #[cfg(feature = "query-rewriting")]
926 rewriter,
927 #[cfg(feature = "multi-tenancy")]
928 tenant_manager,
929 #[cfg(feature = "schema-routing")]
930 schema_analyzer,
931 #[cfg(feature = "pool-modes")]
932 pool_manager,
933 #[cfg(feature = "pool-modes")]
934 backend_pool,
935 #[cfg(feature = "wasm-plugins")]
936 plugin_manager,
937 #[cfg(feature = "ha-tr")]
938 transaction_journal: Arc::new(crate::transaction_journal::TransactionJournal::new()),
939 #[cfg(feature = "anomaly-detection")]
940 anomaly_detector: Arc::new(crate::anomaly::AnomalyDetector::new(
941 crate::anomaly::AnomalyConfig::default(),
942 )),
943 #[cfg(feature = "edge-proxy")]
944 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
945 #[cfg(feature = "edge-proxy")]
946 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
947 32,
948 std::time::Duration::from_secs(120),
949 )),
950 });
951
952 Ok(Self {
953 config,
954 state,
955 shutdown_tx,
956 config_path: None,
957 })
958 }
959
960 pub fn with_config_path(mut self, path: Option<String>) -> Self {
964 self.config_path = path;
965 self
966 }
967
968 #[cfg(unix)]
971 fn hangup_stream() -> tokio::signal::unix::Signal {
972 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
973 .expect("failed to install SIGHUP handler")
974 }
975 #[cfg(not(unix))]
976 fn hangup_stream() -> HangupNever {
977 HangupNever
978 }
979
980 #[cfg(unix)]
983 fn usr2_stream() -> tokio::signal::unix::Signal {
984 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())
985 .expect("failed to install SIGUSR2 handler")
986 }
987 #[cfg(not(unix))]
988 fn usr2_stream() -> HangupNever {
989 HangupNever
990 }
991
992 async fn drain_connections(state: &Arc<ServerState>, timeout: Duration) {
996 let deadline = tokio::time::Instant::now() + timeout;
997 loop {
998 let active = state.sessions.read().await.len();
999 if active == 0 {
1000 tracing::info!("drain complete — all in-flight connections finished");
1001 return;
1002 }
1003 if tokio::time::Instant::now() >= deadline {
1004 tracing::warn!(
1005 active,
1006 "drain timeout reached — exiting with connections still open"
1007 );
1008 return;
1009 }
1010 tokio::time::sleep(Duration::from_millis(200)).await;
1011 }
1012 }
1013
1014 fn drain_timeout(config_secs: u64) -> Duration {
1019 let secs = std::env::var("HELIOS_DRAIN_TIMEOUT_SECS")
1020 .ok()
1021 .and_then(|s| s.parse::<u64>().ok())
1022 .unwrap_or(config_secs);
1023 Duration::from_secs(secs)
1024 }
1025
1026 async fn reload_config(&self) {
1035 let Some(path) = self.config_path.as_deref() else {
1036 tracing::warn!(
1037 "SIGHUP received but config was not loaded from a file — nothing to reload"
1038 );
1039 return;
1040 };
1041 tracing::info!(path, "SIGHUP: reloading configuration");
1042 let new_config = match ProxyConfig::from_file(path) {
1043 Ok(c) => c,
1044 Err(e) => {
1045 tracing::error!(path, error = %e, "SIGHUP reload failed to parse — keeping current config");
1046 return;
1047 }
1048 };
1049 let old = self.state.live_config.load_full();
1050 if new_config.listen_address != old.listen_address {
1051 tracing::warn!(old = %old.listen_address, new = %new_config.listen_address,
1052 "listen_address change needs a restart/handoff; the bound socket is kept");
1053 }
1054 if new_config.admin_address != old.admin_address {
1055 tracing::warn!(old = %old.admin_address, new = %new_config.admin_address,
1056 "admin_address change needs a restart; the bound socket is kept");
1057 }
1058 Self::reconcile_health(&self.state, &new_config);
1061 let nodes = new_config.nodes.len();
1062 let hba_rules = new_config.hba.len();
1063 let pool_max = new_config.pool.max_connections;
1064 self.state.live_config.store(Arc::new(new_config));
1065 tracing::info!(
1066 nodes,
1067 hba_rules,
1068 pool_max,
1069 "SIGHUP: configuration reloaded — applies to new connections"
1070 );
1071 }
1072
1073 fn reconcile_health(state: &Arc<ServerState>, config: &ProxyConfig) {
1077 let _writers = state.health_write.lock();
1080 let current = state.health.load_full();
1081 let mut next: HashMap<String, NodeHealth> = HashMap::new();
1082 for node in &config.nodes {
1083 let addr = node.address();
1084 match current.get(&addr) {
1085 Some(existing) => {
1086 next.insert(addr, existing.clone());
1087 }
1088 None => {
1089 tracing::info!(node = %addr, "SIGHUP: new node added — seeding healthy");
1090 next.insert(
1091 addr.clone(),
1092 NodeHealth {
1093 address: addr,
1094 healthy: true,
1095 last_check: chrono::Utc::now(),
1096 failure_count: 0,
1097 last_error: None,
1098 latency_ms: 0.0,
1099 replication_lag_bytes: None,
1100 },
1101 );
1102 }
1103 }
1104 }
1105 for gone in current.keys().filter(|k| !next.contains_key(*k)) {
1106 tracing::info!(node = %gone, "SIGHUP: node removed from config");
1107 }
1108 state.health.store(Arc::new(next));
1109 }
1110
1111 pub async fn run(&self) -> Result<()> {
1113 let listener = bind_reuseport(&self.config.listen_address)?;
1119
1120 tracing::info!(
1121 "Proxy listening on {} (SO_REUSEPORT)",
1122 self.config.listen_address
1123 );
1124
1125 let health_task = self.spawn_health_checker();
1127 let pool_task = self.spawn_pool_manager();
1128
1129 let admin_task = self.spawn_admin_server();
1131
1132 let mcp_task = if self.config.mcp.enabled {
1134 let mcp_cfg = self.config.mcp.clone();
1135 let contract = mcp_cfg.contract.as_ref().and_then(|id| {
1137 let found = self.config.agent_contracts.iter().find(|c| &c.id == id).cloned();
1138 if found.is_none() {
1139 tracing::warn!(%id, "mcp.contract names an unknown agent_contract; gateway runs with only the read-only guardrail");
1140 }
1141 found
1142 });
1143 Some(tokio::spawn(async move {
1144 if let Err(e) = crate::mcp::McpServer::new(mcp_cfg, contract).run().await {
1145 tracing::error!("MCP gateway error: {}", e);
1146 }
1147 }))
1148 } else {
1149 None
1150 };
1151
1152 let http_gw_task = if self.config.http_gateway.enabled {
1154 let gw_cfg = self.config.http_gateway.clone();
1155 Some(tokio::spawn(async move {
1156 if let Err(e) = crate::http_gateway::HttpGateway::new(gw_cfg).run().await {
1157 tracing::error!("HTTP gateway error: {}", e);
1158 }
1159 }))
1160 } else {
1161 None
1162 };
1163
1164 #[cfg(feature = "graphql-gateway")]
1166 let _graphql_gw_task = if self.config.graphql_gateway.enabled {
1167 let gw_cfg = self.config.graphql_gateway.clone();
1168 Some(tokio::spawn(async move {
1169 if let Err(e) = crate::graphql_gateway::GraphqlGateway::new(gw_cfg)
1170 .run()
1171 .await
1172 {
1173 tracing::error!("GraphQL gateway error: {}", e);
1174 }
1175 }))
1176 } else {
1177 None
1178 };
1179
1180 let mut shutdown_rx = self.shutdown_tx.subscribe();
1181
1182 let mut sighup = Self::hangup_stream();
1186 let mut sigusr2 = Self::usr2_stream();
1187 let mut graceful = false;
1188
1189 loop {
1190 tokio::select! {
1191 _ = sighup.recv() => {
1192 self.reload_config().await;
1193 }
1194 _ = sigusr2.recv() => {
1195 tracing::info!(
1196 "SIGUSR2: graceful binary-handoff drain — closing the listener so new \
1197 connections route to the sibling process; finishing in-flight connections"
1198 );
1199 graceful = true;
1200 break;
1201 }
1202 accept_result = listener.accept() => {
1203 match accept_result {
1204 Ok((stream, addr)) => {
1205 let _ = stream.set_nodelay(true);
1209 self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
1210 let state = self.state.clone();
1211 let config = (*self.state.live_config.load_full()).clone();
1215 let shutdown_tx = self.shutdown_tx.clone();
1216
1217 tokio::spawn(async move {
1218 if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
1219 tracing::error!("Client handler error: {}", e);
1220 }
1221 });
1222 }
1223 Err(e) => {
1224 tracing::error!("Accept error: {}", e);
1225 }
1226 }
1227 }
1228 _ = shutdown_rx.recv() => {
1229 tracing::info!("Shutdown signal received");
1230 break;
1231 }
1232 }
1233 }
1234
1235 drop(listener);
1239
1240 if graceful {
1243 let timeout =
1244 Self::drain_timeout(self.state.live_config.load().shutdown_drain_timeout_secs);
1245 tracing::info!(
1246 timeout_secs = timeout.as_secs(),
1247 "draining in-flight connections"
1248 );
1249 Self::drain_connections(&self.state, timeout).await;
1250 }
1251
1252 health_task.abort();
1254 pool_task.abort();
1255 admin_task.abort();
1256 if let Some(t) = mcp_task {
1257 t.abort();
1258 }
1259 if let Some(t) = http_gw_task {
1260 t.abort();
1261 }
1262
1263 Ok(())
1264 }
1265
1266 fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
1268 let config = self.config.clone();
1269 let state = self.state.clone();
1270 let mut shutdown_rx = self.shutdown_tx.subscribe();
1271
1272 tokio::spawn(async move {
1273 let admin_state = Arc::new(AdminState::new());
1275
1276 {
1278 let mut snapshot = admin_state.config_snapshot.write().await;
1279 *snapshot = ConfigSnapshot {
1280 listen_address: config.listen_address.clone(),
1281 admin_address: config.admin_address.clone(),
1282 tr_enabled: config.tr_enabled,
1283 tr_mode: format!("{:?}", config.tr_mode),
1284 pool_min_connections: config.pool.min_connections,
1285 pool_max_connections: config.pool.max_connections,
1286 nodes: config
1287 .nodes
1288 .iter()
1289 .map(|n| NodeSnapshot {
1290 address: n.address(),
1291 role: format!("{:?}", n.role),
1292 weight: n.weight,
1293 enabled: n.enabled,
1294 })
1295 .collect(),
1296 };
1297 }
1298
1299 admin_state.set_proxy_config(config.clone()).await;
1301
1302 admin_state
1304 .with_auth_token(config.admin_token.clone())
1305 .await;
1306
1307 if config.branch.enabled {
1309 admin_state.with_branch(config.branch.clone()).await;
1310 }
1311
1312 if let Some(ref mirror) = state.mirror {
1314 admin_state
1315 .with_migration(crate::admin::MigrationInfo {
1316 target: mirror.target().to_string(),
1317 writes_only: mirror.writes_only(),
1318 metrics: mirror.metrics.clone(),
1319 config: config.mirror.clone(),
1320 cutover: state.cutover.clone(),
1321 cutover_target: crate::mirror::CutoverTarget {
1322 addr: format!(
1323 "{}:{}",
1324 config.mirror.backend_host, config.mirror.backend_port
1325 ),
1326 user: config.mirror.backend_user.clone(),
1327 password: config.mirror.backend_password.clone(),
1328 database: config.mirror.backend_database.clone(),
1329 },
1330 })
1331 .await;
1332 }
1333
1334 #[cfg(feature = "wasm-plugins")]
1339 if let Some(ref pm) = state.plugin_manager {
1340 admin_state.with_plugin_manager(pm.clone()).await;
1341 }
1342
1343 #[cfg(feature = "pool-modes")]
1346 if let Some(ref pm) = state.pool_manager {
1347 admin_state.with_pool_manager(pm.clone()).await;
1348 }
1349
1350 #[cfg(feature = "circuit-breaker")]
1353 if let Some(ref cb) = state.circuit_breaker {
1354 admin_state.with_circuit_breaker(cb.clone()).await;
1355 }
1356
1357 #[cfg(feature = "ha-tr")]
1364 {
1365 let template = build_replay_backend_template(&config);
1366 let engine = Arc::new(crate::replay::ReplayEngine::new(
1367 state.transaction_journal.clone(),
1368 template,
1369 ));
1370 admin_state.with_replay_engine(engine).await;
1371 }
1372
1373 #[cfg(feature = "anomaly-detection")]
1377 admin_state
1378 .with_anomaly_detector(state.anomaly_detector.clone())
1379 .await;
1380
1381 #[cfg(feature = "query-analytics")]
1383 if let Some(a) = state.analytics.as_ref() {
1384 admin_state.with_analytics(a.clone()).await;
1385 }
1386
1387 #[cfg(feature = "edge-proxy")]
1390 admin_state
1391 .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
1392 .await;
1393
1394 let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
1396
1397 let admin_state_sync = admin_state.clone();
1399 let server_state = state.clone();
1400 let sync_task = tokio::spawn(async move {
1401 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
1402 loop {
1403 interval.tick().await;
1404
1405 {
1407 let health = server_state.health.load_full();
1408 let mut admin_health = admin_state_sync.node_health.write().await;
1409 *admin_health = (*health).clone();
1410 }
1411
1412 {
1414 let metrics = ServerMetricsSnapshot {
1415 connections_accepted: server_state
1416 .metrics
1417 .connections_accepted
1418 .load(Ordering::Relaxed),
1419 connections_closed: server_state
1420 .metrics
1421 .connections_closed
1422 .load(Ordering::Relaxed),
1423 queries_processed: server_state
1424 .metrics
1425 .queries_processed
1426 .load(Ordering::Relaxed),
1427 bytes_received: server_state
1428 .metrics
1429 .bytes_received
1430 .load(Ordering::Relaxed),
1431 bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
1432 failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
1433 };
1434 let mut admin_metrics = admin_state_sync.metrics.write().await;
1435 *admin_metrics = metrics;
1436 }
1437
1438 {
1440 let sessions = server_state.sessions.read().await;
1441 let mut admin_sessions = admin_state_sync.active_sessions.write().await;
1442 *admin_sessions = sessions.len() as u64;
1443 }
1444 }
1445 });
1446
1447 tokio::select! {
1449 result = admin_server.run() => {
1450 if let Err(e) = result {
1451 tracing::error!("Admin server error: {}", e);
1452 }
1453 }
1454 _ = shutdown_rx.recv() => {
1455 tracing::info!("Admin server shutting down");
1456 }
1457 }
1458
1459 sync_task.abort();
1460 })
1461 }
1462
1463 async fn handle_client(
1465 stream: TcpStream,
1466 addr: SocketAddr,
1467 state: Arc<ServerState>,
1468 config: ProxyConfig,
1469 _shutdown_tx: broadcast::Sender<()>,
1470 ) -> Result<()> {
1471 tracing::debug!("New client connection from {}", addr);
1472
1473 let session = Arc::new(ClientSession {
1475 id: Uuid::new_v4(),
1476 client_addr: addr,
1477 current_node: RwLock::new(None),
1478 tx_state: RwLock::new(TransactionState::default()),
1479 variables: RwLock::new(HashMap::new()),
1480 created_at: chrono::Utc::now(),
1481 tr_mode: config.tr_mode,
1482 #[cfg(feature = "lag-routing")]
1483 last_write_at: RwLock::new(None),
1484 #[cfg(feature = "pool-modes")]
1485 pool_client_id: ClientId::new(),
1486 #[cfg(feature = "wasm-plugins")]
1487 plugin_identity: RwLock::new(None),
1488 });
1489
1490 {
1492 let mut sessions = state.sessions.write().await;
1493 sessions.insert(session.id, session.clone());
1494 }
1495
1496 let result = match Self::negotiate_client_tls(stream, &state).await {
1501 Ok((mut client_stream, pre)) => {
1502 Self::client_loop(&mut client_stream, pre, &session, &state, &config).await
1503 }
1504 Err(e) => Err(e),
1505 };
1506
1507 {
1509 let mut sessions = state.sessions.write().await;
1510 sessions.remove(&session.id);
1511 }
1512
1513 #[cfg(feature = "pool-modes")]
1515 if let Some(ref pool_manager) = state.pool_manager {
1516 if pool_manager.has_active_lease(&session.pool_client_id) {
1518 tracing::debug!(
1519 "Releasing pool lease for disconnecting client {:?}",
1520 session.pool_client_id
1521 );
1522 }
1525 }
1526
1527 state
1528 .metrics
1529 .connections_closed
1530 .fetch_add(1, Ordering::Relaxed);
1531
1532 result
1533 }
1534
1535 async fn client_loop(
1537 stream: &mut ClientStream,
1538 pre: Option<StartupMessage>,
1539 session: &Arc<ClientSession>,
1540 state: &Arc<ServerState>,
1541 config: &ProxyConfig,
1542 ) -> Result<()> {
1543 let codec = ProtocolCodec::new();
1544 let mut buffer = BytesMut::with_capacity(8192);
1545
1546 let mut conns: HashMap<String, BackendConn> = HashMap::new();
1563 let mut current_node: Option<String> =
1564 match Self::handle_startup(stream, &mut buffer, &codec, pre, session, state, config)
1565 .await
1566 {
1567 Ok((Some(stream_conn), node_addr)) => {
1568 conns.insert(node_addr.clone(), BackendConn::new(stream_conn));
1569 Some(node_addr)
1570 }
1571 Ok((None, _)) => {
1572 return Ok(());
1574 }
1575 Err(e) => {
1576 tracing::error!("Startup failed: {}", e);
1577 let err_msg =
1579 Self::create_error_response("08006", &format!("Startup failed: {}", e));
1580 let _ = stream.write_all(&err_msg).await;
1581 return Err(e);
1582 }
1583 };
1584
1585 let mut read_buf = vec![0u8; 16384];
1600 let mut pending = BytesMut::new();
1601 let mut pending_route_sql: Option<String> = None;
1602 let mut stmt_registry: HashMap<String, bytes::Bytes> = HashMap::new();
1610 let mut stmt_registry_bytes: usize = 0;
1613 let mut batch_defines: Vec<String> = Vec::new();
1614 let mut batch_refs: Vec<String> = Vec::new();
1615 let mut batch_closes: Vec<String> = Vec::new();
1616 let promote_unnamed = config.optimize_unnamed_parse;
1622 let mut held_unnamed: Option<(bytes::Bytes, bytes::Bytes)> = None;
1623 loop {
1624 let n = stream
1626 .read(&mut read_buf)
1627 .await
1628 .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
1629
1630 if n == 0 {
1631 break;
1633 }
1634
1635 buffer.extend_from_slice(&read_buf[..n]);
1636 state
1637 .metrics
1638 .bytes_received
1639 .fetch_add(n as u64, Ordering::Relaxed);
1640
1641 if buffer.len() > Self::MAX_PENDING_BYTES {
1645 let emsg = Self::create_error_response(
1646 "53400",
1647 "message exceeds per-session size limit",
1648 );
1649 let _ = stream.write_all(&emsg).await;
1650 let _ = stream.write_all(&Self::create_ready_for_query(b'I')).await;
1651 tracing::warn!(
1652 client = %session.client_addr,
1653 bytes = buffer.len(),
1654 "inbound message exceeds size cap; closing connection"
1655 );
1656 return Ok(());
1657 }
1658
1659 while let Some(msg) = codec.decode_message(&mut buffer)? {
1661 match msg.msg_type {
1662 MessageType::Terminate => return Ok(()),
1663
1664 MessageType::Query => {
1666 #[cfg(feature = "anomaly-detection")]
1670 Self::record_anomaly_observation(&msg, state, session);
1671
1672 let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
1675
1676 if let PreQueryAction::Block(reason) = &action {
1677 tracing::info!(reason = %reason, "pre-query plugin blocked query");
1678 Self::send_block_response(stream, reason, state).await?;
1679 state
1680 .metrics
1681 .queries_processed
1682 .fetch_add(1, Ordering::Relaxed);
1683 continue;
1684 }
1685
1686 #[cfg(feature = "wasm-plugins")]
1687 if let PreQueryAction::Cached(bytes) = &action {
1688 match Self::synthesise_cached_response(bytes) {
1689 Ok(reply) => {
1690 stream.write_all(&reply).await.map_err(|e| {
1691 ProxyError::Network(format!("Write error: {}", e))
1692 })?;
1693 state
1694 .metrics
1695 .bytes_sent
1696 .fetch_add(reply.len() as u64, Ordering::Relaxed);
1697 state
1698 .metrics
1699 .queries_processed
1700 .fetch_add(1, Ordering::Relaxed);
1701 continue;
1702 }
1703 Err(e) => {
1704 tracing::warn!(error = %e, "failed to synthesise cached response; falling back to backend");
1705 }
1706 }
1707 }
1708
1709 if let Some(ref mirror) = state.mirror {
1713 if let Some(sql) = crate::protocol::query_text(&msg.payload) {
1714 mirror.offer(sql, Self::is_write_query(sql));
1715 }
1716 }
1717
1718 #[cfg(feature = "wasm-plugins")]
1719 let forward_start = std::time::Instant::now();
1720 let fr = Self::forward_simple_query(
1721 stream,
1722 &msg,
1723 &mut conns,
1724 current_node.as_deref(),
1725 session,
1726 state,
1727 config,
1728 )
1729 .await;
1730 #[cfg(feature = "wasm-plugins")]
1731 Self::fire_post_query_hook(
1732 &msg,
1733 session,
1734 state,
1735 &fr,
1736 forward_start.elapsed(),
1737 );
1738 let (used_node, sent) = fr?;
1739 if let Some(n) = used_node {
1740 current_node = Some(n);
1741 }
1742 #[cfg(feature = "pool-modes")]
1745 Self::release_to_pool_if_idle(
1746 &mut conns,
1747 current_node.as_deref(),
1748 session,
1749 state,
1750 config,
1751 )
1752 .await;
1753 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1754 state
1755 .metrics
1756 .queries_processed
1757 .fetch_add(1, Ordering::Relaxed);
1758 }
1759
1760 MessageType::Parse
1762 | MessageType::Bind
1763 | MessageType::Describe
1764 | MessageType::Execute
1765 | MessageType::Close => {
1766 let mut add_to_pending = true;
1770 match msg.msg_type {
1771 MessageType::Parse => {
1772 let name = Self::parse_stmt_name(&msg.payload);
1776 let unnamed = name.is_empty();
1777 if !unnamed {
1778 let name = name.to_string();
1779 let existed = stmt_registry.contains_key(&name);
1780 if !existed
1784 && stmt_registry.len() >= Self::MAX_PREPARED_STATEMENTS
1785 {
1786 let emsg = Self::create_error_response(
1787 "54000",
1788 "too many prepared statements for this session",
1789 );
1790 let _ = stream.write_all(&emsg).await;
1791 let _ = stream
1792 .write_all(&Self::create_ready_for_query(b'I'))
1793 .await;
1794 tracing::warn!(
1795 client = %session.client_addr,
1796 limit = Self::MAX_PREPARED_STATEMENTS,
1797 "prepared-statement cap exceeded; closing connection"
1798 );
1799 return Ok(());
1800 }
1801 let encoded = msg.encode().freeze();
1802 let old_len = stmt_registry
1806 .get(&name)
1807 .map(|b| b.len())
1808 .unwrap_or(0);
1809 let projected = stmt_registry_bytes
1810 .saturating_sub(old_len)
1811 + encoded.len();
1812 if projected > Self::MAX_PREPARED_BYTES {
1813 let emsg = Self::create_error_response(
1814 "54000",
1815 "prepared-statement memory limit exceeded for this session",
1816 );
1817 let _ = stream.write_all(&emsg).await;
1818 let _ = stream
1819 .write_all(&Self::create_ready_for_query(b'I'))
1820 .await;
1821 tracing::warn!(
1822 client = %session.client_addr,
1823 limit = Self::MAX_PREPARED_BYTES,
1824 "prepared-statement byte cap exceeded; closing connection"
1825 );
1826 return Ok(());
1827 }
1828 stmt_registry.insert(name.clone(), encoded);
1829 stmt_registry_bytes = projected;
1830 batch_defines.push(name);
1831 }
1832 if pending_route_sql.is_none() {
1833 if let Some(end) = msg.payload.iter().position(|&b| b == 0) {
1834 if let Some(q) =
1835 crate::protocol::query_text(&msg.payload[end + 1..])
1836 {
1837 if !q.is_empty() {
1838 pending_route_sql = Some(q.to_string());
1839 #[cfg(feature = "anomaly-detection")]
1840 Self::record_anomaly_sql(q, state, session);
1841 }
1842 }
1843 }
1844 }
1845 if promote_unnamed
1852 && unnamed
1853 && pending.is_empty()
1854 && held_unnamed.is_none()
1855 {
1856 let sig = bytes::Bytes::copy_from_slice(&msg.payload[1..]);
1857 held_unnamed = Some((msg.encode().freeze(), sig));
1858 add_to_pending = false;
1859 } else if let Some((held_msg, _)) = held_unnamed.take() {
1860 let mut combined =
1861 BytesMut::with_capacity(held_msg.len() + pending.len());
1862 combined.extend_from_slice(&held_msg);
1863 combined.extend_from_slice(&pending);
1864 pending = combined;
1865 }
1866 }
1867 MessageType::Bind => {
1868 if let Some(name) = Self::bind_stmt_ref(&msg.payload) {
1869 batch_refs.push(name.to_string());
1870 }
1871 }
1872 MessageType::Describe => {
1873 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1874 batch_refs.push(name.to_string());
1875 }
1876 }
1877 MessageType::Close => {
1878 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1879 batch_closes.push(name.to_string());
1880 }
1881 }
1882 _ => {}
1883 }
1884 if add_to_pending {
1885 pending.extend_from_slice(&msg.encode());
1886 }
1887 }
1888
1889 MessageType::Sync | MessageType::Flush => {
1891 let wait_ready = msg.msg_type == MessageType::Sync;
1892 pending.extend_from_slice(&msg.encode());
1893 let batch = pending.split().freeze();
1894 let reprepare: Vec<String> = batch_refs
1898 .iter()
1899 .filter(|r| !batch_defines.contains(r))
1900 .cloned()
1901 .collect();
1902 let (used_node, sent) = Self::forward_extended_batch(
1903 stream,
1904 &batch,
1905 pending_route_sql.as_deref(),
1906 wait_ready,
1907 &mut conns,
1908 current_node.as_deref(),
1909 &stmt_registry,
1910 &reprepare,
1911 &batch_defines,
1912 held_unnamed.take(),
1913 session,
1914 state,
1915 config,
1916 )
1917 .await?;
1918 if let Some(n) = used_node {
1919 current_node = Some(n);
1920 }
1921 #[cfg(feature = "pool-modes")]
1925 if wait_ready {
1926 Self::release_to_pool_if_idle(
1927 &mut conns,
1928 current_node.as_deref(),
1929 session,
1930 state,
1931 config,
1932 )
1933 .await;
1934 }
1935 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1936 for name in batch_closes.drain(..) {
1939 if let Some(removed) = stmt_registry.remove(&name) {
1940 stmt_registry_bytes =
1941 stmt_registry_bytes.saturating_sub(removed.len());
1942 }
1943 }
1944 if wait_ready {
1945 pending_route_sql = None;
1949 batch_defines.clear();
1950 batch_refs.clear();
1951 state
1952 .metrics
1953 .queries_processed
1954 .fetch_add(1, Ordering::Relaxed);
1955 }
1956 }
1957
1958 MessageType::CopyData | MessageType::CopyDone | MessageType::CopyFail => {
1960 if let Some(node) = current_node.clone() {
1961 if let Some(b) = conns.get_mut(&node) {
1962 b.stream.write_all(&msg.encode()).await.map_err(|e| {
1963 ProxyError::Network(format!("Backend copy write error: {}", e))
1964 })?;
1965 if matches!(
1966 msg.msg_type,
1967 MessageType::CopyDone | MessageType::CopyFail
1968 ) {
1969 let r = Self::stream_until_ready(
1970 stream,
1971 &mut b.stream,
1972 session,
1973 state,
1974 )
1975 .await;
1976 match r {
1977 Ok(sent) => {
1978 state
1979 .metrics
1980 .bytes_sent
1981 .fetch_add(sent, Ordering::Relaxed);
1982 }
1983 Err(e) => {
1984 conns.remove(&node);
1985 return Err(e);
1986 }
1987 }
1988 }
1989 }
1990 }
1991 }
1992
1993 _ => {
1995 if let Some(ref node) = current_node {
1996 if let Some(b) = conns.get_mut(node) {
1997 let _ = b.stream.write_all(&msg.encode()).await;
1998 }
1999 }
2000 }
2001 }
2002 }
2003
2004 if pending.len() > Self::MAX_PENDING_BYTES {
2008 let emsg = Self::create_error_response(
2009 "53400",
2010 "un-flushed extended-protocol buffer exceeds per-session limit",
2011 );
2012 let _ = stream.write_all(&emsg).await;
2013 let _ = stream.write_all(&Self::create_ready_for_query(b'I')).await;
2014 tracing::warn!(
2015 client = %session.client_addr,
2016 pending = pending.len(),
2017 "pending extended-protocol buffer cap exceeded; closing connection"
2018 );
2019 return Ok(());
2020 }
2021 }
2022
2023 #[cfg(feature = "pool-modes")]
2027 if state.backend_pool.is_some() {
2028 let nodes: Vec<String> = conns.keys().cloned().collect();
2029 for node in nodes {
2030 Self::release_to_pool_if_idle(
2031 &mut conns,
2032 Some(node.as_str()),
2033 session,
2034 state,
2035 config,
2036 )
2037 .await;
2038 }
2039 }
2040
2041 Ok(())
2042 }
2043
2044 async fn negotiate_client_tls(
2051 mut tcp: TcpStream,
2052 state: &Arc<ServerState>,
2053 ) -> Result<(ClientStream, Option<StartupMessage>)> {
2054 let codec = ProtocolCodec::new();
2055 let mut buffer = BytesMut::with_capacity(1024);
2056 let mut read_buf = vec![0u8; 1024];
2057
2058 let first = loop {
2059 if let Some(msg) = codec.decode_startup(&mut buffer)? {
2060 break msg;
2061 }
2062 let n = tcp
2063 .read(&mut read_buf)
2064 .await
2065 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
2066 if n == 0 {
2067 return Err(ProxyError::Connection(
2068 "client closed before startup".to_string(),
2069 ));
2070 }
2071 buffer.extend_from_slice(&read_buf[..n]);
2072 };
2073
2074 match first {
2075 StartupMessage::SSLRequest => match state.tls_acceptor.as_ref() {
2076 Some(acceptor) => {
2077 tcp.write_all(b"S")
2078 .await
2079 .map_err(|e| ProxyError::Network(format!("SSL accept write: {}", e)))?;
2080 let tls = acceptor
2081 .accept(tcp)
2082 .await
2083 .map_err(|e| ProxyError::Network(format!("TLS handshake failed: {}", e)))?;
2084 if tls.get_ref().1.peer_certificates().is_some() {
2085 tracing::debug!("client presented a certificate (mTLS)");
2086 }
2087 Ok((ClientStream::Tls(Box::new(tls)), None))
2088 }
2089 None => {
2090 tcp.write_all(b"N")
2091 .await
2092 .map_err(|e| ProxyError::Network(format!("SSL reject write: {}", e)))?;
2093 Ok((ClientStream::Plain(tcp), None))
2094 }
2095 },
2096 other => Ok((ClientStream::Plain(tcp), Some(other))),
2097 }
2098 }
2099
2100 async fn handle_startup(
2104 client_stream: &mut ClientStream,
2105 buffer: &mut BytesMut,
2106 codec: &ProtocolCodec,
2107 pre: Option<StartupMessage>,
2108 session: &Arc<ClientSession>,
2109 state: &Arc<ServerState>,
2110 config: &ProxyConfig,
2111 ) -> Result<(Option<TcpStream>, String)> {
2112 let startup_msg = match pre {
2115 Some(msg) => Some(msg),
2116 None => {
2117 let mut read_buf = vec![0u8; 1024];
2118 loop {
2119 if let Some(msg) = codec.decode_startup(buffer)? {
2120 break Some(msg);
2121 }
2122 let n = client_stream
2123 .read(&mut read_buf)
2124 .await
2125 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
2126 if n == 0 {
2127 return Ok((None, String::new()));
2128 }
2129 buffer.extend_from_slice(&read_buf[..n]);
2130 }
2131 }
2132 };
2133
2134 match startup_msg {
2135 Some(StartupMessage::SSLRequest) => {
2136 client_stream
2139 .write_all(b"N")
2140 .await
2141 .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
2142 Err(ProxyError::Protocol(
2143 "unexpected SSLRequest after startup".to_string(),
2144 ))
2145 }
2146 Some(StartupMessage::CancelRequest { pid, key }) => {
2147 Self::forward_cancel_request(state, pid, key).await;
2150 Ok((None, String::new()))
2151 }
2152 Some(StartupMessage::Startup { params, .. }) => {
2153 Self::connect_and_authenticate(client_stream, ¶ms, session, state, config).await
2154 }
2155 None => Err(ProxyError::Protocol(
2156 "Incomplete startup message".to_string(),
2157 )),
2158 }
2159 }
2160
2161 fn hba_admits(rules: &[HbaRule], ip: std::net::IpAddr, user: &str, database: &str) -> bool {
2164 for r in rules {
2165 let user_ok = r.user == "all" || r.user == user;
2166 let db_ok = r.database == "all" || r.database == database;
2167 if user_ok && db_ok && Self::hba_addr_matches(&r.address, ip) {
2168 return r.action == HbaAction::Allow;
2169 }
2170 }
2171 true
2172 }
2173
2174 fn hba_addr_matches(spec: &str, ip: std::net::IpAddr) -> bool {
2177 use std::net::IpAddr;
2178 if spec == "all" {
2179 return true;
2180 }
2181 if let Some((net, bits)) = spec.split_once('/') {
2182 let bits: u32 = match bits.parse() {
2183 Ok(b) => b,
2184 Err(_) => return false,
2185 };
2186 match (net.parse::<IpAddr>(), ip) {
2187 (Ok(IpAddr::V4(n)), IpAddr::V4(i)) if bits <= 32 => {
2188 let mask = if bits == 0 {
2189 0
2190 } else {
2191 u32::MAX << (32 - bits)
2192 };
2193 (u32::from(n) & mask) == (u32::from(i) & mask)
2194 }
2195 (Ok(IpAddr::V6(n)), IpAddr::V6(i)) if bits <= 128 => {
2196 let mask = if bits == 0 {
2197 0
2198 } else {
2199 u128::MAX << (128 - bits)
2200 };
2201 (u128::from(n) & mask) == (u128::from(i) & mask)
2202 }
2203 _ => false,
2204 }
2205 } else {
2206 spec.parse::<IpAddr>().map(|s| s == ip).unwrap_or(false)
2207 }
2208 }
2209
2210 async fn proxy_scram_auth(
2216 client: &mut ClientStream,
2217 user: &str,
2218 state: &Arc<ServerState>,
2219 ) -> std::result::Result<(), String> {
2220 use crate::auth_scram::ScramServer;
2221 let auth_file = state.auth_file.as_ref().ok_or("scram not configured")?;
2222
2223 let mut sasl = BytesMut::new();
2225 sasl.put_i32(10); sasl.extend_from_slice(b"SCRAM-SHA-256\0");
2227 sasl.put_u8(0); Self::write_auth_frame(client, &sasl).await?;
2229
2230 let init = Self::read_password_message(client).await?;
2232 let mech_end = init
2233 .iter()
2234 .position(|&b| b == 0)
2235 .ok_or("malformed SASLInitialResponse (no mechanism)")?;
2236 if init.len() < mech_end + 5 {
2237 return Err("short SASLInitialResponse".into());
2238 }
2239 let client_first =
2240 std::str::from_utf8(&init[mech_end + 5..]).map_err(|_| "client-first not UTF-8")?;
2241
2242 let verifier = auth_file.get(user).ok_or("no such user")?.clone();
2244
2245 let server_nonce = Self::random_nonce();
2247 let (server, server_first) = ScramServer::start(verifier, client_first, &server_nonce)?;
2248
2249 let mut cont = BytesMut::new();
2251 cont.put_i32(11);
2252 cont.extend_from_slice(server_first.as_bytes());
2253 Self::write_auth_frame(client, &cont).await?;
2254
2255 let client_final_raw = Self::read_password_message(client).await?;
2257 let client_final =
2258 std::str::from_utf8(&client_final_raw).map_err(|_| "client-final not UTF-8")?;
2259
2260 let server_final = server.finish(client_final)?;
2262
2263 let mut fin = BytesMut::new();
2265 fin.put_i32(12);
2266 fin.extend_from_slice(server_final.as_bytes());
2267 Self::write_auth_frame(client, &fin).await?;
2268 Ok(())
2269 }
2270
2271 async fn write_auth_frame(
2273 client: &mut ClientStream,
2274 payload: &[u8],
2275 ) -> std::result::Result<(), String> {
2276 let mut frame = BytesMut::with_capacity(payload.len() + 5);
2277 frame.put_u8(b'R');
2278 frame.put_u32((payload.len() + 4) as u32);
2279 frame.extend_from_slice(payload);
2280 client
2281 .write_all(&frame)
2282 .await
2283 .map_err(|e| format!("client write: {}", e))
2284 }
2285
2286 async fn read_password_message(
2289 client: &mut ClientStream,
2290 ) -> std::result::Result<BytesMut, String> {
2291 let codec = ProtocolCodec::new();
2292 let mut buffer = BytesMut::with_capacity(1024);
2293 let mut read_buf = vec![0u8; 1024];
2294 loop {
2295 if let Some(msg) = codec
2296 .decode_message(&mut buffer)
2297 .map_err(|e| format!("decode: {}", e))?
2298 {
2299 if msg.msg_type == MessageType::Password {
2300 return Ok(msg.payload);
2301 }
2302 return Err(format!("expected SASL response, got {:?}", msg.msg_type));
2303 }
2304 let n = client
2305 .read(&mut read_buf)
2306 .await
2307 .map_err(|e| format!("client read: {}", e))?;
2308 if n == 0 {
2309 return Err("client closed during SASL".into());
2310 }
2311 buffer.extend_from_slice(&read_buf[..n]);
2312 }
2313 }
2314
2315 fn random_nonce() -> String {
2317 use rand::Rng;
2318 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
2319 let mut rng = rand::thread_rng();
2320 (0..24)
2321 .map(|_| CHARS[rng.gen_range(0..CHARS.len())] as char)
2322 .collect()
2323 }
2324
2325 async fn connect_and_authenticate(
2327 client_stream: &mut ClientStream,
2328 params: &HashMap<String, String>,
2329 session: &Arc<ClientSession>,
2330 state: &Arc<ServerState>,
2331 config: &ProxyConfig,
2332 ) -> Result<(Option<TcpStream>, String)> {
2333 let user = params.get("user").map(String::as_str).unwrap_or("");
2336 let database = params.get("database").map(String::as_str).unwrap_or(user);
2337 if !Self::hba_admits(&config.hba, session.client_addr.ip(), user, database) {
2338 tracing::info!(%user, %database, client = %session.client_addr, "connection rejected by hba rule");
2339 let err = Self::create_error_response(
2340 "28000",
2341 "connection rejected by proxy admission rules",
2342 );
2343 let _ = client_stream.write_all(&err).await;
2344 return Ok((None, String::new()));
2345 }
2346
2347 if state.auth_file.is_some() {
2353 if let Err(e) = Self::proxy_scram_auth(client_stream, user, state).await {
2354 tracing::info!(%user, error = %e, "proxy SCRAM auth failed");
2355 let err =
2356 Self::create_error_response("28P01", &format!("authentication failed: {}", e));
2357 let _ = client_stream.write_all(&err).await;
2358 return Ok((None, String::new()));
2359 }
2360 tracing::debug!(%user, "client authenticated by proxy SCRAM");
2361 }
2362
2363 Self::apply_authenticate_hook(params, session, state).await?;
2369
2370 let cutover = state.cutover.load_full();
2374 let (node_addr, effective_params) = if let Some(t) = cutover.as_ref() {
2375 let mut p = params.clone();
2376 p.insert("user".to_string(), t.user.clone());
2377 if let Some(ref db) = t.database {
2378 p.insert("database".to_string(), db.clone());
2379 } else {
2380 p.remove("database");
2381 }
2382 tracing::debug!(target = %t.addr, "routing connection to cutover target");
2383 (t.addr.clone(), p)
2384 } else {
2385 (
2386 Self::select_node(session, state, config).await?,
2387 params.clone(),
2388 )
2389 };
2390
2391 let mut backend = match tokio::time::timeout(
2396 config.pool.acquire_timeout(),
2397 TcpStream::connect(&node_addr),
2398 )
2399 .await
2400 {
2401 Ok(Ok(s)) => s,
2402 Ok(Err(e)) => {
2403 let msg = format!("Failed to connect to {}: {}", node_addr, e);
2404 Self::note_backend_failure(state, &node_addr, &msg);
2405 return Err(ProxyError::Connection(msg));
2406 }
2407 Err(_) => {
2408 let msg = format!("Connection timeout to {}", node_addr);
2409 Self::note_backend_failure(state, &node_addr, &msg);
2410 return Err(ProxyError::Connection(msg));
2411 }
2412 };
2413 let _ = backend.set_nodelay(true);
2414
2415 let params = &effective_params;
2417 let startup_bytes = Self::build_startup_message(params);
2418 backend
2419 .write_all(&startup_bytes)
2420 .await
2421 .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
2422
2423 Self::proxy_authentication(client_stream, &mut backend, state, &node_addr).await?;
2427
2428 {
2430 let mut vars = session.variables.write().await;
2431 for (k, v) in params {
2432 vars.insert(k.clone(), v.clone());
2433 }
2434 }
2435
2436 Ok((Some(backend), node_addr))
2437 }
2438
2439 fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
2441 let mut payload = BytesMut::new();
2442
2443 payload.put_u32(196608);
2445
2446 for (key, value) in params {
2448 payload.extend_from_slice(key.as_bytes());
2449 payload.put_u8(0);
2450 payload.extend_from_slice(value.as_bytes());
2451 payload.put_u8(0);
2452 }
2453 payload.put_u8(0); let mut msg = BytesMut::new();
2457 msg.put_u32((payload.len() + 4) as u32);
2458 msg.extend_from_slice(&payload);
2459
2460 msg.to_vec()
2461 }
2462
2463 const MAX_CANCEL_KEYS: usize = 100_000;
2467
2468 const BACKEND_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
2472 const CLIENT_WRITE_TIMEOUT: Duration = Duration::from_secs(60);
2475 const REPREPARE_TIMEOUT: Duration = Duration::from_secs(15);
2478 const MAX_PREPARED_STATEMENTS: usize = 8192;
2481 const MAX_PREPARED_BYTES: usize = 64 * 1024 * 1024;
2487 const MAX_PENDING_BYTES: usize = 64 * 1024 * 1024;
2490 #[cfg(feature = "pool-modes")]
2494 const MAX_TOTAL_IDLE_BACKEND_CONNS: usize = 8192;
2495 const POOL_REAP_INTERVAL: Duration = Duration::from_secs(30);
2497
2498 fn register_cancel_key(state: &Arc<ServerState>, pid: u32, key: u32, node_addr: &str) {
2500 {
2504 let mut order = state.cancel_order.lock();
2505 while state.cancel_map.len() >= Self::MAX_CANCEL_KEYS {
2506 match order.pop_front() {
2507 Some(old) => {
2508 state.cancel_map.remove(&old);
2509 }
2510 None => {
2511 state.cancel_map.clear();
2514 break;
2515 }
2516 }
2517 }
2518 order.push_back((pid, key));
2519 }
2520 state.cancel_map.insert((pid, key), node_addr.to_string());
2521 }
2522
2523 async fn forward_cancel_request(state: &Arc<ServerState>, pid: u32, key: u32) {
2526 let Some(addr) = state.cancel_map.get(&(pid, key)).map(|e| e.clone()) else {
2527 tracing::debug!(pid, "cancel request for unknown key; ignoring");
2528 return;
2529 };
2530 let mut msg = BytesMut::with_capacity(16);
2532 msg.put_u32(16);
2533 msg.put_u32(80877102);
2534 msg.put_u32(pid);
2535 msg.put_u32(key);
2536 match tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(&addr)).await {
2537 Ok(Ok(mut conn)) => {
2538 let _ = conn.set_nodelay(true);
2539 if let Err(e) = conn.write_all(&msg).await {
2540 tracing::warn!(node = %addr, error = %e, "failed to forward CancelRequest");
2541 }
2542 }
2544 other => {
2545 tracing::warn!(node = %addr, ?other, "could not connect to forward CancelRequest")
2546 }
2547 }
2548 }
2549
2550 async fn proxy_authentication(
2552 client_stream: &mut ClientStream,
2553 backend_stream: &mut TcpStream,
2554 state: &Arc<ServerState>,
2555 node_addr: &str,
2556 ) -> Result<()> {
2557 let codec = ProtocolCodec::new();
2558 let mut backend_buffer = BytesMut::with_capacity(4096);
2559 let mut client_buffer = BytesMut::with_capacity(4096);
2560 let mut read_buf = vec![0u8; 4096];
2561
2562 loop {
2563 let n = backend_stream
2565 .read(&mut read_buf)
2566 .await
2567 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
2568
2569 if n == 0 {
2570 return Err(ProxyError::Connection(
2571 "Backend closed during auth".to_string(),
2572 ));
2573 }
2574
2575 backend_buffer.extend_from_slice(&read_buf[..n]);
2576
2577 client_stream
2579 .write_all(&read_buf[..n])
2580 .await
2581 .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
2582
2583 while let Some(msg) = codec.decode_message(&mut backend_buffer)? {
2587 match msg.msg_type {
2588 MessageType::BackendKeyData
2589 if msg.payload.len() >= 8 => {
2593 let pid = u32::from_be_bytes([
2594 msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3],
2595 ]);
2596 let key = u32::from_be_bytes([
2597 msg.payload[4], msg.payload[5], msg.payload[6], msg.payload[7],
2598 ]);
2599 Self::register_cancel_key(state, pid, key, node_addr);
2600 }
2601 MessageType::AuthRequest
2602 if msg.payload.len() >= 4 => {
2604 let auth_type =
2605 i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
2606 if auth_type == 0 {
2607 }
2609 }
2610 MessageType::ReadyForQuery => {
2611 return Ok(());
2613 }
2614 MessageType::ErrorResponse => {
2615 return Err(ProxyError::Auth("Authentication failed".to_string()));
2617 }
2618 _ => {
2619 }
2621 }
2622 }
2623
2624 let n = tokio::time::timeout(
2627 Duration::from_millis(100),
2628 client_stream.read(&mut read_buf),
2629 )
2630 .await;
2631
2632 if let Ok(Ok(n)) = n {
2633 if n > 0 {
2634 client_buffer.extend_from_slice(&read_buf[..n]);
2635 backend_stream
2636 .write_all(&read_buf[..n])
2637 .await
2638 .map_err(|e| {
2639 ProxyError::Network(format!("Backend password write error: {}", e))
2640 })?;
2641 }
2642 }
2643 }
2644 }
2645
2646 async fn choose_target_node(
2651 is_write: bool,
2652 forced_target: Option<String>,
2653 current_node: Option<&str>,
2654 session: &Arc<ClientSession>,
2655 state: &Arc<ServerState>,
2656 config: &ProxyConfig,
2657 ) -> Result<String> {
2658 if let Some(t) = state.cutover.load_full().as_ref() {
2661 return Ok(t.addr.clone());
2662 }
2663
2664 #[cfg(feature = "lag-routing")]
2668 if !is_write && forced_target.is_none() && config.lag_routing.enabled {
2669 let last_write = *session.last_write_at.read().await;
2670 if Self::ryw_pins_primary(last_write, config.lag_routing.ryw_window_ms) {
2671 tracing::debug!(target: "helios::routing", "read-your-writes: pinning read to primary");
2672 return Self::select_primary_with_timeout(session, state, config).await;
2673 }
2674 }
2675
2676 let need_switch = if let Some(ref forced) = forced_target {
2677 let health = state.health.load_full();
2678 let reuse = current_node
2679 .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
2680 .unwrap_or(false);
2681 !reuse
2682 } else if let Some(current) = current_node {
2683 let health = state.health.load_full();
2684 let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
2685 if !current_healthy {
2686 true
2687 } else if is_write {
2688 let is_primary = config
2689 .nodes
2690 .iter()
2691 .find(|n| n.address() == *current)
2692 .map(|n| n.role == NodeRole::Primary)
2693 .unwrap_or(false);
2694 !is_primary
2695 } else {
2696 false
2697 }
2698 } else {
2699 true
2700 };
2701
2702 if let Some(forced) = forced_target {
2703 let resolved = config
2708 .nodes
2709 .iter()
2710 .find(|n| n.name.as_deref() == Some(forced.as_str()) || n.address() == forced)
2711 .map(|n| n.address())
2712 .unwrap_or(forced);
2713 Ok(resolved)
2714 } else if need_switch {
2715 if is_write {
2716 Self::select_primary_with_timeout(session, state, config).await
2717 } else {
2718 Self::select_read_node(session, state, config).await
2719 }
2720 } else {
2721 Ok(current_node.unwrap().to_string())
2722 }
2723 }
2724
2725 async fn ensure_conn(
2730 conns: &mut HashMap<String, BackendConn>,
2731 target: &str,
2732 session: &Arc<ClientSession>,
2733 config: &ProxyConfig,
2734 _state: &Arc<ServerState>,
2735 ) -> Result<()> {
2736 if conns.contains_key(target) {
2737 return Ok(());
2738 }
2739
2740 #[cfg(feature = "pool-modes")]
2745 if let Some(pool) = _state.backend_pool.as_ref() {
2746 let key = Self::pool_key_for(target, session).await;
2747 if let Some(stream) = pool.checkout(&key) {
2748 tracing::info!(
2749 target: "helios::pool",
2750 node = %target,
2751 "reused pooled backend connection"
2752 );
2753 conns.insert(target.to_string(), BackendConn::new(stream));
2754 return Ok(());
2755 }
2756 }
2757
2758 let mut backend =
2759 tokio::time::timeout(config.pool.acquire_timeout(), TcpStream::connect(target))
2760 .await
2761 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target)))?
2762 .map_err(|e| {
2763 ProxyError::Connection(format!("Failed to connect to {}: {}", target, e))
2764 })?;
2765 let _ = backend.set_nodelay(true);
2766
2767 let params = session.variables.read().await.clone();
2768 let startup = Self::build_startup_message(¶ms);
2769 backend
2770 .write_all(&startup)
2771 .await
2772 .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
2773 Self::complete_backend_auth(&mut backend).await?;
2774 #[cfg(feature = "pool-modes")]
2775 if _state.backend_pool.is_some() {
2776 tracing::debug!(target: "helios::pool", node = %target, "dialed fresh backend connection (pool miss)");
2777 }
2778 tracing::debug!(node = %target, "opened backend connection");
2779 conns.insert(target.to_string(), BackendConn::new(backend));
2780 Ok(())
2781 }
2782
2783 #[cfg(feature = "pool-modes")]
2788 async fn pool_key_for(target: &str, session: &Arc<ClientSession>) -> String {
2789 let vars = session.variables.read().await;
2790 let user = vars.get("user").map(|s| s.as_str()).unwrap_or("");
2791 let database = vars.get("database").map(|s| s.as_str()).unwrap_or(user);
2793 crate::pool::pool_key(target, user, database)
2794 }
2795
2796 #[cfg(feature = "pool-modes")]
2803 async fn reset_backend(stream: &mut TcpStream, reset_sql: &str) -> Result<()> {
2804 let msg = crate::protocol::QueryMessage {
2805 query: reset_sql.to_string(),
2806 }
2807 .encode();
2808 stream
2809 .write_all(&msg.encode())
2810 .await
2811 .map_err(|e| ProxyError::Network(format!("reset write error: {}", e)))?;
2812
2813 let codec = ProtocolCodec::new();
2814 let mut buffer = BytesMut::with_capacity(1024);
2815 let mut read_buf = vec![0u8; 1024];
2816 loop {
2817 while let Some(m) = codec.decode_message(&mut buffer)? {
2818 if m.msg_type == MessageType::ReadyForQuery {
2819 return Ok(());
2820 }
2821 }
2822 let n = tokio::time::timeout(Duration::from_secs(5), stream.read(&mut read_buf))
2823 .await
2824 .map_err(|_| ProxyError::Network("reset drain timeout".to_string()))?
2825 .map_err(|e| ProxyError::Network(format!("reset drain read error: {}", e)))?;
2826 if n == 0 {
2827 return Err(ProxyError::Connection(
2828 "backend closed during reset".to_string(),
2829 ));
2830 }
2831 buffer.extend_from_slice(&read_buf[..n]);
2832 }
2833 }
2834
2835 #[cfg(feature = "pool-modes")]
2841 async fn release_to_pool_if_idle(
2842 conns: &mut HashMap<String, BackendConn>,
2843 node: Option<&str>,
2844 session: &Arc<ClientSession>,
2845 state: &Arc<ServerState>,
2846 config: &ProxyConfig,
2847 ) {
2848 let Some(pool) = state.backend_pool.as_ref() else {
2849 return;
2850 };
2851 let Some(node) = node else {
2852 return;
2853 };
2854 if session.tx_state.read().await.in_transaction {
2856 return;
2857 }
2858 let Some(mut bc) = conns.remove(node) else {
2859 return;
2860 };
2861 if Self::reset_backend(&mut bc.stream, &config.pool_mode.reset_query)
2862 .await
2863 .is_ok()
2864 {
2865 let key = Self::pool_key_for(node, session).await;
2866 if pool.checkin(&key, bc.stream) {
2867 tracing::debug!(target: "helios::pool", node = %node, "parked backend connection for reuse");
2868 }
2869 }
2870 }
2872
2873 async fn forward_simple_query(
2879 client: &mut ClientStream,
2880 msg: &Message,
2881 conns: &mut HashMap<String, BackendConn>,
2882 current_node: Option<&str>,
2883 session: &Arc<ClientSession>,
2884 state: &Arc<ServerState>,
2885 config: &ProxyConfig,
2886 ) -> Result<(Option<String>, u64)> {
2887 #[cfg(feature = "rate-limiting")]
2889 if let Some(mut resp) = Self::rate_limit_check(session, state, config).await {
2890 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
2891 client
2892 .write_all(&resp)
2893 .await
2894 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2895 return Ok((None, resp.len() as u64));
2896 }
2897
2898 let default_is_write = Self::is_write_message(msg);
2899 let plugin_override = Self::apply_route_hook(msg, state, session);
2900
2901 if let RouteOverride::Block(reason) = plugin_override {
2903 let mut response = Vec::with_capacity(64 + reason.len());
2904 response.extend_from_slice(&Self::create_error_response(
2905 "42000",
2906 &format!("Query blocked by route plugin: {}", reason),
2907 ));
2908 response.extend_from_slice(&Self::create_ready_for_query(b'I'));
2909 client
2910 .write_all(&response)
2911 .await
2912 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2913 return Ok((None, response.len() as u64));
2914 }
2915
2916 #[cfg(feature = "routing-hints")]
2920 let (route_override, default_is_write, stripped_msg) =
2921 Self::resolve_simple_route(msg, plugin_override, default_is_write, state);
2922 #[cfg(not(feature = "routing-hints"))]
2923 let (route_override, stripped_msg): (RouteOverride, Option<Message>) =
2924 (plugin_override, None);
2925
2926 let (is_write, forced_target) = match route_override {
2927 RouteOverride::None => (default_is_write, None),
2928 RouteOverride::Primary => (true, None),
2929 RouteOverride::Standby => (false, None),
2930 RouteOverride::Node(name) => (default_is_write, Some(name)),
2931 RouteOverride::Block(_) => unreachable!("handled above"),
2932 };
2933
2934 #[cfg(feature = "lag-routing")]
2937 if is_write && config.lag_routing.enabled {
2938 *session.last_write_at.write().await = Some(std::time::Instant::now());
2939 }
2940
2941 let forward_msg = stripped_msg.as_ref().unwrap_or(msg);
2944
2945 #[cfg(feature = "query-rewriting")]
2949 let rewritten_msg: Option<Message> = state.rewriter.as_ref().and_then(|rw| {
2950 let sql = crate::protocol::query_text(&forward_msg.payload)?;
2951 match rw.rewrite(sql) {
2952 Ok(res) if res.was_rewritten() => {
2953 tracing::debug!(target: "helios::rewrite", rules = ?res.rules_applied, "query rewritten");
2954 Some(crate::protocol::QueryMessage { query: res.query().to_string() }.encode())
2955 }
2956 _ => None,
2957 }
2958 });
2959 #[cfg(feature = "query-rewriting")]
2960 let forward_msg = rewritten_msg.as_ref().unwrap_or(forward_msg);
2961
2962 #[cfg(feature = "multi-tenancy")]
2966 let tenant_msg: Option<Message> = if let Some(tm) = state.tenant_manager.as_ref() {
2967 match crate::protocol::query_text(&forward_msg.payload) {
2968 Some(sql) => {
2969 let ctx = Self::tenant_request_ctx(session).await;
2970 match tm.identify_tenant(&ctx) {
2971 Some(tenant) => {
2972 let res = tm.transform_query(sql, &tenant);
2973 if res.transformed {
2974 tracing::debug!(target: "helios::tenant", tenant = %tenant.0, "tenant filter injected");
2975 Some(crate::protocol::QueryMessage { query: res.query }.encode())
2976 } else {
2977 None
2978 }
2979 }
2980 None => None,
2981 }
2982 }
2983 None => None,
2984 }
2985 } else {
2986 None
2987 };
2988 #[cfg(feature = "multi-tenancy")]
2989 let forward_msg = tenant_msg.as_ref().unwrap_or(forward_msg);
2990
2991 #[cfg(feature = "query-cache")]
2994 let cache_ctx: Option<crate::cache::CacheContext> = if is_write {
2995 None
2996 } else if let Some(qc) = state.query_cache.as_ref() {
2997 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
2998 match Self::cacheable_read_ctx(session, sql).await {
2999 Some(ctx) => {
3000 if let crate::cache::CacheLookup::Hit { result, level } =
3001 qc.get(sql, &ctx).await
3002 {
3003 tracing::debug!(target: "helios::cache", level = %level, "cache hit");
3004 client.write_all(&result.data).await.map_err(|e| {
3005 ProxyError::Network(format!("Client write error: {}", e))
3006 })?;
3007 return Ok((None, result.data.len() as u64));
3008 }
3009 Some(ctx)
3010 }
3011 None => None,
3012 }
3013 } else {
3014 None
3015 };
3016
3017 #[cfg(feature = "schema-routing")]
3020 let forced_target = match state.schema_analyzer.as_ref() {
3021 Some(analyzer)
3022 if forced_target.is_none()
3023 && !is_write
3024 && !config.schema_routing.analytics_node.is_empty() =>
3025 {
3026 match crate::protocol::query_text(&forward_msg.payload) {
3027 Some(sql) if analyzer.analyze(sql).is_analytics() => {
3028 tracing::debug!(target: "helios::schema", "OLAP query routed to analytics node");
3029 Some(config.schema_routing.analytics_node.clone())
3030 }
3031 _ => forced_target,
3032 }
3033 }
3034 _ => forced_target,
3035 };
3036
3037 #[cfg(feature = "query-analytics")]
3039 let analytics_sql =
3040 crate::protocol::query_text(&forward_msg.payload).map(|s| s.to_string());
3041 #[cfg(feature = "query-analytics")]
3042 let started = std::time::Instant::now();
3043
3044 let target = Self::choose_target_node(
3045 is_write,
3046 forced_target,
3047 current_node,
3048 session,
3049 state,
3050 config,
3051 )
3052 .await?;
3053 tracing::debug!(target: "helios::routing", node = %target, is_write, "routed simple query");
3054
3055 #[cfg(feature = "circuit-breaker")]
3057 if let Some(mut resp) = Self::circuit_fast_fail(state, &target) {
3058 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3059 client
3060 .write_all(&resp)
3061 .await
3062 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3063 return Ok((None, resp.len() as u64));
3064 }
3065
3066 if let Err(e) = Self::ensure_conn(conns, &target, session, config, state).await {
3068 Self::record_backend_failure(state, &target, &e.to_string());
3069 return Err(e);
3070 }
3071 let backend = conns.get_mut(&target).expect("just ensured");
3072
3073 let backend_err = match tokio::time::timeout(
3074 Self::BACKEND_WRITE_TIMEOUT,
3075 backend.stream.write_all(&forward_msg.encode()),
3076 )
3077 .await
3078 {
3079 Ok(Ok(())) => None,
3080 Ok(Err(e)) => Some(format!("Backend write error: {}", e)),
3081 Err(_) => Some("Backend write timeout".to_string()),
3082 };
3083 if let Some(msg) = backend_err {
3084 let e = ProxyError::Network(msg);
3085 conns.remove(&target);
3086 Self::record_backend_failure(state, &target, &e.to_string());
3087 return Err(e);
3088 }
3089
3090 #[cfg(feature = "query-cache")]
3093 if let (Some(ctx), Some(qc)) = (cache_ctx.as_ref(), state.query_cache.as_ref()) {
3094 return match Self::stream_until_ready_capture(client, &mut backend.stream, session)
3095 .await
3096 {
3097 Ok((sent, captured, cacheable, rows)) => {
3098 #[cfg(feature = "circuit-breaker")]
3099 Self::circuit_record(state, &target, true, "");
3100 if cacheable && !captured.is_empty() {
3101 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
3102 qc.put(
3103 sql,
3104 ctx,
3105 bytes::Bytes::from(captured),
3106 rows,
3107 std::time::Duration::ZERO,
3108 )
3109 .await;
3110 }
3111 #[cfg(feature = "query-analytics")]
3112 if let Some(sql) = analytics_sql.as_deref() {
3113 Self::record_analytics(
3114 state,
3115 session,
3116 sql,
3117 &target,
3118 started.elapsed(),
3119 None,
3120 )
3121 .await;
3122 }
3123 Ok((Some(target), sent))
3124 }
3125 Err(e) => {
3126 conns.remove(&target);
3127 Self::record_backend_failure(state, &target, &e.to_string());
3128 Err(e)
3129 }
3130 };
3131 }
3132
3133 match Self::stream_until_ready(client, &mut backend.stream, session, state).await {
3134 Ok(sent) => {
3135 #[cfg(feature = "circuit-breaker")]
3136 Self::circuit_record(state, &target, true, "");
3137 #[cfg(feature = "query-cache")]
3139 if is_write {
3140 if let Some(qc) = state.query_cache.as_ref() {
3141 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
3142 qc.invalidate_query(sql).await;
3143 }
3144 }
3145 #[cfg(feature = "ha-tr")]
3147 if is_write && config.tr_enabled {
3148 if let Some(sql) = crate::protocol::query_text(&forward_msg.payload) {
3149 Self::journal_write(state, session, sql).await;
3150 }
3151 }
3152 #[cfg(feature = "query-analytics")]
3153 if let Some(sql) = analytics_sql.as_deref() {
3154 Self::record_analytics(state, session, sql, &target, started.elapsed(), None)
3155 .await;
3156 }
3157 Ok((Some(target), sent))
3158 }
3159 Err(e) => {
3160 conns.remove(&target);
3162 Self::record_backend_failure(state, &target, &e.to_string());
3163 #[cfg(feature = "query-analytics")]
3164 if let Some(sql) = analytics_sql.as_deref() {
3165 Self::record_analytics(
3166 state,
3167 session,
3168 sql,
3169 &target,
3170 started.elapsed(),
3171 Some(e.to_string()),
3172 )
3173 .await;
3174 }
3175 Err(e)
3176 }
3177 }
3178 }
3179
3180 #[allow(clippy::too_many_arguments)]
3193 async fn forward_extended_batch(
3194 client: &mut ClientStream,
3195 batch: &[u8],
3196 route_sql: Option<&str>,
3197 wait_ready: bool,
3198 conns: &mut HashMap<String, BackendConn>,
3199 current_node: Option<&str>,
3200 registry: &HashMap<String, bytes::Bytes>,
3201 reprepare: &[String],
3202 defines: &[String],
3203 unnamed: Option<(bytes::Bytes, bytes::Bytes)>,
3204 session: &Arc<ClientSession>,
3205 state: &Arc<ServerState>,
3206 config: &ProxyConfig,
3207 ) -> Result<(Option<String>, u64)> {
3208 #[cfg(feature = "rate-limiting")]
3212 if let Some(mut resp) = Self::rate_limit_check(session, state, config).await {
3213 if wait_ready {
3214 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3215 }
3216 client
3217 .write_all(&resp)
3218 .await
3219 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3220 return Ok((None, resp.len() as u64));
3221 }
3222
3223 #[cfg(feature = "query-analytics")]
3225 let analytics_sql = route_sql.map(|s| s.to_string());
3226 #[cfg(feature = "query-analytics")]
3227 let started = std::time::Instant::now();
3228
3229 let target = match route_sql {
3230 Some(sql) => {
3231 #[cfg(feature = "routing-hints")]
3234 let (is_write, forced) = Self::extended_hint_route(state, sql)
3235 .unwrap_or_else(|| (Self::is_write_query(sql), None));
3236 #[cfg(not(feature = "routing-hints"))]
3237 let (is_write, forced): (bool, Option<String>) = (Self::is_write_query(sql), None);
3238 #[cfg(feature = "lag-routing")]
3239 if is_write && config.lag_routing.enabled {
3240 *session.last_write_at.write().await = Some(std::time::Instant::now());
3241 }
3242 Self::choose_target_node(is_write, forced, current_node, session, state, config)
3243 .await?
3244 }
3245 None => match current_node {
3249 Some(c) => c.to_string(),
3250 None => Self::select_read_node(session, state, config).await?,
3251 },
3252 };
3253
3254 #[cfg(feature = "circuit-breaker")]
3256 if let Some(mut resp) = Self::circuit_fast_fail(state, &target) {
3257 if wait_ready {
3258 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3259 }
3260 client
3261 .write_all(&resp)
3262 .await
3263 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3264 return Ok((None, resp.len() as u64));
3265 }
3266
3267 if let Err(e) = Self::ensure_conn(conns, &target, session, config, state).await {
3268 Self::record_backend_failure(state, &target, &e.to_string());
3269 return Err(e);
3270 }
3271 let backend = conns.get_mut(&target).expect("just ensured");
3272
3273 for name in reprepare {
3278 if backend.prepared.contains(name) {
3279 continue;
3280 }
3281 let Some(parse_bytes) = registry.get(name) else {
3282 continue; };
3284 match Self::reprepare_statement(&mut backend.stream, parse_bytes).await {
3285 Ok(()) => {
3286 backend.prepared.insert(name.clone());
3287 }
3288 Err(e) => {
3289 conns.remove(&target);
3290 return Err(e);
3291 }
3292 }
3293 }
3294
3295 let mut inject_parse_complete = false;
3302 let mut new_unnamed_sig: Option<bytes::Bytes> = None;
3303 if let Some((parse_msg, sig)) = unnamed.as_ref() {
3304 if backend.unnamed_sig.as_deref() == Some(&sig[..]) {
3305 inject_parse_complete = true;
3306 } else {
3307 if let Err(e) = backend
3308 .stream
3309 .write_all(parse_msg)
3310 .await
3311 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
3312 {
3313 conns.remove(&target);
3314 return Err(e);
3315 }
3316 new_unnamed_sig = Some(sig.clone());
3317 }
3318 }
3319
3320 let batch_err = match tokio::time::timeout(
3321 Self::BACKEND_WRITE_TIMEOUT,
3322 backend.stream.write_all(batch),
3323 )
3324 .await
3325 {
3326 Ok(Ok(())) => None,
3327 Ok(Err(e)) => Some(format!("Backend write error: {}", e)),
3328 Err(_) => Some("Backend write timeout".to_string()),
3329 };
3330 if let Some(msg) = batch_err {
3331 let e = ProxyError::Network(msg);
3332 conns.remove(&target);
3333 Self::record_backend_failure(state, &target, &e.to_string());
3334 return Err(e);
3335 }
3336
3337 let mut injected: u64 = 0;
3340 if inject_parse_complete {
3341 if let Err(e) = client
3342 .write_all(&[b'1', 0, 0, 0, 4])
3343 .await
3344 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))
3345 {
3346 conns.remove(&target);
3347 return Err(e);
3348 }
3349 injected = 5;
3350 }
3351
3352 let r = if wait_ready {
3353 Self::stream_until_ready(client, &mut backend.stream, session, state).await
3354 } else {
3355 Self::stream_flush(client, &mut backend.stream, session, state).await
3356 };
3357 match r {
3358 Ok(sent) => {
3359 #[cfg(feature = "circuit-breaker")]
3360 Self::circuit_record(state, &target, true, "");
3361 #[cfg(feature = "query-analytics")]
3362 if let Some(sql) = analytics_sql.as_deref() {
3363 Self::record_analytics(state, session, sql, &target, started.elapsed(), None)
3364 .await;
3365 }
3366 for name in defines {
3368 backend.prepared.insert(name.clone());
3369 }
3370 if let Some(sig) = new_unnamed_sig {
3372 backend.unnamed_sig = Some(sig);
3373 }
3374 Ok((Some(target), sent + injected))
3375 }
3376 Err(e) => {
3377 conns.remove(&target);
3378 Self::record_backend_failure(state, &target, &e.to_string());
3379 #[cfg(feature = "query-analytics")]
3380 if let Some(sql) = analytics_sql.as_deref() {
3381 Self::record_analytics(
3382 state,
3383 session,
3384 sql,
3385 &target,
3386 started.elapsed(),
3387 Some(e.to_string()),
3388 )
3389 .await;
3390 }
3391 Err(e)
3392 }
3393 }
3394 }
3395
3396 async fn reprepare_statement<S: AsyncReadExt + AsyncWriteExt + Unpin>(
3402 backend: &mut S,
3403 parse_bytes: &[u8],
3404 ) -> Result<()> {
3405 tokio::time::timeout(Self::REPREPARE_TIMEOUT, backend.write_all(parse_bytes))
3406 .await
3407 .map_err(|_| ProxyError::Network("re-prepare write timeout".to_string()))?
3408 .map_err(|e| ProxyError::Network(format!("re-prepare write error: {}", e)))?;
3409 tokio::time::timeout(Self::REPREPARE_TIMEOUT, backend.write_all(&[b'H', 0, 0, 0, 4]))
3411 .await
3412 .map_err(|_| ProxyError::Network("re-prepare flush timeout".to_string()))?
3413 .map_err(|e| ProxyError::Network(format!("re-prepare flush error: {}", e)))?;
3414 let mtype = tokio::time::timeout(Self::REPREPARE_TIMEOUT, Self::read_one_frame_type(backend))
3415 .await
3416 .map_err(|_| ProxyError::Network("re-prepare read timeout".to_string()))??;
3417 match mtype {
3418 b'1' => Ok(()), b'E' => Err(ProxyError::Protocol(
3420 "re-prepare rejected by backend".to_string(),
3421 )),
3422 other => Err(ProxyError::Protocol(format!(
3423 "unexpected re-prepare reply: {}",
3424 other as char
3425 ))),
3426 }
3427 }
3428
3429 async fn read_one_frame_type<S: AsyncReadExt + Unpin>(backend: &mut S) -> Result<u8> {
3433 let mut header = [0u8; 5];
3434 backend
3435 .read_exact(&mut header)
3436 .await
3437 .map_err(|e| ProxyError::Network(format!("re-prepare read error: {}", e)))?;
3438 let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
3439 let body_len = len.saturating_sub(4);
3440 if body_len > 0 {
3441 let mut body = vec![0u8; body_len];
3442 backend
3443 .read_exact(&mut body)
3444 .await
3445 .map_err(|e| ProxyError::Network(format!("re-prepare body read error: {}", e)))?;
3446 }
3447 Ok(header[0])
3448 }
3449
3450 fn parse_stmt_name(payload: &[u8]) -> &str {
3453 let end = payload.iter().position(|&b| b == 0).unwrap_or(0);
3454 std::str::from_utf8(&payload[..end]).unwrap_or("")
3455 }
3456
3457 fn bind_stmt_ref(payload: &[u8]) -> Option<&str> {
3461 let portal_end = payload.iter().position(|&b| b == 0)?;
3462 let rest = &payload[portal_end + 1..];
3463 let stmt_end = rest.iter().position(|&b| b == 0)?;
3464 let name = std::str::from_utf8(&rest[..stmt_end]).ok()?;
3465 (!name.is_empty()).then_some(name)
3466 }
3467
3468 fn stmt_kind_name(payload: &[u8]) -> Option<&str> {
3471 if payload.first() != Some(&b'S') {
3472 return None;
3473 }
3474 let rest = &payload[1..];
3475 let end = rest.iter().position(|&b| b == 0)?;
3476 let name = std::str::from_utf8(&rest[..end]).ok()?;
3477 (!name.is_empty()).then_some(name)
3478 }
3479
3480 async fn stream_until_ready(
3488 client: &mut ClientStream,
3489 backend: &mut TcpStream,
3490 session: &Arc<ClientSession>,
3491 state: &Arc<ServerState>,
3492 ) -> Result<u64> {
3493 let _ = state;
3494 let mut buf = BytesMut::with_capacity(16384);
3495 let mut read_buf = vec![0u8; 16384];
3496 let mut sent: u64 = 0;
3497
3498 loop {
3499 let mut consumed = 0usize;
3501 let mut ready_status: Option<u8> = None;
3502 let mut yield_for_copy = false;
3503 loop {
3504 let rem = &buf[consumed..];
3505 if rem.len() < 5 {
3506 break;
3507 }
3508 let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
3509 if len < 4 || rem.len() < len + 1 {
3510 break; }
3512 let frame_total = len + 1;
3513 let mtype = rem[0];
3514 consumed += frame_total;
3515 if mtype == b'Z' {
3516 ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
3518 break;
3519 }
3520 if mtype == b'G' || mtype == b'W' {
3521 yield_for_copy = true;
3524 break;
3525 }
3526 }
3527
3528 if consumed > 0 {
3529 tokio::time::timeout(Self::CLIENT_WRITE_TIMEOUT, client.write_all(&buf[..consumed]))
3530 .await
3531 .map_err(|_| ProxyError::Network("Client write timeout".to_string()))?
3532 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3533 sent += consumed as u64;
3534 let _ = buf.split_to(consumed);
3535 }
3536
3537 if let Some(status) = ready_status {
3538 let st = TransactionStatus::from_byte(status);
3539 let mut tx = session.tx_state.write().await;
3540 tx.in_transaction = st != TransactionStatus::Idle;
3541 return Ok(sent);
3542 }
3543 if yield_for_copy {
3544 return Ok(sent);
3545 }
3546
3547 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
3548 .await
3549 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
3550 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
3551 if n == 0 {
3552 return Err(ProxyError::Connection(
3553 "Backend closed mid-response".to_string(),
3554 ));
3555 }
3556 buf.extend_from_slice(&read_buf[..n]);
3557 }
3558 }
3559
3560 #[cfg(feature = "query-cache")]
3566 async fn stream_until_ready_capture(
3567 client: &mut ClientStream,
3568 backend: &mut TcpStream,
3569 session: &Arc<ClientSession>,
3570 ) -> Result<(u64, Vec<u8>, bool, usize)> {
3571 let mut buf = BytesMut::with_capacity(16384);
3572 let mut read_buf = vec![0u8; 16384];
3573 let mut sent: u64 = 0;
3574 let mut captured: Vec<u8> = Vec::with_capacity(4096);
3575 let mut had_error = false;
3576 let mut row_count: usize = 0;
3577
3578 loop {
3579 let mut consumed = 0usize;
3580 let mut ready_status: Option<u8> = None;
3581 let mut yield_for_copy = false;
3582 loop {
3583 let rem = &buf[consumed..];
3584 if rem.len() < 5 {
3585 break;
3586 }
3587 let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
3588 if len < 4 || rem.len() < len + 1 {
3589 break;
3590 }
3591 let frame_total = len + 1;
3592 let mtype = rem[0];
3593 if mtype == b'E' {
3594 had_error = true;
3595 }
3596 if mtype == b'C' {
3597 if let Some(tag) = rem.get(5..frame_total) {
3599 if let Some(end) = tag.iter().position(|&b| b == 0) {
3600 if let Ok(s) = std::str::from_utf8(&tag[..end]) {
3601 if let Some(n) =
3602 s.rsplit(' ').next().and_then(|x| x.parse::<usize>().ok())
3603 {
3604 row_count = n;
3605 }
3606 }
3607 }
3608 }
3609 }
3610 consumed += frame_total;
3611 if mtype == b'Z' {
3612 ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
3613 break;
3614 }
3615 if mtype == b'G' || mtype == b'W' {
3616 yield_for_copy = true;
3617 break;
3618 }
3619 }
3620
3621 if consumed > 0 {
3622 tokio::time::timeout(Self::CLIENT_WRITE_TIMEOUT, client.write_all(&buf[..consumed]))
3623 .await
3624 .map_err(|_| ProxyError::Network("Client write timeout".to_string()))?
3625 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3626 captured.extend_from_slice(&buf[..consumed]);
3627 sent += consumed as u64;
3628 let _ = buf.split_to(consumed);
3629 }
3630
3631 if let Some(status) = ready_status {
3632 let st = TransactionStatus::from_byte(status);
3633 let mut tx = session.tx_state.write().await;
3634 tx.in_transaction = st != TransactionStatus::Idle;
3635 let cacheable = !had_error && status == b'I';
3636 return Ok((sent, captured, cacheable, row_count));
3637 }
3638 if yield_for_copy {
3639 return Ok((sent, captured, false, row_count));
3640 }
3641
3642 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
3643 .await
3644 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
3645 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
3646 if n == 0 {
3647 return Err(ProxyError::Connection(
3648 "Backend closed mid-response".to_string(),
3649 ));
3650 }
3651 buf.extend_from_slice(&read_buf[..n]);
3652 }
3653 }
3654
3655 async fn stream_flush(
3661 client: &mut ClientStream,
3662 backend: &mut TcpStream,
3663 session: &Arc<ClientSession>,
3664 state: &Arc<ServerState>,
3665 ) -> Result<u64> {
3666 let _ = (session, state);
3667 let mut read_buf = vec![0u8; 16384];
3668 let mut sent: u64 = 0;
3669 loop {
3670 match tokio::time::timeout(Duration::from_millis(200), backend.read(&mut read_buf))
3671 .await
3672 {
3673 Ok(Ok(0)) => {
3674 return Err(ProxyError::Connection(
3675 "Backend closed mid-flush".to_string(),
3676 ))
3677 }
3678 Ok(Ok(n)) => {
3679 client
3680 .write_all(&read_buf[..n])
3681 .await
3682 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3683 sent += n as u64;
3684 }
3685 Ok(Err(e)) => {
3686 return Err(ProxyError::Network(format!("Backend read error: {}", e)))
3687 }
3688 Err(_) => return Ok(sent), }
3690 }
3691 }
3692
3693 fn is_write_message(msg: &Message) -> bool {
3695 match msg.msg_type {
3696 MessageType::Query => {
3697 crate::protocol::query_text(&msg.payload)
3701 .map(Self::is_write_query)
3702 .unwrap_or(false)
3703 }
3704 MessageType::Parse => {
3705 msg.payload
3708 .iter()
3709 .position(|&b| b == 0)
3710 .and_then(|end| crate::protocol::query_text(&msg.payload[end + 1..]))
3711 .map(Self::is_write_query)
3712 .unwrap_or(false)
3713 }
3714 _ => false,
3716 }
3717 }
3718
3719 fn is_write_query(sql: &str) -> bool {
3721 use crate::protocol::starts_with_ci;
3722 let trimmed = sql.trim();
3723
3724 if starts_with_ci(trimmed, "INSERT")
3726 || starts_with_ci(trimmed, "UPDATE")
3727 || starts_with_ci(trimmed, "DELETE")
3728 || starts_with_ci(trimmed, "CREATE")
3729 || starts_with_ci(trimmed, "DROP")
3730 || starts_with_ci(trimmed, "ALTER")
3731 || starts_with_ci(trimmed, "TRUNCATE")
3732 || starts_with_ci(trimmed, "GRANT")
3733 || starts_with_ci(trimmed, "REVOKE")
3734 || starts_with_ci(trimmed, "VACUUM")
3735 || starts_with_ci(trimmed, "REINDEX")
3736 || starts_with_ci(trimmed, "CLUSTER")
3737 {
3738 return true;
3739 }
3740
3741 if starts_with_ci(trimmed, "BEGIN")
3743 || starts_with_ci(trimmed, "START")
3744 || starts_with_ci(trimmed, "COMMIT")
3745 || starts_with_ci(trimmed, "ROLLBACK")
3746 || starts_with_ci(trimmed, "SAVEPOINT")
3747 || starts_with_ci(trimmed, "RELEASE")
3748 {
3749 return true;
3750 }
3751
3752 if starts_with_ci(trimmed, "SET") && !starts_with_ci(trimmed, "SET TRANSACTION READ ONLY") {
3754 return true;
3755 }
3756
3757 false
3758 }
3759
3760 #[cfg(feature = "rate-limiting")]
3763 async fn rate_limit_key(
3764 session: &Arc<ClientSession>,
3765 config: &ProxyConfig,
3766 ) -> crate::rate_limit::LimiterKey {
3767 use crate::config::RateLimitKeyBy;
3768 use crate::rate_limit::LimiterKey;
3769 match config.rate_limit.key_by {
3770 RateLimitKeyBy::Global => LimiterKey::Global,
3771 RateLimitKeyBy::ClientIp => LimiterKey::ClientIp(session.client_addr.ip()),
3772 RateLimitKeyBy::Database => {
3773 let vars = session.variables.read().await;
3774 LimiterKey::Database(vars.get("database").cloned().unwrap_or_default())
3775 }
3776 RateLimitKeyBy::User => {
3777 let vars = session.variables.read().await;
3778 LimiterKey::User(vars.get("user").cloned().unwrap_or_default())
3779 }
3780 }
3781 }
3782
3783 #[cfg(feature = "rate-limiting")]
3789 async fn rate_limit_check(
3790 session: &Arc<ClientSession>,
3791 state: &Arc<ServerState>,
3792 config: &ProxyConfig,
3793 ) -> Option<Vec<u8>> {
3794 use crate::rate_limit::RateLimitResult;
3795 let limiter = state.rate_limiter.as_ref()?;
3796 let key = Self::rate_limit_key(session, config).await;
3797 match limiter.check(&key, 1) {
3798 RateLimitResult::Allowed => None,
3799 RateLimitResult::Warned(msg) => {
3800 tracing::warn!(key = %key, reason = %msg, "rate limit warning");
3801 None
3802 }
3803 RateLimitResult::Throttled(d) | RateLimitResult::Queued(d) => {
3804 tokio::time::sleep(d.min(Duration::from_secs(5))).await;
3807 None
3808 }
3809 RateLimitResult::Denied(exc) => {
3810 tracing::info!(key = %key, "rate limit exceeded");
3811 let msg = format!(
3812 "rate limit exceeded: {} (retry after {}ms)",
3813 exc.message,
3814 exc.retry_after.as_millis()
3815 );
3816 Some(Self::create_error_response("53400", &msg))
3817 }
3818 }
3819 }
3820
3821 fn is_backend_fault(err: &str) -> bool {
3847 !err.contains("Client") && !err.contains("Backend read timeout")
3848 }
3849
3850 fn note_backend_failure(state: &Arc<ServerState>, addr: &str, err: &str) {
3854 if !Self::is_backend_fault(err) {
3855 return;
3856 }
3857 let _writers = state.health_write.lock();
3866 let snapshot = state.health.load_full();
3867 if snapshot.get(addr).map(|h| h.healthy).unwrap_or(false) {
3870 let mut next = (*snapshot).clone();
3871 if let Some(nh) = next.get_mut(addr) {
3872 nh.healthy = false;
3873 nh.failure_count = nh.failure_count.saturating_add(1);
3874 nh.last_error = Some(format!("in-band failure: {}", err));
3875 tracing::warn!(
3876 node = %addr,
3877 error = %err,
3878 "in-band failure — node marked unhealthy for fast failover"
3879 );
3880 }
3881 state.health.store(Arc::new(next));
3882 }
3883 }
3884
3885 fn record_backend_failure(state: &Arc<ServerState>, node: &str, err: &str) {
3891 Self::note_backend_failure(state, node, err);
3892 #[cfg(feature = "circuit-breaker")]
3893 if Self::is_backend_fault(err) {
3894 Self::circuit_record(state, node, false, err);
3895 }
3896 }
3897
3898 #[cfg(feature = "circuit-breaker")]
3901 fn circuit_is_open(state: &Arc<ServerState>, node: &str) -> bool {
3902 state
3903 .circuit_breaker
3904 .as_ref()
3905 .map(|cb| {
3906 cb.get_breaker(node).get_state() == crate::circuit_breaker::CircuitState::Open
3907 })
3908 .unwrap_or(false)
3909 }
3910
3911 #[cfg(feature = "circuit-breaker")]
3913 fn circuit_record(state: &Arc<ServerState>, node: &str, success: bool, err: &str) {
3914 if let Some(cb) = state.circuit_breaker.as_ref() {
3915 let breaker = cb.get_breaker(node);
3916 if success {
3917 breaker.record_success();
3918 } else {
3919 breaker.record_failure(err);
3920 }
3921 }
3922 }
3923
3924 #[cfg(feature = "circuit-breaker")]
3928 fn circuit_fast_fail(state: &Arc<ServerState>, node: &str) -> Option<Vec<u8>> {
3929 if Self::circuit_is_open(state, node) {
3930 tracing::info!(node = %node, "circuit open — fast-failing");
3931 Some(Self::create_error_response(
3932 "08006",
3933 &format!("circuit open for node {node}: backend temporarily unavailable"),
3934 ))
3935 } else {
3936 None
3937 }
3938 }
3939
3940 #[cfg(feature = "lag-routing")]
3943 fn ryw_pins_primary(last_write: Option<std::time::Instant>, window_ms: u64) -> bool {
3944 window_ms > 0
3945 && last_write
3946 .map(|t| t.elapsed() < Duration::from_millis(window_ms))
3947 .unwrap_or(false)
3948 }
3949
3950 #[cfg(feature = "lag-routing")]
3954 fn lag_excludes_standby(lag_bytes: Option<u64>, max_lag_bytes: u64) -> bool {
3955 max_lag_bytes > 0 && lag_bytes.map(|l| l > max_lag_bytes).unwrap_or(false)
3956 }
3957
3958 #[cfg(feature = "query-cache")]
3961 fn is_cacheable_read_sql(sql: &str) -> bool {
3962 use crate::protocol::{contains_ci, starts_with_ci};
3963 let t = sql.trim_start();
3964 if !starts_with_ci(t, "SELECT") {
3965 return false;
3966 }
3967 if contains_ci(t, "FOR UPDATE") || contains_ci(t, "FOR SHARE") {
3968 return false;
3969 }
3970 const VOLATILE: [&str; 10] = [
3972 "now(",
3973 "current_timestamp",
3974 "current_date",
3975 "current_time",
3976 "clock_timestamp",
3977 "statement_timestamp",
3978 "random(",
3979 "nextval(",
3980 "uuid_generate",
3981 "gen_random_uuid",
3982 ];
3983 !VOLATILE.iter().any(|v| contains_ci(t, v))
3984 }
3985
3986 #[cfg(feature = "query-cache")]
3990 async fn cacheable_read_ctx(
3991 session: &Arc<ClientSession>,
3992 sql: &str,
3993 ) -> Option<crate::cache::CacheContext> {
3994 if !Self::is_cacheable_read_sql(sql) {
3995 return None;
3996 }
3997 if session.tx_state.read().await.in_transaction {
3999 return None;
4000 }
4001 let (user, database) = {
4002 let vars = session.variables.read().await;
4003 (
4004 vars.get("user").cloned(),
4005 vars.get("database")
4006 .cloned()
4007 .unwrap_or_else(|| "default".to_string()),
4008 )
4009 };
4010 Some(crate::cache::CacheContext {
4011 database,
4012 user,
4013 branch: None,
4014 connection_id: Some(session.id.as_u64_pair().0),
4015 })
4016 }
4017
4018 #[cfg(feature = "multi-tenancy")]
4022 async fn tenant_request_ctx(
4023 session: &Arc<ClientSession>,
4024 ) -> crate::multi_tenancy::RequestContext {
4025 let vars = session.variables.read().await;
4026 crate::multi_tenancy::RequestContext {
4027 headers: vars.clone(),
4028 username: vars.get("user").cloned(),
4029 database: vars.get("database").cloned(),
4030 auth_token: None,
4031 sql_context: HashMap::new(),
4032 client_ip: Some(session.client_addr.ip().to_string()),
4033 connection_id: Some(session.id.as_u64_pair().0),
4034 }
4035 }
4036
4037 #[cfg(feature = "ha-tr")]
4042 async fn journal_write(state: &Arc<ServerState>, session: &Arc<ClientSession>, sql: &str) {
4043 let tx_id = uuid::Uuid::new_v4();
4044 let j = &state.transaction_journal;
4045 if j.begin_transaction(tx_id, session.id, crate::NodeId::new(), 0)
4046 .await
4047 .is_ok()
4048 {
4049 let _ = j
4050 .log_statement(tx_id, sql.to_string(), Vec::new(), None, None, 0)
4051 .await;
4052 }
4053 }
4054
4055 #[cfg(feature = "query-analytics")]
4058 async fn record_analytics(
4059 state: &Arc<ServerState>,
4060 session: &Arc<ClientSession>,
4061 sql: &str,
4062 node: &str,
4063 duration: Duration,
4064 error: Option<String>,
4065 ) {
4066 let Some(analytics) = state.analytics.as_ref() else {
4067 return;
4068 };
4069 let (user, database) = {
4070 let vars = session.variables.read().await;
4071 (
4072 vars.get("user").cloned().unwrap_or_default(),
4073 vars.get("database").cloned().unwrap_or_default(),
4074 )
4075 };
4076 let mut exec = crate::analytics::QueryExecution::new(sql, duration);
4077 exec.user = user;
4078 exec.database = database;
4079 exec.client_ip = session.client_addr.ip().to_string();
4080 exec.node = node.to_string();
4081 exec.session_id = Some(session.id.to_string());
4082 exec.error = error;
4083 analytics.record(exec);
4084 }
4085
4086 async fn select_primary_with_timeout(
4088 session: &Arc<ClientSession>,
4089 state: &Arc<ServerState>,
4090 config: &ProxyConfig,
4091 ) -> Result<String> {
4092 let timeout = config.write_timeout();
4093 let start = std::time::Instant::now();
4094 let check_interval = Duration::from_millis(100);
4097
4098 loop {
4099 let health = state.health.load_full();
4101 let primary = config
4102 .nodes
4103 .iter()
4104 .find(|n| n.role == NodeRole::Primary && n.enabled);
4105
4106 if let Some(primary_node) = primary {
4107 if let Some(node_health) = health.get(&primary_node.address()) {
4108 if node_health.healthy {
4109 let mut current = session.current_node.write().await;
4111 *current = Some(primary_node.address());
4112 return Ok(primary_node.address());
4113 }
4114 }
4115 }
4116 drop(health);
4117
4118 if start.elapsed() >= timeout {
4120 state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
4121 return Err(ProxyError::NoHealthyNodes);
4122 }
4123
4124 tracing::warn!(
4125 "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
4126 start.elapsed().as_secs_f64(),
4127 timeout.as_secs_f64()
4128 );
4129
4130 tokio::time::sleep(check_interval).await;
4132 }
4133 }
4134
4135 async fn select_read_node(
4137 session: &Arc<ClientSession>,
4138 state: &Arc<ServerState>,
4139 config: &ProxyConfig,
4140 ) -> Result<String> {
4141 {
4143 let tx_state = session.tx_state.read().await;
4144 if tx_state.in_transaction {
4145 if let Some(node) = session.current_node.read().await.clone() {
4146 return Ok(node);
4147 }
4148 }
4149 }
4150
4151 let health = state.health.load_full();
4153 let healthy_standbys: Vec<&NodeConfig> = config
4154 .nodes
4155 .iter()
4156 .filter(|n| {
4157 let base = n.enabled
4158 && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
4159 && health.get(&n.address()).map(|h| h.healthy).unwrap_or(false);
4160 #[cfg(feature = "circuit-breaker")]
4162 let base = base && !Self::circuit_is_open(state, &n.address());
4163 #[cfg(feature = "lag-routing")]
4165 let base = base
4166 && !Self::lag_excludes_standby(
4167 health
4168 .get(&n.address())
4169 .and_then(|h| h.replication_lag_bytes),
4170 config.lag_routing.max_lag_bytes,
4171 );
4172 base
4173 })
4174 .collect();
4175
4176 if !healthy_standbys.is_empty() {
4177 let ticket = state.lb_state.rr_counter.fetch_add(1, Ordering::Relaxed);
4179 let index = ticket as usize % healthy_standbys.len();
4180 let node_addr = healthy_standbys[index].address();
4181
4182 let mut current = session.current_node.write().await;
4183 *current = Some(node_addr.clone());
4184 return Ok(node_addr);
4185 }
4186
4187 Self::select_node(session, state, config).await
4189 }
4190
4191 async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
4194 let codec = ProtocolCodec::new();
4195 let mut buffer = BytesMut::with_capacity(4096);
4196 let mut read_buf = vec![0u8; 4096];
4197 let timeout = Duration::from_secs(10);
4198 let start = std::time::Instant::now();
4199
4200 loop {
4201 if start.elapsed() > timeout {
4202 return Err(ProxyError::Auth(
4203 "Backend authentication timeout".to_string(),
4204 ));
4205 }
4206
4207 let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
4208 .await
4209 .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
4210 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
4211
4212 if n == 0 {
4213 return Err(ProxyError::Connection(
4214 "Backend closed during auth".to_string(),
4215 ));
4216 }
4217
4218 buffer.extend_from_slice(&read_buf[..n]);
4219
4220 while let Some(msg) = codec.decode_message(&mut buffer)? {
4223 match msg.msg_type {
4224 MessageType::ReadyForQuery => {
4225 return Ok(());
4227 }
4228 MessageType::ErrorResponse => {
4229 let err = ErrorResponse::parse(msg.payload)
4230 .map(|e| e.message().unwrap_or("Unknown error").to_string())
4231 .unwrap_or_else(|_| "Parse error".to_string());
4232 return Err(ProxyError::Auth(err));
4233 }
4234 _ => {
4235 }
4237 }
4238 }
4239 }
4240 }
4241
4242 fn create_error_response(code: &str, message: &str) -> Vec<u8> {
4244 let mut fields = HashMap::new();
4245 fields.insert('S', "ERROR".to_string());
4246 fields.insert('V', "ERROR".to_string());
4247 fields.insert('C', code.to_string());
4248 fields.insert('M', message.to_string());
4249
4250 let err = ErrorResponse { fields };
4251 err.encode().encode().to_vec()
4252 }
4253
4254 fn create_ready_for_query(status: u8) -> Vec<u8> {
4257 let mut payload = BytesMut::with_capacity(1);
4258 payload.put_u8(status);
4259 Message::new(MessageType::ReadyForQuery, payload)
4260 .encode()
4261 .to_vec()
4262 }
4263
4264 #[cfg(feature = "wasm-plugins")]
4302 fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
4303 use serde::Deserialize;
4304
4305 #[derive(Deserialize)]
4306 struct CachedPayload {
4307 columns: Vec<ColumnDef>,
4308 rows: Vec<Vec<Option<String>>>,
4309 }
4310
4311 #[derive(Deserialize)]
4312 struct ColumnDef {
4313 name: String,
4314 #[serde(default = "default_text_oid")]
4315 oid: u32,
4316 }
4317
4318 fn default_text_oid() -> u32 {
4319 25 }
4321
4322 let payload: CachedPayload = serde_json::from_slice(bytes)
4323 .map_err(|e| ProxyError::Protocol(format!("invalid cached payload JSON: {}", e)))?;
4324
4325 if payload.columns.is_empty() {
4326 return Err(ProxyError::Protocol(
4327 "cached payload must declare at least one column".to_string(),
4328 ));
4329 }
4330
4331 let mut reply = Vec::new();
4332
4333 let mut rd = BytesMut::new();
4335 rd.put_u16(payload.columns.len() as u16);
4336 for col in &payload.columns {
4337 rd.extend_from_slice(col.name.as_bytes());
4338 rd.put_u8(0); rd.put_i32(0); rd.put_i16(0); rd.put_u32(col.oid);
4342 rd.put_i16(-1); rd.put_i32(-1); rd.put_i16(0); }
4346 reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
4347
4348 let column_count = payload.columns.len();
4350 for row in &payload.rows {
4351 if row.len() != column_count {
4352 return Err(ProxyError::Protocol(format!(
4353 "cached row has {} values but {} columns are declared",
4354 row.len(),
4355 column_count
4356 )));
4357 }
4358 let mut dr = BytesMut::new();
4359 dr.put_u16(row.len() as u16);
4360 for value in row {
4361 match value {
4362 Some(s) => {
4363 dr.put_i32(s.len() as i32);
4364 dr.extend_from_slice(s.as_bytes());
4365 }
4366 None => {
4367 dr.put_i32(-1); }
4369 }
4370 }
4371 reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
4372 }
4373
4374 let tag = format!("SELECT {}", payload.rows.len());
4376 let mut cc = BytesMut::new();
4377 cc.extend_from_slice(tag.as_bytes());
4378 cc.put_u8(0);
4379 reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
4380
4381 reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
4383
4384 Ok(reply)
4385 }
4386
4387 fn apply_pre_query_hook(
4397 msg: Message,
4398 state: &Arc<ServerState>,
4399 session: &Arc<ClientSession>,
4400 ) -> (Message, PreQueryAction) {
4401 #[cfg(feature = "wasm-plugins")]
4402 {
4403 let pm = match state.plugin_manager.as_ref() {
4404 Some(pm) => pm,
4405 None => return (msg, PreQueryAction::Forward),
4406 };
4407
4408 if msg.msg_type != MessageType::Query {
4409 return (msg, PreQueryAction::Forward);
4410 }
4411
4412 if !pm.has_hook(HookType::PreQuery) {
4415 return (msg, PreQueryAction::Forward);
4416 }
4417
4418 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4419 Ok(q) => q,
4420 Err(_) => return (msg, PreQueryAction::Forward),
4421 };
4422
4423 let ctx = Self::build_query_context(&query_msg.query, session);
4424
4425 match pm.execute_pre_query(&ctx) {
4426 PreQueryResult::Continue => (msg, PreQueryAction::Forward),
4427 PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
4428 PreQueryResult::Rewrite(new_sql) => {
4429 let rewritten = QueryMessage { query: new_sql }.encode();
4430 (rewritten, PreQueryAction::Forward)
4431 }
4432 PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
4433 }
4434 }
4435 #[cfg(not(feature = "wasm-plugins"))]
4436 {
4437 let _ = (state, session);
4438 (msg, PreQueryAction::Forward)
4439 }
4440 }
4441
4442 #[cfg(feature = "anomaly-detection")]
4448 fn record_anomaly_observation(
4449 msg: &Message,
4450 state: &Arc<ServerState>,
4451 session: &Arc<ClientSession>,
4452 ) {
4453 if msg.msg_type != MessageType::Query {
4454 return;
4455 }
4456 if let Some(query) = crate::protocol::query_text(&msg.payload) {
4459 Self::record_anomaly_sql(query, state, session);
4460 }
4461 }
4462
4463 #[cfg(feature = "anomaly-detection")]
4467 fn record_anomaly_sql(query: &str, state: &Arc<ServerState>, session: &Arc<ClientSession>) {
4468 let tenant = match session.variables.try_read() {
4475 Ok(vars) => vars
4476 .get("tenant_id")
4477 .or_else(|| vars.get("user"))
4478 .cloned()
4479 .unwrap_or_else(|| session.client_addr.ip().to_string()),
4480 Err(_) => session.client_addr.ip().to_string(),
4481 };
4482 let fingerprint = anomaly_fingerprint(query);
4483 let obs = crate::anomaly::QueryObservation {
4484 tenant,
4485 fingerprint,
4486 sql: query.to_string(),
4487 timestamp: std::time::Instant::now(),
4488 };
4489 for ev in state.anomaly_detector.record_query(&obs) {
4490 tracing::warn!(anomaly = ?ev, "anomaly detected");
4491 }
4492 }
4493
4494 async fn send_block_response(
4498 stream: &mut ClientStream,
4499 reason: &str,
4500 state: &Arc<ServerState>,
4501 ) -> Result<()> {
4502 let err =
4503 Self::create_error_response("42000", &format!("Query blocked by plugin: {}", reason));
4504 stream
4505 .write_all(&err)
4506 .await
4507 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
4508 let rfq = Self::create_ready_for_query(b'I');
4509 stream
4510 .write_all(&rfq)
4511 .await
4512 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
4513 state
4514 .metrics
4515 .bytes_sent
4516 .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
4517 Ok(())
4518 }
4519
4520 #[cfg(feature = "wasm-plugins")]
4526 fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
4527 let is_read_only = !Self::is_write_query(query);
4528 let hook_context = HookContext {
4529 client_id: Some(session.id.to_string()),
4530 ..HookContext::default()
4531 };
4532 QueryContext {
4533 query: query.to_string(),
4534 normalized: query.to_string(),
4535 tables: Vec::new(),
4536 is_read_only,
4537 hook_context,
4538 }
4539 }
4540
4541 async fn apply_authenticate_hook(
4562 _params: &HashMap<String, String>,
4563 _session: &Arc<ClientSession>,
4564 _state: &Arc<ServerState>,
4565 ) -> Result<()> {
4566 #[cfg(feature = "wasm-plugins")]
4567 {
4568 let pm = match _state.plugin_manager.as_ref() {
4569 Some(pm) => pm,
4570 None => return Ok(()),
4571 };
4572
4573 let request = PluginAuthRequest {
4574 headers: HashMap::new(),
4575 username: _params.get("user").cloned(),
4576 password: None,
4577 client_ip: _session.client_addr.ip().to_string(),
4578 database: _params.get("database").cloned(),
4579 };
4580
4581 match pm.execute_authenticate(&request) {
4582 AuthResult::Defer => Ok(()),
4583 AuthResult::Success(identity) => {
4584 tracing::debug!(
4585 user = %identity.username,
4586 roles = ?identity.roles,
4587 "plugin authenticated user"
4588 );
4589 *_session.plugin_identity.write().await = Some(identity);
4590 Ok(())
4591 }
4592 AuthResult::Denied(reason) => {
4593 tracing::info!(
4594 reason = %reason,
4595 client = %_session.client_addr,
4596 user = ?_params.get("user"),
4597 "plugin denied authentication"
4598 );
4599 Err(ProxyError::Auth(format!(
4600 "authentication denied by plugin: {}",
4601 reason
4602 )))
4603 }
4604 }
4605 }
4606 #[cfg(not(feature = "wasm-plugins"))]
4607 {
4608 Ok(())
4609 }
4610 }
4611
4612 fn apply_route_hook(
4615 msg: &Message,
4616 state: &Arc<ServerState>,
4617 session: &Arc<ClientSession>,
4618 ) -> RouteOverride {
4619 #[cfg(feature = "wasm-plugins")]
4620 {
4621 let pm = match state.plugin_manager.as_ref() {
4622 Some(pm) => pm,
4623 None => return RouteOverride::None,
4624 };
4625 if msg.msg_type != MessageType::Query {
4626 return RouteOverride::None;
4627 }
4628 if !pm.has_hook(HookType::Route) {
4631 return RouteOverride::None;
4632 }
4633 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4634 Ok(q) => q,
4635 Err(_) => return RouteOverride::None,
4636 };
4637 let ctx = Self::build_query_context(&query_msg.query, session);
4638 match pm.execute_route(&ctx) {
4639 RouteResult::Default => RouteOverride::None,
4640 RouteResult::Primary => RouteOverride::Primary,
4641 RouteResult::Standby => RouteOverride::Standby,
4642 RouteResult::Node(name) => RouteOverride::Node(name),
4643 RouteResult::Block(reason) => RouteOverride::Block(reason),
4644 RouteResult::Branch(name) => {
4645 tracing::warn!(
4646 branch = %name,
4647 "Route hook returned Branch but branch routing is not yet wired — using default"
4648 );
4649 RouteOverride::None
4650 }
4651 }
4652 }
4653 #[cfg(not(feature = "wasm-plugins"))]
4654 {
4655 let _ = (msg, state, session);
4656 RouteOverride::None
4657 }
4658 }
4659
4660 #[cfg(feature = "routing-hints")]
4666 fn hint_to_override(hints: &crate::routing::ParsedHints) -> RouteOverride {
4667 use crate::routing::{ConsistencyLevel, RouteTarget};
4668 if let Some(node) = &hints.node {
4669 return RouteOverride::Node(node.clone());
4670 }
4671 if let Some(route) = hints.route {
4672 return match route {
4673 RouteTarget::Primary => RouteOverride::Primary,
4674 RouteTarget::Standby
4675 | RouteTarget::Sync
4676 | RouteTarget::SemiSync
4677 | RouteTarget::Async
4678 | RouteTarget::Local => RouteOverride::Standby,
4679 RouteTarget::Any | RouteTarget::Vector => RouteOverride::None,
4680 };
4681 }
4682 if hints.consistency == Some(ConsistencyLevel::Strong) {
4683 return RouteOverride::Primary;
4684 }
4685 RouteOverride::None
4686 }
4687
4688 #[cfg(feature = "routing-hints")]
4696 fn resolve_simple_route(
4697 msg: &Message,
4698 plugin_override: RouteOverride,
4699 default_is_write: bool,
4700 state: &Arc<ServerState>,
4701 ) -> (RouteOverride, bool, Option<Message>) {
4702 let parser = match state.hint_parser.as_ref() {
4703 Some(p) => p,
4704 None => return (plugin_override, default_is_write, None),
4705 };
4706 let sql = match crate::protocol::query_text(&msg.payload) {
4707 Some(s) => s,
4708 None => return (plugin_override, default_is_write, None),
4709 };
4710 let hints = parser.parse(sql);
4711 if hints.is_empty() {
4712 return (plugin_override, default_is_write, None);
4713 }
4714 let stripped = parser.strip(sql);
4715 let is_write = Self::is_write_query(&stripped);
4716 let effective = match Self::hint_to_override(&hints) {
4717 RouteOverride::None => plugin_override,
4718 hint_override => hint_override,
4719 };
4720 let forward = if parser.strip_hints {
4721 Some(crate::protocol::QueryMessage { query: stripped }.encode())
4722 } else {
4723 None
4724 };
4725 (effective, is_write, forward)
4726 }
4727
4728 #[cfg(feature = "routing-hints")]
4735 fn extended_hint_route(state: &Arc<ServerState>, sql: &str) -> Option<(bool, Option<String>)> {
4736 let parser = state.hint_parser.as_ref()?;
4737 let hints = parser.parse(sql);
4738 if hints.is_empty() {
4739 return None;
4740 }
4741 let stripped = parser.strip(sql);
4742 let is_write = Self::is_write_query(&stripped);
4743 match Self::hint_to_override(&hints) {
4744 RouteOverride::Primary => Some((true, None)),
4745 RouteOverride::Standby => Some((false, None)),
4746 RouteOverride::Node(n) => Some((is_write, Some(n))),
4747 _ => Some((is_write, None)),
4748 }
4749 }
4750
4751 #[cfg(feature = "wasm-plugins")]
4755 fn fire_post_query_hook(
4756 msg: &Message,
4757 session: &Arc<ClientSession>,
4758 state: &Arc<ServerState>,
4759 result: &Result<(Option<String>, u64)>,
4760 elapsed: Duration,
4761 ) {
4762 let pm = match state.plugin_manager.as_ref() {
4763 Some(pm) => pm,
4764 None => return,
4765 };
4766 if msg.msg_type != MessageType::Query {
4767 return;
4768 }
4769 if !pm.has_hook(HookType::PostQuery) {
4772 return;
4773 }
4774 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4775 Ok(q) => q,
4776 Err(_) => return,
4777 };
4778 let ctx = Self::build_query_context(&query_msg.query, session);
4779 let outcome = match result {
4780 Ok((node, bytes)) => PostQueryOutcome {
4781 success: true,
4782 target_node: node.clone(),
4783 elapsed_us: elapsed.as_micros() as u64,
4784 response_bytes: *bytes,
4785 error: None,
4786 },
4787 Err(e) => PostQueryOutcome {
4788 success: false,
4789 target_node: None,
4790 elapsed_us: elapsed.as_micros() as u64,
4791 response_bytes: 0,
4792 error: Some(e.to_string()),
4793 },
4794 };
4795 pm.execute_post_query(&ctx, &outcome);
4796 }
4797
4798 async fn select_node(
4802 session: &Arc<ClientSession>,
4803 state: &Arc<ServerState>,
4804 config: &ProxyConfig,
4805 ) -> Result<String> {
4806 {
4808 let tx_state = session.tx_state.read().await;
4809 if tx_state.in_transaction {
4810 if let Some(node) = session.current_node.read().await.clone() {
4811 return Ok(node);
4812 }
4813 }
4814 }
4815
4816 let health = state.health.load_full();
4818 let healthy_nodes: Vec<&NodeConfig> = config
4819 .nodes
4820 .iter()
4821 .filter(|n| n.enabled && health.get(&n.address()).map(|h| h.healthy).unwrap_or(false))
4822 .collect();
4823
4824 if healthy_nodes.is_empty() {
4825 return Err(ProxyError::NoHealthyNodes);
4826 }
4827
4828 if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
4830 let node_addr = primary.address();
4831 let mut current = session.current_node.write().await;
4832 *current = Some(node_addr.clone());
4833 return Ok(node_addr);
4834 }
4835
4836 if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
4839 tracing::warn!("Primary unavailable, connecting to standby for initial session");
4840 let node_addr = standby.address();
4841 let mut current = session.current_node.write().await;
4842 *current = Some(node_addr.clone());
4843 return Ok(node_addr);
4844 }
4845
4846 Err(ProxyError::NoHealthyNodes)
4848 }
4849
4850 fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
4852 let state = self.state.clone();
4853 let mut shutdown_rx = self.shutdown_tx.subscribe();
4854
4855 tokio::spawn(async move {
4856 let mut interval = tokio::time::interval(std::time::Duration::from_secs(
4857 state.live_config.load().health.check_interval_secs,
4858 ));
4859
4860 loop {
4861 tokio::select! {
4862 _ = interval.tick() => {
4863 let config = state.live_config.load_full();
4866 Self::check_all_nodes(&state, &config).await;
4867 }
4868 _ = shutdown_rx.recv() => {
4869 break;
4870 }
4871 }
4872 }
4873 })
4874 }
4875
4876 async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
4883 let timeout = Duration::from_secs(config.health.check_timeout_secs);
4886 let mut set = tokio::task::JoinSet::new();
4887 for node in &config.nodes {
4888 let addr = node.address();
4889 set.spawn(async move {
4890 let r = Self::check_node_addr(&addr, timeout).await;
4891 (addr, r)
4892 });
4893 }
4894 let mut results = Vec::with_capacity(config.nodes.len());
4895 while let Some(joined) = set.join_next().await {
4896 if let Ok(pair) = joined {
4897 results.push(pair);
4898 }
4899 }
4900
4901 let _writers = state.health_write.lock();
4907 let mut next = (*state.health.load_full()).clone();
4908 for (addr, result) in results {
4909 if let Some(node_health) = next.get_mut(&addr) {
4910 match result {
4911 Ok(latency) => {
4912 node_health.healthy = true;
4913 node_health.failure_count = 0;
4914 node_health.latency_ms = latency;
4915 node_health.last_error = None;
4916 }
4917 Err(e) => {
4918 node_health.failure_count += 1;
4919 node_health.last_error = Some(e.to_string());
4920 if node_health.failure_count >= config.health.failure_threshold {
4921 node_health.healthy = false;
4922 tracing::warn!(
4923 "Node {} marked unhealthy after {} failures",
4924 addr,
4925 node_health.failure_count
4926 );
4927 }
4928 }
4929 }
4930 node_health.last_check = chrono::Utc::now();
4931 }
4932 }
4933 state.health.store(Arc::new(next));
4934 }
4935
4936 async fn check_node_addr(addr: &str, timeout: Duration) -> Result<f64> {
4947 const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F];
4949 let start = std::time::Instant::now();
4950 let mut stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
4951 .await
4952 .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", addr)))?
4953 .map_err(|e| {
4954 ProxyError::HealthCheck(format!("Failed to connect to {}: {}", addr, e))
4955 })?;
4956
4957 let probe = async {
4958 stream.write_all(&SSL_REQUEST).await?;
4959 let mut resp = [0u8; 1];
4960 stream.read_exact(&mut resp).await?;
4961 Ok::<u8, std::io::Error>(resp[0])
4962 };
4963 let remaining = timeout
4965 .saturating_sub(start.elapsed())
4966 .max(Duration::from_millis(1));
4967 let byte = tokio::time::timeout(remaining, probe)
4968 .await
4969 .map_err(|_| {
4970 ProxyError::HealthCheck(format!("{} did not answer protocol probe in time", addr))
4971 })?
4972 .map_err(|e| ProxyError::HealthCheck(format!("{} protocol probe error: {}", addr, e)))?;
4973 if byte != b'S' && byte != b'N' {
4976 return Err(ProxyError::HealthCheck(format!(
4977 "{} sent unexpected probe reply {:#x}",
4978 addr, byte
4979 )));
4980 }
4981 let latency = start.elapsed().as_secs_f64() * 1000.0;
4982 Ok(latency)
4983 }
4984
4985 fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
4987 #[cfg(feature = "pool-modes")]
4989 let state = self.state.clone();
4990 let mut shutdown_rx = self.shutdown_tx.subscribe();
4991
4992 tokio::spawn(async move {
4993 let mut interval = tokio::time::interval(Self::POOL_REAP_INTERVAL);
4994
4995 loop {
4996 tokio::select! {
4997 _ = interval.tick() => {
4998 #[cfg(feature = "pool-modes")]
5000 if let Some(ref pool_manager) = state.pool_manager {
5001 pool_manager.evict_idle().await;
5002 tracing::trace!("Pool-modes idle eviction completed");
5003 }
5004 #[cfg(feature = "pool-modes")]
5009 if let Some(ref backend_pool) = state.backend_pool {
5010 let ttl = std::time::Duration::from_secs(
5011 state.live_config.load().pool_mode.idle_timeout_secs,
5012 );
5013 let n = if ttl.is_zero() {
5019 0
5020 } else {
5021 backend_pool.reap_idle(ttl)
5022 };
5023 if n > 0 {
5024 tracing::debug!(
5025 target: "helios::pool",
5026 reaped = n,
5027 idle_remaining = backend_pool.idle_count(),
5028 "reaped idle backend connections (TTL)"
5029 );
5030 }
5031 }
5032 }
5033 _ = shutdown_rx.recv() => {
5034 #[cfg(feature = "pool-modes")]
5036 if let Some(ref pool_manager) = state.pool_manager {
5037 pool_manager.close_all().await;
5038 tracing::info!("Pool-modes manager closed all connections");
5039 }
5040 break;
5041 }
5042 }
5043 }
5044 })
5045 }
5046
5047 pub fn shutdown(&self) {
5049 let _ = self.shutdown_tx.send(());
5050 }
5051
5052 #[cfg(feature = "pool-modes")]
5054 pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
5055 if let Some(ref pool_manager) = self.state.pool_manager {
5056 let stats = pool_manager.get_stats().await;
5057 let metrics = pool_manager.metrics().snapshot();
5058 let default_mode = pool_manager.default_mode();
5059
5060 let avg_lease_duration_ms = metrics
5062 .mode_stats
5063 .get(&default_mode)
5064 .map(|s| s.avg_lease_duration_ms as u64)
5065 .unwrap_or(0);
5066
5067 Some(PoolModeStatsSnapshot {
5068 mode: format!("{:?}", default_mode),
5069 total_connections: stats.total_connections,
5070 active_leases: stats.active_connections,
5071 idle_connections: stats.idle_connections,
5072 node_count: stats.node_count,
5073 acquires: metrics.acquires,
5074 releases: metrics.releases,
5075 acquire_failures: metrics.acquire_failures,
5076 acquire_timeouts: metrics.acquire_timeouts,
5077 transactions_completed: metrics.transactions_completed,
5078 statements_executed: metrics.statements_executed,
5079 avg_lease_duration_ms,
5080 })
5081 } else {
5082 None
5083 }
5084 }
5085
5086 #[cfg(feature = "pool-modes")]
5088 pub async fn add_node_to_pool(&self, node: &NodeConfig) {
5089 if let Some(ref pool_manager) = self.state.pool_manager {
5090 let endpoint = NodeEndpoint::new(&node.host, node.port)
5091 .with_role(match node.role {
5092 NodeRole::Primary => crate::NodeRole::Primary,
5093 NodeRole::Standby => crate::NodeRole::Standby,
5094 NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
5095 })
5096 .with_weight(node.weight);
5097 pool_manager.add_node(&endpoint).await;
5098 tracing::info!("Added node {} to pool manager", node.address());
5099 }
5100 }
5101
5102 pub fn metrics(&self) -> ServerMetricsSnapshot {
5104 ServerMetricsSnapshot {
5105 connections_accepted: self
5106 .state
5107 .metrics
5108 .connections_accepted
5109 .load(Ordering::Relaxed),
5110 connections_closed: self
5111 .state
5112 .metrics
5113 .connections_closed
5114 .load(Ordering::Relaxed),
5115 queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
5116 bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
5117 bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
5118 failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
5119 }
5120 }
5121}
5122
5123#[derive(Debug, Clone)]
5125pub struct ServerMetricsSnapshot {
5126 pub connections_accepted: u64,
5127 pub connections_closed: u64,
5128 pub queries_processed: u64,
5129 pub bytes_received: u64,
5130 pub bytes_sent: u64,
5131 pub failovers: u64,
5132}
5133
5134#[cfg(feature = "pool-modes")]
5136#[derive(Debug, Clone)]
5137pub struct PoolModeStatsSnapshot {
5138 pub mode: String,
5140 pub total_connections: usize,
5142 pub active_leases: usize,
5144 pub idle_connections: usize,
5146 pub node_count: usize,
5148 pub acquires: u64,
5150 pub releases: u64,
5152 pub acquire_failures: u64,
5154 pub acquire_timeouts: u64,
5156 pub transactions_completed: u64,
5158 pub statements_executed: u64,
5160 pub avg_lease_duration_ms: u64,
5162}
5163
5164#[cfg(test)]
5165mod tests {
5166 use super::*;
5167 use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
5168 #[cfg(not(feature = "wasm-plugins"))]
5169 use crate::protocol::QueryMessage;
5170
5171 fn test_config() -> ProxyConfig {
5172 let mut config = ProxyConfig::default();
5173 config.listen_address = "127.0.0.1:0".to_string();
5174 config.add_node("127.0.0.1:5432", "primary").unwrap();
5175 config
5176 }
5177
5178 #[test]
5179 fn test_server_creation() {
5180 let config = test_config();
5181 let server = ProxyServer::new(config);
5182 assert!(server.is_ok());
5183 }
5184
5185 #[test]
5186 fn is_backend_fault_excludes_client_and_slow_query_errors() {
5187 assert!(ProxyServer::is_backend_fault("Backend read error: connection reset"));
5189 assert!(ProxyServer::is_backend_fault("Backend write error: broken pipe"));
5190 assert!(ProxyServer::is_backend_fault("Backend write timeout"));
5191 assert!(ProxyServer::is_backend_fault(
5192 "Failed to connect to 127.0.0.1:5432: Connection refused"
5193 ));
5194 assert!(!ProxyServer::is_backend_fault("Backend read timeout"));
5197 assert!(!ProxyServer::is_backend_fault("Client write timeout"));
5198 assert!(!ProxyServer::is_backend_fault("Client write error: broken pipe"));
5199 assert!(!ProxyServer::is_backend_fault("Backend read timeout"));
5201 assert!(ProxyServer::is_backend_fault("Backend read error: timed out"));
5202 }
5203
5204 #[test]
5205 fn test_hba_addr_matches() {
5206 use std::net::IpAddr;
5207 let v4 = |s: &str| s.parse::<IpAddr>().unwrap();
5208 assert!(ProxyServer::hba_addr_matches("all", v4("203.0.113.7")));
5210 assert!(ProxyServer::hba_addr_matches("10.0.0.0/8", v4("10.1.2.3")));
5212 assert!(!ProxyServer::hba_addr_matches("10.0.0.0/8", v4("11.1.2.3")));
5213 assert!(ProxyServer::hba_addr_matches(
5214 "127.0.0.1/32",
5215 v4("127.0.0.1")
5216 ));
5217 assert!(!ProxyServer::hba_addr_matches(
5218 "127.0.0.1/32",
5219 v4("127.0.0.2")
5220 ));
5221 assert!(ProxyServer::hba_addr_matches(
5223 "192.168.1.1",
5224 v4("192.168.1.1")
5225 ));
5226 assert!(!ProxyServer::hba_addr_matches(
5227 "192.168.1.1",
5228 v4("192.168.1.2")
5229 ));
5230 assert!(ProxyServer::hba_addr_matches("::1/128", v4("::1")));
5232 assert!(ProxyServer::hba_addr_matches("0.0.0.0/0", v4("8.8.8.8")));
5233 }
5234
5235 #[test]
5236 fn test_hba_admits() {
5237 use crate::config::{HbaAction, HbaRule};
5238 use std::net::IpAddr;
5239 let ip: IpAddr = "10.0.0.5".parse().unwrap();
5240 assert!(ProxyServer::hba_admits(&[], ip, "bench", "benchdb"));
5242 let rules = vec![HbaRule {
5244 action: HbaAction::Reject,
5245 user: "bench".into(),
5246 database: "all".into(),
5247 address: "all".into(),
5248 }];
5249 assert!(!ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
5250 assert!(ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
5251 let rules = vec![
5253 HbaRule {
5254 action: HbaAction::Allow,
5255 user: "bench".into(),
5256 database: "all".into(),
5257 address: "10.0.0.0/8".into(),
5258 },
5259 HbaRule {
5260 action: HbaAction::Reject,
5261 user: "all".into(),
5262 database: "all".into(),
5263 address: "all".into(),
5264 },
5265 ];
5266 assert!(ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
5267 assert!(!ProxyServer::hba_admits(
5268 &rules,
5269 "192.168.0.1".parse().unwrap(),
5270 "bench",
5271 "benchdb"
5272 ));
5273 assert!(!ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
5274 }
5275
5276 #[test]
5277 fn test_initial_metrics() {
5278 let config = test_config();
5279 let server = ProxyServer::new(config).unwrap();
5280 let metrics = server.metrics();
5281 assert_eq!(metrics.connections_accepted, 0);
5282 assert_eq!(metrics.queries_processed, 0);
5283 }
5284
5285 #[tokio::test]
5286 async fn test_session_creation() {
5287 let config = test_config();
5288 let server = ProxyServer::new(config).unwrap();
5289
5290 let sessions = server.state.sessions.read().await;
5291 assert!(sessions.is_empty());
5292 }
5293
5294 #[tokio::test]
5295 async fn test_node_health_initialization() {
5296 let config = test_config();
5297 let server = ProxyServer::new(config).unwrap();
5298
5299 let health = server.state.health.load_full();
5300 assert!(!health.is_empty());
5301
5302 for node_health in health.values() {
5303 assert!(node_health.healthy);
5304 assert_eq!(node_health.failure_count, 0);
5305 }
5306 }
5307
5308 fn make_test_session() -> Arc<ClientSession> {
5310 Arc::new(ClientSession {
5311 id: Uuid::new_v4(),
5312 client_addr: "127.0.0.1:0".parse().unwrap(),
5313 current_node: RwLock::new(None),
5314 tx_state: RwLock::new(TransactionState::default()),
5315 variables: RwLock::new(HashMap::new()),
5316 created_at: chrono::Utc::now(),
5317 tr_mode: crate::config::TrMode::default(),
5318 #[cfg(feature = "lag-routing")]
5319 last_write_at: RwLock::new(None),
5320 #[cfg(feature = "pool-modes")]
5321 pool_client_id: crate::pool::lease::ClientId::default(),
5322 #[cfg(feature = "wasm-plugins")]
5323 plugin_identity: RwLock::new(None),
5324 })
5325 }
5326
5327 #[tokio::test]
5331 async fn test_apply_route_hook_no_plugin_manager_returns_none() {
5332 let config = test_config();
5333 let server = ProxyServer::new(config).unwrap();
5334 let session = make_test_session();
5335
5336 let msg = QueryMessage {
5337 query: "SELECT * FROM users".to_string(),
5338 }
5339 .encode();
5340
5341 let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
5342 assert!(matches!(decision, RouteOverride::None));
5343 }
5344
5345 #[tokio::test]
5349 async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
5350 let config = test_config();
5351 let server = ProxyServer::new(config).unwrap();
5352 let session = make_test_session();
5353
5354 let original = QueryMessage {
5355 query: "SELECT 1".to_string(),
5356 }
5357 .encode();
5358 let original_bytes = original.encode().to_vec();
5359
5360 let (msg_out, action) =
5361 ProxyServer::apply_pre_query_hook(original, &server.state, &session);
5362
5363 assert!(matches!(action, PreQueryAction::Forward));
5364 assert_eq!(msg_out.encode().to_vec(), original_bytes);
5366 }
5367
5368 #[tokio::test]
5372 async fn test_apply_route_hook_skips_non_query_messages() {
5373 let config = test_config();
5374 let server = ProxyServer::new(config).unwrap();
5375 let session = make_test_session();
5376
5377 let sync_msg = Message::empty(MessageType::Sync);
5378 let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
5379 assert!(matches!(decision, RouteOverride::None));
5380 }
5381
5382 #[cfg(feature = "wasm-plugins")]
5387 #[test]
5388 fn test_init_plugin_manager_disabled_by_default_returns_none() {
5389 let config = test_config();
5390 assert!(!config.plugins.enabled);
5391 let pm = ProxyServer::init_plugin_manager(&config.plugins);
5392 assert!(pm.is_none());
5393 }
5394
5395 #[cfg(feature = "wasm-plugins")]
5399 #[test]
5400 fn test_init_plugin_manager_missing_dir_logs_warning() {
5401 let mut config = test_config();
5402 config.plugins.enabled = true;
5403 config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
5404
5405 let pm = ProxyServer::init_plugin_manager(&config.plugins);
5407 assert!(pm.is_some());
5408 }
5409
5410 #[tokio::test]
5414 async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
5415 let config = test_config();
5416 let server = ProxyServer::new(config).unwrap();
5417 let session = make_test_session();
5418
5419 let mut params = HashMap::new();
5420 params.insert("user".to_string(), "alice".to_string());
5421 params.insert("database".to_string(), "app".to_string());
5422
5423 let result = ProxyServer::apply_authenticate_hook(¶ms, &session, &server.state).await;
5424 assert!(result.is_ok());
5425
5426 #[cfg(feature = "wasm-plugins")]
5428 {
5429 let ident = session.plugin_identity.read().await;
5430 assert!(ident.is_none());
5431 }
5432 }
5433
5434 #[cfg(feature = "wasm-plugins")]
5442 #[test]
5443 fn test_synthesise_cached_response_roundtrip() {
5444 let payload = br#"{
5445 "columns": [
5446 {"name": "id", "oid": 23},
5447 {"name": "email", "oid": 25}
5448 ],
5449 "rows": [
5450 ["1", "alice@example.com"],
5451 ["2", null]
5452 ]
5453 }"#;
5454 let reply = ProxyServer::synthesise_cached_response(payload).expect("synthesis");
5455
5456 let mut tags = Vec::new();
5459 let mut i = 0;
5460 while i < reply.len() {
5461 let tag = reply[i];
5462 let len = u32::from_be_bytes([reply[i + 1], reply[i + 2], reply[i + 3], reply[i + 4]])
5463 as usize;
5464 tags.push(tag);
5465 i += 1 + len;
5466 }
5467 assert_eq!(i, reply.len(), "no trailing bytes");
5468 assert_eq!(tags, vec![b'T', b'D', b'D', b'C', b'Z'], "wire frame order");
5469
5470 assert_eq!(*reply.last().unwrap(), b'I');
5472 }
5473
5474 #[cfg(feature = "wasm-plugins")]
5477 #[test]
5478 fn test_synthesise_cached_response_rejects_row_width_mismatch() {
5479 let payload = br#"{
5480 "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
5481 "rows": [["1", "alice", "extra"]]
5482 }"#;
5483 let result = ProxyServer::synthesise_cached_response(payload);
5484 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5485 }
5486
5487 #[cfg(feature = "wasm-plugins")]
5491 #[test]
5492 fn test_synthesise_cached_response_rejects_empty_columns() {
5493 let payload = br#"{ "columns": [], "rows": [] }"#;
5494 let result = ProxyServer::synthesise_cached_response(payload);
5495 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5496 }
5497
5498 #[cfg(feature = "wasm-plugins")]
5501 #[test]
5502 fn test_synthesise_cached_response_rejects_bad_json() {
5503 let payload = b"not json at all";
5504 let result = ProxyServer::synthesise_cached_response(payload);
5505 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5506 }
5507
5508 #[cfg(feature = "wasm-plugins")]
5517 #[tokio::test]
5518 async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
5519 use crate::plugins::{PluginManager, PluginRuntimeConfig};
5520
5521 let config = test_config();
5522 let server = ProxyServer::new(config).unwrap();
5523 let session = make_test_session();
5524
5525 let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
5528 let augmented_state = Arc::new(ServerState {
5529 sessions: RwLock::new(HashMap::new()),
5530 health: ArcSwap::from_pointee(HashMap::new()),
5531 health_write: parking_lot::Mutex::new(()),
5532 live_config: ArcSwap::from_pointee(ProxyConfig::default()),
5533 metrics: ServerMetrics::default(),
5534 cancel_map: Arc::new(DashMap::new()),
5535 cancel_order: Arc::new(parking_lot::Mutex::new(std::collections::VecDeque::new())),
5536 tls_acceptor: None,
5537 auth_file: None,
5538 mirror: None,
5539 cutover: Arc::new(ArcSwap::from_pointee(None)),
5540 lb_state: LoadBalancerState {
5541 rr_counter: AtomicU64::new(0),
5542 },
5543 #[cfg(feature = "routing-hints")]
5544 hint_parser: None,
5545 #[cfg(feature = "rate-limiting")]
5546 rate_limiter: None,
5547 #[cfg(feature = "circuit-breaker")]
5548 circuit_breaker: None,
5549 #[cfg(feature = "query-analytics")]
5550 analytics: None,
5551 #[cfg(feature = "query-cache")]
5552 query_cache: None,
5553 #[cfg(feature = "query-rewriting")]
5554 rewriter: None,
5555 #[cfg(feature = "multi-tenancy")]
5556 tenant_manager: None,
5557 #[cfg(feature = "schema-routing")]
5558 schema_analyzer: None,
5559 #[cfg(feature = "pool-modes")]
5560 pool_manager: None,
5561 #[cfg(feature = "pool-modes")]
5562 backend_pool: None,
5563 plugin_manager: Some(pm),
5564 #[cfg(feature = "ha-tr")]
5565 transaction_journal: Arc::new(crate::transaction_journal::TransactionJournal::new()),
5566 #[cfg(feature = "anomaly-detection")]
5567 anomaly_detector: Arc::new(crate::anomaly::AnomalyDetector::new(
5568 crate::anomaly::AnomalyConfig::default(),
5569 )),
5570 #[cfg(feature = "edge-proxy")]
5571 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
5572 #[cfg(feature = "edge-proxy")]
5573 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
5574 32,
5575 std::time::Duration::from_secs(120),
5576 )),
5577 });
5578
5579 let mut params = HashMap::new();
5580 params.insert("user".to_string(), "alice".to_string());
5581
5582 let result =
5583 ProxyServer::apply_authenticate_hook(¶ms, &session, &augmented_state).await;
5584 assert!(result.is_ok());
5585 let ident = session.plugin_identity.read().await;
5586 assert!(ident.is_none());
5587 let _ = server;
5589 }
5590
5591 fn cstr(s: &str) -> Vec<u8> {
5594 let mut v = s.as_bytes().to_vec();
5595 v.push(0);
5596 v
5597 }
5598
5599 #[test]
5600 fn parse_stmt_name_extracts_named_and_unnamed() {
5601 let mut named = cstr("ps1");
5603 named.extend_from_slice(&cstr("SELECT 1"));
5604 named.extend_from_slice(&[0, 0]);
5605 assert_eq!(ProxyServer::parse_stmt_name(&named), "ps1");
5606
5607 let mut unnamed = cstr("");
5608 unnamed.extend_from_slice(&cstr("SELECT 1"));
5609 unnamed.extend_from_slice(&[0, 0]);
5610 assert_eq!(ProxyServer::parse_stmt_name(&unnamed), "");
5611 }
5612
5613 #[test]
5614 fn bind_stmt_ref_reads_second_cstring() {
5615 let mut named = cstr("portal_a");
5617 named.extend_from_slice(&cstr("ps1"));
5618 named.extend_from_slice(&[0, 0]); assert_eq!(ProxyServer::bind_stmt_ref(&named), Some("ps1"));
5620
5621 let mut unnamed = cstr("");
5623 unnamed.extend_from_slice(&cstr(""));
5624 assert_eq!(ProxyServer::bind_stmt_ref(&unnamed), None);
5625 }
5626
5627 #[test]
5628 fn stmt_kind_name_only_matches_statement_kind() {
5629 let mut stmt = vec![b'S'];
5631 stmt.extend_from_slice(&cstr("ps1"));
5632 assert_eq!(ProxyServer::stmt_kind_name(&stmt), Some("ps1"));
5633
5634 let mut portal = vec![b'P'];
5636 portal.extend_from_slice(&cstr("portal_a"));
5637 assert_eq!(ProxyServer::stmt_kind_name(&portal), None);
5638
5639 let mut empty = vec![b'S'];
5641 empty.extend_from_slice(&cstr(""));
5642 assert_eq!(ProxyServer::stmt_kind_name(&empty), None);
5643 }
5644
5645 #[tokio::test]
5646 async fn read_one_frame_type_consumes_full_frame() {
5647 let (mut a, mut b) = tokio::io::duplex(64);
5650 let bytes = [b'1', 0, 0, 0, 4, b'Z', 0, 0, 0, 5, b'I'];
5652 b.write_all(&bytes).await.unwrap();
5653 let t = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
5654 assert_eq!(t, b'1');
5655 let t2 = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
5657 assert_eq!(t2, b'Z');
5658 }
5659
5660 #[tokio::test]
5661 async fn reprepare_statement_accepts_parse_complete_and_rejects_error() {
5662 let (mut client, mut backend) = tokio::io::duplex(64);
5664 backend.write_all(&[b'1', 0, 0, 0, 4]).await.unwrap();
5665 let parse = {
5666 let mut p = vec![b'P', 0, 0, 0, 0];
5667 p.extend_from_slice(&cstr("ps1"));
5668 p.extend_from_slice(&cstr("SELECT 1"));
5669 p.extend_from_slice(&[0, 0]);
5670 p
5671 };
5672 assert!(ProxyServer::reprepare_statement(&mut client, &parse)
5673 .await
5674 .is_ok());
5675
5676 let (mut client2, mut backend2) = tokio::io::duplex(64);
5678 backend2.write_all(&[b'E', 0, 0, 0, 4]).await.unwrap();
5679 assert!(ProxyServer::reprepare_statement(&mut client2, &parse)
5680 .await
5681 .is_err());
5682 }
5683
5684 #[cfg(feature = "routing-hints")]
5687 mod routing_hints {
5688 use super::*;
5689 use crate::routing::HintParser;
5690
5691 fn over(sql: &str) -> RouteOverride {
5692 let hints = HintParser::new().parse(sql);
5693 ProxyServer::hint_to_override(&hints)
5694 }
5695
5696 #[test]
5697 fn route_primary_maps_to_primary() {
5698 assert!(matches!(
5699 over("/*helios:route=primary*/ SELECT 1"),
5700 RouteOverride::Primary
5701 ));
5702 }
5703
5704 #[test]
5705 fn read_tier_targets_map_to_standby() {
5706 for t in ["standby", "sync", "semisync", "async", "local"] {
5707 assert!(
5708 matches!(
5709 over(&format!("/*helios:route={t}*/ SELECT 1")),
5710 RouteOverride::Standby
5711 ),
5712 "route={t} should map to Standby"
5713 );
5714 }
5715 }
5716
5717 #[test]
5718 fn any_and_vector_impose_no_constraint() {
5719 assert!(matches!(
5720 over("/*helios:route=any*/ SELECT 1"),
5721 RouteOverride::None
5722 ));
5723 assert!(matches!(
5724 over("/*helios:route=vector*/ SELECT 1"),
5725 RouteOverride::None
5726 ));
5727 }
5728
5729 #[test]
5730 fn node_hint_maps_to_node_and_wins_over_route() {
5731 match over("/*helios:node=pg-standby,route=primary*/ SELECT 1") {
5733 RouteOverride::Node(n) => assert_eq!(n, "pg-standby"),
5734 other => panic!("expected Node, got {other:?}"),
5735 }
5736 }
5737
5738 #[test]
5739 fn consistency_strong_forces_primary() {
5740 assert!(matches!(
5741 over("/*helios:consistency=strong*/ SELECT 1"),
5742 RouteOverride::Primary
5743 ));
5744 }
5745
5746 #[test]
5747 fn no_hint_yields_none() {
5748 assert!(matches!(over("SELECT 1"), RouteOverride::None));
5749 }
5750
5751 #[test]
5755 fn write_verb_classified_after_strip() {
5756 let parser = HintParser::new();
5757 let raw = "/*helios:route=primary*/ INSERT INTO t VALUES (1)";
5758 assert!(!ProxyServer::is_write_query(raw));
5761 assert!(ProxyServer::is_write_query(&parser.strip(raw)));
5763 }
5764
5765 #[test]
5766 fn strip_removes_hint_comment() {
5767 let parser = HintParser::new();
5768 assert_eq!(
5769 parser.strip("/*helios:route=standby*/ SELECT 42"),
5770 "SELECT 42"
5771 );
5772 }
5773 }
5774
5775 #[cfg(feature = "rate-limiting")]
5778 mod rate_limiting {
5779 use crate::rate_limit::{LimiterKey, RateLimitConfig, RateLimitResult, RateLimiter};
5780
5781 #[test]
5782 fn burst_allows_then_denies() {
5783 let cfg = RateLimitConfig {
5786 enabled: true,
5787 default_qps: 1,
5788 default_burst: 2,
5789 ..Default::default()
5790 };
5791 let limiter = RateLimiter::new(cfg);
5792 let key = LimiterKey::User("u".to_string());
5793
5794 assert!(matches!(limiter.check(&key, 1), RateLimitResult::Allowed));
5796 assert!(matches!(limiter.check(&key, 1), RateLimitResult::Allowed));
5797
5798 let mut denied = false;
5800 for _ in 0..5 {
5801 if matches!(limiter.check(&key, 1), RateLimitResult::Denied(_)) {
5802 denied = true;
5803 }
5804 }
5805 assert!(denied, "over-burst checks must yield a Denied verdict");
5806 }
5807
5808 #[test]
5809 fn distinct_keys_have_independent_buckets() {
5810 let cfg = RateLimitConfig {
5811 enabled: true,
5812 default_qps: 1,
5813 default_burst: 1,
5814 ..Default::default()
5815 };
5816 let limiter = RateLimiter::new(cfg);
5817 assert!(matches!(
5819 limiter.check(&LimiterKey::User("a".to_string()), 1),
5820 RateLimitResult::Allowed
5821 ));
5822 assert!(matches!(
5823 limiter.check(&LimiterKey::User("b".to_string()), 1),
5824 RateLimitResult::Allowed
5825 ));
5826 }
5827 }
5828
5829 #[cfg(feature = "circuit-breaker")]
5832 mod circuit_breaker {
5833 use crate::circuit_breaker::{
5834 CircuitBreakerConfig, CircuitBreakerManager, CircuitState, ManagerConfig,
5835 };
5836 use std::time::Duration;
5837
5838 fn mgr(threshold: u32) -> CircuitBreakerManager {
5839 let cfg = CircuitBreakerConfig {
5840 failure_threshold: threshold,
5841 cooldown: Duration::from_secs(10),
5842 ..Default::default()
5843 };
5844 CircuitBreakerManager::new(ManagerConfig::new(cfg))
5845 }
5846
5847 #[test]
5848 fn opens_after_threshold_failures() {
5849 let m = mgr(3);
5850 let b = m.get_breaker("n1");
5851 assert_eq!(b.get_state(), CircuitState::Closed);
5852 b.record_failure("boom");
5853 b.record_failure("boom");
5854 assert_eq!(b.get_state(), CircuitState::Closed);
5856 b.record_failure("boom");
5858 assert_eq!(b.get_state(), CircuitState::Open);
5859 }
5860
5861 #[test]
5862 fn healthy_node_stays_closed() {
5863 let m = mgr(3);
5864 let b = m.get_breaker("n2");
5865 b.record_success();
5866 b.record_success();
5867 assert_eq!(b.get_state(), CircuitState::Closed);
5868 }
5869 }
5870
5871 #[cfg(feature = "query-analytics")]
5874 mod query_analytics {
5875 use crate::analytics::{AnalyticsConfig, OrderBy, QueryAnalytics, QueryExecution};
5876 use std::time::Duration;
5877
5878 #[test]
5879 fn records_and_collapses_literals() {
5880 let a = QueryAnalytics::new(AnalyticsConfig::default());
5881 for n in [1, 2, 3] {
5882 a.record(QueryExecution::new(
5883 format!("select {n}"),
5884 Duration::from_millis(1),
5885 ));
5886 }
5887 let top = a.top_queries(OrderBy::Calls, 10);
5888 assert!(!top.is_empty(), "no fingerprints recorded");
5889 assert!(
5891 top.iter().any(|s| s.calls >= 3),
5892 "literals did not collapse: {:?}",
5893 top.iter()
5894 .map(|s| (s.normalized.clone(), s.calls))
5895 .collect::<Vec<_>>()
5896 );
5897 }
5898 }
5899
5900 #[cfg(feature = "lag-routing")]
5903 mod lag_routing {
5904 use super::ProxyServer;
5905
5906 #[test]
5907 fn ryw_pins_recent_write() {
5908 assert!(ProxyServer::ryw_pins_primary(
5910 Some(std::time::Instant::now()),
5911 1000
5912 ));
5913 }
5914
5915 #[test]
5916 fn ryw_releases_old_write() {
5917 let old = std::time::Instant::now()
5918 .checked_sub(std::time::Duration::from_secs(10))
5919 .unwrap();
5920 assert!(!ProxyServer::ryw_pins_primary(Some(old), 1000));
5921 }
5922
5923 #[test]
5924 fn ryw_no_write_or_disabled() {
5925 assert!(!ProxyServer::ryw_pins_primary(None, 1000));
5926 assert!(!ProxyServer::ryw_pins_primary(
5928 Some(std::time::Instant::now()),
5929 0
5930 ));
5931 }
5932
5933 #[test]
5934 fn lag_exclusion_thresholds() {
5935 assert!(!ProxyServer::lag_excludes_standby(Some(999_999), 0));
5937 assert!(!ProxyServer::lag_excludes_standby(None, 1000));
5939 assert!(!ProxyServer::lag_excludes_standby(Some(500), 1000));
5941 assert!(ProxyServer::lag_excludes_standby(Some(2000), 1000));
5943 }
5944 }
5945
5946 #[cfg(feature = "query-cache")]
5949 mod query_cache {
5950 use super::ProxyServer;
5951
5952 #[test]
5953 fn plain_selects_are_cacheable() {
5954 assert!(ProxyServer::is_cacheable_read_sql("select v from t"));
5955 assert!(ProxyServer::is_cacheable_read_sql(
5956 " SELECT a, b FROM users WHERE id = 5"
5957 ));
5958 }
5959
5960 #[test]
5961 fn writes_and_non_selects_are_not_cacheable() {
5962 assert!(!ProxyServer::is_cacheable_read_sql(
5963 "insert into t values (1)"
5964 ));
5965 assert!(!ProxyServer::is_cacheable_read_sql("update t set v = 1"));
5966 assert!(!ProxyServer::is_cacheable_read_sql("show search_path"));
5967 }
5968
5969 #[test]
5970 fn locking_and_volatile_selects_are_not_cacheable() {
5971 assert!(!ProxyServer::is_cacheable_read_sql(
5972 "select * from t for update"
5973 ));
5974 assert!(!ProxyServer::is_cacheable_read_sql("select now()"));
5975 assert!(!ProxyServer::is_cacheable_read_sql("select random()"));
5976 assert!(!ProxyServer::is_cacheable_read_sql("select nextval('s')"));
5977 }
5978 }
5979
5980 #[cfg(feature = "query-rewriting")]
5983 mod query_rewriting {
5984 use crate::rewriter::{
5985 QueryPattern, QueryRewriter, RewriteRule, RewriterConfig, Transformation,
5986 };
5987
5988 fn rw_with_table_replace() -> QueryRewriter {
5989 let rw = QueryRewriter::new(RewriterConfig {
5990 enabled: true,
5991 ..Default::default()
5992 });
5993 rw.add_rule(
5994 RewriteRule::build("t")
5995 .pattern(QueryPattern::Table("a".to_string()))
5996 .transform(Transformation::ReplaceTable {
5997 from: "a".to_string(),
5998 to: "b".to_string(),
5999 })
6000 .build(),
6001 );
6002 rw
6003 }
6004
6005 #[test]
6006 fn matching_query_is_rewritten() {
6007 let res = rw_with_table_replace().rewrite("select * from a").unwrap();
6008 assert!(res.was_rewritten(), "rule did not fire");
6009 assert!(res.query().contains('b'), "rewritten: {}", res.query());
6010 assert!(
6011 !res.query().contains("from a"),
6012 "still references a: {}",
6013 res.query()
6014 );
6015 }
6016
6017 #[test]
6018 fn unmatched_query_is_unchanged() {
6019 let res = rw_with_table_replace()
6020 .rewrite("select * from other")
6021 .unwrap();
6022 assert!(!res.was_rewritten());
6023 assert_eq!(res.query(), "select * from other");
6024 }
6025 }
6026
6027 #[cfg(feature = "multi-tenancy")]
6030 mod multi_tenancy {
6031 use crate::multi_tenancy::{
6032 IdentificationMethod, IsolationStrategy, MultiTenancyConfig, TenantConfig, TenantId,
6033 TenantManager, TenantManagerBuilder, TenantQueryTransformer,
6034 };
6035
6036 fn manager() -> TenantManager {
6037 let transformer = TenantQueryTransformer::new().register_tables(&["t"], "tid");
6038 let tm = TenantManagerBuilder::new()
6039 .config(MultiTenancyConfig {
6040 enabled: true,
6041 identification: IdentificationMethod::Header {
6042 header_name: "application_name".to_string(),
6043 },
6044 ..Default::default()
6045 })
6046 .query_transformer(transformer)
6047 .build();
6048 tm.register_tenant(TenantConfig::new(
6049 TenantId::new("acme"),
6050 IsolationStrategy::row("public", "tid"),
6051 ));
6052 tm
6053 }
6054
6055 #[test]
6056 fn tenant_table_gets_filter() {
6057 let res = manager().transform_query("select * from t", &TenantId::new("acme"));
6058 assert!(res.transformed, "expected a tenant filter to be injected");
6059 let q = res.query.to_lowercase();
6060 assert!(
6061 q.contains("tid") && q.contains("acme"),
6062 "filter missing: {}",
6063 res.query
6064 );
6065 }
6066
6067 #[test]
6068 fn non_tenant_table_passes_through() {
6069 let res = manager().transform_query("select * from other", &TenantId::new("acme"));
6070 assert!(!res.transformed);
6071 }
6072 }
6073
6074 #[cfg(feature = "ha-tr")]
6077 mod ha_tr {
6078 use crate::transaction_journal::TransactionJournal;
6079 use crate::NodeId;
6080
6081 #[tokio::test]
6082 async fn journal_records_and_windows_a_statement() {
6083 let j = TransactionJournal::new();
6084 let from = chrono::Utc::now() - chrono::Duration::seconds(60);
6085 let tx = uuid::Uuid::new_v4();
6086 j.begin_transaction(tx, uuid::Uuid::new_v4(), NodeId::new(), 0)
6087 .await
6088 .unwrap();
6089 j.log_statement(
6090 tx,
6091 "insert into t values (1)".to_string(),
6092 Vec::new(),
6093 None,
6094 None,
6095 0,
6096 )
6097 .await
6098 .unwrap();
6099 let to = chrono::Utc::now() + chrono::Duration::seconds(60);
6100 let entries = j.entries_in_window(from, to).await;
6101 assert_eq!(entries.len(), 1, "journaled statement should be in window");
6102 assert!(entries[0].1.statement.contains("insert"));
6103 }
6104 }
6105
6106 #[cfg(feature = "schema-routing")]
6109 mod schema_routing {
6110 use crate::schema_routing::{QueryAnalyzer, SchemaRegistry};
6111 use std::sync::Arc;
6112
6113 fn analyzer() -> QueryAnalyzer {
6114 QueryAnalyzer::new(Arc::new(SchemaRegistry::new()))
6115 }
6116
6117 #[test]
6118 fn aggregation_group_by_is_analytics() {
6119 let a = analyzer();
6120 assert!(a
6121 .analyze("select count(*) from orders group by region")
6122 .is_analytics());
6123 }
6124
6125 #[test]
6126 fn simple_point_query_is_not_analytics() {
6127 let a = analyzer();
6128 assert!(!a
6129 .analyze("select * from orders where id = 1")
6130 .is_analytics());
6131 }
6132 }
6133}