database_replicator/
checkpoint.rs

1// ABOUTME: Persistent checkpoint tracking for long-running operations
2// ABOUTME: Provides init command resume support with hashed identities
3
4use anyhow::{bail, Context, Result};
5use serde::{Deserialize, Serialize};
6use sha2::{Digest, Sha256};
7use std::collections::BTreeSet;
8use std::fs;
9use std::path::{Path, PathBuf};
10
11const INIT_CHECKPOINT_VERSION: u32 = 1;
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14pub struct InitCheckpointMetadata {
15    pub source_hash: String,
16    pub target_hash: String,
17    pub filter_hash: String,
18    pub drop_existing: bool,
19    pub enable_sync: bool,
20}
21
22impl InitCheckpointMetadata {
23    pub fn new(
24        source_url: &str,
25        target_url: &str,
26        filter_hash: String,
27        drop_existing: bool,
28        enable_sync: bool,
29    ) -> Self {
30        Self {
31            source_hash: hash_string(source_url),
32            target_hash: hash_string(target_url),
33            filter_hash,
34            drop_existing,
35            enable_sync,
36        }
37    }
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41struct InitCheckpointData {
42    version: u32,
43    metadata: InitCheckpointMetadata,
44    databases: Vec<String>,
45    completed: BTreeSet<String>,
46}
47
48#[derive(Debug, Clone)]
49pub struct InitCheckpoint {
50    data: InitCheckpointData,
51}
52
53impl InitCheckpoint {
54    pub fn new(metadata: InitCheckpointMetadata, databases: &[String]) -> Self {
55        Self {
56            data: InitCheckpointData {
57                version: INIT_CHECKPOINT_VERSION,
58                metadata,
59                databases: databases.to_vec(),
60                completed: BTreeSet::new(),
61            },
62        }
63    }
64
65    pub fn load(path: &Path) -> Result<Option<Self>> {
66        if !path.exists() {
67            return Ok(None);
68        }
69
70        let content = fs::read_to_string(path)
71            .with_context(|| format!("Failed to read checkpoint at {}", path.display()))?;
72        let data: InitCheckpointData = serde_json::from_str(&content)
73            .with_context(|| format!("Failed to parse checkpoint JSON at {}", path.display()))?;
74
75        if data.version != INIT_CHECKPOINT_VERSION {
76            bail!(
77                "Checkpoint version mismatch (found {}, expected {}). Run with --no-resume to start fresh.",
78                data.version,
79                INIT_CHECKPOINT_VERSION
80            );
81        }
82
83        Ok(Some(Self { data }))
84    }
85
86    pub fn save(&self, path: &Path) -> Result<()> {
87        if let Some(parent) = path.parent() {
88            fs::create_dir_all(parent).with_context(|| {
89                format!("Failed to create checkpoint directory {}", parent.display())
90            })?;
91        }
92
93        let parent = path.parent().unwrap_or_else(|| Path::new("."));
94        let mut tmp = tempfile::NamedTempFile::new_in(parent)
95            .with_context(|| format!("Failed to create temp checkpoint in {}", parent.display()))?;
96
97        serde_json::to_writer_pretty(tmp.as_file_mut(), &self.data)
98            .with_context(|| format!("Failed to serialize checkpoint at {}", path.display()))?;
99
100        tmp.persist(path)
101            .with_context(|| format!("Failed to persist checkpoint at {}", path.display()))?;
102
103        Ok(())
104    }
105
106    pub fn databases(&self) -> &[String] {
107        &self.data.databases
108    }
109
110    pub fn metadata(&self) -> &InitCheckpointMetadata {
111        &self.data.metadata
112    }
113
114    pub fn mark_completed(&mut self, db_name: &str) -> bool {
115        self.data.completed.insert(db_name.to_string())
116    }
117
118    pub fn is_completed(&self, db_name: &str) -> bool {
119        self.data.completed.contains(db_name)
120    }
121
122    pub fn completed_count(&self) -> usize {
123        self.data.completed.len()
124    }
125
126    pub fn total_databases(&self) -> usize {
127        self.data.databases.len()
128    }
129
130    pub fn validate(&self, metadata: &InitCheckpointMetadata, databases: &[String]) -> Result<()> {
131        if self.data.metadata != *metadata {
132            bail!(
133                "Checkpoint metadata mismatch. Run with --no-resume to discard the previous state."
134            );
135        }
136
137        if self.data.databases != databases {
138            bail!(
139                "Checkpoint database list differs from current discovery. Run with --no-resume to start fresh."
140            );
141        }
142
143        Ok(())
144    }
145}
146
147pub fn checkpoint_path(source_url: &str, target_url: &str) -> Result<PathBuf> {
148    let base = std::env::temp_dir().join("postgres-seren-replicator-checkpoints");
149    fs::create_dir_all(&base).with_context(|| {
150        format!(
151            "Failed to create checkpoint base directory {}",
152            base.display()
153        )
154    })?;
155
156    let mut hasher = Sha256::new();
157    hasher.update(source_url.as_bytes());
158    hasher.update(b"::");
159    hasher.update(target_url.as_bytes());
160    let digest = format!("{:x}", hasher.finalize());
161    let short = &digest[..16.min(digest.len())];
162
163    Ok(base.join(format!("init-{}.json", short)))
164}
165
166pub fn remove_checkpoint(path: &Path) -> Result<()> {
167    if path.exists() {
168        fs::remove_file(path)
169            .with_context(|| format!("Failed to remove checkpoint at {}", path.display()))?;
170    }
171    Ok(())
172}
173
174fn hash_string(input: &str) -> String {
175    let mut hasher = Sha256::new();
176    hasher.update(input.as_bytes());
177    format!("{:x}", hasher.finalize())
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use tempfile::tempdir;
184
185    #[test]
186    fn metadata_hash_changes_with_inputs() {
187        let meta_a = InitCheckpointMetadata::new("src_a", "tgt", "filter".into(), true, false);
188        let meta_b = InitCheckpointMetadata::new("src_b", "tgt", "filter".into(), true, false);
189        assert_ne!(meta_a.source_hash, meta_b.source_hash);
190    }
191
192    #[test]
193    fn checkpoint_roundtrip() {
194        let dir = tempdir().unwrap();
195        let path = dir.path().join("cp.json");
196        let metadata = InitCheckpointMetadata::new("src", "tgt", "filter".into(), false, true);
197        let databases = vec!["db1".to_string(), "db2".to_string()];
198        let mut checkpoint = InitCheckpoint::new(metadata.clone(), &databases);
199        checkpoint.mark_completed("db1");
200        checkpoint.save(&path).unwrap();
201
202        let loaded = InitCheckpoint::load(&path).unwrap().unwrap();
203        loaded.validate(&metadata, &databases).unwrap();
204        assert!(loaded.is_completed("db1"));
205        assert!(!loaded.is_completed("db2"));
206    }
207
208    #[test]
209    fn checkpoint_path_is_deterministic() {
210        let path_a = checkpoint_path("postgres://src/db", "postgres://tgt/db").unwrap();
211        let path_b = checkpoint_path("postgres://src/db", "postgres://tgt/db").unwrap();
212        assert_eq!(path_a, path_b);
213    }
214}