use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Instant;
use futures::StreamExt;
use object_store::limit::LimitStore;
use object_store::path::Path as StorePath;
use object_store::{ObjectStore, ObjectStoreExt};
use tokio::io::{AsyncWriteExt, BufWriter};
use tracing::{debug, info};
use crate::{BackupFile, Error, meta::BackupMeta};
pub async fn list_backup_ids(store: &dyn ObjectStore, prefix: &str) -> Result<Vec<u64>, Error> {
let meta_prefix = StorePath::from(prefix).join("meta");
let mut ids = Vec::new();
let mut stream = store.list(Some(&meta_prefix));
while let Some(item) = stream.next().await {
let item = item.map_err(|e| Error::List {
prefix: meta_prefix.clone(),
source: e,
})?;
if let Some(name) = item.location.filename()
&& !name.ends_with(".tmp")
{
if let Ok(id) = name.parse::<u64>() {
ids.push(id);
} else {
debug!(name, "ignoring invalid backup id");
}
}
}
ids.sort();
Ok(ids)
}
pub const DEFAULT_CONCURRENCY: usize = 64;
#[derive(Debug)]
pub struct RestoreOptions {
pub backup_id: Option<u64>,
pub concurrency: usize,
pub verify: bool,
pub wal_dir: Option<PathBuf>,
}
impl Default for RestoreOptions {
fn default() -> Self {
Self {
backup_id: None,
concurrency: DEFAULT_CONCURRENCY,
verify: true,
wal_dir: None,
}
}
}
async fn download_file(
store: Arc<dyn ObjectStore>,
store_prefix: StorePath,
file: &BackupFile,
target: PathBuf,
wal_dir: PathBuf,
verify: bool,
) -> Result<u64, Error> {
let name = db_filename(&file.path)?;
let mut key = store_prefix;
key.extend(&StorePath::from(file.path.as_str()));
let dest = if name.as_os_str() == "CURRENT" {
target.join("CURRENT.tmp")
} else if name.extension().is_some_and(|ext| ext == "log") {
wal_dir.join(&name)
} else {
target.join(&name)
};
let result = store.get(&key).await.map_err(|e| Error::Fetch {
key: key.clone(),
source: e,
})?;
let mut stream = result.into_stream();
let f = tokio::fs::File::create(&dest)
.await
.map_err(|e| Error::Io {
path: dest.clone(),
source: e,
})?;
let mut out = BufWriter::new(f);
let mut total_size = 0u64;
let mut crc = 0u32;
let do_crc = verify && file.crc32c.is_some();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| Error::Fetch {
key: key.clone(),
source: e,
})?;
total_size += chunk.len() as u64;
if do_crc {
crc = crc32c::crc32c_append(crc, &chunk);
}
out.write_all(&chunk).await.map_err(|e| Error::Io {
path: dest.clone(),
source: e,
})?;
}
out.shutdown().await.map_err(|e| Error::Io {
path: dest.clone(),
source: e,
})?;
if let Some(expected) = file.size
&& total_size != expected
{
return Err(Error::SizeMismatch {
path: file.path.clone(),
expected,
actual: total_size,
});
}
if let Some(expected) = file.crc32c.filter(|_| verify)
&& crc != expected
{
return Err(Error::ChecksumMismatch {
path: file.path.clone(),
expected,
actual: crc,
});
}
Ok(total_size)
}
pub async fn restore(
store: Arc<dyn ObjectStore>,
prefix: &str,
target: impl AsRef<Path>,
opts: RestoreOptions,
) -> Result<(), Error> {
let target = target.as_ref();
let RestoreOptions {
backup_id,
concurrency,
verify,
wal_dir,
} = opts;
let wal_dir = wal_dir.as_deref().unwrap_or(target);
let store: Arc<dyn ObjectStore> = Arc::new(LimitStore::new(store, concurrency));
let id = match backup_id {
Some(id) => id,
None => {
let ids = list_backup_ids(&*store, prefix).await?;
*ids.last().ok_or(Error::NoBackups)?
}
};
let meta = fetch_meta(&*store, prefix, id).await?;
let excluded_count = meta.files.iter().filter(|f| f.excluded).count();
if excluded_count > 0 {
return Err(Error::ExcludedFiles {
count: excluded_count,
});
}
info!(
backup_id = id,
file_count = meta.files.len(),
sequence_number = meta.sequence_number,
"restoring backup"
);
tokio::fs::create_dir_all(target)
.await
.map_err(|e| Error::Io {
path: target.to_path_buf(),
source: e,
})?;
if wal_dir != target {
tokio::fs::create_dir_all(wal_dir)
.await
.map_err(|e| Error::Io {
path: wal_dir.to_path_buf(),
source: e,
})?;
}
let started = Instant::now();
let store_prefix = StorePath::from(prefix);
let target = target.to_path_buf();
let wal_dir = wal_dir.to_path_buf();
let tasks = meta.files.iter().map(|f| {
download_file(
Arc::clone(&store),
store_prefix.clone(),
f,
target.clone(),
wal_dir.clone(),
verify,
)
});
let total = tasks.len();
let mut completed = 0usize;
let mut total_bytes = 0u64;
let mut stream = futures::stream::iter(tasks).buffer_unordered(concurrency);
while let Some(result) = stream.next().await {
let bytes = result?; completed += 1;
total_bytes += bytes;
if completed.is_multiple_of(100) || completed == total {
let elapsed_secs = started.elapsed().as_secs_f64();
let mb = total_bytes as f64 / 2_f64.powf(20.);
let rate_mb_s = if elapsed_secs > 0.0 {
mb / elapsed_secs
} else {
0.0
};
info!(
completed,
total,
downloaded_mb = format_args!("{mb:.1}"),
elapsed_secs = format_args!("{elapsed_secs:.1}"),
rate_mb_s = format_args!("{rate_mb_s:.1}"),
"progress"
);
}
}
let current_tmp = target.join("CURRENT.tmp");
let current_final = target.join("CURRENT");
tokio::fs::rename(¤t_tmp, ¤t_final)
.await
.map_err(|e| Error::Io {
path: current_final,
source: e,
})?;
let elapsed_secs = started.elapsed().as_secs_f64();
let mb = total_bytes as f64 / 2_f64.powf(20.);
let rate_mb_s = if elapsed_secs > 0.0 {
mb / elapsed_secs
} else {
0.0
};
info!(
total_files = total,
total_mb = format_args!("{mb:.1}"),
elapsed_secs = format_args!("{elapsed_secs:.1}"),
rate_mb_s = format_args!("{rate_mb_s:.1}"),
"restore complete"
);
Ok(())
}
pub async fn fetch_meta(
store: &dyn ObjectStore,
prefix: &str,
id: u64,
) -> Result<BackupMeta, Error> {
let key = StorePath::from(prefix).join("meta").join(id.to_string());
let data = store
.get(&key)
.await
.map_err(|e| Error::Fetch {
key: key.clone(),
source: e,
})?
.bytes()
.await
.map_err(|e| Error::Fetch { key, source: e })?;
let text = String::from_utf8(data.to_vec()).map_err(|e| Error::MetaEncoding {
backup_id: id,
source: e,
})?;
BackupMeta::parse(&text).map_err(|e| Error::MetaParse {
backup_id: id,
source: e,
})
}
pub(crate) fn db_filename(backup_path: &str) -> Result<PathBuf, Error> {
let sp = StorePath::from(backup_path);
let parts: Vec<_> = sp.parts().collect();
match parts.first().map(|p| p.as_ref()) {
Some("shared_checksum") => {
let filename = parts
.last()
.ok_or_else(|| Error::SharedChecksumNoExtension(backup_path.to_string()))?;
unmangle_shared_checksum(filename.as_ref())
}
Some("private") => {
parts
.get(2)
.map(|p| PathBuf::from(p.as_ref()))
.ok_or_else(|| Error::PrivatePathTooShort(backup_path.to_string()))
}
Some("shared") => parts
.last()
.map(|p| PathBuf::from(p.as_ref()))
.ok_or_else(|| Error::UnrecognizedPathPrefix(backup_path.to_string())),
_ => Err(Error::UnrecognizedPathPrefix(backup_path.to_string())),
}
}
fn unmangle_shared_checksum(mangled: &str) -> Result<PathBuf, Error> {
let p = Path::new(mangled);
let ext = p
.extension()
.and_then(|e| e.to_str())
.ok_or_else(|| Error::SharedChecksumNoExtension(mangled.to_string()))?;
let stem = p
.file_stem()
.and_then(|s| s.to_str())
.ok_or_else(|| Error::SharedChecksumNoExtension(mangled.to_string()))?;
let underscore = stem
.find('_')
.ok_or_else(|| Error::SharedChecksumNoUnderscore(mangled.to_string()))?;
Ok(PathBuf::from(format!("{}.{ext}", &stem[..underscore])))
}
#[cfg(test)]
mod tests {
use super::*;
use object_store::{PutPayload, memory::InMemory};
async fn put(store: &InMemory, path: &str, data: &[u8]) {
use object_store::ObjectStoreExt;
store
.put(
&StorePath::from(path),
PutPayload::from_iter(data.iter().copied()),
)
.await
.unwrap();
}
fn build_meta(timestamp: u64, seq: u64, files: &[(&str, &[u8])]) -> String {
let mut lines = vec![
timestamp.to_string(),
seq.to_string(),
files.len().to_string(),
];
for (path, data) in files {
let crc = crc32c::crc32c(data);
lines.push(format!("{path} crc32 {crc}"));
}
lines.join("\n")
}
async fn add_backup(
store: &InMemory,
id: u64,
timestamp: u64,
seq: u64,
files: &[(&str, &[u8])],
) {
let meta = build_meta(timestamp, seq, files);
put(store, &format!("meta/{id}"), meta.as_bytes()).await;
for (path, data) in files {
put(store, path, data).await;
}
}
#[tokio::test]
async fn list_discovers_backups() {
let store = InMemory::new();
add_backup(&store, 1, 1000, 100, &[("private/1/CURRENT", b"M-1\n")]).await;
add_backup(&store, 3, 3000, 300, &[("private/3/CURRENT", b"M-3\n")]).await;
put(&store, "meta/5.tmp", b"in progress").await;
let ids = list_backup_ids(&store, "").await.unwrap();
assert_eq!(ids, vec![1, 3]);
}
#[tokio::test]
async fn list_with_prefix() {
let store = InMemory::new();
let meta = build_meta(1000, 100, &[("pfx/private/1/CURRENT", b"M\n")]);
put(&store, "pfx/meta/1", meta.as_bytes()).await;
put(&store, "pfx/private/1/CURRENT", b"M\n").await;
let ids = list_backup_ids(&store, "pfx").await.unwrap();
assert_eq!(ids, vec![1]);
}
#[tokio::test]
async fn list_empty() {
let store = InMemory::new();
let ids = list_backup_ids(&store, "").await.unwrap();
assert!(ids.is_empty());
}
#[tokio::test]
async fn places_files_correctly() {
let store = InMemory::new();
let current = b"MANIFEST-000008\n";
let manifest = b"manifest-data-here";
let options = b"options-data-here";
let sst = b"sst-file-contents!";
add_backup(
&store,
1,
1000,
100,
&[
("private/1/CURRENT", current),
("private/1/MANIFEST-000008", manifest),
("private/1/OPTIONS-000009", options),
("shared_checksum/000007_123_456.sst", sst),
],
)
.await;
let target = tempfile::tempdir().unwrap();
let tp = target.path();
restore(
Arc::new(store),
"",
tp,
RestoreOptions {
backup_id: Some(1),
..Default::default()
},
)
.await
.unwrap();
assert_eq!(std::fs::read(tp.join("CURRENT")).unwrap(), current);
assert_eq!(std::fs::read(tp.join("MANIFEST-000008")).unwrap(), manifest);
assert_eq!(std::fs::read(tp.join("OPTIONS-000009")).unwrap(), options);
assert_eq!(std::fs::read(tp.join("000007.sst")).unwrap(), sst);
}
#[tokio::test]
async fn routes_wal_to_wal_dir() {
let store = InMemory::new();
add_backup(
&store,
1,
1000,
100,
&[
("private/1/CURRENT", b"M-1\n"),
("private/1/000003.log", b"wal-data"),
],
)
.await;
let target = tempfile::tempdir().unwrap();
let wal = tempfile::tempdir().unwrap();
restore(
Arc::new(store),
"",
target.path(),
RestoreOptions {
backup_id: Some(1),
wal_dir: Some(wal.path().to_path_buf()),
..Default::default()
},
)
.await
.unwrap();
assert!(wal.path().join("000003.log").exists());
assert!(!target.path().join("000003.log").exists());
assert!(target.path().join("CURRENT").exists());
}
#[tokio::test]
async fn defaults_to_latest() {
let store = InMemory::new();
add_backup(&store, 1, 1000, 100, &[("private/1/CURRENT", b"old\n")]).await;
add_backup(&store, 5, 5000, 500, &[("private/5/CURRENT", b"new\n")]).await;
let target = tempfile::tempdir().unwrap();
restore(Arc::new(store), "", target.path(), Default::default())
.await
.unwrap();
assert_eq!(
std::fs::read(target.path().join("CURRENT")).unwrap(),
b"new\n"
);
}
#[tokio::test]
async fn rejects_excluded_files() {
let store = InMemory::new();
let current = b"M-1\n";
let crc = crc32c::crc32c(current);
let meta = format!(
"schema_version 2.1\n1000\n100\n2\n\
private/1/CURRENT crc32 {crc}\n\
shared_checksum/000099_123_456.sst crc32 999 ni::excluded true"
);
put(&store, "meta/1", meta.as_bytes()).await;
put(&store, "private/1/CURRENT", current).await;
let target = tempfile::tempdir().unwrap();
let result = restore(
Arc::new(store),
"",
target.path(),
RestoreOptions {
backup_id: Some(1),
..Default::default()
},
)
.await;
assert!(matches!(result, Err(Error::ExcludedFiles { count: 1 })));
}
#[tokio::test]
async fn verify_passes_with_correct_crc() {
let store = InMemory::new();
add_backup(&store, 1, 1000, 100, &[("private/1/CURRENT", b"data\n")]).await;
let target = tempfile::tempdir().unwrap();
restore(
Arc::new(store),
"",
target.path(),
RestoreOptions {
backup_id: Some(1),
verify: true,
..Default::default()
},
)
.await
.unwrap();
}
#[tokio::test]
async fn size_mismatch_detected() {
let store = InMemory::new();
let content = b"hello";
let crc = crc32c::crc32c(content);
let meta = format!(
"schema_version 2.1\n1000\n100\n1\n\
private/1/CURRENT crc32 {crc} size 999"
);
put(&store, "meta/1", meta.as_bytes()).await;
put(&store, "private/1/CURRENT", content).await;
let target = tempfile::tempdir().unwrap();
let result = restore(
Arc::new(store),
"",
target.path(),
RestoreOptions {
backup_id: Some(1),
..Default::default()
},
)
.await;
assert!(matches!(
result,
Err(Error::SizeMismatch {
expected: 999,
actual: 5,
..
})
));
}
#[tokio::test]
async fn verify_catches_mismatch() {
let store = InMemory::new();
let meta = build_meta(1000, 100, &[("private/1/CURRENT", b"correct data")]);
put(&store, "meta/1", meta.as_bytes()).await;
put(&store, "private/1/CURRENT", b"tampered data").await;
let target = tempfile::tempdir().unwrap();
let result = restore(
Arc::new(store),
"",
target.path(),
RestoreOptions {
backup_id: Some(1),
verify: true,
..Default::default()
},
)
.await;
assert!(matches!(result, Err(Error::ChecksumMismatch { .. })));
}
#[tokio::test]
async fn no_backups() {
let store = InMemory::new();
let target = tempfile::tempdir().unwrap();
let result = restore(Arc::new(store), "", target.path(), Default::default()).await;
assert!(matches!(result, Err(Error::NoBackups)));
}
#[tokio::test]
async fn fetch_meta_parses() {
let store = InMemory::new();
add_backup(
&store,
1,
1000,
100,
&[
("private/1/CURRENT", b"MANIFEST-1\n"),
("shared_checksum/000007_123_456.sst", b"sst-data"),
],
)
.await;
let meta = fetch_meta(&store, "", 1).await.unwrap();
assert_eq!(meta.timestamp, 1000);
assert_eq!(meta.sequence_number, 100);
assert_eq!(meta.files.len(), 2);
}
#[test]
fn shared_checksum() {
assert_eq!(
db_filename("shared_checksum/000007_2894567812_590.sst").unwrap(),
Path::new("000007.sst")
);
}
#[test]
fn private() {
assert_eq!(
db_filename("private/1/MANIFEST-000008").unwrap(),
Path::new("MANIFEST-000008")
);
}
#[test]
fn shared() {
assert_eq!(
db_filename("shared/000007.sst").unwrap(),
Path::new("000007.sst")
);
}
#[test]
fn unrecognized_prefix() {
assert!(matches!(
db_filename("unknown/file.sst"),
Err(Error::UnrecognizedPathPrefix(_))
));
}
#[test]
fn private_too_short() {
assert!(matches!(
db_filename("private/1"),
Err(Error::PrivatePathTooShort(_))
));
}
}