Skip to main content

convergio_backup/
snapshot.rs

1//! DB snapshot — atomic SQLite backup with WAL checkpoint.
2//!
3//! Creates a consistent copy of the database by issuing a WAL checkpoint
4//! then performing an atomic file copy. Tracks snapshots in backup_snapshots.
5
6use crate::types::{BackupError, BackupResult, SnapshotRecord};
7use convergio_db::pool::ConnPool;
8use rusqlite::params;
9use sha2::{Digest, Sha256};
10use std::path::{Path, PathBuf};
11use tracing::info;
12
13/// Default backup directory under the data root.
14pub fn backup_dir(data_root: &Path) -> PathBuf {
15    data_root.join("backups")
16}
17
18/// Create an atomic snapshot of the SQLite database.
19///
20/// 1. Issue WAL checkpoint (TRUNCATE) for consistency
21/// 2. Copy the database file atomically (via temp + rename)
22/// 3. Compute SHA-256 checksum of the copy
23/// 4. Record the snapshot in backup_snapshots table
24pub fn create_snapshot(
25    pool: &ConnPool,
26    db_path: &Path,
27    dest_dir: &Path,
28    node: &str,
29) -> BackupResult<SnapshotRecord> {
30    // Path safety: dest_dir is system-constructed (backup_dir from data_root).
31    // User-supplied paths are validated at the HTTP boundary in routes.rs.
32    std::fs::create_dir_all(dest_dir)?;
33
34    // WAL checkpoint for consistency
35    let conn = pool.get()?;
36    conn.execute_batch("PRAGMA wal_checkpoint(TRUNCATE);")?;
37    drop(conn);
38
39    // Generate snapshot ID and paths
40    let snap_id = format!("snap-{}", uuid::Uuid::new_v4());
41    let timestamp = chrono::Utc::now().format("%Y%m%d-%H%M%S").to_string();
42    let filename = format!("convergio-{timestamp}.db");
43    let dest_path = dest_dir.join(&filename);
44    let tmp_path = dest_dir.join(format!(".{filename}.tmp"));
45
46    // Atomic copy: write to temp, then rename
47    std::fs::copy(db_path, &tmp_path)?;
48    std::fs::rename(&tmp_path, &dest_path)?;
49
50    // Compute checksum
51    let checksum = compute_file_checksum(&dest_path)?;
52    let size_bytes = std::fs::metadata(&dest_path)?.len() as i64;
53
54    let record = SnapshotRecord {
55        id: snap_id,
56        path: dest_path.to_string_lossy().into_owned(),
57        size_bytes,
58        checksum,
59        created_at: chrono::Utc::now().to_rfc3339(),
60        node: node.to_string(),
61    };
62
63    // Record in DB
64    let conn = pool.get()?;
65    conn.execute(
66        "INSERT INTO backup_snapshots (id, path, size_bytes, checksum, node) \
67         VALUES (?1, ?2, ?3, ?4, ?5)",
68        params![
69            record.id,
70            record.path,
71            record.size_bytes,
72            record.checksum,
73            record.node,
74        ],
75    )?;
76
77    info!(
78        snapshot = %record.id,
79        size = record.size_bytes,
80        path = %record.path,
81        "snapshot created"
82    );
83    Ok(record)
84}
85
86/// List all recorded snapshots, newest first.
87pub fn list_snapshots(pool: &ConnPool) -> BackupResult<Vec<SnapshotRecord>> {
88    let conn = pool.get()?;
89    let mut stmt = conn.prepare(
90        "SELECT id, path, size_bytes, checksum, created_at, node \
91         FROM backup_snapshots ORDER BY created_at DESC",
92    )?;
93    let records = stmt
94        .query_map([], |row| {
95            Ok(SnapshotRecord {
96                id: row.get(0)?,
97                path: row.get(1)?,
98                size_bytes: row.get(2)?,
99                checksum: row.get(3)?,
100                created_at: row.get(4)?,
101                node: row.get(5)?,
102            })
103        })?
104        .collect::<Result<Vec<_>, _>>()?;
105    Ok(records)
106}
107
108/// Find a snapshot by ID.
109pub fn get_snapshot(pool: &ConnPool, snap_id: &str) -> BackupResult<SnapshotRecord> {
110    let conn = pool.get()?;
111    conn.query_row(
112        "SELECT id, path, size_bytes, checksum, created_at, node \
113         FROM backup_snapshots WHERE id = ?1",
114        params![snap_id],
115        |row| {
116            Ok(SnapshotRecord {
117                id: row.get(0)?,
118                path: row.get(1)?,
119                size_bytes: row.get(2)?,
120                checksum: row.get(3)?,
121                created_at: row.get(4)?,
122                node: row.get(5)?,
123            })
124        },
125    )
126    .map_err(|_| BackupError::SnapshotNotFound(snap_id.to_string()))
127}
128
129/// Compute SHA-256 checksum of a file. Reads in 8 KiB chunks.
130fn compute_file_checksum(path: &Path) -> BackupResult<String> {
131    use std::io::Read;
132    let mut file = std::fs::File::open(path)?;
133    let mut hasher = Sha256::new();
134    let mut buf = [0u8; 8192];
135    loop {
136        let n = file.read(&mut buf)?;
137        if n == 0 {
138            break;
139        }
140        hasher.update(&buf[..n]);
141    }
142    let hash = hasher.finalize();
143    Ok(hash.iter().map(|b| format!("{b:02x}")).collect())
144}
145
146/// Verify a snapshot file matches its recorded checksum.
147pub fn verify_snapshot(record: &SnapshotRecord) -> BackupResult<bool> {
148    let path = Path::new(&record.path);
149    if !path.exists() {
150        return Err(BackupError::SnapshotNotFound(record.id.clone()));
151    }
152    let actual = compute_file_checksum(path)?;
153    Ok(actual == record.checksum)
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    fn setup() -> (ConnPool, tempfile::TempDir) {
161        let tmp = tempfile::tempdir().unwrap();
162        let db_path = tmp.path().join("test.db");
163        let pool = convergio_db::pool::create_pool(&db_path).unwrap();
164        let conn = pool.get().unwrap();
165        for m in crate::schema::migrations() {
166            conn.execute_batch(m.up).unwrap();
167        }
168        conn.execute_batch("CREATE TABLE test_data (id INTEGER, val TEXT)")
169            .unwrap();
170        conn.execute("INSERT INTO test_data VALUES (1, 'hello')", [])
171            .unwrap();
172        drop(conn);
173        (pool, tmp)
174    }
175
176    #[test]
177    fn create_and_list_snapshot() {
178        let (pool, tmp) = setup();
179        let db_path = tmp.path().join("test.db");
180        let dest = tmp.path().join("backups");
181        let rec = create_snapshot(&pool, &db_path, &dest, "test-node").unwrap();
182        assert!(rec.id.starts_with("snap-"));
183        assert!(rec.size_bytes > 0);
184        assert!(!rec.checksum.is_empty());
185
186        let list = list_snapshots(&pool).unwrap();
187        assert_eq!(list.len(), 1);
188        assert_eq!(list[0].id, rec.id);
189    }
190
191    #[test]
192    fn get_snapshot_by_id() {
193        let (pool, tmp) = setup();
194        let db_path = tmp.path().join("test.db");
195        let dest = tmp.path().join("backups");
196        let rec = create_snapshot(&pool, &db_path, &dest, "test-node").unwrap();
197        let found = get_snapshot(&pool, &rec.id).unwrap();
198        assert_eq!(found.checksum, rec.checksum);
199    }
200
201    #[test]
202    fn verify_snapshot_integrity() {
203        let (pool, tmp) = setup();
204        let db_path = tmp.path().join("test.db");
205        let dest = tmp.path().join("backups");
206        let rec = create_snapshot(&pool, &db_path, &dest, "test-node").unwrap();
207        assert!(verify_snapshot(&rec).unwrap());
208    }
209
210    #[test]
211    fn snapshot_not_found_error() {
212        let (pool, _tmp) = setup();
213        let result = get_snapshot(&pool, "snap-nonexistent");
214        assert!(result.is_err());
215    }
216}