use crate::cli::CliOutput;
use crate::db;
use anyhow::{Context, Result};
use clap::Args;
use std::path::{Path, PathBuf};
fn manifest_file_name(stem: &str) -> String {
format!("{stem}.manifest.json")
}
const BACKUP_TS_FMT: &str = "%Y-%m-%dT%H%M%SZ";
#[derive(Args)]
pub struct BackupArgs {
#[arg(long, default_value = "./backups")]
pub to: PathBuf,
#[arg(long, default_value_t = 48)]
pub keep: usize,
}
#[derive(Args)]
pub struct RestoreArgs {
#[arg(long)]
pub from: PathBuf,
#[arg(long)]
pub skip_verify: bool,
}
#[derive(serde::Serialize, serde::Deserialize)]
pub struct BackupManifest {
pub snapshot: String,
pub sha256: String,
pub bytes: u64,
pub source_db: String,
pub version: String,
pub created_at: String,
}
pub fn run_backup(
db_path: &Path,
args: &BackupArgs,
json_out: bool,
out: &mut CliOutput<'_>,
) -> Result<()> {
use std::io::Read;
std::fs::create_dir_all(&args.to)
.with_context(|| format!("creating backup dir {}", args.to.display()))?;
let conn = db::open(db_path).context("opening source DB for backup")?;
let ts = chrono::Utc::now().format(BACKUP_TS_FMT).to_string();
let snapshot_name = format!("ai-memory-{ts}.db");
let snapshot_path = args.to.join(&snapshot_name);
if snapshot_path.exists() {
anyhow::bail!(
"refusing to overwrite existing snapshot {}",
snapshot_path.display()
);
}
conn.execute(
"VACUUM INTO ?1",
rusqlite::params![snapshot_path.to_string_lossy()],
)
.context("VACUUM INTO failed")?;
drop(conn);
let bytes = std::fs::metadata(&snapshot_path)?.len();
let sha = {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
let mut f = std::fs::File::open(&snapshot_path)?;
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
format!("{:x}", hasher.finalize())
};
let manifest = BackupManifest {
snapshot: snapshot_name.clone(),
sha256: sha.clone(),
bytes,
source_db: db_path.to_string_lossy().into_owned(),
version: crate::PKG_VERSION.to_string(),
created_at: chrono::Utc::now().to_rfc3339(),
};
let manifest_path = args.to.join(format!("ai-memory-{ts}.manifest.json"));
let manifest_text = serde_json::to_string_pretty(&manifest)?;
std::fs::write(&manifest_path, manifest_text.as_bytes())?;
if args.keep > 0 {
prune_old_snapshots(&args.to, args.keep)?;
}
if json_out {
writeln!(out.stdout, "{}", serde_json::to_string(&manifest)?)?;
} else {
writeln!(out.stdout, "Snapshot: {}", snapshot_path.display())?;
writeln!(out.stdout, "Manifest: {}", manifest_path.display())?;
writeln!(out.stdout, "SHA-256 : {sha}")?;
writeln!(out.stdout, "Bytes : {bytes}")?;
}
Ok(())
}
fn prune_old_snapshots(dir: &Path, keep: usize) -> Result<()> {
let mut snaps: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(dir)?
.filter_map(std::result::Result::ok)
.filter_map(|entry| {
let path = entry.path();
let name = path.file_name()?.to_str()?.to_owned();
let is_snapshot = name.starts_with("ai-memory-")
&& path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("db"));
if is_snapshot {
let mtime = entry.metadata().ok()?.modified().ok()?;
Some((mtime, path))
} else {
None
}
})
.collect();
snaps.sort_by_key(|b| std::cmp::Reverse(b.0));
for (_, path) in snaps.into_iter().skip(keep) {
let _ = std::fs::remove_file(&path);
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
let manifest = dir.join(manifest_file_name(stem));
let _ = std::fs::remove_file(manifest);
}
}
Ok(())
}
pub fn run_restore(
db_path: &Path,
args: &RestoreArgs,
json_out: bool,
out: &mut CliOutput<'_>,
) -> Result<()> {
use std::io::Read;
let (snapshot_path, manifest_path) = if args.from.is_dir() {
let mut snaps: Vec<(std::time::SystemTime, PathBuf)> = std::fs::read_dir(&args.from)?
.filter_map(std::result::Result::ok)
.filter_map(|entry| {
let path = entry.path();
let name = path.file_name()?.to_str()?.to_owned();
let is_snapshot = name.starts_with("ai-memory-")
&& path
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("db"));
if is_snapshot {
let mtime = entry.metadata().ok()?.modified().ok()?;
Some((mtime, path))
} else {
None
}
})
.collect();
snaps.sort_by_key(|b| std::cmp::Reverse(b.0));
let snap = snaps
.into_iter()
.next()
.map(|(_, p)| p)
.ok_or_else(|| anyhow::anyhow!("no snapshots found in {}", args.from.display()))?;
let stem = snap.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let manifest = args.from.join(manifest_file_name(stem));
(snap, manifest)
} else {
let snap = args.from.clone();
let stem = snap.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let parent = snap.parent().unwrap_or_else(|| Path::new("."));
let manifest = parent.join(manifest_file_name(stem));
(snap, manifest)
};
if !snapshot_path.exists() {
anyhow::bail!("snapshot {} does not exist", snapshot_path.display());
}
if !args.skip_verify {
if !manifest_path.exists() {
anyhow::bail!(
"manifest {} not found; pass --skip-verify to restore anyway",
manifest_path.display()
);
}
let manifest_text = std::fs::read_to_string(&manifest_path)?;
let manifest: BackupManifest = serde_json::from_str(&manifest_text)
.with_context(|| format!("parsing manifest {}", manifest_path.display()))?;
let observed = {
use sha2::Digest;
let mut hasher = sha2::Sha256::new();
let mut f = std::fs::File::open(&snapshot_path)?;
let mut buf = vec![0u8; 64 * 1024];
loop {
let n = f.read(&mut buf)?;
if n == 0 {
break;
}
hasher.update(&buf[..n]);
}
format!("{:x}", hasher.finalize())
};
if observed != manifest.sha256 {
anyhow::bail!(
"sha256 mismatch — manifest says {}, snapshot is {}",
manifest.sha256,
observed
);
}
}
if db_path.exists() {
let ts = chrono::Utc::now().format(BACKUP_TS_FMT).to_string();
let aside = db_path.with_extension(format!("pre-restore-{ts}.db"));
std::fs::rename(db_path, &aside)
.with_context(|| format!("moving current DB aside to {}", aside.display()))?;
if !json_out {
writeln!(out.stdout, "Previous DB moved to {}", aside.display())?;
}
}
std::fs::copy(&snapshot_path, db_path)
.with_context(|| format!("copying snapshot to {}", db_path.display()))?;
if json_out {
writeln!(
out.stdout,
"{}",
serde_json::json!({
"status": "restored",
"from": snapshot_path.to_string_lossy(),
"to": db_path.to_string_lossy(),
})
)?;
} else {
writeln!(
out.stdout,
"Restored {} → {}",
snapshot_path.display(),
db_path.display()
)?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::cli::test_utils::{TestEnv, seed_memory};
#[test]
fn test_backup_happy_path_creates_snapshot_and_manifest() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x1");
let args = BackupArgs {
to: backup_dir.clone(),
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &args, false, &mut out).unwrap();
}
let mut snap_count = 0;
let mut manifest_count = 0;
for entry in std::fs::read_dir(&backup_dir).unwrap().flatten() {
let name = entry.file_name();
let s = name.to_string_lossy();
if s.starts_with("ai-memory-") && s.ends_with(".db") {
snap_count += 1;
}
if s.ends_with(".manifest.json") {
manifest_count += 1;
}
}
assert!(snap_count >= 1, "expected at least one snapshot");
assert!(manifest_count >= 1, "expected at least one manifest");
assert!(env.stdout_str().contains("Snapshot:"));
}
#[test]
fn test_backup_json_emits_manifest_with_sha256() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x2");
let args = BackupArgs {
to: backup_dir,
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &args, true, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert!(v["sha256"].is_string());
let sha = v["sha256"].as_str().unwrap();
assert_eq!(sha.len(), 64); }
#[test]
fn test_restore_from_directory_picks_newest() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "before-backup", "stuff");
let backup_dir = db.parent().unwrap().join("backups-x3");
let backup_args = BackupArgs {
to: backup_dir.clone(),
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &backup_args, false, &mut out).unwrap();
}
env.stdout.clear();
env.stderr.clear();
let restore_args = RestoreArgs {
from: backup_dir,
skip_verify: false,
};
{
let mut out = env.output();
run_restore(&db, &restore_args, false, &mut out).unwrap();
}
assert!(env.stdout_str().contains("Restored"));
}
#[test]
fn test_restore_from_explicit_file_path() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x4");
let backup_args = BackupArgs {
to: backup_dir.clone(),
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &backup_args, true, &mut out).unwrap();
}
let manifest: BackupManifest = serde_json::from_str(env.stdout_str().trim()).unwrap();
let snap_path = backup_dir.join(&manifest.snapshot);
env.stdout.clear();
env.stderr.clear();
let restore_args = RestoreArgs {
from: snap_path,
skip_verify: false,
};
{
let mut out = env.output();
run_restore(&db, &restore_args, true, &mut out).unwrap();
}
let v: serde_json::Value = serde_json::from_str(env.stdout_str().trim()).unwrap();
assert_eq!(v["status"].as_str().unwrap(), "restored");
}
#[test]
fn test_restore_with_skip_verify_succeeds_without_manifest() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x5");
let backup_args = BackupArgs {
to: backup_dir.clone(),
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &backup_args, true, &mut out).unwrap();
}
let manifest: BackupManifest = serde_json::from_str(env.stdout_str().trim()).unwrap();
let snap_path = backup_dir.join(&manifest.snapshot);
let manifest_path = backup_dir.join(format!(
"{}.manifest.json",
snap_path.file_stem().unwrap().to_string_lossy()
));
std::fs::remove_file(&manifest_path).unwrap();
env.stdout.clear();
env.stderr.clear();
let restore_args = RestoreArgs {
from: snap_path,
skip_verify: true,
};
{
let mut out = env.output();
run_restore(&db, &restore_args, false, &mut out).unwrap();
}
assert!(env.stdout_str().contains("Restored"));
}
#[test]
fn test_restore_bad_sha256_errors() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x6");
let backup_args = BackupArgs {
to: backup_dir.clone(),
keep: 48,
};
{
let mut out = env.output();
run_backup(&db, &backup_args, true, &mut out).unwrap();
}
let manifest: BackupManifest = serde_json::from_str(env.stdout_str().trim()).unwrap();
let manifest_path = backup_dir.join(format!(
"{}.manifest.json",
std::path::Path::new(&manifest.snapshot)
.file_stem()
.unwrap()
.to_string_lossy()
));
let mut bad = manifest;
bad.sha256 = "0000000000000000000000000000000000000000000000000000000000000000".to_string();
std::fs::write(&manifest_path, serde_json::to_string(&bad).unwrap()).unwrap();
let snap_path = backup_dir.join(&bad.snapshot);
let restore_args = RestoreArgs {
from: snap_path,
skip_verify: false,
};
let mut out = env.output();
let res = run_restore(&db, &restore_args, false, &mut out);
assert!(res.is_err());
assert!(res.unwrap_err().to_string().contains("sha256 mismatch"));
}
#[test]
fn test_backup_retention_prunes_old_snapshots() {
let mut env = TestEnv::fresh();
let db = env.db_path.clone();
seed_memory(&db, "ns", "t", "c");
let backup_dir = db.parent().unwrap().join("backups-x7");
for _ in 0..3 {
std::thread::sleep(std::time::Duration::from_secs(1));
let args = BackupArgs {
to: backup_dir.clone(),
keep: 1,
};
let mut out = env.output();
run_backup(&db, &args, true, &mut out).unwrap();
drop(out);
env.stdout.clear();
env.stderr.clear();
}
let snaps: Vec<_> = std::fs::read_dir(&backup_dir)
.unwrap()
.flatten()
.filter(|e| {
let name = e.file_name();
let s = name.to_string_lossy();
s.starts_with("ai-memory-") && s.ends_with(".db")
})
.collect();
assert_eq!(snaps.len(), 1, "retention should keep exactly 1 snapshot");
}
}