1use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::client_tls::{build_tls_acceptor, ClientStream};
10use crate::config::{HbaAction, HbaRule, NodeConfig, NodeRole, ProxyConfig, TrMode};
11use crate::protocol::{
12 ErrorResponse, Message, MessageType, ProtocolCodec, QueryMessage,
13 StartupMessage, TransactionStatus,
14};
15use crate::{ProxyError, Result};
16use arc_swap::ArcSwap;
17use bytes::{BufMut, BytesMut};
18use dashmap::DashMap;
19use std::collections::{HashMap, HashSet};
20use std::net::SocketAddr;
21use std::sync::atomic::{AtomicU64, Ordering};
22use std::sync::Arc;
23use std::time::Duration;
24use tokio::io::{AsyncReadExt, AsyncWriteExt};
25use tokio::net::{TcpListener, TcpStream};
26use tokio::sync::{broadcast, RwLock};
27use uuid::Uuid;
28
29#[cfg(feature = "pool-modes")]
31use crate::pool::{
32 ConnectionPoolManager, PoolModeConfig, PoolingMode,
33};
34#[cfg(feature = "pool-modes")]
35use crate::pool::lease::ClientId;
36#[cfg(feature = "pool-modes")]
37use crate::NodeEndpoint;
38
39#[cfg(feature = "wasm-plugins")]
41use crate::plugins::{
42 AuthRequest as PluginAuthRequest, AuthResult, HookContext, HookType, Identity, PluginManager,
43 PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
44};
45
46pub struct ProxyServer {
48 config: ProxyConfig,
49 state: Arc<ServerState>,
50 shutdown_tx: broadcast::Sender<()>,
51 config_path: Option<String>,
55}
56
57#[cfg(not(unix))]
60struct HangupNever;
61#[cfg(not(unix))]
62impl HangupNever {
63 async fn recv(&mut self) -> Option<()> {
64 std::future::pending().await
65 }
66}
67
68#[cfg(feature = "ha-tr")]
82fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
83 BackendConfig {
84 host: "placeholder".to_string(),
85 port: 0,
86 user: "postgres".to_string(),
87 password: None,
88 database: None,
89 application_name: Some("heliosdb-proxy-replay".to_string()),
90 tls_mode: TlsMode::Disable,
91 connect_timeout: Duration::from_secs(5),
92 query_timeout: Duration::from_secs(30),
93 tls_config: default_client_config(),
94 }
95}
96
97#[cfg(feature = "anomaly-detection")]
107fn anomaly_fingerprint(sql: &str) -> String {
108 let mut out = String::with_capacity(sql.len());
109 let mut in_single = false;
110 let mut prev_space = false;
111 let mut chars = sql.chars().peekable();
112 while let Some(c) = chars.next() {
113 if c == '\'' {
114 in_single = !in_single;
115 if in_single {
118 out.push('?');
119 while let Some(&n) = chars.peek() {
120 chars.next();
121 if n == '\'' {
122 in_single = false;
123 break;
124 }
125 }
126 prev_space = false;
127 continue;
128 }
129 }
130 if c.is_ascii_digit() {
131 if !out.ends_with('?') {
132 out.push('?');
133 }
134 while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
136 chars.next();
137 }
138 prev_space = false;
139 continue;
140 }
141 if c.is_ascii_whitespace() {
142 if !prev_space && !out.is_empty() {
143 out.push(' ');
144 prev_space = true;
145 }
146 continue;
147 }
148 out.push(c.to_ascii_lowercase());
149 prev_space = false;
150 }
151 out.trim_end().to_string()
152}
153
154struct ServerState {
156 sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
158 health: ArcSwap<HashMap<String, NodeHealth>>,
164 live_config: ArcSwap<ProxyConfig>,
171 metrics: ServerMetrics,
173 cancel_map: Arc<DashMap<(u32, u32), String>>,
179 tls_acceptor: Option<tokio_rustls::TlsAcceptor>,
182 auth_file: Option<Arc<crate::auth_scram::AuthFile>>,
186 mirror: Option<crate::mirror::MirrorHandle>,
189 cutover: Arc<ArcSwap<Option<Arc<crate::mirror::CutoverTarget>>>>,
193 lb_state: LoadBalancerState,
195 #[cfg(feature = "pool-modes")]
197 pool_manager: Option<Arc<ConnectionPoolManager>>,
198 #[cfg(feature = "wasm-plugins")]
202 plugin_manager: Option<Arc<PluginManager>>,
203 #[cfg(feature = "ha-tr")]
208 transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
209 #[cfg(feature = "anomaly-detection")]
212 anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
213 #[cfg(feature = "edge-proxy")]
217 edge_cache: Arc<crate::edge::EdgeCache>,
218 #[cfg(feature = "edge-proxy")]
219 edge_registry: Arc<crate::edge::EdgeRegistry>,
220}
221
222#[derive(Debug, Clone)]
224pub struct NodeHealth {
225 pub address: String,
227 pub healthy: bool,
229 pub last_check: chrono::DateTime<chrono::Utc>,
231 pub failure_count: u32,
233 pub last_error: Option<String>,
235 pub latency_ms: f64,
237 pub replication_lag_bytes: Option<u64>,
239}
240
241#[derive(Default)]
243struct ServerMetrics {
244 connections_accepted: AtomicU64,
246 connections_closed: AtomicU64,
248 queries_processed: AtomicU64,
250 bytes_received: AtomicU64,
252 bytes_sent: AtomicU64,
254 failovers: AtomicU64,
256}
257
258struct LoadBalancerState {
260 rr_counter: AtomicU64,
263}
264
265pub struct ClientSession {
267 pub id: Uuid,
269 pub client_addr: SocketAddr,
271 pub current_node: RwLock<Option<String>>,
273 pub tx_state: RwLock<TransactionState>,
275 pub variables: RwLock<HashMap<String, String>>,
277 pub created_at: chrono::DateTime<chrono::Utc>,
279 pub tr_mode: TrMode,
281 #[cfg(feature = "pool-modes")]
283 pub pool_client_id: ClientId,
284 #[cfg(feature = "wasm-plugins")]
289 pub plugin_identity: RwLock<Option<Identity>>,
290}
291
292#[derive(Debug, Clone, Default)]
294pub struct TransactionState {
295 pub in_transaction: bool,
297 pub tx_id: Option<Uuid>,
299 pub statements: Vec<StatementLog>,
301 pub read_only: bool,
303 pub savepoints: Vec<String>,
305}
306
307#[derive(Debug, Clone)]
309pub struct StatementLog {
310 pub sql: String,
312 pub params: Vec<String>,
314 pub result_checksum: Option<u64>,
316 pub executed_at: chrono::DateTime<chrono::Utc>,
318}
319
320struct BackendConn {
333 stream: TcpStream,
334 prepared: HashSet<String>,
335 unnamed_sig: Option<bytes::Bytes>,
341}
342
343impl BackendConn {
344 fn new(stream: TcpStream) -> Self {
345 Self { stream, prepared: HashSet::new(), unnamed_sig: None }
346 }
347}
348
349pub(crate) fn bind_reuseport(addr: &str) -> Result<TcpListener> {
355 use socket2::{Domain, Protocol, Socket, Type};
356 let sockaddr: SocketAddr = addr
357 .parse()
358 .map_err(|e| ProxyError::Config(format!("invalid listen address '{}': {}", addr, e)))?;
359 let domain = if sockaddr.is_ipv6() { Domain::IPV6 } else { Domain::IPV4 };
360 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
361 .map_err(|e| ProxyError::Network(format!("socket(): {}", e)))?;
362 socket
363 .set_reuse_address(true)
364 .map_err(|e| ProxyError::Network(format!("SO_REUSEADDR: {}", e)))?;
365 #[cfg(all(unix, not(target_os = "solaris")))]
366 socket
367 .set_reuse_port(true)
368 .map_err(|e| ProxyError::Network(format!("SO_REUSEPORT: {}", e)))?;
369 socket
370 .set_nonblocking(true)
371 .map_err(|e| ProxyError::Network(format!("set_nonblocking: {}", e)))?;
372 socket
373 .bind(&sockaddr.into())
374 .map_err(|e| ProxyError::Network(format!("Failed to bind {}: {}", addr, e)))?;
375 socket
376 .listen(1024)
377 .map_err(|e| ProxyError::Network(format!("listen(): {}", e)))?;
378 let std_listener: std::net::TcpListener = socket.into();
379 TcpListener::from_std(std_listener)
380 .map_err(|e| ProxyError::Network(format!("from_std listener: {}", e)))
381}
382
383#[derive(Debug)]
389#[allow(dead_code)] enum PreQueryAction {
391 Forward,
393 Block(String),
396 Cached(Vec<u8>),
402}
403
404#[derive(Debug)]
410#[allow(dead_code)] enum RouteOverride {
412 None,
414 Primary,
416 Standby,
418 Node(String),
422 Block(String),
427}
428
429impl ProxyServer {
430 #[cfg(feature = "wasm-plugins")]
438 fn init_plugin_manager(
439 toml_cfg: &crate::config::PluginToml,
440 ) -> Option<Arc<crate::plugins::PluginManager>> {
441 if !toml_cfg.enabled {
442 return None;
443 }
444
445 let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
446 let plugin_dir = runtime_cfg.plugin_dir.clone();
447
448 let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
449 Ok(pm) => Arc::new(pm),
450 Err(e) => {
451 tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
452 return None;
453 }
454 };
455
456 match std::fs::read_dir(&plugin_dir) {
457 Ok(entries) => {
458 let mut loaded = 0usize;
459 let mut failed = 0usize;
460 for entry in entries.flatten() {
461 let path = entry.path();
462 if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
463 continue;
464 }
465 match pm.load_plugin(&path) {
466 Ok(()) => loaded += 1,
467 Err(e) => {
468 failed += 1;
469 tracing::warn!(
470 path = %path.display(),
471 error = %e,
472 "Failed to load plugin"
473 );
474 }
475 }
476 }
477 tracing::info!(
478 dir = %plugin_dir.display(),
479 loaded = loaded,
480 failed = failed,
481 "Plugin loading complete"
482 );
483 }
484 Err(e) => {
485 tracing::warn!(
486 dir = %plugin_dir.display(),
487 error = %e,
488 "Plugin directory not readable; no plugins loaded"
489 );
490 }
491 }
492
493 Some(pm)
494 }
495
496 pub fn new(config: ProxyConfig) -> Result<Self> {
498 let (shutdown_tx, _) = broadcast::channel(1);
499
500 let mut health = HashMap::new();
502 for node in &config.nodes {
503 health.insert(
504 node.address(),
505 NodeHealth {
506 address: node.address(),
507 healthy: true, last_check: chrono::Utc::now(),
509 failure_count: 0,
510 last_error: None,
511 latency_ms: 0.0,
512 replication_lag_bytes: None,
513 },
514 );
515 }
516
517 #[cfg(feature = "pool-modes")]
519 let pool_manager = {
520 use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
521
522 let pool_config = PoolModeConfig {
523 default_mode: match config.pool_mode.mode {
524 crate::config::PoolingMode::Session => PoolingMode::Session,
525 crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
526 crate::config::PoolingMode::Statement => PoolingMode::Statement,
527 },
528 max_pool_size: config.pool_mode.max_pool_size,
529 min_idle: config.pool_mode.min_idle,
530 idle_timeout_secs: config.pool_mode.idle_timeout_secs,
531 max_lifetime_secs: config.pool_mode.max_lifetime_secs,
532 acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
533 reset_query: config.pool_mode.reset_query.clone(),
534 prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
535 crate::config::PreparedStatementMode::Disable => {
536 PoolPreparedStatementMode::Disable
537 }
538 crate::config::PreparedStatementMode::Track => {
539 PoolPreparedStatementMode::Track
540 }
541 crate::config::PreparedStatementMode::Named => {
542 PoolPreparedStatementMode::Named
543 }
544 },
545 test_on_acquire: config.pool.test_on_acquire,
546 validation_query: "SELECT 1".to_string(),
547 queue_timeout_secs: 30,
548 max_queue_size: 0,
549 };
550 Some(Arc::new(ConnectionPoolManager::new(pool_config)))
551 };
552
553 #[cfg(feature = "wasm-plugins")]
558 let plugin_manager = Self::init_plugin_manager(&config.plugins);
559
560 let tls_acceptor = match config.tls.as_ref() {
564 Some(tls) if tls.enabled => match build_tls_acceptor(tls) {
565 Ok(acc) => {
566 tracing::info!(
567 mtls = tls.require_client_cert,
568 "client TLS termination enabled"
569 );
570 Some(acc)
571 }
572 Err(e) => {
573 return Err(ProxyError::Config(format!("TLS init failed: {}", e)));
574 }
575 },
576 _ => None,
577 };
578
579 let auth_file = if config.auth.mode == crate::config::AuthMode::Scram {
582 let path = config.auth.auth_file.as_ref().ok_or_else(|| {
583 ProxyError::Config("auth mode 'scram' requires auth_file".to_string())
584 })?;
585 let af = crate::auth_scram::AuthFile::load(path)
586 .map_err(|e| ProxyError::Config(format!("auth_file: {}", e)))?;
587 tracing::info!(users = %(!af.is_empty()), "proxy SCRAM auth enabled");
588 Some(Arc::new(af))
589 } else {
590 None
591 };
592
593 let mirror = if config.mirror.enabled {
596 tracing::info!(target = %format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
597 writes_only = config.mirror.writes_only, "traffic mirroring enabled");
598 Some(crate::mirror::spawn(config.mirror.clone()))
599 } else {
600 None
601 };
602
603 let state = Arc::new(ServerState {
604 sessions: RwLock::new(HashMap::new()),
605 health: ArcSwap::from_pointee(health),
606 live_config: ArcSwap::from_pointee(config.clone()),
607 metrics: ServerMetrics::default(),
608 cancel_map: Arc::new(DashMap::new()),
609 tls_acceptor,
610 auth_file,
611 mirror,
612 cutover: Arc::new(ArcSwap::from_pointee(None)),
613 lb_state: LoadBalancerState {
614 rr_counter: AtomicU64::new(0),
615 },
616 #[cfg(feature = "pool-modes")]
617 pool_manager,
618 #[cfg(feature = "wasm-plugins")]
619 plugin_manager,
620 #[cfg(feature = "ha-tr")]
621 transaction_journal: Arc::new(
622 crate::transaction_journal::TransactionJournal::new(),
623 ),
624 #[cfg(feature = "anomaly-detection")]
625 anomaly_detector: Arc::new(
626 crate::anomaly::AnomalyDetector::new(
627 crate::anomaly::AnomalyConfig::default(),
628 ),
629 ),
630 #[cfg(feature = "edge-proxy")]
631 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
632 #[cfg(feature = "edge-proxy")]
633 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
634 32,
635 std::time::Duration::from_secs(120),
636 )),
637 });
638
639 Ok(Self {
640 config,
641 state,
642 shutdown_tx,
643 config_path: None,
644 })
645 }
646
647 pub fn with_config_path(mut self, path: Option<String>) -> Self {
651 self.config_path = path;
652 self
653 }
654
655 #[cfg(unix)]
658 fn hangup_stream() -> tokio::signal::unix::Signal {
659 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())
660 .expect("failed to install SIGHUP handler")
661 }
662 #[cfg(not(unix))]
663 fn hangup_stream() -> HangupNever {
664 HangupNever
665 }
666
667 #[cfg(unix)]
670 fn usr2_stream() -> tokio::signal::unix::Signal {
671 tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())
672 .expect("failed to install SIGUSR2 handler")
673 }
674 #[cfg(not(unix))]
675 fn usr2_stream() -> HangupNever {
676 HangupNever
677 }
678
679 async fn drain_connections(state: &Arc<ServerState>, timeout: Duration) {
683 let deadline = tokio::time::Instant::now() + timeout;
684 loop {
685 let active = state.sessions.read().await.len();
686 if active == 0 {
687 tracing::info!("drain complete — all in-flight connections finished");
688 return;
689 }
690 if tokio::time::Instant::now() >= deadline {
691 tracing::warn!(
692 active,
693 "drain timeout reached — exiting with connections still open"
694 );
695 return;
696 }
697 tokio::time::sleep(Duration::from_millis(200)).await;
698 }
699 }
700
701 fn drain_timeout(config_secs: u64) -> Duration {
706 let secs = std::env::var("HELIOS_DRAIN_TIMEOUT_SECS")
707 .ok()
708 .and_then(|s| s.parse::<u64>().ok())
709 .unwrap_or(config_secs);
710 Duration::from_secs(secs)
711 }
712
713 async fn reload_config(&self) {
722 let Some(path) = self.config_path.as_deref() else {
723 tracing::warn!(
724 "SIGHUP received but config was not loaded from a file — nothing to reload"
725 );
726 return;
727 };
728 tracing::info!(path, "SIGHUP: reloading configuration");
729 let new_config = match ProxyConfig::from_file(path) {
730 Ok(c) => c,
731 Err(e) => {
732 tracing::error!(path, error = %e, "SIGHUP reload failed to parse — keeping current config");
733 return;
734 }
735 };
736 let old = self.state.live_config.load_full();
737 if new_config.listen_address != old.listen_address {
738 tracing::warn!(old = %old.listen_address, new = %new_config.listen_address,
739 "listen_address change needs a restart/handoff; the bound socket is kept");
740 }
741 if new_config.admin_address != old.admin_address {
742 tracing::warn!(old = %old.admin_address, new = %new_config.admin_address,
743 "admin_address change needs a restart; the bound socket is kept");
744 }
745 Self::reconcile_health(&self.state, &new_config);
748 let nodes = new_config.nodes.len();
749 let hba_rules = new_config.hba.len();
750 let pool_max = new_config.pool.max_connections;
751 self.state.live_config.store(Arc::new(new_config));
752 tracing::info!(
753 nodes,
754 hba_rules,
755 pool_max,
756 "SIGHUP: configuration reloaded — applies to new connections"
757 );
758 }
759
760 fn reconcile_health(state: &Arc<ServerState>, config: &ProxyConfig) {
764 let current = state.health.load_full();
765 let mut next: HashMap<String, NodeHealth> = HashMap::new();
766 for node in &config.nodes {
767 let addr = node.address();
768 match current.get(&addr) {
769 Some(existing) => {
770 next.insert(addr, existing.clone());
771 }
772 None => {
773 tracing::info!(node = %addr, "SIGHUP: new node added — seeding healthy");
774 next.insert(
775 addr.clone(),
776 NodeHealth {
777 address: addr,
778 healthy: true,
779 last_check: chrono::Utc::now(),
780 failure_count: 0,
781 last_error: None,
782 latency_ms: 0.0,
783 replication_lag_bytes: None,
784 },
785 );
786 }
787 }
788 }
789 for gone in current.keys().filter(|k| !next.contains_key(*k)) {
790 tracing::info!(node = %gone, "SIGHUP: node removed from config");
791 }
792 state.health.store(Arc::new(next));
793 }
794
795 pub async fn run(&self) -> Result<()> {
797 let listener = bind_reuseport(&self.config.listen_address)?;
803
804 tracing::info!("Proxy listening on {} (SO_REUSEPORT)", self.config.listen_address);
805
806 let health_task = self.spawn_health_checker();
808 let pool_task = self.spawn_pool_manager();
809
810 let admin_task = self.spawn_admin_server();
812
813 let mcp_task = if self.config.mcp.enabled {
815 let mcp_cfg = self.config.mcp.clone();
816 let contract = mcp_cfg.contract.as_ref().and_then(|id| {
818 let found = self.config.agent_contracts.iter().find(|c| &c.id == id).cloned();
819 if found.is_none() {
820 tracing::warn!(%id, "mcp.contract names an unknown agent_contract; gateway runs with only the read-only guardrail");
821 }
822 found
823 });
824 Some(tokio::spawn(async move {
825 if let Err(e) = crate::mcp::McpServer::new(mcp_cfg, contract).run().await {
826 tracing::error!("MCP gateway error: {}", e);
827 }
828 }))
829 } else {
830 None
831 };
832
833 let http_gw_task = if self.config.http_gateway.enabled {
835 let gw_cfg = self.config.http_gateway.clone();
836 Some(tokio::spawn(async move {
837 if let Err(e) = crate::http_gateway::HttpGateway::new(gw_cfg).run().await {
838 tracing::error!("HTTP gateway error: {}", e);
839 }
840 }))
841 } else {
842 None
843 };
844
845 let mut shutdown_rx = self.shutdown_tx.subscribe();
846
847 let mut sighup = Self::hangup_stream();
851 let mut sigusr2 = Self::usr2_stream();
852 let mut graceful = false;
853
854 loop {
855 tokio::select! {
856 _ = sighup.recv() => {
857 self.reload_config().await;
858 }
859 _ = sigusr2.recv() => {
860 tracing::info!(
861 "SIGUSR2: graceful binary-handoff drain — closing the listener so new \
862 connections route to the sibling process; finishing in-flight connections"
863 );
864 graceful = true;
865 break;
866 }
867 accept_result = listener.accept() => {
868 match accept_result {
869 Ok((stream, addr)) => {
870 let _ = stream.set_nodelay(true);
874 self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
875 let state = self.state.clone();
876 let config = (*self.state.live_config.load_full()).clone();
880 let shutdown_tx = self.shutdown_tx.clone();
881
882 tokio::spawn(async move {
883 if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
884 tracing::error!("Client handler error: {}", e);
885 }
886 });
887 }
888 Err(e) => {
889 tracing::error!("Accept error: {}", e);
890 }
891 }
892 }
893 _ = shutdown_rx.recv() => {
894 tracing::info!("Shutdown signal received");
895 break;
896 }
897 }
898 }
899
900 drop(listener);
904
905 if graceful {
908 let timeout =
909 Self::drain_timeout(self.state.live_config.load().shutdown_drain_timeout_secs);
910 tracing::info!(timeout_secs = timeout.as_secs(), "draining in-flight connections");
911 Self::drain_connections(&self.state, timeout).await;
912 }
913
914 health_task.abort();
916 pool_task.abort();
917 admin_task.abort();
918 if let Some(t) = mcp_task {
919 t.abort();
920 }
921 if let Some(t) = http_gw_task {
922 t.abort();
923 }
924
925 Ok(())
926 }
927
928 fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
930 let config = self.config.clone();
931 let state = self.state.clone();
932 let mut shutdown_rx = self.shutdown_tx.subscribe();
933
934 tokio::spawn(async move {
935 let admin_state = Arc::new(AdminState::new());
937
938 {
940 let mut snapshot = admin_state.config_snapshot.write().await;
941 *snapshot = ConfigSnapshot {
942 listen_address: config.listen_address.clone(),
943 admin_address: config.admin_address.clone(),
944 tr_enabled: config.tr_enabled,
945 tr_mode: format!("{:?}", config.tr_mode),
946 pool_min_connections: config.pool.min_connections,
947 pool_max_connections: config.pool.max_connections,
948 nodes: config.nodes.iter().map(|n| NodeSnapshot {
949 address: n.address(),
950 role: format!("{:?}", n.role),
951 weight: n.weight,
952 enabled: n.enabled,
953 }).collect(),
954 };
955 }
956
957 admin_state.set_proxy_config(config.clone()).await;
959
960 admin_state.with_auth_token(config.admin_token.clone()).await;
962
963 if config.branch.enabled {
965 admin_state.with_branch(config.branch.clone()).await;
966 }
967
968 if let Some(ref mirror) = state.mirror {
970 admin_state
971 .with_migration(crate::admin::MigrationInfo {
972 target: mirror.target().to_string(),
973 writes_only: mirror.writes_only(),
974 metrics: mirror.metrics.clone(),
975 config: config.mirror.clone(),
976 cutover: state.cutover.clone(),
977 cutover_target: crate::mirror::CutoverTarget {
978 addr: format!("{}:{}", config.mirror.backend_host, config.mirror.backend_port),
979 user: config.mirror.backend_user.clone(),
980 password: config.mirror.backend_password.clone(),
981 database: config.mirror.backend_database.clone(),
982 },
983 })
984 .await;
985 }
986
987 #[cfg(feature = "wasm-plugins")]
992 if let Some(ref pm) = state.plugin_manager {
993 admin_state.with_plugin_manager(pm.clone()).await;
994 }
995
996 #[cfg(feature = "ha-tr")]
1003 {
1004 let template = build_replay_backend_template(&config);
1005 let engine = Arc::new(crate::replay::ReplayEngine::new(
1006 state.transaction_journal.clone(),
1007 template,
1008 ));
1009 admin_state.with_replay_engine(engine).await;
1010 }
1011
1012 #[cfg(feature = "anomaly-detection")]
1016 admin_state
1017 .with_anomaly_detector(state.anomaly_detector.clone())
1018 .await;
1019
1020 #[cfg(feature = "edge-proxy")]
1023 admin_state
1024 .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
1025 .await;
1026
1027 let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
1029
1030 let admin_state_sync = admin_state.clone();
1032 let server_state = state.clone();
1033 let sync_task = tokio::spawn(async move {
1034 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
1035 loop {
1036 interval.tick().await;
1037
1038 {
1040 let health = server_state.health.load_full();
1041 let mut admin_health = admin_state_sync.node_health.write().await;
1042 *admin_health = (*health).clone();
1043 }
1044
1045 {
1047 let metrics = ServerMetricsSnapshot {
1048 connections_accepted: server_state.metrics.connections_accepted.load(Ordering::Relaxed),
1049 connections_closed: server_state.metrics.connections_closed.load(Ordering::Relaxed),
1050 queries_processed: server_state.metrics.queries_processed.load(Ordering::Relaxed),
1051 bytes_received: server_state.metrics.bytes_received.load(Ordering::Relaxed),
1052 bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
1053 failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
1054 };
1055 let mut admin_metrics = admin_state_sync.metrics.write().await;
1056 *admin_metrics = metrics;
1057 }
1058
1059 {
1061 let sessions = server_state.sessions.read().await;
1062 let mut admin_sessions = admin_state_sync.active_sessions.write().await;
1063 *admin_sessions = sessions.len() as u64;
1064 }
1065 }
1066 });
1067
1068 tokio::select! {
1070 result = admin_server.run() => {
1071 if let Err(e) = result {
1072 tracing::error!("Admin server error: {}", e);
1073 }
1074 }
1075 _ = shutdown_rx.recv() => {
1076 tracing::info!("Admin server shutting down");
1077 }
1078 }
1079
1080 sync_task.abort();
1081 })
1082 }
1083
1084 async fn handle_client(
1086 stream: TcpStream,
1087 addr: SocketAddr,
1088 state: Arc<ServerState>,
1089 config: ProxyConfig,
1090 _shutdown_tx: broadcast::Sender<()>,
1091 ) -> Result<()> {
1092 tracing::debug!("New client connection from {}", addr);
1093
1094 let session = Arc::new(ClientSession {
1096 id: Uuid::new_v4(),
1097 client_addr: addr,
1098 current_node: RwLock::new(None),
1099 tx_state: RwLock::new(TransactionState::default()),
1100 variables: RwLock::new(HashMap::new()),
1101 created_at: chrono::Utc::now(),
1102 tr_mode: config.tr_mode,
1103 #[cfg(feature = "pool-modes")]
1104 pool_client_id: ClientId::new(),
1105 #[cfg(feature = "wasm-plugins")]
1106 plugin_identity: RwLock::new(None),
1107 });
1108
1109 {
1111 let mut sessions = state.sessions.write().await;
1112 sessions.insert(session.id, session.clone());
1113 }
1114
1115 let result = match Self::negotiate_client_tls(stream, &state).await {
1120 Ok((mut client_stream, pre)) => {
1121 Self::client_loop(&mut client_stream, pre, &session, &state, &config).await
1122 }
1123 Err(e) => Err(e),
1124 };
1125
1126 {
1128 let mut sessions = state.sessions.write().await;
1129 sessions.remove(&session.id);
1130 }
1131
1132 #[cfg(feature = "pool-modes")]
1134 if let Some(ref pool_manager) = state.pool_manager {
1135 if pool_manager.has_active_lease(&session.pool_client_id) {
1137 tracing::debug!(
1138 "Releasing pool lease for disconnecting client {:?}",
1139 session.pool_client_id
1140 );
1141 }
1144 }
1145
1146 state
1147 .metrics
1148 .connections_closed
1149 .fetch_add(1, Ordering::Relaxed);
1150
1151 result
1152 }
1153
1154 async fn client_loop(
1156 stream: &mut ClientStream,
1157 pre: Option<StartupMessage>,
1158 session: &Arc<ClientSession>,
1159 state: &Arc<ServerState>,
1160 config: &ProxyConfig,
1161 ) -> Result<()> {
1162 let codec = ProtocolCodec::new();
1163 let mut buffer = BytesMut::with_capacity(8192);
1164
1165 let mut conns: HashMap<String, BackendConn> = HashMap::new();
1176 let mut current_node: Option<String> =
1177 match Self::handle_startup(stream, &mut buffer, &codec, pre, session, state, config).await {
1178 Ok((Some(stream_conn), node_addr)) => {
1179 conns.insert(node_addr.clone(), BackendConn::new(stream_conn));
1180 Some(node_addr)
1181 }
1182 Ok((None, _)) => {
1183 return Ok(());
1185 }
1186 Err(e) => {
1187 tracing::error!("Startup failed: {}", e);
1188 let err_msg = Self::create_error_response("08006", &format!("Startup failed: {}", e));
1190 let _ = stream.write_all(&err_msg).await;
1191 return Err(e);
1192 }
1193 };
1194
1195 let mut read_buf = vec![0u8; 16384];
1210 let mut pending = BytesMut::new();
1211 let mut pending_route_sql: Option<String> = None;
1212 let mut stmt_registry: HashMap<String, bytes::Bytes> = HashMap::new();
1220 let mut batch_defines: Vec<String> = Vec::new();
1221 let mut batch_refs: Vec<String> = Vec::new();
1222 let mut batch_closes: Vec<String> = Vec::new();
1223 let promote_unnamed = config.optimize_unnamed_parse;
1229 let mut held_unnamed: Option<(bytes::Bytes, bytes::Bytes)> = None;
1230 loop {
1231 let n = stream
1233 .read(&mut read_buf)
1234 .await
1235 .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
1236
1237 if n == 0 {
1238 break;
1240 }
1241
1242 buffer.extend_from_slice(&read_buf[..n]);
1243 state.metrics.bytes_received.fetch_add(n as u64, Ordering::Relaxed);
1244
1245 while let Some(msg) = codec.decode_message(&mut buffer)? {
1247 match msg.msg_type {
1248 MessageType::Terminate => return Ok(()),
1249
1250 MessageType::Query => {
1252 #[cfg(feature = "anomaly-detection")]
1256 Self::record_anomaly_observation(&msg, state, session);
1257
1258 let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
1261
1262 if let PreQueryAction::Block(reason) = &action {
1263 tracing::info!(reason = %reason, "pre-query plugin blocked query");
1264 Self::send_block_response(stream, reason, state).await?;
1265 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1266 continue;
1267 }
1268
1269 #[cfg(feature = "wasm-plugins")]
1270 if let PreQueryAction::Cached(bytes) = &action {
1271 match Self::synthesise_cached_response(bytes) {
1272 Ok(reply) => {
1273 stream.write_all(&reply).await.map_err(|e| {
1274 ProxyError::Network(format!("Write error: {}", e))
1275 })?;
1276 state.metrics.bytes_sent.fetch_add(reply.len() as u64, Ordering::Relaxed);
1277 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1278 continue;
1279 }
1280 Err(e) => {
1281 tracing::warn!(error = %e, "failed to synthesise cached response; falling back to backend");
1282 }
1283 }
1284 }
1285
1286 if let Some(ref mirror) = state.mirror {
1290 if let Some(sql) = crate::protocol::query_text(&msg.payload) {
1291 mirror.offer(sql, Self::is_write_query(sql));
1292 }
1293 }
1294
1295 #[cfg(feature = "wasm-plugins")]
1296 let forward_start = std::time::Instant::now();
1297 let fr = Self::forward_simple_query(
1298 stream,
1299 &msg,
1300 &mut conns,
1301 current_node.as_deref(),
1302 session,
1303 state,
1304 config,
1305 )
1306 .await;
1307 #[cfg(feature = "wasm-plugins")]
1308 Self::fire_post_query_hook(&msg, session, state, &fr, forward_start.elapsed());
1309 let (used_node, sent) = fr?;
1310 if let Some(n) = used_node {
1311 current_node = Some(n);
1312 }
1313 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1314 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1315 }
1316
1317 MessageType::Parse
1319 | MessageType::Bind
1320 | MessageType::Describe
1321 | MessageType::Execute
1322 | MessageType::Close => {
1323 let mut add_to_pending = true;
1327 match msg.msg_type {
1328 MessageType::Parse => {
1329 let name = Self::parse_stmt_name(&msg.payload);
1333 let unnamed = name.is_empty();
1334 if !unnamed {
1335 let name = name.to_string();
1336 stmt_registry.insert(name.clone(), msg.encode().freeze());
1337 batch_defines.push(name);
1338 }
1339 if pending_route_sql.is_none() {
1340 if let Some(end) = msg.payload.iter().position(|&b| b == 0) {
1341 if let Some(q) =
1342 crate::protocol::query_text(&msg.payload[end + 1..])
1343 {
1344 if !q.is_empty() {
1345 pending_route_sql = Some(q.to_string());
1346 #[cfg(feature = "anomaly-detection")]
1347 Self::record_anomaly_sql(q, state, session);
1348 }
1349 }
1350 }
1351 }
1352 if promote_unnamed
1359 && unnamed
1360 && pending.is_empty()
1361 && held_unnamed.is_none()
1362 {
1363 let sig = bytes::Bytes::copy_from_slice(&msg.payload[1..]);
1364 held_unnamed = Some((msg.encode().freeze(), sig));
1365 add_to_pending = false;
1366 } else if let Some((held_msg, _)) = held_unnamed.take() {
1367 let mut combined = BytesMut::with_capacity(held_msg.len() + pending.len());
1368 combined.extend_from_slice(&held_msg);
1369 combined.extend_from_slice(&pending);
1370 pending = combined;
1371 }
1372 }
1373 MessageType::Bind => {
1374 if let Some(name) = Self::bind_stmt_ref(&msg.payload) {
1375 batch_refs.push(name.to_string());
1376 }
1377 }
1378 MessageType::Describe => {
1379 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1380 batch_refs.push(name.to_string());
1381 }
1382 }
1383 MessageType::Close => {
1384 if let Some(name) = Self::stmt_kind_name(&msg.payload) {
1385 batch_closes.push(name.to_string());
1386 }
1387 }
1388 _ => {}
1389 }
1390 if add_to_pending {
1391 pending.extend_from_slice(&msg.encode());
1392 }
1393 }
1394
1395 MessageType::Sync | MessageType::Flush => {
1397 let wait_ready = msg.msg_type == MessageType::Sync;
1398 pending.extend_from_slice(&msg.encode());
1399 let batch = pending.split().freeze();
1400 let reprepare: Vec<String> = batch_refs
1404 .iter()
1405 .filter(|r| !batch_defines.contains(r))
1406 .cloned()
1407 .collect();
1408 let (used_node, sent) = Self::forward_extended_batch(
1409 stream,
1410 &batch,
1411 pending_route_sql.as_deref(),
1412 wait_ready,
1413 &mut conns,
1414 current_node.as_deref(),
1415 &stmt_registry,
1416 &reprepare,
1417 &batch_defines,
1418 held_unnamed.take(),
1419 session,
1420 state,
1421 config,
1422 )
1423 .await?;
1424 if let Some(n) = used_node {
1425 current_node = Some(n);
1426 }
1427 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1428 for name in batch_closes.drain(..) {
1431 stmt_registry.remove(&name);
1432 }
1433 if wait_ready {
1434 pending_route_sql = None;
1438 batch_defines.clear();
1439 batch_refs.clear();
1440 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
1441 }
1442 }
1443
1444 MessageType::CopyData | MessageType::CopyDone | MessageType::CopyFail => {
1446 if let Some(node) = current_node.clone() {
1447 if let Some(b) = conns.get_mut(&node) {
1448 b.stream.write_all(&msg.encode()).await.map_err(|e| {
1449 ProxyError::Network(format!("Backend copy write error: {}", e))
1450 })?;
1451 if matches!(msg.msg_type, MessageType::CopyDone | MessageType::CopyFail) {
1452 let r = Self::stream_until_ready(stream, &mut b.stream, session, state).await;
1453 match r {
1454 Ok(sent) => {
1455 state.metrics.bytes_sent.fetch_add(sent, Ordering::Relaxed);
1456 }
1457 Err(e) => {
1458 conns.remove(&node);
1459 return Err(e);
1460 }
1461 }
1462 }
1463 }
1464 }
1465 }
1466
1467 _ => {
1469 if let Some(ref node) = current_node {
1470 if let Some(b) = conns.get_mut(node) {
1471 let _ = b.stream.write_all(&msg.encode()).await;
1472 }
1473 }
1474 }
1475 }
1476 }
1477 }
1478
1479 Ok(())
1480 }
1481
1482 async fn negotiate_client_tls(
1489 mut tcp: TcpStream,
1490 state: &Arc<ServerState>,
1491 ) -> Result<(ClientStream, Option<StartupMessage>)> {
1492 let codec = ProtocolCodec::new();
1493 let mut buffer = BytesMut::with_capacity(1024);
1494 let mut read_buf = vec![0u8; 1024];
1495
1496 let first = loop {
1497 if let Some(msg) = codec.decode_startup(&mut buffer)? {
1498 break msg;
1499 }
1500 let n = tcp
1501 .read(&mut read_buf)
1502 .await
1503 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
1504 if n == 0 {
1505 return Err(ProxyError::Connection("client closed before startup".to_string()));
1506 }
1507 buffer.extend_from_slice(&read_buf[..n]);
1508 };
1509
1510 match first {
1511 StartupMessage::SSLRequest => match state.tls_acceptor.as_ref() {
1512 Some(acceptor) => {
1513 tcp.write_all(&[b'S'])
1514 .await
1515 .map_err(|e| ProxyError::Network(format!("SSL accept write: {}", e)))?;
1516 let tls = acceptor
1517 .accept(tcp)
1518 .await
1519 .map_err(|e| ProxyError::Network(format!("TLS handshake failed: {}", e)))?;
1520 if tls.get_ref().1.peer_certificates().is_some() {
1521 tracing::debug!("client presented a certificate (mTLS)");
1522 }
1523 Ok((ClientStream::Tls(Box::new(tls)), None))
1524 }
1525 None => {
1526 tcp.write_all(&[b'N'])
1527 .await
1528 .map_err(|e| ProxyError::Network(format!("SSL reject write: {}", e)))?;
1529 Ok((ClientStream::Plain(tcp), None))
1530 }
1531 },
1532 other => Ok((ClientStream::Plain(tcp), Some(other))),
1533 }
1534 }
1535
1536 async fn handle_startup(
1540 client_stream: &mut ClientStream,
1541 buffer: &mut BytesMut,
1542 codec: &ProtocolCodec,
1543 pre: Option<StartupMessage>,
1544 session: &Arc<ClientSession>,
1545 state: &Arc<ServerState>,
1546 config: &ProxyConfig,
1547 ) -> Result<(Option<TcpStream>, String)> {
1548 let startup_msg = match pre {
1551 Some(msg) => Some(msg),
1552 None => {
1553 let mut read_buf = vec![0u8; 1024];
1554 loop {
1555 if let Some(msg) = codec.decode_startup(buffer)? {
1556 break Some(msg);
1557 }
1558 let n = client_stream
1559 .read(&mut read_buf)
1560 .await
1561 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
1562 if n == 0 {
1563 return Ok((None, String::new()));
1564 }
1565 buffer.extend_from_slice(&read_buf[..n]);
1566 }
1567 }
1568 };
1569
1570 match startup_msg {
1571 Some(StartupMessage::SSLRequest) => {
1572 client_stream
1575 .write_all(&[b'N'])
1576 .await
1577 .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
1578 Err(ProxyError::Protocol("unexpected SSLRequest after startup".to_string()))
1579 }
1580 Some(StartupMessage::CancelRequest { pid, key }) => {
1581 Self::forward_cancel_request(state, pid, key).await;
1584 Ok((None, String::new()))
1585 }
1586 Some(StartupMessage::Startup { params, .. }) => {
1587 Self::connect_and_authenticate(client_stream, ¶ms, session, state, config).await
1588 }
1589 None => Err(ProxyError::Protocol("Incomplete startup message".to_string())),
1590 }
1591 }
1592
1593 fn hba_admits(rules: &[HbaRule], ip: std::net::IpAddr, user: &str, database: &str) -> bool {
1596 for r in rules {
1597 let user_ok = r.user == "all" || r.user == user;
1598 let db_ok = r.database == "all" || r.database == database;
1599 if user_ok && db_ok && Self::hba_addr_matches(&r.address, ip) {
1600 return r.action == HbaAction::Allow;
1601 }
1602 }
1603 true
1604 }
1605
1606 fn hba_addr_matches(spec: &str, ip: std::net::IpAddr) -> bool {
1609 use std::net::IpAddr;
1610 if spec == "all" {
1611 return true;
1612 }
1613 if let Some((net, bits)) = spec.split_once('/') {
1614 let bits: u32 = match bits.parse() {
1615 Ok(b) => b,
1616 Err(_) => return false,
1617 };
1618 match (net.parse::<IpAddr>(), ip) {
1619 (Ok(IpAddr::V4(n)), IpAddr::V4(i)) if bits <= 32 => {
1620 let mask = if bits == 0 { 0 } else { u32::MAX << (32 - bits) };
1621 (u32::from(n) & mask) == (u32::from(i) & mask)
1622 }
1623 (Ok(IpAddr::V6(n)), IpAddr::V6(i)) if bits <= 128 => {
1624 let mask = if bits == 0 { 0 } else { u128::MAX << (128 - bits) };
1625 (u128::from(n) & mask) == (u128::from(i) & mask)
1626 }
1627 _ => false,
1628 }
1629 } else {
1630 spec.parse::<IpAddr>().map(|s| s == ip).unwrap_or(false)
1631 }
1632 }
1633
1634 async fn proxy_scram_auth(
1640 client: &mut ClientStream,
1641 user: &str,
1642 state: &Arc<ServerState>,
1643 ) -> std::result::Result<(), String> {
1644 use crate::auth_scram::ScramServer;
1645 let auth_file = state.auth_file.as_ref().ok_or("scram not configured")?;
1646
1647 let mut sasl = BytesMut::new();
1649 sasl.put_i32(10); sasl.extend_from_slice(b"SCRAM-SHA-256\0");
1651 sasl.put_u8(0); Self::write_auth_frame(client, &sasl).await?;
1653
1654 let init = Self::read_password_message(client).await?;
1656 let mech_end = init
1657 .iter()
1658 .position(|&b| b == 0)
1659 .ok_or("malformed SASLInitialResponse (no mechanism)")?;
1660 if init.len() < mech_end + 5 {
1661 return Err("short SASLInitialResponse".into());
1662 }
1663 let client_first = std::str::from_utf8(&init[mech_end + 5..])
1664 .map_err(|_| "client-first not UTF-8")?;
1665
1666 let verifier = auth_file
1668 .get(user)
1669 .ok_or("no such user")?
1670 .clone();
1671
1672 let server_nonce = Self::random_nonce();
1674 let (server, server_first) = ScramServer::start(verifier, client_first, &server_nonce)?;
1675
1676 let mut cont = BytesMut::new();
1678 cont.put_i32(11);
1679 cont.extend_from_slice(server_first.as_bytes());
1680 Self::write_auth_frame(client, &cont).await?;
1681
1682 let client_final_raw = Self::read_password_message(client).await?;
1684 let client_final = std::str::from_utf8(&client_final_raw)
1685 .map_err(|_| "client-final not UTF-8")?;
1686
1687 let server_final = server.finish(client_final)?;
1689
1690 let mut fin = BytesMut::new();
1692 fin.put_i32(12);
1693 fin.extend_from_slice(server_final.as_bytes());
1694 Self::write_auth_frame(client, &fin).await?;
1695 Ok(())
1696 }
1697
1698 async fn write_auth_frame(
1700 client: &mut ClientStream,
1701 payload: &[u8],
1702 ) -> std::result::Result<(), String> {
1703 let mut frame = BytesMut::with_capacity(payload.len() + 5);
1704 frame.put_u8(b'R');
1705 frame.put_u32((payload.len() + 4) as u32);
1706 frame.extend_from_slice(payload);
1707 client
1708 .write_all(&frame)
1709 .await
1710 .map_err(|e| format!("client write: {}", e))
1711 }
1712
1713 async fn read_password_message(
1716 client: &mut ClientStream,
1717 ) -> std::result::Result<BytesMut, String> {
1718 let codec = ProtocolCodec::new();
1719 let mut buffer = BytesMut::with_capacity(1024);
1720 let mut read_buf = vec![0u8; 1024];
1721 loop {
1722 if let Some(msg) = codec
1723 .decode_message(&mut buffer)
1724 .map_err(|e| format!("decode: {}", e))?
1725 {
1726 if msg.msg_type == MessageType::Password {
1727 return Ok(msg.payload);
1728 }
1729 return Err(format!("expected SASL response, got {:?}", msg.msg_type));
1730 }
1731 let n = client
1732 .read(&mut read_buf)
1733 .await
1734 .map_err(|e| format!("client read: {}", e))?;
1735 if n == 0 {
1736 return Err("client closed during SASL".into());
1737 }
1738 buffer.extend_from_slice(&read_buf[..n]);
1739 }
1740 }
1741
1742 fn random_nonce() -> String {
1744 use rand::Rng;
1745 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
1746 let mut rng = rand::thread_rng();
1747 (0..24).map(|_| CHARS[rng.gen_range(0..CHARS.len())] as char).collect()
1748 }
1749
1750 async fn connect_and_authenticate(
1752 client_stream: &mut ClientStream,
1753 params: &HashMap<String, String>,
1754 session: &Arc<ClientSession>,
1755 state: &Arc<ServerState>,
1756 config: &ProxyConfig,
1757 ) -> Result<(Option<TcpStream>, String)> {
1758 let user = params.get("user").map(String::as_str).unwrap_or("");
1761 let database = params.get("database").map(String::as_str).unwrap_or(user);
1762 if !Self::hba_admits(&config.hba, session.client_addr.ip(), user, database) {
1763 tracing::info!(%user, %database, client = %session.client_addr, "connection rejected by hba rule");
1764 let err = Self::create_error_response(
1765 "28000",
1766 "connection rejected by proxy admission rules",
1767 );
1768 let _ = client_stream.write_all(&err).await;
1769 return Ok((None, String::new()));
1770 }
1771
1772 if state.auth_file.is_some() {
1778 if let Err(e) = Self::proxy_scram_auth(client_stream, user, state).await {
1779 tracing::info!(%user, error = %e, "proxy SCRAM auth failed");
1780 let err = Self::create_error_response("28P01", &format!("authentication failed: {}", e));
1781 let _ = client_stream.write_all(&err).await;
1782 return Ok((None, String::new()));
1783 }
1784 tracing::debug!(%user, "client authenticated by proxy SCRAM");
1785 }
1786
1787 Self::apply_authenticate_hook(params, session, state).await?;
1793
1794 let cutover = state.cutover.load_full();
1798 let (node_addr, effective_params) = if let Some(t) = cutover.as_ref() {
1799 let mut p = params.clone();
1800 p.insert("user".to_string(), t.user.clone());
1801 if let Some(ref db) = t.database {
1802 p.insert("database".to_string(), db.clone());
1803 } else {
1804 p.remove("database");
1805 }
1806 tracing::debug!(target = %t.addr, "routing connection to cutover target");
1807 (t.addr.clone(), p)
1808 } else {
1809 (Self::select_node(session, state, config).await?, params.clone())
1810 };
1811
1812 let mut backend = tokio::time::timeout(
1814 config.pool.acquire_timeout(),
1815 TcpStream::connect(&node_addr),
1816 )
1817 .await
1818 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", node_addr)))?
1819 .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", node_addr, e)))?;
1820 let _ = backend.set_nodelay(true);
1821
1822 let params = &effective_params;
1824 let startup_bytes = Self::build_startup_message(params);
1825 backend
1826 .write_all(&startup_bytes)
1827 .await
1828 .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
1829
1830 Self::proxy_authentication(client_stream, &mut backend, state, &node_addr).await?;
1834
1835 {
1837 let mut vars = session.variables.write().await;
1838 for (k, v) in params {
1839 vars.insert(k.clone(), v.clone());
1840 }
1841 }
1842
1843 Ok((Some(backend), node_addr))
1844 }
1845
1846 fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
1848 let mut payload = BytesMut::new();
1849
1850 payload.put_u32(196608);
1852
1853 for (key, value) in params {
1855 payload.extend_from_slice(key.as_bytes());
1856 payload.put_u8(0);
1857 payload.extend_from_slice(value.as_bytes());
1858 payload.put_u8(0);
1859 }
1860 payload.put_u8(0); let mut msg = BytesMut::new();
1864 msg.put_u32((payload.len() + 4) as u32);
1865 msg.extend_from_slice(&payload);
1866
1867 msg.to_vec()
1868 }
1869
1870 const MAX_CANCEL_KEYS: usize = 100_000;
1873
1874 fn register_cancel_key(state: &Arc<ServerState>, pid: u32, key: u32, node_addr: &str) {
1876 if state.cancel_map.len() >= Self::MAX_CANCEL_KEYS {
1877 state.cancel_map.clear();
1878 }
1879 state.cancel_map.insert((pid, key), node_addr.to_string());
1880 }
1881
1882 async fn forward_cancel_request(state: &Arc<ServerState>, pid: u32, key: u32) {
1885 let Some(addr) = state.cancel_map.get(&(pid, key)).map(|e| e.clone()) else {
1886 tracing::debug!(pid, "cancel request for unknown key; ignoring");
1887 return;
1888 };
1889 let mut msg = BytesMut::with_capacity(16);
1891 msg.put_u32(16);
1892 msg.put_u32(80877102);
1893 msg.put_u32(pid);
1894 msg.put_u32(key);
1895 match tokio::time::timeout(Duration::from_secs(5), TcpStream::connect(&addr)).await {
1896 Ok(Ok(mut conn)) => {
1897 let _ = conn.set_nodelay(true);
1898 if let Err(e) = conn.write_all(&msg).await {
1899 tracing::warn!(node = %addr, error = %e, "failed to forward CancelRequest");
1900 }
1901 }
1903 other => tracing::warn!(node = %addr, ?other, "could not connect to forward CancelRequest"),
1904 }
1905 }
1906
1907 async fn proxy_authentication(
1909 client_stream: &mut ClientStream,
1910 backend_stream: &mut TcpStream,
1911 state: &Arc<ServerState>,
1912 node_addr: &str,
1913 ) -> Result<()> {
1914 let codec = ProtocolCodec::new();
1915 let mut backend_buffer = BytesMut::with_capacity(4096);
1916 let mut client_buffer = BytesMut::with_capacity(4096);
1917 let mut read_buf = vec![0u8; 4096];
1918
1919 loop {
1920 let n = backend_stream
1922 .read(&mut read_buf)
1923 .await
1924 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1925
1926 if n == 0 {
1927 return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1928 }
1929
1930 backend_buffer.extend_from_slice(&read_buf[..n]);
1931
1932 client_stream
1934 .write_all(&read_buf[..n])
1935 .await
1936 .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
1937
1938 while let Some(msg) = codec.decode_message(&mut backend_buffer)? {
1942 match msg.msg_type {
1943 MessageType::BackendKeyData => {
1944 if msg.payload.len() >= 8 {
1948 let pid = u32::from_be_bytes([
1949 msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3],
1950 ]);
1951 let key = u32::from_be_bytes([
1952 msg.payload[4], msg.payload[5], msg.payload[6], msg.payload[7],
1953 ]);
1954 Self::register_cancel_key(state, pid, key, node_addr);
1955 }
1956 }
1957 MessageType::AuthRequest => {
1958 if msg.payload.len() >= 4 {
1960 let auth_type =
1961 i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
1962 if auth_type == 0 {
1963 }
1965 }
1966 }
1967 MessageType::ReadyForQuery => {
1968 return Ok(());
1970 }
1971 MessageType::ErrorResponse => {
1972 return Err(ProxyError::Auth("Authentication failed".to_string()));
1974 }
1975 _ => {
1976 }
1978 }
1979 }
1980
1981 let n = tokio::time::timeout(Duration::from_millis(100), client_stream.read(&mut read_buf))
1984 .await;
1985
1986 if let Ok(Ok(n)) = n {
1987 if n > 0 {
1988 client_buffer.extend_from_slice(&read_buf[..n]);
1989 backend_stream
1990 .write_all(&read_buf[..n])
1991 .await
1992 .map_err(|e| ProxyError::Network(format!("Backend password write error: {}", e)))?;
1993 }
1994 }
1995 }
1996 }
1997
1998 async fn choose_target_node(
2003 is_write: bool,
2004 forced_target: Option<String>,
2005 current_node: Option<&str>,
2006 session: &Arc<ClientSession>,
2007 state: &Arc<ServerState>,
2008 config: &ProxyConfig,
2009 ) -> Result<String> {
2010 if let Some(t) = state.cutover.load_full().as_ref() {
2013 return Ok(t.addr.clone());
2014 }
2015 let need_switch = if let Some(ref forced) = forced_target {
2016 let health = state.health.load_full();
2017 let reuse = current_node
2018 .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
2019 .unwrap_or(false);
2020 !reuse
2021 } else if let Some(current) = current_node {
2022 let health = state.health.load_full();
2023 let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
2024 if !current_healthy {
2025 true
2026 } else if is_write {
2027 let is_primary = config
2028 .nodes
2029 .iter()
2030 .find(|n| n.address() == *current)
2031 .map(|n| n.role == NodeRole::Primary)
2032 .unwrap_or(false);
2033 !is_primary
2034 } else {
2035 false
2036 }
2037 } else {
2038 true
2039 };
2040
2041 if let Some(forced) = forced_target {
2042 Ok(forced)
2043 } else if need_switch {
2044 if is_write {
2045 Self::select_primary_with_timeout(session, state, config).await
2046 } else {
2047 Self::select_read_node(session, state, config).await
2048 }
2049 } else {
2050 Ok(current_node.unwrap().to_string())
2051 }
2052 }
2053
2054 async fn ensure_conn(
2059 conns: &mut HashMap<String, BackendConn>,
2060 target: &str,
2061 session: &Arc<ClientSession>,
2062 config: &ProxyConfig,
2063 ) -> Result<()> {
2064 if conns.contains_key(target) {
2065 return Ok(());
2066 }
2067 let mut backend = tokio::time::timeout(
2068 config.pool.acquire_timeout(),
2069 TcpStream::connect(target),
2070 )
2071 .await
2072 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target)))?
2073 .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", target, e)))?;
2074 let _ = backend.set_nodelay(true);
2075
2076 let params = session.variables.read().await.clone();
2077 let startup = Self::build_startup_message(¶ms);
2078 backend
2079 .write_all(&startup)
2080 .await
2081 .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
2082 Self::complete_backend_auth(&mut backend).await?;
2083 tracing::debug!(node = %target, "opened backend connection");
2084 conns.insert(target.to_string(), BackendConn::new(backend));
2085 Ok(())
2086 }
2087
2088 async fn forward_simple_query(
2094 client: &mut ClientStream,
2095 msg: &Message,
2096 conns: &mut HashMap<String, BackendConn>,
2097 current_node: Option<&str>,
2098 session: &Arc<ClientSession>,
2099 state: &Arc<ServerState>,
2100 config: &ProxyConfig,
2101 ) -> Result<(Option<String>, u64)> {
2102 let default_is_write = Self::is_write_message(msg);
2103 let route_override = Self::apply_route_hook(msg, state, session);
2104
2105 if let RouteOverride::Block(reason) = route_override {
2107 let mut response = Vec::with_capacity(64 + reason.len());
2108 response.extend_from_slice(&Self::create_error_response(
2109 "42000",
2110 &format!("Query blocked by route plugin: {}", reason),
2111 ));
2112 response.extend_from_slice(&Self::create_ready_for_query(b'I'));
2113 client
2114 .write_all(&response)
2115 .await
2116 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2117 return Ok((None, response.len() as u64));
2118 }
2119
2120 let (is_write, forced_target) = match route_override {
2121 RouteOverride::None => (default_is_write, None),
2122 RouteOverride::Primary => (true, None),
2123 RouteOverride::Standby => (false, None),
2124 RouteOverride::Node(name) => (default_is_write, Some(name)),
2125 RouteOverride::Block(_) => unreachable!("handled above"),
2126 };
2127
2128 let target =
2129 Self::choose_target_node(is_write, forced_target, current_node, session, state, config)
2130 .await?;
2131 Self::ensure_conn(conns, &target, session, config).await?;
2132 let backend = conns.get_mut(&target).expect("just ensured");
2133
2134 backend
2135 .stream
2136 .write_all(&msg.encode())
2137 .await
2138 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))?;
2139
2140 match Self::stream_until_ready(client, &mut backend.stream, session, state).await {
2141 Ok(sent) => Ok((Some(target), sent)),
2142 Err(e) => {
2143 conns.remove(&target);
2145 Err(e)
2146 }
2147 }
2148 }
2149
2150 #[allow(clippy::too_many_arguments)]
2163 async fn forward_extended_batch(
2164 client: &mut ClientStream,
2165 batch: &[u8],
2166 route_sql: Option<&str>,
2167 wait_ready: bool,
2168 conns: &mut HashMap<String, BackendConn>,
2169 current_node: Option<&str>,
2170 registry: &HashMap<String, bytes::Bytes>,
2171 reprepare: &[String],
2172 defines: &[String],
2173 unnamed: Option<(bytes::Bytes, bytes::Bytes)>,
2174 session: &Arc<ClientSession>,
2175 state: &Arc<ServerState>,
2176 config: &ProxyConfig,
2177 ) -> Result<(Option<String>, u64)> {
2178 let target = match route_sql {
2179 Some(sql) => {
2180 let is_write = Self::is_write_query(sql);
2181 Self::choose_target_node(is_write, None, current_node, session, state, config)
2182 .await?
2183 }
2184 None => match current_node {
2188 Some(c) => c.to_string(),
2189 None => Self::select_read_node(session, state, config).await?,
2190 },
2191 };
2192
2193 Self::ensure_conn(conns, &target, session, config).await?;
2194 let backend = conns.get_mut(&target).expect("just ensured");
2195
2196 for name in reprepare {
2201 if backend.prepared.contains(name) {
2202 continue;
2203 }
2204 let Some(parse_bytes) = registry.get(name) else {
2205 continue; };
2207 match Self::reprepare_statement(&mut backend.stream, parse_bytes).await {
2208 Ok(()) => {
2209 backend.prepared.insert(name.clone());
2210 }
2211 Err(e) => {
2212 conns.remove(&target);
2213 return Err(e);
2214 }
2215 }
2216 }
2217
2218 let mut inject_parse_complete = false;
2225 let mut new_unnamed_sig: Option<bytes::Bytes> = None;
2226 if let Some((parse_msg, sig)) = unnamed.as_ref() {
2227 if backend.unnamed_sig.as_deref() == Some(&sig[..]) {
2228 inject_parse_complete = true;
2229 } else {
2230 if let Err(e) = backend
2231 .stream
2232 .write_all(parse_msg)
2233 .await
2234 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
2235 {
2236 conns.remove(&target);
2237 return Err(e);
2238 }
2239 new_unnamed_sig = Some(sig.clone());
2240 }
2241 }
2242
2243 if let Err(e) = backend
2244 .stream
2245 .write_all(batch)
2246 .await
2247 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))
2248 {
2249 conns.remove(&target);
2250 return Err(e);
2251 }
2252
2253 let mut injected: u64 = 0;
2256 if inject_parse_complete {
2257 if let Err(e) = client
2258 .write_all(&[b'1', 0, 0, 0, 4])
2259 .await
2260 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))
2261 {
2262 conns.remove(&target);
2263 return Err(e);
2264 }
2265 injected = 5;
2266 }
2267
2268 let r = if wait_ready {
2269 Self::stream_until_ready(client, &mut backend.stream, session, state).await
2270 } else {
2271 Self::stream_flush(client, &mut backend.stream, session, state).await
2272 };
2273 match r {
2274 Ok(sent) => {
2275 for name in defines {
2277 backend.prepared.insert(name.clone());
2278 }
2279 if let Some(sig) = new_unnamed_sig {
2281 backend.unnamed_sig = Some(sig);
2282 }
2283 Ok((Some(target), sent + injected))
2284 }
2285 Err(e) => {
2286 conns.remove(&target);
2287 Err(e)
2288 }
2289 }
2290 }
2291
2292 async fn reprepare_statement<S: AsyncReadExt + AsyncWriteExt + Unpin>(
2298 backend: &mut S,
2299 parse_bytes: &[u8],
2300 ) -> Result<()> {
2301 backend
2302 .write_all(parse_bytes)
2303 .await
2304 .map_err(|e| ProxyError::Network(format!("re-prepare write error: {}", e)))?;
2305 backend
2307 .write_all(&[b'H', 0, 0, 0, 4])
2308 .await
2309 .map_err(|e| ProxyError::Network(format!("re-prepare flush error: {}", e)))?;
2310 let mtype = Self::read_one_frame_type(backend).await?;
2311 match mtype {
2312 b'1' => Ok(()), b'E' => Err(ProxyError::Protocol("re-prepare rejected by backend".to_string())),
2314 other => Err(ProxyError::Protocol(format!(
2315 "unexpected re-prepare reply: {}",
2316 other as char
2317 ))),
2318 }
2319 }
2320
2321 async fn read_one_frame_type<S: AsyncReadExt + Unpin>(backend: &mut S) -> Result<u8> {
2325 let mut header = [0u8; 5];
2326 backend
2327 .read_exact(&mut header)
2328 .await
2329 .map_err(|e| ProxyError::Network(format!("re-prepare read error: {}", e)))?;
2330 let len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
2331 let body_len = len.saturating_sub(4);
2332 if body_len > 0 {
2333 let mut body = vec![0u8; body_len];
2334 backend
2335 .read_exact(&mut body)
2336 .await
2337 .map_err(|e| ProxyError::Network(format!("re-prepare body read error: {}", e)))?;
2338 }
2339 Ok(header[0])
2340 }
2341
2342 fn parse_stmt_name(payload: &[u8]) -> &str {
2345 let end = payload.iter().position(|&b| b == 0).unwrap_or(0);
2346 std::str::from_utf8(&payload[..end]).unwrap_or("")
2347 }
2348
2349 fn bind_stmt_ref(payload: &[u8]) -> Option<&str> {
2353 let portal_end = payload.iter().position(|&b| b == 0)?;
2354 let rest = &payload[portal_end + 1..];
2355 let stmt_end = rest.iter().position(|&b| b == 0)?;
2356 let name = std::str::from_utf8(&rest[..stmt_end]).ok()?;
2357 (!name.is_empty()).then_some(name)
2358 }
2359
2360 fn stmt_kind_name(payload: &[u8]) -> Option<&str> {
2363 if payload.first() != Some(&b'S') {
2364 return None;
2365 }
2366 let rest = &payload[1..];
2367 let end = rest.iter().position(|&b| b == 0)?;
2368 let name = std::str::from_utf8(&rest[..end]).ok()?;
2369 (!name.is_empty()).then_some(name)
2370 }
2371
2372 async fn stream_until_ready(
2380 client: &mut ClientStream,
2381 backend: &mut TcpStream,
2382 session: &Arc<ClientSession>,
2383 state: &Arc<ServerState>,
2384 ) -> Result<u64> {
2385 let _ = state;
2386 let mut buf = BytesMut::with_capacity(16384);
2387 let mut read_buf = vec![0u8; 16384];
2388 let mut sent: u64 = 0;
2389
2390 loop {
2391 let mut consumed = 0usize;
2393 let mut ready_status: Option<u8> = None;
2394 let mut yield_for_copy = false;
2395 loop {
2396 let rem = &buf[consumed..];
2397 if rem.len() < 5 {
2398 break;
2399 }
2400 let len = u32::from_be_bytes([rem[1], rem[2], rem[3], rem[4]]) as usize;
2401 if len < 4 || rem.len() < len + 1 {
2402 break; }
2404 let frame_total = len + 1;
2405 let mtype = rem[0];
2406 consumed += frame_total;
2407 if mtype == b'Z' {
2408 ready_status = Some(if frame_total >= 6 { rem[5] } else { b'I' });
2410 break;
2411 }
2412 if mtype == b'G' || mtype == b'W' {
2413 yield_for_copy = true;
2416 break;
2417 }
2418 }
2419
2420 if consumed > 0 {
2421 client
2422 .write_all(&buf[..consumed])
2423 .await
2424 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2425 sent += consumed as u64;
2426 let _ = buf.split_to(consumed);
2427 }
2428
2429 if let Some(status) = ready_status {
2430 let st = TransactionStatus::from_byte(status);
2431 let mut tx = session.tx_state.write().await;
2432 tx.in_transaction = st != TransactionStatus::Idle;
2433 return Ok(sent);
2434 }
2435 if yield_for_copy {
2436 return Ok(sent);
2437 }
2438
2439 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
2440 .await
2441 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
2442 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
2443 if n == 0 {
2444 return Err(ProxyError::Connection("Backend closed mid-response".to_string()));
2445 }
2446 buf.extend_from_slice(&read_buf[..n]);
2447 }
2448 }
2449
2450 async fn stream_flush(
2456 client: &mut ClientStream,
2457 backend: &mut TcpStream,
2458 session: &Arc<ClientSession>,
2459 state: &Arc<ServerState>,
2460 ) -> Result<u64> {
2461 let _ = (session, state);
2462 let mut read_buf = vec![0u8; 16384];
2463 let mut sent: u64 = 0;
2464 loop {
2465 match tokio::time::timeout(Duration::from_millis(200), backend.read(&mut read_buf)).await
2466 {
2467 Ok(Ok(0)) => return Err(ProxyError::Connection("Backend closed mid-flush".to_string())),
2468 Ok(Ok(n)) => {
2469 client
2470 .write_all(&read_buf[..n])
2471 .await
2472 .map_err(|e| ProxyError::Network(format!("Client write error: {}", e)))?;
2473 sent += n as u64;
2474 }
2475 Ok(Err(e)) => return Err(ProxyError::Network(format!("Backend read error: {}", e))),
2476 Err(_) => return Ok(sent), }
2478 }
2479 }
2480
2481 fn is_write_message(msg: &Message) -> bool {
2483 match msg.msg_type {
2484 MessageType::Query => {
2485 crate::protocol::query_text(&msg.payload)
2489 .map(Self::is_write_query)
2490 .unwrap_or(false)
2491 }
2492 MessageType::Parse => {
2493 msg.payload
2496 .iter()
2497 .position(|&b| b == 0)
2498 .and_then(|end| crate::protocol::query_text(&msg.payload[end + 1..]))
2499 .map(Self::is_write_query)
2500 .unwrap_or(false)
2501 }
2502 _ => false,
2504 }
2505 }
2506
2507 fn is_write_query(sql: &str) -> bool {
2509 use crate::protocol::starts_with_ci;
2510 let trimmed = sql.trim();
2511
2512 if starts_with_ci(trimmed, "INSERT")
2514 || starts_with_ci(trimmed, "UPDATE")
2515 || starts_with_ci(trimmed, "DELETE")
2516 || starts_with_ci(trimmed, "CREATE")
2517 || starts_with_ci(trimmed, "DROP")
2518 || starts_with_ci(trimmed, "ALTER")
2519 || starts_with_ci(trimmed, "TRUNCATE")
2520 || starts_with_ci(trimmed, "GRANT")
2521 || starts_with_ci(trimmed, "REVOKE")
2522 || starts_with_ci(trimmed, "VACUUM")
2523 || starts_with_ci(trimmed, "REINDEX")
2524 || starts_with_ci(trimmed, "CLUSTER")
2525 {
2526 return true;
2527 }
2528
2529 if starts_with_ci(trimmed, "BEGIN")
2531 || starts_with_ci(trimmed, "START")
2532 || starts_with_ci(trimmed, "COMMIT")
2533 || starts_with_ci(trimmed, "ROLLBACK")
2534 || starts_with_ci(trimmed, "SAVEPOINT")
2535 || starts_with_ci(trimmed, "RELEASE")
2536 {
2537 return true;
2538 }
2539
2540 if starts_with_ci(trimmed, "SET") && !starts_with_ci(trimmed, "SET TRANSACTION READ ONLY") {
2542 return true;
2543 }
2544
2545 false
2546 }
2547
2548 async fn select_primary_with_timeout(
2550 session: &Arc<ClientSession>,
2551 state: &Arc<ServerState>,
2552 config: &ProxyConfig,
2553 ) -> Result<String> {
2554 let timeout = config.write_timeout();
2555 let start = std::time::Instant::now();
2556 let check_interval = Duration::from_millis(100);
2559
2560 loop {
2561 let health = state.health.load_full();
2563 let primary = config
2564 .nodes
2565 .iter()
2566 .find(|n| n.role == NodeRole::Primary && n.enabled);
2567
2568 if let Some(primary_node) = primary {
2569 if let Some(node_health) = health.get(&primary_node.address()) {
2570 if node_health.healthy {
2571 let mut current = session.current_node.write().await;
2573 *current = Some(primary_node.address());
2574 return Ok(primary_node.address());
2575 }
2576 }
2577 }
2578 drop(health);
2579
2580 if start.elapsed() >= timeout {
2582 state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
2583 return Err(ProxyError::NoHealthyNodes);
2584 }
2585
2586 tracing::warn!(
2587 "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
2588 start.elapsed().as_secs_f64(),
2589 timeout.as_secs_f64()
2590 );
2591
2592 tokio::time::sleep(check_interval).await;
2594 }
2595 }
2596
2597 async fn select_read_node(
2599 session: &Arc<ClientSession>,
2600 state: &Arc<ServerState>,
2601 config: &ProxyConfig,
2602 ) -> Result<String> {
2603 {
2605 let tx_state = session.tx_state.read().await;
2606 if tx_state.in_transaction {
2607 if let Some(node) = session.current_node.read().await.clone() {
2608 return Ok(node);
2609 }
2610 }
2611 }
2612
2613 let health = state.health.load_full();
2615 let healthy_standbys: Vec<&NodeConfig> = config
2616 .nodes
2617 .iter()
2618 .filter(|n| {
2619 n.enabled
2620 && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
2621 && health
2622 .get(&n.address())
2623 .map(|h| h.healthy)
2624 .unwrap_or(false)
2625 })
2626 .collect();
2627
2628 if !healthy_standbys.is_empty() {
2629 let ticket = state.lb_state.rr_counter.fetch_add(1, Ordering::Relaxed);
2631 let index = ticket as usize % healthy_standbys.len();
2632 let node_addr = healthy_standbys[index].address();
2633
2634 let mut current = session.current_node.write().await;
2635 *current = Some(node_addr.clone());
2636 return Ok(node_addr);
2637 }
2638
2639 Self::select_node(session, state, config).await
2641 }
2642
2643 async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
2646 let codec = ProtocolCodec::new();
2647 let mut buffer = BytesMut::with_capacity(4096);
2648 let mut read_buf = vec![0u8; 4096];
2649 let timeout = Duration::from_secs(10);
2650 let start = std::time::Instant::now();
2651
2652 loop {
2653 if start.elapsed() > timeout {
2654 return Err(ProxyError::Auth("Backend authentication timeout".to_string()));
2655 }
2656
2657 let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
2658 .await
2659 .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
2660 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
2661
2662 if n == 0 {
2663 return Err(ProxyError::Connection("Backend closed during auth".to_string()));
2664 }
2665
2666 buffer.extend_from_slice(&read_buf[..n]);
2667
2668 while let Some(msg) = codec.decode_message(&mut buffer)? {
2671 match msg.msg_type {
2672 MessageType::ReadyForQuery => {
2673 return Ok(());
2675 }
2676 MessageType::ErrorResponse => {
2677 let err = ErrorResponse::parse(msg.payload)
2678 .map(|e| e.message().unwrap_or("Unknown error").to_string())
2679 .unwrap_or_else(|_| "Parse error".to_string());
2680 return Err(ProxyError::Auth(err));
2681 }
2682 _ => {
2683 }
2685 }
2686 }
2687 }
2688 }
2689
2690 fn create_error_response(code: &str, message: &str) -> Vec<u8> {
2692 let mut fields = HashMap::new();
2693 fields.insert('S', "ERROR".to_string());
2694 fields.insert('V', "ERROR".to_string());
2695 fields.insert('C', code.to_string());
2696 fields.insert('M', message.to_string());
2697
2698 let err = ErrorResponse { fields };
2699 err.encode().encode().to_vec()
2700 }
2701
2702 fn create_ready_for_query(status: u8) -> Vec<u8> {
2705 let mut payload = BytesMut::with_capacity(1);
2706 payload.put_u8(status);
2707 Message::new(MessageType::ReadyForQuery, payload)
2708 .encode()
2709 .to_vec()
2710 }
2711
2712 #[cfg(feature = "wasm-plugins")]
2750 fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
2751 use serde::Deserialize;
2752
2753 #[derive(Deserialize)]
2754 struct CachedPayload {
2755 columns: Vec<ColumnDef>,
2756 rows: Vec<Vec<Option<String>>>,
2757 }
2758
2759 #[derive(Deserialize)]
2760 struct ColumnDef {
2761 name: String,
2762 #[serde(default = "default_text_oid")]
2763 oid: u32,
2764 }
2765
2766 fn default_text_oid() -> u32 {
2767 25 }
2769
2770 let payload: CachedPayload = serde_json::from_slice(bytes).map_err(|e| {
2771 ProxyError::Protocol(format!("invalid cached payload JSON: {}", e))
2772 })?;
2773
2774 if payload.columns.is_empty() {
2775 return Err(ProxyError::Protocol(
2776 "cached payload must declare at least one column".to_string(),
2777 ));
2778 }
2779
2780 let mut reply = Vec::new();
2781
2782 let mut rd = BytesMut::new();
2784 rd.put_u16(payload.columns.len() as u16);
2785 for col in &payload.columns {
2786 rd.extend_from_slice(col.name.as_bytes());
2787 rd.put_u8(0); rd.put_i32(0); rd.put_i16(0); rd.put_u32(col.oid);
2791 rd.put_i16(-1); rd.put_i32(-1); rd.put_i16(0); }
2795 reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
2796
2797 let column_count = payload.columns.len();
2799 for row in &payload.rows {
2800 if row.len() != column_count {
2801 return Err(ProxyError::Protocol(format!(
2802 "cached row has {} values but {} columns are declared",
2803 row.len(),
2804 column_count
2805 )));
2806 }
2807 let mut dr = BytesMut::new();
2808 dr.put_u16(row.len() as u16);
2809 for value in row {
2810 match value {
2811 Some(s) => {
2812 dr.put_i32(s.len() as i32);
2813 dr.extend_from_slice(s.as_bytes());
2814 }
2815 None => {
2816 dr.put_i32(-1); }
2818 }
2819 }
2820 reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
2821 }
2822
2823 let tag = format!("SELECT {}", payload.rows.len());
2825 let mut cc = BytesMut::new();
2826 cc.extend_from_slice(tag.as_bytes());
2827 cc.put_u8(0);
2828 reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
2829
2830 reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
2832
2833 Ok(reply)
2834 }
2835
2836 fn apply_pre_query_hook(
2846 msg: Message,
2847 state: &Arc<ServerState>,
2848 session: &Arc<ClientSession>,
2849 ) -> (Message, PreQueryAction) {
2850 #[cfg(feature = "wasm-plugins")]
2851 {
2852 let pm = match state.plugin_manager.as_ref() {
2853 Some(pm) => pm,
2854 None => return (msg, PreQueryAction::Forward),
2855 };
2856
2857 if msg.msg_type != MessageType::Query {
2858 return (msg, PreQueryAction::Forward);
2859 }
2860
2861 if !pm.has_hook(HookType::PreQuery) {
2864 return (msg, PreQueryAction::Forward);
2865 }
2866
2867 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
2868 Ok(q) => q,
2869 Err(_) => return (msg, PreQueryAction::Forward),
2870 };
2871
2872 let ctx = Self::build_query_context(&query_msg.query, session);
2873
2874 match pm.execute_pre_query(&ctx) {
2875 PreQueryResult::Continue => (msg, PreQueryAction::Forward),
2876 PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
2877 PreQueryResult::Rewrite(new_sql) => {
2878 let rewritten = QueryMessage { query: new_sql }.encode();
2879 (rewritten, PreQueryAction::Forward)
2880 }
2881 PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
2882 }
2883 }
2884 #[cfg(not(feature = "wasm-plugins"))]
2885 {
2886 let _ = (state, session);
2887 (msg, PreQueryAction::Forward)
2888 }
2889 }
2890
2891 #[cfg(feature = "anomaly-detection")]
2897 fn record_anomaly_observation(
2898 msg: &Message,
2899 state: &Arc<ServerState>,
2900 session: &Arc<ClientSession>,
2901 ) {
2902 if msg.msg_type != MessageType::Query {
2903 return;
2904 }
2905 if let Some(query) = crate::protocol::query_text(&msg.payload) {
2908 Self::record_anomaly_sql(query, state, session);
2909 }
2910 }
2911
2912 #[cfg(feature = "anomaly-detection")]
2916 fn record_anomaly_sql(query: &str, state: &Arc<ServerState>, session: &Arc<ClientSession>) {
2917 let tenant = match session.variables.try_read() {
2924 Ok(vars) => vars
2925 .get("tenant_id")
2926 .or_else(|| vars.get("user"))
2927 .cloned()
2928 .unwrap_or_else(|| session.client_addr.ip().to_string()),
2929 Err(_) => session.client_addr.ip().to_string(),
2930 };
2931 let fingerprint = anomaly_fingerprint(query);
2932 let obs = crate::anomaly::QueryObservation {
2933 tenant,
2934 fingerprint,
2935 sql: query.to_string(),
2936 timestamp: std::time::Instant::now(),
2937 };
2938 for ev in state.anomaly_detector.record_query(&obs) {
2939 tracing::warn!(anomaly = ?ev, "anomaly detected");
2940 }
2941 }
2942
2943 async fn send_block_response(
2947 stream: &mut ClientStream,
2948 reason: &str,
2949 state: &Arc<ServerState>,
2950 ) -> Result<()> {
2951 let err = Self::create_error_response(
2952 "42000",
2953 &format!("Query blocked by plugin: {}", reason),
2954 );
2955 stream
2956 .write_all(&err)
2957 .await
2958 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
2959 let rfq = Self::create_ready_for_query(b'I');
2960 stream
2961 .write_all(&rfq)
2962 .await
2963 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
2964 state
2965 .metrics
2966 .bytes_sent
2967 .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
2968 Ok(())
2969 }
2970
2971 #[cfg(feature = "wasm-plugins")]
2977 fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
2978 let is_read_only = !Self::is_write_query(query);
2979 let mut hook_context = HookContext::default();
2980 hook_context.client_id = Some(session.id.to_string());
2981 QueryContext {
2982 query: query.to_string(),
2983 normalized: query.to_string(),
2984 tables: Vec::new(),
2985 is_read_only,
2986 hook_context,
2987 }
2988 }
2989
2990 async fn apply_authenticate_hook(
3011 _params: &HashMap<String, String>,
3012 _session: &Arc<ClientSession>,
3013 _state: &Arc<ServerState>,
3014 ) -> Result<()> {
3015 #[cfg(feature = "wasm-plugins")]
3016 {
3017 let pm = match _state.plugin_manager.as_ref() {
3018 Some(pm) => pm,
3019 None => return Ok(()),
3020 };
3021
3022 let request = PluginAuthRequest {
3023 headers: HashMap::new(),
3024 username: _params.get("user").cloned(),
3025 password: None,
3026 client_ip: _session.client_addr.ip().to_string(),
3027 database: _params.get("database").cloned(),
3028 };
3029
3030 match pm.execute_authenticate(&request) {
3031 AuthResult::Defer => Ok(()),
3032 AuthResult::Success(identity) => {
3033 tracing::debug!(
3034 user = %identity.username,
3035 roles = ?identity.roles,
3036 "plugin authenticated user"
3037 );
3038 *_session.plugin_identity.write().await = Some(identity);
3039 Ok(())
3040 }
3041 AuthResult::Denied(reason) => {
3042 tracing::info!(
3043 reason = %reason,
3044 client = %_session.client_addr,
3045 user = ?_params.get("user"),
3046 "plugin denied authentication"
3047 );
3048 Err(ProxyError::Auth(format!(
3049 "authentication denied by plugin: {}",
3050 reason
3051 )))
3052 }
3053 }
3054 }
3055 #[cfg(not(feature = "wasm-plugins"))]
3056 {
3057 Ok(())
3058 }
3059 }
3060
3061 fn apply_route_hook(
3064 msg: &Message,
3065 state: &Arc<ServerState>,
3066 session: &Arc<ClientSession>,
3067 ) -> RouteOverride {
3068 #[cfg(feature = "wasm-plugins")]
3069 {
3070 let pm = match state.plugin_manager.as_ref() {
3071 Some(pm) => pm,
3072 None => return RouteOverride::None,
3073 };
3074 if msg.msg_type != MessageType::Query {
3075 return RouteOverride::None;
3076 }
3077 if !pm.has_hook(HookType::Route) {
3080 return RouteOverride::None;
3081 }
3082 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
3083 Ok(q) => q,
3084 Err(_) => return RouteOverride::None,
3085 };
3086 let ctx = Self::build_query_context(&query_msg.query, session);
3087 match pm.execute_route(&ctx) {
3088 RouteResult::Default => RouteOverride::None,
3089 RouteResult::Primary => RouteOverride::Primary,
3090 RouteResult::Standby => RouteOverride::Standby,
3091 RouteResult::Node(name) => RouteOverride::Node(name),
3092 RouteResult::Block(reason) => RouteOverride::Block(reason),
3093 RouteResult::Branch(name) => {
3094 tracing::warn!(
3095 branch = %name,
3096 "Route hook returned Branch but branch routing is not yet wired — using default"
3097 );
3098 RouteOverride::None
3099 }
3100 }
3101 }
3102 #[cfg(not(feature = "wasm-plugins"))]
3103 {
3104 let _ = (msg, state, session);
3105 RouteOverride::None
3106 }
3107 }
3108
3109 #[cfg(feature = "wasm-plugins")]
3113 fn fire_post_query_hook(
3114 msg: &Message,
3115 session: &Arc<ClientSession>,
3116 state: &Arc<ServerState>,
3117 result: &Result<(Option<String>, u64)>,
3118 elapsed: Duration,
3119 ) {
3120 let pm = match state.plugin_manager.as_ref() {
3121 Some(pm) => pm,
3122 None => return,
3123 };
3124 if msg.msg_type != MessageType::Query {
3125 return;
3126 }
3127 if !pm.has_hook(HookType::PostQuery) {
3130 return;
3131 }
3132 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
3133 Ok(q) => q,
3134 Err(_) => return,
3135 };
3136 let ctx = Self::build_query_context(&query_msg.query, session);
3137 let outcome = match result {
3138 Ok((node, bytes)) => PostQueryOutcome {
3139 success: true,
3140 target_node: node.clone(),
3141 elapsed_us: elapsed.as_micros() as u64,
3142 response_bytes: *bytes,
3143 error: None,
3144 },
3145 Err(e) => PostQueryOutcome {
3146 success: false,
3147 target_node: None,
3148 elapsed_us: elapsed.as_micros() as u64,
3149 response_bytes: 0,
3150 error: Some(e.to_string()),
3151 },
3152 };
3153 pm.execute_post_query(&ctx, &outcome);
3154 }
3155
3156 async fn select_node(
3160 session: &Arc<ClientSession>,
3161 state: &Arc<ServerState>,
3162 config: &ProxyConfig,
3163 ) -> Result<String> {
3164 {
3166 let tx_state = session.tx_state.read().await;
3167 if tx_state.in_transaction {
3168 if let Some(node) = session.current_node.read().await.clone() {
3169 return Ok(node);
3170 }
3171 }
3172 }
3173
3174 let health = state.health.load_full();
3176 let healthy_nodes: Vec<&NodeConfig> = config
3177 .nodes
3178 .iter()
3179 .filter(|n| {
3180 n.enabled
3181 && health
3182 .get(&n.address())
3183 .map(|h| h.healthy)
3184 .unwrap_or(false)
3185 })
3186 .collect();
3187
3188 if healthy_nodes.is_empty() {
3189 return Err(ProxyError::NoHealthyNodes);
3190 }
3191
3192 if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
3194 let node_addr = primary.address();
3195 let mut current = session.current_node.write().await;
3196 *current = Some(node_addr.clone());
3197 return Ok(node_addr);
3198 }
3199
3200 if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
3203 tracing::warn!("Primary unavailable, connecting to standby for initial session");
3204 let node_addr = standby.address();
3205 let mut current = session.current_node.write().await;
3206 *current = Some(node_addr.clone());
3207 return Ok(node_addr);
3208 }
3209
3210 Err(ProxyError::NoHealthyNodes)
3212 }
3213
3214 fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
3216 let state = self.state.clone();
3217 let mut shutdown_rx = self.shutdown_tx.subscribe();
3218
3219 tokio::spawn(async move {
3220 let mut interval = tokio::time::interval(std::time::Duration::from_secs(
3221 state.live_config.load().health.check_interval_secs,
3222 ));
3223
3224 loop {
3225 tokio::select! {
3226 _ = interval.tick() => {
3227 let config = state.live_config.load_full();
3230 Self::check_all_nodes(&state, &config).await;
3231 }
3232 _ = shutdown_rx.recv() => {
3233 break;
3234 }
3235 }
3236 }
3237 })
3238 }
3239
3240 async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
3247 let timeout = Duration::from_secs(config.health.check_timeout_secs);
3250 let mut set = tokio::task::JoinSet::new();
3251 for node in &config.nodes {
3252 let addr = node.address();
3253 set.spawn(async move {
3254 let r = Self::check_node_addr(&addr, timeout).await;
3255 (addr, r)
3256 });
3257 }
3258 let mut results = Vec::with_capacity(config.nodes.len());
3259 while let Some(joined) = set.join_next().await {
3260 if let Ok(pair) = joined {
3261 results.push(pair);
3262 }
3263 }
3264
3265 let mut next = (*state.health.load_full()).clone();
3267 for (addr, result) in results {
3268 if let Some(node_health) = next.get_mut(&addr) {
3269 match result {
3270 Ok(latency) => {
3271 node_health.healthy = true;
3272 node_health.failure_count = 0;
3273 node_health.latency_ms = latency;
3274 node_health.last_error = None;
3275 }
3276 Err(e) => {
3277 node_health.failure_count += 1;
3278 node_health.last_error = Some(e.to_string());
3279 if node_health.failure_count >= config.health.failure_threshold {
3280 node_health.healthy = false;
3281 tracing::warn!(
3282 "Node {} marked unhealthy after {} failures",
3283 addr,
3284 node_health.failure_count
3285 );
3286 }
3287 }
3288 }
3289 node_health.last_check = chrono::Utc::now();
3290 }
3291 }
3292 state.health.store(Arc::new(next));
3293 }
3294
3295 async fn check_node_addr(addr: &str, timeout: Duration) -> Result<f64> {
3298 let start = std::time::Instant::now();
3299 let _stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
3300 .await
3301 .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", addr)))?
3302 .map_err(|e| ProxyError::HealthCheck(format!("Failed to connect to {}: {}", addr, e)))?;
3303 let latency = start.elapsed().as_secs_f64() * 1000.0;
3304 Ok(latency)
3305 }
3306
3307 fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
3309 let state = self.state.clone();
3310 let mut shutdown_rx = self.shutdown_tx.subscribe();
3311
3312 tokio::spawn(async move {
3313 let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
3314
3315 loop {
3316 tokio::select! {
3317 _ = interval.tick() => {
3318 #[cfg(feature = "pool-modes")]
3320 if let Some(ref pool_manager) = state.pool_manager {
3321 pool_manager.evict_idle().await;
3322 tracing::trace!("Pool-modes idle eviction completed");
3323 }
3324 }
3325 _ = shutdown_rx.recv() => {
3326 #[cfg(feature = "pool-modes")]
3328 if let Some(ref pool_manager) = state.pool_manager {
3329 pool_manager.close_all().await;
3330 tracing::info!("Pool-modes manager closed all connections");
3331 }
3332 break;
3333 }
3334 }
3335 }
3336 })
3337 }
3338
3339 pub fn shutdown(&self) {
3341 let _ = self.shutdown_tx.send(());
3342 }
3343
3344 #[cfg(feature = "pool-modes")]
3346 pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
3347 if let Some(ref pool_manager) = self.state.pool_manager {
3348 let stats = pool_manager.get_stats().await;
3349 let metrics = pool_manager.metrics().snapshot();
3350 let default_mode = pool_manager.default_mode();
3351
3352 let avg_lease_duration_ms = metrics
3354 .mode_stats
3355 .get(&default_mode)
3356 .map(|s| s.avg_lease_duration_ms as u64)
3357 .unwrap_or(0);
3358
3359 Some(PoolModeStatsSnapshot {
3360 mode: format!("{:?}", default_mode),
3361 total_connections: stats.total_connections,
3362 active_leases: stats.active_connections,
3363 idle_connections: stats.idle_connections,
3364 node_count: stats.node_count,
3365 acquires: metrics.acquires,
3366 releases: metrics.releases,
3367 acquire_failures: metrics.acquire_failures,
3368 acquire_timeouts: metrics.acquire_timeouts,
3369 transactions_completed: metrics.transactions_completed,
3370 statements_executed: metrics.statements_executed,
3371 avg_lease_duration_ms,
3372 })
3373 } else {
3374 None
3375 }
3376 }
3377
3378 #[cfg(feature = "pool-modes")]
3380 pub async fn add_node_to_pool(&self, node: &NodeConfig) {
3381 if let Some(ref pool_manager) = self.state.pool_manager {
3382 let endpoint = NodeEndpoint::new(&node.host, node.port)
3383 .with_role(match node.role {
3384 NodeRole::Primary => crate::NodeRole::Primary,
3385 NodeRole::Standby => crate::NodeRole::Standby,
3386 NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
3387 })
3388 .with_weight(node.weight);
3389 pool_manager.add_node(&endpoint).await;
3390 tracing::info!("Added node {} to pool manager", node.address());
3391 }
3392 }
3393
3394 pub fn metrics(&self) -> ServerMetricsSnapshot {
3396 ServerMetricsSnapshot {
3397 connections_accepted: self.state.metrics.connections_accepted.load(Ordering::Relaxed),
3398 connections_closed: self.state.metrics.connections_closed.load(Ordering::Relaxed),
3399 queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
3400 bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
3401 bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
3402 failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
3403 }
3404 }
3405}
3406
3407#[derive(Debug, Clone)]
3409pub struct ServerMetricsSnapshot {
3410 pub connections_accepted: u64,
3411 pub connections_closed: u64,
3412 pub queries_processed: u64,
3413 pub bytes_received: u64,
3414 pub bytes_sent: u64,
3415 pub failovers: u64,
3416}
3417
3418#[cfg(feature = "pool-modes")]
3420#[derive(Debug, Clone)]
3421pub struct PoolModeStatsSnapshot {
3422 pub mode: String,
3424 pub total_connections: usize,
3426 pub active_leases: usize,
3428 pub idle_connections: usize,
3430 pub node_count: usize,
3432 pub acquires: u64,
3434 pub releases: u64,
3436 pub acquire_failures: u64,
3438 pub acquire_timeouts: u64,
3440 pub transactions_completed: u64,
3442 pub statements_executed: u64,
3444 pub avg_lease_duration_ms: u64,
3446}
3447
3448#[cfg(test)]
3449mod tests {
3450 use super::*;
3451 use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
3452
3453 fn test_config() -> ProxyConfig {
3454 let mut config = ProxyConfig::default();
3455 config.listen_address = "127.0.0.1:0".to_string();
3456 config
3457 .add_node("127.0.0.1:5432", "primary")
3458 .unwrap();
3459 config
3460 }
3461
3462 #[test]
3463 fn test_server_creation() {
3464 let config = test_config();
3465 let server = ProxyServer::new(config);
3466 assert!(server.is_ok());
3467 }
3468
3469 #[test]
3470 fn test_hba_addr_matches() {
3471 use std::net::IpAddr;
3472 let v4 = |s: &str| s.parse::<IpAddr>().unwrap();
3473 assert!(ProxyServer::hba_addr_matches("all", v4("203.0.113.7")));
3475 assert!(ProxyServer::hba_addr_matches("10.0.0.0/8", v4("10.1.2.3")));
3477 assert!(!ProxyServer::hba_addr_matches("10.0.0.0/8", v4("11.1.2.3")));
3478 assert!(ProxyServer::hba_addr_matches("127.0.0.1/32", v4("127.0.0.1")));
3479 assert!(!ProxyServer::hba_addr_matches("127.0.0.1/32", v4("127.0.0.2")));
3480 assert!(ProxyServer::hba_addr_matches("192.168.1.1", v4("192.168.1.1")));
3482 assert!(!ProxyServer::hba_addr_matches("192.168.1.1", v4("192.168.1.2")));
3483 assert!(ProxyServer::hba_addr_matches("::1/128", v4("::1")));
3485 assert!(ProxyServer::hba_addr_matches("0.0.0.0/0", v4("8.8.8.8")));
3486 }
3487
3488 #[test]
3489 fn test_hba_admits() {
3490 use crate::config::{HbaAction, HbaRule};
3491 use std::net::IpAddr;
3492 let ip: IpAddr = "10.0.0.5".parse().unwrap();
3493 assert!(ProxyServer::hba_admits(&[], ip, "bench", "benchdb"));
3495 let rules = vec![HbaRule {
3497 action: HbaAction::Reject,
3498 user: "bench".into(),
3499 database: "all".into(),
3500 address: "all".into(),
3501 }];
3502 assert!(!ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
3503 assert!(ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
3504 let rules = vec![
3506 HbaRule { action: HbaAction::Allow, user: "bench".into(), database: "all".into(), address: "10.0.0.0/8".into() },
3507 HbaRule { action: HbaAction::Reject, user: "all".into(), database: "all".into(), address: "all".into() },
3508 ];
3509 assert!(ProxyServer::hba_admits(&rules, ip, "bench", "benchdb"));
3510 assert!(!ProxyServer::hba_admits(&rules, "192.168.0.1".parse().unwrap(), "bench", "benchdb"));
3511 assert!(!ProxyServer::hba_admits(&rules, ip, "alice", "benchdb"));
3512 }
3513
3514 #[test]
3515 fn test_initial_metrics() {
3516 let config = test_config();
3517 let server = ProxyServer::new(config).unwrap();
3518 let metrics = server.metrics();
3519 assert_eq!(metrics.connections_accepted, 0);
3520 assert_eq!(metrics.queries_processed, 0);
3521 }
3522
3523 #[tokio::test]
3524 async fn test_session_creation() {
3525 let config = test_config();
3526 let server = ProxyServer::new(config).unwrap();
3527
3528 let sessions = server.state.sessions.read().await;
3529 assert!(sessions.is_empty());
3530 }
3531
3532 #[tokio::test]
3533 async fn test_node_health_initialization() {
3534 let config = test_config();
3535 let server = ProxyServer::new(config).unwrap();
3536
3537 let health = server.state.health.load_full();
3538 assert!(!health.is_empty());
3539
3540 for node_health in health.values() {
3541 assert!(node_health.healthy);
3542 assert_eq!(node_health.failure_count, 0);
3543 }
3544 }
3545
3546 fn make_test_session() -> Arc<ClientSession> {
3548 Arc::new(ClientSession {
3549 id: Uuid::new_v4(),
3550 client_addr: "127.0.0.1:0".parse().unwrap(),
3551 current_node: RwLock::new(None),
3552 tx_state: RwLock::new(TransactionState::default()),
3553 variables: RwLock::new(HashMap::new()),
3554 created_at: chrono::Utc::now(),
3555 tr_mode: crate::config::TrMode::default(),
3556 #[cfg(feature = "pool-modes")]
3557 pool_client_id: crate::pool::lease::ClientId::default(),
3558 #[cfg(feature = "wasm-plugins")]
3559 plugin_identity: RwLock::new(None),
3560 })
3561 }
3562
3563 #[tokio::test]
3567 async fn test_apply_route_hook_no_plugin_manager_returns_none() {
3568 let config = test_config();
3569 let server = ProxyServer::new(config).unwrap();
3570 let session = make_test_session();
3571
3572 let msg = QueryMessage {
3573 query: "SELECT * FROM users".to_string(),
3574 }
3575 .encode();
3576
3577 let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
3578 assert!(matches!(decision, RouteOverride::None));
3579 }
3580
3581 #[tokio::test]
3585 async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
3586 let config = test_config();
3587 let server = ProxyServer::new(config).unwrap();
3588 let session = make_test_session();
3589
3590 let original = QueryMessage {
3591 query: "SELECT 1".to_string(),
3592 }
3593 .encode();
3594 let original_bytes = original.encode().to_vec();
3595
3596 let (msg_out, action) =
3597 ProxyServer::apply_pre_query_hook(original, &server.state, &session);
3598
3599 assert!(matches!(action, PreQueryAction::Forward));
3600 assert_eq!(msg_out.encode().to_vec(), original_bytes);
3602 }
3603
3604 #[tokio::test]
3608 async fn test_apply_route_hook_skips_non_query_messages() {
3609 let config = test_config();
3610 let server = ProxyServer::new(config).unwrap();
3611 let session = make_test_session();
3612
3613 let sync_msg = Message::empty(MessageType::Sync);
3614 let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
3615 assert!(matches!(decision, RouteOverride::None));
3616 }
3617
3618 #[cfg(feature = "wasm-plugins")]
3623 #[test]
3624 fn test_init_plugin_manager_disabled_by_default_returns_none() {
3625 let config = test_config();
3626 assert!(!config.plugins.enabled);
3627 let pm = ProxyServer::init_plugin_manager(&config.plugins);
3628 assert!(pm.is_none());
3629 }
3630
3631 #[cfg(feature = "wasm-plugins")]
3635 #[test]
3636 fn test_init_plugin_manager_missing_dir_logs_warning() {
3637 let mut config = test_config();
3638 config.plugins.enabled = true;
3639 config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
3640
3641 let pm = ProxyServer::init_plugin_manager(&config.plugins);
3643 assert!(pm.is_some());
3644 }
3645
3646 #[tokio::test]
3650 async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
3651 let config = test_config();
3652 let server = ProxyServer::new(config).unwrap();
3653 let session = make_test_session();
3654
3655 let mut params = HashMap::new();
3656 params.insert("user".to_string(), "alice".to_string());
3657 params.insert("database".to_string(), "app".to_string());
3658
3659 let result =
3660 ProxyServer::apply_authenticate_hook(¶ms, &session, &server.state).await;
3661 assert!(result.is_ok());
3662
3663 #[cfg(feature = "wasm-plugins")]
3665 {
3666 let ident = session.plugin_identity.read().await;
3667 assert!(ident.is_none());
3668 }
3669 }
3670
3671 #[cfg(feature = "wasm-plugins")]
3679 #[test]
3680 fn test_synthesise_cached_response_roundtrip() {
3681 let payload = br#"{
3682 "columns": [
3683 {"name": "id", "oid": 23},
3684 {"name": "email", "oid": 25}
3685 ],
3686 "rows": [
3687 ["1", "alice@example.com"],
3688 ["2", null]
3689 ]
3690 }"#;
3691 let reply =
3692 ProxyServer::synthesise_cached_response(payload).expect("synthesis");
3693
3694 let mut tags = Vec::new();
3697 let mut i = 0;
3698 while i < reply.len() {
3699 let tag = reply[i];
3700 let len = u32::from_be_bytes([
3701 reply[i + 1],
3702 reply[i + 2],
3703 reply[i + 3],
3704 reply[i + 4],
3705 ]) as usize;
3706 tags.push(tag);
3707 i += 1 + len;
3708 }
3709 assert_eq!(i, reply.len(), "no trailing bytes");
3710 assert_eq!(
3711 tags,
3712 vec![b'T', b'D', b'D', b'C', b'Z'],
3713 "wire frame order"
3714 );
3715
3716 assert_eq!(*reply.last().unwrap(), b'I');
3718 }
3719
3720 #[cfg(feature = "wasm-plugins")]
3723 #[test]
3724 fn test_synthesise_cached_response_rejects_row_width_mismatch() {
3725 let payload = br#"{
3726 "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
3727 "rows": [["1", "alice", "extra"]]
3728 }"#;
3729 let result = ProxyServer::synthesise_cached_response(payload);
3730 assert!(matches!(result, Err(ProxyError::Protocol(_))));
3731 }
3732
3733 #[cfg(feature = "wasm-plugins")]
3737 #[test]
3738 fn test_synthesise_cached_response_rejects_empty_columns() {
3739 let payload = br#"{ "columns": [], "rows": [] }"#;
3740 let result = ProxyServer::synthesise_cached_response(payload);
3741 assert!(matches!(result, Err(ProxyError::Protocol(_))));
3742 }
3743
3744 #[cfg(feature = "wasm-plugins")]
3747 #[test]
3748 fn test_synthesise_cached_response_rejects_bad_json() {
3749 let payload = b"not json at all";
3750 let result = ProxyServer::synthesise_cached_response(payload);
3751 assert!(matches!(result, Err(ProxyError::Protocol(_))));
3752 }
3753
3754 #[cfg(feature = "wasm-plugins")]
3763 #[tokio::test]
3764 async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
3765 use crate::plugins::{PluginManager, PluginRuntimeConfig};
3766
3767 let config = test_config();
3768 let server = ProxyServer::new(config).unwrap();
3769 let session = make_test_session();
3770
3771 let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
3774 let augmented_state = Arc::new(ServerState {
3775 sessions: RwLock::new(HashMap::new()),
3776 health: ArcSwap::from_pointee(HashMap::new()),
3777 live_config: ArcSwap::from_pointee(ProxyConfig::default()),
3778 metrics: ServerMetrics::default(),
3779 cancel_map: Arc::new(DashMap::new()),
3780 tls_acceptor: None,
3781 auth_file: None,
3782 mirror: None,
3783 cutover: Arc::new(ArcSwap::from_pointee(None)),
3784 lb_state: LoadBalancerState {
3785 rr_counter: AtomicU64::new(0),
3786 },
3787 #[cfg(feature = "pool-modes")]
3788 pool_manager: None,
3789 plugin_manager: Some(pm),
3790 #[cfg(feature = "ha-tr")]
3791 transaction_journal: Arc::new(
3792 crate::transaction_journal::TransactionJournal::new(),
3793 ),
3794 #[cfg(feature = "anomaly-detection")]
3795 anomaly_detector: Arc::new(
3796 crate::anomaly::AnomalyDetector::new(
3797 crate::anomaly::AnomalyConfig::default(),
3798 ),
3799 ),
3800 #[cfg(feature = "edge-proxy")]
3801 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
3802 #[cfg(feature = "edge-proxy")]
3803 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
3804 32,
3805 std::time::Duration::from_secs(120),
3806 )),
3807 });
3808
3809 let mut params = HashMap::new();
3810 params.insert("user".to_string(), "alice".to_string());
3811
3812 let result =
3813 ProxyServer::apply_authenticate_hook(¶ms, &session, &augmented_state).await;
3814 assert!(result.is_ok());
3815 let ident = session.plugin_identity.read().await;
3816 assert!(ident.is_none());
3817 let _ = server;
3819 }
3820
3821 fn cstr(s: &str) -> Vec<u8> {
3824 let mut v = s.as_bytes().to_vec();
3825 v.push(0);
3826 v
3827 }
3828
3829 #[test]
3830 fn parse_stmt_name_extracts_named_and_unnamed() {
3831 let mut named = cstr("ps1");
3833 named.extend_from_slice(&cstr("SELECT 1"));
3834 named.extend_from_slice(&[0, 0]);
3835 assert_eq!(ProxyServer::parse_stmt_name(&named), "ps1");
3836
3837 let mut unnamed = cstr("");
3838 unnamed.extend_from_slice(&cstr("SELECT 1"));
3839 unnamed.extend_from_slice(&[0, 0]);
3840 assert_eq!(ProxyServer::parse_stmt_name(&unnamed), "");
3841 }
3842
3843 #[test]
3844 fn bind_stmt_ref_reads_second_cstring() {
3845 let mut named = cstr("portal_a");
3847 named.extend_from_slice(&cstr("ps1"));
3848 named.extend_from_slice(&[0, 0]); assert_eq!(ProxyServer::bind_stmt_ref(&named), Some("ps1"));
3850
3851 let mut unnamed = cstr("");
3853 unnamed.extend_from_slice(&cstr(""));
3854 assert_eq!(ProxyServer::bind_stmt_ref(&unnamed), None);
3855 }
3856
3857 #[test]
3858 fn stmt_kind_name_only_matches_statement_kind() {
3859 let mut stmt = vec![b'S'];
3861 stmt.extend_from_slice(&cstr("ps1"));
3862 assert_eq!(ProxyServer::stmt_kind_name(&stmt), Some("ps1"));
3863
3864 let mut portal = vec![b'P'];
3866 portal.extend_from_slice(&cstr("portal_a"));
3867 assert_eq!(ProxyServer::stmt_kind_name(&portal), None);
3868
3869 let mut empty = vec![b'S'];
3871 empty.extend_from_slice(&cstr(""));
3872 assert_eq!(ProxyServer::stmt_kind_name(&empty), None);
3873 }
3874
3875 #[tokio::test]
3876 async fn read_one_frame_type_consumes_full_frame() {
3877 let (mut a, mut b) = tokio::io::duplex(64);
3880 let bytes = [b'1', 0, 0, 0, 4, b'Z', 0, 0, 0, 5, b'I'];
3882 b.write_all(&bytes).await.unwrap();
3883 let t = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
3884 assert_eq!(t, b'1');
3885 let t2 = ProxyServer::read_one_frame_type(&mut a).await.unwrap();
3887 assert_eq!(t2, b'Z');
3888 }
3889
3890 #[tokio::test]
3891 async fn reprepare_statement_accepts_parse_complete_and_rejects_error() {
3892 let (mut client, mut backend) = tokio::io::duplex(64);
3894 backend.write_all(&[b'1', 0, 0, 0, 4]).await.unwrap();
3895 let parse = {
3896 let mut p = vec![b'P', 0, 0, 0, 0];
3897 p.extend_from_slice(&cstr("ps1"));
3898 p.extend_from_slice(&cstr("SELECT 1"));
3899 p.extend_from_slice(&[0, 0]);
3900 p
3901 };
3902 assert!(ProxyServer::reprepare_statement(&mut client, &parse).await.is_ok());
3903
3904 let (mut client2, mut backend2) = tokio::io::duplex(64);
3906 backend2.write_all(&[b'E', 0, 0, 0, 4]).await.unwrap();
3907 assert!(ProxyServer::reprepare_statement(&mut client2, &parse).await.is_err());
3908 }
3909}