Skip to main content

heliosdb_proxy/
failover_controller.rs

1//! Failover Controller - HeliosProxy
2//!
3//! Orchestrates failover operations including primary detection,
4//! automatic rerouting, and transaction replay coordination.
5
6#[cfg(test)]
7use super::NodeRole;
8use super::{NodeEndpoint, NodeId, ProxyError, Result};
9use crate::backend::{BackendClient, BackendConfig};
10use std::collections::HashMap;
11use std::sync::atomic::{AtomicU64, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::{mpsc, RwLock};
15
16// TR (Transaction Replay) imports
17#[cfg(feature = "ha-tr")]
18use super::failover_replay::{FailoverReplay, ReplayConfig, ReplayResult};
19#[cfg(feature = "ha-tr")]
20use super::transaction_journal::TransactionJournal;
21
22/// Failover configuration
23#[derive(Debug, Clone)]
24pub struct FailoverConfig {
25    /// Time to wait before initiating failover
26    pub detection_time: Duration,
27    /// Maximum time to wait for failover completion
28    pub failover_timeout: Duration,
29    /// Automatic failover (vs manual confirmation)
30    pub auto_failover: bool,
31    /// Prefer synchronous standbys for failover
32    pub prefer_sync_standby: bool,
33    /// Maximum LSN lag allowed for standby promotion (bytes)
34    pub max_lag_bytes: u64,
35    /// Retry failed failovers
36    pub retry_failed: bool,
37    /// Max retry attempts
38    pub max_retries: u32,
39}
40
41impl Default for FailoverConfig {
42    fn default() -> Self {
43        Self {
44            detection_time: Duration::from_secs(10),
45            failover_timeout: Duration::from_secs(60),
46            auto_failover: true,
47            prefer_sync_standby: true,
48            max_lag_bytes: 16 * 1024 * 1024, // 16MB
49            retry_failed: true,
50            max_retries: 3,
51        }
52    }
53}
54
55/// Failover mode
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
57pub enum FailoverMode {
58    /// Automatic failover on primary failure
59    Automatic,
60    /// Manual failover (require confirmation)
61    Manual,
62    /// Disabled (no failover)
63    Disabled,
64}
65
66/// Failover state
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum FailoverState {
69    /// Normal operation
70    Normal,
71    /// Primary failure detected
72    PrimaryFailed,
73    /// Failover in progress
74    InProgress,
75    /// Waiting for standby to catch up
76    WaitingForSync,
77    /// Failover completed
78    Completed,
79    /// Failover failed
80    Failed,
81}
82
83/// Failover event
84#[derive(Debug, Clone)]
85pub enum FailoverEvent {
86    /// Primary failure detected
87    PrimaryFailed { node_id: NodeId },
88    /// Failover started
89    FailoverStarted { from: NodeId, to: NodeId },
90    /// Waiting for standby sync
91    WaitingForSync { standby: NodeId, lag_bytes: u64 },
92    /// Standby promoted
93    StandbyPromoted { new_primary: NodeId },
94    /// Failover completed
95    FailoverCompleted { duration_ms: u64 },
96    /// Failover failed
97    FailoverFailed { reason: String },
98    /// Old primary recovered (split-brain prevention)
99    OldPrimaryRecovered { node_id: NodeId },
100}
101
102/// Failover candidate information
103#[derive(Debug, Clone)]
104pub struct FailoverCandidate {
105    /// Node ID
106    pub node_id: NodeId,
107    /// Node endpoint
108    pub endpoint: NodeEndpoint,
109    /// Is synchronous standby
110    pub is_sync: bool,
111    /// Replication lag (bytes)
112    pub lag_bytes: u64,
113    /// Priority (lower = better)
114    pub priority: u32,
115    /// Last heartbeat
116    pub last_heartbeat: Option<chrono::DateTime<chrono::Utc>>,
117}
118
119/// Failover history entry
120#[derive(Debug, Clone)]
121pub struct FailoverHistoryEntry {
122    /// Failover ID
123    pub id: uuid::Uuid,
124    /// Start time
125    pub started_at: chrono::DateTime<chrono::Utc>,
126    /// End time
127    pub ended_at: Option<chrono::DateTime<chrono::Utc>>,
128    /// Old primary
129    pub old_primary: NodeId,
130    /// New primary
131    pub new_primary: Option<NodeId>,
132    /// Result
133    pub success: bool,
134    /// Error message (if failed)
135    pub error: Option<String>,
136}
137
138/// Failover Controller
139pub struct FailoverController {
140    /// Configuration
141    config: FailoverConfig,
142    /// Current state
143    state: Arc<RwLock<FailoverState>>,
144    /// Current primary node
145    current_primary: Arc<RwLock<Option<NodeId>>>,
146    /// Failover candidates (standbys)
147    candidates: Arc<RwLock<HashMap<NodeId, FailoverCandidate>>>,
148    /// Event channel sender
149    event_tx: mpsc::Sender<FailoverEvent>,
150    /// Event channel receiver
151    event_rx: Option<mpsc::Receiver<FailoverEvent>>,
152    /// Failover count
153    failover_count: AtomicU64,
154    /// Failover history
155    history: Arc<RwLock<Vec<FailoverHistoryEntry>>>,
156    /// Optional backend-connection template. Host/port are swapped to
157    /// a candidate's endpoint when running `pg_promote()` or polling
158    /// `pg_last_wal_replay_lsn()`. When `None`, all backend-talking
159    /// paths become no-ops that log and succeed — preserving the
160    /// pre-T0-TR4 behaviour for unit tests.
161    backend_template: Option<BackendConfig>,
162}
163
164impl FailoverController {
165    /// Create a new failover controller
166    pub fn new(config: FailoverConfig) -> Self {
167        let (event_tx, event_rx) = mpsc::channel(100);
168
169        Self {
170            config,
171            state: Arc::new(RwLock::new(FailoverState::Normal)),
172            current_primary: Arc::new(RwLock::new(None)),
173            candidates: Arc::new(RwLock::new(HashMap::new())),
174            event_tx,
175            event_rx: Some(event_rx),
176            failover_count: AtomicU64::new(0),
177            history: Arc::new(RwLock::new(Vec::new())),
178            backend_template: None,
179        }
180    }
181
182    /// Attach a backend-connection template so sync-wait and promotion
183    /// can actually run SQL against the candidate.
184    pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
185        self.backend_template = Some(template);
186        self
187    }
188
189    /// Build a BackendConfig for a specific node's endpoint. Returns
190    /// `None` when no template is configured (the no-op / test path).
191    fn backend_config_for(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
192        self.backend_template.as_ref().map(|t| {
193            let mut c = t.clone();
194            c.host = endpoint.host.clone();
195            c.port = endpoint.port;
196            c
197        })
198    }
199
200    /// Set the current primary
201    pub async fn set_primary(&self, node_id: NodeId) {
202        *self.current_primary.write().await = Some(node_id);
203        tracing::info!("Primary set to {:?}", node_id);
204    }
205
206    /// Get the current primary
207    pub async fn get_primary(&self) -> Option<NodeId> {
208        *self.current_primary.read().await
209    }
210
211    /// Register a failover candidate (standby)
212    pub async fn register_candidate(&self, candidate: FailoverCandidate) {
213        let node_id = candidate.node_id;
214        self.candidates.write().await.insert(node_id, candidate);
215        tracing::debug!("Registered failover candidate {:?}", node_id);
216    }
217
218    /// Remove a failover candidate
219    pub async fn remove_candidate(&self, node_id: &NodeId) {
220        self.candidates.write().await.remove(node_id);
221    }
222
223    /// Update candidate lag
224    pub async fn update_candidate_lag(&self, node_id: &NodeId, lag_bytes: u64) {
225        if let Some(candidate) = self.candidates.write().await.get_mut(node_id) {
226            candidate.lag_bytes = lag_bytes;
227            candidate.last_heartbeat = Some(chrono::Utc::now());
228        }
229    }
230
231    /// Get current state
232    pub async fn state(&self) -> FailoverState {
233        *self.state.read().await
234    }
235
236    /// Handle primary failure
237    pub async fn on_primary_failed(&self, node_id: NodeId) -> Result<()> {
238        let current_primary = self.current_primary.read().await;
239        if *current_primary != Some(node_id) {
240            return Ok(()); // Not the current primary
241        }
242        drop(current_primary);
243
244        *self.state.write().await = FailoverState::PrimaryFailed;
245
246        let _ = self
247            .event_tx
248            .send(FailoverEvent::PrimaryFailed { node_id })
249            .await;
250
251        tracing::warn!("Primary node {:?} failed", node_id);
252
253        if self.config.auto_failover {
254            self.initiate_failover().await?;
255        }
256
257        Ok(())
258    }
259
260    /// Initiate failover to best candidate
261    pub async fn initiate_failover(&self) -> Result<()> {
262        let old_primary =
263            self.current_primary.read().await.ok_or_else(|| {
264                ProxyError::FailoverFailed("No primary to failover from".to_string())
265            })?;
266
267        // Select best candidate
268        let candidate = self.select_best_candidate().await?;
269        let new_primary = candidate.node_id;
270
271        *self.state.write().await = FailoverState::InProgress;
272
273        let _ = self
274            .event_tx
275            .send(FailoverEvent::FailoverStarted {
276                from: old_primary,
277                to: new_primary,
278            })
279            .await;
280
281        let start = chrono::Utc::now();
282
283        // Record history entry
284        let history_entry = FailoverHistoryEntry {
285            id: uuid::Uuid::new_v4(),
286            started_at: start,
287            ended_at: None,
288            old_primary,
289            new_primary: Some(new_primary),
290            success: false,
291            error: None,
292        };
293        self.history.write().await.push(history_entry);
294
295        // Check lag
296        if candidate.lag_bytes > self.config.max_lag_bytes {
297            *self.state.write().await = FailoverState::WaitingForSync;
298
299            let _ = self
300                .event_tx
301                .send(FailoverEvent::WaitingForSync {
302                    standby: new_primary,
303                    lag_bytes: candidate.lag_bytes,
304                })
305                .await;
306
307            // Wait for sync (with timeout)
308            let sync_result = self.wait_for_sync(new_primary).await;
309            if let Err(e) = sync_result {
310                self.fail_failover(&e.to_string()).await;
311                return Err(e);
312            }
313        }
314
315        // Promote standby
316        self.promote_standby(new_primary).await?;
317
318        // Complete failover
319        *self.current_primary.write().await = Some(new_primary);
320        *self.state.write().await = FailoverState::Completed;
321        self.failover_count.fetch_add(1, Ordering::SeqCst);
322
323        let duration = chrono::Utc::now()
324            .signed_duration_since(start)
325            .num_milliseconds() as u64;
326
327        // Update history
328        if let Some(entry) = self.history.write().await.last_mut() {
329            entry.ended_at = Some(chrono::Utc::now());
330            entry.success = true;
331        }
332
333        let _ = self
334            .event_tx
335            .send(FailoverEvent::StandbyPromoted { new_primary })
336            .await;
337
338        let _ = self
339            .event_tx
340            .send(FailoverEvent::FailoverCompleted {
341                duration_ms: duration,
342            })
343            .await;
344
345        tracing::info!(
346            "Failover completed: {:?} -> {:?} in {}ms",
347            old_primary,
348            new_primary,
349            duration
350        );
351
352        // Reset state after a moment
353        tokio::spawn({
354            let state = self.state.clone();
355            async move {
356                tokio::time::sleep(Duration::from_secs(1)).await;
357                *state.write().await = FailoverState::Normal;
358            }
359        });
360
361        Ok(())
362    }
363
364    /// Select the best failover candidate
365    async fn select_best_candidate(&self) -> Result<FailoverCandidate> {
366        let candidates = self.candidates.read().await;
367
368        if candidates.is_empty() {
369            return Err(ProxyError::FailoverFailed(
370                "No failover candidates available".to_string(),
371            ));
372        }
373
374        // Sort by: sync status, lag, priority
375        let mut sorted: Vec<_> = candidates.values().cloned().collect();
376        sorted.sort_by(|a, b| {
377            // Prefer sync standbys
378            if self.config.prefer_sync_standby && a.is_sync != b.is_sync {
379                return b.is_sync.cmp(&a.is_sync);
380            }
381            // Then by lag
382            if a.lag_bytes != b.lag_bytes {
383                return a.lag_bytes.cmp(&b.lag_bytes);
384            }
385            // Then by priority
386            a.priority.cmp(&b.priority)
387        });
388
389        sorted
390            .first()
391            .cloned()
392            .ok_or_else(|| ProxyError::FailoverFailed("No eligible candidates".to_string()))
393    }
394
395    /// Wait for a standby to catch up before promotion.
396    ///
397    /// Polls `pg_last_wal_replay_lsn()` on the candidate at 200 ms
398    /// cadence. Two consecutive polls that return the same LSN are
399    /// treated as "caught up as far as it can go" (the primary is
400    /// presumed dead, so no new WAL is arriving). Bounded by
401    /// `config.failover_timeout`.
402    ///
403    /// When no backend template is attached, returns `Ok(())` after
404    /// a short delay — the pre-T0-TR4 skeleton behaviour for tests.
405    async fn wait_for_sync(&self, standby: NodeId) -> Result<()> {
406        let endpoint = self
407            .candidates
408            .read()
409            .await
410            .get(&standby)
411            .map(|c| c.endpoint.clone());
412        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
413            Some(c) => c,
414            None => {
415                // Skeleton path: simulate a brief wait so the state
416                // machine test harness still sees WaitingForSync.
417                tokio::time::sleep(Duration::from_millis(50)).await;
418                return Ok(());
419            }
420        };
421
422        let overall = self.config.failover_timeout;
423        tokio::time::timeout(overall, Self::poll_until_caught_up(cfg))
424            .await
425            .map_err(|_| ProxyError::Timeout("standby sync timeout".to_string()))??;
426        Ok(())
427    }
428
429    /// Connect to the candidate and poll `pg_last_wal_replay_lsn()`
430    /// until it stabilises across two consecutive 200 ms polls.
431    async fn poll_until_caught_up(cfg: BackendConfig) -> Result<()> {
432        let mut client = BackendClient::connect(&cfg)
433            .await
434            .map_err(|e| ProxyError::Failover(format!("connect to candidate: {}", e)))?;
435
436        let mut last: Option<String> = None;
437        let mut stable_polls = 0u32;
438        loop {
439            let value = client
440                .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
441                .await
442                .map_err(|e| ProxyError::Failover(format!("wal lsn probe: {}", e)))?;
443            let lsn = value
444                .into_string()
445                .ok_or_else(|| ProxyError::Failover("null WAL replay LSN".into()))?;
446
447            if last.as_ref() == Some(&lsn) {
448                stable_polls += 1;
449                if stable_polls >= 2 {
450                    tracing::info!(lsn = %lsn, "standby caught up");
451                    client.close().await;
452                    return Ok(());
453                }
454            } else {
455                stable_polls = 0;
456                last = Some(lsn);
457            }
458            tokio::time::sleep(Duration::from_millis(200)).await;
459        }
460    }
461
462    /// Promote a standby to primary via `pg_promote()`.
463    ///
464    /// Uses `pg_promote(wait => true, wait_seconds => N)` so the server
465    /// waits for promotion to complete before returning. Verifies
466    /// post-promotion by re-running `pg_is_in_recovery()` (must now be
467    /// `false`) on a fresh connection.
468    ///
469    /// When no backend template is attached, logs and returns `Ok(())`
470    /// — skeleton / test path.
471    async fn promote_standby(&self, standby: NodeId) -> Result<()> {
472        let endpoint = self
473            .candidates
474            .read()
475            .await
476            .get(&standby)
477            .map(|c| c.endpoint.clone());
478        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
479            Some(c) => c,
480            None => {
481                tracing::info!(
482                    node = ?standby,
483                    "promote_standby: skeleton path (no backend template) — no-op"
484                );
485                return Ok(());
486            }
487        };
488
489        let wait_secs = self.config.failover_timeout.as_secs().clamp(10, 300);
490        let mut client = BackendClient::connect(&cfg)
491            .await
492            .map_err(|e| ProxyError::FailoverFailed(format!("connect to promote: {}", e)))?;
493
494        let sql = format!("SELECT pg_promote(true, {})", wait_secs);
495        let value = client
496            .query_scalar(&sql)
497            .await
498            .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote: {}", e)))?;
499        let promoted = value
500            .as_bool("pg_promote")
501            .map_err(|e| ProxyError::FailoverFailed(format!("pg_promote result: {}", e)))?
502            .unwrap_or(false);
503        client.close().await;
504
505        if !promoted {
506            return Err(ProxyError::FailoverFailed(
507                "pg_promote returned false".to_string(),
508            ));
509        }
510
511        // Verify on a fresh connection that the node is no longer in recovery.
512        let mut verify = BackendClient::connect(&cfg)
513            .await
514            .map_err(|e| ProxyError::FailoverFailed(format!("connect to verify: {}", e)))?;
515        let in_recovery = verify
516            .query_scalar("SELECT pg_is_in_recovery()")
517            .await
518            .map_err(|e| ProxyError::FailoverFailed(format!("verify probe: {}", e)))?;
519        verify.close().await;
520        let still_standby = in_recovery
521            .as_bool("pg_is_in_recovery")
522            .map_err(|e| ProxyError::FailoverFailed(format!("verify bool: {}", e)))?
523            .unwrap_or(true);
524        if still_standby {
525            return Err(ProxyError::FailoverFailed(
526                "post-promote pg_is_in_recovery still true".to_string(),
527            ));
528        }
529
530        tracing::info!(node = ?standby, "standby promoted to primary");
531        Ok(())
532    }
533
534    /// Fail the failover
535    async fn fail_failover(&self, reason: &str) {
536        *self.state.write().await = FailoverState::Failed;
537
538        if let Some(entry) = self.history.write().await.last_mut() {
539            entry.ended_at = Some(chrono::Utc::now());
540            entry.success = false;
541            entry.error = Some(reason.to_string());
542        }
543
544        let _ = self
545            .event_tx
546            .send(FailoverEvent::FailoverFailed {
547                reason: reason.to_string(),
548            })
549            .await;
550
551        tracing::error!("Failover failed: {}", reason);
552    }
553
554    /// Handle old primary recovery (split-brain prevention).
555    ///
556    /// PostgreSQL has no built-in "demote the current primary" command —
557    /// re-joining as a standby requires stopping the process and
558    /// re-initialising (`pg_rewind` or `pg_basebackup`). This method
559    /// therefore cannot fully automate demotion. What it CAN do:
560    ///
561    /// 1. Connect to the recovered node and verify whether it still
562    ///    believes it is the primary (`pg_is_in_recovery() = false`).
563    /// 2. Emit `OldPrimaryRecovered` so operators (or an external
564    ///    orchestrator like Patroni / pg_auto_failover) can react.
565    ///
566    /// This is deliberately read-only. Rewriting WAL on a live cluster
567    /// without operator oversight is the canonical way to lose data;
568    /// the proxy refuses to do it.
569    pub async fn on_old_primary_recovered(&self, node_id: NodeId) {
570        let _ = self
571            .event_tx
572            .send(FailoverEvent::OldPrimaryRecovered { node_id })
573            .await;
574        tracing::warn!(
575            "old primary {:?} recovered — must be demoted out-of-band to prevent split-brain",
576            node_id
577        );
578
579        // Best-effort: if we have an endpoint and template, probe to
580        // confirm recovery state and shout extra-loud if it still
581        // thinks it's primary.
582        let endpoint = self
583            .candidates
584            .read()
585            .await
586            .get(&node_id)
587            .map(|c| c.endpoint.clone());
588        let cfg = match endpoint.as_ref().and_then(|e| self.backend_config_for(e)) {
589            Some(c) => c,
590            None => return, // no backend template → nothing more to do
591        };
592
593        match BackendClient::connect(&cfg).await {
594            Ok(mut client) => {
595                let in_recovery_result = client.query_scalar("SELECT pg_is_in_recovery()").await;
596                client.close().await;
597                if let Ok(tv) = in_recovery_result {
598                    if let Ok(Some(false)) = tv.as_bool("pg_is_in_recovery") {
599                        tracing::error!(
600                            "split-brain hazard: node {:?} recovered and still reports primary (pg_is_in_recovery=false). Shut it down or use pg_rewind before reintroducing.",
601                            node_id
602                        );
603                    }
604                }
605            }
606            Err(e) => {
607                tracing::debug!(
608                    error = %e,
609                    "could not connect to recovered node for split-brain probe"
610                );
611            }
612        }
613    }
614
615    /// Manual failover to specific node
616    pub async fn manual_failover(&self, target: NodeId) -> Result<()> {
617        // Verify target is a valid candidate
618        let candidates = self.candidates.read().await;
619        if !candidates.contains_key(&target) {
620            return Err(ProxyError::FailoverFailed(format!(
621                "Node {:?} is not a valid failover candidate",
622                target
623            )));
624        }
625        drop(candidates);
626
627        // Force failover to specific node
628        *self.state.write().await = FailoverState::InProgress;
629
630        let old_primary = self.current_primary.read().await.unwrap_or(NodeId::new());
631
632        let _ = self
633            .event_tx
634            .send(FailoverEvent::FailoverStarted {
635                from: old_primary,
636                to: target,
637            })
638            .await;
639
640        self.promote_standby(target).await?;
641
642        *self.current_primary.write().await = Some(target);
643        *self.state.write().await = FailoverState::Completed;
644        self.failover_count.fetch_add(1, Ordering::SeqCst);
645
646        Ok(())
647    }
648
649    /// Get failover count
650    pub fn failover_count(&self) -> u64 {
651        self.failover_count.load(Ordering::SeqCst)
652    }
653
654    /// Get failover history
655    pub async fn history(&self) -> Vec<FailoverHistoryEntry> {
656        self.history.read().await.clone()
657    }
658
659    /// Take the event receiver
660    pub fn take_event_receiver(&mut self) -> Option<mpsc::Receiver<FailoverEvent>> {
661        self.event_rx.take()
662    }
663
664    /// Coordinate transaction replay after failover (TR integration)
665    ///
666    /// This method orchestrates the replay of in-flight transactions on a new primary
667    /// after a failover event. It ensures transaction atomicity by:
668    /// 1. Getting all active transactions from the journal that were on the failed node
669    /// 2. Waiting for the new primary to catch up to the required LSN
670    /// 3. Replaying each transaction's statements on the new primary
671    /// 4. Verifying results match the original execution (via checksums)
672    #[cfg(feature = "ha-tr")]
673    pub async fn coordinate_failover_replay(
674        &self,
675        journal: &TransactionJournal,
676        failed_node: NodeId,
677        new_primary_endpoint: &NodeEndpoint,
678    ) -> Result<CoordinatedReplayResult> {
679        let start = std::time::Instant::now();
680
681        tracing::info!(
682            "Starting coordinated replay: failed_node={:?}, new_primary={:?}",
683            failed_node,
684            new_primary_endpoint.id
685        );
686
687        // 1. Get all active transactions that were on the failed node
688        let affected_txs = journal.get_transactions_for_node(failed_node).await;
689
690        if affected_txs.is_empty() {
691            tracing::info!("No active transactions to replay");
692            return Ok(CoordinatedReplayResult {
693                total_transactions: 0,
694                successful_replays: 0,
695                failed_replays: 0,
696                transaction_results: vec![],
697                duration_ms: start.elapsed().as_millis() as u64,
698                new_primary: new_primary_endpoint.id,
699            });
700        }
701
702        tracing::info!("Found {} active transactions to replay", affected_txs.len());
703
704        // 2. Get the maximum LSN we need to wait for
705        let max_lsn = affected_txs
706            .iter()
707            .map(|tx| tx.start_lsn)
708            .max()
709            .unwrap_or(0);
710
711        // 3. Wait for the new primary to catch up to this LSN
712        self.wait_for_lsn_catchup(new_primary_endpoint.id, max_lsn)
713            .await?;
714
715        // 4. Create replay manager and replay each transaction
716        let replay_manager = FailoverReplay::new(ReplayConfig {
717            verify_results: true,
718            statement_timeout_ms: 30000,
719            retry_on_error: true,
720            max_retries: 3,
721            skip_read_only: false,
722            wait_for_wal_sync: false, // Already waited above
723            max_wal_lag_bytes: 0,
724        });
725
726        let mut transaction_results = Vec::new();
727        let mut successful_replays = 0;
728        let mut failed_replays = 0;
729
730        for tx_journal in affected_txs {
731            let tx_id = tx_journal.tx_id;
732
733            tracing::debug!(
734                "Replaying transaction {:?} with {} entries",
735                tx_id,
736                tx_journal.entries.len()
737            );
738
739            // Start and execute replay
740            match replay_manager
741                .start_replay(tx_journal, new_primary_endpoint.id)
742                .await
743            {
744                Ok(_) => match replay_manager.execute_replay(tx_id).await {
745                    Ok(result) => {
746                        if result.success {
747                            successful_replays += 1;
748                            tracing::debug!("Transaction {:?} replayed successfully", tx_id);
749                        } else {
750                            failed_replays += 1;
751                            tracing::warn!(
752                                "Transaction {:?} replay failed: {:?}",
753                                tx_id,
754                                result.error
755                            );
756                        }
757                        transaction_results.push(result);
758                    }
759                    Err(e) => {
760                        failed_replays += 1;
761                        tracing::error!("Failed to execute replay for {:?}: {}", tx_id, e);
762                        transaction_results.push(ReplayResult {
763                            tx_id,
764                            success: false,
765                            statements_replayed: 0,
766                            statements_skipped: 0,
767                            statements_failed: 0,
768                            verification_failures: 0,
769                            duration_ms: 0,
770                            error: Some(e.to_string()),
771                            statement_results: vec![],
772                        });
773                    }
774                },
775                Err(e) => {
776                    failed_replays += 1;
777                    tracing::error!("Failed to start replay for {:?}: {}", tx_id, e);
778                    transaction_results.push(ReplayResult {
779                        tx_id,
780                        success: false,
781                        statements_replayed: 0,
782                        statements_skipped: 0,
783                        statements_failed: 0,
784                        verification_failures: 0,
785                        duration_ms: 0,
786                        error: Some(e.to_string()),
787                        statement_results: vec![],
788                    });
789                }
790            }
791        }
792
793        let duration_ms = start.elapsed().as_millis() as u64;
794
795        tracing::info!(
796            "Coordinated replay completed: {}/{} successful in {}ms",
797            successful_replays,
798            successful_replays + failed_replays,
799            duration_ms
800        );
801
802        Ok(CoordinatedReplayResult {
803            total_transactions: successful_replays + failed_replays,
804            successful_replays,
805            failed_replays,
806            transaction_results,
807            duration_ms,
808            new_primary: new_primary_endpoint.id,
809        })
810    }
811
812    /// Wait for a node to catch up to a specific LSN
813    #[cfg(feature = "ha-tr")]
814    async fn wait_for_lsn_catchup(&self, node: NodeId, target_lsn: u64) -> Result<()> {
815        if target_lsn == 0 {
816            return Ok(());
817        }
818
819        tracing::debug!(
820            "Waiting for node {:?} to catch up to LSN {}",
821            node,
822            target_lsn
823        );
824
825        // Use configured timeout
826        let timeout = self.config.failover_timeout;
827        let start = std::time::Instant::now();
828
829        loop {
830            if start.elapsed() >= timeout {
831                return Err(ProxyError::Timeout(format!(
832                    "Timeout waiting for node {:?} to catch up to LSN {}",
833                    node, target_lsn
834                )));
835            }
836
837            // Check if candidate has caught up
838            let candidates = self.candidates.read().await;
839            if let Some(candidate) = candidates.get(&node) {
840                // In a real implementation, we'd query the node's current LSN
841                // For now, we check if lag is acceptable
842                if candidate.lag_bytes == 0 {
843                    tracing::debug!("Node {:?} has caught up", node);
844                    return Ok(());
845                }
846            }
847            drop(candidates);
848
849            tokio::time::sleep(Duration::from_millis(100)).await;
850        }
851    }
852}
853
854/// Result of coordinated transaction replay after failover
855#[cfg(feature = "ha-tr")]
856#[derive(Debug, Clone)]
857pub struct CoordinatedReplayResult {
858    /// Total number of transactions replayed
859    pub total_transactions: usize,
860    /// Number of successful replays
861    pub successful_replays: usize,
862    /// Number of failed replays
863    pub failed_replays: usize,
864    /// Per-transaction replay results
865    pub transaction_results: Vec<ReplayResult>,
866    /// Total duration (ms)
867    pub duration_ms: u64,
868    /// New primary node ID
869    pub new_primary: NodeId,
870}
871
872#[cfg(feature = "ha-tr")]
873impl CoordinatedReplayResult {
874    /// Check if all transactions were replayed successfully
875    pub fn all_successful(&self) -> bool {
876        self.failed_replays == 0
877    }
878
879    /// Get the success rate as a percentage
880    pub fn success_rate(&self) -> f64 {
881        if self.total_transactions == 0 {
882            100.0
883        } else {
884            (self.successful_replays as f64 / self.total_transactions as f64) * 100.0
885        }
886    }
887}
888
889#[cfg(test)]
890mod tests {
891    use super::*;
892
893    #[test]
894    fn test_config_default() {
895        let config = FailoverConfig::default();
896        assert!(config.auto_failover);
897        assert!(config.prefer_sync_standby);
898        assert_eq!(config.max_retries, 3);
899    }
900
901    #[tokio::test]
902    async fn test_set_get_primary() {
903        let controller = FailoverController::new(FailoverConfig::default());
904        let node_id = NodeId::new();
905
906        controller.set_primary(node_id).await;
907        assert_eq!(controller.get_primary().await, Some(node_id));
908    }
909
910    #[tokio::test]
911    async fn test_register_candidate() {
912        let controller = FailoverController::new(FailoverConfig::default());
913        let node_id = NodeId::new();
914
915        let candidate = FailoverCandidate {
916            node_id,
917            endpoint: NodeEndpoint::new("localhost", 5432).with_role(NodeRole::Standby),
918            is_sync: true,
919            lag_bytes: 0,
920            priority: 1,
921            last_heartbeat: None,
922        };
923
924        controller.register_candidate(candidate).await;
925
926        let candidates = controller.candidates.read().await;
927        assert!(candidates.contains_key(&node_id));
928    }
929
930    #[tokio::test]
931    async fn test_state_transitions() {
932        let controller = FailoverController::new(FailoverConfig::default());
933
934        assert_eq!(controller.state().await, FailoverState::Normal);
935
936        *controller.state.write().await = FailoverState::PrimaryFailed;
937        assert_eq!(controller.state().await, FailoverState::PrimaryFailed);
938    }
939
940    #[tokio::test]
941    async fn test_select_best_candidate() {
942        let controller = FailoverController::new(FailoverConfig::default());
943
944        let sync_node = NodeId::new();
945        let async_node = NodeId::new();
946
947        controller
948            .register_candidate(FailoverCandidate {
949                node_id: async_node,
950                endpoint: NodeEndpoint::new("async", 5432),
951                is_sync: false,
952                lag_bytes: 100,
953                priority: 1,
954                last_heartbeat: None,
955            })
956            .await;
957
958        controller
959            .register_candidate(FailoverCandidate {
960                node_id: sync_node,
961                endpoint: NodeEndpoint::new("sync", 5432),
962                is_sync: true,
963                lag_bytes: 50,
964                priority: 2,
965                last_heartbeat: None,
966            })
967            .await;
968
969        let best = controller.select_best_candidate().await.unwrap();
970        // Sync standby should be preferred
971        assert_eq!(best.node_id, sync_node);
972    }
973
974    #[cfg(feature = "ha-tr")]
975    #[tokio::test]
976    async fn test_coordinate_failover_replay_empty() {
977        use super::super::transaction_journal::TransactionJournal;
978
979        let controller = FailoverController::new(FailoverConfig::default());
980        let journal = TransactionJournal::new();
981        let failed_node = NodeId::new();
982        let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
983
984        // With no transactions, should succeed immediately
985        let result = controller
986            .coordinate_failover_replay(&journal, failed_node, &new_primary)
987            .await
988            .unwrap();
989
990        assert_eq!(result.total_transactions, 0);
991        assert_eq!(result.successful_replays, 0);
992        assert_eq!(result.failed_replays, 0);
993        assert!(result.all_successful());
994        assert_eq!(result.success_rate(), 100.0);
995    }
996
997    #[cfg(feature = "ha-tr")]
998    #[tokio::test]
999    async fn test_coordinate_failover_replay_with_transactions() {
1000        use super::super::transaction_journal::{
1001            JournalEntry, JournalValue, StatementType, TransactionJournal,
1002        };
1003        use uuid::Uuid;
1004
1005        let controller = FailoverController::new(FailoverConfig::default());
1006        let journal = TransactionJournal::new();
1007        let failed_node = NodeId::new();
1008        let new_primary_id = NodeId::new();
1009        let new_primary = NodeEndpoint::new("new-primary", 5432).with_role(NodeRole::Primary);
1010
1011        // Register the new primary as a candidate with zero lag
1012        controller
1013            .register_candidate(FailoverCandidate {
1014                node_id: new_primary.id,
1015                endpoint: new_primary.clone(),
1016                is_sync: true,
1017                lag_bytes: 0,
1018                priority: 1,
1019                last_heartbeat: None,
1020            })
1021            .await;
1022
1023        // Create a transaction on the failed node
1024        let tx_id = Uuid::new_v4();
1025        let session_id = Uuid::new_v4();
1026        journal
1027            .begin_transaction(tx_id, session_id, failed_node, 100)
1028            .await
1029            .unwrap();
1030        journal
1031            .log_statement(
1032                tx_id,
1033                "INSERT INTO users (name) VALUES ('test')".to_string(),
1034                vec![JournalValue::Text("test".to_string())],
1035                Some(12345),
1036                Some(1),
1037                10,
1038            )
1039            .await
1040            .unwrap();
1041
1042        // Coordinate replay
1043        let result = controller
1044            .coordinate_failover_replay(&journal, failed_node, &new_primary)
1045            .await
1046            .unwrap();
1047
1048        assert_eq!(result.total_transactions, 1);
1049        assert_eq!(result.successful_replays, 1);
1050        assert_eq!(result.failed_replays, 0);
1051        assert!(result.all_successful());
1052    }
1053
1054    #[cfg(feature = "ha-tr")]
1055    #[test]
1056    fn test_coordinated_replay_result_methods() {
1057        let result = CoordinatedReplayResult {
1058            total_transactions: 10,
1059            successful_replays: 8,
1060            failed_replays: 2,
1061            transaction_results: vec![],
1062            duration_ms: 1000,
1063            new_primary: NodeId::new(),
1064        };
1065
1066        assert!(!result.all_successful());
1067        assert_eq!(result.success_rate(), 80.0);
1068
1069        let perfect = CoordinatedReplayResult {
1070            total_transactions: 5,
1071            successful_replays: 5,
1072            failed_replays: 0,
1073            transaction_results: vec![],
1074            duration_ms: 500,
1075            new_primary: NodeId::new(),
1076        };
1077
1078        assert!(perfect.all_successful());
1079        assert_eq!(perfect.success_rate(), 100.0);
1080    }
1081}