Skip to main content

punch_memory/
backup.rs

1//! Database backup and restore for the SQLite-backed memory substrate.
2//!
3//! Supports hot backups (no downtime), optional gzip compression, automatic
4//! rotation of old backups, and integrity verification.
5
6use std::path::{Path, PathBuf};
7
8use chrono::{DateTime, Utc};
9use rusqlite::Connection;
10use tracing::{info, warn};
11
12use punch_types::{PunchError, PunchResult};
13
14/// Metadata about a single backup.
15#[derive(Debug, Clone)]
16pub struct BackupInfo {
17    /// Unique identifier derived from the filename.
18    pub id: String,
19    /// Full path to the backup file.
20    pub path: PathBuf,
21    /// Size in bytes.
22    pub size_bytes: u64,
23    /// When the backup was created.
24    pub created_at: DateTime<Utc>,
25    /// Database schema version (from `user_version` pragma).
26    pub db_version: String,
27}
28
29/// Manages creation, listing, rotation, and restoration of database backups.
30pub struct BackupManager {
31    /// Path to the live database file.
32    db_path: PathBuf,
33    /// Directory where backups are stored.
34    backup_dir: PathBuf,
35    /// Maximum number of backups to keep (oldest are pruned).
36    max_backups: usize,
37    /// Whether to gzip-compress backup files.
38    compress: bool,
39}
40
41impl BackupManager {
42    /// Create a new backup manager.
43    ///
44    /// `db_path` is the live SQLite database. `backup_dir` is the directory
45    /// where backups will be written. The directory will be created if it does
46    /// not exist.
47    pub fn new(db_path: PathBuf, backup_dir: PathBuf) -> Self {
48        Self {
49            db_path,
50            backup_dir,
51            max_backups: 10,
52            compress: false,
53        }
54    }
55
56    /// Set the maximum number of backups to retain.
57    pub fn with_max_backups(mut self, max: usize) -> Self {
58        self.max_backups = max;
59        self
60    }
61
62    /// Enable or disable gzip compression for backups.
63    pub fn with_compression(mut self, compress: bool) -> Self {
64        self.compress = compress;
65        self
66    }
67
68    /// Create a new backup of the live database.
69    ///
70    /// Uses SQLite's `VACUUM INTO` to produce a consistent snapshot without
71    /// interrupting readers/writers on the live database.
72    pub async fn create_backup(&self) -> PunchResult<BackupInfo> {
73        std::fs::create_dir_all(&self.backup_dir).map_err(|e| {
74            PunchError::Memory(format!(
75                "failed to create backup directory {}: {e}",
76                self.backup_dir.display()
77            ))
78        })?;
79
80        let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
81        let base_name = format!("punch_backup_{}.db", timestamp);
82        let backup_path = self.backup_dir.join(&base_name);
83
84        // Perform VACUUM INTO on a blocking thread (SQLite is not async).
85        let db_path = self.db_path.clone();
86        let dest = backup_path.clone();
87        let compress = self.compress;
88
89        let final_path = tokio::task::spawn_blocking(move || -> PunchResult<PathBuf> {
90            let conn = Connection::open(&db_path).map_err(|e| {
91                PunchError::Memory(format!("failed to open database for backup: {e}"))
92            })?;
93
94            let dest_str = dest.to_string_lossy().to_string();
95            conn.execute_batch(&format!("VACUUM INTO '{}'", dest_str.replace('\'', "''")))
96                .map_err(|e| {
97                    PunchError::Memory(format!("VACUUM INTO failed: {e}"))
98                })?;
99
100            // Verify the backup.
101            verify_backup(&dest)?;
102
103            if compress {
104                let gz_path = compress_backup(&dest)?;
105                // Remove uncompressed file after successful compression.
106                let _ = std::fs::remove_file(&dest);
107                Ok(gz_path)
108            } else {
109                Ok(dest)
110            }
111        })
112        .await
113        .map_err(|e| PunchError::Memory(format!("backup task panicked: {e}")))??;
114
115        let metadata = std::fs::metadata(&final_path).map_err(|e| {
116            PunchError::Memory(format!("failed to stat backup file: {e}"))
117        })?;
118
119        let id = final_path
120            .file_stem()
121            .and_then(|s| s.to_str())
122            .unwrap_or(&base_name)
123            .to_string();
124
125        let info = BackupInfo {
126            id,
127            path: final_path.clone(),
128            size_bytes: metadata.len(),
129            created_at: Utc::now(),
130            db_version: read_db_version(&self.db_path)?,
131        };
132
133        info!(
134            path = %final_path.display(),
135            size_bytes = info.size_bytes,
136            "database backup created"
137        );
138
139        Ok(info)
140    }
141
142    /// List all existing backups, newest first.
143    pub async fn list_backups(&self) -> PunchResult<Vec<BackupInfo>> {
144        let backup_dir = self.backup_dir.clone();
145        let db_path = self.db_path.clone();
146
147        tokio::task::spawn_blocking(move || -> PunchResult<Vec<BackupInfo>> {
148            let mut backups = Vec::new();
149
150            let entries = match std::fs::read_dir(&backup_dir) {
151                Ok(entries) => entries,
152                Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(backups),
153                Err(e) => {
154                    return Err(PunchError::Memory(format!(
155                        "failed to read backup directory: {e}"
156                    )));
157                }
158            };
159
160            for entry in entries {
161                let entry = entry.map_err(|e| {
162                    PunchError::Memory(format!("failed to read directory entry: {e}"))
163                })?;
164
165                let path = entry.path();
166                let name = path
167                    .file_name()
168                    .and_then(|n| n.to_str())
169                    .unwrap_or_default();
170
171                if !name.starts_with("punch_backup_") {
172                    continue;
173                }
174
175                let metadata = entry.metadata().map_err(|e| {
176                    PunchError::Memory(format!("failed to stat backup file: {e}"))
177                })?;
178
179                let created_at = metadata
180                    .created()
181                    .or_else(|_| metadata.modified())
182                    .map(DateTime::<Utc>::from)
183                    .unwrap_or_else(|_| Utc::now());
184
185                let id = path
186                    .file_stem()
187                    .and_then(|s| s.to_str())
188                    .unwrap_or(name)
189                    .to_string();
190
191                let db_version = read_db_version(&db_path).unwrap_or_else(|_| "unknown".to_string());
192
193                backups.push(BackupInfo {
194                    id,
195                    path,
196                    size_bytes: metadata.len(),
197                    created_at,
198                    db_version,
199                });
200            }
201
202            // Sort newest first.
203            backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
204
205            Ok(backups)
206        })
207        .await
208        .map_err(|e| PunchError::Memory(format!("list backups task panicked: {e}")))?
209    }
210
211    /// Restore the live database from a backup identified by `backup_id`.
212    ///
213    /// The `backup_id` is the file stem (e.g. `punch_backup_20260101_120000`).
214    /// If the backup is gzip-compressed it will be decompressed first.
215    pub async fn restore_backup(&self, backup_id: &str) -> PunchResult<()> {
216        let backups = self.list_backups().await?;
217        let backup = backups
218            .iter()
219            .find(|b| b.id == backup_id)
220            .ok_or_else(|| PunchError::Memory(format!("backup not found: {backup_id}")))?;
221
222        let backup_path = backup.path.clone();
223        let db_path = self.db_path.clone();
224
225        tokio::task::spawn_blocking(move || -> PunchResult<()> {
226            let source = if backup_path.extension().and_then(|e| e.to_str()) == Some("gz") {
227                decompress_backup(&backup_path)?
228            } else {
229                backup_path.clone()
230            };
231
232            // Verify backup integrity before restoring.
233            verify_backup(&source)?;
234
235            // Copy the backup over the live database.
236            std::fs::copy(&source, &db_path).map_err(|e| {
237                PunchError::Memory(format!("failed to restore backup: {e}"))
238            })?;
239
240            // Clean up decompressed temp file if we made one.
241            if source != backup_path {
242                let _ = std::fs::remove_file(&source);
243            }
244
245            info!(
246                backup = %backup_path.display(),
247                target = %db_path.display(),
248                "database restored from backup"
249            );
250
251            Ok(())
252        })
253        .await
254        .map_err(|e| PunchError::Memory(format!("restore task panicked: {e}")))?
255    }
256
257    /// Remove old backups, keeping only the most recent `max_backups`.
258    ///
259    /// Returns the number of backups removed.
260    pub async fn cleanup_old_backups(&self) -> PunchResult<usize> {
261        let backups = self.list_backups().await?;
262
263        if backups.len() <= self.max_backups {
264            return Ok(0);
265        }
266
267        let to_remove = &backups[self.max_backups..];
268        let mut removed = 0;
269
270        for backup in to_remove {
271            match std::fs::remove_file(&backup.path) {
272                Ok(()) => {
273                    info!(path = %backup.path.display(), "removed old backup");
274                    removed += 1;
275                }
276                Err(e) => {
277                    warn!(
278                        path = %backup.path.display(),
279                        error = %e,
280                        "failed to remove old backup"
281                    );
282                }
283            }
284        }
285
286        Ok(removed)
287    }
288}
289
290// ---------------------------------------------------------------------------
291// Helpers
292// ---------------------------------------------------------------------------
293
294/// Read the `user_version` pragma from a SQLite database.
295fn read_db_version(path: &Path) -> PunchResult<String> {
296    let conn = Connection::open(path).map_err(|e| {
297        PunchError::Memory(format!("failed to open database for version check: {e}"))
298    })?;
299
300    let version: i64 = conn
301        .pragma_query_value(None, "user_version", |row| row.get(0))
302        .map_err(|e| PunchError::Memory(format!("failed to read user_version: {e}")))?;
303
304    Ok(version.to_string())
305}
306
307/// Verify a backup file by opening it and running `PRAGMA integrity_check`.
308fn verify_backup(path: &Path) -> PunchResult<()> {
309    let conn = Connection::open(path).map_err(|e| {
310        PunchError::Memory(format!(
311            "failed to open backup for verification: {e}"
312        ))
313    })?;
314
315    let result: String = conn
316        .pragma_query_value(None, "integrity_check", |row| row.get(0))
317        .map_err(|e| {
318            PunchError::Memory(format!("integrity check failed: {e}"))
319        })?;
320
321    if result != "ok" {
322        return Err(PunchError::Memory(format!(
323            "backup integrity check failed: {result}"
324        )));
325    }
326
327    Ok(())
328}
329
330/// Compress a backup file with gzip, returning the path to the `.gz` file.
331fn compress_backup(path: &Path) -> PunchResult<PathBuf> {
332    use flate2::write::GzEncoder;
333    use flate2::Compression;
334    use std::io::{Read, Write};
335
336    let gz_path = path.with_extension("db.gz");
337    let mut input = std::fs::File::open(path).map_err(|e| {
338        PunchError::Memory(format!("failed to open backup for compression: {e}"))
339    })?;
340
341    let output = std::fs::File::create(&gz_path).map_err(|e| {
342        PunchError::Memory(format!("failed to create compressed backup: {e}"))
343    })?;
344
345    let mut encoder = GzEncoder::new(output, Compression::default());
346    let mut buf = [0u8; 64 * 1024];
347
348    loop {
349        let n = input.read(&mut buf).map_err(|e| {
350            PunchError::Memory(format!("read error during compression: {e}"))
351        })?;
352        if n == 0 {
353            break;
354        }
355        encoder.write_all(&buf[..n]).map_err(|e| {
356            PunchError::Memory(format!("write error during compression: {e}"))
357        })?;
358    }
359
360    encoder.finish().map_err(|e| {
361        PunchError::Memory(format!("failed to finalize compressed backup: {e}"))
362    })?;
363
364    Ok(gz_path)
365}
366
367/// Decompress a `.gz` backup, returning the path to the decompressed file.
368fn decompress_backup(gz_path: &Path) -> PunchResult<PathBuf> {
369    use flate2::read::GzDecoder;
370    use std::io::{Read, Write};
371
372    let stem = gz_path
373        .file_stem()
374        .and_then(|s| s.to_str())
375        .unwrap_or("backup");
376    let out_path = gz_path.with_file_name(format!("{}_restored.db", stem));
377
378    let input = std::fs::File::open(gz_path).map_err(|e| {
379        PunchError::Memory(format!("failed to open compressed backup: {e}"))
380    })?;
381
382    let mut decoder = GzDecoder::new(input);
383    let mut output = std::fs::File::create(&out_path).map_err(|e| {
384        PunchError::Memory(format!("failed to create decompressed file: {e}"))
385    })?;
386
387    let mut buf = [0u8; 64 * 1024];
388    loop {
389        let n = decoder.read(&mut buf).map_err(|e| {
390            PunchError::Memory(format!("read error during decompression: {e}"))
391        })?;
392        if n == 0 {
393            break;
394        }
395        output.write_all(&buf[..n]).map_err(|e| {
396            PunchError::Memory(format!("write error during decompression: {e}"))
397        })?;
398    }
399
400    Ok(out_path)
401}
402
403// ---------------------------------------------------------------------------
404// Tests
405// ---------------------------------------------------------------------------
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    use std::fs;
411
412    /// Create a minimal SQLite database for testing.
413    fn create_test_db(path: &Path) {
414        let conn = Connection::open(path).expect("create test db");
415        conn.execute_batch(
416            "PRAGMA journal_mode = WAL;
417             CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);
418             INSERT INTO test (value) VALUES ('hello');",
419        )
420        .expect("init test db");
421    }
422
423    #[tokio::test]
424    async fn create_backup_produces_a_file() {
425        let dir = tempfile::tempdir().expect("tempdir");
426        let db_path = dir.path().join("test.db");
427        let backup_dir = dir.path().join("backups");
428
429        create_test_db(&db_path);
430
431        let mgr = BackupManager::new(db_path, backup_dir.clone());
432        let info = mgr.create_backup().await.expect("create backup");
433
434        assert!(info.path.exists());
435        assert!(info.size_bytes > 0);
436        assert!(info.id.starts_with("punch_backup_"));
437    }
438
439    #[tokio::test]
440    async fn backup_is_valid_sqlite() {
441        let dir = tempfile::tempdir().expect("tempdir");
442        let db_path = dir.path().join("test.db");
443        let backup_dir = dir.path().join("backups");
444
445        create_test_db(&db_path);
446
447        let mgr = BackupManager::new(db_path, backup_dir);
448        let info = mgr.create_backup().await.expect("create backup");
449
450        // Should be openable and pass integrity check.
451        verify_backup(&info.path).expect("integrity check should pass");
452
453        // Should contain our test data.
454        let conn = Connection::open(&info.path).expect("open backup");
455        let value: String = conn
456            .query_row("SELECT value FROM test WHERE id = 1", [], |row| row.get(0))
457            .expect("query backup");
458        assert_eq!(value, "hello");
459    }
460
461    #[tokio::test]
462    async fn list_backups_returns_created_backups() {
463        let dir = tempfile::tempdir().expect("tempdir");
464        let db_path = dir.path().join("test.db");
465        let backup_dir = dir.path().join("backups");
466
467        create_test_db(&db_path);
468
469        let mgr = BackupManager::new(db_path, backup_dir);
470
471        // Empty initially.
472        let list = mgr.list_backups().await.expect("list");
473        assert!(list.is_empty());
474
475        mgr.create_backup().await.expect("backup 1");
476        let list = mgr.list_backups().await.expect("list");
477        assert_eq!(list.len(), 1);
478    }
479
480    #[tokio::test]
481    async fn cleanup_removes_old_backups() {
482        let dir = tempfile::tempdir().expect("tempdir");
483        let db_path = dir.path().join("test.db");
484        let backup_dir = dir.path().join("backups");
485
486        create_test_db(&db_path);
487
488        let mgr = BackupManager::new(db_path, backup_dir.clone()).with_max_backups(2);
489
490        // Create 4 backups with distinct timestamps.
491        for i in 0..4 {
492            let name = format!("punch_backup_20260101_12000{}.db", i);
493            let path = backup_dir.join(&name);
494            fs::create_dir_all(&backup_dir).expect("mkdir");
495            // Create a minimal valid SQLite file by copying.
496            let conn = Connection::open(&path).expect("create");
497            conn.execute_batch("CREATE TABLE t (id INTEGER);")
498                .expect("init");
499        }
500
501        let removed = mgr.cleanup_old_backups().await.expect("cleanup");
502        assert_eq!(removed, 2);
503
504        let remaining = mgr.list_backups().await.expect("list");
505        assert_eq!(remaining.len(), 2);
506    }
507
508    #[tokio::test]
509    async fn backup_naming_follows_pattern() {
510        let dir = tempfile::tempdir().expect("tempdir");
511        let db_path = dir.path().join("test.db");
512        let backup_dir = dir.path().join("backups");
513
514        create_test_db(&db_path);
515
516        let mgr = BackupManager::new(db_path, backup_dir);
517        let info = mgr.create_backup().await.expect("backup");
518
519        let filename = info
520            .path
521            .file_name()
522            .and_then(|n| n.to_str())
523            .expect("filename");
524        assert!(
525            filename.starts_with("punch_backup_"),
526            "expected punch_backup_ prefix, got: {filename}"
527        );
528        assert!(
529            filename.ends_with(".db"),
530            "expected .db suffix, got: {filename}"
531        );
532    }
533
534    #[tokio::test]
535    async fn compressed_backup_roundtrip() {
536        let dir = tempfile::tempdir().expect("tempdir");
537        let db_path = dir.path().join("test.db");
538        let backup_dir = dir.path().join("backups");
539
540        create_test_db(&db_path);
541
542        let mgr = BackupManager::new(db_path.clone(), backup_dir).with_compression(true);
543        let info = mgr.create_backup().await.expect("backup");
544
545        let filename = info
546            .path
547            .file_name()
548            .and_then(|n| n.to_str())
549            .expect("filename");
550        assert!(
551            filename.ends_with(".db.gz"),
552            "expected .db.gz suffix, got: {filename}"
553        );
554
555        // Decompress and verify.
556        let restored = decompress_backup(&info.path).expect("decompress");
557        verify_backup(&restored).expect("integrity check on decompressed");
558    }
559}