1use crate::admin::{AdminServer, AdminState, ConfigSnapshot, NodeSnapshot};
7#[cfg(feature = "ha-tr")]
8use crate::backend::{tls::default_client_config, BackendConfig, TlsMode};
9use crate::config::{NodeConfig, NodeRole, ProxyConfig, TrMode};
10use crate::protocol::{
11 ErrorResponse, Message, MessageType, ParseMessage, ProtocolCodec, QueryMessage,
12 StartupMessage, TransactionStatus,
13};
14use crate::{ProxyError, Result};
15use bytes::{BufMut, BytesMut};
16use std::collections::HashMap;
17use std::net::SocketAddr;
18use std::sync::atomic::{AtomicU64, Ordering};
19use std::sync::Arc;
20use std::time::Duration;
21use tokio::io::{AsyncReadExt, AsyncWriteExt};
22use tokio::net::{TcpListener, TcpStream};
23use tokio::sync::{broadcast, RwLock};
24use uuid::Uuid;
25
26#[cfg(feature = "pool-modes")]
28use crate::pool::{
29 ConnectionPoolManager, PoolModeConfig, PoolingMode,
30};
31#[cfg(feature = "pool-modes")]
32use crate::pool::lease::ClientId;
33#[cfg(feature = "pool-modes")]
34use crate::NodeEndpoint;
35
36#[cfg(feature = "wasm-plugins")]
38use crate::plugins::{
39 AuthRequest as PluginAuthRequest, AuthResult, HookContext, Identity, PluginManager,
40 PostQueryOutcome, PreQueryResult, QueryContext, RouteResult,
41};
42
43pub struct ProxyServer {
45 config: ProxyConfig,
46 state: Arc<ServerState>,
47 shutdown_tx: broadcast::Sender<()>,
48}
49
50#[cfg(feature = "ha-tr")]
64fn build_replay_backend_template(_config: &ProxyConfig) -> BackendConfig {
65 BackendConfig {
66 host: "placeholder".to_string(),
67 port: 0,
68 user: "postgres".to_string(),
69 password: None,
70 database: None,
71 application_name: Some("heliosdb-proxy-replay".to_string()),
72 tls_mode: TlsMode::Disable,
73 connect_timeout: Duration::from_secs(5),
74 query_timeout: Duration::from_secs(30),
75 tls_config: default_client_config(),
76 }
77}
78
79#[cfg(feature = "anomaly-detection")]
89fn anomaly_fingerprint(sql: &str) -> String {
90 let mut out = String::with_capacity(sql.len());
91 let mut in_single = false;
92 let mut prev_space = false;
93 let mut chars = sql.chars().peekable();
94 while let Some(c) = chars.next() {
95 if c == '\'' {
96 in_single = !in_single;
97 if in_single {
100 out.push('?');
101 while let Some(&n) = chars.peek() {
102 chars.next();
103 if n == '\'' {
104 in_single = false;
105 break;
106 }
107 }
108 prev_space = false;
109 continue;
110 }
111 }
112 if c.is_ascii_digit() {
113 if !out.ends_with('?') {
114 out.push('?');
115 }
116 while matches!(chars.peek(), Some(c) if c.is_ascii_digit() || *c == '.') {
118 chars.next();
119 }
120 prev_space = false;
121 continue;
122 }
123 if c.is_ascii_whitespace() {
124 if !prev_space && !out.is_empty() {
125 out.push(' ');
126 prev_space = true;
127 }
128 continue;
129 }
130 out.push(c.to_ascii_lowercase());
131 prev_space = false;
132 }
133 out.trim_end().to_string()
134}
135
136struct ServerState {
138 sessions: RwLock<HashMap<Uuid, Arc<ClientSession>>>,
140 health: RwLock<HashMap<String, NodeHealth>>,
142 metrics: ServerMetrics,
144 lb_state: RwLock<LoadBalancerState>,
146 #[cfg(feature = "pool-modes")]
148 pool_manager: Option<Arc<ConnectionPoolManager>>,
149 #[cfg(feature = "wasm-plugins")]
153 plugin_manager: Option<Arc<PluginManager>>,
154 #[cfg(feature = "ha-tr")]
159 transaction_journal: Arc<crate::transaction_journal::TransactionJournal>,
160 #[cfg(feature = "anomaly-detection")]
163 anomaly_detector: Arc<crate::anomaly::AnomalyDetector>,
164 #[cfg(feature = "edge-proxy")]
168 edge_cache: Arc<crate::edge::EdgeCache>,
169 #[cfg(feature = "edge-proxy")]
170 edge_registry: Arc<crate::edge::EdgeRegistry>,
171}
172
173#[derive(Debug, Clone)]
175pub struct NodeHealth {
176 pub address: String,
178 pub healthy: bool,
180 pub last_check: chrono::DateTime<chrono::Utc>,
182 pub failure_count: u32,
184 pub last_error: Option<String>,
186 pub latency_ms: f64,
188 pub replication_lag_bytes: Option<u64>,
190}
191
192#[derive(Default)]
194struct ServerMetrics {
195 connections_accepted: AtomicU64,
197 connections_closed: AtomicU64,
199 queries_processed: AtomicU64,
201 bytes_received: AtomicU64,
203 bytes_sent: AtomicU64,
205 failovers: AtomicU64,
207}
208
209struct LoadBalancerState {
211 rr_counter: u64,
213}
214
215pub struct ClientSession {
217 pub id: Uuid,
219 pub client_addr: SocketAddr,
221 pub current_node: RwLock<Option<String>>,
223 pub tx_state: RwLock<TransactionState>,
225 pub variables: RwLock<HashMap<String, String>>,
227 pub created_at: chrono::DateTime<chrono::Utc>,
229 pub tr_mode: TrMode,
231 #[cfg(feature = "pool-modes")]
233 pub pool_client_id: ClientId,
234 #[cfg(feature = "wasm-plugins")]
239 pub plugin_identity: RwLock<Option<Identity>>,
240}
241
242#[derive(Debug, Clone, Default)]
244pub struct TransactionState {
245 pub in_transaction: bool,
247 pub tx_id: Option<Uuid>,
249 pub statements: Vec<StatementLog>,
251 pub read_only: bool,
253 pub savepoints: Vec<String>,
255}
256
257#[derive(Debug, Clone)]
259pub struct StatementLog {
260 pub sql: String,
262 pub params: Vec<String>,
264 pub result_checksum: Option<u64>,
266 pub executed_at: chrono::DateTime<chrono::Utc>,
268}
269
270#[derive(Debug)]
276#[allow(dead_code)] enum PreQueryAction {
278 Forward,
280 Block(String),
283 Cached(Vec<u8>),
289}
290
291#[derive(Debug)]
297#[allow(dead_code)] enum RouteOverride {
299 None,
301 Primary,
303 Standby,
305 Node(String),
309 Block(String),
314}
315
316impl ProxyServer {
317 #[cfg(feature = "wasm-plugins")]
325 fn init_plugin_manager(
326 toml_cfg: &crate::config::PluginToml,
327 ) -> Option<Arc<crate::plugins::PluginManager>> {
328 if !toml_cfg.enabled {
329 return None;
330 }
331
332 let runtime_cfg = crate::plugins::PluginRuntimeConfig::from(toml_cfg);
333 let plugin_dir = runtime_cfg.plugin_dir.clone();
334
335 let pm = match crate::plugins::PluginManager::new(runtime_cfg) {
336 Ok(pm) => Arc::new(pm),
337 Err(e) => {
338 tracing::error!(error = %e, "Failed to create plugin manager; plugins disabled");
339 return None;
340 }
341 };
342
343 match std::fs::read_dir(&plugin_dir) {
344 Ok(entries) => {
345 let mut loaded = 0usize;
346 let mut failed = 0usize;
347 for entry in entries.flatten() {
348 let path = entry.path();
349 if path.extension().and_then(|s| s.to_str()) != Some("wasm") {
350 continue;
351 }
352 match pm.load_plugin(&path) {
353 Ok(()) => loaded += 1,
354 Err(e) => {
355 failed += 1;
356 tracing::warn!(
357 path = %path.display(),
358 error = %e,
359 "Failed to load plugin"
360 );
361 }
362 }
363 }
364 tracing::info!(
365 dir = %plugin_dir.display(),
366 loaded = loaded,
367 failed = failed,
368 "Plugin loading complete"
369 );
370 }
371 Err(e) => {
372 tracing::warn!(
373 dir = %plugin_dir.display(),
374 error = %e,
375 "Plugin directory not readable; no plugins loaded"
376 );
377 }
378 }
379
380 Some(pm)
381 }
382
383 pub fn new(config: ProxyConfig) -> Result<Self> {
385 let (shutdown_tx, _) = broadcast::channel(1);
386
387 let mut health = HashMap::new();
389 for node in &config.nodes {
390 health.insert(
391 node.address(),
392 NodeHealth {
393 address: node.address(),
394 healthy: true, last_check: chrono::Utc::now(),
396 failure_count: 0,
397 last_error: None,
398 latency_ms: 0.0,
399 replication_lag_bytes: None,
400 },
401 );
402 }
403
404 #[cfg(feature = "pool-modes")]
406 let pool_manager = {
407 use crate::pool::PreparedStatementMode as PoolPreparedStatementMode;
408
409 let pool_config = PoolModeConfig {
410 default_mode: match config.pool_mode.mode {
411 crate::config::PoolingMode::Session => PoolingMode::Session,
412 crate::config::PoolingMode::Transaction => PoolingMode::Transaction,
413 crate::config::PoolingMode::Statement => PoolingMode::Statement,
414 },
415 max_pool_size: config.pool_mode.max_pool_size,
416 min_idle: config.pool_mode.min_idle,
417 idle_timeout_secs: config.pool_mode.idle_timeout_secs,
418 max_lifetime_secs: config.pool_mode.max_lifetime_secs,
419 acquire_timeout_secs: config.pool_mode.acquire_timeout_secs,
420 reset_query: config.pool_mode.reset_query.clone(),
421 prepared_statement_mode: match config.pool_mode.prepared_statement_mode {
422 crate::config::PreparedStatementMode::Disable => {
423 PoolPreparedStatementMode::Disable
424 }
425 crate::config::PreparedStatementMode::Track => {
426 PoolPreparedStatementMode::Track
427 }
428 crate::config::PreparedStatementMode::Named => {
429 PoolPreparedStatementMode::Named
430 }
431 },
432 test_on_acquire: config.pool.test_on_acquire,
433 validation_query: "SELECT 1".to_string(),
434 queue_timeout_secs: 30,
435 max_queue_size: 0,
436 };
437 Some(Arc::new(ConnectionPoolManager::new(pool_config)))
438 };
439
440 #[cfg(feature = "wasm-plugins")]
445 let plugin_manager = Self::init_plugin_manager(&config.plugins);
446
447 let state = Arc::new(ServerState {
448 sessions: RwLock::new(HashMap::new()),
449 health: RwLock::new(health),
450 metrics: ServerMetrics::default(),
451 lb_state: RwLock::new(LoadBalancerState {
452 rr_counter: 0,
453 }),
454 #[cfg(feature = "pool-modes")]
455 pool_manager,
456 #[cfg(feature = "wasm-plugins")]
457 plugin_manager,
458 #[cfg(feature = "ha-tr")]
459 transaction_journal: Arc::new(
460 crate::transaction_journal::TransactionJournal::new(),
461 ),
462 #[cfg(feature = "anomaly-detection")]
463 anomaly_detector: Arc::new(
464 crate::anomaly::AnomalyDetector::new(
465 crate::anomaly::AnomalyConfig::default(),
466 ),
467 ),
468 #[cfg(feature = "edge-proxy")]
469 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
470 #[cfg(feature = "edge-proxy")]
471 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
472 32,
473 std::time::Duration::from_secs(120),
474 )),
475 });
476
477 Ok(Self {
478 config,
479 state,
480 shutdown_tx,
481 })
482 }
483
484 pub async fn run(&self) -> Result<()> {
486 let listener = TcpListener::bind(&self.config.listen_address)
487 .await
488 .map_err(|e| ProxyError::Network(format!("Failed to bind: {}", e)))?;
489
490 tracing::info!("Proxy listening on {}", self.config.listen_address);
491
492 let health_task = self.spawn_health_checker();
494 let pool_task = self.spawn_pool_manager();
495
496 let admin_task = self.spawn_admin_server();
498
499 let mut shutdown_rx = self.shutdown_tx.subscribe();
500
501 loop {
502 tokio::select! {
503 accept_result = listener.accept() => {
504 match accept_result {
505 Ok((stream, addr)) => {
506 self.state.metrics.connections_accepted.fetch_add(1, Ordering::Relaxed);
507 let state = self.state.clone();
508 let config = self.config.clone();
509 let shutdown_tx = self.shutdown_tx.clone();
510
511 tokio::spawn(async move {
512 if let Err(e) = Self::handle_client(stream, addr, state, config, shutdown_tx).await {
513 tracing::error!("Client handler error: {}", e);
514 }
515 });
516 }
517 Err(e) => {
518 tracing::error!("Accept error: {}", e);
519 }
520 }
521 }
522 _ = shutdown_rx.recv() => {
523 tracing::info!("Shutdown signal received");
524 break;
525 }
526 }
527 }
528
529 health_task.abort();
531 pool_task.abort();
532 admin_task.abort();
533
534 Ok(())
535 }
536
537 fn spawn_admin_server(&self) -> tokio::task::JoinHandle<()> {
539 let config = self.config.clone();
540 let state = self.state.clone();
541 let mut shutdown_rx = self.shutdown_tx.subscribe();
542
543 tokio::spawn(async move {
544 let admin_state = Arc::new(AdminState::new());
546
547 {
549 let mut snapshot = admin_state.config_snapshot.write().await;
550 *snapshot = ConfigSnapshot {
551 listen_address: config.listen_address.clone(),
552 admin_address: config.admin_address.clone(),
553 tr_enabled: config.tr_enabled,
554 tr_mode: format!("{:?}", config.tr_mode),
555 pool_min_connections: config.pool.min_connections,
556 pool_max_connections: config.pool.max_connections,
557 nodes: config.nodes.iter().map(|n| NodeSnapshot {
558 address: n.address(),
559 role: format!("{:?}", n.role),
560 weight: n.weight,
561 enabled: n.enabled,
562 }).collect(),
563 };
564 }
565
566 admin_state.set_proxy_config(config.clone()).await;
568
569 #[cfg(feature = "wasm-plugins")]
574 if let Some(ref pm) = state.plugin_manager {
575 admin_state.with_plugin_manager(pm.clone()).await;
576 }
577
578 #[cfg(feature = "ha-tr")]
585 {
586 let template = build_replay_backend_template(&config);
587 let engine = Arc::new(crate::replay::ReplayEngine::new(
588 state.transaction_journal.clone(),
589 template,
590 ));
591 admin_state.with_replay_engine(engine).await;
592 }
593
594 #[cfg(feature = "anomaly-detection")]
598 admin_state
599 .with_anomaly_detector(state.anomaly_detector.clone())
600 .await;
601
602 #[cfg(feature = "edge-proxy")]
605 admin_state
606 .with_edge(state.edge_cache.clone(), state.edge_registry.clone())
607 .await;
608
609 let admin_server = AdminServer::new(config.admin_address.clone(), admin_state.clone());
611
612 let admin_state_sync = admin_state.clone();
614 let server_state = state.clone();
615 let sync_task = tokio::spawn(async move {
616 let mut interval = tokio::time::interval(std::time::Duration::from_secs(1));
617 loop {
618 interval.tick().await;
619
620 {
622 let health = server_state.health.read().await;
623 let mut admin_health = admin_state_sync.node_health.write().await;
624 *admin_health = health.clone();
625 }
626
627 {
629 let metrics = ServerMetricsSnapshot {
630 connections_accepted: server_state.metrics.connections_accepted.load(Ordering::Relaxed),
631 connections_closed: server_state.metrics.connections_closed.load(Ordering::Relaxed),
632 queries_processed: server_state.metrics.queries_processed.load(Ordering::Relaxed),
633 bytes_received: server_state.metrics.bytes_received.load(Ordering::Relaxed),
634 bytes_sent: server_state.metrics.bytes_sent.load(Ordering::Relaxed),
635 failovers: server_state.metrics.failovers.load(Ordering::Relaxed),
636 };
637 let mut admin_metrics = admin_state_sync.metrics.write().await;
638 *admin_metrics = metrics;
639 }
640
641 {
643 let sessions = server_state.sessions.read().await;
644 let mut admin_sessions = admin_state_sync.active_sessions.write().await;
645 *admin_sessions = sessions.len() as u64;
646 }
647 }
648 });
649
650 tokio::select! {
652 result = admin_server.run() => {
653 if let Err(e) = result {
654 tracing::error!("Admin server error: {}", e);
655 }
656 }
657 _ = shutdown_rx.recv() => {
658 tracing::info!("Admin server shutting down");
659 }
660 }
661
662 sync_task.abort();
663 })
664 }
665
666 async fn handle_client(
668 mut stream: TcpStream,
669 addr: SocketAddr,
670 state: Arc<ServerState>,
671 config: ProxyConfig,
672 _shutdown_tx: broadcast::Sender<()>,
673 ) -> Result<()> {
674 tracing::debug!("New client connection from {}", addr);
675
676 let session = Arc::new(ClientSession {
678 id: Uuid::new_v4(),
679 client_addr: addr,
680 current_node: RwLock::new(None),
681 tx_state: RwLock::new(TransactionState::default()),
682 variables: RwLock::new(HashMap::new()),
683 created_at: chrono::Utc::now(),
684 tr_mode: config.tr_mode,
685 #[cfg(feature = "pool-modes")]
686 pool_client_id: ClientId::new(),
687 #[cfg(feature = "wasm-plugins")]
688 plugin_identity: RwLock::new(None),
689 });
690
691 {
693 let mut sessions = state.sessions.write().await;
694 sessions.insert(session.id, session.clone());
695 }
696
697 let result = Self::client_loop(&mut stream, &session, &state, &config).await;
699
700 {
702 let mut sessions = state.sessions.write().await;
703 sessions.remove(&session.id);
704 }
705
706 #[cfg(feature = "pool-modes")]
708 if let Some(ref pool_manager) = state.pool_manager {
709 if pool_manager.has_active_lease(&session.pool_client_id) {
711 tracing::debug!(
712 "Releasing pool lease for disconnecting client {:?}",
713 session.pool_client_id
714 );
715 }
718 }
719
720 state
721 .metrics
722 .connections_closed
723 .fetch_add(1, Ordering::Relaxed);
724
725 result
726 }
727
728 async fn client_loop(
730 stream: &mut TcpStream,
731 session: &Arc<ClientSession>,
732 state: &Arc<ServerState>,
733 config: &ProxyConfig,
734 ) -> Result<()> {
735 let codec = ProtocolCodec::new();
736 let mut buffer = BytesMut::with_capacity(8192);
737
738 let (mut backend_stream, mut backend_node): (Option<TcpStream>, Option<String>) =
740 match Self::handle_startup(stream, &mut buffer, &codec, session, state, config).await {
741 Ok((Some(stream_conn), node_addr)) => (Some(stream_conn), Some(node_addr)),
742 Ok((None, _)) => {
743 return Ok(());
745 }
746 Err(e) => {
747 tracing::error!("Startup failed: {}", e);
748 let err_msg = Self::create_error_response("08006", &format!("Startup failed: {}", e));
750 let _ = stream.write_all(&err_msg).await;
751 return Err(e);
752 }
753 };
754
755 loop {
757 let mut read_buf = vec![0u8; 8192];
759 let n = stream
760 .read(&mut read_buf)
761 .await
762 .map_err(|e| ProxyError::Network(format!("Read error: {}", e)))?;
763
764 if n == 0 {
765 break;
767 }
768
769 buffer.extend_from_slice(&read_buf[..n]);
770 state.metrics.bytes_received.fetch_add(n as u64, Ordering::Relaxed);
771
772 while let Some(msg) = codec.decode_message(&mut buffer)? {
774 if msg.msg_type == MessageType::Terminate {
776 return Ok(());
777 }
778
779 #[cfg(feature = "anomaly-detection")]
785 Self::record_anomaly_observation(&msg, state, session);
786
787 let (msg, action) = Self::apply_pre_query_hook(msg, state, session);
790
791 if let PreQueryAction::Block(reason) = &action {
792 tracing::info!(reason = %reason, "pre-query plugin blocked query");
793 Self::send_block_response(stream, reason, state).await?;
794 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
795 continue;
796 }
797
798 #[cfg(feature = "wasm-plugins")]
805 if let PreQueryAction::Cached(bytes) = &action {
806 match Self::synthesise_cached_response(bytes) {
807 Ok(reply) => {
808 stream
809 .write_all(&reply)
810 .await
811 .map_err(|e| {
812 ProxyError::Network(format!("Write error: {}", e))
813 })?;
814 state
815 .metrics
816 .bytes_sent
817 .fetch_add(reply.len() as u64, Ordering::Relaxed);
818 state
819 .metrics
820 .queries_processed
821 .fetch_add(1, Ordering::Relaxed);
822 continue;
823 }
824 Err(e) => {
825 tracing::warn!(
826 error = %e,
827 "failed to synthesise cached response; falling back to backend"
828 );
829 }
831 }
832 }
833
834 #[cfg(feature = "wasm-plugins")]
836 let forward_start = std::time::Instant::now();
837 let forward_result = Self::route_and_forward(
838 &msg,
839 backend_stream.take(),
840 backend_node.take(),
841 session,
842 state,
843 config,
844 )
845 .await;
846 #[cfg(feature = "wasm-plugins")]
847 Self::fire_post_query_hook(
848 &msg,
849 session,
850 state,
851 &forward_result,
852 forward_start.elapsed(),
853 );
854 let (response, new_backend, new_node) = forward_result?;
855
856 backend_stream = new_backend;
857 backend_node = new_node;
858
859 if !response.is_empty() {
861 stream
862 .write_all(&response)
863 .await
864 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
865
866 state
867 .metrics
868 .bytes_sent
869 .fetch_add(response.len() as u64, Ordering::Relaxed);
870 }
871
872 state.metrics.queries_processed.fetch_add(1, Ordering::Relaxed);
873 }
874 }
875
876 Ok(())
877 }
878
879 async fn handle_startup(
881 client_stream: &mut TcpStream,
882 buffer: &mut BytesMut,
883 codec: &ProtocolCodec,
884 session: &Arc<ClientSession>,
885 state: &Arc<ServerState>,
886 config: &ProxyConfig,
887 ) -> Result<(Option<TcpStream>, String)> {
888 let mut read_buf = vec![0u8; 1024];
890 let n = client_stream
891 .read(&mut read_buf)
892 .await
893 .map_err(|e| ProxyError::Network(format!("Startup read error: {}", e)))?;
894
895 if n == 0 {
896 return Ok((None, String::new()));
897 }
898
899 buffer.extend_from_slice(&read_buf[..n]);
900
901 let startup_msg = codec.decode_startup(buffer)?;
903
904 match startup_msg {
905 Some(StartupMessage::SSLRequest) => {
906 client_stream
908 .write_all(&[b'N'])
909 .await
910 .map_err(|e| ProxyError::Network(format!("SSL reject error: {}", e)))?;
911
912 buffer.clear();
914 let n = client_stream
915 .read(&mut read_buf)
916 .await
917 .map_err(|e| ProxyError::Network(format!("Post-SSL read error: {}", e)))?;
918
919 if n == 0 {
920 return Ok((None, String::new()));
921 }
922
923 buffer.extend_from_slice(&read_buf[..n]);
924
925 return Self::process_startup(
927 client_stream,
928 buffer,
929 codec,
930 session,
931 state,
932 config,
933 )
934 .await;
935 }
936 Some(StartupMessage::CancelRequest { .. }) => {
937 return Ok((None, String::new()));
939 }
940 Some(StartupMessage::Startup { params, .. }) => {
941 return Self::connect_and_authenticate(
943 client_stream,
944 ¶ms,
945 session,
946 state,
947 config,
948 )
949 .await;
950 }
951 None => {
952 return Err(ProxyError::Protocol("Incomplete startup message".to_string()));
953 }
954 }
955 }
956
957 async fn process_startup(
959 client_stream: &mut TcpStream,
960 buffer: &mut BytesMut,
961 codec: &ProtocolCodec,
962 session: &Arc<ClientSession>,
963 state: &Arc<ServerState>,
964 config: &ProxyConfig,
965 ) -> Result<(Option<TcpStream>, String)> {
966 let startup_msg = codec.decode_startup(buffer)?;
967
968 match startup_msg {
969 Some(StartupMessage::Startup { params, .. }) => {
970 Self::connect_and_authenticate(client_stream, ¶ms, session, state, config).await
971 }
972 _ => Err(ProxyError::Protocol("Expected startup message".to_string())),
973 }
974 }
975
976 async fn connect_and_authenticate(
978 client_stream: &mut TcpStream,
979 params: &HashMap<String, String>,
980 session: &Arc<ClientSession>,
981 state: &Arc<ServerState>,
982 config: &ProxyConfig,
983 ) -> Result<(Option<TcpStream>, String)> {
984 Self::apply_authenticate_hook(params, session, state).await?;
990
991 let node_addr = Self::select_node(session, state, config).await?;
993
994 let mut backend = tokio::time::timeout(
996 config.pool.acquire_timeout(),
997 TcpStream::connect(&node_addr),
998 )
999 .await
1000 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", node_addr)))?
1001 .map_err(|e| ProxyError::Connection(format!("Failed to connect to {}: {}", node_addr, e)))?;
1002
1003 let startup_bytes = Self::build_startup_message(params);
1005 backend
1006 .write_all(&startup_bytes)
1007 .await
1008 .map_err(|e| ProxyError::Network(format!("Backend startup write error: {}", e)))?;
1009
1010 Self::proxy_authentication(client_stream, &mut backend).await?;
1012
1013 {
1015 let mut vars = session.variables.write().await;
1016 for (k, v) in params {
1017 vars.insert(k.clone(), v.clone());
1018 }
1019 }
1020
1021 Ok((Some(backend), node_addr))
1022 }
1023
1024 fn build_startup_message(params: &HashMap<String, String>) -> Vec<u8> {
1026 let mut payload = BytesMut::new();
1027
1028 payload.put_u32(196608);
1030
1031 for (key, value) in params {
1033 payload.extend_from_slice(key.as_bytes());
1034 payload.put_u8(0);
1035 payload.extend_from_slice(value.as_bytes());
1036 payload.put_u8(0);
1037 }
1038 payload.put_u8(0); let mut msg = BytesMut::new();
1042 msg.put_u32((payload.len() + 4) as u32);
1043 msg.extend_from_slice(&payload);
1044
1045 msg.to_vec()
1046 }
1047
1048 async fn proxy_authentication(
1050 client_stream: &mut TcpStream,
1051 backend_stream: &mut TcpStream,
1052 ) -> Result<()> {
1053 let codec = ProtocolCodec::new();
1054 let mut backend_buffer = BytesMut::with_capacity(4096);
1055 let mut client_buffer = BytesMut::with_capacity(4096);
1056
1057 loop {
1058 let mut read_buf = vec![0u8; 4096];
1060 let n = backend_stream
1061 .read(&mut read_buf)
1062 .await
1063 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1064
1065 if n == 0 {
1066 return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1067 }
1068
1069 backend_buffer.extend_from_slice(&read_buf[..n]);
1070
1071 client_stream
1073 .write_all(&read_buf[..n])
1074 .await
1075 .map_err(|e| ProxyError::Network(format!("Client auth write error: {}", e)))?;
1076
1077 while let Some(msg) = codec.decode_message(&mut backend_buffer.clone())? {
1079 match msg.msg_type {
1080 MessageType::AuthRequest => {
1081 if msg.payload.len() >= 4 {
1083 let auth_type =
1084 i32::from_be_bytes([msg.payload[0], msg.payload[1], msg.payload[2], msg.payload[3]]);
1085 if auth_type == 0 {
1086 }
1088 }
1089 }
1090 MessageType::ReadyForQuery => {
1091 return Ok(());
1093 }
1094 MessageType::ErrorResponse => {
1095 return Err(ProxyError::Auth("Authentication failed".to_string()));
1097 }
1098 _ => {
1099 }
1101 }
1102 let _ = codec.decode_message(&mut backend_buffer)?;
1104 }
1105
1106 let n = tokio::time::timeout(Duration::from_millis(100), client_stream.read(&mut read_buf))
1109 .await;
1110
1111 if let Ok(Ok(n)) = n {
1112 if n > 0 {
1113 client_buffer.extend_from_slice(&read_buf[..n]);
1114 backend_stream
1115 .write_all(&read_buf[..n])
1116 .await
1117 .map_err(|e| ProxyError::Network(format!("Backend password write error: {}", e)))?;
1118 }
1119 }
1120 }
1121 }
1122
1123 async fn route_and_forward(
1125 msg: &Message,
1126 backend_stream: Option<TcpStream>,
1127 current_node: Option<String>,
1128 session: &Arc<ClientSession>,
1129 state: &Arc<ServerState>,
1130 config: &ProxyConfig,
1131 ) -> Result<(Vec<u8>, Option<TcpStream>, Option<String>)> {
1132 let default_is_write = Self::is_write_message(msg);
1134
1135 let route_override = Self::apply_route_hook(msg, state, session);
1138
1139 if let RouteOverride::Block(reason) = route_override {
1144 let mut response = Vec::with_capacity(64 + reason.len());
1145 response.extend_from_slice(&Self::create_error_response(
1146 "42000",
1147 &format!("Query blocked by route plugin: {}", reason),
1148 ));
1149 response.extend_from_slice(&Self::create_ready_for_query(b'I'));
1150 state
1151 .metrics
1152 .bytes_sent
1153 .fetch_add(response.len() as u64, Ordering::Relaxed);
1154 return Ok((response, backend_stream, current_node));
1155 }
1156
1157 let (is_write, forced_target) = match route_override {
1159 RouteOverride::None => (default_is_write, None),
1160 RouteOverride::Primary => (true, None),
1161 RouteOverride::Standby => (false, None),
1162 RouteOverride::Node(name) => (default_is_write, Some(name)),
1163 RouteOverride::Block(_) => unreachable!("handled above"),
1164 };
1165
1166 let need_switch = if let Some(ref forced) = forced_target {
1171 let health = state.health.read().await;
1172 let reuse = current_node
1173 .as_ref()
1174 .map(|c| c == forced && health.get(c).map(|h| h.healthy).unwrap_or(false))
1175 .unwrap_or(false);
1176 !reuse
1177 } else if let Some(ref current) = current_node {
1178 let health = state.health.read().await;
1179 let current_healthy = health.get(current).map(|h| h.healthy).unwrap_or(false);
1180
1181 if !current_healthy {
1182 true
1183 } else if is_write {
1184 let is_primary = config.nodes.iter()
1186 .find(|n| n.address() == *current)
1187 .map(|n| n.role == NodeRole::Primary)
1188 .unwrap_or(false);
1189 !is_primary
1190 } else {
1191 false
1192 }
1193 } else {
1194 true
1195 };
1196
1197 let target_node = if let Some(forced) = forced_target {
1198 forced
1199 } else if need_switch {
1200 if is_write {
1201 Self::select_primary_with_timeout(session, state, config).await?
1202 } else {
1203 Self::select_read_node(session, state, config).await?
1204 }
1205 } else {
1206 current_node.clone().unwrap()
1207 };
1208
1209 let mut backend = if need_switch {
1210 drop(backend_stream);
1212
1213 let new_backend = tokio::time::timeout(
1215 config.pool.acquire_timeout(),
1216 TcpStream::connect(&target_node),
1217 )
1218 .await
1219 .map_err(|_| ProxyError::Connection(format!("Connection timeout to {}", target_node)))?
1220 .map_err(|e| {
1221 ProxyError::Connection(format!("Failed to connect to {}: {}", target_node, e))
1222 })?;
1223
1224 let params = session.variables.read().await.clone();
1226 let startup = Self::build_startup_message(¶ms);
1227 let mut backend = new_backend;
1228 backend
1229 .write_all(&startup)
1230 .await
1231 .map_err(|e| ProxyError::Network(format!("Backend startup error: {}", e)))?;
1232
1233 Self::complete_backend_auth(&mut backend).await?;
1235
1236 tracing::debug!(
1237 "Switched backend from {:?} to {} for {} query",
1238 current_node,
1239 target_node,
1240 if is_write { "write" } else { "read" }
1241 );
1242
1243 backend
1244 } else {
1245 backend_stream.unwrap()
1246 };
1247
1248 let encoded = msg.encode();
1250 backend
1251 .write_all(&encoded)
1252 .await
1253 .map_err(|e| ProxyError::Network(format!("Backend write error: {}", e)))?;
1254
1255 let mut response = Vec::new();
1257 let mut response_buffer = BytesMut::with_capacity(8192);
1258 let codec = ProtocolCodec::new();
1259
1260 loop {
1261 let mut read_buf = vec![0u8; 8192];
1262 let n = tokio::time::timeout(Duration::from_secs(30), backend.read(&mut read_buf))
1263 .await
1264 .map_err(|_| ProxyError::Network("Backend read timeout".to_string()))?
1265 .map_err(|e| ProxyError::Network(format!("Backend read error: {}", e)))?;
1266
1267 if n == 0 {
1268 break;
1269 }
1270
1271 response.extend_from_slice(&read_buf[..n]);
1272 response_buffer.extend_from_slice(&read_buf[..n]);
1273
1274 while let Some(resp_msg) = codec.decode_message(&mut response_buffer.clone())? {
1276 if resp_msg.msg_type == MessageType::ReadyForQuery {
1277 if !resp_msg.payload.is_empty() {
1279 let status = TransactionStatus::from_byte(resp_msg.payload[0]);
1280 let mut tx_state = session.tx_state.write().await;
1281 tx_state.in_transaction = status != TransactionStatus::Idle;
1282 }
1283 return Ok((response, Some(backend), Some(target_node)));
1284 }
1285 let _ = codec.decode_message(&mut response_buffer)?;
1286 }
1287 }
1288
1289 Ok((response, Some(backend), Some(target_node)))
1290 }
1291
1292 fn is_write_message(msg: &Message) -> bool {
1294 match msg.msg_type {
1295 MessageType::Query => {
1296 if let Ok(query_msg) = QueryMessage::parse(msg.payload.clone()) {
1298 Self::is_write_query(&query_msg.query)
1299 } else {
1300 false
1301 }
1302 }
1303 MessageType::Parse => {
1304 if let Ok(parse_msg) = ParseMessage::parse(msg.payload.clone()) {
1306 Self::is_write_query(&parse_msg.query)
1307 } else {
1308 false
1309 }
1310 }
1311 _ => false,
1313 }
1314 }
1315
1316 fn is_write_query(sql: &str) -> bool {
1318 let upper = sql.trim().to_uppercase();
1319
1320 if upper.starts_with("INSERT")
1322 || upper.starts_with("UPDATE")
1323 || upper.starts_with("DELETE")
1324 || upper.starts_with("CREATE")
1325 || upper.starts_with("DROP")
1326 || upper.starts_with("ALTER")
1327 || upper.starts_with("TRUNCATE")
1328 || upper.starts_with("GRANT")
1329 || upper.starts_with("REVOKE")
1330 || upper.starts_with("VACUUM")
1331 || upper.starts_with("REINDEX")
1332 || upper.starts_with("CLUSTER")
1333 {
1334 return true;
1335 }
1336
1337 if upper.starts_with("BEGIN")
1339 || upper.starts_with("START")
1340 || upper.starts_with("COMMIT")
1341 || upper.starts_with("ROLLBACK")
1342 || upper.starts_with("SAVEPOINT")
1343 || upper.starts_with("RELEASE")
1344 {
1345 return true;
1346 }
1347
1348 if upper.starts_with("SET") && !upper.starts_with("SET TRANSACTION READ ONLY") {
1350 return true;
1351 }
1352
1353 false
1354 }
1355
1356 async fn select_primary_with_timeout(
1358 session: &Arc<ClientSession>,
1359 state: &Arc<ServerState>,
1360 config: &ProxyConfig,
1361 ) -> Result<String> {
1362 let timeout = config.write_timeout();
1363 let start = std::time::Instant::now();
1364 let check_interval = Duration::from_millis(500);
1365
1366 loop {
1367 let health = state.health.read().await;
1369 let primary = config
1370 .nodes
1371 .iter()
1372 .find(|n| n.role == NodeRole::Primary && n.enabled);
1373
1374 if let Some(primary_node) = primary {
1375 if let Some(node_health) = health.get(&primary_node.address()) {
1376 if node_health.healthy {
1377 let mut current = session.current_node.write().await;
1379 *current = Some(primary_node.address());
1380 return Ok(primary_node.address());
1381 }
1382 }
1383 }
1384 drop(health);
1385
1386 if start.elapsed() >= timeout {
1388 state.metrics.failovers.fetch_add(1, Ordering::Relaxed);
1389 return Err(ProxyError::NoHealthyNodes);
1390 }
1391
1392 tracing::warn!(
1393 "Primary unavailable, waiting for failover... ({:.1}s elapsed, {:.1}s timeout)",
1394 start.elapsed().as_secs_f64(),
1395 timeout.as_secs_f64()
1396 );
1397
1398 tokio::time::sleep(check_interval).await;
1400 }
1401 }
1402
1403 async fn select_read_node(
1405 session: &Arc<ClientSession>,
1406 state: &Arc<ServerState>,
1407 config: &ProxyConfig,
1408 ) -> Result<String> {
1409 {
1411 let tx_state = session.tx_state.read().await;
1412 if tx_state.in_transaction {
1413 if let Some(node) = session.current_node.read().await.clone() {
1414 return Ok(node);
1415 }
1416 }
1417 }
1418
1419 let health = state.health.read().await;
1421 let healthy_standbys: Vec<&NodeConfig> = config
1422 .nodes
1423 .iter()
1424 .filter(|n| {
1425 n.enabled
1426 && (n.role == NodeRole::Standby || n.role == NodeRole::ReadReplica)
1427 && health
1428 .get(&n.address())
1429 .map(|h| h.healthy)
1430 .unwrap_or(false)
1431 })
1432 .collect();
1433
1434 if !healthy_standbys.is_empty() {
1435 let mut lb_state = state.lb_state.write().await;
1437 let index = lb_state.rr_counter as usize % healthy_standbys.len();
1438 lb_state.rr_counter = lb_state.rr_counter.wrapping_add(1);
1439 let node_addr = healthy_standbys[index].address();
1440
1441 let mut current = session.current_node.write().await;
1442 *current = Some(node_addr.clone());
1443 return Ok(node_addr);
1444 }
1445
1446 Self::select_node(session, state, config).await
1448 }
1449
1450 async fn complete_backend_auth(backend: &mut TcpStream) -> Result<()> {
1453 let codec = ProtocolCodec::new();
1454 let mut buffer = BytesMut::with_capacity(4096);
1455 let timeout = Duration::from_secs(10);
1456 let start = std::time::Instant::now();
1457
1458 loop {
1459 if start.elapsed() > timeout {
1460 return Err(ProxyError::Auth("Backend authentication timeout".to_string()));
1461 }
1462
1463 let mut read_buf = vec![0u8; 4096];
1464 let n = tokio::time::timeout(Duration::from_secs(5), backend.read(&mut read_buf))
1465 .await
1466 .map_err(|_| ProxyError::Auth("Read timeout during backend auth".to_string()))?
1467 .map_err(|e| ProxyError::Network(format!("Backend auth read error: {}", e)))?;
1468
1469 if n == 0 {
1470 return Err(ProxyError::Connection("Backend closed during auth".to_string()));
1471 }
1472
1473 buffer.extend_from_slice(&read_buf[..n]);
1474
1475 loop {
1477 if buffer.len() < 5 {
1478 break;
1479 }
1480
1481 let mut temp_buffer = buffer.clone();
1483 match codec.decode_message(&mut temp_buffer)? {
1484 Some(msg) => {
1485 match msg.msg_type {
1486 MessageType::ReadyForQuery => {
1487 return Ok(());
1489 }
1490 MessageType::ErrorResponse => {
1491 let err = ErrorResponse::parse(msg.payload)
1492 .map(|e| e.message().unwrap_or("Unknown error").to_string())
1493 .unwrap_or_else(|_| "Parse error".to_string());
1494 return Err(ProxyError::Auth(err));
1495 }
1496 _ => {
1497 }
1499 }
1500 let _ = codec.decode_message(&mut buffer)?;
1502 }
1503 None => {
1504 break;
1506 }
1507 }
1508 }
1509 }
1510 }
1511
1512 fn create_error_response(code: &str, message: &str) -> Vec<u8> {
1514 let mut fields = HashMap::new();
1515 fields.insert('S', "ERROR".to_string());
1516 fields.insert('V', "ERROR".to_string());
1517 fields.insert('C', code.to_string());
1518 fields.insert('M', message.to_string());
1519
1520 let err = ErrorResponse { fields };
1521 err.encode().encode().to_vec()
1522 }
1523
1524 fn create_ready_for_query(status: u8) -> Vec<u8> {
1527 let mut payload = BytesMut::with_capacity(1);
1528 payload.put_u8(status);
1529 Message::new(MessageType::ReadyForQuery, payload)
1530 .encode()
1531 .to_vec()
1532 }
1533
1534 #[cfg(feature = "wasm-plugins")]
1572 fn synthesise_cached_response(bytes: &[u8]) -> Result<Vec<u8>> {
1573 use serde::Deserialize;
1574
1575 #[derive(Deserialize)]
1576 struct CachedPayload {
1577 columns: Vec<ColumnDef>,
1578 rows: Vec<Vec<Option<String>>>,
1579 }
1580
1581 #[derive(Deserialize)]
1582 struct ColumnDef {
1583 name: String,
1584 #[serde(default = "default_text_oid")]
1585 oid: u32,
1586 }
1587
1588 fn default_text_oid() -> u32 {
1589 25 }
1591
1592 let payload: CachedPayload = serde_json::from_slice(bytes).map_err(|e| {
1593 ProxyError::Protocol(format!("invalid cached payload JSON: {}", e))
1594 })?;
1595
1596 if payload.columns.is_empty() {
1597 return Err(ProxyError::Protocol(
1598 "cached payload must declare at least one column".to_string(),
1599 ));
1600 }
1601
1602 let mut reply = Vec::new();
1603
1604 let mut rd = BytesMut::new();
1606 rd.put_u16(payload.columns.len() as u16);
1607 for col in &payload.columns {
1608 rd.extend_from_slice(col.name.as_bytes());
1609 rd.put_u8(0); rd.put_i32(0); rd.put_i16(0); rd.put_u32(col.oid);
1613 rd.put_i16(-1); rd.put_i32(-1); rd.put_i16(0); }
1617 reply.extend_from_slice(&Message::new(MessageType::RowDescription, rd).encode());
1618
1619 let column_count = payload.columns.len();
1621 for row in &payload.rows {
1622 if row.len() != column_count {
1623 return Err(ProxyError::Protocol(format!(
1624 "cached row has {} values but {} columns are declared",
1625 row.len(),
1626 column_count
1627 )));
1628 }
1629 let mut dr = BytesMut::new();
1630 dr.put_u16(row.len() as u16);
1631 for value in row {
1632 match value {
1633 Some(s) => {
1634 dr.put_i32(s.len() as i32);
1635 dr.extend_from_slice(s.as_bytes());
1636 }
1637 None => {
1638 dr.put_i32(-1); }
1640 }
1641 }
1642 reply.extend_from_slice(&Message::new(MessageType::DataRow, dr).encode());
1643 }
1644
1645 let tag = format!("SELECT {}", payload.rows.len());
1647 let mut cc = BytesMut::new();
1648 cc.extend_from_slice(tag.as_bytes());
1649 cc.put_u8(0);
1650 reply.extend_from_slice(&Message::new(MessageType::CommandComplete, cc).encode());
1651
1652 reply.extend_from_slice(&Self::create_ready_for_query(b'I'));
1654
1655 Ok(reply)
1656 }
1657
1658 fn apply_pre_query_hook(
1668 msg: Message,
1669 state: &Arc<ServerState>,
1670 session: &Arc<ClientSession>,
1671 ) -> (Message, PreQueryAction) {
1672 #[cfg(feature = "wasm-plugins")]
1673 {
1674 let pm = match state.plugin_manager.as_ref() {
1675 Some(pm) => pm,
1676 None => return (msg, PreQueryAction::Forward),
1677 };
1678
1679 if msg.msg_type != MessageType::Query {
1680 return (msg, PreQueryAction::Forward);
1681 }
1682
1683 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1684 Ok(q) => q,
1685 Err(_) => return (msg, PreQueryAction::Forward),
1686 };
1687
1688 let ctx = Self::build_query_context(&query_msg.query, session);
1689
1690 match pm.execute_pre_query(&ctx) {
1691 PreQueryResult::Continue => (msg, PreQueryAction::Forward),
1692 PreQueryResult::Block(reason) => (msg, PreQueryAction::Block(reason)),
1693 PreQueryResult::Rewrite(new_sql) => {
1694 let rewritten = QueryMessage { query: new_sql }.encode();
1695 (rewritten, PreQueryAction::Forward)
1696 }
1697 PreQueryResult::Cached(bytes) => (msg, PreQueryAction::Cached(bytes)),
1698 }
1699 }
1700 #[cfg(not(feature = "wasm-plugins"))]
1701 {
1702 let _ = (state, session);
1703 (msg, PreQueryAction::Forward)
1704 }
1705 }
1706
1707 #[cfg(feature = "anomaly-detection")]
1713 fn record_anomaly_observation(
1714 msg: &Message,
1715 state: &Arc<ServerState>,
1716 session: &Arc<ClientSession>,
1717 ) {
1718 if msg.msg_type != MessageType::Query {
1719 return;
1720 }
1721 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1722 Ok(q) => q,
1723 Err(_) => return,
1724 };
1725 let tenant = match session.variables.try_read() {
1734 Ok(vars) => vars
1735 .get("tenant_id")
1736 .or_else(|| vars.get("user"))
1737 .cloned()
1738 .unwrap_or_else(|| session.client_addr.ip().to_string()),
1739 Err(_) => session.client_addr.ip().to_string(),
1740 };
1741 let fingerprint = anomaly_fingerprint(&query_msg.query);
1742 let now = std::time::Instant::now();
1743 let iso = chrono::Utc::now().to_rfc3339();
1744 let obs = crate::anomaly::QueryObservation {
1745 tenant,
1746 fingerprint,
1747 sql: query_msg.query,
1748 timestamp: now,
1749 iso_timestamp: iso,
1750 };
1751 for ev in state.anomaly_detector.record_query(&obs) {
1752 tracing::warn!(
1753 anomaly = ?ev,
1754 "anomaly detected"
1755 );
1756 }
1757 }
1758
1759 async fn send_block_response(
1763 stream: &mut TcpStream,
1764 reason: &str,
1765 state: &Arc<ServerState>,
1766 ) -> Result<()> {
1767 let err = Self::create_error_response(
1768 "42000",
1769 &format!("Query blocked by plugin: {}", reason),
1770 );
1771 stream
1772 .write_all(&err)
1773 .await
1774 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
1775 let rfq = Self::create_ready_for_query(b'I');
1776 stream
1777 .write_all(&rfq)
1778 .await
1779 .map_err(|e| ProxyError::Network(format!("Write error: {}", e)))?;
1780 state
1781 .metrics
1782 .bytes_sent
1783 .fetch_add((err.len() + rfq.len()) as u64, Ordering::Relaxed);
1784 Ok(())
1785 }
1786
1787 #[cfg(feature = "wasm-plugins")]
1793 fn build_query_context(query: &str, session: &Arc<ClientSession>) -> QueryContext {
1794 let is_read_only = !Self::is_write_query(query);
1795 let mut hook_context = HookContext::default();
1796 hook_context.client_id = Some(session.id.to_string());
1797 QueryContext {
1798 query: query.to_string(),
1799 normalized: query.to_string(),
1800 tables: Vec::new(),
1801 is_read_only,
1802 hook_context,
1803 }
1804 }
1805
1806 async fn apply_authenticate_hook(
1827 _params: &HashMap<String, String>,
1828 _session: &Arc<ClientSession>,
1829 _state: &Arc<ServerState>,
1830 ) -> Result<()> {
1831 #[cfg(feature = "wasm-plugins")]
1832 {
1833 let pm = match _state.plugin_manager.as_ref() {
1834 Some(pm) => pm,
1835 None => return Ok(()),
1836 };
1837
1838 let request = PluginAuthRequest {
1839 headers: HashMap::new(),
1840 username: _params.get("user").cloned(),
1841 password: None,
1842 client_ip: _session.client_addr.ip().to_string(),
1843 database: _params.get("database").cloned(),
1844 };
1845
1846 match pm.execute_authenticate(&request) {
1847 AuthResult::Defer => Ok(()),
1848 AuthResult::Success(identity) => {
1849 tracing::debug!(
1850 user = %identity.username,
1851 roles = ?identity.roles,
1852 "plugin authenticated user"
1853 );
1854 *_session.plugin_identity.write().await = Some(identity);
1855 Ok(())
1856 }
1857 AuthResult::Denied(reason) => {
1858 tracing::info!(
1859 reason = %reason,
1860 client = %_session.client_addr,
1861 user = ?_params.get("user"),
1862 "plugin denied authentication"
1863 );
1864 Err(ProxyError::Auth(format!(
1865 "authentication denied by plugin: {}",
1866 reason
1867 )))
1868 }
1869 }
1870 }
1871 #[cfg(not(feature = "wasm-plugins"))]
1872 {
1873 Ok(())
1874 }
1875 }
1876
1877 fn apply_route_hook(
1880 msg: &Message,
1881 state: &Arc<ServerState>,
1882 session: &Arc<ClientSession>,
1883 ) -> RouteOverride {
1884 #[cfg(feature = "wasm-plugins")]
1885 {
1886 let pm = match state.plugin_manager.as_ref() {
1887 Some(pm) => pm,
1888 None => return RouteOverride::None,
1889 };
1890 if msg.msg_type != MessageType::Query {
1891 return RouteOverride::None;
1892 }
1893 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1894 Ok(q) => q,
1895 Err(_) => return RouteOverride::None,
1896 };
1897 let ctx = Self::build_query_context(&query_msg.query, session);
1898 match pm.execute_route(&ctx) {
1899 RouteResult::Default => RouteOverride::None,
1900 RouteResult::Primary => RouteOverride::Primary,
1901 RouteResult::Standby => RouteOverride::Standby,
1902 RouteResult::Node(name) => RouteOverride::Node(name),
1903 RouteResult::Block(reason) => RouteOverride::Block(reason),
1904 RouteResult::Branch(name) => {
1905 tracing::warn!(
1906 branch = %name,
1907 "Route hook returned Branch but branch routing is not yet wired — using default"
1908 );
1909 RouteOverride::None
1910 }
1911 }
1912 }
1913 #[cfg(not(feature = "wasm-plugins"))]
1914 {
1915 let _ = (msg, state, session);
1916 RouteOverride::None
1917 }
1918 }
1919
1920 #[cfg(feature = "wasm-plugins")]
1924 fn fire_post_query_hook(
1925 msg: &Message,
1926 session: &Arc<ClientSession>,
1927 state: &Arc<ServerState>,
1928 result: &Result<(Vec<u8>, Option<TcpStream>, Option<String>)>,
1929 elapsed: Duration,
1930 ) {
1931 let pm = match state.plugin_manager.as_ref() {
1932 Some(pm) => pm,
1933 None => return,
1934 };
1935 if msg.msg_type != MessageType::Query {
1936 return;
1937 }
1938 let query_msg = match QueryMessage::parse(msg.payload.clone()) {
1939 Ok(q) => q,
1940 Err(_) => return,
1941 };
1942 let ctx = Self::build_query_context(&query_msg.query, session);
1943 let outcome = match result {
1944 Ok((resp, _, node)) => PostQueryOutcome {
1945 success: true,
1946 target_node: node.clone(),
1947 elapsed_us: elapsed.as_micros() as u64,
1948 response_bytes: resp.len() as u64,
1949 error: None,
1950 },
1951 Err(e) => PostQueryOutcome {
1952 success: false,
1953 target_node: None,
1954 elapsed_us: elapsed.as_micros() as u64,
1955 response_bytes: 0,
1956 error: Some(e.to_string()),
1957 },
1958 };
1959 pm.execute_post_query(&ctx, &outcome);
1960 }
1961
1962 async fn select_node(
1966 session: &Arc<ClientSession>,
1967 state: &Arc<ServerState>,
1968 config: &ProxyConfig,
1969 ) -> Result<String> {
1970 {
1972 let tx_state = session.tx_state.read().await;
1973 if tx_state.in_transaction {
1974 if let Some(node) = session.current_node.read().await.clone() {
1975 return Ok(node);
1976 }
1977 }
1978 }
1979
1980 let health = state.health.read().await;
1982 let healthy_nodes: Vec<&NodeConfig> = config
1983 .nodes
1984 .iter()
1985 .filter(|n| {
1986 n.enabled
1987 && health
1988 .get(&n.address())
1989 .map(|h| h.healthy)
1990 .unwrap_or(false)
1991 })
1992 .collect();
1993
1994 if healthy_nodes.is_empty() {
1995 return Err(ProxyError::NoHealthyNodes);
1996 }
1997
1998 if let Some(primary) = healthy_nodes.iter().find(|n| n.role == NodeRole::Primary) {
2000 let node_addr = primary.address();
2001 let mut current = session.current_node.write().await;
2002 *current = Some(node_addr.clone());
2003 return Ok(node_addr);
2004 }
2005
2006 if let Some(standby) = healthy_nodes.iter().find(|n| n.role == NodeRole::Standby) {
2009 tracing::warn!("Primary unavailable, connecting to standby for initial session");
2010 let node_addr = standby.address();
2011 let mut current = session.current_node.write().await;
2012 *current = Some(node_addr.clone());
2013 return Ok(node_addr);
2014 }
2015
2016 Err(ProxyError::NoHealthyNodes)
2018 }
2019
2020 fn spawn_health_checker(&self) -> tokio::task::JoinHandle<()> {
2022 let state = self.state.clone();
2023 let config = self.config.clone();
2024 let mut shutdown_rx = self.shutdown_tx.subscribe();
2025
2026 tokio::spawn(async move {
2027 let mut interval =
2028 tokio::time::interval(std::time::Duration::from_secs(config.health.check_interval_secs));
2029
2030 loop {
2031 tokio::select! {
2032 _ = interval.tick() => {
2033 Self::check_all_nodes(&state, &config).await;
2034 }
2035 _ = shutdown_rx.recv() => {
2036 break;
2037 }
2038 }
2039 }
2040 })
2041 }
2042
2043 async fn check_all_nodes(state: &Arc<ServerState>, config: &ProxyConfig) {
2045 for node in &config.nodes {
2046 let result = Self::check_node_health(node, config).await;
2047 let mut health = state.health.write().await;
2048
2049 if let Some(node_health) = health.get_mut(&node.address()) {
2050 match result {
2051 Ok(latency) => {
2052 node_health.healthy = true;
2053 node_health.failure_count = 0;
2054 node_health.latency_ms = latency;
2055 node_health.last_error = None;
2056 }
2057 Err(e) => {
2058 node_health.failure_count += 1;
2059 node_health.last_error = Some(e.to_string());
2060
2061 if node_health.failure_count >= config.health.failure_threshold {
2062 node_health.healthy = false;
2063 tracing::warn!(
2064 "Node {} marked unhealthy after {} failures",
2065 node.address(),
2066 node_health.failure_count
2067 );
2068 }
2069 }
2070 }
2071 node_health.last_check = chrono::Utc::now();
2072 }
2073 }
2074 }
2075
2076 async fn check_node_health(node: &NodeConfig, config: &ProxyConfig) -> Result<f64> {
2078 let start = std::time::Instant::now();
2079
2080 let timeout = std::time::Duration::from_secs(config.health.check_timeout_secs);
2081 let _stream = tokio::time::timeout(timeout, TcpStream::connect(node.address()))
2082 .await
2083 .map_err(|_| ProxyError::HealthCheck(format!("Timeout connecting to {}", node.address())))?
2084 .map_err(|e| {
2085 ProxyError::HealthCheck(format!("Failed to connect to {}: {}", node.address(), e))
2086 })?;
2087
2088 let latency = start.elapsed().as_secs_f64() * 1000.0;
2090 Ok(latency)
2091 }
2092
2093 fn spawn_pool_manager(&self) -> tokio::task::JoinHandle<()> {
2095 let state = self.state.clone();
2096 let mut shutdown_rx = self.shutdown_tx.subscribe();
2097
2098 tokio::spawn(async move {
2099 let mut interval = tokio::time::interval(std::time::Duration::from_secs(30));
2100
2101 loop {
2102 tokio::select! {
2103 _ = interval.tick() => {
2104 #[cfg(feature = "pool-modes")]
2106 if let Some(ref pool_manager) = state.pool_manager {
2107 pool_manager.evict_idle().await;
2108 tracing::trace!("Pool-modes idle eviction completed");
2109 }
2110 }
2111 _ = shutdown_rx.recv() => {
2112 #[cfg(feature = "pool-modes")]
2114 if let Some(ref pool_manager) = state.pool_manager {
2115 pool_manager.close_all().await;
2116 tracing::info!("Pool-modes manager closed all connections");
2117 }
2118 break;
2119 }
2120 }
2121 }
2122 })
2123 }
2124
2125 pub fn shutdown(&self) {
2127 let _ = self.shutdown_tx.send(());
2128 }
2129
2130 #[cfg(feature = "pool-modes")]
2132 pub async fn pool_mode_stats(&self) -> Option<PoolModeStatsSnapshot> {
2133 if let Some(ref pool_manager) = self.state.pool_manager {
2134 let stats = pool_manager.get_stats().await;
2135 let metrics = pool_manager.metrics().snapshot();
2136 let default_mode = pool_manager.default_mode();
2137
2138 let avg_lease_duration_ms = metrics
2140 .mode_stats
2141 .get(&default_mode)
2142 .map(|s| s.avg_lease_duration_ms as u64)
2143 .unwrap_or(0);
2144
2145 Some(PoolModeStatsSnapshot {
2146 mode: format!("{:?}", default_mode),
2147 total_connections: stats.total_connections,
2148 active_leases: stats.active_connections,
2149 idle_connections: stats.idle_connections,
2150 node_count: stats.node_count,
2151 acquires: metrics.acquires,
2152 releases: metrics.releases,
2153 acquire_failures: metrics.acquire_failures,
2154 acquire_timeouts: metrics.acquire_timeouts,
2155 transactions_completed: metrics.transactions_completed,
2156 statements_executed: metrics.statements_executed,
2157 avg_lease_duration_ms,
2158 })
2159 } else {
2160 None
2161 }
2162 }
2163
2164 #[cfg(feature = "pool-modes")]
2166 pub async fn add_node_to_pool(&self, node: &NodeConfig) {
2167 if let Some(ref pool_manager) = self.state.pool_manager {
2168 let endpoint = NodeEndpoint::new(&node.host, node.port)
2169 .with_role(match node.role {
2170 NodeRole::Primary => crate::NodeRole::Primary,
2171 NodeRole::Standby => crate::NodeRole::Standby,
2172 NodeRole::ReadReplica => crate::NodeRole::ReadReplica,
2173 })
2174 .with_weight(node.weight);
2175 pool_manager.add_node(&endpoint).await;
2176 tracing::info!("Added node {} to pool manager", node.address());
2177 }
2178 }
2179
2180 pub fn metrics(&self) -> ServerMetricsSnapshot {
2182 ServerMetricsSnapshot {
2183 connections_accepted: self.state.metrics.connections_accepted.load(Ordering::Relaxed),
2184 connections_closed: self.state.metrics.connections_closed.load(Ordering::Relaxed),
2185 queries_processed: self.state.metrics.queries_processed.load(Ordering::Relaxed),
2186 bytes_received: self.state.metrics.bytes_received.load(Ordering::Relaxed),
2187 bytes_sent: self.state.metrics.bytes_sent.load(Ordering::Relaxed),
2188 failovers: self.state.metrics.failovers.load(Ordering::Relaxed),
2189 }
2190 }
2191}
2192
2193#[derive(Debug, Clone)]
2195pub struct ServerMetricsSnapshot {
2196 pub connections_accepted: u64,
2197 pub connections_closed: u64,
2198 pub queries_processed: u64,
2199 pub bytes_received: u64,
2200 pub bytes_sent: u64,
2201 pub failovers: u64,
2202}
2203
2204#[cfg(feature = "pool-modes")]
2206#[derive(Debug, Clone)]
2207pub struct PoolModeStatsSnapshot {
2208 pub mode: String,
2210 pub total_connections: usize,
2212 pub active_leases: usize,
2214 pub idle_connections: usize,
2216 pub node_count: usize,
2218 pub acquires: u64,
2220 pub releases: u64,
2222 pub acquire_failures: u64,
2224 pub acquire_timeouts: u64,
2226 pub transactions_completed: u64,
2228 pub statements_executed: u64,
2230 pub avg_lease_duration_ms: u64,
2232}
2233
2234#[cfg(test)]
2235mod tests {
2236 use super::*;
2237 use crate::config::{HealthConfig, LoadBalancerConfig, PoolConfig};
2238
2239 fn test_config() -> ProxyConfig {
2240 let mut config = ProxyConfig::default();
2241 config.listen_address = "127.0.0.1:0".to_string();
2242 config
2243 .add_node("127.0.0.1:5432", "primary")
2244 .unwrap();
2245 config
2246 }
2247
2248 #[test]
2249 fn test_server_creation() {
2250 let config = test_config();
2251 let server = ProxyServer::new(config);
2252 assert!(server.is_ok());
2253 }
2254
2255 #[test]
2256 fn test_initial_metrics() {
2257 let config = test_config();
2258 let server = ProxyServer::new(config).unwrap();
2259 let metrics = server.metrics();
2260 assert_eq!(metrics.connections_accepted, 0);
2261 assert_eq!(metrics.queries_processed, 0);
2262 }
2263
2264 #[tokio::test]
2265 async fn test_session_creation() {
2266 let config = test_config();
2267 let server = ProxyServer::new(config).unwrap();
2268
2269 let sessions = server.state.sessions.read().await;
2270 assert!(sessions.is_empty());
2271 }
2272
2273 #[tokio::test]
2274 async fn test_node_health_initialization() {
2275 let config = test_config();
2276 let server = ProxyServer::new(config).unwrap();
2277
2278 let health = server.state.health.read().await;
2279 assert!(!health.is_empty());
2280
2281 for node_health in health.values() {
2282 assert!(node_health.healthy);
2283 assert_eq!(node_health.failure_count, 0);
2284 }
2285 }
2286
2287 fn make_test_session() -> Arc<ClientSession> {
2289 Arc::new(ClientSession {
2290 id: Uuid::new_v4(),
2291 client_addr: "127.0.0.1:0".parse().unwrap(),
2292 current_node: RwLock::new(None),
2293 tx_state: RwLock::new(TransactionState::default()),
2294 variables: RwLock::new(HashMap::new()),
2295 created_at: chrono::Utc::now(),
2296 tr_mode: crate::config::TrMode::default(),
2297 #[cfg(feature = "pool-modes")]
2298 pool_client_id: crate::pool::lease::ClientId::default(),
2299 #[cfg(feature = "wasm-plugins")]
2300 plugin_identity: RwLock::new(None),
2301 })
2302 }
2303
2304 #[tokio::test]
2308 async fn test_apply_route_hook_no_plugin_manager_returns_none() {
2309 let config = test_config();
2310 let server = ProxyServer::new(config).unwrap();
2311 let session = make_test_session();
2312
2313 let msg = QueryMessage {
2314 query: "SELECT * FROM users".to_string(),
2315 }
2316 .encode();
2317
2318 let decision = ProxyServer::apply_route_hook(&msg, &server.state, &session);
2319 assert!(matches!(decision, RouteOverride::None));
2320 }
2321
2322 #[tokio::test]
2326 async fn test_apply_pre_query_hook_no_plugin_manager_forwards() {
2327 let config = test_config();
2328 let server = ProxyServer::new(config).unwrap();
2329 let session = make_test_session();
2330
2331 let original = QueryMessage {
2332 query: "SELECT 1".to_string(),
2333 }
2334 .encode();
2335 let original_bytes = original.encode().to_vec();
2336
2337 let (msg_out, action) =
2338 ProxyServer::apply_pre_query_hook(original, &server.state, &session);
2339
2340 assert!(matches!(action, PreQueryAction::Forward));
2341 assert_eq!(msg_out.encode().to_vec(), original_bytes);
2343 }
2344
2345 #[tokio::test]
2349 async fn test_apply_route_hook_skips_non_query_messages() {
2350 let config = test_config();
2351 let server = ProxyServer::new(config).unwrap();
2352 let session = make_test_session();
2353
2354 let sync_msg = Message::empty(MessageType::Sync);
2355 let decision = ProxyServer::apply_route_hook(&sync_msg, &server.state, &session);
2356 assert!(matches!(decision, RouteOverride::None));
2357 }
2358
2359 #[cfg(feature = "wasm-plugins")]
2364 #[test]
2365 fn test_init_plugin_manager_disabled_by_default_returns_none() {
2366 let config = test_config();
2367 assert!(!config.plugins.enabled);
2368 let pm = ProxyServer::init_plugin_manager(&config.plugins);
2369 assert!(pm.is_none());
2370 }
2371
2372 #[cfg(feature = "wasm-plugins")]
2376 #[test]
2377 fn test_init_plugin_manager_missing_dir_logs_warning() {
2378 let mut config = test_config();
2379 config.plugins.enabled = true;
2380 config.plugins.plugin_dir = "/definitely/not/a/real/path".to_string();
2381
2382 let pm = ProxyServer::init_plugin_manager(&config.plugins);
2384 assert!(pm.is_some());
2385 }
2386
2387 #[tokio::test]
2391 async fn test_apply_authenticate_hook_no_plugin_manager_defers() {
2392 let config = test_config();
2393 let server = ProxyServer::new(config).unwrap();
2394 let session = make_test_session();
2395
2396 let mut params = HashMap::new();
2397 params.insert("user".to_string(), "alice".to_string());
2398 params.insert("database".to_string(), "app".to_string());
2399
2400 let result =
2401 ProxyServer::apply_authenticate_hook(¶ms, &session, &server.state).await;
2402 assert!(result.is_ok());
2403
2404 #[cfg(feature = "wasm-plugins")]
2406 {
2407 let ident = session.plugin_identity.read().await;
2408 assert!(ident.is_none());
2409 }
2410 }
2411
2412 #[cfg(feature = "wasm-plugins")]
2420 #[test]
2421 fn test_synthesise_cached_response_roundtrip() {
2422 let payload = br#"{
2423 "columns": [
2424 {"name": "id", "oid": 23},
2425 {"name": "email", "oid": 25}
2426 ],
2427 "rows": [
2428 ["1", "alice@example.com"],
2429 ["2", null]
2430 ]
2431 }"#;
2432 let reply =
2433 ProxyServer::synthesise_cached_response(payload).expect("synthesis");
2434
2435 let mut tags = Vec::new();
2438 let mut i = 0;
2439 while i < reply.len() {
2440 let tag = reply[i];
2441 let len = u32::from_be_bytes([
2442 reply[i + 1],
2443 reply[i + 2],
2444 reply[i + 3],
2445 reply[i + 4],
2446 ]) as usize;
2447 tags.push(tag);
2448 i += 1 + len;
2449 }
2450 assert_eq!(i, reply.len(), "no trailing bytes");
2451 assert_eq!(
2452 tags,
2453 vec![b'T', b'D', b'D', b'C', b'Z'],
2454 "wire frame order"
2455 );
2456
2457 assert_eq!(*reply.last().unwrap(), b'I');
2459 }
2460
2461 #[cfg(feature = "wasm-plugins")]
2464 #[test]
2465 fn test_synthesise_cached_response_rejects_row_width_mismatch() {
2466 let payload = br#"{
2467 "columns": [{"name": "id", "oid": 23}, {"name": "name", "oid": 25}],
2468 "rows": [["1", "alice", "extra"]]
2469 }"#;
2470 let result = ProxyServer::synthesise_cached_response(payload);
2471 assert!(matches!(result, Err(ProxyError::Protocol(_))));
2472 }
2473
2474 #[cfg(feature = "wasm-plugins")]
2478 #[test]
2479 fn test_synthesise_cached_response_rejects_empty_columns() {
2480 let payload = br#"{ "columns": [], "rows": [] }"#;
2481 let result = ProxyServer::synthesise_cached_response(payload);
2482 assert!(matches!(result, Err(ProxyError::Protocol(_))));
2483 }
2484
2485 #[cfg(feature = "wasm-plugins")]
2488 #[test]
2489 fn test_synthesise_cached_response_rejects_bad_json() {
2490 let payload = b"not json at all";
2491 let result = ProxyServer::synthesise_cached_response(payload);
2492 assert!(matches!(result, Err(ProxyError::Protocol(_))));
2493 }
2494
2495 #[cfg(feature = "wasm-plugins")]
2504 #[tokio::test]
2505 async fn test_apply_authenticate_hook_with_manager_no_plugins_defers() {
2506 use crate::plugins::{PluginManager, PluginRuntimeConfig};
2507
2508 let config = test_config();
2509 let server = ProxyServer::new(config).unwrap();
2510 let session = make_test_session();
2511
2512 let pm = Arc::new(PluginManager::new(PluginRuntimeConfig::default()).unwrap());
2515 let augmented_state = Arc::new(ServerState {
2516 sessions: RwLock::new(HashMap::new()),
2517 health: RwLock::new(HashMap::new()),
2518 metrics: ServerMetrics::default(),
2519 lb_state: RwLock::new(LoadBalancerState {
2520 rr_counter: 0,
2521 }),
2522 #[cfg(feature = "pool-modes")]
2523 pool_manager: None,
2524 plugin_manager: Some(pm),
2525 #[cfg(feature = "ha-tr")]
2526 transaction_journal: Arc::new(
2527 crate::transaction_journal::TransactionJournal::new(),
2528 ),
2529 #[cfg(feature = "anomaly-detection")]
2530 anomaly_detector: Arc::new(
2531 crate::anomaly::AnomalyDetector::new(
2532 crate::anomaly::AnomalyConfig::default(),
2533 ),
2534 ),
2535 #[cfg(feature = "edge-proxy")]
2536 edge_cache: Arc::new(crate::edge::EdgeCache::new(10_000)),
2537 #[cfg(feature = "edge-proxy")]
2538 edge_registry: Arc::new(crate::edge::EdgeRegistry::new(
2539 32,
2540 std::time::Duration::from_secs(120),
2541 )),
2542 });
2543
2544 let mut params = HashMap::new();
2545 params.insert("user".to_string(), "alice".to_string());
2546
2547 let result =
2548 ProxyServer::apply_authenticate_hook(¶ms, &session, &augmented_state).await;
2549 assert!(result.is_ok());
2550 let ident = session.plugin_identity.read().await;
2551 assert!(ident.is_none());
2552 let _ = server;
2554 }
2555}