1#[cfg(test)]
7use super::NodeRole;
8use super::{NodeEndpoint, NodeId, ProxyError, Result};
9use crate::backend::{BackendClient, BackendConfig};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, RwLock};
15
16#[cfg(feature = "ha-tr")]
18use super::failover_replay::{FailoverReplay, ReplayConfig, ReplayResult};
19#[cfg(feature = "ha-tr")]
20use super::transaction_journal::TransactionJournal;
21
22#[derive(Debug, Clone)]
24pub struct FailoverConfig {
25 pub detection_time: Duration,
27 pub failover_timeout: Duration,
29 pub auto_failover: bool,
31 pub prefer_sync_standby: bool,
33 pub max_lag_bytes: u64,
35 pub retry_failed: bool,
37 pub max_retries: u32,
39}
40
41impl Default for FailoverConfig {
42 fn default() -> Self {
43 Self {
44 detection_time: Duration::from_secs(10),
45 failover_timeout: Duration::from_secs(60),
46 auto_failover: true,
47 prefer_sync_standby: true,
48 max_lag_bytes: 16 * 1024 * 1024, retry_failed: true,
50 max_retries: 3,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FailoverMode {
58 Automatic,
60 Manual,
62 Disabled,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum FailoverState {
69 Normal,
71 PrimaryFailed,
73 InProgress,
75 WaitingForSync,
77 Completed,
79 Failed,
81}
82
83#[derive(Debug, Clone)]
85pub enum FailoverEvent {
86 PrimaryFailed { node_id: NodeId },
88 FailoverStarted { from: NodeId, to: NodeId },
90 WaitingForSync { standby: NodeId, lag_bytes: u64 },
92 StandbyPromoted { new_primary: NodeId },
94 FailoverCompleted { duration_ms: u64 },
96 FailoverFailed { reason: String },
98 OldPrimaryRecovered { node_id: NodeId },
100}
101
102#[derive(Debug, Clone)]
104pub struct FailoverCandidate {
105 pub node_id: NodeId,
107 pub endpoint: NodeEndpoint,
109 pub is_sync: bool,
111 pub lag_bytes: u64,
113 pub priority: u32,
115 pub last_heartbeat: Option<chrono::DateTime<chrono::Utc>>,
117}
118
119#[derive(Debug, Clone)]
121pub struct FailoverHistoryEntry {
122 pub id: uuid::Uuid,
124 pub started_at: chrono::DateTime<chrono::Utc>,
126 pub ended_at: Option<chrono::DateTime<chrono::Utc>>,
128 pub old_primary: NodeId,
130 pub new_primary: Option<NodeId>,
132 pub success: bool,
134 pub error: Option<String>,
136}
137
138pub struct FailoverController {
140 config: FailoverConfig,
142 state: Arc<RwLock<FailoverState>>,
144 current_primary: Arc<RwLock<Option<NodeId>>>,
146 candidates: Arc<RwLock<HashMap<NodeId, FailoverCandidate>>>,
148 event_tx: mpsc::Sender<FailoverEvent>,
150 event_rx: Option<mpsc::Receiver<FailoverEvent>>,
152 failover_count: AtomicU64,
154 history: Arc<RwLock<Vec<FailoverHistoryEntry>>>,
156 backend_template: Option<BackendConfig>,
162}
163
164impl FailoverController {
165 pub fn new(config: FailoverConfig) -> Self {
167 let (event_tx, event_rx) = mpsc::channel(100);
168
169 Self {
170 config,
171 state: Arc::new(RwLock::new(FailoverState::Normal)),
172 current_primary: Arc::new(RwLock::new(None)),
173 candidates: Arc::new(RwLock::new(HashMap::new())),
174 event_tx,
175 event_rx: Some(event_rx),
176 failover_count: AtomicU64::new(0),
177 history: Arc::new(RwLock::new(Vec::new())),
178 backend_template: None,
179 }
180 }
181
182 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
185 self.backend_template = Some(template);
186 self
187 }
188
189 fn backend_config_for(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
192 self.backend_template.as_ref().map(|t| {
193 let mut c = t.clone();
194 c.host = endpoint.host.clone();
195 c.port = endpoint.port;
196 c
197 })
198 }
199
200 pub async fn set_primary(&self, node_id: NodeId) {
202 *self.current_primary.write().await = Some(node_id);
203 tracing::info!("Primary set to {:?}", node_id);
204 }
205
206 pub async fn get_primary(&self) -> Option<NodeId> {
208 *self.current_primary.read().await
209 }
210
211 pub async fn register_candidate(&self, candidate: FailoverCandidate) {
213 let node_id = candidate.node_id;
214 self.candidates.write().await.insert(node_id, candidate);
215 tracing::debug!("Registered failover candidate {:?}", node_id);
216 }
217
218 pub async fn remove_candidate(&self, node_id: &NodeId) {
220 self.candidates.write().await.remove(node_id);
221 }
222
223 pub async fn update_candidate_lag(&self, node_id: &NodeId, lag_bytes: u64) {
225 if let Some(candidate) = self.candidates.write().await.get_mut(node_id) {
226 candidate.lag_bytes = lag_bytes;
227 candidate.last_heartbeat = Some(chrono::Utc::now());
228 }
229 }
230
231 pub async fn state(&self) -> FailoverState {
233 *self.state.read().await
234 }
235
236 pub async fn on_primary_failed(&self, node_id: NodeId) -> Result<()> {
238 let current_primary = self.current_primary.read().await;
239 if *current_primary != Some(node_id) {
240 return Ok(()); }
242 drop(current_primary);
243
244 *self.state.write().await = FailoverState::PrimaryFailed;
245
246 let _ = self
247 .event_tx
248 .send(FailoverEvent::PrimaryFailed { node_id })
249 .await;
250
251 tracing::warn!("Primary node {:?} failed", node_id);
252
253 if self.config.auto_failover {
254 self.initiate_failover().await?;
255 }
256
257 Ok(())
258 }
259
260 pub async fn initiate_failover(&self) -> Result<()> {
262 let old_primary =
263 self.current_primary.read().await.ok_or_else(|| {
264 ProxyError::FailoverFailed("No primary to failover from".to_string())
265 })?;
266
267 let candidate = self.select_best_candidate().await?;
269 let new_primary = candidate.node_id;
270
271 *self.state.write().await = FailoverState::InProgress;
272
273 let _ = self
274 .event_tx
275 .send(FailoverEvent::FailoverStarted {
276 from: old_primary,
277 to: new_primary,
278 })
279 .await;
280
281 let start = chrono::Utc::now();
282
283 let history_entry = FailoverHistoryEntry {
285 id: uuid::Uuid::new_v4(),
286 started_at: start,
287 ended_at: None,
288 old_primary,
289 new_primary: Some(new_primary),
290 success: false,
291 error: None,
292 };
293 self.history.write().await.push(history_entry);
294
295 if candidate.lag_bytes > self.config.max_lag_bytes {
297 *self.state.write().await = FailoverState::WaitingForSync;
298
299 let _ = self
300 .event_tx
301 .send(FailoverEvent::WaitingForSync {
302 standby: new_primary,
303 lag_bytes: candidate.lag_bytes,
304 })
305 .await;
306
307 let sync_result = self.wait_for_sync(new_primary).await;
309 if let Err(e) = sync_result {
310 self.fail_failover(&e.to_string()).await;
311 return Err(e);
312 }
313 }
314
315 self.promote_standby(new_primary).await?;
317
318 *self.current_primary.write().await = Some(new_primary);
320 *self.state.write().await = FailoverState::Completed;
321 self.failover_count.fetch_add(1, Ordering::SeqCst);
322
323 let duration = chrono::Utc::now()
324 .signed_duration_since(start)
325 .num_milliseconds() as u64;
326
327 if let Some(entry) = self.history.write().await.last_mut() {
329 entry.ended_at = Some(chrono::Utc::now());
330 entry.success = true;
331 }
332
333 let _ = self
334 .event_tx
335 .send(FailoverEvent::StandbyPromoted { new_primary })
336 .await;
337
338 let _ = self
339 .event_tx
340 .send(FailoverEvent::FailoverCompleted {
341 duration_ms: duration,
342 })
343 .await;
344
345 tracing::info!(
346 "Failover completed: {:?} -> {:?} in {}ms",
347 old_primary,
348 new_primary,
349 duration
350 );
351
352 tokio::spawn({
354 let state = self.state.clone();
355 async move {
356 tokio::time::sleep(Duration::from_secs(1)).await;
357 *state.write().await = FailoverState::Normal;
358 }
359 });
360
361 Ok(())
362 }
363
364 async fn select_best_candidate(&self) -> Result<FailoverCandidate> {
366 let candidates = self.candidates.read().await;
367
368 if candidates.is_empty() {
369 return Err(ProxyError::FailoverFailed(
370 "No failover candidates available".to_string(),
371 ));
372 }
373
374 let mut sorted: Vec<_> = candidates.values().cloned().collect();
376 sorted.sort_by(|a, b| {
377 if self.config.prefer_sync_standby && a.is_sync != b.is_sync {
379 return b.is_sync.cmp(&a.is_sync);
380 }
381 if a.lag_bytes != b.lag_bytes {
383 return a.lag_bytes.cmp(&b.lag_bytes);
384 }
385 a.priority.cmp(&b.priority)
387 });
388
389 sorted
390 .first()
391 .cloned()
392 .ok_or_else(|| ProxyError::FailoverFailed("No eligible candidates".to_string()))
393 }
394
395 async fn wait_for_sync(&self, standby: NodeId) -> Result<()> {
406 let endpoint = self
407 .candidates
408 .read()
409 .await
410 .get(&standby)
411 .map(|c| c.endpoint.clone());
412 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
413 Some(c) => c,
414 None => {
415 tokio::time::sleep(Duration::from_millis(50)).await;
418 return Ok(());
419 }
420 };
421
422 let overall = self.config.failover_timeout;
423 tokio::time::timeout(overall, Self::poll_until_caught_up(cfg))
424 .await
425 .map_err(|_| ProxyError::Timeout("standby sync timeout".to_string()))??;
426 Ok(())
427 }
428
429 async fn poll_until_caught_up(cfg: BackendConfig) -> Result<()> {
432 let mut client = BackendClient::connect(&cfg)
433 .await
434 .map_err(|e| ProxyError::Failover(format!("connect to candidate: {}", e)))?;
435
436 let mut last: Option<String> = None;
437 let mut stable_polls = 0u32;
438 loop {
439 let value = client
440 .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
441 .await
442 .map_err(|e| ProxyError::Failover(format!("wal lsn probe: {}", e)))?;
443 let lsn = value
444 .into_string()
445 .ok_or_else(|| ProxyError::Failover("null WAL replay LSN".into()))?;
446
447 if last.as_ref() == Some(&lsn) {
448 stable_polls += 1;
449 if stable_polls >= 2 {
450 tracing::info!(lsn = %lsn, "standby caught up");
451 client.close().await;
452 return Ok(());
453 }
454 } else {
455 stable_polls = 0;
456 last = Some(lsn);
457 }
458 tokio::time::sleep(Duration::from_millis(200)).await;
459 }
460 }
461
462 async fn promote_standby(&self, standby: NodeId) -> Result<()> {
472 let endpoint = self
473 .candidates
474 .read()
475 .await
476 .get(&standby)
477 .map(|c| c.endpoint.clone());
478 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
479 Some(c) => c,
480 None => {
481 tracing::info!(
482 node = ?standby,
483 "promote_standby: skeleton path (no backend template) — no-op"
484 );
485 return Ok(());
486 }
487 };
488
489 let wait_secs = self.config.failover_timeout.as_secs().clamp(10, 300);
490 let mut client = BackendClient::connect(&cfg)
491 .await
492 .map_err(|e| ProxyError::FailoverFailed(format!("connect to promote: {}", e)))?;
493
494 let sql = format!("SELECT pg_promote(true, {})", wait_secs);
495 let value = client
496 .query_scalar(&sql)
497 .await
498 .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)))?;
499 let promoted = value
500 .as_bool("pg_promote")
501 .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote result: {}", e)))?
502 .unwrap_or(false);
503 client.close().await;
504
505 if !promoted {
506 return Err(ProxyError::FailoverFailed(
507 "pg_promote returned false".to_string(),
508 ));
509 }
510
511 let mut verify = BackendClient::connect(&cfg)
513 .await
514 .map_err(|e| ProxyError::FailoverFailed(format!("connect to verify: {}", e)))?;
515 let in_recovery = verify
516 .query_scalar("SELECT pg_is_in_recovery()")
517 .await
518 .map_err(|e| ProxyError::FailoverFailed(format!("verify probe: {}", e)))?;
519 verify.close().await;
520 let still_standby = in_recovery
521 .as_bool("pg_is_in_recovery")
522 .map_err(|e| ProxyError::FailoverFailed(format!("verify bool: {}", e)))?
523 .unwrap_or(true);
524 if still_standby {
525 return Err(ProxyError::FailoverFailed(
526 "post-promote pg_is_in_recovery still true".to_string(),
527 ));
528 }
529
530 tracing::info!(node = ?standby, "standby promoted to primary");
531 Ok(())
532 }
533
534 async fn fail_failover(&self, reason: &str) {
536 *self.state.write().await = FailoverState::Failed;
537
538 if let Some(entry) = self.history.write().await.last_mut() {
539 entry.ended_at = Some(chrono::Utc::now());
540 entry.success = false;
541 entry.error = Some(reason.to_string());
542 }
543
544 let _ = self
545 .event_tx
546 .send(FailoverEvent::FailoverFailed {
547 reason: reason.to_string(),
548 })
549 .await;
550
551 tracing::error!("Failover failed: {}", reason);
552 }
553
554 pub async fn on_old_primary_recovered(&self, node_id: NodeId) {
570 let _ = self
571 .event_tx
572 .send(FailoverEvent::OldPrimaryRecovered { node_id })
573 .await;
574 tracing::warn!(
575 "old primary {:?} recovered — must be demoted out-of-band to prevent split-brain",
576 node_id
577 );
578
579 let endpoint = self
583 .candidates
584 .read()
585 .await
586 .get(&node_id)
587 .map(|c| c.endpoint.clone());
588 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
589 Some(c) => c,
590 None => return, };
592
593 match BackendClient::connect(&cfg).await {
594 Ok(mut client) => {
595 let in_recovery_result = client.query_scalar("SELECT pg_is_in_recovery()").await;
596 client.close().await;
597 if let Ok(tv) = in_recovery_result {
598 if let Ok(Some(false)) = tv.as_bool("pg_is_in_recovery") {
599 tracing::error!(
600 "split-brain hazard: node {:?} recovered and still reports primary (pg_is_in_recovery=false). Shut it down or use pg_rewind before reintroducing.",
601 node_id
602 );
603 }
604 }
605 }
606 Err(e) => {
607 tracing::debug!(
608 error = %e,
609 "could not connect to recovered node for split-brain probe"
610 );
611 }
612 }
613 }
614
615 pub async fn manual_failover(&self, target: NodeId) -> Result<()> {
617 let candidates = self.candidates.read().await;
619 if !candidates.contains_key(&target) {
620 return Err(ProxyError::FailoverFailed(format!(
621 "Node {:?} is not a valid failover candidate",
622 target
623 )));
624 }
625 drop(candidates);
626
627 *self.state.write().await = FailoverState::InProgress;
629
630 let old_primary = self.current_primary.read().await.unwrap_or(NodeId::new());
631
632 let _ = self
633 .event_tx
634 .send(FailoverEvent::FailoverStarted {
635 from: old_primary,
636 to: target,
637 })
638 .await;
639
640 self.promote_standby(target).await?;
641
642 *self.current_primary.write().await = Some(target);
643 *self.state.write().await = FailoverState::Completed;
644 self.failover_count.fetch_add(1, Ordering::SeqCst);
645
646 Ok(())
647 }
648
649 pub fn failover_count(&self) -> u64 {
651 self.failover_count.load(Ordering::SeqCst)
652 }
653
654 pub async fn history(&self) -> Vec<FailoverHistoryEntry> {
656 self.history.read().await.clone()
657 }
658
659 pub fn take_event_receiver(&mut self) -> Option<mpsc::Receiver<FailoverEvent>> {
661 self.event_rx.take()
662 }
663
664 #[cfg(feature = "ha-tr")]
673 pub async fn coordinate_failover_replay(
674 &self,
675 journal: &TransactionJournal,
676 failed_node: NodeId,
677 new_primary_endpoint: &NodeEndpoint,
678 ) -> Result<CoordinatedReplayResult> {
679 let start = std::time::Instant::now();
680
681 tracing::info!(
682 "Starting coordinated replay: failed_node={:?}, new_primary={:?}",
683 failed_node,
684 new_primary_endpoint.id
685 );
686
687 let affected_txs = journal.get_transactions_for_node(failed_node).await;
689
690 if affected_txs.is_empty() {
691 tracing::info!("No active transactions to replay");
692 return Ok(CoordinatedReplayResult {
693 total_transactions: 0,
694 successful_replays: 0,
695 failed_replays: 0,
696 transaction_results: vec![],
697 duration_ms: start.elapsed().as_millis() as u64,
698 new_primary: new_primary_endpoint.id,
699 });
700 }
701
702 tracing::info!("Found {} active transactions to replay", affected_txs.len());
703
704 let max_lsn = affected_txs
706 .iter()
707 .map(|tx| tx.start_lsn)
708 .max()
709 .unwrap_or(0);
710
711 self.wait_for_lsn_catchup(new_primary_endpoint.id, max_lsn)
713 .await?;
714
715 let replay_manager = FailoverReplay::new(ReplayConfig {
717 verify_results: true,
718 statement_timeout_ms: 30000,
719 retry_on_error: true,
720 max_retries: 3,
721 skip_read_only: false,
722 wait_for_wal_sync: false, max_wal_lag_bytes: 0,
724 });
725
726 let mut transaction_results = Vec::new();
727 let mut successful_replays = 0;
728 let mut failed_replays = 0;
729
730 for tx_journal in affected_txs {
731 let tx_id = tx_journal.tx_id;
732
733 tracing::debug!(
734 "Replaying transaction {:?} with {} entries",
735 tx_id,
736 tx_journal.entries.len()
737 );
738
739 match replay_manager
741 .start_replay(tx_journal, new_primary_endpoint.id)
742 .await
743 {
744 Ok(_) => match replay_manager.execute_replay(tx_id).await {
745 Ok(result) => {
746 if result.success {
747 successful_replays += 1;
748 tracing::debug!("Transaction {:?} replayed successfully", tx_id);
749 } else {
750 failed_replays += 1;
751 tracing::warn!(
752 "Transaction {:?} replay failed: {:?}",
753 tx_id,
754 result.error
755 );
756 }
757 transaction_results.push(result);
758 }
759 Err(e) => {
760 failed_replays += 1;
761 tracing::error!("Failed to execute replay for {:?}: {}", tx_id, e);
762 transaction_results.push(ReplayResult {
763 tx_id,
764 success: false,
765 statements_replayed: 0,
766 statements_skipped: 0,
767 statements_failed: 0,
768 verification_failures: 0,
769 duration_ms: 0,
770 error: Some(e.to_string()),
771 statement_results: vec![],
772 });
773 }
774 },
775 Err(e) => {
776 failed_replays += 1;
777 tracing::error!("Failed to start replay for {:?}: {}", tx_id, e);
778 transaction_results.push(ReplayResult {
779 tx_id,
780 success: false,
781 statements_replayed: 0,
782 statements_skipped: 0,
783 statements_failed: 0,
784 verification_failures: 0,
785 duration_ms: 0,
786 error: Some(e.to_string()),
787 statement_results: vec![],
788 });
789 }
790 }
791 }
792
793 let duration_ms = start.elapsed().as_millis() as u64;
794
795 tracing::info!(
796 "Coordinated replay completed: {}/{} successful in {}ms",
797 successful_replays,
798 successful_replays + failed_replays,
799 duration_ms
800 );
801
802 Ok(CoordinatedReplayResult {
803 total_transactions: successful_replays + failed_replays,
804 successful_replays,
805 failed_replays,
806 transaction_results,
807 duration_ms,
808 new_primary: new_primary_endpoint.id,
809 })
810 }
811
812 #[cfg(feature = "ha-tr")]
814 async fn wait_for_lsn_catchup(&self, node: NodeId, target_lsn: u64) -> Result<()> {
815 if target_lsn == 0 {
816 return Ok(());
817 }
818
819 tracing::debug!(
820 "Waiting for node {:?} to catch up to LSN {}",
821 node,
822 target_lsn
823 );
824
825 let timeout = self.config.failover_timeout;
827 let start = std::time::Instant::now();
828
829 loop {
830 if start.elapsed() >= timeout {
831 return Err(ProxyError::Timeout(format!(
832 "Timeout waiting for node {:?} to catch up to LSN {}",
833 node, target_lsn
834 )));
835 }
836
837 let candidates = self.candidates.read().await;
839 if let Some(candidate) = candidates.get(&node) {
840 if candidate.lag_bytes == 0 {
843 tracing::debug!("Node {:?} has caught up", node);
844 return Ok(());
845 }
846 }
847 drop(candidates);
848
849 tokio::time::sleep(Duration::from_millis(100)).await;
850 }
851 }
852}
853
854#[cfg(feature = "ha-tr")]
856#[derive(Debug, Clone)]
857pub struct CoordinatedReplayResult {
858 pub total_transactions: usize,
860 pub successful_replays: usize,
862 pub failed_replays: usize,
864 pub transaction_results: Vec<ReplayResult>,
866 pub duration_ms: u64,
868 pub new_primary: NodeId,
870}
871
872#[cfg(feature = "ha-tr")]
873impl CoordinatedReplayResult {
874 pub fn all_successful(&self) -> bool {
876 self.failed_replays == 0
877 }
878
879 pub fn success_rate(&self) -> f64 {
881 if self.total_transactions == 0 {
882 100.0
883 } else {
884 (self.successful_replays as f64 / self.total_transactions as f64) * 100.0
885 }
886 }
887}
888
889#[cfg(test)]
890mod tests {
891 use super::*;
892
893 #[test]
894 fn test_config_default() {
895 let config = FailoverConfig::default();
896 assert!(config.auto_failover);
897 assert!(config.prefer_sync_standby);
898 assert_eq!(config.max_retries, 3);
899 }
900
901 #[tokio::test]
902 async fn test_set_get_primary() {
903 let controller = FailoverController::new(FailoverConfig::default());
904 let node_id = NodeId::new();
905
906 controller.set_primary(node_id).await;
907 assert_eq!(controller.get_primary().await, Some(node_id));
908 }
909
910 #[tokio::test]
911 async fn test_register_candidate() {
912 let controller = FailoverController::new(FailoverConfig::default());
913 let node_id = NodeId::new();
914
915 let candidate = FailoverCandidate {
916 node_id,
917 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
918 is_sync: true,
919 lag_bytes: 0,
920 priority: 1,
921 last_heartbeat: None,
922 };
923
924 controller.register_candidate(candidate).await;
925
926 let candidates = controller.candidates.read().await;
927 assert!(candidates.contains_key(&node_id));
928 }
929
930 #[tokio::test]
931 async fn test_state_transitions() {
932 let controller = FailoverController::new(FailoverConfig::default());
933
934 assert_eq!(controller.state().await, FailoverState::Normal);
935
936 *controller.state.write().await = FailoverState::PrimaryFailed;
937 assert_eq!(controller.state().await, FailoverState::PrimaryFailed);
938 }
939
940 #[tokio::test]
941 async fn test_select_best_candidate() {
942 let controller = FailoverController::new(FailoverConfig::default());
943
944 let sync_node = NodeId::new();
945 let async_node = NodeId::new();
946
947 controller
948 .register_candidate(FailoverCandidate {
949 node_id: async_node,
950 endpoint: NodeEndpoint::new("async", 5432),
951 is_sync: false,
952 lag_bytes: 100,
953 priority: 1,
954 last_heartbeat: None,
955 })
956 .await;
957
958 controller
959 .register_candidate(FailoverCandidate {
960 node_id: sync_node,
961 endpoint: NodeEndpoint::new("sync", 5432),
962 is_sync: true,
963 lag_bytes: 50,
964 priority: 2,
965 last_heartbeat: None,
966 })
967 .await;
968
969 let best = controller.select_best_candidate().await.unwrap();
970 assert_eq!(best.node_id, sync_node);
972 }
973
974 #[cfg(feature = "ha-tr")]
975 #[tokio::test]
976 async fn test_coordinate_failover_replay_empty() {
977 use super::super::transaction_journal::TransactionJournal;
978
979 let controller = FailoverController::new(FailoverConfig::default());
980 let journal = TransactionJournal::new();
981 let failed_node = NodeId::new();
982 let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
983
984 let result = controller
986 .coordinate_failover_replay(&journal, failed_node, &new_primary)
987 .await
988 .unwrap();
989
990 assert_eq!(result.total_transactions, 0);
991 assert_eq!(result.successful_replays, 0);
992 assert_eq!(result.failed_replays, 0);
993 assert!(result.all_successful());
994 assert_eq!(result.success_rate(), 100.0);
995 }
996
997 #[cfg(feature = "ha-tr")]
998 #[tokio::test]
999 async fn test_coordinate_failover_replay_with_transactions() {
1000 use super::super::transaction_journal::{
1001 JournalEntry, JournalValue, StatementType, TransactionJournal,
1002 };
1003 use uuid::Uuid;
1004
1005 let controller = FailoverController::new(FailoverConfig::default());
1006 let journal = TransactionJournal::new();
1007 let failed_node = NodeId::new();
1008 let new_primary_id = NodeId::new();
1009 let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
1010
1011 controller
1013 .register_candidate(FailoverCandidate {
1014 node_id: new_primary.id,
1015 endpoint: new_primary.clone(),
1016 is_sync: true,
1017 lag_bytes: 0,
1018 priority: 1,
1019 last_heartbeat: None,
1020 })
1021 .await;
1022
1023 let tx_id = Uuid::new_v4();
1025 let session_id = Uuid::new_v4();
1026 journal
1027 .begin_transaction(tx_id, session_id, failed_node, 100)
1028 .await
1029 .unwrap();
1030 journal
1031 .log_statement(
1032 tx_id,
1033 "INSERT INTO users (name) VALUES ('test')".to_string(),
1034 vec![JournalValue::Text("test".to_string())],
1035 Some(12345),
1036 Some(1),
1037 10,
1038 )
1039 .await
1040 .unwrap();
1041
1042 let result = controller
1044 .coordinate_failover_replay(&journal, failed_node, &new_primary)
1045 .await
1046 .unwrap();
1047
1048 assert_eq!(result.total_transactions, 1);
1049 assert_eq!(result.successful_replays, 1);
1050 assert_eq!(result.failed_replays, 0);
1051 assert!(result.all_successful());
1052 }
1053
1054 #[cfg(feature = "ha-tr")]
1055 #[test]
1056 fn test_coordinated_replay_result_methods() {
1057 let result = CoordinatedReplayResult {
1058 total_transactions: 10,
1059 successful_replays: 8,
1060 failed_replays: 2,
1061 transaction_results: vec![],
1062 duration_ms: 1000,
1063 new_primary: NodeId::new(),
1064 };
1065
1066 assert!(!result.all_successful());
1067 assert_eq!(result.success_rate(), 80.0);
1068
1069 let perfect = CoordinatedReplayResult {
1070 total_transactions: 5,
1071 successful_replays: 5,
1072 failed_replays: 0,
1073 transaction_results: vec![],
1074 duration_ms: 500,
1075 new_primary: NodeId::new(),
1076 };
1077
1078 assert!(perfect.all_successful());
1079 assert_eq!(perfect.success_rate(), 100.0);
1080 }
1081}