1use super::{NodeEndpoint, NodeId, ProxyError, Result};
7#[cfg(test)]
8use super::NodeRole;
9use crate::backend::{BackendClient, BackendConfig};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, RwLock};
15
16#[cfg(feature = "ha-tr")]
18use super::failover_replay::{FailoverReplay, ReplayConfig, ReplayResult};
19#[cfg(feature = "ha-tr")]
20use super::transaction_journal::TransactionJournal;
21
22#[derive(Debug, Clone)]
24pub struct FailoverConfig {
25 pub detection_time: Duration,
27 pub failover_timeout: Duration,
29 pub auto_failover: bool,
31 pub prefer_sync_standby: bool,
33 pub max_lag_bytes: u64,
35 pub retry_failed: bool,
37 pub max_retries: u32,
39}
40
41impl Default for FailoverConfig {
42 fn default() -> Self {
43 Self {
44 detection_time: Duration::from_secs(10),
45 failover_timeout: Duration::from_secs(60),
46 auto_failover: true,
47 prefer_sync_standby: true,
48 max_lag_bytes: 16 * 1024 * 1024, retry_failed: true,
50 max_retries: 3,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FailoverMode {
58 Automatic,
60 Manual,
62 Disabled,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum FailoverState {
69 Normal,
71 PrimaryFailed,
73 InProgress,
75 WaitingForSync,
77 Completed,
79 Failed,
81}
82
83#[derive(Debug, Clone)]
85pub enum FailoverEvent {
86 PrimaryFailed { node_id: NodeId },
88 FailoverStarted { from: NodeId, to: NodeId },
90 WaitingForSync { standby: NodeId, lag_bytes: u64 },
92 StandbyPromoted { new_primary: NodeId },
94 FailoverCompleted { duration_ms: u64 },
96 FailoverFailed { reason: String },
98 OldPrimaryRecovered { node_id: NodeId },
100}
101
102#[derive(Debug, Clone)]
104pub struct FailoverCandidate {
105 pub node_id: NodeId,
107 pub endpoint: NodeEndpoint,
109 pub is_sync: bool,
111 pub lag_bytes: u64,
113 pub priority: u32,
115 pub last_heartbeat: Option<chrono::DateTime<chrono::Utc>>,
117}
118
119#[derive(Debug, Clone)]
121pub struct FailoverHistoryEntry {
122 pub id: uuid::Uuid,
124 pub started_at: chrono::DateTime<chrono::Utc>,
126 pub ended_at: Option<chrono::DateTime<chrono::Utc>>,
128 pub old_primary: NodeId,
130 pub new_primary: Option<NodeId>,
132 pub success: bool,
134 pub error: Option<String>,
136}
137
138pub struct FailoverController {
140 config: FailoverConfig,
142 state: Arc<RwLock<FailoverState>>,
144 current_primary: Arc<RwLock<Option<NodeId>>>,
146 candidates: Arc<RwLock<HashMap<NodeId, FailoverCandidate>>>,
148 event_tx: mpsc::Sender<FailoverEvent>,
150 event_rx: Option<mpsc::Receiver<FailoverEvent>>,
152 failover_count: AtomicU64,
154 history: Arc<RwLock<Vec<FailoverHistoryEntry>>>,
156 backend_template: Option<BackendConfig>,
162}
163
164impl FailoverController {
165 pub fn new(config: FailoverConfig) -> Self {
167 let (event_tx, event_rx) = mpsc::channel(100);
168
169 Self {
170 config,
171 state: Arc::new(RwLock::new(FailoverState::Normal)),
172 current_primary: Arc::new(RwLock::new(None)),
173 candidates: Arc::new(RwLock::new(HashMap::new())),
174 event_tx,
175 event_rx: Some(event_rx),
176 failover_count: AtomicU64::new(0),
177 history: Arc::new(RwLock::new(Vec::new())),
178 backend_template: None,
179 }
180 }
181
182 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
185 self.backend_template = Some(template);
186 self
187 }
188
189 fn backend_config_for(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
192 self.backend_template.as_ref().map(|t| {
193 let mut c = t.clone();
194 c.host = endpoint.host.clone();
195 c.port = endpoint.port;
196 c
197 })
198 }
199
200 pub async fn set_primary(&self, node_id: NodeId) {
202 *self.current_primary.write().await = Some(node_id);
203 tracing::info!("Primary set to {:?}", node_id);
204 }
205
206 pub async fn get_primary(&self) -> Option<NodeId> {
208 *self.current_primary.read().await
209 }
210
211 pub async fn register_candidate(&self, candidate: FailoverCandidate) {
213 let node_id = candidate.node_id;
214 self.candidates.write().await.insert(node_id, candidate);
215 tracing::debug!("Registered failover candidate {:?}", node_id);
216 }
217
218 pub async fn remove_candidate(&self, node_id: &NodeId) {
220 self.candidates.write().await.remove(node_id);
221 }
222
223 pub async fn update_candidate_lag(&self, node_id: &NodeId, lag_bytes: u64) {
225 if let Some(candidate) = self.candidates.write().await.get_mut(node_id) {
226 candidate.lag_bytes = lag_bytes;
227 candidate.last_heartbeat = Some(chrono::Utc::now());
228 }
229 }
230
231 pub async fn state(&self) -> FailoverState {
233 *self.state.read().await
234 }
235
236 pub async fn on_primary_failed(&self, node_id: NodeId) -> Result<()> {
238 let current_primary = self.current_primary.read().await;
239 if *current_primary != Some(node_id) {
240 return Ok(()); }
242 drop(current_primary);
243
244 *self.state.write().await = FailoverState::PrimaryFailed;
245
246 let _ = self
247 .event_tx
248 .send(FailoverEvent::PrimaryFailed { node_id })
249 .await;
250
251 tracing::warn!("Primary node {:?} failed", node_id);
252
253 if self.config.auto_failover {
254 self.initiate_failover().await?;
255 }
256
257 Ok(())
258 }
259
260 pub async fn initiate_failover(&self) -> Result<()> {
262 let old_primary = self
263 .current_primary
264 .read()
265 .await
266 .ok_or_else(|| ProxyError::FailoverFailed("No primary to failover from".to_string()))?;
267
268 let candidate = self.select_best_candidate().await?;
270 let new_primary = candidate.node_id;
271
272 *self.state.write().await = FailoverState::InProgress;
273
274 let _ = self
275 .event_tx
276 .send(FailoverEvent::FailoverStarted {
277 from: old_primary,
278 to: new_primary,
279 })
280 .await;
281
282 let start = chrono::Utc::now();
283
284 let history_entry = FailoverHistoryEntry {
286 id: uuid::Uuid::new_v4(),
287 started_at: start,
288 ended_at: None,
289 old_primary,
290 new_primary: Some(new_primary),
291 success: false,
292 error: None,
293 };
294 self.history.write().await.push(history_entry);
295
296 if candidate.lag_bytes > self.config.max_lag_bytes {
298 *self.state.write().await = FailoverState::WaitingForSync;
299
300 let _ = self
301 .event_tx
302 .send(FailoverEvent::WaitingForSync {
303 standby: new_primary,
304 lag_bytes: candidate.lag_bytes,
305 })
306 .await;
307
308 let sync_result = self.wait_for_sync(new_primary).await;
310 if let Err(e) = sync_result {
311 self.fail_failover(&e.to_string()).await;
312 return Err(e);
313 }
314 }
315
316 self.promote_standby(new_primary).await?;
318
319 *self.current_primary.write().await = Some(new_primary);
321 *self.state.write().await = FailoverState::Completed;
322 self.failover_count.fetch_add(1, Ordering::SeqCst);
323
324 let duration = chrono::Utc::now()
325 .signed_duration_since(start)
326 .num_milliseconds() as u64;
327
328 if let Some(entry) = self.history.write().await.last_mut() {
330 entry.ended_at = Some(chrono::Utc::now());
331 entry.success = true;
332 }
333
334 let _ = self
335 .event_tx
336 .send(FailoverEvent::StandbyPromoted {
337 new_primary,
338 })
339 .await;
340
341 let _ = self
342 .event_tx
343 .send(FailoverEvent::FailoverCompleted { duration_ms: duration })
344 .await;
345
346 tracing::info!(
347 "Failover completed: {:?} -> {:?} in {}ms",
348 old_primary,
349 new_primary,
350 duration
351 );
352
353 tokio::spawn({
355 let state = self.state.clone();
356 async move {
357 tokio::time::sleep(Duration::from_secs(1)).await;
358 *state.write().await = FailoverState::Normal;
359 }
360 });
361
362 Ok(())
363 }
364
365 async fn select_best_candidate(&self) -> Result<FailoverCandidate> {
367 let candidates = self.candidates.read().await;
368
369 if candidates.is_empty() {
370 return Err(ProxyError::FailoverFailed(
371 "No failover candidates available".to_string(),
372 ));
373 }
374
375 let mut sorted: Vec<_> = candidates.values().cloned().collect();
377 sorted.sort_by(|a, b| {
378 if self.config.prefer_sync_standby {
380 if a.is_sync != b.is_sync {
381 return b.is_sync.cmp(&a.is_sync);
382 }
383 }
384 if a.lag_bytes != b.lag_bytes {
386 return a.lag_bytes.cmp(&b.lag_bytes);
387 }
388 a.priority.cmp(&b.priority)
390 });
391
392 sorted
393 .first()
394 .cloned()
395 .ok_or_else(|| ProxyError::FailoverFailed("No eligible candidates".to_string()))
396 }
397
398 async fn wait_for_sync(&self, standby: NodeId) -> Result<()> {
409 let endpoint = self
410 .candidates
411 .read()
412 .await
413 .get(&standby)
414 .map(|c| c.endpoint.clone());
415 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
416 Some(c) => c,
417 None => {
418 tokio::time::sleep(Duration::from_millis(50)).await;
421 return Ok(());
422 }
423 };
424
425 let overall = self.config.failover_timeout;
426 tokio::time::timeout(overall, Self::poll_until_caught_up(cfg))
427 .await
428 .map_err(|_| ProxyError::Timeout("standby sync timeout".to_string()))??;
429 Ok(())
430 }
431
432 async fn poll_until_caught_up(cfg: BackendConfig) -> Result<()> {
435 let mut client = BackendClient::connect(&cfg)
436 .await
437 .map_err(|e| ProxyError::Failover(format!("connect to candidate: {}", e)))?;
438
439 let mut last: Option<String> = None;
440 let mut stable_polls = 0u32;
441 loop {
442 let value = client
443 .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
444 .await
445 .map_err(|e| ProxyError::Failover(format!("wal lsn probe: {}", e)))?;
446 let lsn = value
447 .into_string()
448 .ok_or_else(|| ProxyError::Failover("null WAL replay LSN".into()))?;
449
450 if last.as_ref() == Some(&lsn) {
451 stable_polls += 1;
452 if stable_polls >= 2 {
453 tracing::info!(lsn = %lsn, "standby caught up");
454 client.close().await;
455 return Ok(());
456 }
457 } else {
458 stable_polls = 0;
459 last = Some(lsn);
460 }
461 tokio::time::sleep(Duration::from_millis(200)).await;
462 }
463 }
464
465 async fn promote_standby(&self, standby: NodeId) -> Result<()> {
475 let endpoint = self
476 .candidates
477 .read()
478 .await
479 .get(&standby)
480 .map(|c| c.endpoint.clone());
481 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
482 Some(c) => c,
483 None => {
484 tracing::info!(
485 node = ?standby,
486 "promote_standby: skeleton path (no backend template) — no-op"
487 );
488 return Ok(());
489 }
490 };
491
492 let wait_secs = self.config.failover_timeout.as_secs().max(10).min(300);
493 let mut client = BackendClient::connect(&cfg)
494 .await
495 .map_err(|e| ProxyError::FailoverFailed(format!("connect to promote: {}", e)))?;
496
497 let sql = format!("SELECT pg_promote(true, {})", wait_secs);
498 let value = client
499 .query_scalar(&sql)
500 .await
501 .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)))?;
502 let promoted = value
503 .as_bool("pg_promote")
504 .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote result: {}", e)))?
505 .unwrap_or(false);
506 client.close().await;
507
508 if !promoted {
509 return Err(ProxyError::FailoverFailed(
510 "pg_promote returned false".to_string(),
511 ));
512 }
513
514 let mut verify = BackendClient::connect(&cfg)
516 .await
517 .map_err(|e| ProxyError::FailoverFailed(format!("connect to verify: {}", e)))?;
518 let in_recovery = verify
519 .query_scalar("SELECT pg_is_in_recovery()")
520 .await
521 .map_err(|e| ProxyError::FailoverFailed(format!("verify probe: {}", e)))?;
522 verify.close().await;
523 let still_standby = in_recovery
524 .as_bool("pg_is_in_recovery")
525 .map_err(|e| ProxyError::FailoverFailed(format!("verify bool: {}", e)))?
526 .unwrap_or(true);
527 if still_standby {
528 return Err(ProxyError::FailoverFailed(
529 "post-promote pg_is_in_recovery still true".to_string(),
530 ));
531 }
532
533 tracing::info!(node = ?standby, "standby promoted to primary");
534 Ok(())
535 }
536
537 async fn fail_failover(&self, reason: &str) {
539 *self.state.write().await = FailoverState::Failed;
540
541 if let Some(entry) = self.history.write().await.last_mut() {
542 entry.ended_at = Some(chrono::Utc::now());
543 entry.success = false;
544 entry.error = Some(reason.to_string());
545 }
546
547 let _ = self
548 .event_tx
549 .send(FailoverEvent::FailoverFailed {
550 reason: reason.to_string(),
551 })
552 .await;
553
554 tracing::error!("Failover failed: {}", reason);
555 }
556
557 pub async fn on_old_primary_recovered(&self, node_id: NodeId) {
573 let _ = self
574 .event_tx
575 .send(FailoverEvent::OldPrimaryRecovered { node_id })
576 .await;
577 tracing::warn!(
578 "old primary {:?} recovered — must be demoted out-of-band to prevent split-brain",
579 node_id
580 );
581
582 let endpoint = self
586 .candidates
587 .read()
588 .await
589 .get(&node_id)
590 .map(|c| c.endpoint.clone());
591 let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
592 Some(c) => c,
593 None => return, };
595
596 match BackendClient::connect(&cfg).await {
597 Ok(mut client) => {
598 let in_recovery_result = client
599 .query_scalar("SELECT pg_is_in_recovery()")
600 .await;
601 client.close().await;
602 if let Ok(tv) = in_recovery_result {
603 if let Ok(Some(false)) = tv.as_bool("pg_is_in_recovery") {
604 tracing::error!(
605 "split-brain hazard: node {:?} recovered and still reports primary (pg_is_in_recovery=false). Shut it down or use pg_rewind before reintroducing.",
606 node_id
607 );
608 }
609 }
610 }
611 Err(e) => {
612 tracing::debug!(
613 error = %e,
614 "could not connect to recovered node for split-brain probe"
615 );
616 }
617 }
618 }
619
620 pub async fn manual_failover(&self, target: NodeId) -> Result<()> {
622 let candidates = self.candidates.read().await;
624 if !candidates.contains_key(&target) {
625 return Err(ProxyError::FailoverFailed(format!(
626 "Node {:?} is not a valid failover candidate",
627 target
628 )));
629 }
630 drop(candidates);
631
632 *self.state.write().await = FailoverState::InProgress;
634
635 let old_primary = self.current_primary.read().await.unwrap_or(NodeId::new());
636
637 let _ = self
638 .event_tx
639 .send(FailoverEvent::FailoverStarted {
640 from: old_primary,
641 to: target,
642 })
643 .await;
644
645 self.promote_standby(target).await?;
646
647 *self.current_primary.write().await = Some(target);
648 *self.state.write().await = FailoverState::Completed;
649 self.failover_count.fetch_add(1, Ordering::SeqCst);
650
651 Ok(())
652 }
653
654 pub fn failover_count(&self) -> u64 {
656 self.failover_count.load(Ordering::SeqCst)
657 }
658
659 pub async fn history(&self) -> Vec<FailoverHistoryEntry> {
661 self.history.read().await.clone()
662 }
663
664 pub fn take_event_receiver(&mut self) -> Option<mpsc::Receiver<FailoverEvent>> {
666 self.event_rx.take()
667 }
668
669 #[cfg(feature = "ha-tr")]
678 pub async fn coordinate_failover_replay(
679 &self,
680 journal: &TransactionJournal,
681 failed_node: NodeId,
682 new_primary_endpoint: &NodeEndpoint,
683 ) -> Result<CoordinatedReplayResult> {
684 let start = std::time::Instant::now();
685
686 tracing::info!(
687 "Starting coordinated replay: failed_node={:?}, new_primary={:?}",
688 failed_node,
689 new_primary_endpoint.id
690 );
691
692 let affected_txs = journal.get_transactions_for_node(failed_node).await;
694
695 if affected_txs.is_empty() {
696 tracing::info!("No active transactions to replay");
697 return Ok(CoordinatedReplayResult {
698 total_transactions: 0,
699 successful_replays: 0,
700 failed_replays: 0,
701 transaction_results: vec![],
702 duration_ms: start.elapsed().as_millis() as u64,
703 new_primary: new_primary_endpoint.id,
704 });
705 }
706
707 tracing::info!("Found {} active transactions to replay", affected_txs.len());
708
709 let max_lsn = affected_txs.iter().map(|tx| tx.start_lsn).max().unwrap_or(0);
711
712 self.wait_for_lsn_catchup(new_primary_endpoint.id, max_lsn).await?;
714
715 let replay_manager = FailoverReplay::new(ReplayConfig {
717 verify_results: true,
718 statement_timeout_ms: 30000,
719 retry_on_error: true,
720 max_retries: 3,
721 skip_read_only: false,
722 wait_for_wal_sync: false, max_wal_lag_bytes: 0,
724 });
725
726 let mut transaction_results = Vec::new();
727 let mut successful_replays = 0;
728 let mut failed_replays = 0;
729
730 for tx_journal in affected_txs {
731 let tx_id = tx_journal.tx_id;
732
733 tracing::debug!("Replaying transaction {:?} with {} entries", tx_id, tx_journal.entries.len());
734
735 match replay_manager.start_replay(tx_journal, new_primary_endpoint.id).await {
737 Ok(_) => {
738 match replay_manager.execute_replay(tx_id).await {
739 Ok(result) => {
740 if result.success {
741 successful_replays += 1;
742 tracing::debug!("Transaction {:?} replayed successfully", tx_id);
743 } else {
744 failed_replays += 1;
745 tracing::warn!(
746 "Transaction {:?} replay failed: {:?}",
747 tx_id,
748 result.error
749 );
750 }
751 transaction_results.push(result);
752 }
753 Err(e) => {
754 failed_replays += 1;
755 tracing::error!("Failed to execute replay for {:?}: {}", tx_id, e);
756 transaction_results.push(ReplayResult {
757 tx_id,
758 success: false,
759 statements_replayed: 0,
760 statements_skipped: 0,
761 statements_failed: 0,
762 verification_failures: 0,
763 duration_ms: 0,
764 error: Some(e.to_string()),
765 statement_results: vec![],
766 });
767 }
768 }
769 }
770 Err(e) => {
771 failed_replays += 1;
772 tracing::error!("Failed to start replay for {:?}: {}", tx_id, e);
773 transaction_results.push(ReplayResult {
774 tx_id,
775 success: false,
776 statements_replayed: 0,
777 statements_skipped: 0,
778 statements_failed: 0,
779 verification_failures: 0,
780 duration_ms: 0,
781 error: Some(e.to_string()),
782 statement_results: vec![],
783 });
784 }
785 }
786 }
787
788 let duration_ms = start.elapsed().as_millis() as u64;
789
790 tracing::info!(
791 "Coordinated replay completed: {}/{} successful in {}ms",
792 successful_replays,
793 successful_replays + failed_replays,
794 duration_ms
795 );
796
797 Ok(CoordinatedReplayResult {
798 total_transactions: successful_replays + failed_replays,
799 successful_replays,
800 failed_replays,
801 transaction_results,
802 duration_ms,
803 new_primary: new_primary_endpoint.id,
804 })
805 }
806
807 #[cfg(feature = "ha-tr")]
809 async fn wait_for_lsn_catchup(&self, node: NodeId, target_lsn: u64) -> Result<()> {
810 if target_lsn == 0 {
811 return Ok(());
812 }
813
814 tracing::debug!("Waiting for node {:?} to catch up to LSN {}", node, target_lsn);
815
816 let timeout = self.config.failover_timeout;
818 let start = std::time::Instant::now();
819
820 loop {
821 if start.elapsed() >= timeout {
822 return Err(ProxyError::Timeout(format!(
823 "Timeout waiting for node {:?} to catch up to LSN {}",
824 node, target_lsn
825 )));
826 }
827
828 let candidates = self.candidates.read().await;
830 if let Some(candidate) = candidates.get(&node) {
831 if candidate.lag_bytes == 0 {
834 tracing::debug!("Node {:?} has caught up", node);
835 return Ok(());
836 }
837 }
838 drop(candidates);
839
840 tokio::time::sleep(Duration::from_millis(100)).await;
841 }
842 }
843}
844
845#[cfg(feature = "ha-tr")]
847#[derive(Debug, Clone)]
848pub struct CoordinatedReplayResult {
849 pub total_transactions: usize,
851 pub successful_replays: usize,
853 pub failed_replays: usize,
855 pub transaction_results: Vec<ReplayResult>,
857 pub duration_ms: u64,
859 pub new_primary: NodeId,
861}
862
863#[cfg(feature = "ha-tr")]
864impl CoordinatedReplayResult {
865 pub fn all_successful(&self) -> bool {
867 self.failed_replays == 0
868 }
869
870 pub fn success_rate(&self) -> f64 {
872 if self.total_transactions == 0 {
873 100.0
874 } else {
875 (self.successful_replays as f64 / self.total_transactions as f64) * 100.0
876 }
877 }
878}
879
880#[cfg(test)]
881mod tests {
882 use super::*;
883
884 #[test]
885 fn test_config_default() {
886 let config = FailoverConfig::default();
887 assert!(config.auto_failover);
888 assert!(config.prefer_sync_standby);
889 assert_eq!(config.max_retries, 3);
890 }
891
892 #[tokio::test]
893 async fn test_set_get_primary() {
894 let controller = FailoverController::new(FailoverConfig::default());
895 let node_id = NodeId::new();
896
897 controller.set_primary(node_id).await;
898 assert_eq!(controller.get_primary().await, Some(node_id));
899 }
900
901 #[tokio::test]
902 async fn test_register_candidate() {
903 let controller = FailoverController::new(FailoverConfig::default());
904 let node_id = NodeId::new();
905
906 let candidate = FailoverCandidate {
907 node_id,
908 endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
909 is_sync: true,
910 lag_bytes: 0,
911 priority: 1,
912 last_heartbeat: None,
913 };
914
915 controller.register_candidate(candidate).await;
916
917 let candidates = controller.candidates.read().await;
918 assert!(candidates.contains_key(&node_id));
919 }
920
921 #[tokio::test]
922 async fn test_state_transitions() {
923 let controller = FailoverController::new(FailoverConfig::default());
924
925 assert_eq!(controller.state().await, FailoverState::Normal);
926
927 *controller.state.write().await = FailoverState::PrimaryFailed;
928 assert_eq!(controller.state().await, FailoverState::PrimaryFailed);
929 }
930
931 #[tokio::test]
932 async fn test_select_best_candidate() {
933 let controller = FailoverController::new(FailoverConfig::default());
934
935 let sync_node = NodeId::new();
936 let async_node = NodeId::new();
937
938 controller
939 .register_candidate(FailoverCandidate {
940 node_id: async_node,
941 endpoint: NodeEndpoint::new("async", 5432),
942 is_sync: false,
943 lag_bytes: 100,
944 priority: 1,
945 last_heartbeat: None,
946 })
947 .await;
948
949 controller
950 .register_candidate(FailoverCandidate {
951 node_id: sync_node,
952 endpoint: NodeEndpoint::new("sync", 5432),
953 is_sync: true,
954 lag_bytes: 50,
955 priority: 2,
956 last_heartbeat: None,
957 })
958 .await;
959
960 let best = controller.select_best_candidate().await.unwrap();
961 assert_eq!(best.node_id, sync_node);
963 }
964
965 #[cfg(feature = "ha-tr")]
966 #[tokio::test]
967 async fn test_coordinate_failover_replay_empty() {
968 use super::super::transaction_journal::TransactionJournal;
969
970 let controller = FailoverController::new(FailoverConfig::default());
971 let journal = TransactionJournal::new();
972 let failed_node = NodeId::new();
973 let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
974
975 let result = controller
977 .coordinate_failover_replay(&journal, failed_node, &new_primary)
978 .await
979 .unwrap();
980
981 assert_eq!(result.total_transactions, 0);
982 assert_eq!(result.successful_replays, 0);
983 assert_eq!(result.failed_replays, 0);
984 assert!(result.all_successful());
985 assert_eq!(result.success_rate(), 100.0);
986 }
987
988 #[cfg(feature = "ha-tr")]
989 #[tokio::test]
990 async fn test_coordinate_failover_replay_with_transactions() {
991 use super::super::transaction_journal::{TransactionJournal, JournalEntry, JournalValue, StatementType};
992 use uuid::Uuid;
993
994 let controller = FailoverController::new(FailoverConfig::default());
995 let journal = TransactionJournal::new();
996 let failed_node = NodeId::new();
997 let new_primary_id = NodeId::new();
998 let new_primary = NodeEndpoint::new("new-primary", 5432)
999 .with_role(NodeRole::Primary);
1000
1001 controller.register_candidate(FailoverCandidate {
1003 node_id: new_primary.id,
1004 endpoint: new_primary.clone(),
1005 is_sync: true,
1006 lag_bytes: 0,
1007 priority: 1,
1008 last_heartbeat: None,
1009 }).await;
1010
1011 let tx_id = Uuid::new_v4();
1013 let session_id = Uuid::new_v4();
1014 journal.begin_transaction(tx_id, session_id, failed_node, 100).await.unwrap();
1015 journal.log_statement(
1016 tx_id,
1017 "INSERT INTO users (name) VALUES ('test')".to_string(),
1018 vec![JournalValue::Text("test".to_string())],
1019 Some(12345),
1020 Some(1),
1021 10,
1022 ).await.unwrap();
1023
1024 let result = controller
1026 .coordinate_failover_replay(&journal, failed_node, &new_primary)
1027 .await
1028 .unwrap();
1029
1030 assert_eq!(result.total_transactions, 1);
1031 assert_eq!(result.successful_replays, 1);
1032 assert_eq!(result.failed_replays, 0);
1033 assert!(result.all_successful());
1034 }
1035
1036 #[cfg(feature = "ha-tr")]
1037 #[test]
1038 fn test_coordinated_replay_result_methods() {
1039 let result = CoordinatedReplayResult {
1040 total_transactions: 10,
1041 successful_replays: 8,
1042 failed_replays: 2,
1043 transaction_results: vec![],
1044 duration_ms: 1000,
1045 new_primary: NodeId::new(),
1046 };
1047
1048 assert!(!result.all_successful());
1049 assert_eq!(result.success_rate(), 80.0);
1050
1051 let perfect = CoordinatedReplayResult {
1052 total_transactions: 5,
1053 successful_replays: 5,
1054 failed_replays: 0,
1055 transaction_results: vec![],
1056 duration_ms: 500,
1057 new_primary: NodeId::new(),
1058 };
1059
1060 assert!(perfect.all_successful());
1061 assert_eq!(perfect.success_rate(), 100.0);
1062 }
1063}