1use std::path::{Path, PathBuf};
7
8use chrono::{DateTime, Utc};
9use rusqlite::Connection;
10use tracing::{info, warn};
11
12use punch_types::{PunchError, PunchResult};
13
14#[derive(Debug, Clone)]
16pub struct BackupInfo {
17 pub id: String,
19 pub path: PathBuf,
21 pub size_bytes: u64,
23 pub created_at: DateTime<Utc>,
25 pub db_version: String,
27}
28
29pub struct BackupManager {
31 db_path: PathBuf,
33 backup_dir: PathBuf,
35 max_backups: usize,
37 compress: bool,
39}
40
41impl BackupManager {
42 pub fn new(db_path: PathBuf, backup_dir: PathBuf) -> Self {
48 Self {
49 db_path,
50 backup_dir,
51 max_backups: 10,
52 compress: false,
53 }
54 }
55
56 pub fn with_max_backups(mut self, max: usize) -> Self {
58 self.max_backups = max;
59 self
60 }
61
62 pub fn with_compression(mut self, compress: bool) -> Self {
64 self.compress = compress;
65 self
66 }
67
68 pub async fn create_backup(&self) -> PunchResult<BackupInfo> {
73 std::fs::create_dir_all(&self.backup_dir).map_err(|e| {
74 PunchError::Memory(format!(
75 "failed to create backup directory {}: {e}",
76 self.backup_dir.display()
77 ))
78 })?;
79
80 let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
81 let base_name = format!("punch_backup_{}.db", timestamp);
82 let backup_path = self.backup_dir.join(&base_name);
83
84 let db_path = self.db_path.clone();
86 let dest = backup_path.clone();
87 let compress = self.compress;
88
89 let final_path = tokio::task::spawn_blocking(move || -> PunchResult<PathBuf> {
90 let conn = Connection::open(&db_path).map_err(|e| {
91 PunchError::Memory(format!("failed to open database for backup: {e}"))
92 })?;
93
94 let dest_str = dest.to_string_lossy().to_string();
95 conn.execute_batch(&format!("VACUUM INTO '{}'", dest_str.replace('\'', "''")))
96 .map_err(|e| PunchError::Memory(format!("VACUUM INTO failed: {e}")))?;
97
98 verify_backup(&dest)?;
100
101 if compress {
102 let gz_path = compress_backup(&dest)?;
103 let _ = std::fs::remove_file(&dest);
105 Ok(gz_path)
106 } else {
107 Ok(dest)
108 }
109 })
110 .await
111 .map_err(|e| PunchError::Memory(format!("backup task panicked: {e}")))??;
112
113 let metadata = std::fs::metadata(&final_path)
114 .map_err(|e| PunchError::Memory(format!("failed to stat backup file: {e}")))?;
115
116 let id = final_path
117 .file_stem()
118 .and_then(|s| s.to_str())
119 .unwrap_or(&base_name)
120 .to_string();
121
122 let info = BackupInfo {
123 id,
124 path: final_path.clone(),
125 size_bytes: metadata.len(),
126 created_at: Utc::now(),
127 db_version: read_db_version(&self.db_path)?,
128 };
129
130 info!(
131 path = %final_path.display(),
132 size_bytes = info.size_bytes,
133 "database backup created"
134 );
135
136 Ok(info)
137 }
138
139 pub async fn list_backups(&self) -> PunchResult<Vec<BackupInfo>> {
141 let backup_dir = self.backup_dir.clone();
142 let db_path = self.db_path.clone();
143
144 tokio::task::spawn_blocking(move || -> PunchResult<Vec<BackupInfo>> {
145 let mut backups = Vec::new();
146
147 let entries = match std::fs::read_dir(&backup_dir) {
148 Ok(entries) => entries,
149 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(backups),
150 Err(e) => {
151 return Err(PunchError::Memory(format!(
152 "failed to read backup directory: {e}"
153 )));
154 }
155 };
156
157 for entry in entries {
158 let entry = entry.map_err(|e| {
159 PunchError::Memory(format!("failed to read directory entry: {e}"))
160 })?;
161
162 let path = entry.path();
163 let name = path
164 .file_name()
165 .and_then(|n| n.to_str())
166 .unwrap_or_default();
167
168 if !name.starts_with("punch_backup_") {
169 continue;
170 }
171
172 let metadata = entry
173 .metadata()
174 .map_err(|e| PunchError::Memory(format!("failed to stat backup file: {e}")))?;
175
176 let created_at = metadata
177 .created()
178 .or_else(|_| metadata.modified())
179 .map(DateTime::<Utc>::from)
180 .unwrap_or_else(|_| Utc::now());
181
182 let id = path
183 .file_stem()
184 .and_then(|s| s.to_str())
185 .unwrap_or(name)
186 .to_string();
187
188 let db_version =
189 read_db_version(&db_path).unwrap_or_else(|_| "unknown".to_string());
190
191 backups.push(BackupInfo {
192 id,
193 path,
194 size_bytes: metadata.len(),
195 created_at,
196 db_version,
197 });
198 }
199
200 backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
202
203 Ok(backups)
204 })
205 .await
206 .map_err(|e| PunchError::Memory(format!("list backups task panicked: {e}")))?
207 }
208
209 pub async fn restore_backup(&self, backup_id: &str) -> PunchResult<()> {
214 let backups = self.list_backups().await?;
215 let backup = backups
216 .iter()
217 .find(|b| b.id == backup_id)
218 .ok_or_else(|| PunchError::Memory(format!("backup not found: {backup_id}")))?;
219
220 let backup_path = backup.path.clone();
221 let db_path = self.db_path.clone();
222
223 tokio::task::spawn_blocking(move || -> PunchResult<()> {
224 let source = if backup_path.extension().and_then(|e| e.to_str()) == Some("gz") {
225 decompress_backup(&backup_path)?
226 } else {
227 backup_path.clone()
228 };
229
230 verify_backup(&source)?;
232
233 std::fs::copy(&source, &db_path)
235 .map_err(|e| PunchError::Memory(format!("failed to restore backup: {e}")))?;
236
237 if source != backup_path {
239 let _ = std::fs::remove_file(&source);
240 }
241
242 info!(
243 backup = %backup_path.display(),
244 target = %db_path.display(),
245 "database restored from backup"
246 );
247
248 Ok(())
249 })
250 .await
251 .map_err(|e| PunchError::Memory(format!("restore task panicked: {e}")))?
252 }
253
254 pub async fn cleanup_old_backups(&self) -> PunchResult<usize> {
258 let backups = self.list_backups().await?;
259
260 if backups.len() <= self.max_backups {
261 return Ok(0);
262 }
263
264 let to_remove = &backups[self.max_backups..];
265 let mut removed = 0;
266
267 for backup in to_remove {
268 match std::fs::remove_file(&backup.path) {
269 Ok(()) => {
270 info!(path = %backup.path.display(), "removed old backup");
271 removed += 1;
272 }
273 Err(e) => {
274 warn!(
275 path = %backup.path.display(),
276 error = %e,
277 "failed to remove old backup"
278 );
279 }
280 }
281 }
282
283 Ok(removed)
284 }
285}
286
287fn read_db_version(path: &Path) -> PunchResult<String> {
293 let conn = Connection::open(path).map_err(|e| {
294 PunchError::Memory(format!("failed to open database for version check: {e}"))
295 })?;
296
297 let version: i64 = conn
298 .pragma_query_value(None, "user_version", |row| row.get(0))
299 .map_err(|e| PunchError::Memory(format!("failed to read user_version: {e}")))?;
300
301 Ok(version.to_string())
302}
303
304fn verify_backup(path: &Path) -> PunchResult<()> {
306 let conn = Connection::open(path)
307 .map_err(|e| PunchError::Memory(format!("failed to open backup for verification: {e}")))?;
308
309 let result: String = conn
310 .pragma_query_value(None, "integrity_check", |row| row.get(0))
311 .map_err(|e| PunchError::Memory(format!("integrity check failed: {e}")))?;
312
313 if result != "ok" {
314 return Err(PunchError::Memory(format!(
315 "backup integrity check failed: {result}"
316 )));
317 }
318
319 Ok(())
320}
321
322fn compress_backup(path: &Path) -> PunchResult<PathBuf> {
324 use flate2::Compression;
325 use flate2::write::GzEncoder;
326 use std::io::{Read, Write};
327
328 let gz_path = path.with_extension("db.gz");
329 let mut input = std::fs::File::open(path)
330 .map_err(|e| PunchError::Memory(format!("failed to open backup for compression: {e}")))?;
331
332 let output = std::fs::File::create(&gz_path)
333 .map_err(|e| PunchError::Memory(format!("failed to create compressed backup: {e}")))?;
334
335 let mut encoder = GzEncoder::new(output, Compression::default());
336 let mut buf = [0u8; 64 * 1024];
337
338 loop {
339 let n = input
340 .read(&mut buf)
341 .map_err(|e| PunchError::Memory(format!("read error during compression: {e}")))?;
342 if n == 0 {
343 break;
344 }
345 encoder
346 .write_all(&buf[..n])
347 .map_err(|e| PunchError::Memory(format!("write error during compression: {e}")))?;
348 }
349
350 encoder
351 .finish()
352 .map_err(|e| PunchError::Memory(format!("failed to finalize compressed backup: {e}")))?;
353
354 Ok(gz_path)
355}
356
357fn decompress_backup(gz_path: &Path) -> PunchResult<PathBuf> {
359 use flate2::read::GzDecoder;
360 use std::io::{Read, Write};
361
362 let stem = gz_path
363 .file_stem()
364 .and_then(|s| s.to_str())
365 .unwrap_or("backup");
366 let out_path = gz_path.with_file_name(format!("{}_restored.db", stem));
367
368 let input = std::fs::File::open(gz_path)
369 .map_err(|e| PunchError::Memory(format!("failed to open compressed backup: {e}")))?;
370
371 let mut decoder = GzDecoder::new(input);
372 let mut output = std::fs::File::create(&out_path)
373 .map_err(|e| PunchError::Memory(format!("failed to create decompressed file: {e}")))?;
374
375 let mut buf = [0u8; 64 * 1024];
376 loop {
377 let n = decoder
378 .read(&mut buf)
379 .map_err(|e| PunchError::Memory(format!("read error during decompression: {e}")))?;
380 if n == 0 {
381 break;
382 }
383 output
384 .write_all(&buf[..n])
385 .map_err(|e| PunchError::Memory(format!("write error during decompression: {e}")))?;
386 }
387
388 Ok(out_path)
389}
390
391#[cfg(test)]
396mod tests {
397 use super::*;
398 use std::fs;
399
400 fn create_test_db(path: &Path) {
402 let conn = Connection::open(path).expect("create test db");
403 conn.execute_batch(
404 "PRAGMA journal_mode = WAL;
405 CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);
406 INSERT INTO test (value) VALUES ('hello');",
407 )
408 .expect("init test db");
409 }
410
411 #[tokio::test]
412 async fn create_backup_produces_a_file() {
413 let dir = tempfile::tempdir().expect("tempdir");
414 let db_path = dir.path().join("test.db");
415 let backup_dir = dir.path().join("backups");
416
417 create_test_db(&db_path);
418
419 let mgr = BackupManager::new(db_path, backup_dir.clone());
420 let info = mgr.create_backup().await.expect("create backup");
421
422 assert!(info.path.exists());
423 assert!(info.size_bytes > 0);
424 assert!(info.id.starts_with("punch_backup_"));
425 }
426
427 #[tokio::test]
428 async fn backup_is_valid_sqlite() {
429 let dir = tempfile::tempdir().expect("tempdir");
430 let db_path = dir.path().join("test.db");
431 let backup_dir = dir.path().join("backups");
432
433 create_test_db(&db_path);
434
435 let mgr = BackupManager::new(db_path, backup_dir);
436 let info = mgr.create_backup().await.expect("create backup");
437
438 verify_backup(&info.path).expect("integrity check should pass");
440
441 let conn = Connection::open(&info.path).expect("open backup");
443 let value: String = conn
444 .query_row("SELECT value FROM test WHERE id = 1", [], |row| row.get(0))
445 .expect("query backup");
446 assert_eq!(value, "hello");
447 }
448
449 #[tokio::test]
450 async fn list_backups_returns_created_backups() {
451 let dir = tempfile::tempdir().expect("tempdir");
452 let db_path = dir.path().join("test.db");
453 let backup_dir = dir.path().join("backups");
454
455 create_test_db(&db_path);
456
457 let mgr = BackupManager::new(db_path, backup_dir);
458
459 let list = mgr.list_backups().await.expect("list");
461 assert!(list.is_empty());
462
463 mgr.create_backup().await.expect("backup 1");
464 let list = mgr.list_backups().await.expect("list");
465 assert_eq!(list.len(), 1);
466 }
467
468 #[tokio::test]
469 async fn cleanup_removes_old_backups() {
470 let dir = tempfile::tempdir().expect("tempdir");
471 let db_path = dir.path().join("test.db");
472 let backup_dir = dir.path().join("backups");
473
474 create_test_db(&db_path);
475
476 let mgr = BackupManager::new(db_path, backup_dir.clone()).with_max_backups(2);
477
478 for i in 0..4 {
480 let name = format!("punch_backup_20260101_12000{}.db", i);
481 let path = backup_dir.join(&name);
482 fs::create_dir_all(&backup_dir).expect("mkdir");
483 let conn = Connection::open(&path).expect("create");
485 conn.execute_batch("CREATE TABLE t (id INTEGER);")
486 .expect("init");
487 }
488
489 let removed = mgr.cleanup_old_backups().await.expect("cleanup");
490 assert_eq!(removed, 2);
491
492 let remaining = mgr.list_backups().await.expect("list");
493 assert_eq!(remaining.len(), 2);
494 }
495
496 #[tokio::test]
497 async fn backup_naming_follows_pattern() {
498 let dir = tempfile::tempdir().expect("tempdir");
499 let db_path = dir.path().join("test.db");
500 let backup_dir = dir.path().join("backups");
501
502 create_test_db(&db_path);
503
504 let mgr = BackupManager::new(db_path, backup_dir);
505 let info = mgr.create_backup().await.expect("backup");
506
507 let filename = info
508 .path
509 .file_name()
510 .and_then(|n| n.to_str())
511 .expect("filename");
512 assert!(
513 filename.starts_with("punch_backup_"),
514 "expected punch_backup_ prefix, got: {filename}"
515 );
516 assert!(
517 filename.ends_with(".db"),
518 "expected .db suffix, got: {filename}"
519 );
520 }
521
522 #[tokio::test]
523 async fn compressed_backup_roundtrip() {
524 let dir = tempfile::tempdir().expect("tempdir");
525 let db_path = dir.path().join("test.db");
526 let backup_dir = dir.path().join("backups");
527
528 create_test_db(&db_path);
529
530 let mgr = BackupManager::new(db_path.clone(), backup_dir).with_compression(true);
531 let info = mgr.create_backup().await.expect("backup");
532
533 let filename = info
534 .path
535 .file_name()
536 .and_then(|n| n.to_str())
537 .expect("filename");
538 assert!(
539 filename.ends_with(".db.gz"),
540 "expected .db.gz suffix, got: {filename}"
541 );
542
543 let restored = decompress_backup(&info.path).expect("decompress");
545 verify_backup(&restored).expect("integrity check on decompressed");
546 }
547}