1use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::client_tls::{build_tls_acceptor, ClientStream};
10use crate::config::{HbaAction, HbaRule, NodeConfig, NodeRole, ProxyConfig, TrMode};
11#[cfg(feature = "wasm-plugins")]
12use crate::protocol::QueryMessage;
13use crate::protocol::{
14 ErrorResponse, Message, MessageType, ProtocolCodec, StartupMessage, TransactionStatus,
15};
16use crate::{ProxyError, Result};
17use arc_swap::ArcSwap;
18use bytes::{BufMut, BytesMut};
19use dashmap::DashMap;
20use std::collections::{HashMap, HashSet};
21use std::net::SocketAddr;
22use std::sync::atomic::{AtomicU64, Ordering};
23use std::sync::Arc;
24use std::time::Duration;
25use tokio::io::{AsyncReadExt, AsyncWriteExt};
26use tokio::net::{TcpListener, TcpStream};
27use tokio::sync::{broadcast, RwLock};
28use uuid::Uuid;
29
30#[cfg(feature = "pool-modes")]
32use crate::pool::lease::ClientId;
33#[cfg(feature = "pool-modes")]
34use crate::pool::{ConnectionPoolManager, PoolModeConfig, PoolingMode};
35#[cfg(feature = "pool-modes")]
36use crate::NodeEndpoint;
37
38#[cfg(feature = "wasm-plugins")]
40use crate::plugins::{
41 AuthRequest as PluginAuthRequest, AuthResult, HookContext, HookType, Identity, PluginManager,
42 PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
43};
44
45pub struct ProxyServer {
47 config: ProxyConfig,
48 state: Arc<ServerState>,
49 shutdown_tx: broadcast::Sender<()>,
50 config_path: Option<String>,
54}
55
56#[cfg(not(unix))]
59struct HangupNever;
60#[cfg(not(unix))]
61impl HangupNever {
62 async fn recv(&mut self) -> Option<()> {
63 std::future::pending().await
64 }
65}
66
67#[cfg(feature = "ha-tr")]
81fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
82 BackendConfig {
83 host: "placeholder".to_string(),
84 port: 0,
85 user: "postgres".to_string(),
86 password: None,
87 database: None,
88 application_name: Some("heliosdb-proxy-replay".to_string()),
89 tls_mode: TlsMode::Disable,
90 connect_timeout: Duration::from_secs(5),
91 query_timeout: Duration::from_secs(30),
92 tls_config: default_client_config(),
93 }
94}
95
96#[cfg(feature = "anomaly-detection")]
106fn anomaly_fingerprint(sql: &str) -> String {
107 let mut out = String::with_capacity(sql.len());
108 let mut in_single = false;
109 let mut prev_space = false;
110 let mut chars = sql.chars().peekable();
111 while let Some(c) = chars.next() {
112 if c == '\'' {
113 in_single = !in_single;
114 if in_single {
117 out.push('?');
118 while let Some(&n) = chars.peek() {
119 chars.next();
120 if n == '\'' {
121 in_single = false;
122 break;
123 }
124 }
125 prev_space = false;
126 continue;
127 }
128 }
129 if c.is_ascii_digit() {
130 if !out.ends_with('?') {
131 out.push('?');
132 }
133 while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
135 chars.next();
136 }
137 prev_space = false;
138 continue;
139 }
140 if c.is_ascii_whitespace() {
141 if !prev_space && !out.is_empty() {
142 out.push(' ');
143 prev_space = true;
144 }
145 continue;
146 }
147 out.push(c.to_ascii_lowercase());
148 prev_space = false;
149 }
150 out.trim_end().to_string()
151}
152
153struct ServerState {
155 sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
157 health: ArcSwap<HashMap<String, NodeHealth>>,
163 health_write: parking_lot::Mutex<()>,
168 live_config: ArcSwap<ProxyConfig>,
175 metrics: ServerMetrics,
177 cancel_map: Arc<DashMap<(u32, u32), String>>,
183 cancel_order: Arc<parking_lot::Mutex<std::collections::VecDeque<(u32, u32)>>>,
187 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
190 auth_file: Option<Arc<crate::auth_scram::AuthFile>>,
194 mirror: Option<crate::mirror::MirrorHandle>,
197 cutover: Arc<ArcSwap<Option<Arc<crate::mirror::CutoverTarget>>>>,
201 lb_state: LoadBalancerState,
203 #[cfg(feature = "routing-hints")]
208 hint_parser: Option<crate::routing::HintParser>,
209 #[cfg(feature = "rate-limiting")]
212 rate_limiter: Option<Arc<crate::rate_limit::RateLimiter>>,
213 #[cfg(feature = "circuit-breaker")]
217 circuit_breaker: Option<Arc<crate::circuit_breaker::CircuitBreakerManager>>,
218 #[cfg(feature = "query-analytics")]
221 analytics: Option<Arc<crate::analytics::QueryAnalytics>>,
222 #[cfg(feature = "query-cache")]
225 query_cache: Option<Arc<crate::cache::QueryCache>>,
226 #[cfg(feature = "query-rewriting")]
229 rewriter: Option<Arc<crate::rewriter::QueryRewriter>>,
230 #[cfg(feature = "multi-tenancy")]
233 tenant_manager: Option<Arc<crate::multi_tenancy::TenantManager>>,
234 #[cfg(feature = "schema-routing")]
237 schema_analyzer: Option<Arc<crate::schema_routing::QueryAnalyzer>>,
238 #[cfg(feature = "pool-modes")]
240 pool_manager: Option<Arc<ConnectionPoolManager>>,
241 #[cfg(feature = "pool-modes")]
246 backend_pool: Option<Arc<crate::pool::BackendIdlePool>>,
247 #[cfg(feature = "wasm-plugins")]
251 plugin_manager: Option<Arc<PluginManager>>,
252 #[cfg(feature = "ha-tr")]
257 transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
258 #[cfg(feature = "anomaly-detection")]
261 anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
262 #[cfg(feature = "edge-proxy")]
266 edge_cache: Arc<crate::edge::EdgeCache>,
267 #[cfg(feature = "edge-proxy")]
268 edge_registry: Arc<crate::edge::EdgeRegistry>,
269}
270
271#[derive(Debug, Clone)]
273pub struct NodeHealth {
274 pub address: String,
276 pub healthy: bool,
278 pub last_check: chrono::DateTime<chrono::Utc>,
280 pub failure_count: u32,
282 pub last_error: Option<String>,
284 pub latency_ms: f64,
286 pub replication_lag_bytes: Option<u64>,
288}
289
290#[derive(Default)]
292struct ServerMetrics {
293 connections_accepted: AtomicU64,
295 connections_closed: AtomicU64,
297 queries_processed: AtomicU64,
299 bytes_received: AtomicU64,
301 bytes_sent: AtomicU64,
303 failovers: AtomicU64,
305}
306
307struct LoadBalancerState {
309 rr_counter: AtomicU64,
312}
313
314pub struct ClientSession {
316 pub id: Uuid,
318 pub client_addr: SocketAddr,
320 pub current_node: RwLock<Option<String>>,
322 pub tx_state: RwLock<TransactionState>,
324 pub variables: RwLock<HashMap<String, String>>,
326 pub created_at: chrono::DateTime<chrono::Utc>,
328 pub tr_mode: TrMode,
330 #[cfg(feature = "lag-routing")]
335 pub last_write_at: RwLock<Option<std::time::Instant>>,
336 #[cfg(feature = "pool-modes")]
338 pub pool_client_id: ClientId,
339 #[cfg(feature = "wasm-plugins")]
344 pub plugin_identity: RwLock<Option<Identity>>,
345}
346
347#[derive(Debug, Clone, Default)]
349pub struct TransactionState {
350 pub in_transaction: bool,
352 pub tx_id: Option<Uuid>,
354 pub statements: Vec<StatementLog>,
356 pub read_only: bool,
358 pub savepoints: Vec<String>,
360}
361
362#[derive(Debug, Clone)]
364pub struct StatementLog {
365 pub sql: String,
367 pub params: Vec<String>,
369 pub result_checksum: Option<u64>,
371 pub executed_at: chrono::DateTime<chrono::Utc>,
373}
374
375struct BackendConn {
388 stream: TcpStream,
389 prepared: HashSet<String>,
390 unnamed_sig: Option<bytes::Bytes>,
396}
397
398impl BackendConn {
399 fn new(stream: TcpStream) -> Self {
400 Self {
401 stream,
402 prepared: HashSet::new(),
403 unnamed_sig: None,
404 }
405 }
406}
407
408pub(crate) fn bind_reuseport(addr: &str) -> Result<TcpListener> {
414 use socket2::{Domain, Protocol, Socket, Type};
415 let sockaddr: SocketAddr = addr
416 .parse()
417 .map_err(|e| ProxyError::Config(format!("invalid listen address '{}': {}", addr, e)))?;
418 let domain = if sockaddr.is_ipv6() {
419 Domain::IPV6
420 } else {
421 Domain::IPV4
422 };
423 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
424 .map_err(|e| ProxyError::Network(format!("socket(): {}", e)))?;
425 socket
426 .set_reuse_address(true)
427 .map_err(|e| ProxyError::Network(format!("SO_REUSEADDR: {}", e)))?;
428 #[cfg(all(unix, not(target_os = "solaris")))]
429 socket
430 .set_reuse_port(true)
431 .map_err(|e| ProxyError::Network(format!("SO_REUSEPORT: {}", e)))?;
432 socket
433 .set_nonblocking(true)
434 .map_err(|e| ProxyError::Network(format!("set_nonblocking: {}", e)))?;
435 socket
436 .bind(&sockaddr.into())
437 .map_err(|e| ProxyError::Network(format!("Failed to bind {}: {}", addr, e)))?;
438 socket
439 .listen(1024)
440 .map_err(|e| ProxyError::Network(format!("listen(): {}", e)))?;
441 let std_listener: std::net::TcpListener = socket.into();
442 TcpListener::from_std(std_listener)
443 .map_err(|e| ProxyError::Network(format!("from_std listener: {}", e)))
444}
445
446#[derive(Debug)]
452#[allow(dead_code)] enum PreQueryAction {
454 Forward,
456 Block(String),
459 Cached(Vec<u8>),
465}
466
467#[derive(Debug)]
473#[allow(dead_code)] enum RouteOverride {
475 None,
477 Primary,
479 Standby,
481 Node(String),
485 Block(String),
490}
491
492impl ProxyServer {
493 #[cfg(feature = "wasm-plugins")]
501 fn init_plugin_manager(
502 toml_cfg: &crate::config::PluginToml,
503 ) -> Option<Arc<crate::plugins::PluginManager>> {
504 if !toml_cfg.enabled {
505 return None;
506 }
507
508 let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
509 let plugin_dir = runtime_cfg.plugin_dir.clone();
510
511 let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
512 Ok(pm) => Arc::new(pm),
513 Err(e) => {
514 tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
515 return None;
516 }
517 };
518
519 match std::fs::read_dir(&plugin_dir) {
520 Ok(entries) => {
521 let mut loaded = 0usize;
522 let mut failed = 0usize;
523 for entry in entries.flatten() {
524 let path = entry.path();
525 if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
526 continue;
527 }
528 match pm.load_plugin(&path) {
529 Ok(()) => loaded += 1,
530 Err(e) => {
531 failed += 1;
532 tracing::warn!(
533 path = %path.display(),
534 error = %e,
535 "Failed to load plugin"
536 );
537 }
538 }
539 }
540 tracing::info!(
541 dir = %plugin_dir.display(),
542 loaded = loaded,
543 failed = failed,
544 "Plugin loading complete"
545 );
546 }
547 Err(e) => {
548 tracing::warn!(
549 dir = %plugin_dir.display(),
550 error = %e,
551 "Plugin directory not readable; no plugins loaded"
552 );
553 }
554 }
555
556 Some(pm)
557 }
558
559 pub fn new(config: ProxyConfig) -> Result<Self> {
561 let (shutdown_tx, _) = broadcast::channel(1);
562
563 let mut health = HashMap::new();
565 for node in &config.nodes {
566 health.insert(
567 node.address(),
568 NodeHealth {
569 address: node.address(),
570 healthy: true, last_check: chrono::Utc::now(),
572 failure_count: 0,
573 last_error: None,
574 latency_ms: 0.0,
575 replication_lag_bytes: None,
576 },
577 );
578 }
579
580 #[cfg(feature = "pool-modes")]
582 let pool_manager = {
583 use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
584
585 let pool_config = PoolModeConfig {
586 default_mode: match config.pool_mode.mode {
587 crate::config::PoolingMode::Session => PoolingMode::Session,
588 crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
589 crate::config::PoolingMode::Statement => PoolingMode::Statement,
590 },
591 max_pool_size: config.pool_mode.max_pool_size,
592 min_idle: config.pool_mode.min_idle,
593 idle_timeout_secs: config.pool_mode.idle_timeout_secs,
594 max_lifetime_secs: config.pool_mode.max_lifetime_secs,
595 acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
596 reset_query: config.pool_mode.reset_query.clone(),
597 prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
598 crate::config::PreparedStatementMode::Disable => {
599 PoolPreparedStatementMode::Disable
600 }
601 crate::config::PreparedStatementMode::Track => PoolPreparedStatementMode::Track,
602 crate::config::PreparedStatementMode::Named => PoolPreparedStatementMode::Named,
603 },
604 test_on_acquire: config.pool.test_on_acquire,
605 validation_query: "SELECT 1".to_string(),
606 queue_timeout_secs: 30,
607 max_queue_size: 0,
608 };
609 Some(Arc::new(ConnectionPoolManager::new(pool_config)))
610 };
611
612 #[cfg(feature = "pool-modes")]
616 let backend_pool = match config.pool_mode.mode {
617 crate::config::PoolingMode::Transaction | crate::config::PoolingMode::Statement => {
618 tracing::info!(
619 mode = ?config.pool_mode.mode,
620 max_idle_per_identity = config.pool_mode.max_pool_size,
621 "pool-modes: data-path connection pooling enabled"
622 );
623 Some(Arc::new(crate::pool::BackendIdlePool::new(
624 config.pool_mode.max_pool_size as usize,
625 Self::MAX_TOTAL_IDLE_BACKEND_CONNS,
626 )))
627 }
628 crate::config::PoolingMode::Session => None,
629 };
630
631 #[cfg(feature = "wasm-plugins")]
636 let plugin_manager = Self::init_plugin_manager(&config.plugins);
637
638 let tls_acceptor = match config.tls.as_ref() {
642 Some(tls) if tls.enabled => match build_tls_acceptor(tls) {
643 Ok(acc) => {
644 tracing::info!(
645 mtls = tls.require_client_cert,
646 "client TLS termination enabled"
647 );
648 Some(acc)
649 }
650 Err(e) => {
651 return Err(ProxyError::Config(format!("TLS init failed: {}", e)));
652 }
653 },
654 _ => None,
655 };
656
657 let auth_file = if config.auth.mode == crate::config::AuthMode::Scram {
660 let path = config.auth.auth_file.as_ref().ok_or_else(|| {
661 ProxyError::Config("auth mode 'scram' requires auth_file".to_string())
662 })?;
663 let af = crate::auth_scram::AuthFile::load(path)
664 .map_err(|e| ProxyError::Config(format!("auth_file: {}", e)))?;
665 tracing::info!(users = %(!af.is_empty()), "proxy SCRAM auth enabled");
666 Some(Arc::new(af))
667 } else {
668 None
669 };
670
671 let mirror = if config.mirror.enabled {
674 tracing::info!(target = %format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
675 writes_only = config.mirror.writes_only, "traffic mirroring enabled");
676 Some(crate::mirror::spawn(config.mirror.clone()))
677 } else {
678 None
679 };
680
681 #[cfg(feature = "rate-limiting")]
683 let rate_limiter = if config.rate_limit.enabled {
684 let rl = &config.rate_limit;
685 tracing::info!(
686 qps = rl.default_qps,
687 burst = rl.default_burst,
688 key_by = ?rl.key_by,
689 "rate limiting enabled"
690 );
691 let rlc = crate::rate_limit::RateLimitConfig {
692 enabled: true,
693 default_qps: rl.default_qps,
694 default_burst: rl.default_burst,
695 default_concurrency: if rl.max_concurrent > 0 {
696 rl.max_concurrent
697 } else {
698 crate::rate_limit::RateLimitConfig::default().default_concurrency
699 },
700 ..Default::default()
701 };
702 Some(Arc::new(crate::rate_limit::RateLimiter::new(rlc)))
703 } else {
704 None
705 };
706
707 #[cfg(feature = "circuit-breaker")]
709 let circuit_breaker = if config.circuit_breaker.enabled {
710 let cb = &config.circuit_breaker;
711 tracing::info!(
712 failure_threshold = cb.failure_threshold,
713 open_secs = cb.open_secs,
714 "circuit breaker enabled"
715 );
716 let cbc = crate::circuit_breaker::CircuitBreakerConfig {
717 failure_threshold: cb.failure_threshold,
718 cooldown: Duration::from_secs(cb.open_secs),
719 half_open_success_threshold: cb.success_threshold,
720 ..Default::default()
721 };
722 let mgr = crate::circuit_breaker::CircuitBreakerManager::new(
723 crate::circuit_breaker::ManagerConfig::new(cbc),
724 );
725 Some(Arc::new(mgr))
726 } else {
727 None
728 };
729
730 #[cfg(feature = "query-analytics")]
732 let analytics = if config.analytics.enabled {
733 let a = &config.analytics;
734 tracing::info!(
735 slow_query_ms = a.slow_query_ms,
736 max_fingerprints = a.max_fingerprints,
737 "query analytics enabled"
738 );
739 let ac = crate::analytics::AnalyticsConfig {
740 enabled: true,
741 max_fingerprints: a.max_fingerprints as usize,
742 slow_query: crate::analytics::SlowQueryConfig {
743 threshold: Duration::from_millis(a.slow_query_ms),
744 ..Default::default()
745 },
746 ..Default::default()
747 };
748 Some(Arc::new(crate::analytics::QueryAnalytics::new(ac)))
749 } else {
750 None
751 };
752
753 #[cfg(feature = "query-cache")]
755 let query_cache = if config.cache.enabled {
756 let c = &config.cache;
757 tracing::info!(
758 ttl_secs = c.ttl_secs,
759 max_result_bytes = c.max_result_bytes,
760 "query cache enabled (L1 hot + L2 warm)"
761 );
762 let ttl = Duration::from_secs(c.ttl_secs);
763 let cc = crate::cache::CacheConfig {
764 enabled: true,
765 default_ttl: ttl,
766 max_result_size: c.max_result_bytes,
767 l1: crate::cache::L1Config {
768 ttl,
769 ..Default::default()
770 },
771 l2: crate::cache::L2Config {
772 ttl,
773 ..Default::default()
774 },
775 ..Default::default()
776 };
777 Some(Arc::new(crate::cache::QueryCache::new(cc)))
778 } else {
779 None
780 };
781
782 #[cfg(feature = "query-rewriting")]
784 let rewriter = if config.query_rewrite.enabled && !config.query_rewrite.rules.is_empty() {
785 use crate::rewriter::{
786 QueryPattern, QueryRewriter, RewriteRule, RewriterConfig, Transformation,
787 };
788 let rw = QueryRewriter::new(RewriterConfig {
789 enabled: true,
790 ..Default::default()
791 });
792 let mut n = 0usize;
793 for (i, r) in config.query_rewrite.rules.iter().enumerate() {
794 let transformation =
795 if let (Some(from), Some(to)) = (&r.match_table, &r.replace_table_with) {
796 Transformation::ReplaceTable {
797 from: from.clone(),
798 to: to.clone(),
799 }
800 } else if let Some(w) = &r.append_where {
801 Transformation::AppendWhereAnd(w.clone())
802 } else if let Some(limit) = r.add_limit {
803 Transformation::AddLimit(limit)
804 } else {
805 continue; };
807 let pattern = if let Some(t) = &r.match_table {
808 QueryPattern::Table(t.clone())
809 } else if let Some(re) = &r.match_regex {
810 QueryPattern::regex(re.clone())
811 } else {
812 QueryPattern::All
813 };
814 rw.add_rule(
815 RewriteRule::build(format!("rule-{i}"))
816 .pattern(pattern)
817 .transform(transformation)
818 .build(),
819 );
820 n += 1;
821 }
822 tracing::info!(rules = n, "query rewriting enabled");
823 Some(Arc::new(rw))
824 } else {
825 None
826 };
827
828 #[cfg(feature = "multi-tenancy")]
830 let tenant_manager =
831 if config.multi_tenancy.enabled && !config.multi_tenancy.tenants.is_empty() {
832 use crate::multi_tenancy::{
833 IdentificationMethod, IsolationStrategy, MultiTenancyConfig, TenantConfig,
834 TenantId, TenantManagerBuilder, TenantQueryTransformer,
835 };
836 let mt = &config.multi_tenancy;
837 let identification = match mt.identify_by.as_str() {
838 "database" => IdentificationMethod::DatabaseName,
839 param => IdentificationMethod::Header {
840 header_name: param.to_string(),
841 },
842 };
843 let mtc = MultiTenancyConfig {
844 enabled: true,
845 identification,
846 ..Default::default()
847 };
848 let table_refs: Vec<&str> = mt.tenant_tables.iter().map(|s| s.as_str()).collect();
850 let transformer = TenantQueryTransformer::new()
851 .register_tables(&table_refs, mt.tenant_column.clone());
852 let tm = TenantManagerBuilder::new()
853 .config(mtc)
854 .query_transformer(transformer)
855 .build();
856 for id in &mt.tenants {
857 tm.register_tenant(TenantConfig::new(
858 TenantId::new(id.clone()),
859 IsolationStrategy::row("public", mt.tenant_column.clone()),
860 ));
861 }
862 tracing::info!(
863 tenants = mt.tenants.len(),
864 identify_by = %mt.identify_by,
865 "multi-tenancy enabled"
866 );
867 Some(Arc::new(tm))
868 } else {
869 None
870 };
871
872 #[cfg(feature = "schema-routing")]
874 let schema_analyzer =
875 if config.schema_routing.enabled && !config.schema_routing.analytics_node.is_empty() {
876 tracing::info!(
877 analytics_node = %config.schema_routing.analytics_node,
878 "schema/workload routing enabled (OLAP -> analytics node)"
879 );
880 let registry = Arc::new(crate::schema_routing::SchemaRegistry::new());
881 Some(Arc::new(crate::schema_routing::QueryAnalyzer::new(
882 registry,
883 )))
884 } else {
885 None
886 };
887
888 let state = Arc::new(ServerState {
889 sessions: RwLock::new(HashMap::new()),
890 health: ArcSwap::from_pointee(health),
891 health_write: parking_lot::Mutex::new(()),
892 live_config: ArcSwap::from_pointee(config.clone()),
893 metrics: ServerMetrics::default(),
894 cancel_map: Arc::new(DashMap::new()),
895 cancel_order: Arc::new(parking_lot::Mutex::new(std::collections::VecDeque::new())),
896 tls_acceptor,
897 auth_file,
898 mirror,
899 cutover: Arc::new(ArcSwap::from_pointee(None)),
900 lb_state: LoadBalancerState {
901 rr_counter: AtomicU64::new(0),
902 },
903 #[cfg(feature = "routing-hints")]
904 hint_parser: if config.routing_hints.enabled {
905 tracing::info!(
906 strip = config.routing_hints.strip_hints,
907 "SQL-comment routing hints enabled"
908 );
909 Some(if config.routing_hints.strip_hints {
910 crate::routing::HintParser::new()
911 } else {
912 crate::routing::HintParser::without_stripping()
913 })
914 } else {
915 None
916 },
917 #[cfg(feature = "rate-limiting")]
918 rate_limiter,
919 #[cfg(feature = "circuit-breaker")]
920 circuit_breaker,
921 #[cfg(feature = "query-analytics")]
922 analytics,
923 #[cfg(feature = "query-cache")]
924 query_cache,
925 #[cfg(feature = "query-rewriting")]
926 rewriter,
927 #[cfg(feature = "multi-tenancy")]
928 tenant_manager,
929 #[cfg(feature = "schema-routing")]
930 schema_analyzer,
931 #[cfg(feature = "pool-modes")]
932 pool_manager,
933 #[cfg(feature = "pool-modes")]
934 backend_pool,
935 #[cfg(feature = "wasm-plugins")]
936 plugin_manager,
937 #[cfg(feature = "ha-tr")]
938 transaction_journal: Arc::new(crate::transaction_journal::TransactionJournal::new()),
939 #[cfg(feature = "anomaly-detection")]
940 anomaly_detector: Arc::new(crate::anomaly::AnomalyDetector::new(
941 crate::anomaly::AnomalyConfig::default(),
942 )),
943 #[cfg(feature = "edge-proxy")]
944 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
945 #[cfg(feature = "edge-proxy")]
946 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
947 32,
948 std::time::Duration::from_secs(120),
949 )),
950 });
951
952 Ok(Self {
953 config,
954 state,
955 shutdown_tx,
956 config_path: None,
957 })
958 }
959
960 pub fn with_config_path(mut self, path: Option<String>) -> Self {
964 self.config_path = path;
965 self
966 }
967
968 #[cfg(unix)]
971 fn hangup_stream() -> tokio::signal::unix::Signal {
972 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
973 .expect("failed to install SIGHUP handler")
974 }
975 #[cfg(not(unix))]
976 fn hangup_stream() -> HangupNever {
977 HangupNever
978 }
979
980 #[cfg(unix)]
983 fn usr2_stream() -> tokio::signal::unix::Signal {
984 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())
985 .expect("failed to install SIGUSR2 handler")
986 }
987 #[cfg(not(unix))]
988 fn usr2_stream() -> HangupNever {
989 HangupNever
990 }
991
992 async fn drain_connections(state: &Arc<ServerState>, timeout: Duration) {
996 let deadline = tokio::time::Instant::now() + timeout;
997 loop {
998 let active = state.sessions.read().await.len();
999 if active == 0 {
1000 tracing::info!("drain complete — all in-flight connections finished");
1001 return;
1002 }
1003 if tokio::time::Instant::now() >= deadline {
1004 tracing::warn!(
1005 active,
1006 "drain timeout reached — exiting with connections still open"
1007 );
1008 return;
1009 }
1010 tokio::time::sleep(Duration::from_millis(200)).await;
1011 }
1012 }
1013
1014 fn drain_timeout(config_secs: u64) -> Duration {
1019 let secs = std::env::var("HELIOS_DRAIN_TIMEOUT_SECS")
1020 .ok()
1021 .and_then(|s| s.parse::<u64>().ok())
1022 .unwrap_or(config_secs);
1023 Duration::from_secs(secs)
1024 }
1025
1026 async fn reload_config(&self) {
1035 let Some(path) = self.config_path.as_deref() else {
1036 tracing::warn!(
1037 "SIGHUP received but config was not loaded from a file — nothing to reload"
1038 );
1039 return;
1040 };
1041 tracing::info!(path, "SIGHUP: reloading configuration");
1042 let new_config = match ProxyConfig::from_file(path) {
1043 Ok(c) => c,
1044 Err(e) => {
1045 tracing::error!(path, error = %e, "SIGHUP reload failed to parse — keeping current config");
1046 return;
1047 }
1048 };
1049 let old = self.state.live_config.load_full();
1050 if new_config.listen_address != old.listen_address {
1051 tracing::warn!(old = %old.listen_address, new = %new_config.listen_address,
1052 "listen_address change needs a restart/handoff; the bound socket is kept");
1053 }
1054 if new_config.admin_address != old.admin_address {
1055 tracing::warn!(old = %old.admin_address, new = %new_config.admin_address,
1056 "admin_address change needs a restart; the bound socket is kept");
1057 }
1058 Self::reconcile_health(&self.state, &new_config);
1061 let nodes = new_config.nodes.len();
1062 let hba_rules = new_config.hba.len();
1063 let pool_max = new_config.pool.max_connections;
1064 self.state.live_config.store(Arc::new(new_config));
1065 tracing::info!(
1066 nodes,
1067 hba_rules,
1068 pool_max,
1069 "SIGHUP: configuration reloaded — applies to new connections"
1070 );
1071 }
1072
1073 fn reconcile_health(state: &Arc<ServerState>, config: &ProxyConfig) {
1077 let _writers = state.health_write.lock();
1080 let current = state.health.load_full();
1081 let mut next: HashMap<String, NodeHealth> = HashMap::new();
1082 for node in &config.nodes {
1083 let addr = node.address();
1084 match current.get(&addr) {
1085 Some(existing) => {
1086 next.insert(addr, existing.clone());
1087 }
1088 None => {
1089 tracing::info!(node = %addr, "SIGHUP: new node added — seeding healthy");
1090 next.insert(
1091 addr.clone(),
1092 NodeHealth {
1093 address: addr,
1094 healthy: true,
1095 last_check: chrono::Utc::now(),
1096 failure_count: 0,
1097 last_error: None,
1098 latency_ms: 0.0,
1099 replication_lag_bytes: None,
1100 },
1101 );
1102 }
1103 }
1104 }
1105 for gone in current.keys().filter(|k| !next.contains_key(*k)) {
1106 tracing::info!(node = %gone, "SIGHUP: node removed from config");
1107 }
1108 state.health.store(Arc::new(next));
1109 }
1110
1111 pub async fn run(&self) -> Result<()> {
1113 let listener = bind_reuseport(&self.config.listen_address)?;
1119
1120 tracing::info!(
1121 "Proxy listening on {} (SO_REUSEPORT)",
1122 self.config.listen_address
1123 );
1124
1125 let health_task = self.spawn_health_checker();
1127 let pool_task = self.spawn_pool_manager();
1128
1129 let admin_task = self.spawn_admin_server();
1131
1132 let mcp_task = if self.config.mcp.enabled {
1134 let mcp_cfg = self.config.mcp.clone();
1135 let contract = mcp_cfg.contract.as_ref().and_then(|id| {
1137 let found = self.config.agent_contracts.iter().find(|c| &c.id == id).cloned();
1138 if found.is_none() {
1139 tracing::warn!(%id, "mcp.contract names an unknown agent_contract; gateway runs with only the read-only guardrail");
1140 }
1141 found
1142 });
1143 Some(tokio::spawn(async move {
1144 if let Err(e) = crate::mcp::McpServer::new(mcp_cfg, contract).run().await {
1145 tracing::error!("MCP gateway error: {}", e);
1146 }
1147 }))
1148 } else {
1149 None
1150 };
1151
1152 let http_gw_task = if self.config.http_gateway.enabled {
1154 let gw_cfg = self.config.http_gateway.clone();
1155 Some(tokio::spawn(async move {
1156 if let Err(e) = crate::http_gateway::HttpGateway::new(gw_cfg).run().await {
1157 tracing::error!("HTTP gateway error: {}", e);
1158 }
1159 }))
1160 } else {
1161 None
1162 };
1163
1164 #[cfg(feature = "graphql-gateway")]
1166 let _graphql_gw_task = if self.config.graphql_gateway.enabled {
1167 let gw_cfg = self.config.graphql_gateway.clone();
1168 Some(tokio::spawn(async move {
1169 if let Err(e) = crate::graphql_gateway::GraphqlGateway::new(gw_cfg)
1170 .run()
1171 .await
1172 {
1173 tracing::error!("GraphQL gateway error: {}", e);
1174 }
1175 }))
1176 } else {
1177 None
1178 };
1179
1180 let mut shutdown_rx = self.shutdown_tx.subscribe();
1181
1182 let mut sighup = Self::hangup_stream();
1186 let mut sigusr2 = Self::usr2_stream();
1187 let mut graceful = false;
1188
1189 loop {
1190 tokio::select! {
1191 _ = sighup.recv() => {
1192 self.reload_config().await;
1193 }
1194 _ = sigusr2.recv() => {
1195 tracing::info!(
1196 "SIGUSR2: graceful binary-handoff drain — closing the listener so new \
1197 connections route to the sibling process; finishing in-flight connections"
1198 );
1199 graceful = true;
1200 break;
1201 }
1202 accept_result = listener.accept() => {
1203 match accept_result {
1204 Ok((stream, addr)) => {
1205 let _ = stream.set_nodelay(true);
1209 self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
1210 let state = self.state.clone();
1211 let config = (*self.state.live_config.load_full()).clone();
1215 let shutdown_tx = self.shutdown_tx.clone();
1216
1217 tokio::spawn(async move {
1218 if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
1219 tracing::error!("Client handler error: {}", e);
1220 }
1221 });
1222 }
1223 Err(e) => {
1224 tracing::error!("Accept error: {}", e);
1225 }
1226 }
1227 }
1228 _ = shutdown_rx.recv() => {
1229 tracing::info!("Shutdown signal received");
1230 break;
1231 }
1232 }
1233 }
1234
1235 drop(listener);
1239
1240 if graceful {
1243 let timeout =
1244 Self::drain_timeout(self.state.live_config.load().shutdown_drain_timeout_secs);
1245 tracing::info!(
1246 timeout_secs = timeout.as_secs(),
1247 "draining in-flight connections"
1248 );
1249 Self::drain_connections(&self.state, timeout).await;
1250 }
1251
1252 health_task.abort();
1254 pool_task.abort();
1255 admin_task.abort();
1256 if let Some(t) = mcp_task {
1257 t.abort();
1258 }
1259 if let Some(t) = http_gw_task {
1260 t.abort();
1261 }
1262
1263 Ok(())
1264 }
1265
1266 fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
1268 let config = self.config.clone();
1269 let state = self.state.clone();
1270 let mut shutdown_rx = self.shutdown_tx.subscribe();
1271
1272 tokio::spawn(async move {
1273 let admin_state = Arc::new(AdminState::new());
1275
1276 {
1278 let mut snapshot = admin_state.config_snapshot.write().await;
1279 *snapshot = ConfigSnapshot {
1280 listen_address: config.listen_address.clone(),
1281 admin_address: config.admin_address.clone(),
1282 tr_enabled: config.tr_enabled,
1283 tr_mode: format!("{:?}", config.tr_mode),
1284 pool_min_connections: config.pool.min_connections,
1285 pool_max_connections: config.pool.max_connections,
1286 nodes: config
1287 .nodes
1288 .iter()
1289 .map(|n| NodeSnapshot {
1290 address: n.address(),
1291 role: format!("{:?}", n.role),
1292 weight: n.weight,
1293 enabled: n.enabled,
1294 })
1295 .collect(),
1296 };
1297 }
1298
1299 admin_state.set_proxy_config(config.clone()).await;
1301
1302 admin_state
1304 .with_auth_token(config.admin_token.clone())
1305 .await;
1306
1307 if config.branch.enabled {
1309 admin_state.with_branch(config.branch.clone()).await;
1310 }
1311
1312 if let Some(ref mirror) = state.mirror {
1314 admin_state
1315 .with_migration(crate::admin::MigrationInfo {
1316 target: mirror.target().to_string(),
1317 writes_only: mirror.writes_only(),
1318 metrics: mirror.metrics.clone(),
1319 config: config.mirror.clone(),
1320 cutover: state.cutover.clone(),
1321 cutover_target: crate::mirror::CutoverTarget {
1322 addr: format!(
1323 "{}:{}",
1324 config.mirror.backend_host, config.mirror.backend_port
1325 ),
1326 user: config.mirror.backend_user.clone(),
1327 password: config.mirror.backend_password.clone(),
1328 database: config.mirror.backend_database.clone(),
1329 },
1330 })
1331 .await;
1332 }
1333
1334 #[cfg(feature = "wasm-plugins")]
1339 if let Some(ref pm) = state.plugin_manager {
1340 admin_state.with_plugin_manager(pm.clone()).await;
1341 }
1342
1343 #[cfg(feature = "pool-modes")]
1346 if let Some(ref pm) = state.pool_manager {
1347 admin_state.with_pool_manager(pm.clone()).await;
1348 }
1349
1350 #[cfg(feature = "circuit-breaker")]
1353 if let Some(ref cb) = state.circuit_breaker {
1354 admin_state.with_circuit_breaker(cb.clone()).await;
1355 }
1356
1357 #[cfg(feature = "ha-tr")]
1364 {
1365 let template = build_replay_backend_template(&config);
1366 let engine = Arc::new(crate::replay::ReplayEngine::new(
1367 state.transaction_journal.clone(),
1368 template,
1369 ));
1370 admin_state.with_replay_engine(engine).await;
1371 }
1372
1373 #[cfg(feature = "anomaly-detection")]
1377 admin_state
1378 .with_anomaly_detector(state.anomaly_detector.clone())
1379 .await;
1380
1381 #[cfg(feature = "query-analytics")]
1383 if let Some(a) = state.analytics.as_ref() {
1384 admin_state.with_analytics(a.clone()).await;
1385 }
1386
1387 #[cfg(feature = "edge-proxy")]
1390 admin_state
1391 .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
1392 .await;
1393
1394 let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
1396
1397 let admin_state_sync = admin_state.clone();
1399 let server_state = state.clone();
1400 let sync_task = tokio::spawn(async move {
1401 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
1402 loop {
1403 interval.tick().await;
1404
1405 {
1407 let health = server_state.health.load_full();
1408 let mut admin_health = admin_state_sync.node_health.write().await;
1409 *admin_health = (*health).clone();
1410 }
1411
1412 {
1414 let metrics = ServerMetricsSnapshot {
1415 connections_accepted: server_state
1416 .metrics
1417 .connections_accepted
1418 .load(Ordering::Relaxed),
1419 connections_closed: server_state
1420 .metrics
1421 .connections_closed
1422 .load(Ordering::Relaxed),
1423 queries_processed: server_state
1424 .metrics
1425 .queries_processed
1426 .load(Ordering::Relaxed),
1427 bytes_received: server_state
1428 .metrics
1429 .bytes_received
1430 .load(Ordering::Relaxed),
1431 bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
1432 failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
1433 };
1434 let mut admin_metrics = admin_state_sync.metrics.write().await;
1435 *admin_metrics = metrics;
1436 }
1437
1438 {
1440 let sessions = server_state.sessions.read().await;
1441 let mut admin_sessions = admin_state_sync.active_sessions.write().await;
1442 *admin_sessions = sessions.len() as u64;
1443 }
1444 }
1445 });
1446
1447 tokio::select! {
1449 result = admin_server.run() => {
1450 if let Err(e) = result {
1451 tracing::error!("Admin server error: {}", e);
1452 }
1453 }
1454 _ = shutdown_rx.recv() => {
1455 tracing::info!("Admin server shutting down");
1456 }
1457 }
1458
1459 sync_task.abort();
1460 })
1461 }
1462
1463 async fn handle_client(
1465 stream: TcpStream,
1466 addr: SocketAddr,
1467 state: Arc<ServerState>,
1468 config: ProxyConfig,
1469 _shutdown_tx: broadcast::Sender<()>,
1470 ) -> Result<()> {
1471 tracing::debug!("New client connection from {}", addr);
1472
1473 let session = Arc::new(ClientSession {
1475 id: Uuid::new_v4(),
1476 client_addr: addr,
1477 current_node: RwLock::new(None),
1478 tx_state: RwLock::new(TransactionState::default()),
1479 variables: RwLock::new(HashMap::new()),
1480 created_at: chrono::Utc::now(),
1481 tr_mode: config.tr_mode,
1482 #[cfg(feature = "lag-routing")]
1483 last_write_at: RwLock::new(None),
1484 #[cfg(feature = "pool-modes")]
1485 pool_client_id: ClientId::new(),
1486 #[cfg(feature = "wasm-plugins")]
1487 plugin_identity: RwLock::new(None),
1488 });
1489
1490 {
1492 let mut sessions = state.sessions.write().await;
1493 sessions.insert(session.id, session.clone());
1494 }
1495
1496 let result = match Self::negotiate_client_tls(stream, &state).await {
1501 Ok((mut client_stream, pre)) => {
1502 Self::client_loop(&mut client_stream, pre, &session, &state, &config).await
1503 }
1504 Err(e) => Err(e),
1505 };
1506
1507 {
1509 let mut sessions = state.sessions.write().await;
1510 sessions.remove(&session.id);
1511 }
1512
1513 #[cfg(feature = "pool-modes")]
1515 if let Some(ref pool_manager) = state.pool_manager {
1516 if pool_manager.has_active_lease(&session.pool_client_id) {
1518 tracing::debug!(
1519 "Releasing pool lease for disconnecting client {:?}",
1520 session.pool_client_id
1521 );
1522 }
1525 }
1526
1527 state
1528 .metrics
1529 .connections_closed
1530 .fetch_add(1, Ordering::Relaxed);
1531
1532 result
1533 }
1534
1535 async fn client_loop(
1537 stream: &mut ClientStream,
1538 pre: Option<StartupMessage>,
1539 session: &Arc<ClientSession>,
1540 state: &Arc<ServerState>,
1541 config: &ProxyConfig,
1542 ) -> Result<()> {
1543 let codec = ProtocolCodec::new();
1544 let mut buffer = BytesMut::with_capacity(8192);
1545
1546 let mut conns: HashMap<String, BackendConn> = HashMap::new();
1563 let mut current_node: Option<String> =
1564 match Self::handle_startup(stream, &mut buffer, &codec, pre, session, state, config)
1565 .await
1566 {
1567 Ok((Some(stream_conn), node_addr)) => {
1568 conns.insert(node_addr.clone(), BackendConn::new(stream_conn));
1569 Some(node_addr)
1570 }
1571 Ok((None, _)) => {
1572 return Ok(());
1574 }
1575 Err(e) => {
1576 tracing::error!("Startup failed: {}", e);
1577 let err_msg =
1579 Self::create_error_response("08006", &format!("Startup failed: {}", e));
1580 let _ = stream.write_all(&err_msg).await;
1581 return Err(e);
1582 }
1583 };
1584
1585 let mut read_buf = vec![0u8; 16384];
1600 let mut pending = BytesMut::new();
1601 let mut pending_route_sql: Option<String> = None;
1602 let mut stmt_registry: HashMap<String, bytes::Bytes> = HashMap::new();
1610 let mut stmt_registry_bytes: usize = 0;
1613 let mut batch_defines: Vec<String> = Vec::new();
1614 let mut batch_refs: Vec<String> = Vec::new();
1615 let mut batch_closes: Vec<String> = Vec::new();
1616 let promote_unnamed = config.optimize_unnamed_parse;
1622 let mut held_unnamed: Option<(bytes::Bytes, bytes::Bytes)> = None;
1623 loop {
1624 let n = stream
1626 .read(&mut read_buf)
1627 .await
1628 .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
1629
1630 if n == 0 {
1631 break;
1633 }
1634
1635 buffer.extend_from_slice(&read_buf[..n]);
1636 state
1637 .metrics
1638 .bytes_received
1639 .fetch_add(n as u64, Ordering::Relaxed);
1640
1641 if buffer.len() > Self::MAX_PENDING_BYTES {
1645 let emsg =
1646 Self::create_error_response("53400", "message exceeds per-session size limit");
1647 let _ = stream.write_all(&emsg).await;
1648 let _ = stream.write_all(&Self::create_ready_for_query(b'I')).await;
1649 tracing::warn!(
1650 client = %session.client_addr,
1651 bytes = buffer.len(),
1652 "inbound message exceeds size cap; closing connection"
1653 );
1654 return Ok(());
1655 }
1656
1657 while let Some(msg) = codec.decode_message(&mut buffer)? {
1659 match msg.msg_type {
1660 MessageType::Terminate => return Ok(()),
1661
1662 MessageType::Query => {
1664 #[cfg(feature = "anomaly-detection")]
1668 Self::record_anomaly_observation(&msg, state, session);
1669
1670 let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
1673
1674 if let PreQueryAction::Block(reason) = &action {
1675 tracing::info!(reason = %reason, "pre-query plugin blocked query");
1676 Self::send_block_response(stream, reason, state).await?;
1677 state
1678 .metrics
1679 .queries_processed
1680 .fetch_add(1, Ordering::Relaxed);
1681 continue;
1682 }
1683
1684 #[cfg(feature = "wasm-plugins")]
1685 if let PreQueryAction::Cached(bytes) = &action {
1686 match Self::synthesise_cached_response(bytes) {
1687 Ok(reply) => {
1688 stream.write_all(&reply).await.map_err(|e| {
1689 ProxyError::Network(format!("Write error: {}", e))
1690 })?;
1691 state
1692 .metrics
1693 .bytes_sent
1694 .fetch_add(reply.len() as u64, Ordering::Relaxed);
1695 state
1696 .metrics
1697 .queries_processed
1698 .fetch_add(1, Ordering::Relaxed);
1699 continue;
1700 }
1701 Err(e) => {
1702 tracing::warn!(error = %e, "failed to synthesise cached response; falling back to backend");
1703 }
1704 }
1705 }
1706
1707 if let Some(ref mirror) = state.mirror {
1711 if let Some(sql) = crate::protocol::query_text(&msg.payload) {
1712 mirror.offer(sql, Self::is_write_query(sql));
1713 }
1714 }
1715
1716 #[cfg(feature = "wasm-plugins")]
1717 let forward_start = std::time::Instant::now();
1718 let fr = Self::forward_simple_query(
1719 stream,
1720 &msg,
1721 &mut conns,
1722 current_node.as_deref(),
1723 session,
1724 state,
1725 config,
1726 )
1727 .await;
1728 #[cfg(feature = "wasm-plugins")]
1729 Self::fire_post_query_hook(
1730 &msg,
1731 session,
1732 state,
1733 &fr,
1734 forward_start.elapsed(),
1735 );
1736 let (used_node, sent) = fr?;
1737 if let Some(n) = used_node {
1738 current_node = Some(n);
1739 }
1740 #[cfg(feature = "pool-modes")]
1743 Self::release_to_pool_if_idle(
1744 &mut conns,
1745 current_node.as_deref(),
1746 session,
1747 state,
1748 config,
1749 )
1750 .await;
1751 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1752 state
1753 .metrics
1754 .queries_processed
1755 .fetch_add(1, Ordering::Relaxed);
1756 }
1757
1758 MessageType::Parse
1760 | MessageType::Bind
1761 | MessageType::Describe
1762 | MessageType::Execute
1763 | MessageType::Close => {
1764 let mut add_to_pending = true;
1768 match msg.msg_type {
1769 MessageType::Parse => {
1770 let name = Self::parse_stmt_name(&msg.payload);
1774 let unnamed = name.is_empty();
1775 if !unnamed {
1776 let name = name.to_string();
1777 let existed = stmt_registry.contains_key(&name);
1778 if !existed
1782 && stmt_registry.len() >= Self::MAX_PREPARED_STATEMENTS
1783 {
1784 let emsg = Self::create_error_response(
1785 "54000",
1786 "too many prepared statements for this session",
1787 );
1788 let _ = stream.write_all(&emsg).await;
1789 let _ = stream
1790 .write_all(&Self::create_ready_for_query(b'I'))
1791 .await;
1792 tracing::warn!(
1793 client = %session.client_addr,
1794 limit = Self::MAX_PREPARED_STATEMENTS,
1795 "prepared-statement cap exceeded; closing connection"
1796 );
1797 return Ok(());
1798 }
1799 let encoded = msg.encode().freeze();
1800 let old_len =
1804 stmt_registry.get(&name).map(|b| b.len()).unwrap_or(0);
1805 let projected =
1806 stmt_registry_bytes.saturating_sub(old_len) + encoded.len();
1807 if projected > Self::MAX_PREPARED_BYTES {
1808 let emsg = Self::create_error_response(
1809 "54000",
1810 "prepared-statement memory limit exceeded for this session",
1811 );
1812 let _ = stream.write_all(&emsg).await;
1813 let _ = stream
1814 .write_all(&Self::create_ready_for_query(b'I'))
1815 .await;
1816 tracing::warn!(
1817 client = %session.client_addr,
1818 limit = Self::MAX_PREPARED_BYTES,
1819 "prepared-statement byte cap exceeded; closing connection"
1820 );
1821 return Ok(());
1822 }
1823 stmt_registry.insert(name.clone(), encoded);
1824 stmt_registry_bytes = projected;
1825 batch_defines.push(name);
1826 }
1827 if pending_route_sql.is_none() {
1828 if let Some(end) = msg.payload.iter().position(|&b| b == 0) {
1829 if let Some(q) =
1830 crate::protocol::query_text(&msg.payload[end + 1..])
1831 {
1832 if !q.is_empty() {
1833 pending_route_sql = Some(q.to_string());
1834 #[cfg(feature = "anomaly-detection")]
1835 Self::record_anomaly_sql(q, state, session);
1836 }
1837 }
1838 }
1839 }
1840 if promote_unnamed
1847 && unnamed
1848 && pending.is_empty()
1849 && held_unnamed.is_none()
1850 {
1851 let sig = bytes::Bytes::copy_from_slice(&msg.payload[1..]);
1852 held_unnamed = Some((msg.encode().freeze(), sig));
1853 add_to_pending = false;
1854 } else if let Some((held_msg, _)) = held_unnamed.take() {
1855 let mut combined =
1856 BytesMut::with_capacity(held_msg.len() + pending.len());
1857 combined.extend_from_slice(&held_msg);
1858 combined.extend_from_slice(&pending);
1859 pending = combined;
1860 }
1861 }
1862 MessageType::Bind => {
1863 if let Some(name) = Self::bind_stmt_ref(&msg.payload) {
1864 batch_refs.push(name.to_string());
1865 }
1866 }
1867 MessageType::Describe => {
1868 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1869 batch_refs.push(name.to_string());
1870 }
1871 }
1872 MessageType::Close => {
1873 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1874 batch_closes.push(name.to_string());
1875 }
1876 }
1877 _ => {}
1878 }
1879 if add_to_pending {
1880 pending.extend_from_slice(&msg.encode());
1881 }
1882 }
1883
1884 MessageType::Sync | MessageType::Flush => {
1886 let wait_ready = msg.msg_type == MessageType::Sync;
1887 pending.extend_from_slice(&msg.encode());
1888 let batch = pending.split().freeze();
1889 let reprepare: Vec<String> = batch_refs
1893 .iter()
1894 .filter(|r| !batch_defines.contains(r))
1895 .cloned()
1896 .collect();
1897 let (used_node, sent) = Self::forward_extended_batch(
1898 stream,
1899 &batch,
1900 pending_route_sql.as_deref(),
1901 wait_ready,
1902 &mut conns,
1903 current_node.as_deref(),
1904 &stmt_registry,
1905 &reprepare,
1906 &batch_defines,
1907 held_unnamed.take(),
1908 session,
1909 state,
1910 config,
1911 )
1912 .await?;
1913 if let Some(n) = used_node {
1914 current_node = Some(n);
1915 }
1916 #[cfg(feature = "pool-modes")]
1920 if wait_ready {
1921 Self::release_to_pool_if_idle(
1922 &mut conns,
1923 current_node.as_deref(),
1924 session,
1925 state,
1926 config,
1927 )
1928 .await;
1929 }
1930 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1931 for name in batch_closes.drain(..) {
1934 if let Some(removed) = stmt_registry.remove(&name) {
1935 stmt_registry_bytes =
1936 stmt_registry_bytes.saturating_sub(removed.len());
1937 }
1938 }
1939 if wait_ready {
1940 pending_route_sql = None;
1944 batch_defines.clear();
1945 batch_refs.clear();
1946 state
1947 .metrics
1948 .queries_processed
1949 .fetch_add(1, Ordering::Relaxed);
1950 }
1951 }
1952
1953 MessageType::CopyData | MessageType::CopyDone | MessageType::CopyFail => {
1955 if let Some(node) = current_node.clone() {
1956 if let Some(b) = conns.get_mut(&node) {
1957 b.stream.write_all(&msg.encode()).await.map_err(|e| {
1958 ProxyError::Network(format!("Backend copy write error: {}", e))
1959 })?;
1960 if matches!(
1961 msg.msg_type,
1962 MessageType::CopyDone | MessageType::CopyFail
1963 ) {
1964 let r = Self::stream_until_ready(
1965 stream,
1966 &mut b.stream,
1967 session,
1968 state,
1969 )
1970 .await;
1971 match r {
1972 Ok(sent) => {
1973 state
1974 .metrics
1975 .bytes_sent
1976 .fetch_add(sent, Ordering::Relaxed);
1977 }
1978 Err(e) => {
1979 conns.remove(&node);
1980 return Err(e);
1981 }
1982 }
1983 }
1984 }
1985 }
1986 }
1987
1988 _ => {
1990 if let Some(ref node) = current_node {
1991 if let Some(b) = conns.get_mut(node) {
1992 let _ = b.stream.write_all(&msg.encode()).await;
1993 }
1994 }
1995 }
1996 }
1997 }
1998
1999 if pending.len() > Self::MAX_PENDING_BYTES {
2003 let emsg = Self::create_error_response(
2004 "53400",
2005 "un-flushed extended-protocol buffer exceeds per-session limit",
2006 );
2007 let _ = stream.write_all(&emsg).await;
2008 let _ = stream.write_all(&Self::create_ready_for_query(b'I')).await;
2009 tracing::warn!(
2010 client = %session.client_addr,
2011 pending = pending.len(),
2012 "pending extended-protocol buffer cap exceeded; closing connection"
2013 );
2014 return Ok(());
2015 }
2016 }
2017
2018 #[cfg(feature = "pool-modes")]
2022 if state.backend_pool.is_some() {
2023 let nodes: Vec<String> = conns.keys().cloned().collect();
2024 for node in nodes {
2025 Self::release_to_pool_if_idle(
2026 &mut conns,
2027 Some(node.as_str()),
2028 session,
2029 state,
2030 config,
2031 )
2032 .await;
2033 }
2034 }
2035
2036 Ok(())
2037 }
2038
2039 async fn negotiate_client_tls(
2046 mut tcp: TcpStream,
2047 state: &Arc<ServerState>,
2048 ) -> Result<(ClientStream, Option<StartupMessage>)> {
2049 let codec = ProtocolCodec::new();
2050 let mut buffer = BytesMut::with_capacity(1024);
2051 let mut read_buf = vec![0u8; 1024];
2052
2053 let first = loop {
2054 if let Some(msg) = codec.decode_startup(&mut buffer)? {
2055 break msg;
2056 }
2057 let n = tcp
2058 .read(&mut read_buf)
2059 .await
2060 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
2061 if n == 0 {
2062 return Err(ProxyError::Connection(
2063 "client closed before startup".to_string(),
2064 ));
2065 }
2066 buffer.extend_from_slice(&read_buf[..n]);
2067 };
2068
2069 match first {
2070 StartupMessage::SSLRequest => match state.tls_acceptor.as_ref() {
2071 Some(acceptor) => {
2072 tcp.write_all(b"S")
2073 .await
2074 .map_err(|e| ProxyError::Network(format!("SSL accept write: {}", e)))?;
2075 let tls = acceptor
2076 .accept(tcp)
2077 .await
2078 .map_err(|e| ProxyError::Network(format!("TLS handshake failed: {}", e)))?;
2079 if tls.get_ref().1.peer_certificates().is_some() {
2080 tracing::debug!("client presented a certificate (mTLS)");
2081 }
2082 Ok((ClientStream::Tls(Box::new(tls)), None))
2083 }
2084 None => {
2085 tcp.write_all(b"N")
2086 .await
2087 .map_err(|e| ProxyError::Network(format!("SSL reject write: {}", e)))?;
2088 Ok((ClientStream::Plain(tcp), None))
2089 }
2090 },
2091 other => Ok((ClientStream::Plain(tcp), Some(other))),
2092 }
2093 }
2094
2095 async fn handle_startup(
2099 client_stream: &mut ClientStream,
2100 buffer: &mut BytesMut,
2101 codec: &ProtocolCodec,
2102 pre: Option<StartupMessage>,
2103 session: &Arc<ClientSession>,
2104 state: &Arc<ServerState>,
2105 config: &ProxyConfig,
2106 ) -> Result<(Option<TcpStream>, String)> {
2107 let startup_msg = match pre {
2110 Some(msg) => Some(msg),
2111 None => {
2112 let mut read_buf = vec![0u8; 1024];
2113 loop {
2114 if let Some(msg) = codec.decode_startup(buffer)? {
2115 break Some(msg);
2116 }
2117 let n = client_stream
2118 .read(&mut read_buf)
2119 .await
2120 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
2121 if n == 0 {
2122 return Ok((None, String::new()));
2123 }
2124 buffer.extend_from_slice(&read_buf[..n]);
2125 }
2126 }
2127 };
2128
2129 match startup_msg {
2130 Some(StartupMessage::SSLRequest) => {
2131 client_stream
2134 .write_all(b"N")
2135 .await
2136 .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
2137 Err(ProxyError::Protocol(
2138 "unexpected SSLRequest after startup".to_string(),
2139 ))
2140 }
2141 Some(StartupMessage::CancelRequest { pid, key }) => {
2142 Self::forward_cancel_request(state, pid, key).await;
2145 Ok((None, String::new()))
2146 }
2147 Some(StartupMessage::Startup { params, .. }) => {
2148 Self::connect_and_authenticate(client_stream, ¶ms, session, state, config).await
2149 }
2150 None => Err(ProxyError::Protocol(
2151 "Incomplete startup message".to_string(),
2152 )),
2153 }
2154 }
2155
2156 fn hba_admits(rules: &[HbaRule], ip: std::net::IpAddr, user: &str, database: &str) -> bool {
2159 for r in rules {
2160 let user_ok = r.user == "all" || r.user == user;
2161 let db_ok = r.database == "all" || r.database == database;
2162 if user_ok && db_ok && Self::hba_addr_matches(&r.address, ip) {
2163 return r.action == HbaAction::Allow;
2164 }
2165 }
2166 true
2167 }
2168
2169 fn hba_addr_matches(spec: &str, ip: std::net::IpAddr) -> bool {
2172 use std::net::IpAddr;
2173 if spec == "all" {
2174 return true;
2175 }
2176 if let Some((net, bits)) = spec.split_once('/') {
2177 let bits: u32 = match bits.parse() {
2178 Ok(b) => b,
2179 Err(_) => return false,
2180 };
2181 match (net.parse::<IpAddr>(), ip) {
2182 (Ok(IpAddr::V4(n)), IpAddr::V4(i)) if bits <= 32 => {
2183 let mask = if bits == 0 {
2184 0
2185 } else {
2186 u32::MAX << (32 - bits)
2187 };
2188 (u32::from(n) & mask) == (u32::from(i) & mask)
2189 }
2190 (Ok(IpAddr::V6(n)), IpAddr::V6(i)) if bits <= 128 => {
2191 let mask = if bits == 0 {
2192 0
2193 } else {
2194 u128::MAX << (128 - bits)
2195 };
2196 (u128::from(n) & mask) == (u128::from(i) & mask)
2197 }
2198 _ => false,
2199 }
2200 } else {
2201 spec.parse::<IpAddr>().map(|s| s == ip).unwrap_or(false)
2202 }
2203 }
2204
2205 async fn proxy_scram_auth(
2211 client: &mut ClientStream,
2212 user: &str,
2213 state: &Arc<ServerState>,
2214 ) -> std::result::Result<(), String> {
2215 use crate::auth_scram::ScramServer;
2216 let auth_file = state.auth_file.as_ref().ok_or("scram not configured")?;
2217
2218 let mut sasl = BytesMut::new();
2220 sasl.put_i32(10); sasl.extend_from_slice(b"SCRAM-SHA-256\0");
2222 sasl.put_u8(0); Self::write_auth_frame(client, &sasl).await?;
2224
2225 let init = Self::read_password_message(client).await?;
2227 let mech_end = init
2228 .iter()
2229 .position(|&b| b == 0)
2230 .ok_or("malformed SASLInitialResponse (no mechanism)")?;
2231 if init.len() < mech_end + 5 {
2232 return Err("short SASLInitialResponse".into());
2233 }
2234 let client_first =
2235 std::str::from_utf8(&init[mech_end + 5..]).map_err(|_| "client-first not UTF-8")?;
2236
2237 let verifier = auth_file.get(user).ok_or("no such user")?.clone();
2239
2240 let server_nonce = Self::random_nonce();
2242 let (server, server_first) = ScramServer::start(verifier, client_first, &server_nonce)?;
2243
2244 let mut cont = BytesMut::new();
2246 cont.put_i32(11);
2247 cont.extend_from_slice(server_first.as_bytes());
2248 Self::write_auth_frame(client, &cont).await?;
2249
2250 let client_final_raw = Self::read_password_message(client).await?;
2252 let client_final =
2253 std::str::from_utf8(&client_final_raw).map_err(|_| "client-final not UTF-8")?;
2254
2255 let server_final = server.finish(client_final)?;
2257
2258 let mut fin = BytesMut::new();
2260 fin.put_i32(12);
2261 fin.extend_from_slice(server_final.as_bytes());
2262 Self::write_auth_frame(client, &fin).await?;
2263 Ok(())
2264 }
2265
2266 async fn write_auth_frame(
2268 client: &mut ClientStream,
2269 payload: &[u8],
2270 ) -> std::result::Result<(), String> {
2271 let mut frame = BytesMut::with_capacity(payload.len() + 5);
2272 frame.put_u8(b'R');
2273 frame.put_u32((payload.len() + 4) as u32);
2274 frame.extend_from_slice(payload);
2275 client
2276 .write_all(&frame)
2277 .await
2278 .map_err(|e| format!("client write: {}", e))
2279 }
2280
2281 async fn read_password_message(
2284 client: &mut ClientStream,
2285 ) -> std::result::Result<BytesMut, String> {
2286 let codec = ProtocolCodec::new();
2287 let mut buffer = BytesMut::with_capacity(1024);
2288 let mut read_buf = vec![0u8; 1024];
2289 loop {
2290 if let Some(msg) = codec
2291 .decode_message(&mut buffer)
2292 .map_err(|e| format!("decode: {}", e))?
2293 {
2294 if msg.msg_type == MessageType::Password {
2295 return Ok(msg.payload);
2296 }
2297 return Err(format!("expected SASL response, got {:?}", msg.msg_type));
2298 }
2299 let n = client
2300 .read(&mut read_buf)
2301 .await
2302 .map_err(|e| format!("client read: {}", e))?;
2303 if n == 0 {
2304 return Err("client closed during SASL".into());
2305 }
2306 buffer.extend_from_slice(&read_buf[..n]);
2307 }
2308 }
2309
2310 fn random_nonce() -> String {
2312 use rand::Rng;
2313 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
2314 let mut rng = rand::thread_rng();
2315 (0..24)
2316 .map(|_| CHARS[rng.gen_range(0..CHARS.len())] as char)
2317 .collect()
2318 }
2319
2320 async fn connect_and_authenticate(
2322 client_stream: &mut ClientStream,
2323 params: &HashMap<String, String>,
2324 session: &Arc<ClientSession>,
2325 state: &Arc<ServerState>,
2326 config: &ProxyConfig,
2327 ) -> Result<(Option<TcpStream>, String)> {
2328 let user = params.get("user").map(String::as_str).unwrap_or("");
2331 let database = params.get("database").map(String::as_str).unwrap_or(user);
2332 if !Self::hba_admits(&config.hba, session.client_addr.ip(), user, database) {
2333 tracing::info!(%user, %database, client = %session.client_addr, "connection rejected by hba rule");
2334 let err = Self::create_error_response(
2335 "28000",
2336 "connection rejected by proxy admission rules",
2337 );
2338 let _ = client_stream.write_all(&err).await;
2339 return Ok((None, String::new()));
2340 }
2341
2342 if state.auth_file.is_some() {
2348 if let Err(e) = Self::proxy_scram_auth(client_stream, user, state).await {
2349 tracing::info!(%user, error = %e, "proxy SCRAM auth failed");
2350 let err =
2351 Self::create_error_response("28P01", &format!("authentication failed: {}", e));
2352 let _ = client_stream.write_all(&err).await;
2353 return Ok((None, String::new()));
2354 }
2355 tracing::debug!(%user, "client authenticated by proxy SCRAM");
2356 }
2357
2358 Self::apply_authenticate_hook(params, session, state).await?;
2364
2365 let cutover = state.cutover.load_full();
2369 let (node_addr, effective_params) = if let Some(t) = cutover.as_ref() {
2370 let mut p = params.clone();
2371 p.insert("user".to_string(), t.user.clone());
2372 if let Some(ref db) = t.database {
2373 p.insert("database".to_string(), db.clone());
2374 } else {
2375 p.remove("database");
2376 }
2377 tracing::debug!(target = %t.addr, "routing connection to cutover target");
2378 (t.addr.clone(), p)
2379 } else {
2380 (
2381 Self::select_node(session, state, config).await?,
2382 params.clone(),
2383 )
2384 };
2385
2386 let mut backend = match tokio::time::timeout(
2391 config.pool.acquire_timeout(),
2392 TcpStream::connect(&node_addr),
2393 )
2394 .await
2395 {
2396 Ok(Ok(s)) => s,
2397 Ok(Err(e)) => {
2398 let msg = format!("Failed to connect to {}: {}", node_addr, e);
2399 Self::note_backend_failure(state, &node_addr, &msg);
2400 return Err(ProxyError::Connection(msg));
2401 }
2402 Err(_) => {
2403 let msg = format!("Connection timeout to {}", node_addr);
2404 Self::note_backend_failure(state, &node_addr, &msg);
2405 return Err(ProxyError::Connection(msg));
2406 }
2407 };
2408 let _ = backend.set_nodelay(true);
2409
2410 let params = &effective_params;
2412 let startup_bytes = Self::build_startup_message(params);
2413 backend
2414 .write_all(&startup_bytes)
2415 .await
2416 .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
2417
2418 Self::proxy_authentication(client_stream, &mut backend, state, &node_addr).await?;
2422
2423 {
2425 let mut vars = session.variables.write().await;
2426 for (k, v) in params {
2427 vars.insert(k.clone(), v.clone());
2428 }
2429 }
2430
2431 Ok((Some(backend), node_addr))
2432 }
2433
2434 fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
2436 let mut payload = BytesMut::new();
2437
2438 payload.put_u32(196608);
2440
2441 for (key, value) in params {
2443 payload.extend_from_slice(key.as_bytes());
2444 payload.put_u8(0);
2445 payload.extend_from_slice(value.as_bytes());
2446 payload.put_u8(0);
2447 }
2448 payload.put_u8(0); let mut msg = BytesMut::new();
2452 msg.put_u32((payload.len() + 4) as u32);
2453 msg.extend_from_slice(&payload);
2454
2455 msg.to_vec()
2456 }
2457
2458 const MAX_CANCEL_KEYS: usize = 100_000;
2462
2463 const BACKEND_WRITE_TIMEOUT: Duration = Duration::from_secs(30);
2467 const CLIENT_WRITE_TIMEOUT: Duration = Duration::from_secs(60);
2470 const REPREPARE_TIMEOUT: Duration = Duration::from_secs(15);
2473 const MAX_PREPARED_STATEMENTS: usize = 8192;
2476 const MAX_PREPARED_BYTES: usize = 64 * 1024 * 1024;
2482 const MAX_PENDING_BYTES: usize = 64 * 1024 * 1024;
2485 #[cfg(feature = "pool-modes")]
2489 const MAX_TOTAL_IDLE_BACKEND_CONNS: usize = 8192;
2490 const POOL_REAP_INTERVAL: Duration = Duration::from_secs(30);
2492
2493 fn register_cancel_key(state: &Arc<ServerState>, pid: u32, key: u32, node_addr: &str) {
2495 {
2499 let mut order = state.cancel_order.lock();
2500 while state.cancel_map.len() >= Self::MAX_CANCEL_KEYS {
2501 match order.pop_front() {
2502 Some(old) => {
2503 state.cancel_map.remove(&old);
2504 }
2505 None => {
2506 state.cancel_map.clear();
2509 break;
2510 }
2511 }
2512 }
2513 order.push_back((pid, key));
2514 }
2515 state.cancel_map.insert((pid, key), node_addr.to_string());
2516 }
2517
2518 async fn forward_cancel_request(state: &Arc<ServerState>, pid: u32, key: u32) {
2521 let Some(addr) = state.cancel_map.get(&(pid, key)).map(|e| e.clone()) else {
2522 tracing::debug!(pid, "cancel request for unknown key; ignoring");
2523 return;
2524 };
2525 let mut msg = BytesMut::with_capacity(16);
2527 msg.put_u32(16);
2528 msg.put_u32(80877102);
2529 msg.put_u32(pid);
2530 msg.put_u32(key);
2531 match tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(&addr)).await {
2532 Ok(Ok(mut conn)) => {
2533 let _ = conn.set_nodelay(true);
2534 if let Err(e) = conn.write_all(&msg).await {
2535 tracing::warn!(node = %addr, error = %e, "failed to forward CancelRequest");
2536 }
2537 }
2539 other => {
2540 tracing::warn!(node = %addr, ?other, "could not connect to forward CancelRequest")
2541 }
2542 }
2543 }
2544
2545 async fn proxy_authentication(
2547 client_stream: &mut ClientStream,
2548 backend_stream: &mut TcpStream,
2549 state: &Arc<ServerState>,
2550 node_addr: &str,
2551 ) -> Result<()> {
2552 let codec = ProtocolCodec::new();
2553 let mut backend_buffer = BytesMut::with_capacity(4096);
2554 let mut client_buffer = BytesMut::with_capacity(4096);
2555 let mut read_buf = vec![0u8; 4096];
2556
2557 loop {
2558 let n = backend_stream
2560 .read(&mut read_buf)
2561 .await
2562 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
2563
2564 if n == 0 {
2565 return Err(ProxyError::Connection(
2566 "Backend closed during auth".to_string(),
2567 ));
2568 }
2569
2570 backend_buffer.extend_from_slice(&read_buf[..n]);
2571
2572 client_stream
2574 .write_all(&read_buf[..n])
2575 .await
2576 .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
2577
2578 while let Some(msg) = codec.decode_message(&mut backend_buffer)? {
2582 match msg.msg_type {
2583 MessageType::BackendKeyData
2584 if msg.payload.len() >= 8 => {
2588 let pid = u32::from_be_bytes([
2589 msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3],
2590 ]);
2591 let key = u32::from_be_bytes([
2592 msg.payload[4], msg.payload[5], msg.payload[6], msg.payload[7],
2593 ]);
2594 Self::register_cancel_key(state, pid, key, node_addr);
2595 }
2596 MessageType::AuthRequest
2597 if msg.payload.len() >= 4 => {
2599 let auth_type =
2600 i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
2601 if auth_type == 0 {
2602 }
2604 }
2605 MessageType::ReadyForQuery => {
2606 return Ok(());
2608 }
2609 MessageType::ErrorResponse => {
2610 return Err(ProxyError::Auth("Authentication failed".to_string()));
2612 }
2613 _ => {
2614 }
2616 }
2617 }
2618
2619 let n = tokio::time::timeout(
2622 Duration::from_millis(100),
2623 client_stream.read(&mut read_buf),
2624 )
2625 .await;
2626
2627 if let Ok(Ok(n)) = n {
2628 if n > 0 {
2629 client_buffer.extend_from_slice(&read_buf[..n]);
2630 backend_stream
2631 .write_all(&read_buf[..n])
2632 .await
2633 .map_err(|e| {
2634 ProxyError::Network(format!("Backend password write error: {}", e))
2635 })?;
2636 }
2637 }
2638 }
2639 }
2640
2641 async fn choose_target_node(
2646 is_write: bool,
2647 forced_target: Option<String>,
2648 current_node: Option<&str>,
2649 session: &Arc<ClientSession>,
2650 state: &Arc<ServerState>,
2651 config: &ProxyConfig,
2652 ) -> Result<String> {
2653 if let Some(t) = state.cutover.load_full().as_ref() {
2656 return Ok(t.addr.clone());
2657 }
2658
2659 #[cfg(feature = "lag-routing")]
2663 if !is_write && forced_target.is_none() && config.lag_routing.enabled {
2664 let last_write = *session.last_write_at.read().await;
2665 if Self::ryw_pins_primary(last_write, config.lag_routing.ryw_window_ms) {
2666 tracing::debug!(target: "helios::routing", "read-your-writes: pinning read to primary");
2667 return Self::select_primary_with_timeout(session, state, config).await;
2668 }
2669 }
2670
2671 let need_switch = if let Some(ref forced) = forced_target {
2672 let health = state.health.load_full();
2673 let reuse = current_node
2674 .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
2675 .unwrap_or(false);
2676 !reuse
2677 } else if let Some(current) = current_node {
2678 let health = state.health.load_full();
2679 let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
2680 if !current_healthy {
2681 true
2682 } else if is_write {
2683 let is_primary = config
2684 .nodes
2685 .iter()
2686 .find(|n| n.address() == *current)
2687 .map(|n| n.role == NodeRole::Primary)
2688 .unwrap_or(false);
2689 !is_primary
2690 } else {
2691 false
2692 }
2693 } else {
2694 true
2695 };
2696
2697 if let Some(forced) = forced_target {
2698 let resolved = config
2703 .nodes
2704 .iter()
2705 .find(|n| n.name.as_deref() == Some(forced.as_str()) || n.address() == forced)
2706 .map(|n| n.address())
2707 .unwrap_or(forced);
2708 Ok(resolved)
2709 } else if need_switch {
2710 if is_write {
2711 Self::select_primary_with_timeout(session, state, config).await
2712 } else {
2713 Self::select_read_node(session, state, config).await
2714 }
2715 } else {
2716 Ok(current_node.unwrap().to_string())
2717 }
2718 }
2719
2720 async fn ensure_conn(
2725 conns: &mut HashMap<String, BackendConn>,
2726 target: &str,
2727 session: &Arc<ClientSession>,
2728 config: &ProxyConfig,
2729 _state: &Arc<ServerState>,
2730 ) -> Result<()> {
2731 if conns.contains_key(target) {
2732 return Ok(());
2733 }
2734
2735 #[cfg(feature = "pool-modes")]
2740 if let Some(pool) = _state.backend_pool.as_ref() {
2741 let key = Self::pool_key_for(target, session).await;
2742 if let Some(stream) = pool.checkout(&key) {
2743 tracing::info!(
2744 target: "helios::pool",
2745 node = %target,
2746 "reused pooled backend connection"
2747 );
2748 conns.insert(target.to_string(), BackendConn::new(stream));
2749 return Ok(());
2750 }
2751 }
2752
2753 let mut backend =
2754 tokio::time::timeout(config.pool.acquire_timeout(), TcpStream::connect(target))
2755 .await
2756 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target)))?
2757 .map_err(|e| {
2758 ProxyError::Connection(format!("Failed to connect to {}: {}", target, e))
2759 })?;
2760 let _ = backend.set_nodelay(true);
2761
2762 let params = session.variables.read().await.clone();
2763 let startup = Self::build_startup_message(¶ms);
2764 backend
2765 .write_all(&startup)
2766 .await
2767 .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
2768 Self::complete_backend_auth(&mut backend).await?;
2769 #[cfg(feature = "pool-modes")]
2770 if _state.backend_pool.is_some() {
2771 tracing::debug!(target: "helios::pool", node = %target, "dialed fresh backend connection (pool miss)");
2772 }
2773 tracing::debug!(node = %target, "opened backend connection");
2774 conns.insert(target.to_string(), BackendConn::new(backend));
2775 Ok(())
2776 }
2777
2778 #[cfg(feature = "pool-modes")]
2783 async fn pool_key_for(target: &str, session: &Arc<ClientSession>) -> String {
2784 let vars = session.variables.read().await;
2785 let user = vars.get("user").map(|s| s.as_str()).unwrap_or("");
2786 let database = vars.get("database").map(|s| s.as_str()).unwrap_or(user);
2788 crate::pool::pool_key(target, user, database)
2789 }
2790
2791 #[cfg(feature = "pool-modes")]
2798 async fn reset_backend(stream: &mut TcpStream, reset_sql: &str) -> Result<()> {
2799 let msg = crate::protocol::QueryMessage {
2800 query: reset_sql.to_string(),
2801 }
2802 .encode();
2803 stream
2804 .write_all(&msg.encode())
2805 .await
2806 .map_err(|e| ProxyError::Network(format!("reset write error: {}", e)))?;
2807
2808 let codec = ProtocolCodec::new();
2809 let mut buffer = BytesMut::with_capacity(1024);
2810 let mut read_buf = vec![0u8; 1024];
2811 loop {
2812 while let Some(m) = codec.decode_message(&mut buffer)? {
2813 if m.msg_type == MessageType::ReadyForQuery {
2814 return Ok(());
2815 }
2816 }
2817 let n = tokio::time::timeout(Duration::from_secs(5), stream.read(&mut read_buf))
2818 .await
2819 .map_err(|_| ProxyError::Network("reset drain timeout".to_string()))?
2820 .map_err(|e| ProxyError::Network(format!("reset drain read error: {}", e)))?;
2821 if n == 0 {
2822 return Err(ProxyError::Connection(
2823 "backend closed during reset".to_string(),
2824 ));
2825 }
2826 buffer.extend_from_slice(&read_buf[..n]);
2827 }
2828 }
2829
2830 #[cfg(feature = "pool-modes")]
2836 async fn release_to_pool_if_idle(
2837 conns: &mut HashMap<String, BackendConn>,
2838 node: Option<&str>,
2839 session: &Arc<ClientSession>,
2840 state: &Arc<ServerState>,
2841 config: &ProxyConfig,
2842 ) {
2843 let Some(pool) = state.backend_pool.as_ref() else {
2844 return;
2845 };
2846 let Some(node) = node else {
2847 return;
2848 };
2849 if session.tx_state.read().await.in_transaction {
2851 return;
2852 }
2853 let Some(mut bc) = conns.remove(node) else {
2854 return;
2855 };
2856 if Self::reset_backend(&mut bc.stream, &config.pool_mode.reset_query)
2857 .await
2858 .is_ok()
2859 {
2860 let key = Self::pool_key_for(node, session).await;
2861 if pool.checkin(&key, bc.stream) {
2862 tracing::debug!(target: "helios::pool", node = %node, "parked backend connection for reuse");
2863 }
2864 }
2865 }
2867
2868 async fn forward_simple_query(
2874 client: &mut ClientStream,
2875 msg: &Message,
2876 conns: &mut HashMap<String, BackendConn>,
2877 current_node: Option<&str>,
2878 session: &Arc<ClientSession>,
2879 state: &Arc<ServerState>,
2880 config: &ProxyConfig,
2881 ) -> Result<(Option<String>, u64)> {
2882 #[cfg(feature = "rate-limiting")]
2884 if let Some(mut resp) = Self::rate_limit_check(session, state, config).await {
2885 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
2886 client
2887 .write_all(&resp)
2888 .await
2889 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2890 return Ok((None, resp.len() as u64));
2891 }
2892
2893 let default_is_write = Self::is_write_message(msg);
2894 let plugin_override = Self::apply_route_hook(msg, state, session);
2895
2896 if let RouteOverride::Block(reason) = plugin_override {
2898 let mut response = Vec::with_capacity(64 + reason.len());
2899 response.extend_from_slice(&Self::create_error_response(
2900 "42000",
2901 &format!("Query blocked by route plugin: {}", reason),
2902 ));
2903 response.extend_from_slice(&Self::create_ready_for_query(b'I'));
2904 client
2905 .write_all(&response)
2906 .await
2907 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2908 return Ok((None, response.len() as u64));
2909 }
2910
2911 #[cfg(feature = "routing-hints")]
2915 let (route_override, default_is_write, stripped_msg) =
2916 Self::resolve_simple_route(msg, plugin_override, default_is_write, state);
2917 #[cfg(not(feature = "routing-hints"))]
2918 let (route_override, stripped_msg): (RouteOverride, Option<Message>) =
2919 (plugin_override, None);
2920
2921 let (is_write, forced_target) = match route_override {
2922 RouteOverride::None => (default_is_write, None),
2923 RouteOverride::Primary => (true, None),
2924 RouteOverride::Standby => (false, None),
2925 RouteOverride::Node(name) => (default_is_write, Some(name)),
2926 RouteOverride::Block(_) => unreachable!("handled above"),
2927 };
2928
2929 #[cfg(feature = "lag-routing")]
2932 if is_write && config.lag_routing.enabled {
2933 *session.last_write_at.write().await = Some(std::time::Instant::now());
2934 }
2935
2936 let forward_msg = stripped_msg.as_ref().unwrap_or(msg);
2939
2940 #[cfg(feature = "query-rewriting")]
2944 let rewritten_msg: Option<Message> = state.rewriter.as_ref().and_then(|rw| {
2945 let sql = crate::protocol::query_text(&forward_msg.payload)?;
2946 match rw.rewrite(sql) {
2947 Ok(res) if res.was_rewritten() => {
2948 tracing::debug!(target: "helios::rewrite", rules = ?res.rules_applied, "query rewritten");
2949 Some(crate::protocol::QueryMessage { query: res.query().to_string() }.encode())
2950 }
2951 _ => None,
2952 }
2953 });
2954 #[cfg(feature = "query-rewriting")]
2955 let forward_msg = rewritten_msg.as_ref().unwrap_or(forward_msg);
2956
2957 #[cfg(feature = "multi-tenancy")]
2961 let tenant_msg: Option<Message> = if let Some(tm) = state.tenant_manager.as_ref() {
2962 match crate::protocol::query_text(&forward_msg.payload) {
2963 Some(sql) => {
2964 let ctx = Self::tenant_request_ctx(session).await;
2965 match tm.identify_tenant(&ctx) {
2966 Some(tenant) => {
2967 let res = tm.transform_query(sql, &tenant);
2968 if res.transformed {
2969 tracing::debug!(target: "helios::tenant", tenant = %tenant.0, "tenant filter injected");
2970 Some(crate::protocol::QueryMessage { query: res.query }.encode())
2971 } else {
2972 None
2973 }
2974 }
2975 None => None,
2976 }
2977 }
2978 None => None,
2979 }
2980 } else {
2981 None
2982 };
2983 #[cfg(feature = "multi-tenancy")]
2984 let forward_msg = tenant_msg.as_ref().unwrap_or(forward_msg);
2985
2986 #[cfg(feature = "query-cache")]
2989 let cache_ctx: Option<crate::cache::CacheContext> = if is_write {
2990 None
2991 } else if let Some(qc) = state.query_cache.as_ref() {
2992 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
2993 match Self::cacheable_read_ctx(session, sql).await {
2994 Some(ctx) => {
2995 if let crate::cache::CacheLookup::Hit { result, level } =
2996 qc.get(sql, &ctx).await
2997 {
2998 tracing::debug!(target: "helios::cache", level = %level, "cache hit");
2999 client.write_all(&result.data).await.map_err(|e| {
3000 ProxyError::Network(format!("Client write error: {}", e))
3001 })?;
3002 return Ok((None, result.data.len() as u64));
3003 }
3004 Some(ctx)
3005 }
3006 None => None,
3007 }
3008 } else {
3009 None
3010 };
3011
3012 #[cfg(feature = "schema-routing")]
3015 let forced_target = match state.schema_analyzer.as_ref() {
3016 Some(analyzer)
3017 if forced_target.is_none()
3018 && !is_write
3019 && !config.schema_routing.analytics_node.is_empty() =>
3020 {
3021 match crate::protocol::query_text(&forward_msg.payload) {
3022 Some(sql) if analyzer.analyze(sql).is_analytics() => {
3023 tracing::debug!(target: "helios::schema", "OLAP query routed to analytics node");
3024 Some(config.schema_routing.analytics_node.clone())
3025 }
3026 _ => forced_target,
3027 }
3028 }
3029 _ => forced_target,
3030 };
3031
3032 #[cfg(feature = "query-analytics")]
3034 let analytics_sql =
3035 crate::protocol::query_text(&forward_msg.payload).map(|s| s.to_string());
3036 #[cfg(feature = "query-analytics")]
3037 let started = std::time::Instant::now();
3038
3039 let target = Self::choose_target_node(
3040 is_write,
3041 forced_target,
3042 current_node,
3043 session,
3044 state,
3045 config,
3046 )
3047 .await?;
3048 tracing::debug!(target: "helios::routing", node = %target, is_write, "routed simple query");
3049
3050 #[cfg(feature = "circuit-breaker")]
3052 if let Some(mut resp) = Self::circuit_fast_fail(state, &target) {
3053 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3054 client
3055 .write_all(&resp)
3056 .await
3057 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3058 return Ok((None, resp.len() as u64));
3059 }
3060
3061 if let Err(e) = Self::ensure_conn(conns, &target, session, config, state).await {
3063 Self::record_backend_failure(state, &target, &e.to_string());
3064 return Err(e);
3065 }
3066 let backend = conns.get_mut(&target).expect("just ensured");
3067
3068 let backend_err = match tokio::time::timeout(
3069 Self::BACKEND_WRITE_TIMEOUT,
3070 backend.stream.write_all(&forward_msg.encode()),
3071 )
3072 .await
3073 {
3074 Ok(Ok(())) => None,
3075 Ok(Err(e)) => Some(format!("Backend write error: {}", e)),
3076 Err(_) => Some("Backend write timeout".to_string()),
3077 };
3078 if let Some(msg) = backend_err {
3079 let e = ProxyError::Network(msg);
3080 conns.remove(&target);
3081 Self::record_backend_failure(state, &target, &e.to_string());
3082 return Err(e);
3083 }
3084
3085 #[cfg(feature = "query-cache")]
3088 if let (Some(ctx), Some(qc)) = (cache_ctx.as_ref(), state.query_cache.as_ref()) {
3089 return match Self::stream_until_ready_capture(client, &mut backend.stream, session)
3090 .await
3091 {
3092 Ok((sent, captured, cacheable, rows)) => {
3093 #[cfg(feature = "circuit-breaker")]
3094 Self::circuit_record(state, &target, true, "");
3095 if cacheable && !captured.is_empty() {
3096 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
3097 qc.put(
3098 sql,
3099 ctx,
3100 bytes::Bytes::from(captured),
3101 rows,
3102 std::time::Duration::ZERO,
3103 )
3104 .await;
3105 }
3106 #[cfg(feature = "query-analytics")]
3107 if let Some(sql) = analytics_sql.as_deref() {
3108 Self::record_analytics(
3109 state,
3110 session,
3111 sql,
3112 &target,
3113 started.elapsed(),
3114 None,
3115 )
3116 .await;
3117 }
3118 Ok((Some(target), sent))
3119 }
3120 Err(e) => {
3121 conns.remove(&target);
3122 Self::record_backend_failure(state, &target, &e.to_string());
3123 Err(e)
3124 }
3125 };
3126 }
3127
3128 match Self::stream_until_ready(client, &mut backend.stream, session, state).await {
3129 Ok(sent) => {
3130 #[cfg(feature = "circuit-breaker")]
3131 Self::circuit_record(state, &target, true, "");
3132 #[cfg(feature = "query-cache")]
3134 if is_write {
3135 if let Some(qc) = state.query_cache.as_ref() {
3136 let sql = crate::protocol::query_text(&forward_msg.payload).unwrap_or("");
3137 qc.invalidate_query(sql).await;
3138 }
3139 }
3140 #[cfg(feature = "ha-tr")]
3142 if is_write && config.tr_enabled {
3143 if let Some(sql) = crate::protocol::query_text(&forward_msg.payload) {
3144 Self::journal_write(state, session, sql).await;
3145 }
3146 }
3147 #[cfg(feature = "query-analytics")]
3148 if let Some(sql) = analytics_sql.as_deref() {
3149 Self::record_analytics(state, session, sql, &target, started.elapsed(), None)
3150 .await;
3151 }
3152 Ok((Some(target), sent))
3153 }
3154 Err(e) => {
3155 conns.remove(&target);
3157 Self::record_backend_failure(state, &target, &e.to_string());
3158 #[cfg(feature = "query-analytics")]
3159 if let Some(sql) = analytics_sql.as_deref() {
3160 Self::record_analytics(
3161 state,
3162 session,
3163 sql,
3164 &target,
3165 started.elapsed(),
3166 Some(e.to_string()),
3167 )
3168 .await;
3169 }
3170 Err(e)
3171 }
3172 }
3173 }
3174
3175 #[allow(clippy::too_many_arguments)]
3188 async fn forward_extended_batch(
3189 client: &mut ClientStream,
3190 batch: &[u8],
3191 route_sql: Option<&str>,
3192 wait_ready: bool,
3193 conns: &mut HashMap<String, BackendConn>,
3194 current_node: Option<&str>,
3195 registry: &HashMap<String, bytes::Bytes>,
3196 reprepare: &[String],
3197 defines: &[String],
3198 unnamed: Option<(bytes::Bytes, bytes::Bytes)>,
3199 session: &Arc<ClientSession>,
3200 state: &Arc<ServerState>,
3201 config: &ProxyConfig,
3202 ) -> Result<(Option<String>, u64)> {
3203 #[cfg(feature = "rate-limiting")]
3207 if let Some(mut resp) = Self::rate_limit_check(session, state, config).await {
3208 if wait_ready {
3209 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3210 }
3211 client
3212 .write_all(&resp)
3213 .await
3214 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3215 return Ok((None, resp.len() as u64));
3216 }
3217
3218 #[cfg(feature = "query-analytics")]
3220 let analytics_sql = route_sql.map(|s| s.to_string());
3221 #[cfg(feature = "query-analytics")]
3222 let started = std::time::Instant::now();
3223
3224 let target = match route_sql {
3225 Some(sql) => {
3226 #[cfg(feature = "routing-hints")]
3229 let (is_write, forced) = Self::extended_hint_route(state, sql)
3230 .unwrap_or_else(|| (Self::is_write_query(sql), None));
3231 #[cfg(not(feature = "routing-hints"))]
3232 let (is_write, forced): (bool, Option<String>) = (Self::is_write_query(sql), None);
3233 #[cfg(feature = "lag-routing")]
3234 if is_write && config.lag_routing.enabled {
3235 *session.last_write_at.write().await = Some(std::time::Instant::now());
3236 }
3237 Self::choose_target_node(is_write, forced, current_node, session, state, config)
3238 .await?
3239 }
3240 None => match current_node {
3244 Some(c) => c.to_string(),
3245 None => Self::select_read_node(session, state, config).await?,
3246 },
3247 };
3248
3249 #[cfg(feature = "circuit-breaker")]
3251 if let Some(mut resp) = Self::circuit_fast_fail(state, &target) {
3252 if wait_ready {
3253 resp.extend_from_slice(&Self::create_ready_for_query(b'I'));
3254 }
3255 client
3256 .write_all(&resp)
3257 .await
3258 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3259 return Ok((None, resp.len() as u64));
3260 }
3261
3262 if let Err(e) = Self::ensure_conn(conns, &target, session, config, state).await {
3263 Self::record_backend_failure(state, &target, &e.to_string());
3264 return Err(e);
3265 }
3266 let backend = conns.get_mut(&target).expect("just ensured");
3267
3268 for name in reprepare {
3273 if backend.prepared.contains(name) {
3274 continue;
3275 }
3276 let Some(parse_bytes) = registry.get(name) else {
3277 continue; };
3279 match Self::reprepare_statement(&mut backend.stream, parse_bytes).await {
3280 Ok(()) => {
3281 backend.prepared.insert(name.clone());
3282 }
3283 Err(e) => {
3284 conns.remove(&target);
3285 return Err(e);
3286 }
3287 }
3288 }
3289
3290 let mut inject_parse_complete = false;
3297 let mut new_unnamed_sig: Option<bytes::Bytes> = None;
3298 if let Some((parse_msg, sig)) = unnamed.as_ref() {
3299 if backend.unnamed_sig.as_deref() == Some(&sig[..]) {
3300 inject_parse_complete = true;
3301 } else {
3302 if let Err(e) = backend
3303 .stream
3304 .write_all(parse_msg)
3305 .await
3306 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
3307 {
3308 conns.remove(&target);
3309 return Err(e);
3310 }
3311 new_unnamed_sig = Some(sig.clone());
3312 }
3313 }
3314
3315 let batch_err = match tokio::time::timeout(
3316 Self::BACKEND_WRITE_TIMEOUT,
3317 backend.stream.write_all(batch),
3318 )
3319 .await
3320 {
3321 Ok(Ok(())) => None,
3322 Ok(Err(e)) => Some(format!("Backend write error: {}", e)),
3323 Err(_) => Some("Backend write timeout".to_string()),
3324 };
3325 if let Some(msg) = batch_err {
3326 let e = ProxyError::Network(msg);
3327 conns.remove(&target);
3328 Self::record_backend_failure(state, &target, &e.to_string());
3329 return Err(e);
3330 }
3331
3332 let mut injected: u64 = 0;
3335 if inject_parse_complete {
3336 if let Err(e) = client
3337 .write_all(&[b'1', 0, 0, 0, 4])
3338 .await
3339 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))
3340 {
3341 conns.remove(&target);
3342 return Err(e);
3343 }
3344 injected = 5;
3345 }
3346
3347 let r = if wait_ready {
3348 Self::stream_until_ready(client, &mut backend.stream, session, state).await
3349 } else {
3350 Self::stream_flush(client, &mut backend.stream, session, state).await
3351 };
3352 match r {
3353 Ok(sent) => {
3354 #[cfg(feature = "circuit-breaker")]
3355 Self::circuit_record(state, &target, true, "");
3356 #[cfg(feature = "query-analytics")]
3357 if let Some(sql) = analytics_sql.as_deref() {
3358 Self::record_analytics(state, session, sql, &target, started.elapsed(), None)
3359 .await;
3360 }
3361 for name in defines {
3363 backend.prepared.insert(name.clone());
3364 }
3365 if let Some(sig) = new_unnamed_sig {
3367 backend.unnamed_sig = Some(sig);
3368 }
3369 Ok((Some(target), sent + injected))
3370 }
3371 Err(e) => {
3372 conns.remove(&target);
3373 Self::record_backend_failure(state, &target, &e.to_string());
3374 #[cfg(feature = "query-analytics")]
3375 if let Some(sql) = analytics_sql.as_deref() {
3376 Self::record_analytics(
3377 state,
3378 session,
3379 sql,
3380 &target,
3381 started.elapsed(),
3382 Some(e.to_string()),
3383 )
3384 .await;
3385 }
3386 Err(e)
3387 }
3388 }
3389 }
3390
3391 async fn reprepare_statement<S: AsyncReadExt + AsyncWriteExt + Unpin>(
3397 backend: &mut S,
3398 parse_bytes: &[u8],
3399 ) -> Result<()> {
3400 tokio::time::timeout(Self::REPREPARE_TIMEOUT, backend.write_all(parse_bytes))
3401 .await
3402 .map_err(|_| ProxyError::Network("re-prepare write timeout".to_string()))?
3403 .map_err(|e| ProxyError::Network(format!("re-prepare write error: {}", e)))?;
3404 tokio::time::timeout(
3406 Self::REPREPARE_TIMEOUT,
3407 backend.write_all(&[b'H', 0, 0, 0, 4]),
3408 )
3409 .await
3410 .map_err(|_| ProxyError::Network("re-prepare flush timeout".to_string()))?
3411 .map_err(|e| ProxyError::Network(format!("re-prepare flush error: {}", e)))?;
3412 let mtype =
3413 tokio::time::timeout(Self::REPREPARE_TIMEOUT, Self::read_one_frame_type(backend))
3414 .await
3415 .map_err(|_| ProxyError::Network("re-prepare read timeout".to_string()))??;
3416 match mtype {
3417 b'1' => Ok(()), b'E' => Err(ProxyError::Protocol(
3419 "re-prepare rejected by backend".to_string(),
3420 )),
3421 other => Err(ProxyError::Protocol(format!(
3422 "unexpected re-prepare reply: {}",
3423 other as char
3424 ))),
3425 }
3426 }
3427
3428 async fn read_one_frame_type<S: AsyncReadExt + Unpin>(backend: &mut S) -> Result<u8> {
3432 let mut header = [0u8; 5];
3433 backend
3434 .read_exact(&mut header)
3435 .await
3436 .map_err(|e| ProxyError::Network(format!("re-prepare read error: {}", e)))?;
3437 let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
3438 let body_len = len.saturating_sub(4);
3439 if body_len > 0 {
3440 let mut body = vec![0u8; body_len];
3441 backend
3442 .read_exact(&mut body)
3443 .await
3444 .map_err(|e| ProxyError::Network(format!("re-prepare body read error: {}", e)))?;
3445 }
3446 Ok(header[0])
3447 }
3448
3449 fn parse_stmt_name(payload: &[u8]) -> &str {
3452 let end = payload.iter().position(|&b| b == 0).unwrap_or(0);
3453 std::str::from_utf8(&payload[..end]).unwrap_or("")
3454 }
3455
3456 fn bind_stmt_ref(payload: &[u8]) -> Option<&str> {
3460 let portal_end = payload.iter().position(|&b| b == 0)?;
3461 let rest = &payload[portal_end + 1..];
3462 let stmt_end = rest.iter().position(|&b| b == 0)?;
3463 let name = std::str::from_utf8(&rest[..stmt_end]).ok()?;
3464 (!name.is_empty()).then_some(name)
3465 }
3466
3467 fn stmt_kind_name(payload: &[u8]) -> Option<&str> {
3470 if payload.first() != Some(&b'S') {
3471 return None;
3472 }
3473 let rest = &payload[1..];
3474 let end = rest.iter().position(|&b| b == 0)?;
3475 let name = std::str::from_utf8(&rest[..end]).ok()?;
3476 (!name.is_empty()).then_some(name)
3477 }
3478
3479 async fn stream_until_ready(
3487 client: &mut ClientStream,
3488 backend: &mut TcpStream,
3489 session: &Arc<ClientSession>,
3490 state: &Arc<ServerState>,
3491 ) -> Result<u64> {
3492 let _ = state;
3493 let mut buf = BytesMut::with_capacity(16384);
3494 let mut read_buf = vec![0u8; 16384];
3495 let mut sent: u64 = 0;
3496
3497 loop {
3498 let mut consumed = 0usize;
3500 let mut ready_status: Option<u8> = None;
3501 let mut yield_for_copy = false;
3502 loop {
3503 let rem = &buf[consumed..];
3504 if rem.len() < 5 {
3505 break;
3506 }
3507 let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
3508 if len < 4 || rem.len() < len + 1 {
3509 break; }
3511 let frame_total = len + 1;
3512 let mtype = rem[0];
3513 consumed += frame_total;
3514 if mtype == b'Z' {
3515 ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
3517 break;
3518 }
3519 if mtype == b'G' || mtype == b'W' {
3520 yield_for_copy = true;
3523 break;
3524 }
3525 }
3526
3527 if consumed > 0 {
3528 tokio::time::timeout(
3529 Self::CLIENT_WRITE_TIMEOUT,
3530 client.write_all(&buf[..consumed]),
3531 )
3532 .await
3533 .map_err(|_| ProxyError::Network("Client write timeout".to_string()))?
3534 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3535 sent += consumed as u64;
3536 let _ = buf.split_to(consumed);
3537 }
3538
3539 if let Some(status) = ready_status {
3540 let st = TransactionStatus::from_byte(status);
3541 let mut tx = session.tx_state.write().await;
3542 tx.in_transaction = st != TransactionStatus::Idle;
3543 return Ok(sent);
3544 }
3545 if yield_for_copy {
3546 return Ok(sent);
3547 }
3548
3549 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
3550 .await
3551 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
3552 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
3553 if n == 0 {
3554 return Err(ProxyError::Connection(
3555 "Backend closed mid-response".to_string(),
3556 ));
3557 }
3558 buf.extend_from_slice(&read_buf[..n]);
3559 }
3560 }
3561
3562 #[cfg(feature = "query-cache")]
3568 async fn stream_until_ready_capture(
3569 client: &mut ClientStream,
3570 backend: &mut TcpStream,
3571 session: &Arc<ClientSession>,
3572 ) -> Result<(u64, Vec<u8>, bool, usize)> {
3573 let mut buf = BytesMut::with_capacity(16384);
3574 let mut read_buf = vec![0u8; 16384];
3575 let mut sent: u64 = 0;
3576 let mut captured: Vec<u8> = Vec::with_capacity(4096);
3577 let mut had_error = false;
3578 let mut row_count: usize = 0;
3579
3580 loop {
3581 let mut consumed = 0usize;
3582 let mut ready_status: Option<u8> = None;
3583 let mut yield_for_copy = false;
3584 loop {
3585 let rem = &buf[consumed..];
3586 if rem.len() < 5 {
3587 break;
3588 }
3589 let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
3590 if len < 4 || rem.len() < len + 1 {
3591 break;
3592 }
3593 let frame_total = len + 1;
3594 let mtype = rem[0];
3595 if mtype == b'E' {
3596 had_error = true;
3597 }
3598 if mtype == b'C' {
3599 if let Some(tag) = rem.get(5..frame_total) {
3601 if let Some(end) = tag.iter().position(|&b| b == 0) {
3602 if let Ok(s) = std::str::from_utf8(&tag[..end]) {
3603 if let Some(n) =
3604 s.rsplit(' ').next().and_then(|x| x.parse::<usize>().ok())
3605 {
3606 row_count = n;
3607 }
3608 }
3609 }
3610 }
3611 }
3612 consumed += frame_total;
3613 if mtype == b'Z' {
3614 ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
3615 break;
3616 }
3617 if mtype == b'G' || mtype == b'W' {
3618 yield_for_copy = true;
3619 break;
3620 }
3621 }
3622
3623 if consumed > 0 {
3624 tokio::time::timeout(
3625 Self::CLIENT_WRITE_TIMEOUT,
3626 client.write_all(&buf[..consumed]),
3627 )
3628 .await
3629 .map_err(|_| ProxyError::Network("Client write timeout".to_string()))?
3630 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3631 captured.extend_from_slice(&buf[..consumed]);
3632 sent += consumed as u64;
3633 let _ = buf.split_to(consumed);
3634 }
3635
3636 if let Some(status) = ready_status {
3637 let st = TransactionStatus::from_byte(status);
3638 let mut tx = session.tx_state.write().await;
3639 tx.in_transaction = st != TransactionStatus::Idle;
3640 let cacheable = !had_error && status == b'I';
3641 return Ok((sent, captured, cacheable, row_count));
3642 }
3643 if yield_for_copy {
3644 return Ok((sent, captured, false, row_count));
3645 }
3646
3647 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
3648 .await
3649 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
3650 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
3651 if n == 0 {
3652 return Err(ProxyError::Connection(
3653 "Backend closed mid-response".to_string(),
3654 ));
3655 }
3656 buf.extend_from_slice(&read_buf[..n]);
3657 }
3658 }
3659
3660 async fn stream_flush(
3666 client: &mut ClientStream,
3667 backend: &mut TcpStream,
3668 session: &Arc<ClientSession>,
3669 state: &Arc<ServerState>,
3670 ) -> Result<u64> {
3671 let _ = (session, state);
3672 let mut read_buf = vec![0u8; 16384];
3673 let mut sent: u64 = 0;
3674 loop {
3675 match tokio::time::timeout(Duration::from_millis(200), backend.read(&mut read_buf))
3676 .await
3677 {
3678 Ok(Ok(0)) => {
3679 return Err(ProxyError::Connection(
3680 "Backend closed mid-flush".to_string(),
3681 ))
3682 }
3683 Ok(Ok(n)) => {
3684 client
3685 .write_all(&read_buf[..n])
3686 .await
3687 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
3688 sent += n as u64;
3689 }
3690 Ok(Err(e)) => {
3691 return Err(ProxyError::Network(format!("Backend read error: {}", e)))
3692 }
3693 Err(_) => return Ok(sent), }
3695 }
3696 }
3697
3698 fn is_write_message(msg: &Message) -> bool {
3700 match msg.msg_type {
3701 MessageType::Query => {
3702 crate::protocol::query_text(&msg.payload)
3706 .map(Self::is_write_query)
3707 .unwrap_or(false)
3708 }
3709 MessageType::Parse => {
3710 msg.payload
3713 .iter()
3714 .position(|&b| b == 0)
3715 .and_then(|end| crate::protocol::query_text(&msg.payload[end + 1..]))
3716 .map(Self::is_write_query)
3717 .unwrap_or(false)
3718 }
3719 _ => false,
3721 }
3722 }
3723
3724 fn is_write_query(sql: &str) -> bool {
3726 use crate::protocol::starts_with_ci;
3727 let trimmed = sql.trim();
3728
3729 if starts_with_ci(trimmed, "INSERT")
3731 || starts_with_ci(trimmed, "UPDATE")
3732 || starts_with_ci(trimmed, "DELETE")
3733 || starts_with_ci(trimmed, "CREATE")
3734 || starts_with_ci(trimmed, "DROP")
3735 || starts_with_ci(trimmed, "ALTER")
3736 || starts_with_ci(trimmed, "TRUNCATE")
3737 || starts_with_ci(trimmed, "GRANT")
3738 || starts_with_ci(trimmed, "REVOKE")
3739 || starts_with_ci(trimmed, "VACUUM")
3740 || starts_with_ci(trimmed, "REINDEX")
3741 || starts_with_ci(trimmed, "CLUSTER")
3742 {
3743 return true;
3744 }
3745
3746 if starts_with_ci(trimmed, "BEGIN")
3748 || starts_with_ci(trimmed, "START")
3749 || starts_with_ci(trimmed, "COMMIT")
3750 || starts_with_ci(trimmed, "ROLLBACK")
3751 || starts_with_ci(trimmed, "SAVEPOINT")
3752 || starts_with_ci(trimmed, "RELEASE")
3753 {
3754 return true;
3755 }
3756
3757 if starts_with_ci(trimmed, "SET") && !starts_with_ci(trimmed, "SET TRANSACTION READ ONLY") {
3759 return true;
3760 }
3761
3762 false
3763 }
3764
3765 #[cfg(feature = "rate-limiting")]
3768 async fn rate_limit_key(
3769 session: &Arc<ClientSession>,
3770 config: &ProxyConfig,
3771 ) -> crate::rate_limit::LimiterKey {
3772 use crate::config::RateLimitKeyBy;
3773 use crate::rate_limit::LimiterKey;
3774 match config.rate_limit.key_by {
3775 RateLimitKeyBy::Global => LimiterKey::Global,
3776 RateLimitKeyBy::ClientIp => LimiterKey::ClientIp(session.client_addr.ip()),
3777 RateLimitKeyBy::Database => {
3778 let vars = session.variables.read().await;
3779 LimiterKey::Database(vars.get("database").cloned().unwrap_or_default())
3780 }
3781 RateLimitKeyBy::User => {
3782 let vars = session.variables.read().await;
3783 LimiterKey::User(vars.get("user").cloned().unwrap_or_default())
3784 }
3785 }
3786 }
3787
3788 #[cfg(feature = "rate-limiting")]
3794 async fn rate_limit_check(
3795 session: &Arc<ClientSession>,
3796 state: &Arc<ServerState>,
3797 config: &ProxyConfig,
3798 ) -> Option<Vec<u8>> {
3799 use crate::rate_limit::RateLimitResult;
3800 let limiter = state.rate_limiter.as_ref()?;
3801 let key = Self::rate_limit_key(session, config).await;
3802 match limiter.check(&key, 1) {
3803 RateLimitResult::Allowed => None,
3804 RateLimitResult::Warned(msg) => {
3805 tracing::warn!(key = %key, reason = %msg, "rate limit warning");
3806 None
3807 }
3808 RateLimitResult::Throttled(d) | RateLimitResult::Queued(d) => {
3809 tokio::time::sleep(d.min(Duration::from_secs(5))).await;
3812 None
3813 }
3814 RateLimitResult::Denied(exc) => {
3815 tracing::info!(key = %key, "rate limit exceeded");
3816 let msg = format!(
3817 "rate limit exceeded: {} (retry after {}ms)",
3818 exc.message,
3819 exc.retry_after.as_millis()
3820 );
3821 Some(Self::create_error_response("53400", &msg))
3822 }
3823 }
3824 }
3825
3826 fn is_backend_fault(err: &str) -> bool {
3852 !err.contains("Client") && !err.contains("Backend read timeout")
3853 }
3854
3855 fn note_backend_failure(state: &Arc<ServerState>, addr: &str, err: &str) {
3859 if !Self::is_backend_fault(err) {
3860 return;
3861 }
3862 let _writers = state.health_write.lock();
3871 let snapshot = state.health.load_full();
3872 if snapshot.get(addr).map(|h| h.healthy).unwrap_or(false) {
3875 let mut next = (*snapshot).clone();
3876 if let Some(nh) = next.get_mut(addr) {
3877 nh.healthy = false;
3878 nh.failure_count = nh.failure_count.saturating_add(1);
3879 nh.last_error = Some(format!("in-band failure: {}", err));
3880 tracing::warn!(
3881 node = %addr,
3882 error = %err,
3883 "in-band failure — node marked unhealthy for fast failover"
3884 );
3885 }
3886 state.health.store(Arc::new(next));
3887 }
3888 }
3889
3890 fn record_backend_failure(state: &Arc<ServerState>, node: &str, err: &str) {
3896 Self::note_backend_failure(state, node, err);
3897 #[cfg(feature = "circuit-breaker")]
3898 if Self::is_backend_fault(err) {
3899 Self::circuit_record(state, node, false, err);
3900 }
3901 }
3902
3903 #[cfg(feature = "circuit-breaker")]
3906 fn circuit_is_open(state: &Arc<ServerState>, node: &str) -> bool {
3907 state
3908 .circuit_breaker
3909 .as_ref()
3910 .map(|cb| {
3911 cb.get_breaker(node).get_state() == crate::circuit_breaker::CircuitState::Open
3912 })
3913 .unwrap_or(false)
3914 }
3915
3916 #[cfg(feature = "circuit-breaker")]
3918 fn circuit_record(state: &Arc<ServerState>, node: &str, success: bool, err: &str) {
3919 if let Some(cb) = state.circuit_breaker.as_ref() {
3920 let breaker = cb.get_breaker(node);
3921 if success {
3922 breaker.record_success();
3923 } else {
3924 breaker.record_failure(err);
3925 }
3926 }
3927 }
3928
3929 #[cfg(feature = "circuit-breaker")]
3933 fn circuit_fast_fail(state: &Arc<ServerState>, node: &str) -> Option<Vec<u8>> {
3934 if Self::circuit_is_open(state, node) {
3935 tracing::info!(node = %node, "circuit open — fast-failing");
3936 Some(Self::create_error_response(
3937 "08006",
3938 &format!("circuit open for node {node}: backend temporarily unavailable"),
3939 ))
3940 } else {
3941 None
3942 }
3943 }
3944
3945 #[cfg(feature = "lag-routing")]
3948 fn ryw_pins_primary(last_write: Option<std::time::Instant>, window_ms: u64) -> bool {
3949 window_ms > 0
3950 && last_write
3951 .map(|t| t.elapsed() < Duration::from_millis(window_ms))
3952 .unwrap_or(false)
3953 }
3954
3955 #[cfg(feature = "lag-routing")]
3959 fn lag_excludes_standby(lag_bytes: Option<u64>, max_lag_bytes: u64) -> bool {
3960 max_lag_bytes > 0 && lag_bytes.map(|l| l > max_lag_bytes).unwrap_or(false)
3961 }
3962
3963 #[cfg(feature = "query-cache")]
3966 fn is_cacheable_read_sql(sql: &str) -> bool {
3967 use crate::protocol::{contains_ci, starts_with_ci};
3968 let t = sql.trim_start();
3969 if !starts_with_ci(t, "SELECT") {
3970 return false;
3971 }
3972 if contains_ci(t, "FOR UPDATE") || contains_ci(t, "FOR SHARE") {
3973 return false;
3974 }
3975 const VOLATILE: [&str; 10] = [
3977 "now(",
3978 "current_timestamp",
3979 "current_date",
3980 "current_time",
3981 "clock_timestamp",
3982 "statement_timestamp",
3983 "random(",
3984 "nextval(",
3985 "uuid_generate",
3986 "gen_random_uuid",
3987 ];
3988 !VOLATILE.iter().any(|v| contains_ci(t, v))
3989 }
3990
3991 #[cfg(feature = "query-cache")]
3995 async fn cacheable_read_ctx(
3996 session: &Arc<ClientSession>,
3997 sql: &str,
3998 ) -> Option<crate::cache::CacheContext> {
3999 if !Self::is_cacheable_read_sql(sql) {
4000 return None;
4001 }
4002 if session.tx_state.read().await.in_transaction {
4004 return None;
4005 }
4006 let (user, database) = {
4007 let vars = session.variables.read().await;
4008 (
4009 vars.get("user").cloned(),
4010 vars.get("database")
4011 .cloned()
4012 .unwrap_or_else(|| "default".to_string()),
4013 )
4014 };
4015 Some(crate::cache::CacheContext {
4016 database,
4017 user,
4018 branch: None,
4019 connection_id: Some(session.id.as_u64_pair().0),
4020 })
4021 }
4022
4023 #[cfg(feature = "multi-tenancy")]
4027 async fn tenant_request_ctx(
4028 session: &Arc<ClientSession>,
4029 ) -> crate::multi_tenancy::RequestContext {
4030 let vars = session.variables.read().await;
4031 crate::multi_tenancy::RequestContext {
4032 headers: vars.clone(),
4033 username: vars.get("user").cloned(),
4034 database: vars.get("database").cloned(),
4035 auth_token: None,
4036 sql_context: HashMap::new(),
4037 client_ip: Some(session.client_addr.ip().to_string()),
4038 connection_id: Some(session.id.as_u64_pair().0),
4039 }
4040 }
4041
4042 #[cfg(feature = "ha-tr")]
4047 async fn journal_write(state: &Arc<ServerState>, session: &Arc<ClientSession>, sql: &str) {
4048 let tx_id = uuid::Uuid::new_v4();
4049 let j = &state.transaction_journal;
4050 if j.begin_transaction(tx_id, session.id, crate::NodeId::new(), 0)
4051 .await
4052 .is_ok()
4053 {
4054 let _ = j
4055 .log_statement(tx_id, sql.to_string(), Vec::new(), None, None, 0)
4056 .await;
4057 }
4058 }
4059
4060 #[cfg(feature = "query-analytics")]
4063 async fn record_analytics(
4064 state: &Arc<ServerState>,
4065 session: &Arc<ClientSession>,
4066 sql: &str,
4067 node: &str,
4068 duration: Duration,
4069 error: Option<String>,
4070 ) {
4071 let Some(analytics) = state.analytics.as_ref() else {
4072 return;
4073 };
4074 let (user, database) = {
4075 let vars = session.variables.read().await;
4076 (
4077 vars.get("user").cloned().unwrap_or_default(),
4078 vars.get("database").cloned().unwrap_or_default(),
4079 )
4080 };
4081 let mut exec = crate::analytics::QueryExecution::new(sql, duration);
4082 exec.user = user;
4083 exec.database = database;
4084 exec.client_ip = session.client_addr.ip().to_string();
4085 exec.node = node.to_string();
4086 exec.session_id = Some(session.id.to_string());
4087 exec.error = error;
4088 analytics.record(exec);
4089 }
4090
4091 async fn select_primary_with_timeout(
4093 session: &Arc<ClientSession>,
4094 state: &Arc<ServerState>,
4095 config: &ProxyConfig,
4096 ) -> Result<String> {
4097 let timeout = config.write_timeout();
4098 let start = std::time::Instant::now();
4099 let check_interval = Duration::from_millis(100);
4102
4103 loop {
4104 let health = state.health.load_full();
4106 let primary = config
4107 .nodes
4108 .iter()
4109 .find(|n| n.role == NodeRole::Primary && n.enabled);
4110
4111 if let Some(primary_node) = primary {
4112 if let Some(node_health) = health.get(&primary_node.address()) {
4113 if node_health.healthy {
4114 let mut current = session.current_node.write().await;
4116 *current = Some(primary_node.address());
4117 return Ok(primary_node.address());
4118 }
4119 }
4120 }
4121 drop(health);
4122
4123 if start.elapsed() >= timeout {
4125 state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
4126 return Err(ProxyError::NoHealthyNodes);
4127 }
4128
4129 tracing::warn!(
4130 "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
4131 start.elapsed().as_secs_f64(),
4132 timeout.as_secs_f64()
4133 );
4134
4135 tokio::time::sleep(check_interval).await;
4137 }
4138 }
4139
4140 async fn select_read_node(
4142 session: &Arc<ClientSession>,
4143 state: &Arc<ServerState>,
4144 config: &ProxyConfig,
4145 ) -> Result<String> {
4146 {
4148 let tx_state = session.tx_state.read().await;
4149 if tx_state.in_transaction {
4150 if let Some(node) = session.current_node.read().await.clone() {
4151 return Ok(node);
4152 }
4153 }
4154 }
4155
4156 let health = state.health.load_full();
4158 let healthy_standbys: Vec<&NodeConfig> = config
4159 .nodes
4160 .iter()
4161 .filter(|n| {
4162 let base = n.enabled
4163 && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
4164 && health.get(&n.address()).map(|h| h.healthy).unwrap_or(false);
4165 #[cfg(feature = "circuit-breaker")]
4167 let base = base && !Self::circuit_is_open(state, &n.address());
4168 #[cfg(feature = "lag-routing")]
4170 let base = base
4171 && !Self::lag_excludes_standby(
4172 health
4173 .get(&n.address())
4174 .and_then(|h| h.replication_lag_bytes),
4175 config.lag_routing.max_lag_bytes,
4176 );
4177 base
4178 })
4179 .collect();
4180
4181 if !healthy_standbys.is_empty() {
4182 let ticket = state.lb_state.rr_counter.fetch_add(1, Ordering::Relaxed);
4184 let index = ticket as usize % healthy_standbys.len();
4185 let node_addr = healthy_standbys[index].address();
4186
4187 let mut current = session.current_node.write().await;
4188 *current = Some(node_addr.clone());
4189 return Ok(node_addr);
4190 }
4191
4192 Self::select_node(session, state, config).await
4194 }
4195
4196 async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
4199 let codec = ProtocolCodec::new();
4200 let mut buffer = BytesMut::with_capacity(4096);
4201 let mut read_buf = vec![0u8; 4096];
4202 let timeout = Duration::from_secs(10);
4203 let start = std::time::Instant::now();
4204
4205 loop {
4206 if start.elapsed() > timeout {
4207 return Err(ProxyError::Auth(
4208 "Backend authentication timeout".to_string(),
4209 ));
4210 }
4211
4212 let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
4213 .await
4214 .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
4215 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
4216
4217 if n == 0 {
4218 return Err(ProxyError::Connection(
4219 "Backend closed during auth".to_string(),
4220 ));
4221 }
4222
4223 buffer.extend_from_slice(&read_buf[..n]);
4224
4225 while let Some(msg) = codec.decode_message(&mut buffer)? {
4228 match msg.msg_type {
4229 MessageType::ReadyForQuery => {
4230 return Ok(());
4232 }
4233 MessageType::ErrorResponse => {
4234 let err = ErrorResponse::parse(msg.payload)
4235 .map(|e| e.message().unwrap_or("Unknown error").to_string())
4236 .unwrap_or_else(|_| "Parse error".to_string());
4237 return Err(ProxyError::Auth(err));
4238 }
4239 _ => {
4240 }
4242 }
4243 }
4244 }
4245 }
4246
4247 fn create_error_response(code: &str, message: &str) -> Vec<u8> {
4249 let mut fields = HashMap::new();
4250 fields.insert('S', "ERROR".to_string());
4251 fields.insert('V', "ERROR".to_string());
4252 fields.insert('C', code.to_string());
4253 fields.insert('M', message.to_string());
4254
4255 let err = ErrorResponse { fields };
4256 err.encode().encode().to_vec()
4257 }
4258
4259 fn create_ready_for_query(status: u8) -> Vec<u8> {
4262 let mut payload = BytesMut::with_capacity(1);
4263 payload.put_u8(status);
4264 Message::new(MessageType::ReadyForQuery, payload)
4265 .encode()
4266 .to_vec()
4267 }
4268
4269 #[cfg(feature = "wasm-plugins")]
4307 fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
4308 use serde::Deserialize;
4309
4310 #[derive(Deserialize)]
4311 struct CachedPayload {
4312 columns: Vec<ColumnDef>,
4313 rows: Vec<Vec<Option<String>>>,
4314 }
4315
4316 #[derive(Deserialize)]
4317 struct ColumnDef {
4318 name: String,
4319 #[serde(default = "default_text_oid")]
4320 oid: u32,
4321 }
4322
4323 fn default_text_oid() -> u32 {
4324 25 }
4326
4327 let payload: CachedPayload = serde_json::from_slice(bytes)
4328 .map_err(|e| ProxyError::Protocol(format!("invalid cached payload JSON: {}", e)))?;
4329
4330 if payload.columns.is_empty() {
4331 return Err(ProxyError::Protocol(
4332 "cached payload must declare at least one column".to_string(),
4333 ));
4334 }
4335
4336 let mut reply = Vec::new();
4337
4338 let mut rd = BytesMut::new();
4340 rd.put_u16(payload.columns.len() as u16);
4341 for col in &payload.columns {
4342 rd.extend_from_slice(col.name.as_bytes());
4343 rd.put_u8(0); rd.put_i32(0); rd.put_i16(0); rd.put_u32(col.oid);
4347 rd.put_i16(-1); rd.put_i32(-1); rd.put_i16(0); }
4351 reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
4352
4353 let column_count = payload.columns.len();
4355 for row in &payload.rows {
4356 if row.len() != column_count {
4357 return Err(ProxyError::Protocol(format!(
4358 "cached row has {} values but {} columns are declared",
4359 row.len(),
4360 column_count
4361 )));
4362 }
4363 let mut dr = BytesMut::new();
4364 dr.put_u16(row.len() as u16);
4365 for value in row {
4366 match value {
4367 Some(s) => {
4368 dr.put_i32(s.len() as i32);
4369 dr.extend_from_slice(s.as_bytes());
4370 }
4371 None => {
4372 dr.put_i32(-1); }
4374 }
4375 }
4376 reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
4377 }
4378
4379 let tag = format!("SELECT {}", payload.rows.len());
4381 let mut cc = BytesMut::new();
4382 cc.extend_from_slice(tag.as_bytes());
4383 cc.put_u8(0);
4384 reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
4385
4386 reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
4388
4389 Ok(reply)
4390 }
4391
4392 fn apply_pre_query_hook(
4402 msg: Message,
4403 state: &Arc<ServerState>,
4404 session: &Arc<ClientSession>,
4405 ) -> (Message, PreQueryAction) {
4406 #[cfg(feature = "wasm-plugins")]
4407 {
4408 let pm = match state.plugin_manager.as_ref() {
4409 Some(pm) => pm,
4410 None => return (msg, PreQueryAction::Forward),
4411 };
4412
4413 if msg.msg_type != MessageType::Query {
4414 return (msg, PreQueryAction::Forward);
4415 }
4416
4417 if !pm.has_hook(HookType::PreQuery) {
4420 return (msg, PreQueryAction::Forward);
4421 }
4422
4423 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4424 Ok(q) => q,
4425 Err(_) => return (msg, PreQueryAction::Forward),
4426 };
4427
4428 let ctx = Self::build_query_context(&query_msg.query, session);
4429
4430 match pm.execute_pre_query(&ctx) {
4431 PreQueryResult::Continue => (msg, PreQueryAction::Forward),
4432 PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
4433 PreQueryResult::Rewrite(new_sql) => {
4434 let rewritten = QueryMessage { query: new_sql }.encode();
4435 (rewritten, PreQueryAction::Forward)
4436 }
4437 PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
4438 }
4439 }
4440 #[cfg(not(feature = "wasm-plugins"))]
4441 {
4442 let _ = (state, session);
4443 (msg, PreQueryAction::Forward)
4444 }
4445 }
4446
4447 #[cfg(feature = "anomaly-detection")]
4453 fn record_anomaly_observation(
4454 msg: &Message,
4455 state: &Arc<ServerState>,
4456 session: &Arc<ClientSession>,
4457 ) {
4458 if msg.msg_type != MessageType::Query {
4459 return;
4460 }
4461 if let Some(query) = crate::protocol::query_text(&msg.payload) {
4464 Self::record_anomaly_sql(query, state, session);
4465 }
4466 }
4467
4468 #[cfg(feature = "anomaly-detection")]
4472 fn record_anomaly_sql(query: &str, state: &Arc<ServerState>, session: &Arc<ClientSession>) {
4473 let tenant = match session.variables.try_read() {
4480 Ok(vars) => vars
4481 .get("tenant_id")
4482 .or_else(|| vars.get("user"))
4483 .cloned()
4484 .unwrap_or_else(|| session.client_addr.ip().to_string()),
4485 Err(_) => session.client_addr.ip().to_string(),
4486 };
4487 let fingerprint = anomaly_fingerprint(query);
4488 let obs = crate::anomaly::QueryObservation {
4489 tenant,
4490 fingerprint,
4491 sql: query.to_string(),
4492 timestamp: std::time::Instant::now(),
4493 };
4494 for ev in state.anomaly_detector.record_query(&obs) {
4495 tracing::warn!(anomaly = ?ev, "anomaly detected");
4496 }
4497 }
4498
4499 async fn send_block_response(
4503 stream: &mut ClientStream,
4504 reason: &str,
4505 state: &Arc<ServerState>,
4506 ) -> Result<()> {
4507 let err =
4508 Self::create_error_response("42000", &format!("Query blocked by plugin: {}", reason));
4509 stream
4510 .write_all(&err)
4511 .await
4512 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
4513 let rfq = Self::create_ready_for_query(b'I');
4514 stream
4515 .write_all(&rfq)
4516 .await
4517 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
4518 state
4519 .metrics
4520 .bytes_sent
4521 .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
4522 Ok(())
4523 }
4524
4525 #[cfg(feature = "wasm-plugins")]
4531 fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
4532 let is_read_only = !Self::is_write_query(query);
4533 let hook_context = HookContext {
4534 client_id: Some(session.id.to_string()),
4535 ..HookContext::default()
4536 };
4537 QueryContext {
4538 query: query.to_string(),
4539 normalized: query.to_string(),
4540 tables: Vec::new(),
4541 is_read_only,
4542 hook_context,
4543 }
4544 }
4545
4546 async fn apply_authenticate_hook(
4567 _params: &HashMap<String, String>,
4568 _session: &Arc<ClientSession>,
4569 _state: &Arc<ServerState>,
4570 ) -> Result<()> {
4571 #[cfg(feature = "wasm-plugins")]
4572 {
4573 let pm = match _state.plugin_manager.as_ref() {
4574 Some(pm) => pm,
4575 None => return Ok(()),
4576 };
4577
4578 let request = PluginAuthRequest {
4579 headers: HashMap::new(),
4580 username: _params.get("user").cloned(),
4581 password: None,
4582 client_ip: _session.client_addr.ip().to_string(),
4583 database: _params.get("database").cloned(),
4584 };
4585
4586 match pm.execute_authenticate(&request) {
4587 AuthResult::Defer => Ok(()),
4588 AuthResult::Success(identity) => {
4589 tracing::debug!(
4590 user = %identity.username,
4591 roles = ?identity.roles,
4592 "plugin authenticated user"
4593 );
4594 *_session.plugin_identity.write().await = Some(identity);
4595 Ok(())
4596 }
4597 AuthResult::Denied(reason) => {
4598 tracing::info!(
4599 reason = %reason,
4600 client = %_session.client_addr,
4601 user = ?_params.get("user"),
4602 "plugin denied authentication"
4603 );
4604 Err(ProxyError::Auth(format!(
4605 "authentication denied by plugin: {}",
4606 reason
4607 )))
4608 }
4609 }
4610 }
4611 #[cfg(not(feature = "wasm-plugins"))]
4612 {
4613 Ok(())
4614 }
4615 }
4616
4617 fn apply_route_hook(
4620 msg: &Message,
4621 state: &Arc<ServerState>,
4622 session: &Arc<ClientSession>,
4623 ) -> RouteOverride {
4624 #[cfg(feature = "wasm-plugins")]
4625 {
4626 let pm = match state.plugin_manager.as_ref() {
4627 Some(pm) => pm,
4628 None => return RouteOverride::None,
4629 };
4630 if msg.msg_type != MessageType::Query {
4631 return RouteOverride::None;
4632 }
4633 if !pm.has_hook(HookType::Route) {
4636 return RouteOverride::None;
4637 }
4638 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4639 Ok(q) => q,
4640 Err(_) => return RouteOverride::None,
4641 };
4642 let ctx = Self::build_query_context(&query_msg.query, session);
4643 match pm.execute_route(&ctx) {
4644 RouteResult::Default => RouteOverride::None,
4645 RouteResult::Primary => RouteOverride::Primary,
4646 RouteResult::Standby => RouteOverride::Standby,
4647 RouteResult::Node(name) => RouteOverride::Node(name),
4648 RouteResult::Block(reason) => RouteOverride::Block(reason),
4649 RouteResult::Branch(name) => {
4650 tracing::warn!(
4651 branch = %name,
4652 "Route hook returned Branch but branch routing is not yet wired — using default"
4653 );
4654 RouteOverride::None
4655 }
4656 }
4657 }
4658 #[cfg(not(feature = "wasm-plugins"))]
4659 {
4660 let _ = (msg, state, session);
4661 RouteOverride::None
4662 }
4663 }
4664
4665 #[cfg(feature = "routing-hints")]
4671 fn hint_to_override(hints: &crate::routing::ParsedHints) -> RouteOverride {
4672 use crate::routing::{ConsistencyLevel, RouteTarget};
4673 if let Some(node) = &hints.node {
4674 return RouteOverride::Node(node.clone());
4675 }
4676 if let Some(route) = hints.route {
4677 return match route {
4678 RouteTarget::Primary => RouteOverride::Primary,
4679 RouteTarget::Standby
4680 | RouteTarget::Sync
4681 | RouteTarget::SemiSync
4682 | RouteTarget::Async
4683 | RouteTarget::Local => RouteOverride::Standby,
4684 RouteTarget::Any | RouteTarget::Vector => RouteOverride::None,
4685 };
4686 }
4687 if hints.consistency == Some(ConsistencyLevel::Strong) {
4688 return RouteOverride::Primary;
4689 }
4690 RouteOverride::None
4691 }
4692
4693 #[cfg(feature = "routing-hints")]
4701 fn resolve_simple_route(
4702 msg: &Message,
4703 plugin_override: RouteOverride,
4704 default_is_write: bool,
4705 state: &Arc<ServerState>,
4706 ) -> (RouteOverride, bool, Option<Message>) {
4707 let parser = match state.hint_parser.as_ref() {
4708 Some(p) => p,
4709 None => return (plugin_override, default_is_write, None),
4710 };
4711 let sql = match crate::protocol::query_text(&msg.payload) {
4712 Some(s) => s,
4713 None => return (plugin_override, default_is_write, None),
4714 };
4715 let hints = parser.parse(sql);
4716 if hints.is_empty() {
4717 return (plugin_override, default_is_write, None);
4718 }
4719 let stripped = parser.strip(sql);
4720 let is_write = Self::is_write_query(&stripped);
4721 let effective = match Self::hint_to_override(&hints) {
4722 RouteOverride::None => plugin_override,
4723 hint_override => hint_override,
4724 };
4725 let forward = if parser.strip_hints {
4726 Some(crate::protocol::QueryMessage { query: stripped }.encode())
4727 } else {
4728 None
4729 };
4730 (effective, is_write, forward)
4731 }
4732
4733 #[cfg(feature = "routing-hints")]
4740 fn extended_hint_route(state: &Arc<ServerState>, sql: &str) -> Option<(bool, Option<String>)> {
4741 let parser = state.hint_parser.as_ref()?;
4742 let hints = parser.parse(sql);
4743 if hints.is_empty() {
4744 return None;
4745 }
4746 let stripped = parser.strip(sql);
4747 let is_write = Self::is_write_query(&stripped);
4748 match Self::hint_to_override(&hints) {
4749 RouteOverride::Primary => Some((true, None)),
4750 RouteOverride::Standby => Some((false, None)),
4751 RouteOverride::Node(n) => Some((is_write, Some(n))),
4752 _ => Some((is_write, None)),
4753 }
4754 }
4755
4756 #[cfg(feature = "wasm-plugins")]
4760 fn fire_post_query_hook(
4761 msg: &Message,
4762 session: &Arc<ClientSession>,
4763 state: &Arc<ServerState>,
4764 result: &Result<(Option<String>, u64)>,
4765 elapsed: Duration,
4766 ) {
4767 let pm = match state.plugin_manager.as_ref() {
4768 Some(pm) => pm,
4769 None => return,
4770 };
4771 if msg.msg_type != MessageType::Query {
4772 return;
4773 }
4774 if !pm.has_hook(HookType::PostQuery) {
4777 return;
4778 }
4779 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
4780 Ok(q) => q,
4781 Err(_) => return,
4782 };
4783 let ctx = Self::build_query_context(&query_msg.query, session);
4784 let outcome = match result {
4785 Ok((node, bytes)) => PostQueryOutcome {
4786 success: true,
4787 target_node: node.clone(),
4788 elapsed_us: elapsed.as_micros() as u64,
4789 response_bytes: *bytes,
4790 error: None,
4791 },
4792 Err(e) => PostQueryOutcome {
4793 success: false,
4794 target_node: None,
4795 elapsed_us: elapsed.as_micros() as u64,
4796 response_bytes: 0,
4797 error: Some(e.to_string()),
4798 },
4799 };
4800 pm.execute_post_query(&ctx, &outcome);
4801 }
4802
4803 async fn select_node(
4807 session: &Arc<ClientSession>,
4808 state: &Arc<ServerState>,
4809 config: &ProxyConfig,
4810 ) -> Result<String> {
4811 {
4813 let tx_state = session.tx_state.read().await;
4814 if tx_state.in_transaction {
4815 if let Some(node) = session.current_node.read().await.clone() {
4816 return Ok(node);
4817 }
4818 }
4819 }
4820
4821 let health = state.health.load_full();
4823 let healthy_nodes: Vec<&NodeConfig> = config
4824 .nodes
4825 .iter()
4826 .filter(|n| n.enabled && health.get(&n.address()).map(|h| h.healthy).unwrap_or(false))
4827 .collect();
4828
4829 if healthy_nodes.is_empty() {
4830 return Err(ProxyError::NoHealthyNodes);
4831 }
4832
4833 if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
4835 let node_addr = primary.address();
4836 let mut current = session.current_node.write().await;
4837 *current = Some(node_addr.clone());
4838 return Ok(node_addr);
4839 }
4840
4841 if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
4844 tracing::warn!("Primary unavailable, connecting to standby for initial session");
4845 let node_addr = standby.address();
4846 let mut current = session.current_node.write().await;
4847 *current = Some(node_addr.clone());
4848 return Ok(node_addr);
4849 }
4850
4851 Err(ProxyError::NoHealthyNodes)
4853 }
4854
4855 fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
4857 let state = self.state.clone();
4858 let mut shutdown_rx = self.shutdown_tx.subscribe();
4859
4860 tokio::spawn(async move {
4861 let mut interval = tokio::time::interval(std::time::Duration::from_secs(
4862 state.live_config.load().health.check_interval_secs,
4863 ));
4864
4865 loop {
4866 tokio::select! {
4867 _ = interval.tick() => {
4868 let config = state.live_config.load_full();
4871 Self::check_all_nodes(&state, &config).await;
4872 }
4873 _ = shutdown_rx.recv() => {
4874 break;
4875 }
4876 }
4877 }
4878 })
4879 }
4880
4881 async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
4888 let timeout = Duration::from_secs(config.health.check_timeout_secs);
4891 let mut set = tokio::task::JoinSet::new();
4892 for node in &config.nodes {
4893 let addr = node.address();
4894 set.spawn(async move {
4895 let r = Self::check_node_addr(&addr, timeout).await;
4896 (addr, r)
4897 });
4898 }
4899 let mut results = Vec::with_capacity(config.nodes.len());
4900 while let Some(joined) = set.join_next().await {
4901 if let Ok(pair) = joined {
4902 results.push(pair);
4903 }
4904 }
4905
4906 let _writers = state.health_write.lock();
4912 let mut next = (*state.health.load_full()).clone();
4913 for (addr, result) in results {
4914 if let Some(node_health) = next.get_mut(&addr) {
4915 match result {
4916 Ok(latency) => {
4917 node_health.healthy = true;
4918 node_health.failure_count = 0;
4919 node_health.latency_ms = latency;
4920 node_health.last_error = None;
4921 }
4922 Err(e) => {
4923 node_health.failure_count += 1;
4924 node_health.last_error = Some(e.to_string());
4925 if node_health.failure_count >= config.health.failure_threshold {
4926 node_health.healthy = false;
4927 tracing::warn!(
4928 "Node {} marked unhealthy after {} failures",
4929 addr,
4930 node_health.failure_count
4931 );
4932 }
4933 }
4934 }
4935 node_health.last_check = chrono::Utc::now();
4936 }
4937 }
4938 state.health.store(Arc::new(next));
4939 }
4940
4941 async fn check_node_addr(addr: &str, timeout: Duration) -> Result<f64> {
4952 const SSL_REQUEST: [u8; 8] = [0, 0, 0, 8, 0x04, 0xD2, 0x16, 0x2F];
4954 let start = std::time::Instant::now();
4955 let mut stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
4956 .await
4957 .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", addr)))?
4958 .map_err(|e| {
4959 ProxyError::HealthCheck(format!("Failed to connect to {}: {}", addr, e))
4960 })?;
4961
4962 let probe = async {
4963 stream.write_all(&SSL_REQUEST).await?;
4964 let mut resp = [0u8; 1];
4965 stream.read_exact(&mut resp).await?;
4966 Ok::<u8, std::io::Error>(resp[0])
4967 };
4968 let remaining = timeout
4970 .saturating_sub(start.elapsed())
4971 .max(Duration::from_millis(1));
4972 let byte = tokio::time::timeout(remaining, probe)
4973 .await
4974 .map_err(|_| {
4975 ProxyError::HealthCheck(format!("{} did not answer protocol probe in time", addr))
4976 })?
4977 .map_err(|e| {
4978 ProxyError::HealthCheck(format!("{} protocol probe error: {}", addr, e))
4979 })?;
4980 if byte != b'S' && byte != b'N' {
4983 return Err(ProxyError::HealthCheck(format!(
4984 "{} sent unexpected probe reply {:#x}",
4985 addr, byte
4986 )));
4987 }
4988 let latency = start.elapsed().as_secs_f64() * 1000.0;
4989 Ok(latency)
4990 }
4991
4992 fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
4994 #[cfg(feature = "pool-modes")]
4996 let state = self.state.clone();
4997 let mut shutdown_rx = self.shutdown_tx.subscribe();
4998
4999 tokio::spawn(async move {
5000 let mut interval = tokio::time::interval(Self::POOL_REAP_INTERVAL);
5001
5002 loop {
5003 tokio::select! {
5004 _ = interval.tick() => {
5005 #[cfg(feature = "pool-modes")]
5007 if let Some(ref pool_manager) = state.pool_manager {
5008 pool_manager.evict_idle().await;
5009 tracing::trace!("Pool-modes idle eviction completed");
5010 }
5011 #[cfg(feature = "pool-modes")]
5016 if let Some(ref backend_pool) = state.backend_pool {
5017 let ttl = std::time::Duration::from_secs(
5018 state.live_config.load().pool_mode.idle_timeout_secs,
5019 );
5020 let n = if ttl.is_zero() {
5026 0
5027 } else {
5028 backend_pool.reap_idle(ttl)
5029 };
5030 if n > 0 {
5031 tracing::debug!(
5032 target: "helios::pool",
5033 reaped = n,
5034 idle_remaining = backend_pool.idle_count(),
5035 "reaped idle backend connections (TTL)"
5036 );
5037 }
5038 }
5039 }
5040 _ = shutdown_rx.recv() => {
5041 #[cfg(feature = "pool-modes")]
5043 if let Some(ref pool_manager) = state.pool_manager {
5044 pool_manager.close_all().await;
5045 tracing::info!("Pool-modes manager closed all connections");
5046 }
5047 break;
5048 }
5049 }
5050 }
5051 })
5052 }
5053
5054 pub fn shutdown(&self) {
5056 let _ = self.shutdown_tx.send(());
5057 }
5058
5059 #[cfg(feature = "pool-modes")]
5061 pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
5062 if let Some(ref pool_manager) = self.state.pool_manager {
5063 let stats = pool_manager.get_stats().await;
5064 let metrics = pool_manager.metrics().snapshot();
5065 let default_mode = pool_manager.default_mode();
5066
5067 let avg_lease_duration_ms = metrics
5069 .mode_stats
5070 .get(&default_mode)
5071 .map(|s| s.avg_lease_duration_ms as u64)
5072 .unwrap_or(0);
5073
5074 Some(PoolModeStatsSnapshot {
5075 mode: format!("{:?}", default_mode),
5076 total_connections: stats.total_connections,
5077 active_leases: stats.active_connections,
5078 idle_connections: stats.idle_connections,
5079 node_count: stats.node_count,
5080 acquires: metrics.acquires,
5081 releases: metrics.releases,
5082 acquire_failures: metrics.acquire_failures,
5083 acquire_timeouts: metrics.acquire_timeouts,
5084 transactions_completed: metrics.transactions_completed,
5085 statements_executed: metrics.statements_executed,
5086 avg_lease_duration_ms,
5087 })
5088 } else {
5089 None
5090 }
5091 }
5092
5093 #[cfg(feature = "pool-modes")]
5095 pub async fn add_node_to_pool(&self, node: &NodeConfig) {
5096 if let Some(ref pool_manager) = self.state.pool_manager {
5097 let endpoint = NodeEndpoint::new(&node.host, node.port)
5098 .with_role(match node.role {
5099 NodeRole::Primary => crate::NodeRole::Primary,
5100 NodeRole::Standby => crate::NodeRole::Standby,
5101 NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
5102 })
5103 .with_weight(node.weight);
5104 pool_manager.add_node(&endpoint).await;
5105 tracing::info!("Added node {} to pool manager", node.address());
5106 }
5107 }
5108
5109 pub fn metrics(&self) -> ServerMetricsSnapshot {
5111 ServerMetricsSnapshot {
5112 connections_accepted: self
5113 .state
5114 .metrics
5115 .connections_accepted
5116 .load(Ordering::Relaxed),
5117 connections_closed: self
5118 .state
5119 .metrics
5120 .connections_closed
5121 .load(Ordering::Relaxed),
5122 queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
5123 bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
5124 bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
5125 failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
5126 }
5127 }
5128}
5129
5130#[derive(Debug, Clone)]
5132pub struct ServerMetricsSnapshot {
5133 pub connections_accepted: u64,
5134 pub connections_closed: u64,
5135 pub queries_processed: u64,
5136 pub bytes_received: u64,
5137 pub bytes_sent: u64,
5138 pub failovers: u64,
5139}
5140
5141#[cfg(feature = "pool-modes")]
5143#[derive(Debug, Clone)]
5144pub struct PoolModeStatsSnapshot {
5145 pub mode: String,
5147 pub total_connections: usize,
5149 pub active_leases: usize,
5151 pub idle_connections: usize,
5153 pub node_count: usize,
5155 pub acquires: u64,
5157 pub releases: u64,
5159 pub acquire_failures: u64,
5161 pub acquire_timeouts: u64,
5163 pub transactions_completed: u64,
5165 pub statements_executed: u64,
5167 pub avg_lease_duration_ms: u64,
5169}
5170
5171#[cfg(test)]
5172mod tests {
5173 use super::*;
5174 use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
5175 #[cfg(not(feature = "wasm-plugins"))]
5176 use crate::protocol::QueryMessage;
5177
5178 fn test_config() -> ProxyConfig {
5179 let mut config = ProxyConfig::default();
5180 config.listen_address = "127.0.0.1:0".to_string();
5181 config.add_node("127.0.0.1:5432", "primary").unwrap();
5182 config
5183 }
5184
5185 #[test]
5186 fn test_server_creation() {
5187 let config = test_config();
5188 let server = ProxyServer::new(config);
5189 assert!(server.is_ok());
5190 }
5191
5192 #[test]
5193 fn is_backend_fault_excludes_client_and_slow_query_errors() {
5194 assert!(ProxyServer::is_backend_fault(
5196 "Backend read error: connection reset"
5197 ));
5198 assert!(ProxyServer::is_backend_fault(
5199 "Backend write error: broken pipe"
5200 ));
5201 assert!(ProxyServer::is_backend_fault("Backend write timeout"));
5202 assert!(ProxyServer::is_backend_fault(
5203 "Failed to connect to 127.0.0.1:5432: Connection refused"
5204 ));
5205 assert!(!ProxyServer::is_backend_fault("Backend read timeout"));
5208 assert!(!ProxyServer::is_backend_fault("Client write timeout"));
5209 assert!(!ProxyServer::is_backend_fault(
5210 "Client write error: broken pipe"
5211 ));
5212 assert!(!ProxyServer::is_backend_fault("Backend read timeout"));
5214 assert!(ProxyServer::is_backend_fault(
5215 "Backend read error: timed out"
5216 ));
5217 }
5218
5219 #[test]
5220 fn test_hba_addr_matches() {
5221 use std::net::IpAddr;
5222 let v4 = |s: &str| s.parse::<IpAddr>().unwrap();
5223 assert!(ProxyServer::hba_addr_matches("all", v4("203.0.113.7")));
5225 assert!(ProxyServer::hba_addr_matches("10.0.0.0/8", v4("10.1.2.3")));
5227 assert!(!ProxyServer::hba_addr_matches("10.0.0.0/8", v4("11.1.2.3")));
5228 assert!(ProxyServer::hba_addr_matches(
5229 "127.0.0.1/32",
5230 v4("127.0.0.1")
5231 ));
5232 assert!(!ProxyServer::hba_addr_matches(
5233 "127.0.0.1/32",
5234 v4("127.0.0.2")
5235 ));
5236 assert!(ProxyServer::hba_addr_matches(
5238 "192.168.1.1",
5239 v4("192.168.1.1")
5240 ));
5241 assert!(!ProxyServer::hba_addr_matches(
5242 "192.168.1.1",
5243 v4("192.168.1.2")
5244 ));
5245 assert!(ProxyServer::hba_addr_matches("::1/128", v4("::1")));
5247 assert!(ProxyServer::hba_addr_matches("0.0.0.0/0", v4("8.8.8.8")));
5248 }
5249
5250 #[test]
5251 fn test_hba_admits() {
5252 use crate::config::{HbaAction, HbaRule};
5253 use std::net::IpAddr;
5254 let ip: IpAddr = "10.0.0.5".parse().unwrap();
5255 assert!(ProxyServer::hba_admits(&[], ip, "bench", "benchdb"));
5257 let rules = vec![HbaRule {
5259 action: HbaAction::Reject,
5260 user: "bench".into(),
5261 database: "all".into(),
5262 address: "all".into(),
5263 }];
5264 assert!(!ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
5265 assert!(ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
5266 let rules = vec![
5268 HbaRule {
5269 action: HbaAction::Allow,
5270 user: "bench".into(),
5271 database: "all".into(),
5272 address: "10.0.0.0/8".into(),
5273 },
5274 HbaRule {
5275 action: HbaAction::Reject,
5276 user: "all".into(),
5277 database: "all".into(),
5278 address: "all".into(),
5279 },
5280 ];
5281 assert!(ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
5282 assert!(!ProxyServer::hba_admits(
5283 &rules,
5284 "192.168.0.1".parse().unwrap(),
5285 "bench",
5286 "benchdb"
5287 ));
5288 assert!(!ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
5289 }
5290
5291 #[test]
5292 fn test_initial_metrics() {
5293 let config = test_config();
5294 let server = ProxyServer::new(config).unwrap();
5295 let metrics = server.metrics();
5296 assert_eq!(metrics.connections_accepted, 0);
5297 assert_eq!(metrics.queries_processed, 0);
5298 }
5299
5300 #[tokio::test]
5301 async fn test_session_creation() {
5302 let config = test_config();
5303 let server = ProxyServer::new(config).unwrap();
5304
5305 let sessions = server.state.sessions.read().await;
5306 assert!(sessions.is_empty());
5307 }
5308
5309 #[tokio::test]
5310 async fn test_node_health_initialization() {
5311 let config = test_config();
5312 let server = ProxyServer::new(config).unwrap();
5313
5314 let health = server.state.health.load_full();
5315 assert!(!health.is_empty());
5316
5317 for node_health in health.values() {
5318 assert!(node_health.healthy);
5319 assert_eq!(node_health.failure_count, 0);
5320 }
5321 }
5322
5323 fn make_test_session() -> Arc<ClientSession> {
5325 Arc::new(ClientSession {
5326 id: Uuid::new_v4(),
5327 client_addr: "127.0.0.1:0".parse().unwrap(),
5328 current_node: RwLock::new(None),
5329 tx_state: RwLock::new(TransactionState::default()),
5330 variables: RwLock::new(HashMap::new()),
5331 created_at: chrono::Utc::now(),
5332 tr_mode: crate::config::TrMode::default(),
5333 #[cfg(feature = "lag-routing")]
5334 last_write_at: RwLock::new(None),
5335 #[cfg(feature = "pool-modes")]
5336 pool_client_id: crate::pool::lease::ClientId::default(),
5337 #[cfg(feature = "wasm-plugins")]
5338 plugin_identity: RwLock::new(None),
5339 })
5340 }
5341
5342 #[tokio::test]
5346 async fn test_apply_route_hook_no_plugin_manager_returns_none() {
5347 let config = test_config();
5348 let server = ProxyServer::new(config).unwrap();
5349 let session = make_test_session();
5350
5351 let msg = QueryMessage {
5352 query: "SELECT * FROM users".to_string(),
5353 }
5354 .encode();
5355
5356 let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
5357 assert!(matches!(decision, RouteOverride::None));
5358 }
5359
5360 #[tokio::test]
5364 async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
5365 let config = test_config();
5366 let server = ProxyServer::new(config).unwrap();
5367 let session = make_test_session();
5368
5369 let original = QueryMessage {
5370 query: "SELECT 1".to_string(),
5371 }
5372 .encode();
5373 let original_bytes = original.encode().to_vec();
5374
5375 let (msg_out, action) =
5376 ProxyServer::apply_pre_query_hook(original, &server.state, &session);
5377
5378 assert!(matches!(action, PreQueryAction::Forward));
5379 assert_eq!(msg_out.encode().to_vec(), original_bytes);
5381 }
5382
5383 #[tokio::test]
5387 async fn test_apply_route_hook_skips_non_query_messages() {
5388 let config = test_config();
5389 let server = ProxyServer::new(config).unwrap();
5390 let session = make_test_session();
5391
5392 let sync_msg = Message::empty(MessageType::Sync);
5393 let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
5394 assert!(matches!(decision, RouteOverride::None));
5395 }
5396
5397 #[cfg(feature = "wasm-plugins")]
5402 #[test]
5403 fn test_init_plugin_manager_disabled_by_default_returns_none() {
5404 let config = test_config();
5405 assert!(!config.plugins.enabled);
5406 let pm = ProxyServer::init_plugin_manager(&config.plugins);
5407 assert!(pm.is_none());
5408 }
5409
5410 #[cfg(feature = "wasm-plugins")]
5414 #[test]
5415 fn test_init_plugin_manager_missing_dir_logs_warning() {
5416 let mut config = test_config();
5417 config.plugins.enabled = true;
5418 config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
5419
5420 let pm = ProxyServer::init_plugin_manager(&config.plugins);
5422 assert!(pm.is_some());
5423 }
5424
5425 #[tokio::test]
5429 async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
5430 let config = test_config();
5431 let server = ProxyServer::new(config).unwrap();
5432 let session = make_test_session();
5433
5434 let mut params = HashMap::new();
5435 params.insert("user".to_string(), "alice".to_string());
5436 params.insert("database".to_string(), "app".to_string());
5437
5438 let result = ProxyServer::apply_authenticate_hook(¶ms, &session, &server.state).await;
5439 assert!(result.is_ok());
5440
5441 #[cfg(feature = "wasm-plugins")]
5443 {
5444 let ident = session.plugin_identity.read().await;
5445 assert!(ident.is_none());
5446 }
5447 }
5448
5449 #[cfg(feature = "wasm-plugins")]
5457 #[test]
5458 fn test_synthesise_cached_response_roundtrip() {
5459 let payload = br#"{
5460 "columns": [
5461 {"name": "id", "oid": 23},
5462 {"name": "email", "oid": 25}
5463 ],
5464 "rows": [
5465 ["1", "alice@example.com"],
5466 ["2", null]
5467 ]
5468 }"#;
5469 let reply = ProxyServer::synthesise_cached_response(payload).expect("synthesis");
5470
5471 let mut tags = Vec::new();
5474 let mut i = 0;
5475 while i < reply.len() {
5476 let tag = reply[i];
5477 let len = u32::from_be_bytes([reply[i + 1], reply[i + 2], reply[i + 3], reply[i + 4]])
5478 as usize;
5479 tags.push(tag);
5480 i += 1 + len;
5481 }
5482 assert_eq!(i, reply.len(), "no trailing bytes");
5483 assert_eq!(tags, vec![b'T', b'D', b'D', b'C', b'Z'], "wire frame order");
5484
5485 assert_eq!(*reply.last().unwrap(), b'I');
5487 }
5488
5489 #[cfg(feature = "wasm-plugins")]
5492 #[test]
5493 fn test_synthesise_cached_response_rejects_row_width_mismatch() {
5494 let payload = br#"{
5495 "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
5496 "rows": [["1", "alice", "extra"]]
5497 }"#;
5498 let result = ProxyServer::synthesise_cached_response(payload);
5499 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5500 }
5501
5502 #[cfg(feature = "wasm-plugins")]
5506 #[test]
5507 fn test_synthesise_cached_response_rejects_empty_columns() {
5508 let payload = br#"{ "columns": [], "rows": [] }"#;
5509 let result = ProxyServer::synthesise_cached_response(payload);
5510 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5511 }
5512
5513 #[cfg(feature = "wasm-plugins")]
5516 #[test]
5517 fn test_synthesise_cached_response_rejects_bad_json() {
5518 let payload = b"not json at all";
5519 let result = ProxyServer::synthesise_cached_response(payload);
5520 assert!(matches!(result, Err(ProxyError::Protocol(_))));
5521 }
5522
5523 #[cfg(feature = "wasm-plugins")]
5532 #[tokio::test]
5533 async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
5534 use crate::plugins::{PluginManager, PluginRuntimeConfig};
5535
5536 let config = test_config();
5537 let server = ProxyServer::new(config).unwrap();
5538 let session = make_test_session();
5539
5540 let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
5543 let augmented_state = Arc::new(ServerState {
5544 sessions: RwLock::new(HashMap::new()),
5545 health: ArcSwap::from_pointee(HashMap::new()),
5546 health_write: parking_lot::Mutex::new(()),
5547 live_config: ArcSwap::from_pointee(ProxyConfig::default()),
5548 metrics: ServerMetrics::default(),
5549 cancel_map: Arc::new(DashMap::new()),
5550 cancel_order: Arc::new(parking_lot::Mutex::new(std::collections::VecDeque::new())),
5551 tls_acceptor: None,
5552 auth_file: None,
5553 mirror: None,
5554 cutover: Arc::new(ArcSwap::from_pointee(None)),
5555 lb_state: LoadBalancerState {
5556 rr_counter: AtomicU64::new(0),
5557 },
5558 #[cfg(feature = "routing-hints")]
5559 hint_parser: None,
5560 #[cfg(feature = "rate-limiting")]
5561 rate_limiter: None,
5562 #[cfg(feature = "circuit-breaker")]
5563 circuit_breaker: None,
5564 #[cfg(feature = "query-analytics")]
5565 analytics: None,
5566 #[cfg(feature = "query-cache")]
5567 query_cache: None,
5568 #[cfg(feature = "query-rewriting")]
5569 rewriter: None,
5570 #[cfg(feature = "multi-tenancy")]
5571 tenant_manager: None,
5572 #[cfg(feature = "schema-routing")]
5573 schema_analyzer: None,
5574 #[cfg(feature = "pool-modes")]
5575 pool_manager: None,
5576 #[cfg(feature = "pool-modes")]
5577 backend_pool: None,
5578 plugin_manager: Some(pm),
5579 #[cfg(feature = "ha-tr")]
5580 transaction_journal: Arc::new(crate::transaction_journal::TransactionJournal::new()),
5581 #[cfg(feature = "anomaly-detection")]
5582 anomaly_detector: Arc::new(crate::anomaly::AnomalyDetector::new(
5583 crate::anomaly::AnomalyConfig::default(),
5584 )),
5585 #[cfg(feature = "edge-proxy")]
5586 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
5587 #[cfg(feature = "edge-proxy")]
5588 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
5589 32,
5590 std::time::Duration::from_secs(120),
5591 )),
5592 });
5593
5594 let mut params = HashMap::new();
5595 params.insert("user".to_string(), "alice".to_string());
5596
5597 let result =
5598 ProxyServer::apply_authenticate_hook(¶ms, &session, &augmented_state).await;
5599 assert!(result.is_ok());
5600 let ident = session.plugin_identity.read().await;
5601 assert!(ident.is_none());
5602 let _ = server;
5604 }
5605
5606 fn cstr(s: &str) -> Vec<u8> {
5609 let mut v = s.as_bytes().to_vec();
5610 v.push(0);
5611 v
5612 }
5613
5614 #[test]
5615 fn parse_stmt_name_extracts_named_and_unnamed() {
5616 let mut named = cstr("ps1");
5618 named.extend_from_slice(&cstr("SELECT 1"));
5619 named.extend_from_slice(&[0, 0]);
5620 assert_eq!(ProxyServer::parse_stmt_name(&named), "ps1");
5621
5622 let mut unnamed = cstr("");
5623 unnamed.extend_from_slice(&cstr("SELECT 1"));
5624 unnamed.extend_from_slice(&[0, 0]);
5625 assert_eq!(ProxyServer::parse_stmt_name(&unnamed), "");
5626 }
5627
5628 #[test]
5629 fn bind_stmt_ref_reads_second_cstring() {
5630 let mut named = cstr("portal_a");
5632 named.extend_from_slice(&cstr("ps1"));
5633 named.extend_from_slice(&[0, 0]); assert_eq!(ProxyServer::bind_stmt_ref(&named), Some("ps1"));
5635
5636 let mut unnamed = cstr("");
5638 unnamed.extend_from_slice(&cstr(""));
5639 assert_eq!(ProxyServer::bind_stmt_ref(&unnamed), None);
5640 }
5641
5642 #[test]
5643 fn stmt_kind_name_only_matches_statement_kind() {
5644 let mut stmt = vec![b'S'];
5646 stmt.extend_from_slice(&cstr("ps1"));
5647 assert_eq!(ProxyServer::stmt_kind_name(&stmt), Some("ps1"));
5648
5649 let mut portal = vec![b'P'];
5651 portal.extend_from_slice(&cstr("portal_a"));
5652 assert_eq!(ProxyServer::stmt_kind_name(&portal), None);
5653
5654 let mut empty = vec![b'S'];
5656 empty.extend_from_slice(&cstr(""));
5657 assert_eq!(ProxyServer::stmt_kind_name(&empty), None);
5658 }
5659
5660 #[tokio::test]
5661 async fn read_one_frame_type_consumes_full_frame() {
5662 let (mut a, mut b) = tokio::io::duplex(64);
5665 let bytes = [b'1', 0, 0, 0, 4, b'Z', 0, 0, 0, 5, b'I'];
5667 b.write_all(&bytes).await.unwrap();
5668 let t = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
5669 assert_eq!(t, b'1');
5670 let t2 = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
5672 assert_eq!(t2, b'Z');
5673 }
5674
5675 #[tokio::test]
5676 async fn reprepare_statement_accepts_parse_complete_and_rejects_error() {
5677 let (mut client, mut backend) = tokio::io::duplex(64);
5679 backend.write_all(&[b'1', 0, 0, 0, 4]).await.unwrap();
5680 let parse = {
5681 let mut p = vec![b'P', 0, 0, 0, 0];
5682 p.extend_from_slice(&cstr("ps1"));
5683 p.extend_from_slice(&cstr("SELECT 1"));
5684 p.extend_from_slice(&[0, 0]);
5685 p
5686 };
5687 assert!(ProxyServer::reprepare_statement(&mut client, &parse)
5688 .await
5689 .is_ok());
5690
5691 let (mut client2, mut backend2) = tokio::io::duplex(64);
5693 backend2.write_all(&[b'E', 0, 0, 0, 4]).await.unwrap();
5694 assert!(ProxyServer::reprepare_statement(&mut client2, &parse)
5695 .await
5696 .is_err());
5697 }
5698
5699 #[cfg(feature = "routing-hints")]
5702 mod routing_hints {
5703 use super::*;
5704 use crate::routing::HintParser;
5705
5706 fn over(sql: &str) -> RouteOverride {
5707 let hints = HintParser::new().parse(sql);
5708 ProxyServer::hint_to_override(&hints)
5709 }
5710
5711 #[test]
5712 fn route_primary_maps_to_primary() {
5713 assert!(matches!(
5714 over("/*helios:route=primary*/ SELECT 1"),
5715 RouteOverride::Primary
5716 ));
5717 }
5718
5719 #[test]
5720 fn read_tier_targets_map_to_standby() {
5721 for t in ["standby", "sync", "semisync", "async", "local"] {
5722 assert!(
5723 matches!(
5724 over(&format!("/*helios:route={t}*/ SELECT 1")),
5725 RouteOverride::Standby
5726 ),
5727 "route={t} should map to Standby"
5728 );
5729 }
5730 }
5731
5732 #[test]
5733 fn any_and_vector_impose_no_constraint() {
5734 assert!(matches!(
5735 over("/*helios:route=any*/ SELECT 1"),
5736 RouteOverride::None
5737 ));
5738 assert!(matches!(
5739 over("/*helios:route=vector*/ SELECT 1"),
5740 RouteOverride::None
5741 ));
5742 }
5743
5744 #[test]
5745 fn node_hint_maps_to_node_and_wins_over_route() {
5746 match over("/*helios:node=pg-standby,route=primary*/ SELECT 1") {
5748 RouteOverride::Node(n) => assert_eq!(n, "pg-standby"),
5749 other => panic!("expected Node, got {other:?}"),
5750 }
5751 }
5752
5753 #[test]
5754 fn consistency_strong_forces_primary() {
5755 assert!(matches!(
5756 over("/*helios:consistency=strong*/ SELECT 1"),
5757 RouteOverride::Primary
5758 ));
5759 }
5760
5761 #[test]
5762 fn no_hint_yields_none() {
5763 assert!(matches!(over("SELECT 1"), RouteOverride::None));
5764 }
5765
5766 #[test]
5770 fn write_verb_classified_after_strip() {
5771 let parser = HintParser::new();
5772 let raw = "/*helios:route=primary*/ INSERT INTO t VALUES (1)";
5773 assert!(!ProxyServer::is_write_query(raw));
5776 assert!(ProxyServer::is_write_query(&parser.strip(raw)));
5778 }
5779
5780 #[test]
5781 fn strip_removes_hint_comment() {
5782 let parser = HintParser::new();
5783 assert_eq!(
5784 parser.strip("/*helios:route=standby*/ SELECT 42"),
5785 "SELECT 42"
5786 );
5787 }
5788 }
5789
5790 #[cfg(feature = "rate-limiting")]
5793 mod rate_limiting {
5794 use crate::rate_limit::{LimiterKey, RateLimitConfig, RateLimitResult, RateLimiter};
5795
5796 #[test]
5797 fn burst_allows_then_denies() {
5798 let cfg = RateLimitConfig {
5801 enabled: true,
5802 default_qps: 1,
5803 default_burst: 2,
5804 ..Default::default()
5805 };
5806 let limiter = RateLimiter::new(cfg);
5807 let key = LimiterKey::User("u".to_string());
5808
5809 assert!(matches!(limiter.check(&key, 1), RateLimitResult::Allowed));
5811 assert!(matches!(limiter.check(&key, 1), RateLimitResult::Allowed));
5812
5813 let mut denied = false;
5815 for _ in 0..5 {
5816 if matches!(limiter.check(&key, 1), RateLimitResult::Denied(_)) {
5817 denied = true;
5818 }
5819 }
5820 assert!(denied, "over-burst checks must yield a Denied verdict");
5821 }
5822
5823 #[test]
5824 fn distinct_keys_have_independent_buckets() {
5825 let cfg = RateLimitConfig {
5826 enabled: true,
5827 default_qps: 1,
5828 default_burst: 1,
5829 ..Default::default()
5830 };
5831 let limiter = RateLimiter::new(cfg);
5832 assert!(matches!(
5834 limiter.check(&LimiterKey::User("a".to_string()), 1),
5835 RateLimitResult::Allowed
5836 ));
5837 assert!(matches!(
5838 limiter.check(&LimiterKey::User("b".to_string()), 1),
5839 RateLimitResult::Allowed
5840 ));
5841 }
5842 }
5843
5844 #[cfg(feature = "circuit-breaker")]
5847 mod circuit_breaker {
5848 use crate::circuit_breaker::{
5849 CircuitBreakerConfig, CircuitBreakerManager, CircuitState, ManagerConfig,
5850 };
5851 use std::time::Duration;
5852
5853 fn mgr(threshold: u32) -> CircuitBreakerManager {
5854 let cfg = CircuitBreakerConfig {
5855 failure_threshold: threshold,
5856 cooldown: Duration::from_secs(10),
5857 ..Default::default()
5858 };
5859 CircuitBreakerManager::new(ManagerConfig::new(cfg))
5860 }
5861
5862 #[test]
5863 fn opens_after_threshold_failures() {
5864 let m = mgr(3);
5865 let b = m.get_breaker("n1");
5866 assert_eq!(b.get_state(), CircuitState::Closed);
5867 b.record_failure("boom");
5868 b.record_failure("boom");
5869 assert_eq!(b.get_state(), CircuitState::Closed);
5871 b.record_failure("boom");
5873 assert_eq!(b.get_state(), CircuitState::Open);
5874 }
5875
5876 #[test]
5877 fn healthy_node_stays_closed() {
5878 let m = mgr(3);
5879 let b = m.get_breaker("n2");
5880 b.record_success();
5881 b.record_success();
5882 assert_eq!(b.get_state(), CircuitState::Closed);
5883 }
5884 }
5885
5886 #[cfg(feature = "query-analytics")]
5889 mod query_analytics {
5890 use crate::analytics::{AnalyticsConfig, OrderBy, QueryAnalytics, QueryExecution};
5891 use std::time::Duration;
5892
5893 #[test]
5894 fn records_and_collapses_literals() {
5895 let a = QueryAnalytics::new(AnalyticsConfig::default());
5896 for n in [1, 2, 3] {
5897 a.record(QueryExecution::new(
5898 format!("select {n}"),
5899 Duration::from_millis(1),
5900 ));
5901 }
5902 let top = a.top_queries(OrderBy::Calls, 10);
5903 assert!(!top.is_empty(), "no fingerprints recorded");
5904 assert!(
5906 top.iter().any(|s| s.calls >= 3),
5907 "literals did not collapse: {:?}",
5908 top.iter()
5909 .map(|s| (s.normalized.clone(), s.calls))
5910 .collect::<Vec<_>>()
5911 );
5912 }
5913 }
5914
5915 #[cfg(feature = "lag-routing")]
5918 mod lag_routing {
5919 use super::ProxyServer;
5920
5921 #[test]
5922 fn ryw_pins_recent_write() {
5923 assert!(ProxyServer::ryw_pins_primary(
5925 Some(std::time::Instant::now()),
5926 1000
5927 ));
5928 }
5929
5930 #[test]
5931 fn ryw_releases_old_write() {
5932 let old = std::time::Instant::now()
5933 .checked_sub(std::time::Duration::from_secs(10))
5934 .unwrap();
5935 assert!(!ProxyServer::ryw_pins_primary(Some(old), 1000));
5936 }
5937
5938 #[test]
5939 fn ryw_no_write_or_disabled() {
5940 assert!(!ProxyServer::ryw_pins_primary(None, 1000));
5941 assert!(!ProxyServer::ryw_pins_primary(
5943 Some(std::time::Instant::now()),
5944 0
5945 ));
5946 }
5947
5948 #[test]
5949 fn lag_exclusion_thresholds() {
5950 assert!(!ProxyServer::lag_excludes_standby(Some(999_999), 0));
5952 assert!(!ProxyServer::lag_excludes_standby(None, 1000));
5954 assert!(!ProxyServer::lag_excludes_standby(Some(500), 1000));
5956 assert!(ProxyServer::lag_excludes_standby(Some(2000), 1000));
5958 }
5959 }
5960
5961 #[cfg(feature = "query-cache")]
5964 mod query_cache {
5965 use super::ProxyServer;
5966
5967 #[test]
5968 fn plain_selects_are_cacheable() {
5969 assert!(ProxyServer::is_cacheable_read_sql("select v from t"));
5970 assert!(ProxyServer::is_cacheable_read_sql(
5971 " SELECT a, b FROM users WHERE id = 5"
5972 ));
5973 }
5974
5975 #[test]
5976 fn writes_and_non_selects_are_not_cacheable() {
5977 assert!(!ProxyServer::is_cacheable_read_sql(
5978 "insert into t values (1)"
5979 ));
5980 assert!(!ProxyServer::is_cacheable_read_sql("update t set v = 1"));
5981 assert!(!ProxyServer::is_cacheable_read_sql("show search_path"));
5982 }
5983
5984 #[test]
5985 fn locking_and_volatile_selects_are_not_cacheable() {
5986 assert!(!ProxyServer::is_cacheable_read_sql(
5987 "select * from t for update"
5988 ));
5989 assert!(!ProxyServer::is_cacheable_read_sql("select now()"));
5990 assert!(!ProxyServer::is_cacheable_read_sql("select random()"));
5991 assert!(!ProxyServer::is_cacheable_read_sql("select nextval('s')"));
5992 }
5993 }
5994
5995 #[cfg(feature = "query-rewriting")]
5998 mod query_rewriting {
5999 use crate::rewriter::{
6000 QueryPattern, QueryRewriter, RewriteRule, RewriterConfig, Transformation,
6001 };
6002
6003 fn rw_with_table_replace() -> QueryRewriter {
6004 let rw = QueryRewriter::new(RewriterConfig {
6005 enabled: true,
6006 ..Default::default()
6007 });
6008 rw.add_rule(
6009 RewriteRule::build("t")
6010 .pattern(QueryPattern::Table("a".to_string()))
6011 .transform(Transformation::ReplaceTable {
6012 from: "a".to_string(),
6013 to: "b".to_string(),
6014 })
6015 .build(),
6016 );
6017 rw
6018 }
6019
6020 #[test]
6021 fn matching_query_is_rewritten() {
6022 let res = rw_with_table_replace().rewrite("select * from a").unwrap();
6023 assert!(res.was_rewritten(), "rule did not fire");
6024 assert!(res.query().contains('b'), "rewritten: {}", res.query());
6025 assert!(
6026 !res.query().contains("from a"),
6027 "still references a: {}",
6028 res.query()
6029 );
6030 }
6031
6032 #[test]
6033 fn unmatched_query_is_unchanged() {
6034 let res = rw_with_table_replace()
6035 .rewrite("select * from other")
6036 .unwrap();
6037 assert!(!res.was_rewritten());
6038 assert_eq!(res.query(), "select * from other");
6039 }
6040 }
6041
6042 #[cfg(feature = "multi-tenancy")]
6045 mod multi_tenancy {
6046 use crate::multi_tenancy::{
6047 IdentificationMethod, IsolationStrategy, MultiTenancyConfig, TenantConfig, TenantId,
6048 TenantManager, TenantManagerBuilder, TenantQueryTransformer,
6049 };
6050
6051 fn manager() -> TenantManager {
6052 let transformer = TenantQueryTransformer::new().register_tables(&["t"], "tid");
6053 let tm = TenantManagerBuilder::new()
6054 .config(MultiTenancyConfig {
6055 enabled: true,
6056 identification: IdentificationMethod::Header {
6057 header_name: "application_name".to_string(),
6058 },
6059 ..Default::default()
6060 })
6061 .query_transformer(transformer)
6062 .build();
6063 tm.register_tenant(TenantConfig::new(
6064 TenantId::new("acme"),
6065 IsolationStrategy::row("public", "tid"),
6066 ));
6067 tm
6068 }
6069
6070 #[test]
6071 fn tenant_table_gets_filter() {
6072 let res = manager().transform_query("select * from t", &TenantId::new("acme"));
6073 assert!(res.transformed, "expected a tenant filter to be injected");
6074 let q = res.query.to_lowercase();
6075 assert!(
6076 q.contains("tid") && q.contains("acme"),
6077 "filter missing: {}",
6078 res.query
6079 );
6080 }
6081
6082 #[test]
6083 fn non_tenant_table_passes_through() {
6084 let res = manager().transform_query("select * from other", &TenantId::new("acme"));
6085 assert!(!res.transformed);
6086 }
6087 }
6088
6089 #[cfg(feature = "ha-tr")]
6092 mod ha_tr {
6093 use crate::transaction_journal::TransactionJournal;
6094 use crate::NodeId;
6095
6096 #[tokio::test]
6097 async fn journal_records_and_windows_a_statement() {
6098 let j = TransactionJournal::new();
6099 let from = chrono::Utc::now() - chrono::Duration::seconds(60);
6100 let tx = uuid::Uuid::new_v4();
6101 j.begin_transaction(tx, uuid::Uuid::new_v4(), NodeId::new(), 0)
6102 .await
6103 .unwrap();
6104 j.log_statement(
6105 tx,
6106 "insert into t values (1)".to_string(),
6107 Vec::new(),
6108 None,
6109 None,
6110 0,
6111 )
6112 .await
6113 .unwrap();
6114 let to = chrono::Utc::now() + chrono::Duration::seconds(60);
6115 let entries = j.entries_in_window(from, to).await;
6116 assert_eq!(entries.len(), 1, "journaled statement should be in window");
6117 assert!(entries[0].1.statement.contains("insert"));
6118 }
6119 }
6120
6121 #[cfg(feature = "schema-routing")]
6124 mod schema_routing {
6125 use crate::schema_routing::{QueryAnalyzer, SchemaRegistry};
6126 use std::sync::Arc;
6127
6128 fn analyzer() -> QueryAnalyzer {
6129 QueryAnalyzer::new(Arc::new(SchemaRegistry::new()))
6130 }
6131
6132 #[test]
6133 fn aggregation_group_by_is_analytics() {
6134 let a = analyzer();
6135 assert!(a
6136 .analyze("select count(*) from orders group by region")
6137 .is_analytics());
6138 }
6139
6140 #[test]
6141 fn simple_point_query_is_not_analytics() {
6142 let a = analyzer();
6143 assert!(!a
6144 .analyze("select * from orders where id = 1")
6145 .is_analytics());
6146 }
6147 }
6148}