1use super::transaction_journal::{JournalEntry, JournalValue, StatementType, TransactionJournalEntry};
7use super::{NodeEndpoint, NodeId, ProxyError, Result};
8use crate::backend::{BackendClient, BackendConfig, ParamValue};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14#[derive(Debug, Clone)]
16pub struct ReplayConfig {
17 pub verify_results: bool,
19 pub statement_timeout_ms: u64,
21 pub retry_on_error: bool,
23 pub max_retries: u32,
25 pub skip_read_only: bool,
27 pub wait_for_wal_sync: bool,
29 pub max_wal_lag_bytes: u64,
31}
32
33impl Default for ReplayConfig {
34 fn default() -> Self {
35 Self {
36 verify_results: true,
37 statement_timeout_ms: 30000,
38 retry_on_error: true,
39 max_retries: 3,
40 skip_read_only: false,
41 wait_for_wal_sync: true,
42 max_wal_lag_bytes: 0, }
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct ReplayResult {
50 pub tx_id: Uuid,
52 pub success: bool,
54 pub statements_replayed: usize,
56 pub statements_skipped: usize,
58 pub statements_failed: usize,
60 pub verification_failures: usize,
62 pub duration_ms: u64,
64 pub error: Option<String>,
66 pub statement_results: Vec<StatementReplayResult>,
68}
69
70#[derive(Debug, Clone)]
72pub struct StatementReplayResult {
73 pub sequence: u64,
75 pub success: bool,
77 pub checksum_matched: Option<bool>,
79 pub rows_matched: Option<bool>,
81 pub duration_ms: u64,
83 pub error: Option<String>,
85 pub retries: u32,
87}
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
91pub enum ReplayState {
92 Pending,
94 WaitingForWal,
96 Replaying,
98 Verifying,
100 Completed,
102 Failed,
104}
105
106#[derive(Debug)]
108struct ActiveReplay {
109 tx_id: Uuid,
111 target_node: NodeId,
113 journal: TransactionJournalEntry,
115 state: ReplayState,
117 position: usize,
119 started_at: chrono::DateTime<chrono::Utc>,
121 results: Vec<StatementReplayResult>,
123}
124
125pub struct FailoverReplay {
127 config: ReplayConfig,
129 active_replays: Arc<RwLock<HashMap<Uuid, ActiveReplay>>>,
131 completed_replays: Arc<RwLock<Vec<ReplayResult>>>,
133 max_history: usize,
135 backend_template: Option<BackendConfig>,
140 endpoints: Arc<RwLock<HashMap<NodeId, NodeEndpoint>>>,
144}
145
146impl FailoverReplay {
147 pub fn new(config: ReplayConfig) -> Self {
149 Self {
150 config,
151 active_replays: Arc::new(RwLock::new(HashMap::new())),
152 completed_replays: Arc::new(RwLock::new(Vec::new())),
153 max_history: 100,
154 backend_template: None,
155 endpoints: Arc::new(RwLock::new(HashMap::new())),
156 }
157 }
158
159 pub fn with_backend_template(mut self, template: BackendConfig) -> Self {
162 self.backend_template = Some(template);
163 self
164 }
165
166 pub async fn register_endpoint(&self, node_id: NodeId, endpoint: NodeEndpoint) {
169 self.endpoints.write().await.insert(node_id, endpoint);
170 }
171
172 fn build_config(&self, endpoint: &NodeEndpoint) -> Option<BackendConfig> {
173 self.backend_template.as_ref().map(|t| {
174 let mut c = t.clone();
175 c.host = endpoint.host.clone();
176 c.port = endpoint.port;
177 c
178 })
179 }
180
181 pub async fn start_replay(
183 &self,
184 journal: TransactionJournalEntry,
185 target_node: NodeId,
186 ) -> Result<Uuid> {
187 let tx_id = journal.tx_id;
188
189 let replay = ActiveReplay {
190 tx_id,
191 target_node,
192 journal,
193 state: ReplayState::Pending,
194 position: 0,
195 started_at: chrono::Utc::now(),
196 results: Vec::new(),
197 };
198
199 self.active_replays.write().await.insert(tx_id, replay);
200
201 tracing::info!("Starting replay for transaction {:?} on node {:?}", tx_id, target_node);
202
203 Ok(tx_id)
204 }
205
206 pub async fn execute_replay(&self, tx_id: Uuid) -> Result<ReplayResult> {
208 let start = std::time::Instant::now();
209
210 let mut replays = self.active_replays.write().await;
212 let replay = replays.get_mut(&tx_id).ok_or_else(|| {
213 ProxyError::ReplayFailed(format!("No active replay for transaction {:?}", tx_id))
214 })?;
215
216 if self.config.wait_for_wal_sync {
218 replay.state = ReplayState::WaitingForWal;
219 self.wait_for_wal_sync(replay.target_node, replay.journal.start_lsn).await?;
220 }
221
222 replay.state = ReplayState::Replaying;
223
224 let entries = replay.journal.entries.clone();
225 let mut statements_replayed = 0;
226 let mut statements_skipped = 0;
227 let mut statements_failed = 0;
228 let mut verification_failures = 0;
229
230 for entry in &entries {
232 if self.config.skip_read_only && entry.statement_type.is_read_only() {
234 statements_skipped += 1;
235 replay.results.push(StatementReplayResult {
236 sequence: entry.sequence,
237 success: true,
238 checksum_matched: None,
239 rows_matched: None,
240 duration_ms: 0,
241 error: None,
242 retries: 0,
243 });
244 continue;
245 }
246
247 if entry.statement_type == StatementType::Transaction {
249 statements_skipped += 1;
250 continue;
251 }
252
253 let result = self.replay_statement(entry, replay.target_node).await;
254
255 match result {
256 Ok(stmt_result) => {
257 if stmt_result.success {
258 statements_replayed += 1;
259
260 if self.config.verify_results {
262 if let Some(false) = stmt_result.checksum_matched {
263 verification_failures += 1;
264 }
265 }
266 } else {
267 statements_failed += 1;
268 }
269 replay.results.push(stmt_result);
270 }
271 Err(e) => {
272 statements_failed += 1;
273 replay.results.push(StatementReplayResult {
274 sequence: entry.sequence,
275 success: false,
276 checksum_matched: None,
277 rows_matched: None,
278 duration_ms: 0,
279 error: Some(e.to_string()),
280 retries: 0,
281 });
282 }
283 }
284
285 replay.position += 1;
286 }
287
288 replay.state = if statements_failed > 0 {
289 ReplayState::Failed
290 } else {
291 ReplayState::Completed
292 };
293
294 let duration_ms = start.elapsed().as_millis() as u64;
295
296 let result = ReplayResult {
297 tx_id,
298 success: statements_failed == 0 && verification_failures == 0,
299 statements_replayed,
300 statements_skipped,
301 statements_failed,
302 verification_failures,
303 duration_ms,
304 error: if statements_failed > 0 {
305 Some("Some statements failed during replay".to_string())
306 } else if verification_failures > 0 {
307 Some("Result verification failed".to_string())
308 } else {
309 None
310 },
311 statement_results: replay.results.clone(),
312 };
313
314 drop(replays);
316 self.active_replays.write().await.remove(&tx_id);
317 self.add_to_history(result.clone()).await;
318
319 tracing::info!(
320 "Replay completed for {:?}: {} replayed, {} failed, {}ms",
321 tx_id,
322 statements_replayed,
323 statements_failed,
324 duration_ms
325 );
326
327 Ok(result)
328 }
329
330 async fn replay_statement(
332 &self,
333 entry: &JournalEntry,
334 target_node: NodeId,
335 ) -> Result<StatementReplayResult> {
336 let start = std::time::Instant::now();
337 let mut retries = 0;
338
339 loop {
340 let (success, checksum_matched, rows_matched, error_msg) =
341 self.execute_statement(entry, target_node).await;
342
343 if success || !self.config.retry_on_error || retries >= self.config.max_retries {
344 return Ok(StatementReplayResult {
345 sequence: entry.sequence,
346 success,
347 checksum_matched: if self.config.verify_results && entry.result_checksum.is_some() {
348 Some(checksum_matched)
349 } else {
350 None
351 },
352 rows_matched: if entry.rows_affected.is_some() {
353 Some(rows_matched)
354 } else {
355 None
356 },
357 duration_ms: start.elapsed().as_millis() as u64,
358 error: if success {
359 None
360 } else {
361 Some(error_msg.unwrap_or_else(|| {
362 "statement execution failed".to_string()
363 }))
364 },
365 retries,
366 });
367 }
368
369 retries += 1;
370 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
371 }
372 }
373
374 async fn execute_statement(
381 &self,
382 entry: &JournalEntry,
383 target_node: NodeId,
384 ) -> (bool, bool, bool, Option<String>) {
385 let endpoint = self.endpoints.read().await.get(&target_node).cloned();
386 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
387 Some(c) => c,
388 None => return (true, true, true, None),
389 };
390
391 let mut client = match BackendClient::connect(&cfg).await {
392 Ok(c) => c,
393 Err(e) => return (false, false, false, Some(format!("connect: {}", e))),
394 };
395
396 let params: Vec<ParamValue> =
397 entry.parameters.iter().map(journal_value_to_param).collect();
398
399 let result = if params.is_empty() {
400 client.simple_query(&entry.statement).await
401 } else {
402 client.query_with_params(&entry.statement, ¶ms).await
403 };
404
405 let outcome = match result {
406 Ok(qr) => {
407 let rows_matched = match entry.rows_affected {
408 Some(expected) => qr.rows_affected() == Some(expected),
409 None => true,
410 };
411 let checksum_matched = entry.result_checksum.is_none();
416 (true, checksum_matched, rows_matched, None)
417 }
418 Err(e) => (false, false, false, Some(e.to_string())),
419 };
420 client.close().await;
421 outcome
422 }
423
424 async fn wait_for_wal_sync(&self, node: NodeId, start_lsn: u64) -> Result<()> {
431 let endpoint = self.endpoints.read().await.get(&node).cloned();
432 let cfg = match endpoint.as_ref().and_then(|e| self.build_config(e)) {
433 Some(c) => c,
434 None => {
435 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
437 return Ok(());
438 }
439 };
440
441 let timeout = std::time::Duration::from_millis(self.config.statement_timeout_ms);
442 tokio::time::timeout(timeout, Self::poll_wal_lsn(cfg, start_lsn))
443 .await
444 .map_err(|_| ProxyError::Timeout("WAL sync wait timeout".into()))??;
445 Ok(())
446 }
447
448 async fn poll_wal_lsn(cfg: BackendConfig, target: u64) -> Result<()> {
449 let mut client = BackendClient::connect(&cfg)
450 .await
451 .map_err(|e| ProxyError::ReplayFailed(format!("connect: {}", e)))?;
452 loop {
453 let value = client
454 .query_scalar("SELECT pg_last_wal_replay_lsn()::text")
455 .await
456 .map_err(|e| ProxyError::ReplayFailed(format!("lsn probe: {}", e)))?;
457 if let Some(s) = value.into_string() {
458 if let Some(current) = pg_lsn_to_u64(&s) {
459 if current >= target {
460 client.close().await;
461 return Ok(());
462 }
463 }
464 }
465 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
466 }
467 }
468
469 async fn add_to_history(&self, result: ReplayResult) {
471 let mut history = self.completed_replays.write().await;
472 history.push(result);
473
474 if history.len() > self.max_history {
476 history.remove(0);
477 }
478 }
479
480 pub async fn get_state(&self, tx_id: &Uuid) -> Option<ReplayState> {
482 self.active_replays
483 .read()
484 .await
485 .get(tx_id)
486 .map(|r| r.state)
487 }
488
489 pub async fn get_progress(&self, tx_id: &Uuid) -> Option<(usize, usize)> {
491 self.active_replays.read().await.get(tx_id).map(|r| {
492 (r.position, r.journal.entries.len())
493 })
494 }
495
496 pub async fn cancel_replay(&self, tx_id: &Uuid) -> Result<()> {
498 self.active_replays.write().await.remove(tx_id);
499 tracing::info!("Cancelled replay for transaction {:?}", tx_id);
500 Ok(())
501 }
502
503 pub async fn history(&self) -> Vec<ReplayResult> {
505 self.completed_replays.read().await.clone()
506 }
507
508 pub async fn stats(&self) -> ReplayStats {
510 let history = self.completed_replays.read().await;
511 let successful = history.iter().filter(|r| r.success).count();
512 let total_statements: usize = history.iter().map(|r| r.statements_replayed).sum();
513
514 ReplayStats {
515 active_replays: self.active_replays.read().await.len(),
516 completed_replays: history.len(),
517 successful_replays: successful,
518 total_statements_replayed: total_statements,
519 }
520 }
521}
522
523fn journal_value_to_param(v: &JournalValue) -> ParamValue {
526 match v {
527 JournalValue::Null => ParamValue::Null,
528 JournalValue::Bool(b) => ParamValue::Bool(*b),
529 JournalValue::Int64(i) => ParamValue::Int(*i),
530 JournalValue::Float64(f) => ParamValue::Float(*f),
531 JournalValue::Text(s) => ParamValue::Text(s.clone()),
532 JournalValue::Bytes(b) => {
533 let mut s = String::with_capacity(2 + b.len() * 2);
535 s.push_str("\\x");
536 for byte in b {
537 s.push_str(&format!("{:02x}", byte));
538 }
539 ParamValue::Text(s)
540 }
541 JournalValue::Array(_) => {
542 ParamValue::Null
546 }
547 }
548}
549
550fn pg_lsn_to_u64(s: &str) -> Option<u64> {
554 let (hi, lo) = s.split_once('/')?;
555 let hi = u64::from_str_radix(hi.trim(), 16).ok()?;
556 let lo = u64::from_str_radix(lo.trim(), 16).ok()?;
557 if lo > u64::from(u32::MAX) {
558 return None;
559 }
560 Some((hi << 32) | lo)
561}
562
563#[derive(Debug, Clone)]
565pub struct ReplayStats {
566 pub active_replays: usize,
568 pub completed_replays: usize,
570 pub successful_replays: usize,
572 pub total_statements_replayed: usize,
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579 use super::super::transaction_journal::TransactionJournalEntry;
580
581 fn make_journal() -> TransactionJournalEntry {
582 let tx_id = Uuid::new_v4();
583 let session_id = Uuid::new_v4();
584 let node_id = NodeId::new();
585
586 let mut journal = TransactionJournalEntry::new(tx_id, session_id, node_id, 0);
587
588 journal.add_entry(JournalEntry {
589 sequence: 1,
590 statement: "INSERT INTO users (name) VALUES ('test')".to_string(),
591 parameters: vec![],
592 result_checksum: Some(12345),
593 rows_affected: Some(1),
594 timestamp: chrono::Utc::now(),
595 statement_type: StatementType::Insert,
596 duration_ms: 10,
597 });
598
599 journal.add_entry(JournalEntry {
600 sequence: 2,
601 statement: "SELECT * FROM users".to_string(),
602 parameters: vec![],
603 result_checksum: Some(67890),
604 rows_affected: None,
605 timestamp: chrono::Utc::now(),
606 statement_type: StatementType::Select,
607 duration_ms: 5,
608 });
609
610 journal
611 }
612
613 #[test]
614 fn test_config_default() {
615 let config = ReplayConfig::default();
616 assert!(config.verify_results);
617 assert!(config.retry_on_error);
618 assert!(config.wait_for_wal_sync);
619 }
620
621 #[test]
624 fn test_pg_lsn_to_u64_roundtrip() {
625 assert_eq!(pg_lsn_to_u64("0/0"), Some(0));
626 assert_eq!(pg_lsn_to_u64("0/1"), Some(1));
627 assert_eq!(pg_lsn_to_u64("0/FFFFFFFF"), Some(0xFFFFFFFF));
628 assert_eq!(
629 pg_lsn_to_u64("1/0"),
630 Some(1u64 << 32)
631 );
632 assert_eq!(
633 pg_lsn_to_u64("16/B3780A90"),
634 Some((0x16u64 << 32) | 0xB3780A90u64)
635 );
636 assert!(pg_lsn_to_u64("0/A").unwrap() < pg_lsn_to_u64("0/B").unwrap());
638 assert!(pg_lsn_to_u64("0/FFFFFFFF").unwrap() < pg_lsn_to_u64("1/0").unwrap());
639 }
640
641 #[test]
642 fn test_pg_lsn_to_u64_rejects_malformed() {
643 assert!(pg_lsn_to_u64("no-slash").is_none());
644 assert!(pg_lsn_to_u64("/lo-only").is_none());
645 assert!(pg_lsn_to_u64("hi-only/").is_none());
646 assert!(pg_lsn_to_u64("zz/zz").is_none());
647 assert!(pg_lsn_to_u64("0/100000000").is_none());
649 }
650
651 #[test]
652 fn test_journal_value_to_param_basic_types() {
653 use crate::backend::ParamValue;
654
655 assert!(matches!(
656 journal_value_to_param(&JournalValue::Null),
657 ParamValue::Null
658 ));
659 assert!(matches!(
660 journal_value_to_param(&JournalValue::Bool(true)),
661 ParamValue::Bool(true)
662 ));
663 assert!(matches!(
664 journal_value_to_param(&JournalValue::Int64(42)),
665 ParamValue::Int(42)
666 ));
667 match journal_value_to_param(&JournalValue::Float64(3.14)) {
668 ParamValue::Float(f) => assert!((f - 3.14).abs() < 1e-9),
669 other => panic!("expected Float, got {:?}", other),
670 }
671 match journal_value_to_param(&JournalValue::Text("hi".into())) {
672 ParamValue::Text(s) => assert_eq!(s, "hi"),
673 other => panic!("expected Text, got {:?}", other),
674 }
675 }
676
677 #[test]
678 fn test_journal_value_bytes_to_hex_escape() {
679 use crate::backend::ParamValue;
680 let v = journal_value_to_param(&JournalValue::Bytes(vec![0xDE, 0xAD, 0xBE, 0xEF]));
681 match v {
682 ParamValue::Text(s) => assert_eq!(s, "\\xdeadbeef"),
683 other => panic!("expected Text, got {:?}", other),
684 }
685 }
686
687 #[tokio::test]
688 async fn test_start_replay() {
689 let replay = FailoverReplay::new(ReplayConfig::default());
690 let journal = make_journal();
691 let tx_id = journal.tx_id;
692 let target = NodeId::new();
693
694 let result_tx_id = replay.start_replay(journal, target).await.unwrap();
695 assert_eq!(result_tx_id, tx_id);
696
697 let state = replay.get_state(&tx_id).await;
698 assert_eq!(state, Some(ReplayState::Pending));
699 }
700
701 #[tokio::test]
702 async fn test_execute_replay() {
703 let replay = FailoverReplay::new(ReplayConfig::default());
704 let journal = make_journal();
705 let tx_id = journal.tx_id;
706 let target = NodeId::new();
707
708 replay.start_replay(journal, target).await.unwrap();
709 let result = replay.execute_replay(tx_id).await.unwrap();
710
711 assert!(result.success);
712 assert_eq!(result.statements_replayed, 2);
713 assert_eq!(result.statements_failed, 0);
714 }
715
716 #[tokio::test]
717 async fn test_cancel_replay() {
718 let replay = FailoverReplay::new(ReplayConfig::default());
719 let journal = make_journal();
720 let tx_id = journal.tx_id;
721 let target = NodeId::new();
722
723 replay.start_replay(journal, target).await.unwrap();
724 replay.cancel_replay(&tx_id).await.unwrap();
725
726 assert!(replay.get_state(&tx_id).await.is_none());
727 }
728
729 #[tokio::test]
730 async fn test_stats() {
731 let replay = FailoverReplay::new(ReplayConfig::default());
732
733 let stats = replay.stats().await;
734 assert_eq!(stats.active_replays, 0);
735 assert_eq!(stats.completed_replays, 0);
736 }
737}