Skip to main content

heliosdb_proxy/
failover_controller.rs

1//! Failover Controller - HeliosProxy
2//!
3//! Orchestrates failover operations including primary detection,
4//! automatic rerouting, and transaction replay coordination.
5
6use super::{NodeEndpoint, NodeId, ProxyError, Result};
7#[cfg(test)]
8use super::NodeRole;
9use crate::backend::{BackendClient, BackendConfig};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, RwLock};
15
16// TR (Transaction Replay) imports
17#[cfg(feature = "ha-tr")]
18use super::failover_replay::{FailoverReplay, ReplayConfig, ReplayResult};
19#[cfg(feature = "ha-tr")]
20use super::transaction_journal::TransactionJournal;
21
22/// Failover configuration
23#[derive(Debug, Clone)]
24pub struct FailoverConfig {
25    /// Time to wait before initiating failover
26    pub detection_time: Duration,
27    /// Maximum time to wait for failover completion
28    pub failover_timeout: Duration,
29    /// Automatic failover (vs manual confirmation)
30    pub auto_failover: bool,
31    /// Prefer synchronous standbys for failover
32    pub prefer_sync_standby: bool,
33    /// Maximum LSN lag allowed for standby promotion (bytes)
34    pub max_lag_bytes: u64,
35    /// Retry failed failovers
36    pub retry_failed: bool,
37    /// Max retry attempts
38    pub max_retries: u32,
39}
40
41impl Default for FailoverConfig {
42    fn default() -> Self {
43        Self {
44            detection_time: Duration::from_secs(10),
45            failover_timeout: Duration::from_secs(60),
46            auto_failover: true,
47            prefer_sync_standby: true,
48            max_lag_bytes: 16 * 1024 * 1024, // 16MB
49            retry_failed: true,
50            max_retries: 3,
51        }
52    }
53}
54
55/// Failover mode
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FailoverMode {
58    /// Automatic failover on primary failure
59    Automatic,
60    /// Manual failover (require confirmation)
61    Manual,
62    /// Disabled (no failover)
63    Disabled,
64}
65
66/// Failover state
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum FailoverState {
69    /// Normal operation
70    Normal,
71    /// Primary failure detected
72    PrimaryFailed,
73    /// Failover in progress
74    InProgress,
75    /// Waiting for standby to catch up
76    WaitingForSync,
77    /// Failover completed
78    Completed,
79    /// Failover failed
80    Failed,
81}
82
83/// Failover event
84#[derive(Debug, Clone)]
85pub enum FailoverEvent {
86    /// Primary failure detected
87    PrimaryFailed { node_id: NodeId },
88    /// Failover started
89    FailoverStarted { from: NodeId, to: NodeId },
90    /// Waiting for standby sync
91    WaitingForSync { standby: NodeId, lag_bytes: u64 },
92    /// Standby promoted
93    StandbyPromoted { new_primary: NodeId },
94    /// Failover completed
95    FailoverCompleted { duration_ms: u64 },
96    /// Failover failed
97    FailoverFailed { reason: String },
98    /// Old primary recovered (split-brain prevention)
99    OldPrimaryRecovered { node_id: NodeId },
100}
101
102/// Failover candidate information
103#[derive(Debug, Clone)]
104pub struct FailoverCandidate {
105    /// Node ID
106    pub node_id: NodeId,
107    /// Node endpoint
108    pub endpoint: NodeEndpoint,
109    /// Is synchronous standby
110    pub is_sync: bool,
111    /// Replication lag (bytes)
112    pub lag_bytes: u64,
113    /// Priority (lower = better)
114    pub priority: u32,
115    /// Last heartbeat
116    pub last_heartbeat: Option<chrono::DateTime<chrono::Utc>>,
117}
118
119/// Failover history entry
120#[derive(Debug, Clone)]
121pub struct FailoverHistoryEntry {
122    /// Failover ID
123    pub id: uuid::Uuid,
124    /// Start time
125    pub started_at: chrono::DateTime<chrono::Utc>,
126    /// End time
127    pub ended_at: Option<chrono::DateTime<chrono::Utc>>,
128    /// Old primary
129    pub old_primary: NodeId,
130    /// New primary
131    pub new_primary: Option<NodeId>,
132    /// Result
133    pub success: bool,
134    /// Error message (if failed)
135    pub error: Option<String>,
136}
137
138/// Failover Controller
139pub struct FailoverController {
140    /// Configuration
141    config: FailoverConfig,
142    /// Current state
143    state: Arc<RwLock<FailoverState>>,
144    /// Current primary node
145    current_primary: Arc<RwLock<Option<NodeId>>>,
146    /// Failover candidates (standbys)
147    candidates: Arc<RwLock<HashMap<NodeId, FailoverCandidate>>>,
148    /// Event channel sender
149    event_tx: mpsc::Sender<FailoverEvent>,
150    /// Event channel receiver
151    event_rx: Option<mpsc::Receiver<FailoverEvent>>,
152    /// Failover count
153    failover_count: AtomicU64,
154    /// Failover history
155    history: Arc<RwLock<Vec<FailoverHistoryEntry>>>,
156    /// Optional backend-connection template. Host/port are swapped to
157    /// a candidate's endpoint when running `pg_promote()` or polling
158    /// `pg_last_wal_replay_lsn()`. When `None`, all backend-talking
159    /// paths become no-ops that log and succeed — preserving the
160    /// pre-T0-TR4 behaviour for unit tests.
161    backend_template: Option<BackendConfig>,
162}
163
164impl FailoverController {
165    /// Create a new failover controller
166    pub fn new(config: FailoverConfig) -> Self {
167        let (event_tx, event_rx) = mpsc::channel(100);
168
169        Self {
170            config,
171            state: Arc::new(RwLock::new(FailoverState::Normal)),
172            current_primary: Arc::new(RwLock::new(None)),
173            candidates: Arc::new(RwLock::new(HashMap::new())),
174            event_tx,
175            event_rx: Some(event_rx),
176            failover_count: AtomicU64::new(0),
177            history: Arc::new(RwLock::new(Vec::new())),
178            backend_template: None,
179        }
180    }
181
182    /// Attach a backend-connection template so sync-wait and promotion
183    /// can actually run SQL against the candidate.
184    pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
185        self.backend_template = Some(template);
186        self
187    }
188
189    /// Build a BackendConfig for a specific node's endpoint. Returns
190    /// `None` when no template is configured (the no-op / test path).
191    fn backend_config_for(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
192        self.backend_template.as_ref().map(|t| {
193            let mut c = t.clone();
194            c.host = endpoint.host.clone();
195            c.port = endpoint.port;
196            c
197        })
198    }
199
200    /// Set the current primary
201    pub async fn set_primary(&self, node_id: NodeId) {
202        *self.current_primary.write().await = Some(node_id);
203        tracing::info!("Primary set to {:?}", node_id);
204    }
205
206    /// Get the current primary
207    pub async fn get_primary(&self) -> Option<NodeId> {
208        *self.current_primary.read().await
209    }
210
211    /// Register a failover candidate (standby)
212    pub async fn register_candidate(&self, candidate: FailoverCandidate) {
213        let node_id = candidate.node_id;
214        self.candidates.write().await.insert(node_id, candidate);
215        tracing::debug!("Registered failover candidate {:?}", node_id);
216    }
217
218    /// Remove a failover candidate
219    pub async fn remove_candidate(&self, node_id: &NodeId) {
220        self.candidates.write().await.remove(node_id);
221    }
222
223    /// Update candidate lag
224    pub async fn update_candidate_lag(&self, node_id: &NodeId, lag_bytes: u64) {
225        if let Some(candidate) = self.candidates.write().await.get_mut(node_id) {
226            candidate.lag_bytes = lag_bytes;
227            candidate.last_heartbeat = Some(chrono::Utc::now());
228        }
229    }
230
231    /// Get current state
232    pub async fn state(&self) -> FailoverState {
233        *self.state.read().await
234    }
235
236    /// Handle primary failure
237    pub async fn on_primary_failed(&self, node_id: NodeId) -> Result<()> {
238        let current_primary = self.current_primary.read().await;
239        if *current_primary != Some(node_id) {
240            return Ok(()); // Not the current primary
241        }
242        drop(current_primary);
243
244        *self.state.write().await = FailoverState::PrimaryFailed;
245
246        let _ = self
247            .event_tx
248            .send(FailoverEvent::PrimaryFailed { node_id })
249            .await;
250
251        tracing::warn!("Primary node {:?} failed", node_id);
252
253        if self.config.auto_failover {
254            self.initiate_failover().await?;
255        }
256
257        Ok(())
258    }
259
260    /// Initiate failover to best candidate
261    pub async fn initiate_failover(&self) -> Result<()> {
262        let old_primary = self
263            .current_primary
264            .read()
265            .await
266            .ok_or_else(|| ProxyError::FailoverFailed("No primary to failover from".to_string()))?;
267
268        // Select best candidate
269        let candidate = self.select_best_candidate().await?;
270        let new_primary = candidate.node_id;
271
272        *self.state.write().await = FailoverState::InProgress;
273
274        let _ = self
275            .event_tx
276            .send(FailoverEvent::FailoverStarted {
277                from: old_primary,
278                to: new_primary,
279            })
280            .await;
281
282        let start = chrono::Utc::now();
283
284        // Record history entry
285        let history_entry = FailoverHistoryEntry {
286            id: uuid::Uuid::new_v4(),
287            started_at: start,
288            ended_at: None,
289            old_primary,
290            new_primary: Some(new_primary),
291            success: false,
292            error: None,
293        };
294        self.history.write().await.push(history_entry);
295
296        // Check lag
297        if candidate.lag_bytes > self.config.max_lag_bytes {
298            *self.state.write().await = FailoverState::WaitingForSync;
299
300            let _ = self
301                .event_tx
302                .send(FailoverEvent::WaitingForSync {
303                    standby: new_primary,
304                    lag_bytes: candidate.lag_bytes,
305                })
306                .await;
307
308            // Wait for sync (with timeout)
309            let sync_result = self.wait_for_sync(new_primary).await;
310            if let Err(e) = sync_result {
311                self.fail_failover(&e.to_string()).await;
312                return Err(e);
313            }
314        }
315
316        // Promote standby
317        self.promote_standby(new_primary).await?;
318
319        // Complete failover
320        *self.current_primary.write().await = Some(new_primary);
321        *self.state.write().await = FailoverState::Completed;
322        self.failover_count.fetch_add(1, Ordering::SeqCst);
323
324        let duration = chrono::Utc::now()
325            .signed_duration_since(start)
326            .num_milliseconds() as u64;
327
328        // Update history
329        if let Some(entry) = self.history.write().await.last_mut() {
330            entry.ended_at = Some(chrono::Utc::now());
331            entry.success = true;
332        }
333
334        let _ = self
335            .event_tx
336            .send(FailoverEvent::StandbyPromoted {
337                new_primary,
338            })
339            .await;
340
341        let _ = self
342            .event_tx
343            .send(FailoverEvent::FailoverCompleted { duration_ms: duration })
344            .await;
345
346        tracing::info!(
347            "Failover completed: {:?} -> {:?} in {}ms",
348            old_primary,
349            new_primary,
350            duration
351        );
352
353        // Reset state after a moment
354        tokio::spawn({
355            let state = self.state.clone();
356            async move {
357                tokio::time::sleep(Duration::from_secs(1)).await;
358                *state.write().await = FailoverState::Normal;
359            }
360        });
361
362        Ok(())
363    }
364
365    /// Select the best failover candidate
366    async fn select_best_candidate(&self) -> Result<FailoverCandidate> {
367        let candidates = self.candidates.read().await;
368
369        if candidates.is_empty() {
370            return Err(ProxyError::FailoverFailed(
371                "No failover candidates available".to_string(),
372            ));
373        }
374
375        // Sort by: sync status, lag, priority
376        let mut sorted: Vec<_> = candidates.values().cloned().collect();
377        sorted.sort_by(|a, b| {
378            // Prefer sync standbys
379            if self.config.prefer_sync_standby {
380                if a.is_sync != b.is_sync {
381                    return b.is_sync.cmp(&a.is_sync);
382                }
383            }
384            // Then by lag
385            if a.lag_bytes != b.lag_bytes {
386                return a.lag_bytes.cmp(&b.lag_bytes);
387            }
388            // Then by priority
389            a.priority.cmp(&b.priority)
390        });
391
392        sorted
393            .first()
394            .cloned()
395            .ok_or_else(|| ProxyError::FailoverFailed("No eligible candidates".to_string()))
396    }
397
398    /// Wait for a standby to catch up before promotion.
399    ///
400    /// Polls `pg_last_wal_replay_lsn()` on the candidate at 200 ms
401    /// cadence. Two consecutive polls that return the same LSN are
402    /// treated as "caught up as far as it can go" (the primary is
403    /// presumed dead, so no new WAL is arriving). Bounded by
404    /// `config.failover_timeout`.
405    ///
406    /// When no backend template is attached, returns `Ok(())` after
407    /// a short delay — the pre-T0-TR4 skeleton behaviour for tests.
408    async fn wait_for_sync(&self, standby: NodeId) -> Result<()> {
409        let endpoint = self
410            .candidates
411            .read()
412            .await
413            .get(&standby)
414            .map(|c| c.endpoint.clone());
415        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
416            Some(c) => c,
417            None => {
418                // Skeleton path: simulate a brief wait so the state
419                // machine test harness still sees WaitingForSync.
420                tokio::time::sleep(Duration::from_millis(50)).await;
421                return Ok(());
422            }
423        };
424
425        let overall = self.config.failover_timeout;
426        tokio::time::timeout(overall, Self::poll_until_caught_up(cfg))
427            .await
428            .map_err(|_| ProxyError::Timeout("standby sync timeout".to_string()))??;
429        Ok(())
430    }
431
432    /// Connect to the candidate and poll `pg_last_wal_replay_lsn()`
433    /// until it stabilises across two consecutive 200 ms polls.
434    async fn poll_until_caught_up(cfg: BackendConfig) -> Result<()> {
435        let mut client = BackendClient::connect(&cfg)
436            .await
437            .map_err(|e| ProxyError::Failover(format!("connect to candidate: {}", e)))?;
438
439        let mut last: Option<String> = None;
440        let mut stable_polls = 0u32;
441        loop {
442            let value = client
443                .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
444                .await
445                .map_err(|e| ProxyError::Failover(format!("wal lsn probe: {}", e)))?;
446            let lsn = value
447                .into_string()
448                .ok_or_else(|| ProxyError::Failover("null WAL replay LSN".into()))?;
449
450            if last.as_ref() == Some(&lsn) {
451                stable_polls += 1;
452                if stable_polls >= 2 {
453                    tracing::info!(lsn = %lsn, "standby caught up");
454                    client.close().await;
455                    return Ok(());
456                }
457            } else {
458                stable_polls = 0;
459                last = Some(lsn);
460            }
461            tokio::time::sleep(Duration::from_millis(200)).await;
462        }
463    }
464
465    /// Promote a standby to primary via `pg_promote()`.
466    ///
467    /// Uses `pg_promote(wait => true, wait_seconds => N)` so the server
468    /// waits for promotion to complete before returning. Verifies
469    /// post-promotion by re-running `pg_is_in_recovery()` (must now be
470    /// `false`) on a fresh connection.
471    ///
472    /// When no backend template is attached, logs and returns `Ok(())`
473    /// — skeleton / test path.
474    async fn promote_standby(&self, standby: NodeId) -> Result<()> {
475        let endpoint = self
476            .candidates
477            .read()
478            .await
479            .get(&standby)
480            .map(|c| c.endpoint.clone());
481        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
482            Some(c) => c,
483            None => {
484                tracing::info!(
485                    node = ?standby,
486                    "promote_standby: skeleton path (no backend template) — no-op"
487                );
488                return Ok(());
489            }
490        };
491
492        let wait_secs = self.config.failover_timeout.as_secs().max(10).min(300);
493        let mut client = BackendClient::connect(&cfg)
494            .await
495            .map_err(|e| ProxyError::FailoverFailed(format!("connect to promote: {}", e)))?;
496
497        let sql = format!("SELECT pg_promote(true, {})", wait_secs);
498        let value = client
499            .query_scalar(&sql)
500            .await
501            .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)))?;
502        let promoted = value
503            .as_bool("pg_promote")
504            .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote result: {}", e)))?
505            .unwrap_or(false);
506        client.close().await;
507
508        if !promoted {
509            return Err(ProxyError::FailoverFailed(
510                "pg_promote returned false".to_string(),
511            ));
512        }
513
514        // Verify on a fresh connection that the node is no longer in recovery.
515        let mut verify = BackendClient::connect(&cfg)
516            .await
517            .map_err(|e| ProxyError::FailoverFailed(format!("connect to verify: {}", e)))?;
518        let in_recovery = verify
519            .query_scalar("SELECT pg_is_in_recovery()")
520            .await
521            .map_err(|e| ProxyError::FailoverFailed(format!("verify probe: {}", e)))?;
522        verify.close().await;
523        let still_standby = in_recovery
524            .as_bool("pg_is_in_recovery")
525            .map_err(|e| ProxyError::FailoverFailed(format!("verify bool: {}", e)))?
526            .unwrap_or(true);
527        if still_standby {
528            return Err(ProxyError::FailoverFailed(
529                "post-promote pg_is_in_recovery still true".to_string(),
530            ));
531        }
532
533        tracing::info!(node = ?standby, "standby promoted to primary");
534        Ok(())
535    }
536
537    /// Fail the failover
538    async fn fail_failover(&self, reason: &str) {
539        *self.state.write().await = FailoverState::Failed;
540
541        if let Some(entry) = self.history.write().await.last_mut() {
542            entry.ended_at = Some(chrono::Utc::now());
543            entry.success = false;
544            entry.error = Some(reason.to_string());
545        }
546
547        let _ = self
548            .event_tx
549            .send(FailoverEvent::FailoverFailed {
550                reason: reason.to_string(),
551            })
552            .await;
553
554        tracing::error!("Failover failed: {}", reason);
555    }
556
557    /// Handle old primary recovery (split-brain prevention).
558    ///
559    /// PostgreSQL has no built-in "demote the current primary" command —
560    /// re-joining as a standby requires stopping the process and
561    /// re-initialising (`pg_rewind` or `pg_basebackup`). This method
562    /// therefore cannot fully automate demotion. What it CAN do:
563    ///
564    /// 1. Connect to the recovered node and verify whether it still
565    ///    believes it is the primary (`pg_is_in_recovery() = false`).
566    /// 2. Emit `OldPrimaryRecovered` so operators (or an external
567    ///    orchestrator like Patroni / pg_auto_failover) can react.
568    ///
569    /// This is deliberately read-only. Rewriting WAL on a live cluster
570    /// without operator oversight is the canonical way to lose data;
571    /// the proxy refuses to do it.
572    pub async fn on_old_primary_recovered(&self, node_id: NodeId) {
573        let _ = self
574            .event_tx
575            .send(FailoverEvent::OldPrimaryRecovered { node_id })
576            .await;
577        tracing::warn!(
578            "old primary {:?} recovered — must be demoted out-of-band to prevent split-brain",
579            node_id
580        );
581
582        // Best-effort: if we have an endpoint and template, probe to
583        // confirm recovery state and shout extra-loud if it still
584        // thinks it's primary.
585        let endpoint = self
586            .candidates
587            .read()
588            .await
589            .get(&node_id)
590            .map(|c| c.endpoint.clone());
591        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
592            Some(c) => c,
593            None => return, // no backend template → nothing more to do
594        };
595
596        match BackendClient::connect(&cfg).await {
597            Ok(mut client) => {
598                let in_recovery_result = client
599                    .query_scalar("SELECT pg_is_in_recovery()")
600                    .await;
601                client.close().await;
602                if let Ok(tv) = in_recovery_result {
603                    if let Ok(Some(false)) = tv.as_bool("pg_is_in_recovery") {
604                        tracing::error!(
605                            "split-brain hazard: node {:?} recovered and still reports primary (pg_is_in_recovery=false). Shut it down or use pg_rewind before reintroducing.",
606                            node_id
607                        );
608                    }
609                }
610            }
611            Err(e) => {
612                tracing::debug!(
613                    error = %e,
614                    "could not connect to recovered node for split-brain probe"
615                );
616            }
617        }
618    }
619
620    /// Manual failover to specific node
621    pub async fn manual_failover(&self, target: NodeId) -> Result<()> {
622        // Verify target is a valid candidate
623        let candidates = self.candidates.read().await;
624        if !candidates.contains_key(&target) {
625            return Err(ProxyError::FailoverFailed(format!(
626                "Node {:?} is not a valid failover candidate",
627                target
628            )));
629        }
630        drop(candidates);
631
632        // Force failover to specific node
633        *self.state.write().await = FailoverState::InProgress;
634
635        let old_primary = self.current_primary.read().await.unwrap_or(NodeId::new());
636
637        let _ = self
638            .event_tx
639            .send(FailoverEvent::FailoverStarted {
640                from: old_primary,
641                to: target,
642            })
643            .await;
644
645        self.promote_standby(target).await?;
646
647        *self.current_primary.write().await = Some(target);
648        *self.state.write().await = FailoverState::Completed;
649        self.failover_count.fetch_add(1, Ordering::SeqCst);
650
651        Ok(())
652    }
653
654    /// Get failover count
655    pub fn failover_count(&self) -> u64 {
656        self.failover_count.load(Ordering::SeqCst)
657    }
658
659    /// Get failover history
660    pub async fn history(&self) -> Vec<FailoverHistoryEntry> {
661        self.history.read().await.clone()
662    }
663
664    /// Take the event receiver
665    pub fn take_event_receiver(&mut self) -> Option<mpsc::Receiver<FailoverEvent>> {
666        self.event_rx.take()
667    }
668
669    /// Coordinate transaction replay after failover (TR integration)
670    ///
671    /// This method orchestrates the replay of in-flight transactions on a new primary
672    /// after a failover event. It ensures transaction atomicity by:
673    /// 1. Getting all active transactions from the journal that were on the failed node
674    /// 2. Waiting for the new primary to catch up to the required LSN
675    /// 3. Replaying each transaction's statements on the new primary
676    /// 4. Verifying results match the original execution (via checksums)
677    #[cfg(feature = "ha-tr")]
678    pub async fn coordinate_failover_replay(
679        &self,
680        journal: &TransactionJournal,
681        failed_node: NodeId,
682        new_primary_endpoint: &NodeEndpoint,
683    ) -> Result<CoordinatedReplayResult> {
684        let start = std::time::Instant::now();
685
686        tracing::info!(
687            "Starting coordinated replay: failed_node={:?}, new_primary={:?}",
688            failed_node,
689            new_primary_endpoint.id
690        );
691
692        // 1. Get all active transactions that were on the failed node
693        let affected_txs = journal.get_transactions_for_node(failed_node).await;
694
695        if affected_txs.is_empty() {
696            tracing::info!("No active transactions to replay");
697            return Ok(CoordinatedReplayResult {
698                total_transactions: 0,
699                successful_replays: 0,
700                failed_replays: 0,
701                transaction_results: vec![],
702                duration_ms: start.elapsed().as_millis() as u64,
703                new_primary: new_primary_endpoint.id,
704            });
705        }
706
707        tracing::info!("Found {} active transactions to replay", affected_txs.len());
708
709        // 2. Get the maximum LSN we need to wait for
710        let max_lsn = affected_txs.iter().map(|tx| tx.start_lsn).max().unwrap_or(0);
711
712        // 3. Wait for the new primary to catch up to this LSN
713        self.wait_for_lsn_catchup(new_primary_endpoint.id, max_lsn).await?;
714
715        // 4. Create replay manager and replay each transaction
716        let replay_manager = FailoverReplay::new(ReplayConfig {
717            verify_results: true,
718            statement_timeout_ms: 30000,
719            retry_on_error: true,
720            max_retries: 3,
721            skip_read_only: false,
722            wait_for_wal_sync: false, // Already waited above
723            max_wal_lag_bytes: 0,
724        });
725
726        let mut transaction_results = Vec::new();
727        let mut successful_replays = 0;
728        let mut failed_replays = 0;
729
730        for tx_journal in affected_txs {
731            let tx_id = tx_journal.tx_id;
732
733            tracing::debug!("Replaying transaction {:?} with {} entries", tx_id, tx_journal.entries.len());
734
735            // Start and execute replay
736            match replay_manager.start_replay(tx_journal, new_primary_endpoint.id).await {
737                Ok(_) => {
738                    match replay_manager.execute_replay(tx_id).await {
739                        Ok(result) => {
740                            if result.success {
741                                successful_replays += 1;
742                                tracing::debug!("Transaction {:?} replayed successfully", tx_id);
743                            } else {
744                                failed_replays += 1;
745                                tracing::warn!(
746                                    "Transaction {:?} replay failed: {:?}",
747                                    tx_id,
748                                    result.error
749                                );
750                            }
751                            transaction_results.push(result);
752                        }
753                        Err(e) => {
754                            failed_replays += 1;
755                            tracing::error!("Failed to execute replay for {:?}: {}", tx_id, e);
756                            transaction_results.push(ReplayResult {
757                                tx_id,
758                                success: false,
759                                statements_replayed: 0,
760                                statements_skipped: 0,
761                                statements_failed: 0,
762                                verification_failures: 0,
763                                duration_ms: 0,
764                                error: Some(e.to_string()),
765                                statement_results: vec![],
766                            });
767                        }
768                    }
769                }
770                Err(e) => {
771                    failed_replays += 1;
772                    tracing::error!("Failed to start replay for {:?}: {}", tx_id, e);
773                    transaction_results.push(ReplayResult {
774                        tx_id,
775                        success: false,
776                        statements_replayed: 0,
777                        statements_skipped: 0,
778                        statements_failed: 0,
779                        verification_failures: 0,
780                        duration_ms: 0,
781                        error: Some(e.to_string()),
782                        statement_results: vec![],
783                    });
784                }
785            }
786        }
787
788        let duration_ms = start.elapsed().as_millis() as u64;
789
790        tracing::info!(
791            "Coordinated replay completed: {}/{} successful in {}ms",
792            successful_replays,
793            successful_replays + failed_replays,
794            duration_ms
795        );
796
797        Ok(CoordinatedReplayResult {
798            total_transactions: successful_replays + failed_replays,
799            successful_replays,
800            failed_replays,
801            transaction_results,
802            duration_ms,
803            new_primary: new_primary_endpoint.id,
804        })
805    }
806
807    /// Wait for a node to catch up to a specific LSN
808    #[cfg(feature = "ha-tr")]
809    async fn wait_for_lsn_catchup(&self, node: NodeId, target_lsn: u64) -> Result<()> {
810        if target_lsn == 0 {
811            return Ok(());
812        }
813
814        tracing::debug!("Waiting for node {:?} to catch up to LSN {}", node, target_lsn);
815
816        // Use configured timeout
817        let timeout = self.config.failover_timeout;
818        let start = std::time::Instant::now();
819
820        loop {
821            if start.elapsed() >= timeout {
822                return Err(ProxyError::Timeout(format!(
823                    "Timeout waiting for node {:?} to catch up to LSN {}",
824                    node, target_lsn
825                )));
826            }
827
828            // Check if candidate has caught up
829            let candidates = self.candidates.read().await;
830            if let Some(candidate) = candidates.get(&node) {
831                // In a real implementation, we'd query the node's current LSN
832                // For now, we check if lag is acceptable
833                if candidate.lag_bytes == 0 {
834                    tracing::debug!("Node {:?} has caught up", node);
835                    return Ok(());
836                }
837            }
838            drop(candidates);
839
840            tokio::time::sleep(Duration::from_millis(100)).await;
841        }
842    }
843}
844
845/// Result of coordinated transaction replay after failover
846#[cfg(feature = "ha-tr")]
847#[derive(Debug, Clone)]
848pub struct CoordinatedReplayResult {
849    /// Total number of transactions replayed
850    pub total_transactions: usize,
851    /// Number of successful replays
852    pub successful_replays: usize,
853    /// Number of failed replays
854    pub failed_replays: usize,
855    /// Per-transaction replay results
856    pub transaction_results: Vec<ReplayResult>,
857    /// Total duration (ms)
858    pub duration_ms: u64,
859    /// New primary node ID
860    pub new_primary: NodeId,
861}
862
863#[cfg(feature = "ha-tr")]
864impl CoordinatedReplayResult {
865    /// Check if all transactions were replayed successfully
866    pub fn all_successful(&self) -> bool {
867        self.failed_replays == 0
868    }
869
870    /// Get the success rate as a percentage
871    pub fn success_rate(&self) -> f64 {
872        if self.total_transactions == 0 {
873            100.0
874        } else {
875            (self.successful_replays as f64 / self.total_transactions as f64) * 100.0
876        }
877    }
878}
879
880#[cfg(test)]
881mod tests {
882    use super::*;
883
884    #[test]
885    fn test_config_default() {
886        let config = FailoverConfig::default();
887        assert!(config.auto_failover);
888        assert!(config.prefer_sync_standby);
889        assert_eq!(config.max_retries, 3);
890    }
891
892    #[tokio::test]
893    async fn test_set_get_primary() {
894        let controller = FailoverController::new(FailoverConfig::default());
895        let node_id = NodeId::new();
896
897        controller.set_primary(node_id).await;
898        assert_eq!(controller.get_primary().await, Some(node_id));
899    }
900
901    #[tokio::test]
902    async fn test_register_candidate() {
903        let controller = FailoverController::new(FailoverConfig::default());
904        let node_id = NodeId::new();
905
906        let candidate = FailoverCandidate {
907            node_id,
908            endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
909            is_sync: true,
910            lag_bytes: 0,
911            priority: 1,
912            last_heartbeat: None,
913        };
914
915        controller.register_candidate(candidate).await;
916
917        let candidates = controller.candidates.read().await;
918        assert!(candidates.contains_key(&node_id));
919    }
920
921    #[tokio::test]
922    async fn test_state_transitions() {
923        let controller = FailoverController::new(FailoverConfig::default());
924
925        assert_eq!(controller.state().await, FailoverState::Normal);
926
927        *controller.state.write().await = FailoverState::PrimaryFailed;
928        assert_eq!(controller.state().await, FailoverState::PrimaryFailed);
929    }
930
931    #[tokio::test]
932    async fn test_select_best_candidate() {
933        let controller = FailoverController::new(FailoverConfig::default());
934
935        let sync_node = NodeId::new();
936        let async_node = NodeId::new();
937
938        controller
939            .register_candidate(FailoverCandidate {
940                node_id: async_node,
941                endpoint: NodeEndpoint::new("async", 5432),
942                is_sync: false,
943                lag_bytes: 100,
944                priority: 1,
945                last_heartbeat: None,
946            })
947            .await;
948
949        controller
950            .register_candidate(FailoverCandidate {
951                node_id: sync_node,
952                endpoint: NodeEndpoint::new("sync", 5432),
953                is_sync: true,
954                lag_bytes: 50,
955                priority: 2,
956                last_heartbeat: None,
957            })
958            .await;
959
960        let best = controller.select_best_candidate().await.unwrap();
961        // Sync standby should be preferred
962        assert_eq!(best.node_id, sync_node);
963    }
964
965    #[cfg(feature = "ha-tr")]
966    #[tokio::test]
967    async fn test_coordinate_failover_replay_empty() {
968        use super::super::transaction_journal::TransactionJournal;
969
970        let controller = FailoverController::new(FailoverConfig::default());
971        let journal = TransactionJournal::new();
972        let failed_node = NodeId::new();
973        let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
974
975        // With no transactions, should succeed immediately
976        let result = controller
977            .coordinate_failover_replay(&journal, failed_node, &new_primary)
978            .await
979            .unwrap();
980
981        assert_eq!(result.total_transactions, 0);
982        assert_eq!(result.successful_replays, 0);
983        assert_eq!(result.failed_replays, 0);
984        assert!(result.all_successful());
985        assert_eq!(result.success_rate(), 100.0);
986    }
987
988    #[cfg(feature = "ha-tr")]
989    #[tokio::test]
990    async fn test_coordinate_failover_replay_with_transactions() {
991        use super::super::transaction_journal::{TransactionJournal, JournalEntry, JournalValue, StatementType};
992        use uuid::Uuid;
993
994        let controller = FailoverController::new(FailoverConfig::default());
995        let journal = TransactionJournal::new();
996        let failed_node = NodeId::new();
997        let new_primary_id = NodeId::new();
998        let new_primary = NodeEndpoint::new("new-primary", 5432)
999            .with_role(NodeRole::Primary);
1000
1001        // Register the new primary as a candidate with zero lag
1002        controller.register_candidate(FailoverCandidate {
1003            node_id: new_primary.id,
1004            endpoint: new_primary.clone(),
1005            is_sync: true,
1006            lag_bytes: 0,
1007            priority: 1,
1008            last_heartbeat: None,
1009        }).await;
1010
1011        // Create a transaction on the failed node
1012        let tx_id = Uuid::new_v4();
1013        let session_id = Uuid::new_v4();
1014        journal.begin_transaction(tx_id, session_id, failed_node, 100).await.unwrap();
1015        journal.log_statement(
1016            tx_id,
1017            "INSERT INTO users (name) VALUES ('test')".to_string(),
1018            vec![JournalValue::Text("test".to_string())],
1019            Some(12345),
1020            Some(1),
1021            10,
1022        ).await.unwrap();
1023
1024        // Coordinate replay
1025        let result = controller
1026            .coordinate_failover_replay(&journal, failed_node, &new_primary)
1027            .await
1028            .unwrap();
1029
1030        assert_eq!(result.total_transactions, 1);
1031        assert_eq!(result.successful_replays, 1);
1032        assert_eq!(result.failed_replays, 0);
1033        assert!(result.all_successful());
1034    }
1035
1036    #[cfg(feature = "ha-tr")]
1037    #[test]
1038    fn test_coordinated_replay_result_methods() {
1039        let result = CoordinatedReplayResult {
1040            total_transactions: 10,
1041            successful_replays: 8,
1042            failed_replays: 2,
1043            transaction_results: vec![],
1044            duration_ms: 1000,
1045            new_primary: NodeId::new(),
1046        };
1047
1048        assert!(!result.all_successful());
1049        assert_eq!(result.success_rate(), 80.0);
1050
1051        let perfect = CoordinatedReplayResult {
1052            total_transactions: 5,
1053            successful_replays: 5,
1054            failed_replays: 0,
1055            transaction_results: vec![],
1056            duration_ms: 500,
1057            new_primary: NodeId::new(),
1058        };
1059
1060        assert!(perfect.all_successful());
1061        assert_eq!(perfect.success_rate(), 100.0);
1062    }
1063}