Skip to main content

convergio_backup/
restore.rs

1//! Disaster recovery — restore daemon from a snapshot.
2//!
3//! Validates the snapshot checksum before replacing the live database.
4//! The restore is atomic: copy to temp, verify, then rename.
5
6use crate::snapshot::{get_snapshot, verify_snapshot};
7use crate::types::{BackupError, BackupResult};
8use convergio_db::pool::ConnPool;
9use std::path::Path;
10use tracing::{info, warn};
11
12/// Restore the database from a snapshot file.
13///
14/// 1. Find the snapshot record by ID
15/// 2. Verify the snapshot file checksum
16/// 3. Copy snapshot to temp location next to target
17/// 4. Rename temp over the live database (atomic on same filesystem)
18///
19/// The caller must stop the daemon or close all pool connections
20/// before calling this. Returns the path of the restored file.
21pub fn restore_from_snapshot(
22    pool: &ConnPool,
23    snap_id: &str,
24    target_db_path: &Path,
25) -> BackupResult<String> {
26    // Path safety: target_db_path is system-constructed (data_dir + "convergio.db").
27    let record = get_snapshot(pool, snap_id)?;
28
29    let snap_path = Path::new(&record.path);
30    if !snap_path.exists() {
31        return Err(BackupError::SnapshotNotFound(format!(
32            "file missing: {}",
33            record.path
34        )));
35    }
36
37    // Verify integrity
38    if !verify_snapshot(&record)? {
39        return Err(BackupError::RestoreFailed(
40            "checksum mismatch — snapshot may be corrupted".into(),
41        ));
42    }
43
44    info!(snapshot = %snap_id, "verified snapshot integrity, starting restore");
45
46    // Atomic restore: copy to temp, then rename
47    let tmp_path = target_db_path.with_extension("db.restoring");
48    std::fs::copy(snap_path, &tmp_path)?;
49
50    // Remove WAL and SHM files from target (stale after restore)
51    let wal = target_db_path.with_extension("db-wal");
52    let shm = target_db_path.with_extension("db-shm");
53    remove_if_exists(&wal);
54    remove_if_exists(&shm);
55
56    // Rename temp over live DB
57    std::fs::rename(&tmp_path, target_db_path)?;
58
59    info!(
60        snapshot = %snap_id,
61        target = %target_db_path.display(),
62        "database restored from snapshot"
63    );
64    Ok(record.path)
65}
66
67/// Restore from a raw snapshot file path (no pool lookup).
68/// Used by `cvg backup restore <file>` when the DB may not be running.
69pub fn restore_from_file(snapshot_path: &Path, target_db_path: &Path) -> BackupResult<()> {
70    if !snapshot_path.exists() {
71        return Err(BackupError::SnapshotNotFound(
72            snapshot_path.to_string_lossy().into_owned(),
73        ));
74    }
75
76    let tmp_path = target_db_path.with_extension("db.restoring");
77    std::fs::copy(snapshot_path, &tmp_path)?;
78
79    // Remove stale WAL/SHM
80    remove_if_exists(&target_db_path.with_extension("db-wal"));
81    remove_if_exists(&target_db_path.with_extension("db-shm"));
82
83    std::fs::rename(&tmp_path, target_db_path)?;
84
85    info!(
86        source = %snapshot_path.display(),
87        target = %target_db_path.display(),
88        "database restored from file"
89    );
90    Ok(())
91}
92
93fn remove_if_exists(path: &Path) {
94    if path.exists() {
95        if let Err(e) = std::fs::remove_file(path) {
96            warn!(path = %path.display(), err = %e, "failed to remove stale file");
97        }
98    }
99}
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104
105    #[test]
106    fn restore_from_file_copies_correctly() {
107        let tmp = tempfile::tempdir().unwrap();
108        let source = tmp.path().join("source.db");
109        let target = tmp.path().join("target.db");
110        std::fs::write(&source, b"fake-sqlite-data").unwrap();
111
112        restore_from_file(&source, &target).unwrap();
113
114        assert!(target.exists());
115        let content = std::fs::read(&target).unwrap();
116        assert_eq!(content, b"fake-sqlite-data");
117    }
118
119    #[test]
120    fn restore_from_file_removes_stale_wal() {
121        let tmp = tempfile::tempdir().unwrap();
122        let source = tmp.path().join("source.db");
123        let target = tmp.path().join("target.db");
124        let wal = tmp.path().join("target.db-wal");
125        let shm = tmp.path().join("target.db-shm");
126
127        std::fs::write(&source, b"db-data").unwrap();
128        std::fs::write(&wal, b"stale-wal").unwrap();
129        std::fs::write(&shm, b"stale-shm").unwrap();
130
131        restore_from_file(&source, &target).unwrap();
132
133        assert!(!wal.exists());
134        assert!(!shm.exists());
135    }
136
137    #[test]
138    fn restore_from_missing_file_errors() {
139        let tmp = tempfile::tempdir().unwrap();
140        let result = restore_from_file(
141            &tmp.path().join("nonexistent.db"),
142            &tmp.path().join("target.db"),
143        );
144        assert!(result.is_err());
145    }
146
147    #[test]
148    fn restore_from_snapshot_round_trip() {
149        let tmp = tempfile::tempdir().unwrap();
150        let db_path = tmp.path().join("live.db");
151        let pool = convergio_db::pool::create_pool(&db_path).unwrap();
152        let conn = pool.get().unwrap();
153        for m in crate::schema::migrations() {
154            conn.execute_batch(m.up).unwrap();
155        }
156        conn.execute_batch("CREATE TABLE test_rt (v TEXT)").unwrap();
157        conn.execute("INSERT INTO test_rt VALUES ('original')", [])
158            .unwrap();
159        drop(conn);
160
161        // Create snapshot
162        let dest = tmp.path().join("backups");
163        let rec = crate::snapshot::create_snapshot(&pool, &db_path, &dest, "test-node").unwrap();
164
165        // Modify live DB
166        let conn = pool.get().unwrap();
167        conn.execute("DELETE FROM test_rt", []).unwrap();
168        drop(conn);
169
170        // Restore
171        let snap_path = std::path::Path::new(&rec.path);
172        restore_from_file(snap_path, &db_path).unwrap();
173
174        // Verify restoration
175        let conn2 = rusqlite::Connection::open(&db_path).unwrap();
176        let val: String = conn2
177            .query_row("SELECT v FROM test_rt", [], |r| r.get(0))
178            .unwrap();
179        assert_eq!(val, "original");
180    }
181}