1use super::transaction_journal::{
7 JournalEntry, JournalValue, StatementType, TransactionJournalEntry,
8};
9use super::{NodeEndpoint, NodeId, ProxyError, Result};
10use crate::backend::{BackendClient, BackendConfig, ParamValue};
11use std::collections::HashMap;
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use uuid::Uuid;
15
16#[derive(Debug, Clone)]
18pub struct ReplayConfig {
19 pub verify_results: bool,
21 pub statement_timeout_ms: u64,
23 pub retry_on_error: bool,
25 pub max_retries: u32,
27 pub skip_read_only: bool,
29 pub wait_for_wal_sync: bool,
31 pub max_wal_lag_bytes: u64,
33}
34
35impl Default for ReplayConfig {
36 fn default() -> Self {
37 Self {
38 verify_results: true,
39 statement_timeout_ms: 30000,
40 retry_on_error: true,
41 max_retries: 3,
42 skip_read_only: false,
43 wait_for_wal_sync: true,
44 max_wal_lag_bytes: 0, }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct ReplayResult {
52 pub tx_id: Uuid,
54 pub success: bool,
56 pub statements_replayed: usize,
58 pub statements_skipped: usize,
60 pub statements_failed: usize,
62 pub verification_failures: usize,
64 pub duration_ms: u64,
66 pub error: Option<String>,
68 pub statement_results: Vec<StatementReplayResult>,
70}
71
72#[derive(Debug, Clone)]
74pub struct StatementReplayResult {
75 pub sequence: u64,
77 pub success: bool,
79 pub checksum_matched: Option<bool>,
81 pub rows_matched: Option<bool>,
83 pub duration_ms: u64,
85 pub error: Option<String>,
87 pub retries: u32,
89}
90
91#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum ReplayState {
94 Pending,
96 WaitingForWal,
98 Replaying,
100 Verifying,
102 Completed,
104 Failed,
106}
107
108#[derive(Debug)]
110struct ActiveReplay {
111 #[allow(dead_code)]
113 tx_id: Uuid,
114 target_node: NodeId,
116 journal: TransactionJournalEntry,
118 state: ReplayState,
120 position: usize,
122 #[allow(dead_code)]
124 started_at: chrono::DateTime<chrono::Utc>,
125 results: Vec<StatementReplayResult>,
127}
128
129pub struct FailoverReplay {
131 config: ReplayConfig,
133 active_replays: Arc<RwLock<HashMap<Uuid, ActiveReplay>>>,
135 completed_replays: Arc<RwLock<Vec<ReplayResult>>>,
137 max_history: usize,
139 backend_template: Option<BackendConfig>,
144 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
148}
149
150impl FailoverReplay {
151 pub fn new(config: ReplayConfig) -> Self {
153 Self {
154 config,
155 active_replays: Arc::new(RwLock::new(HashMap::new())),
156 completed_replays: Arc::new(RwLock::new(Vec::new())),
157 max_history: 100,
158 backend_template: None,
159 endpoints: Arc::new(RwLock::new(HashMap::new())),
160 }
161 }
162
163 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
166 self.backend_template = Some(template);
167 self
168 }
169
170 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
173 self.endpoints.write().await.insert(node_id, endpoint);
174 }
175
176 fn build_config(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
177 self.backend_template.as_ref().map(|t| {
178 let mut c = t.clone();
179 c.host = endpoint.host.clone();
180 c.port = endpoint.port;
181 c
182 })
183 }
184
185 pub async fn start_replay(
187 &self,
188 journal: TransactionJournalEntry,
189 target_node: NodeId,
190 ) -> Result<Uuid> {
191 let tx_id = journal.tx_id;
192
193 let replay = ActiveReplay {
194 tx_id,
195 target_node,
196 journal,
197 state: ReplayState::Pending,
198 position: 0,
199 started_at: chrono::Utc::now(),
200 results: Vec::new(),
201 };
202
203 self.active_replays.write().await.insert(tx_id, replay);
204
205 tracing::info!(
206 "Starting replay for transaction {:?} on node {:?}",
207 tx_id,
208 target_node
209 );
210
211 Ok(tx_id)
212 }
213
214 pub async fn execute_replay(&self, tx_id: Uuid) -> Result<ReplayResult> {
216 let start = std::time::Instant::now();
217
218 let mut replays = self.active_replays.write().await;
220 let replay = replays.get_mut(&tx_id).ok_or_else(|| {
221 ProxyError::ReplayFailed(format!("No active replay for transaction {:?}", tx_id))
222 })?;
223
224 if self.config.wait_for_wal_sync {
226 replay.state = ReplayState::WaitingForWal;
227 self.wait_for_wal_sync(replay.target_node, replay.journal.start_lsn)
228 .await?;
229 }
230
231 replay.state = ReplayState::Replaying;
232
233 let entries = replay.journal.entries.clone();
234 let mut statements_replayed = 0;
235 let mut statements_skipped = 0;
236 let mut statements_failed = 0;
237 let mut verification_failures = 0;
238
239 for entry in &entries {
241 if self.config.skip_read_only && entry.statement_type.is_read_only() {
243 statements_skipped += 1;
244 replay.results.push(StatementReplayResult {
245 sequence: entry.sequence,
246 success: true,
247 checksum_matched: None,
248 rows_matched: None,
249 duration_ms: 0,
250 error: None,
251 retries: 0,
252 });
253 continue;
254 }
255
256 if entry.statement_type == StatementType::Transaction {
258 statements_skipped += 1;
259 continue;
260 }
261
262 let result = self.replay_statement(entry, replay.target_node).await;
263
264 match result {
265 Ok(stmt_result) => {
266 if stmt_result.success {
267 statements_replayed += 1;
268
269 if self.config.verify_results {
271 if let Some(false) = stmt_result.checksum_matched {
272 verification_failures += 1;
273 }
274 }
275 } else {
276 statements_failed += 1;
277 }
278 replay.results.push(stmt_result);
279 }
280 Err(e) => {
281 statements_failed += 1;
282 replay.results.push(StatementReplayResult {
283 sequence: entry.sequence,
284 success: false,
285 checksum_matched: None,
286 rows_matched: None,
287 duration_ms: 0,
288 error: Some(e.to_string()),
289 retries: 0,
290 });
291 }
292 }
293
294 replay.position += 1;
295 }
296
297 replay.state = if statements_failed > 0 {
298 ReplayState::Failed
299 } else {
300 ReplayState::Completed
301 };
302
303 let duration_ms = start.elapsed().as_millis() as u64;
304
305 let result = ReplayResult {
306 tx_id,
307 success: statements_failed == 0 && verification_failures == 0,
308 statements_replayed,
309 statements_skipped,
310 statements_failed,
311 verification_failures,
312 duration_ms,
313 error: if statements_failed > 0 {
314 Some("Some statements failed during replay".to_string())
315 } else if verification_failures > 0 {
316 Some("Result verification failed".to_string())
317 } else {
318 None
319 },
320 statement_results: replay.results.clone(),
321 };
322
323 drop(replays);
325 self.active_replays.write().await.remove(&tx_id);
326 self.add_to_history(result.clone()).await;
327
328 tracing::info!(
329 "Replay completed for {:?}: {} replayed, {} failed, {}ms",
330 tx_id,
331 statements_replayed,
332 statements_failed,
333 duration_ms
334 );
335
336 Ok(result)
337 }
338
339 async fn replay_statement(
341 &self,
342 entry: &JournalEntry,
343 target_node: NodeId,
344 ) -> Result<StatementReplayResult> {
345 let start = std::time::Instant::now();
346 let mut retries = 0;
347
348 loop {
349 let (success, checksum_matched, rows_matched, error_msg) =
350 self.execute_statement(entry, target_node).await;
351
352 if success || !self.config.retry_on_error || retries >= self.config.max_retries {
353 return Ok(StatementReplayResult {
354 sequence: entry.sequence,
355 success,
356 checksum_matched: if self.config.verify_results
357 && entry.result_checksum.is_some()
358 {
359 Some(checksum_matched)
360 } else {
361 None
362 },
363 rows_matched: if entry.rows_affected.is_some() {
364 Some(rows_matched)
365 } else {
366 None
367 },
368 duration_ms: start.elapsed().as_millis() as u64,
369 error: if success {
370 None
371 } else {
372 Some(error_msg.unwrap_or_else(|| "statement execution failed".to_string()))
373 },
374 retries,
375 });
376 }
377
378 retries += 1;
379 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
380 }
381 }
382
383 async fn execute_statement(
390 &self,
391 entry: &JournalEntry,
392 target_node: NodeId,
393 ) -> (bool, bool, bool, Option<String>) {
394 let endpoint = self.endpoints.read().await.get(&target_node).cloned();
395 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
396 Some(c) => c,
397 None => return (true, true, true, None),
398 };
399
400 let mut client = match BackendClient::connect(&cfg).await {
401 Ok(c) => c,
402 Err(e) => return (false, false, false, Some(format!("connect: {}", e))),
403 };
404
405 let params: Vec<ParamValue> = entry
406 .parameters
407 .iter()
408 .map(journal_value_to_param)
409 .collect();
410
411 let result = if params.is_empty() {
412 client.simple_query(&entry.statement).await
413 } else {
414 client.query_with_params(&entry.statement, ¶ms).await
415 };
416
417 let outcome = match result {
418 Ok(qr) => {
419 let rows_matched = match entry.rows_affected {
420 Some(expected) => qr.rows_affected() == Some(expected),
421 None => true,
422 };
423 let checksum_matched = entry.result_checksum.is_none();
428 (true, checksum_matched, rows_matched, None)
429 }
430 Err(e) => (false, false, false, Some(e.to_string())),
431 };
432 client.close().await;
433 outcome
434 }
435
436 async fn wait_for_wal_sync(&self, node: NodeId, start_lsn: u64) -> Result<()> {
443 let endpoint = self.endpoints.read().await.get(&node).cloned();
444 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
445 Some(c) => c,
446 None => {
447 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
449 return Ok(());
450 }
451 };
452
453 let timeout = std::time::Duration::from_millis(self.config.statement_timeout_ms);
454 tokio::time::timeout(timeout, Self::poll_wal_lsn(cfg, start_lsn))
455 .await
456 .map_err(|_| ProxyError::Timeout("WAL sync wait timeout".into()))??;
457 Ok(())
458 }
459
460 async fn poll_wal_lsn(cfg: BackendConfig, target: u64) -> Result<()> {
461 let mut client = BackendClient::connect(&cfg)
462 .await
463 .map_err(|e| ProxyError::ReplayFailed(format!("connect: {}", e)))?;
464 loop {
465 let value = client
466 .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
467 .await
468 .map_err(|e| ProxyError::ReplayFailed(format!("lsn probe: {}", e)))?;
469 if let Some(s) = value.into_string() {
470 if let Some(current) = pg_lsn_to_u64(&s) {
471 if current >= target {
472 client.close().await;
473 return Ok(());
474 }
475 }
476 }
477 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
478 }
479 }
480
481 async fn add_to_history(&self, result: ReplayResult) {
483 let mut history = self.completed_replays.write().await;
484 history.push(result);
485
486 if history.len() > self.max_history {
488 history.remove(0);
489 }
490 }
491
492 pub async fn get_state(&self, tx_id: &Uuid) -> Option<ReplayState> {
494 self.active_replays.read().await.get(tx_id).map(|r| r.state)
495 }
496
497 pub async fn get_progress(&self, tx_id: &Uuid) -> Option<(usize, usize)> {
499 self.active_replays
500 .read()
501 .await
502 .get(tx_id)
503 .map(|r| (r.position, r.journal.entries.len()))
504 }
505
506 pub async fn cancel_replay(&self, tx_id: &Uuid) -> Result<()> {
508 self.active_replays.write().await.remove(tx_id);
509 tracing::info!("Cancelled replay for transaction {:?}", tx_id);
510 Ok(())
511 }
512
513 pub async fn history(&self) -> Vec<ReplayResult> {
515 self.completed_replays.read().await.clone()
516 }
517
518 pub async fn stats(&self) -> ReplayStats {
520 let history = self.completed_replays.read().await;
521 let successful = history.iter().filter(|r| r.success).count();
522 let total_statements: usize = history.iter().map(|r| r.statements_replayed).sum();
523
524 ReplayStats {
525 active_replays: self.active_replays.read().await.len(),
526 completed_replays: history.len(),
527 successful_replays: successful,
528 total_statements_replayed: total_statements,
529 }
530 }
531}
532
533fn journal_value_to_param(v: &JournalValue) -> ParamValue {
536 match v {
537 JournalValue::Null => ParamValue::Null,
538 JournalValue::Bool(b) => ParamValue::Bool(*b),
539 JournalValue::Int64(i) => ParamValue::Int(*i),
540 JournalValue::Float64(f) => ParamValue::Float(*f),
541 JournalValue::Text(s) => ParamValue::Text(s.clone()),
542 JournalValue::Bytes(b) => {
543 let mut s = String::with_capacity(2 + b.len() * 2);
545 s.push_str("\\x");
546 for byte in b {
547 s.push_str(&format!("{:02x}", byte));
548 }
549 ParamValue::Text(s)
550 }
551 JournalValue::Array(_) => {
552 ParamValue::Null
556 }
557 }
558}
559
560fn pg_lsn_to_u64(s: &str) -> Option<u64> {
564 let (hi, lo) = s.split_once('/')?;
565 let hi = u64::from_str_radix(hi.trim(), 16).ok()?;
566 let lo = u64::from_str_radix(lo.trim(), 16).ok()?;
567 if lo > u64::from(u32::MAX) {
568 return None;
569 }
570 Some((hi << 32) | lo)
571}
572
573#[derive(Debug, Clone)]
575pub struct ReplayStats {
576 pub active_replays: usize,
578 pub completed_replays: usize,
580 pub successful_replays: usize,
582 pub total_statements_replayed: usize,
584}
585
586#[cfg(test)]
587mod tests {
588 use super::super::transaction_journal::TransactionJournalEntry;
589 use super::*;
590
591 fn make_journal() -> TransactionJournalEntry {
592 let tx_id = Uuid::new_v4();
593 let session_id = Uuid::new_v4();
594 let node_id = NodeId::new();
595
596 let mut journal = TransactionJournalEntry::new(tx_id, session_id, node_id, 0);
597
598 journal.add_entry(JournalEntry {
599 sequence: 1,
600 statement: "INSERT INTO users (name) VALUES ('test')".to_string(),
601 parameters: vec![],
602 result_checksum: Some(12345),
603 rows_affected: Some(1),
604 timestamp: chrono::Utc::now(),
605 statement_type: StatementType::Insert,
606 duration_ms: 10,
607 });
608
609 journal.add_entry(JournalEntry {
610 sequence: 2,
611 statement: "SELECT * FROM users".to_string(),
612 parameters: vec![],
613 result_checksum: Some(67890),
614 rows_affected: None,
615 timestamp: chrono::Utc::now(),
616 statement_type: StatementType::Select,
617 duration_ms: 5,
618 });
619
620 journal
621 }
622
623 #[test]
624 fn test_config_default() {
625 let config = ReplayConfig::default();
626 assert!(config.verify_results);
627 assert!(config.retry_on_error);
628 assert!(config.wait_for_wal_sync);
629 }
630
631 #[test]
634 fn test_pg_lsn_to_u64_roundtrip() {
635 assert_eq!(pg_lsn_to_u64("0/0"), Some(0));
636 assert_eq!(pg_lsn_to_u64("0/1"), Some(1));
637 assert_eq!(pg_lsn_to_u64("0/FFFFFFFF"), Some(0xFFFFFFFF));
638 assert_eq!(pg_lsn_to_u64("1/0"), Some(1u64 << 32));
639 assert_eq!(
640 pg_lsn_to_u64("16/B3780A90"),
641 Some((0x16u64 << 32) | 0xB3780A90u64)
642 );
643 assert!(pg_lsn_to_u64("0/A").unwrap() < pg_lsn_to_u64("0/B").unwrap());
645 assert!(pg_lsn_to_u64("0/FFFFFFFF").unwrap() < pg_lsn_to_u64("1/0").unwrap());
646 }
647
648 #[test]
649 fn test_pg_lsn_to_u64_rejects_malformed() {
650 assert!(pg_lsn_to_u64("no-slash").is_none());
651 assert!(pg_lsn_to_u64("/lo-only").is_none());
652 assert!(pg_lsn_to_u64("hi-only/").is_none());
653 assert!(pg_lsn_to_u64("zz/zz").is_none());
654 assert!(pg_lsn_to_u64("0/100000000").is_none());
656 }
657
658 #[test]
659 fn test_journal_value_to_param_basic_types() {
660 use crate::backend::ParamValue;
661
662 assert!(matches!(
663 journal_value_to_param(&JournalValue::Null),
664 ParamValue::Null
665 ));
666 assert!(matches!(
667 journal_value_to_param(&JournalValue::Bool(true)),
668 ParamValue::Bool(true)
669 ));
670 assert!(matches!(
671 journal_value_to_param(&JournalValue::Int64(42)),
672 ParamValue::Int(42)
673 ));
674 match journal_value_to_param(&JournalValue::Float64(3.14)) {
675 ParamValue::Float(f) => assert!((f - 3.14).abs() < 1e-9),
676 other => panic!("expected Float, got {:?}", other),
677 }
678 match journal_value_to_param(&JournalValue::Text("hi".into())) {
679 ParamValue::Text(s) => assert_eq!(s, "hi"),
680 other => panic!("expected Text, got {:?}", other),
681 }
682 }
683
684 #[test]
685 fn test_journal_value_bytes_to_hex_escape() {
686 use crate::backend::ParamValue;
687 let v = journal_value_to_param(&JournalValue::Bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]));
688 match v {
689 ParamValue::Text(s) => assert_eq!(s, "\\xdeadbeef"),
690 other => panic!("expected Text, got {:?}", other),
691 }
692 }
693
694 #[tokio::test]
695 async fn test_start_replay() {
696 let replay = FailoverReplay::new(ReplayConfig::default());
697 let journal = make_journal();
698 let tx_id = journal.tx_id;
699 let target = NodeId::new();
700
701 let result_tx_id = replay.start_replay(journal, target).await.unwrap();
702 assert_eq!(result_tx_id, tx_id);
703
704 let state = replay.get_state(&tx_id).await;
705 assert_eq!(state, Some(ReplayState::Pending));
706 }
707
708 #[tokio::test]
709 async fn test_execute_replay() {
710 let replay = FailoverReplay::new(ReplayConfig::default());
711 let journal = make_journal();
712 let tx_id = journal.tx_id;
713 let target = NodeId::new();
714
715 replay.start_replay(journal, target).await.unwrap();
716 let result = replay.execute_replay(tx_id).await.unwrap();
717
718 assert!(result.success);
719 assert_eq!(result.statements_replayed, 2);
720 assert_eq!(result.statements_failed, 0);
721 }
722
723 #[tokio::test]
724 async fn test_cancel_replay() {
725 let replay = FailoverReplay::new(ReplayConfig::default());
726 let journal = make_journal();
727 let tx_id = journal.tx_id;
728 let target = NodeId::new();
729
730 replay.start_replay(journal, target).await.unwrap();
731 replay.cancel_replay(&tx_id).await.unwrap();
732
733 assert!(replay.get_state(&tx_id).await.is_none());
734 }
735
736 #[tokio::test]
737 async fn test_stats() {
738 let replay = FailoverReplay::new(ReplayConfig::default());
739
740 let stats = replay.stats().await;
741 assert_eq!(stats.active_replays, 0);
742 assert_eq!(stats.completed_replays, 0);
743 }
744}