1use std::path::{Path, PathBuf};
7
8use chrono::{DateTime, Utc};
9use rusqlite::Connection;
10use tracing::{info, warn};
11
12use punch_types::{PunchError, PunchResult};
13
14#[derive(Debug, Clone)]
16pub struct BackupInfo {
17 pub id: String,
19 pub path: PathBuf,
21 pub size_bytes: u64,
23 pub created_at: DateTime<Utc>,
25 pub db_version: String,
27}
28
29pub struct BackupManager {
31 db_path: PathBuf,
33 backup_dir: PathBuf,
35 max_backups: usize,
37 compress: bool,
39}
40
41impl BackupManager {
42 pub fn new(db_path: PathBuf, backup_dir: PathBuf) -> Self {
48 Self {
49 db_path,
50 backup_dir,
51 max_backups: 10,
52 compress: false,
53 }
54 }
55
56 pub fn with_max_backups(mut self, max: usize) -> Self {
58 self.max_backups = max;
59 self
60 }
61
62 pub fn with_compression(mut self, compress: bool) -> Self {
64 self.compress = compress;
65 self
66 }
67
68 pub async fn create_backup(&self) -> PunchResult<BackupInfo> {
73 std::fs::create_dir_all(&self.backup_dir).map_err(|e| {
74 PunchError::Memory(format!(
75 "failed to create backup directory {}: {e}",
76 self.backup_dir.display()
77 ))
78 })?;
79
80 let timestamp = Utc::now().format("%Y%m%d_%H%M%S").to_string();
81 let base_name = format!("punch_backup_{}.db", timestamp);
82 let backup_path = self.backup_dir.join(&base_name);
83
84 let db_path = self.db_path.clone();
86 let dest = backup_path.clone();
87 let compress = self.compress;
88
89 let final_path = tokio::task::spawn_blocking(move || -> PunchResult<PathBuf> {
90 let conn = Connection::open(&db_path).map_err(|e| {
91 PunchError::Memory(format!("failed to open database for backup: {e}"))
92 })?;
93
94 let dest_str = dest.to_string_lossy().to_string();
95 conn.execute_batch(&format!("VACUUM INTO '{}'", dest_str.replace('\'', "''")))
96 .map_err(|e| {
97 PunchError::Memory(format!("VACUUM INTO failed: {e}"))
98 })?;
99
100 verify_backup(&dest)?;
102
103 if compress {
104 let gz_path = compress_backup(&dest)?;
105 let _ = std::fs::remove_file(&dest);
107 Ok(gz_path)
108 } else {
109 Ok(dest)
110 }
111 })
112 .await
113 .map_err(|e| PunchError::Memory(format!("backup task panicked: {e}")))??;
114
115 let metadata = std::fs::metadata(&final_path).map_err(|e| {
116 PunchError::Memory(format!("failed to stat backup file: {e}"))
117 })?;
118
119 let id = final_path
120 .file_stem()
121 .and_then(|s| s.to_str())
122 .unwrap_or(&base_name)
123 .to_string();
124
125 let info = BackupInfo {
126 id,
127 path: final_path.clone(),
128 size_bytes: metadata.len(),
129 created_at: Utc::now(),
130 db_version: read_db_version(&self.db_path)?,
131 };
132
133 info!(
134 path = %final_path.display(),
135 size_bytes = info.size_bytes,
136 "database backup created"
137 );
138
139 Ok(info)
140 }
141
142 pub async fn list_backups(&self) -> PunchResult<Vec<BackupInfo>> {
144 let backup_dir = self.backup_dir.clone();
145 let db_path = self.db_path.clone();
146
147 tokio::task::spawn_blocking(move || -> PunchResult<Vec<BackupInfo>> {
148 let mut backups = Vec::new();
149
150 let entries = match std::fs::read_dir(&backup_dir) {
151 Ok(entries) => entries,
152 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(backups),
153 Err(e) => {
154 return Err(PunchError::Memory(format!(
155 "failed to read backup directory: {e}"
156 )));
157 }
158 };
159
160 for entry in entries {
161 let entry = entry.map_err(|e| {
162 PunchError::Memory(format!("failed to read directory entry: {e}"))
163 })?;
164
165 let path = entry.path();
166 let name = path
167 .file_name()
168 .and_then(|n| n.to_str())
169 .unwrap_or_default();
170
171 if !name.starts_with("punch_backup_") {
172 continue;
173 }
174
175 let metadata = entry.metadata().map_err(|e| {
176 PunchError::Memory(format!("failed to stat backup file: {e}"))
177 })?;
178
179 let created_at = metadata
180 .created()
181 .or_else(|_| metadata.modified())
182 .map(DateTime::<Utc>::from)
183 .unwrap_or_else(|_| Utc::now());
184
185 let id = path
186 .file_stem()
187 .and_then(|s| s.to_str())
188 .unwrap_or(name)
189 .to_string();
190
191 let db_version = read_db_version(&db_path).unwrap_or_else(|_| "unknown".to_string());
192
193 backups.push(BackupInfo {
194 id,
195 path,
196 size_bytes: metadata.len(),
197 created_at,
198 db_version,
199 });
200 }
201
202 backups.sort_by(|a, b| b.created_at.cmp(&a.created_at));
204
205 Ok(backups)
206 })
207 .await
208 .map_err(|e| PunchError::Memory(format!("list backups task panicked: {e}")))?
209 }
210
211 pub async fn restore_backup(&self, backup_id: &str) -> PunchResult<()> {
216 let backups = self.list_backups().await?;
217 let backup = backups
218 .iter()
219 .find(|b| b.id == backup_id)
220 .ok_or_else(|| PunchError::Memory(format!("backup not found: {backup_id}")))?;
221
222 let backup_path = backup.path.clone();
223 let db_path = self.db_path.clone();
224
225 tokio::task::spawn_blocking(move || -> PunchResult<()> {
226 let source = if backup_path.extension().and_then(|e| e.to_str()) == Some("gz") {
227 decompress_backup(&backup_path)?
228 } else {
229 backup_path.clone()
230 };
231
232 verify_backup(&source)?;
234
235 std::fs::copy(&source, &db_path).map_err(|e| {
237 PunchError::Memory(format!("failed to restore backup: {e}"))
238 })?;
239
240 if source != backup_path {
242 let _ = std::fs::remove_file(&source);
243 }
244
245 info!(
246 backup = %backup_path.display(),
247 target = %db_path.display(),
248 "database restored from backup"
249 );
250
251 Ok(())
252 })
253 .await
254 .map_err(|e| PunchError::Memory(format!("restore task panicked: {e}")))?
255 }
256
257 pub async fn cleanup_old_backups(&self) -> PunchResult<usize> {
261 let backups = self.list_backups().await?;
262
263 if backups.len() <= self.max_backups {
264 return Ok(0);
265 }
266
267 let to_remove = &backups[self.max_backups..];
268 let mut removed = 0;
269
270 for backup in to_remove {
271 match std::fs::remove_file(&backup.path) {
272 Ok(()) => {
273 info!(path = %backup.path.display(), "removed old backup");
274 removed += 1;
275 }
276 Err(e) => {
277 warn!(
278 path = %backup.path.display(),
279 error = %e,
280 "failed to remove old backup"
281 );
282 }
283 }
284 }
285
286 Ok(removed)
287 }
288}
289
290fn read_db_version(path: &Path) -> PunchResult<String> {
296 let conn = Connection::open(path).map_err(|e| {
297 PunchError::Memory(format!("failed to open database for version check: {e}"))
298 })?;
299
300 let version: i64 = conn
301 .pragma_query_value(None, "user_version", |row| row.get(0))
302 .map_err(|e| PunchError::Memory(format!("failed to read user_version: {e}")))?;
303
304 Ok(version.to_string())
305}
306
307fn verify_backup(path: &Path) -> PunchResult<()> {
309 let conn = Connection::open(path).map_err(|e| {
310 PunchError::Memory(format!(
311 "failed to open backup for verification: {e}"
312 ))
313 })?;
314
315 let result: String = conn
316 .pragma_query_value(None, "integrity_check", |row| row.get(0))
317 .map_err(|e| {
318 PunchError::Memory(format!("integrity check failed: {e}"))
319 })?;
320
321 if result != "ok" {
322 return Err(PunchError::Memory(format!(
323 "backup integrity check failed: {result}"
324 )));
325 }
326
327 Ok(())
328}
329
330fn compress_backup(path: &Path) -> PunchResult<PathBuf> {
332 use flate2::write::GzEncoder;
333 use flate2::Compression;
334 use std::io::{Read, Write};
335
336 let gz_path = path.with_extension("db.gz");
337 let mut input = std::fs::File::open(path).map_err(|e| {
338 PunchError::Memory(format!("failed to open backup for compression: {e}"))
339 })?;
340
341 let output = std::fs::File::create(&gz_path).map_err(|e| {
342 PunchError::Memory(format!("failed to create compressed backup: {e}"))
343 })?;
344
345 let mut encoder = GzEncoder::new(output, Compression::default());
346 let mut buf = [0u8; 64 * 1024];
347
348 loop {
349 let n = input.read(&mut buf).map_err(|e| {
350 PunchError::Memory(format!("read error during compression: {e}"))
351 })?;
352 if n == 0 {
353 break;
354 }
355 encoder.write_all(&buf[..n]).map_err(|e| {
356 PunchError::Memory(format!("write error during compression: {e}"))
357 })?;
358 }
359
360 encoder.finish().map_err(|e| {
361 PunchError::Memory(format!("failed to finalize compressed backup: {e}"))
362 })?;
363
364 Ok(gz_path)
365}
366
367fn decompress_backup(gz_path: &Path) -> PunchResult<PathBuf> {
369 use flate2::read::GzDecoder;
370 use std::io::{Read, Write};
371
372 let stem = gz_path
373 .file_stem()
374 .and_then(|s| s.to_str())
375 .unwrap_or("backup");
376 let out_path = gz_path.with_file_name(format!("{}_restored.db", stem));
377
378 let input = std::fs::File::open(gz_path).map_err(|e| {
379 PunchError::Memory(format!("failed to open compressed backup: {e}"))
380 })?;
381
382 let mut decoder = GzDecoder::new(input);
383 let mut output = std::fs::File::create(&out_path).map_err(|e| {
384 PunchError::Memory(format!("failed to create decompressed file: {e}"))
385 })?;
386
387 let mut buf = [0u8; 64 * 1024];
388 loop {
389 let n = decoder.read(&mut buf).map_err(|e| {
390 PunchError::Memory(format!("read error during decompression: {e}"))
391 })?;
392 if n == 0 {
393 break;
394 }
395 output.write_all(&buf[..n]).map_err(|e| {
396 PunchError::Memory(format!("write error during decompression: {e}"))
397 })?;
398 }
399
400 Ok(out_path)
401}
402
403#[cfg(test)]
408mod tests {
409 use super::*;
410 use std::fs;
411
412 fn create_test_db(path: &Path) {
414 let conn = Connection::open(path).expect("create test db");
415 conn.execute_batch(
416 "PRAGMA journal_mode = WAL;
417 CREATE TABLE test (id INTEGER PRIMARY KEY, value TEXT);
418 INSERT INTO test (value) VALUES ('hello');",
419 )
420 .expect("init test db");
421 }
422
423 #[tokio::test]
424 async fn create_backup_produces_a_file() {
425 let dir = tempfile::tempdir().expect("tempdir");
426 let db_path = dir.path().join("test.db");
427 let backup_dir = dir.path().join("backups");
428
429 create_test_db(&db_path);
430
431 let mgr = BackupManager::new(db_path, backup_dir.clone());
432 let info = mgr.create_backup().await.expect("create backup");
433
434 assert!(info.path.exists());
435 assert!(info.size_bytes > 0);
436 assert!(info.id.starts_with("punch_backup_"));
437 }
438
439 #[tokio::test]
440 async fn backup_is_valid_sqlite() {
441 let dir = tempfile::tempdir().expect("tempdir");
442 let db_path = dir.path().join("test.db");
443 let backup_dir = dir.path().join("backups");
444
445 create_test_db(&db_path);
446
447 let mgr = BackupManager::new(db_path, backup_dir);
448 let info = mgr.create_backup().await.expect("create backup");
449
450 verify_backup(&info.path).expect("integrity check should pass");
452
453 let conn = Connection::open(&info.path).expect("open backup");
455 let value: String = conn
456 .query_row("SELECT value FROM test WHERE id = 1", [], |row| row.get(0))
457 .expect("query backup");
458 assert_eq!(value, "hello");
459 }
460
461 #[tokio::test]
462 async fn list_backups_returns_created_backups() {
463 let dir = tempfile::tempdir().expect("tempdir");
464 let db_path = dir.path().join("test.db");
465 let backup_dir = dir.path().join("backups");
466
467 create_test_db(&db_path);
468
469 let mgr = BackupManager::new(db_path, backup_dir);
470
471 let list = mgr.list_backups().await.expect("list");
473 assert!(list.is_empty());
474
475 mgr.create_backup().await.expect("backup 1");
476 let list = mgr.list_backups().await.expect("list");
477 assert_eq!(list.len(), 1);
478 }
479
480 #[tokio::test]
481 async fn cleanup_removes_old_backups() {
482 let dir = tempfile::tempdir().expect("tempdir");
483 let db_path = dir.path().join("test.db");
484 let backup_dir = dir.path().join("backups");
485
486 create_test_db(&db_path);
487
488 let mgr = BackupManager::new(db_path, backup_dir.clone()).with_max_backups(2);
489
490 for i in 0..4 {
492 let name = format!("punch_backup_20260101_12000{}.db", i);
493 let path = backup_dir.join(&name);
494 fs::create_dir_all(&backup_dir).expect("mkdir");
495 let conn = Connection::open(&path).expect("create");
497 conn.execute_batch("CREATE TABLE t (id INTEGER);")
498 .expect("init");
499 }
500
501 let removed = mgr.cleanup_old_backups().await.expect("cleanup");
502 assert_eq!(removed, 2);
503
504 let remaining = mgr.list_backups().await.expect("list");
505 assert_eq!(remaining.len(), 2);
506 }
507
508 #[tokio::test]
509 async fn backup_naming_follows_pattern() {
510 let dir = tempfile::tempdir().expect("tempdir");
511 let db_path = dir.path().join("test.db");
512 let backup_dir = dir.path().join("backups");
513
514 create_test_db(&db_path);
515
516 let mgr = BackupManager::new(db_path, backup_dir);
517 let info = mgr.create_backup().await.expect("backup");
518
519 let filename = info
520 .path
521 .file_name()
522 .and_then(|n| n.to_str())
523 .expect("filename");
524 assert!(
525 filename.starts_with("punch_backup_"),
526 "expected punch_backup_ prefix, got: {filename}"
527 );
528 assert!(
529 filename.ends_with(".db"),
530 "expected .db suffix, got: {filename}"
531 );
532 }
533
534 #[tokio::test]
535 async fn compressed_backup_roundtrip() {
536 let dir = tempfile::tempdir().expect("tempdir");
537 let db_path = dir.path().join("test.db");
538 let backup_dir = dir.path().join("backups");
539
540 create_test_db(&db_path);
541
542 let mgr = BackupManager::new(db_path.clone(), backup_dir).with_compression(true);
543 let info = mgr.create_backup().await.expect("backup");
544
545 let filename = info
546 .path
547 .file_name()
548 .and_then(|n| n.to_str())
549 .expect("filename");
550 assert!(
551 filename.ends_with(".db.gz"),
552 "expected .db.gz suffix, got: {filename}"
553 );
554
555 let restored = decompress_backup(&info.path).expect("decompress");
557 verify_backup(&restored).expect("integrity check on decompressed");
558 }
559}