database_replicator/
checkpoint.rs1use anyhow::{bail, Context, Result};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::collections::BTreeSet;
8use std::fs;
9use std::path::{Path, PathBuf};
10
11const INIT_CHECKPOINT_VERSION: u32 = 1;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14pub struct InitCheckpointMetadata {
15 pub source_hash: String,
16 pub target_hash: String,
17 pub filter_hash: String,
18 pub drop_existing: bool,
19 pub enable_sync: bool,
20}
21
22impl InitCheckpointMetadata {
23 pub fn new(
24 source_url: &str,
25 target_url: &str,
26 filter_hash: String,
27 drop_existing: bool,
28 enable_sync: bool,
29 ) -> Self {
30 Self {
31 source_hash: hash_string(source_url),
32 target_hash: hash_string(target_url),
33 filter_hash,
34 drop_existing,
35 enable_sync,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct InitCheckpointData {
42 version: u32,
43 metadata: InitCheckpointMetadata,
44 databases: Vec<String>,
45 completed: BTreeSet<String>,
46}
47
48#[derive(Debug, Clone)]
49pub struct InitCheckpoint {
50 data: InitCheckpointData,
51}
52
53impl InitCheckpoint {
54 pub fn new(metadata: InitCheckpointMetadata, databases: &[String]) -> Self {
55 Self {
56 data: InitCheckpointData {
57 version: INIT_CHECKPOINT_VERSION,
58 metadata,
59 databases: databases.to_vec(),
60 completed: BTreeSet::new(),
61 },
62 }
63 }
64
65 pub fn load(path: &Path) -> Result<Option<Self>> {
66 if !path.exists() {
67 return Ok(None);
68 }
69
70 let content = fs::read_to_string(path)
71 .with_context(|| format!("Failed to read checkpoint at {}", path.display()))?;
72 let data: InitCheckpointData = serde_json::from_str(&content)
73 .with_context(|| format!("Failed to parse checkpoint JSON at {}", path.display()))?;
74
75 if data.version != INIT_CHECKPOINT_VERSION {
76 bail!(
77 "Checkpoint version mismatch (found {}, expected {}). Run with --no-resume to start fresh.",
78 data.version,
79 INIT_CHECKPOINT_VERSION
80 );
81 }
82
83 Ok(Some(Self { data }))
84 }
85
86 pub fn save(&self, path: &Path) -> Result<()> {
87 if let Some(parent) = path.parent() {
88 fs::create_dir_all(parent).with_context(|| {
89 format!("Failed to create checkpoint directory {}", parent.display())
90 })?;
91 }
92
93 let parent = path.parent().unwrap_or_else(|| Path::new("."));
94 let mut tmp = tempfile::NamedTempFile::new_in(parent)
95 .with_context(|| format!("Failed to create temp checkpoint in {}", parent.display()))?;
96
97 serde_json::to_writer_pretty(tmp.as_file_mut(), &self.data)
98 .with_context(|| format!("Failed to serialize checkpoint at {}", path.display()))?;
99
100 tmp.persist(path)
101 .with_context(|| format!("Failed to persist checkpoint at {}", path.display()))?;
102
103 Ok(())
104 }
105
106 pub fn databases(&self) -> &[String] {
107 &self.data.databases
108 }
109
110 pub fn metadata(&self) -> &InitCheckpointMetadata {
111 &self.data.metadata
112 }
113
114 pub fn mark_completed(&mut self, db_name: &str) -> bool {
115 self.data.completed.insert(db_name.to_string())
116 }
117
118 pub fn is_completed(&self, db_name: &str) -> bool {
119 self.data.completed.contains(db_name)
120 }
121
122 pub fn completed_count(&self) -> usize {
123 self.data.completed.len()
124 }
125
126 pub fn total_databases(&self) -> usize {
127 self.data.databases.len()
128 }
129
130 pub fn validate(&self, metadata: &InitCheckpointMetadata, databases: &[String]) -> Result<()> {
131 if self.data.metadata != *metadata {
132 bail!(
133 "Checkpoint metadata mismatch. Run with --no-resume to discard the previous state."
134 );
135 }
136
137 if self.data.databases != databases {
138 bail!(
139 "Checkpoint database list differs from current discovery. Run with --no-resume to start fresh."
140 );
141 }
142
143 Ok(())
144 }
145}
146
147pub fn checkpoint_path(source_url: &str, target_url: &str) -> Result<PathBuf> {
148 let base = std::env::temp_dir().join("postgres-seren-replicator-checkpoints");
149 fs::create_dir_all(&base).with_context(|| {
150 format!(
151 "Failed to create checkpoint base directory {}",
152 base.display()
153 )
154 })?;
155
156 let mut hasher = Sha256::new();
157 hasher.update(source_url.as_bytes());
158 hasher.update(b"::");
159 hasher.update(target_url.as_bytes());
160 let digest = format!("{:x}", hasher.finalize());
161 let short = &digest[..16.min(digest.len())];
162
163 Ok(base.join(format!("init-{}.json", short)))
164}
165
166pub fn remove_checkpoint(path: &Path) -> Result<()> {
167 if path.exists() {
168 fs::remove_file(path)
169 .with_context(|| format!("Failed to remove checkpoint at {}", path.display()))?;
170 }
171 Ok(())
172}
173
174fn hash_string(input: &str) -> String {
175 let mut hasher = Sha256::new();
176 hasher.update(input.as_bytes());
177 format!("{:x}", hasher.finalize())
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use tempfile::tempdir;
184
185 #[test]
186 fn metadata_hash_changes_with_inputs() {
187 let meta_a = InitCheckpointMetadata::new("src_a", "tgt", "filter".into(), true, false);
188 let meta_b = InitCheckpointMetadata::new("src_b", "tgt", "filter".into(), true, false);
189 assert_ne!(meta_a.source_hash, meta_b.source_hash);
190 }
191
192 #[test]
193 fn checkpoint_roundtrip() {
194 let dir = tempdir().unwrap();
195 let path = dir.path().join("cp.json");
196 let metadata = InitCheckpointMetadata::new("src", "tgt", "filter".into(), false, true);
197 let databases = vec!["db1".to_string(), "db2".to_string()];
198 let mut checkpoint = InitCheckpoint::new(metadata.clone(), &databases);
199 checkpoint.mark_completed("db1");
200 checkpoint.save(&path).unwrap();
201
202 let loaded = InitCheckpoint::load(&path).unwrap().unwrap();
203 loaded.validate(&metadata, &databases).unwrap();
204 assert!(loaded.is_completed("db1"));
205 assert!(!loaded.is_completed("db2"));
206 }
207
208 #[test]
209 fn checkpoint_path_is_deterministic() {
210 let path_a = checkpoint_path("postgres://src/db", "postgres://tgt/db").unwrap();
211 let path_b = checkpoint_path("postgres://src/db", "postgres://tgt/db").unwrap();
212 assert_eq!(path_a, path_b);
213 }
214}