Skip to main content

punch_memory/
backup.rs

1//! Database backup and restore for the SQLite-backed memory substrate.
2//!
3//! Supports hot backups (no downtime), optional gzip compression, automatic
4//! rotation of old backups, and integrity verification.
5
6use std::path::{Path, PathBuf};
7
8use chrono::{DateTime, Utc};
9use rusqlite::Connection;
10use tracing::{info, warn};
11
12use punch_types::{PunchError, PunchResult};
13
14/// Metadata about a single backup.
15#[derive(Debug, Clone)]
16pub struct BackupInfo {
17    /// Unique identifier derived from the filename.
18    pub id: String,
19    /// Full path to the backup file.
20    pub path: PathBuf,
21    /// Size in bytes.
22    pub size_bytes: u64,
23    /// When the backup was created.
24    pub created_at: DateTime<Utc>,
25    /// Database schema version (from `user_version` pragma).
26    pub db_version: String,
27}
28
29/// Manages creation, listing, rotation, and restoration of database backups.
30pub struct BackupManager {
31    /// Path to the live database file.
32    db_path: PathBuf,
33    /// Directory where backups are stored.
34    backup_dir: PathBuf,
35    /// Maximum number of backups to keep (oldest are pruned).
36    max_backups: usize,
37    /// Whether to gzip-compress backup files.
38    compress: bool,
39}
40
41impl BackupManager {
42    /// Create a new backup manager.
43    ///
44    /// `db_path` is the live SQLite database. `backup_dir` is the directory
45    /// where backups will be written. The directory will be created if it does
46    /// not exist.
47    pub fn new(db_path: PathBuf, backup_dir: PathBuf) -> Self {
48        Self {
49            db_path,
50            backup_dir,
51            max_backups: 10,
52            compress: false,
53        }
54    }
55
56    /// Set the maximum number of backups to retain.
57    pub fn with_max_backups(mut self, max: usize) -> Self {
58        self.max_backups = max;
59        self
60    }
61
62    /// Enable or disable gzip compression for backups.
63    pub fn with_compression(mut self, compress: bool) -> Self {
64        self.compress = compress;
65        self
66    }
67
68    /// Create a new backup of the live database.
69    ///
70    /// Uses SQLite's `VACUUM INTO` to produce a consistent snapshot without
71    /// interrupting readers/writers on the live database.
72    pub async fn create_backup(&self) -> PunchResult<BackupInfo> {
73        std::fs::create_dir_all(&self.backup_dir).map_err(|e| {
74            PunchError::Memory(format!(
75                "failed to create backup directory {}: {e}",
76                self.backup_dir.display()
77            ))
78        })?;
79
80        let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
81        let base_name = format!("punch_backup_{}.db", timestamp);
82        let backup_path = self.backup_dir.join(&base_name);
83
84        // Perform VACUUM INTO on a blocking thread (SQLite is not async).
85        let db_path = self.db_path.clone();
86        let dest = backup_path.clone();
87        let compress = self.compress;
88
89        let final_path = tokio::task::spawn_blocking(move || -> PunchResult<PathBuf> {
90            let conn = Connection::open(&db_path).map_err(|e| {
91                PunchError::Memory(format!("failed to open database for backup: {e}"))
92            })?;
93
94            let dest_str = dest.to_string_lossy().to_string();
95            conn.execute_batch(&format!("VACUUM INTO '{}'", dest_str.replace('\'', "''")))
96                .map_err(|e| PunchError::Memory(format!("VACUUM INTO failed: {e}")))?;
97
98            // Verify the backup.
99            verify_backup(&dest)?;
100
101            if compress {
102                let gz_path = compress_backup(&dest)?;
103                // Remove uncompressed file after successful compression.
104                let _ = std::fs::remove_file(&dest);
105                Ok(gz_path)
106            } else {
107                Ok(dest)
108            }
109        })
110        .await
111        .map_err(|e| PunchError::Memory(format!("backup task panicked: {e}")))??;
112
113        let metadata = std::fs::metadata(&final_path)
114            .map_err(|e| PunchError::Memory(format!("failed to stat backup file: {e}")))?;
115
116        let id = final_path
117            .file_stem()
118            .and_then(|s| s.to_str())
119            .unwrap_or(&base_name)
120            .to_string();
121
122        let info = BackupInfo {
123            id,
124            path: final_path.clone(),
125            size_bytes: metadata.len(),
126            created_at: Utc::now(),
127            db_version: read_db_version(&self.db_path)?,
128        };
129
130        info!(
131            path = %final_path.display(),
132            size_bytes = info.size_bytes,
133            "database backup created"
134        );
135
136        Ok(info)
137    }
138
139    /// List all existing backups, newest first.
140    pub async fn list_backups(&self) -> PunchResult<Vec<BackupInfo>> {
141        let backup_dir = self.backup_dir.clone();
142        let db_path = self.db_path.clone();
143
144        tokio::task::spawn_blocking(move || -> PunchResult<Vec<BackupInfo>> {
145            let mut backups = Vec::new();
146
147            let entries = match std::fs::read_dir(&backup_dir) {
148                Ok(entries) => entries,
149                Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(backups),
150                Err(e) => {
151                    return Err(PunchError::Memory(format!(
152                        "failed to read backup directory: {e}"
153                    )));
154                }
155            };
156
157            for entry in entries {
158                let entry = entry.map_err(|e| {
159                    PunchError::Memory(format!("failed to read directory entry: {e}"))
160                })?;
161
162                let path = entry.path();
163                let name = path
164                    .file_name()
165                    .and_then(|n| n.to_str())
166                    .unwrap_or_default();
167
168                if !name.starts_with("punch_backup_") {
169                    continue;
170                }
171
172                let metadata = entry
173                    .metadata()
174                    .map_err(|e| PunchError::Memory(format!("failed to stat backup file: {e}")))?;
175
176                let created_at = metadata
177                    .created()
178                    .or_else(|_| metadata.modified())
179                    .map(DateTime::<Utc>::from)
180                    .unwrap_or_else(|_| Utc::now());
181
182                let id = path
183                    .file_stem()
184                    .and_then(|s| s.to_str())
185                    .unwrap_or(name)
186                    .to_string();
187
188                let db_version =
189                    read_db_version(&db_path).unwrap_or_else(|_| "unknown".to_string());
190
191                backups.push(BackupInfo {
192                    id,
193                    path,
194                    size_bytes: metadata.len(),
195                    created_at,
196                    db_version,
197                });
198            }
199
200            // Sort newest first.
201            backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
202
203            Ok(backups)
204        })
205        .await
206        .map_err(|e| PunchError::Memory(format!("list backups task panicked: {e}")))?
207    }
208
209    /// Restore the live database from a backup identified by `backup_id`.
210    ///
211    /// The `backup_id` is the file stem (e.g. `punch_backup_20260101_120000`).
212    /// If the backup is gzip-compressed it will be decompressed first.
213    pub async fn restore_backup(&self, backup_id: &str) -> PunchResult<()> {
214        let backups = self.list_backups().await?;
215        let backup = backups
216            .iter()
217            .find(|b| b.id == backup_id)
218            .ok_or_else(|| PunchError::Memory(format!("backup not found: {backup_id}")))?;
219
220        let backup_path = backup.path.clone();
221        let db_path = self.db_path.clone();
222
223        tokio::task::spawn_blocking(move || -> PunchResult<()> {
224            let source = if backup_path.extension().and_then(|e| e.to_str()) == Some("gz") {
225                decompress_backup(&backup_path)?
226            } else {
227                backup_path.clone()
228            };
229
230            // Verify backup integrity before restoring.
231            verify_backup(&source)?;
232
233            // Copy the backup over the live database.
234            std::fs::copy(&source, &db_path)
235                .map_err(|e| PunchError::Memory(format!("failed to restore backup: {e}")))?;
236
237            // Clean up decompressed temp file if we made one.
238            if source != backup_path {
239                let _ = std::fs::remove_file(&source);
240            }
241
242            info!(
243                backup = %backup_path.display(),
244                target = %db_path.display(),
245                "database restored from backup"
246            );
247
248            Ok(())
249        })
250        .await
251        .map_err(|e| PunchError::Memory(format!("restore task panicked: {e}")))?
252    }
253
254    /// Remove old backups, keeping only the most recent `max_backups`.
255    ///
256    /// Returns the number of backups removed.
257    pub async fn cleanup_old_backups(&self) -> PunchResult<usize> {
258        let backups = self.list_backups().await?;
259
260        if backups.len() <= self.max_backups {
261            return Ok(0);
262        }
263
264        let to_remove = &backups[self.max_backups..];
265        let mut removed = 0;
266
267        for backup in to_remove {
268            match std::fs::remove_file(&backup.path) {
269                Ok(()) => {
270                    info!(path = %backup.path.display(), "removed old backup");
271                    removed += 1;
272                }
273                Err(e) => {
274                    warn!(
275                        path = %backup.path.display(),
276                        error = %e,
277                        "failed to remove old backup"
278                    );
279                }
280            }
281        }
282
283        Ok(removed)
284    }
285}
286
287// ---------------------------------------------------------------------------
288// Helpers
289// ---------------------------------------------------------------------------
290
291/// Read the `user_version` pragma from a SQLite database.
292fn read_db_version(path: &Path) -> PunchResult<String> {
293    let conn = Connection::open(path).map_err(|e| {
294        PunchError::Memory(format!("failed to open database for version check: {e}"))
295    })?;
296
297    let version: i64 = conn
298        .pragma_query_value(None, "user_version", |row| row.get(0))
299        .map_err(|e| PunchError::Memory(format!("failed to read user_version: {e}")))?;
300
301    Ok(version.to_string())
302}
303
304/// Verify a backup file by opening it and running `PRAGMA integrity_check`.
305fn verify_backup(path: &Path) -> PunchResult<()> {
306    let conn = Connection::open(path)
307        .map_err(|e| PunchError::Memory(format!("failed to open backup for verification: {e}")))?;
308
309    let result: String = conn
310        .pragma_query_value(None, "integrity_check", |row| row.get(0))
311        .map_err(|e| PunchError::Memory(format!("integrity check failed: {e}")))?;
312
313    if result != "ok" {
314        return Err(PunchError::Memory(format!(
315            "backup integrity check failed: {result}"
316        )));
317    }
318
319    Ok(())
320}
321
322/// Compress a backup file with gzip, returning the path to the `.gz` file.
323fn compress_backup(path: &Path) -> PunchResult<PathBuf> {
324    use flate2::Compression;
325    use flate2::write::GzEncoder;
326    use std::io::{Read, Write};
327
328    let gz_path = path.with_extension("db.gz");
329    let mut input = std::fs::File::open(path)
330        .map_err(|e| PunchError::Memory(format!("failed to open backup for compression: {e}")))?;
331
332    let output = std::fs::File::create(&gz_path)
333        .map_err(|e| PunchError::Memory(format!("failed to create compressed backup: {e}")))?;
334
335    let mut encoder = GzEncoder::new(output, Compression::default());
336    let mut buf = [0u8; 64 * 1024];
337
338    loop {
339        let n = input
340            .read(&mut buf)
341            .map_err(|e| PunchError::Memory(format!("read error during compression: {e}")))?;
342        if n == 0 {
343            break;
344        }
345        encoder
346            .write_all(&buf[..n])
347            .map_err(|e| PunchError::Memory(format!("write error during compression: {e}")))?;
348    }
349
350    encoder
351        .finish()
352        .map_err(|e| PunchError::Memory(format!("failed to finalize compressed backup: {e}")))?;
353
354    Ok(gz_path)
355}
356
357/// Decompress a `.gz` backup, returning the path to the decompressed file.
358fn decompress_backup(gz_path: &Path) -> PunchResult<PathBuf> {
359    use flate2::read::GzDecoder;
360    use std::io::{Read, Write};
361
362    let stem = gz_path
363        .file_stem()
364        .and_then(|s| s.to_str())
365        .unwrap_or("backup");
366    let out_path = gz_path.with_file_name(format!("{}_restored.db", stem));
367
368    let input = std::fs::File::open(gz_path)
369        .map_err(|e| PunchError::Memory(format!("failed to open compressed backup: {e}")))?;
370
371    let mut decoder = GzDecoder::new(input);
372    let mut output = std::fs::File::create(&out_path)
373        .map_err(|e| PunchError::Memory(format!("failed to create decompressed file: {e}")))?;
374
375    let mut buf = [0u8; 64 * 1024];
376    loop {
377        let n = decoder
378            .read(&mut buf)
379            .map_err(|e| PunchError::Memory(format!("read error during decompression: {e}")))?;
380        if n == 0 {
381            break;
382        }
383        output
384            .write_all(&buf[..n])
385            .map_err(|e| PunchError::Memory(format!("write error during decompression: {e}")))?;
386    }
387
388    Ok(out_path)
389}
390
391// ---------------------------------------------------------------------------
392// Tests
393// ---------------------------------------------------------------------------
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use std::fs;
399
400    /// Create a minimal SQLite database for testing.
401    fn create_test_db(path: &Path) {
402        let conn = Connection::open(path).expect("create test db");
403        conn.execute_batch(
404            "PRAGMA journal_mode = WAL;
405             CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);
406             INSERT INTO test (value) VALUES ('hello');",
407        )
408        .expect("init test db");
409    }
410
411    #[tokio::test]
412    async fn create_backup_produces_a_file() {
413        let dir = tempfile::tempdir().expect("tempdir");
414        let db_path = dir.path().join("test.db");
415        let backup_dir = dir.path().join("backups");
416
417        create_test_db(&db_path);
418
419        let mgr = BackupManager::new(db_path, backup_dir.clone());
420        let info = mgr.create_backup().await.expect("create backup");
421
422        assert!(info.path.exists());
423        assert!(info.size_bytes > 0);
424        assert!(info.id.starts_with("punch_backup_"));
425    }
426
427    #[tokio::test]
428    async fn backup_is_valid_sqlite() {
429        let dir = tempfile::tempdir().expect("tempdir");
430        let db_path = dir.path().join("test.db");
431        let backup_dir = dir.path().join("backups");
432
433        create_test_db(&db_path);
434
435        let mgr = BackupManager::new(db_path, backup_dir);
436        let info = mgr.create_backup().await.expect("create backup");
437
438        // Should be openable and pass integrity check.
439        verify_backup(&info.path).expect("integrity check should pass");
440
441        // Should contain our test data.
442        let conn = Connection::open(&info.path).expect("open backup");
443        let value: String = conn
444            .query_row("SELECT value FROM test WHERE id = 1", [], |row| row.get(0))
445            .expect("query backup");
446        assert_eq!(value, "hello");
447    }
448
449    #[tokio::test]
450    async fn list_backups_returns_created_backups() {
451        let dir = tempfile::tempdir().expect("tempdir");
452        let db_path = dir.path().join("test.db");
453        let backup_dir = dir.path().join("backups");
454
455        create_test_db(&db_path);
456
457        let mgr = BackupManager::new(db_path, backup_dir);
458
459        // Empty initially.
460        let list = mgr.list_backups().await.expect("list");
461        assert!(list.is_empty());
462
463        mgr.create_backup().await.expect("backup 1");
464        let list = mgr.list_backups().await.expect("list");
465        assert_eq!(list.len(), 1);
466    }
467
468    #[tokio::test]
469    async fn cleanup_removes_old_backups() {
470        let dir = tempfile::tempdir().expect("tempdir");
471        let db_path = dir.path().join("test.db");
472        let backup_dir = dir.path().join("backups");
473
474        create_test_db(&db_path);
475
476        let mgr = BackupManager::new(db_path, backup_dir.clone()).with_max_backups(2);
477
478        // Create 4 backups with distinct timestamps.
479        for i in 0..4 {
480            let name = format!("punch_backup_20260101_12000{}.db", i);
481            let path = backup_dir.join(&name);
482            fs::create_dir_all(&backup_dir).expect("mkdir");
483            // Create a minimal valid SQLite file by copying.
484            let conn = Connection::open(&path).expect("create");
485            conn.execute_batch("CREATE TABLE t (id INTEGER);")
486                .expect("init");
487        }
488
489        let removed = mgr.cleanup_old_backups().await.expect("cleanup");
490        assert_eq!(removed, 2);
491
492        let remaining = mgr.list_backups().await.expect("list");
493        assert_eq!(remaining.len(), 2);
494    }
495
496    #[tokio::test]
497    async fn backup_naming_follows_pattern() {
498        let dir = tempfile::tempdir().expect("tempdir");
499        let db_path = dir.path().join("test.db");
500        let backup_dir = dir.path().join("backups");
501
502        create_test_db(&db_path);
503
504        let mgr = BackupManager::new(db_path, backup_dir);
505        let info = mgr.create_backup().await.expect("backup");
506
507        let filename = info
508            .path
509            .file_name()
510            .and_then(|n| n.to_str())
511            .expect("filename");
512        assert!(
513            filename.starts_with("punch_backup_"),
514            "expected punch_backup_ prefix, got: {filename}"
515        );
516        assert!(
517            filename.ends_with(".db"),
518            "expected .db suffix, got: {filename}"
519        );
520    }
521
522    #[tokio::test]
523    async fn compressed_backup_roundtrip() {
524        let dir = tempfile::tempdir().expect("tempdir");
525        let db_path = dir.path().join("test.db");
526        let backup_dir = dir.path().join("backups");
527
528        create_test_db(&db_path);
529
530        let mgr = BackupManager::new(db_path.clone(), backup_dir).with_compression(true);
531        let info = mgr.create_backup().await.expect("backup");
532
533        let filename = info
534            .path
535            .file_name()
536            .and_then(|n| n.to_str())
537            .expect("filename");
538        assert!(
539            filename.ends_with(".db.gz"),
540            "expected .db.gz suffix, got: {filename}"
541        );
542
543        // Decompress and verify.
544        let restored = decompress_backup(&info.path).expect("decompress");
545        verify_backup(&restored).expect("integrity check on decompressed");
546    }
547}