use std::collections::{HashMap, HashSet};
use std::fs::{self, File, OpenOptions};
use std::io::{self, Write};
use std::path::{Path, PathBuf};
use crate::artifact::{ExportManifest, ExportStage, verify_and_stage_import};
use crate::kv::{KvEntry, KvUpdate, VersionToken, WatchCursor};
const MAGIC: &[u8; 4] = b"PGSS";
const FORMAT_VERSION: u16 = 2;
const HEADER_LEN: usize = 6;
const REC_PUT: u8 = 0x01;
const REC_DELETE: u8 = 0x02;
const REC_CURSOR: u8 = 0x03;
const MIN_CURSOR_RECORD: usize = 4 + 1 + 1;
#[derive(Debug, thiserror::Error)]
pub enum SnapshotError {
#[error("snapshot I/O error: {0}")]
Io(#[from] io::Error),
#[error("invalid snapshot format: {0}")]
InvalidFormat(String),
#[error("snapshot corrupted (CRC mismatch)")]
Corrupted,
#[error("snapshot backend error: {0}")]
Backend(String),
#[error("invalid snapshot artifact: {0}")]
ArtifactInvalid(String),
}
#[derive(Debug)]
pub struct Snapshot {
pub cursor: WatchCursor,
pub entries: HashMap<String, KvEntry>,
}
impl Snapshot {
pub fn stale_keys<'a, I>(&'a self, current_keys: I) -> Vec<&'a str>
where
I: IntoIterator<Item = &'a str>,
{
let current: HashSet<&str> = current_keys.into_iter().collect();
self.entries
.keys()
.filter(|k| !current.contains(k.as_str()))
.map(|k| k.as_str())
.collect()
}
}
pub struct SnapshotWriter {
path: PathBuf,
writer: Option<io::BufWriter<File>>,
bytes_since_compact: u64,
compact_threshold: u64,
}
impl SnapshotWriter {
pub fn open(path: &Path, compact_threshold: u64) -> Result<Self, SnapshotError> {
let file = OpenOptions::new().create(true).append(true).open(path)?;
let existing_len = file.metadata()?.len();
let mut writer = io::BufWriter::new(file);
let bytes_since_compact = if existing_len >= HEADER_LEN as u64 {
existing_len - HEADER_LEN as u64
} else {
writer.write_all(MAGIC)?;
writer.write_all(&FORMAT_VERSION.to_le_bytes())?;
writer.flush()?;
0
};
Ok(Self {
path: path.to_path_buf(),
writer: Some(writer),
bytes_since_compact,
compact_threshold,
})
}
fn writer(&mut self) -> Result<&mut io::BufWriter<File>, SnapshotError> {
self.writer.as_mut().ok_or_else(|| {
SnapshotError::Io(io::Error::other(
"snapshot writer poisoned: a prior compact() failed to reopen the log for append",
))
})
}
#[must_use = "I/O errors mean the write was lost"]
pub fn write_update(&mut self, update: &KvUpdate) -> Result<(), SnapshotError> {
let w = self.writer()?;
let bytes = match update {
KvUpdate::Put(entry) => write_put_record(w, &entry.key, &entry.value, &entry.version)?,
KvUpdate::Delete { key, version } | KvUpdate::Purge { key, version } => {
write_delete_record(w, key, version)?
}
};
self.bytes_since_compact += bytes as u64;
Ok(())
}
#[must_use = "returns true when compaction is needed"]
pub fn checkpoint(&mut self, cursor: &WatchCursor) -> Result<bool, SnapshotError> {
let w = self.writer()?;
let bytes = write_cursor_record(w, cursor)?;
w.flush()?;
self.bytes_since_compact += bytes as u64;
Ok(self.bytes_since_compact > self.compact_threshold)
}
#[must_use = "I/O errors mean the flush failed"]
pub fn flush(&mut self) -> Result<(), SnapshotError> {
self.writer()?.flush()?;
Ok(())
}
#[must_use = "compaction errors leave the log uncompacted"]
pub fn compact(&mut self) -> Result<(), SnapshotError> {
self.writer()?.flush()?;
let data = fs::read(&self.path)?;
let (entries, cursor, _already_compact) = replay_log(&data)?;
compact_to_file(&self.path, &entries, &cursor)?;
self.writer = None;
let file = OpenOptions::new().append(true).open(&self.path)?;
self.writer = Some(io::BufWriter::new(file));
self.bytes_since_compact = 0;
Ok(())
}
}
pub fn load(path: &Path) -> Result<Option<Snapshot>, SnapshotError> {
let data = match fs::read(path) {
Ok(d) => d,
Err(e) if e.kind() == io::ErrorKind::NotFound => return Ok(None),
Err(e) => return Err(SnapshotError::Io(e)),
};
let (entries, cursor, already_compact) = replay_log(&data)?;
if entries.is_empty() && cursor.is_none() {
return Ok(None);
}
if !already_compact {
compact_to_file(path, &entries, &cursor)?;
}
Ok(Some(Snapshot { cursor, entries }))
}
pub trait SnapshotStore: Sized + Send {
fn load(path: &Path) -> Result<(WatchCursor, Self), SnapshotError>;
fn apply(&mut self, batch: &[KvUpdate], cursor: &WatchCursor) -> Result<(), SnapshotError>;
fn get(&self, key: &str) -> Result<Option<KvEntry>, SnapshotError>;
fn range(&self, prefix: &str) -> Result<Vec<KvEntry>, SnapshotError>;
fn for_each_in_range(
&self,
prefix: &str,
mut f: impl FnMut(KvEntry) -> Result<(), SnapshotError>,
) -> Result<(), SnapshotError> {
for entry in self.range(prefix)? {
f(entry)?;
}
Ok(())
}
fn cursor(&self) -> WatchCursor;
fn export_to(&mut self, dest_dir: &Path) -> Result<ExportManifest, SnapshotError>;
}
pub const DEFAULT_COMPACT_THRESHOLD: u64 = 10 * 1024 * 1024;
pub struct AppendLogSnapshot {
writer: SnapshotWriter,
entries: HashMap<String, KvEntry>,
cursor: WatchCursor,
}
impl AppendLogSnapshot {
pub(crate) const BACKEND: &'static str = "append-log";
pub(crate) const BACKEND_VERSION: &'static str = "2";
const ARTIFACT_PAYLOAD: &'static str = "fold.snap";
pub fn open(path: &Path, compact_threshold: u64) -> Result<(WatchCursor, Self), SnapshotError> {
let (cursor, entries) = match load(path)? {
Some(snap) => (snap.cursor, snap.entries),
None => (WatchCursor::none(), HashMap::new()),
};
let writer = SnapshotWriter::open(path, compact_threshold)?;
Ok((
cursor.clone(),
Self {
writer,
entries,
cursor,
},
))
}
pub fn import(
artifact_dir: &Path,
dest_path: &Path,
compact_threshold: u64,
) -> Result<(WatchCursor, Self), SnapshotError> {
let (manifest, stage) =
verify_and_stage_import(artifact_dir, dest_path, Self::BACKEND, |v| {
if v == Self::BACKEND_VERSION {
Ok(())
} else {
Err(SnapshotError::ArtifactInvalid(format!(
"append-log artifact has format generation {v:?}, this build reads {:?}",
Self::BACKEND_VERSION
)))
}
})?;
let expected = format!(
"{}/{}",
crate::artifact::PAYLOAD_DIR,
Self::ARTIFACT_PAYLOAD
);
if manifest.files.len() != 1 || manifest.files[0].path != expected {
return Err(SnapshotError::ArtifactInvalid(format!(
"append-log artifact must contain exactly {expected:?}"
)));
}
let staged_file = stage.payload().join(Self::ARTIFACT_PAYLOAD);
let staged_cursor = match load(&staged_file)? {
Some(snap) => snap.cursor,
None => WatchCursor::none(),
};
if staged_cursor != manifest.cursor {
return Err(SnapshotError::ArtifactInvalid(format!(
"payload cursor {staged_cursor:?} disagrees with manifest cursor {:?}",
manifest.cursor
)));
}
stage.finalize_file(Self::ARTIFACT_PAYLOAD)?;
Self::open(dest_path, compact_threshold)
}
}
impl SnapshotStore for AppendLogSnapshot {
fn load(path: &Path) -> Result<(WatchCursor, Self), SnapshotError> {
Self::open(path, DEFAULT_COMPACT_THRESHOLD)
}
fn apply(&mut self, batch: &[KvUpdate], cursor: &WatchCursor) -> Result<(), SnapshotError> {
for update in batch {
self.writer.write_update(update)?;
match update {
KvUpdate::Put(entry) => {
self.entries.insert(entry.key.clone(), entry.clone());
}
KvUpdate::Delete { key, .. } | KvUpdate::Purge { key, .. } => {
self.entries.remove(key);
}
}
}
if self.writer.checkpoint(cursor)? {
self.writer.compact()?;
}
self.cursor = cursor.clone();
Ok(())
}
fn get(&self, key: &str) -> Result<Option<KvEntry>, SnapshotError> {
Ok(self.entries.get(key).cloned())
}
fn range(&self, prefix: &str) -> Result<Vec<KvEntry>, SnapshotError> {
let mut out: Vec<KvEntry> = self
.entries
.values()
.filter(|e| e.key.starts_with(prefix))
.cloned()
.collect();
out.sort_unstable_by(|a, b| a.key.cmp(&b.key));
Ok(out)
}
fn cursor(&self) -> WatchCursor {
self.cursor.clone()
}
fn export_to(&mut self, dest_dir: &Path) -> Result<ExportManifest, SnapshotError> {
let stage = ExportStage::new(dest_dir)?;
fs::create_dir(stage.payload())?;
let payload = stage.payload().join(Self::ARTIFACT_PAYLOAD);
compact_to_file(&payload, &self.entries, &self.cursor)?;
let (loaded_cursor, loaded_len) = match load(&payload)? {
Some(snap) => (snap.cursor, snap.entries.len()),
None => (WatchCursor::none(), 0),
};
if loaded_cursor != self.cursor || loaded_len != self.entries.len() {
return Err(SnapshotError::ArtifactInvalid(format!(
"exported payload disagrees with live fold (cursor {loaded_cursor:?} vs {:?}, \
{loaded_len} vs {} entries)",
self.cursor,
self.entries.len()
)));
}
stage.seal_and_finalize(Self::BACKEND, Self::BACKEND_VERSION, &self.cursor)
}
}
fn replay_log(data: &[u8]) -> Result<(HashMap<String, KvEntry>, WatchCursor, bool), SnapshotError> {
if data.len() < HEADER_LEN {
return Err(SnapshotError::InvalidFormat("file too short".into()));
}
if &data[0..4] != MAGIC {
return Err(SnapshotError::InvalidFormat("bad magic".into()));
}
let version = u16::from_le_bytes([data[4], data[5]]);
if version != FORMAT_VERSION {
return Err(SnapshotError::InvalidFormat(format!(
"unsupported version {version}"
)));
}
let estimated = (data.len() - HEADER_LEN) / 30;
let mut live: HashMap<&str, (&[u8], VersionToken)> =
HashMap::with_capacity(estimated.min(4096));
let mut cursor = WatchCursor::none();
let mut pos = HEADER_LEN;
let mut redundant = false;
let mut clean_eof = true;
while pos < data.len() {
match parse_record(&data[pos..]) {
Ok((record, consumed)) => {
match record {
Record::Put {
key,
value,
version,
} => {
if live.insert(key, (value, version)).is_some() {
redundant = true;
}
}
Record::Delete { key } => {
live.remove(key);
redundant = true;
}
Record::Cursor(c) => {
if !cursor.is_none() {
redundant = true;
}
cursor = c;
}
}
pos += consumed;
}
Err(RecordError::Truncated) => {
clean_eof = false;
break;
}
Err(RecordError::CrcMismatch { consumed }) => {
if pos + consumed >= data.len() || data.len() - (pos + consumed) < MIN_CURSOR_RECORD
{
clean_eof = false;
break;
}
return Err(SnapshotError::Corrupted);
}
Err(RecordError::Invalid(msg)) => {
return Err(SnapshotError::InvalidFormat(msg));
}
}
}
let mut entries: HashMap<String, KvEntry> = HashMap::with_capacity(live.len());
for (key, (value, version)) in live {
let key = key.to_string();
entries.insert(
key.clone(),
KvEntry {
key,
value: value.to_vec(),
version,
},
);
}
let already_compact = !redundant && clean_eof;
Ok((entries, cursor, already_compact))
}
enum Record<'a> {
Put {
key: &'a str,
value: &'a [u8],
version: VersionToken,
},
Delete {
key: &'a str,
},
Cursor(WatchCursor),
}
enum RecordError {
Truncated,
CrcMismatch { consumed: usize },
Invalid(String),
}
fn parse_record(data: &[u8]) -> Result<(Record<'_>, usize), RecordError> {
if data.len() < 5 {
return Err(RecordError::Truncated);
}
let stored_crc = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
match data[4] {
REC_PUT => parse_put(data, stored_crc),
REC_DELETE => parse_delete(data, stored_crc),
REC_CURSOR => parse_cursor(data, stored_crc),
other => Err(RecordError::Invalid(format!(
"unknown record type: {other:#x}"
))),
}
}
fn parse_put(data: &[u8], stored_crc: u32) -> Result<(Record<'_>, usize), RecordError> {
if data.len() < 7 {
return Err(RecordError::Truncated);
}
let key_len = u16::from_le_bytes([data[5], data[6]]) as usize;
let vl_off = 7 + key_len;
if data.len() < vl_off + 4 {
return Err(RecordError::Truncated);
}
let value_len = u32::from_le_bytes([
data[vl_off],
data[vl_off + 1],
data[vl_off + 2],
data[vl_off + 3],
]) as usize;
let ver_len_off = vl_off + 4 + value_len;
if data.len() < ver_len_off + 1 {
return Err(RecordError::Truncated);
}
let ver_len = data[ver_len_off] as usize;
if ver_len > 10 {
return Err(RecordError::Invalid(format!(
"version length {ver_len} exceeds max version token size (10)"
)));
}
let total = ver_len_off + 1 + ver_len;
if data.len() < total {
return Err(RecordError::Truncated);
}
let computed = crc32fast::hash(&data[4..total]);
if computed != stored_crc {
return Err(RecordError::CrcMismatch { consumed: total });
}
let key = std::str::from_utf8(&data[7..7 + key_len])
.map_err(|e| RecordError::Invalid(format!("invalid UTF-8 key: {e}")))?;
let value = &data[vl_off + 4..vl_off + 4 + value_len];
let version = VersionToken::from_raw(&data[ver_len_off + 1..total]).ok_or_else(|| {
RecordError::Invalid(format!(
"version length {ver_len} exceeds max version token size (10)"
))
})?;
Ok((
Record::Put {
key,
value,
version,
},
total,
))
}
fn parse_delete(data: &[u8], stored_crc: u32) -> Result<(Record<'_>, usize), RecordError> {
if data.len() < 7 {
return Err(RecordError::Truncated);
}
let key_len = u16::from_le_bytes([data[5], data[6]]) as usize;
let ver_len_off = 7 + key_len;
if data.len() < ver_len_off + 1 {
return Err(RecordError::Truncated);
}
let ver_len = data[ver_len_off] as usize;
if ver_len > 10 {
return Err(RecordError::Invalid(format!(
"version length {ver_len} exceeds max version token size (10)"
)));
}
let total = ver_len_off + 1 + ver_len;
if data.len() < total {
return Err(RecordError::Truncated);
}
let computed = crc32fast::hash(&data[4..total]);
if computed != stored_crc {
return Err(RecordError::CrcMismatch { consumed: total });
}
let key = std::str::from_utf8(&data[7..7 + key_len])
.map_err(|e| RecordError::Invalid(format!("invalid UTF-8 key: {e}")))?;
Ok((Record::Delete { key }, total))
}
fn parse_cursor(data: &[u8], stored_crc: u32) -> Result<(Record<'_>, usize), RecordError> {
if data.len() < 6 {
return Err(RecordError::Truncated);
}
let cursor_len = data[5] as usize;
if cursor_len > 10 {
return Err(RecordError::Invalid(format!(
"cursor length {cursor_len} exceeds max version token size (10)"
)));
}
let total = 6 + cursor_len;
if data.len() < total {
return Err(RecordError::Truncated);
}
let computed = crc32fast::hash(&data[4..total]);
if computed != stored_crc {
return Err(RecordError::CrcMismatch { consumed: total });
}
let version = VersionToken::from_raw(&data[6..total]).ok_or_else(|| {
RecordError::Invalid(format!(
"cursor length {cursor_len} exceeds max version token size (10)"
))
})?;
Ok((Record::Cursor(WatchCursor::from_version(version)), total))
}
fn write_put_record(
w: &mut impl Write,
key: &str,
value: &[u8],
version: &VersionToken,
) -> Result<usize, SnapshotError> {
let kb = key.as_bytes();
let vb = version.as_bytes();
let key_len = u16::try_from(kb.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"key too long: {} bytes (max {})",
kb.len(),
u16::MAX
))
})?;
let value_len = u32::try_from(value.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"value too long: {} bytes (max {})",
value.len(),
u32::MAX
))
})?;
let ver_len = u8::try_from(vb.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"version too long: {} bytes (max {})",
vb.len(),
u8::MAX
))
})?;
let mut h = crc32fast::Hasher::new();
h.update(&[REC_PUT]);
h.update(&key_len.to_le_bytes());
h.update(kb);
h.update(&value_len.to_le_bytes());
h.update(value);
h.update(&[ver_len]);
h.update(vb);
let crc = h.finalize();
w.write_all(&crc.to_le_bytes())?;
w.write_all(&[REC_PUT])?;
w.write_all(&key_len.to_le_bytes())?;
w.write_all(kb)?;
w.write_all(&value_len.to_le_bytes())?;
w.write_all(value)?;
w.write_all(&[ver_len])?;
w.write_all(vb)?;
Ok(4 + 1 + 2 + kb.len() + 4 + value.len() + 1 + vb.len())
}
fn write_delete_record(
w: &mut impl Write,
key: &str,
version: &VersionToken,
) -> Result<usize, SnapshotError> {
let kb = key.as_bytes();
let vb = version.as_bytes();
let key_len = u16::try_from(kb.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"key too long: {} bytes (max {})",
kb.len(),
u16::MAX
))
})?;
let ver_len = u8::try_from(vb.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"version too long: {} bytes (max {})",
vb.len(),
u8::MAX
))
})?;
let mut h = crc32fast::Hasher::new();
h.update(&[REC_DELETE]);
h.update(&key_len.to_le_bytes());
h.update(kb);
h.update(&[ver_len]);
h.update(vb);
let crc = h.finalize();
w.write_all(&crc.to_le_bytes())?;
w.write_all(&[REC_DELETE])?;
w.write_all(&key_len.to_le_bytes())?;
w.write_all(kb)?;
w.write_all(&[ver_len])?;
w.write_all(vb)?;
Ok(4 + 1 + 2 + kb.len() + 1 + vb.len())
}
fn write_cursor_record(w: &mut impl Write, cursor: &WatchCursor) -> Result<usize, SnapshotError> {
let cb = cursor.version().as_bytes();
let cb_len = u8::try_from(cb.len()).map_err(|_| {
SnapshotError::InvalidFormat(format!(
"cursor too long: {} bytes (max {})",
cb.len(),
u8::MAX
))
})?;
let mut h = crc32fast::Hasher::new();
h.update(&[REC_CURSOR]);
h.update(&[cb_len]);
h.update(cb);
let crc = h.finalize();
w.write_all(&crc.to_le_bytes())?;
w.write_all(&[REC_CURSOR])?;
w.write_all(&[cb_len])?;
w.write_all(cb)?;
Ok(4 + 1 + 1 + cb.len())
}
fn compact_to_file(
path: &Path,
entries: &HashMap<String, KvEntry>,
cursor: &WatchCursor,
) -> Result<(), SnapshotError> {
let dir = path.parent().ok_or_else(|| {
SnapshotError::InvalidFormat(format!("snapshot path has no parent: {}", path.display()))
})?;
let mut sorted: Vec<&KvEntry> = entries.values().collect();
sorted.sort_unstable_by(|a, b| a.key.cmp(&b.key));
let estimated: usize = HEADER_LEN
+ sorted
.iter()
.map(|e| 4 + 1 + 2 + e.key.len() + 4 + e.value.len() + 1 + e.version.as_bytes().len())
.sum::<usize>()
+ if cursor.is_none() {
0
} else {
4 + 1 + 1 + cursor.version().as_bytes().len()
};
let capacity = estimated.clamp(8 * 1024, 1024 * 1024);
let mut buf = io::BufWriter::with_capacity(capacity, tempfile::NamedTempFile::new_in(dir)?);
buf.write_all(MAGIC)?;
buf.write_all(&FORMAT_VERSION.to_le_bytes())?;
for entry in sorted {
write_put_record(&mut buf, &entry.key, &entry.value, &entry.version)?;
}
if !cursor.is_none() {
write_cursor_record(&mut buf, cursor)?;
}
buf.flush()?;
let tmp = buf
.into_inner()
.map_err(|e| SnapshotError::Io(e.into_error()))?;
tmp.as_file().sync_all()?;
tmp.persist(path).map_err(|e| SnapshotError::Io(e.error))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn entry(key: &str, value: &[u8], rev: u64) -> KvEntry {
KvEntry {
key: key.to_string(),
value: value.to_vec(),
version: VersionToken::from_u64(rev),
}
}
fn cursor(rev: u64) -> WatchCursor {
WatchCursor::from_u64(rev)
}
fn put(key: &str, value: &[u8], rev: u64) -> KvUpdate {
KvUpdate::Put(entry(key, value, rev))
}
fn delete(key: &str, rev: u64) -> KvUpdate {
KvUpdate::Delete {
key: key.to_string(),
version: VersionToken::from_u64(rev),
}
}
#[test]
fn round_trip() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("node.us-east-1", b"val1", 1)).unwrap();
w.write_update(&put("node.eu-west-1", b"val2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 2);
assert_eq!(snap.cursor.as_u64(), Some(2));
assert_eq!(snap.entries["node.us-east-1"].value, b"val1");
assert_eq!(snap.entries["node.eu-west-1"].value, b"val2");
}
#[test]
fn multiple_batches() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.checkpoint(&cursor(1)).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 2);
assert_eq!(snap.cursor.as_u64(), Some(2));
}
#[test]
fn delete_removes_entry() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
w.write_update(&delete("a", 3)).unwrap();
w.checkpoint(&cursor(3)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 1);
assert!(snap.entries.contains_key("b"));
}
#[test]
fn purge_removes_entry() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.checkpoint(&cursor(1)).unwrap();
w.write_update(&KvUpdate::Purge {
key: "a".to_string(),
version: VersionToken::from_u64(2),
})
.unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert!(!snap.entries.contains_key("a"));
}
#[test]
fn missing_file_returns_none() {
let dir = TempDir::new().unwrap();
assert!(load(&dir.path().join("nope.snap")).unwrap().is_none());
}
#[test]
fn corrupted_mid_file() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"aaaa-long-value-here", 1))
.unwrap();
w.checkpoint(&cursor(1)).unwrap();
w.write_update(&put("b", b"bbbb-long-value-here", 2))
.unwrap();
w.checkpoint(&cursor(2)).unwrap();
w.write_update(&put("c", b"cccc-long-value-here", 3))
.unwrap();
w.checkpoint(&cursor(3)).unwrap();
drop(w);
let mut data = fs::read(&path).unwrap();
let target = HEADER_LEN + 40;
assert!(
target < data.len() - 60,
"need enough room after corruption for valid records"
);
data[target] ^= 0xFF;
fs::write(&path, &data).unwrap();
match load(&path) {
Err(SnapshotError::Corrupted) => {}
other => panic!("expected Corrupted, got {other:?}"),
}
}
#[test]
fn truncated_final_record_recovered() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.checkpoint(&cursor(1)).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let mut data = fs::read(&path).unwrap();
data.truncate(data.len() - 3);
fs::write(&path, &data).unwrap();
let snap = load(&path).unwrap().unwrap();
assert!(snap.entries.contains_key("a"));
}
#[test]
fn truncated_tail_repaired_then_appendable() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let mut data = fs::read(&path).unwrap();
data.truncate(data.len() - 3);
fs::write(&path, &data).unwrap();
let snap = load(&path).unwrap().unwrap();
assert!(snap.entries.contains_key("a"));
assert!(snap.entries.contains_key("b"));
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("c", b"v3", 3)).unwrap();
w.checkpoint(&cursor(3)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 3);
assert!(snap.entries.contains_key("c"));
assert_eq!(snap.cursor.as_u64(), Some(3));
}
#[test]
fn repeated_cursor_records_trigger_compaction() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
for i in 1..=10u64 {
w.checkpoint(&cursor(i)).unwrap();
}
drop(w);
let size_before = fs::metadata(&path).unwrap().len();
let snap = load(&path).unwrap().unwrap();
let size_after = fs::metadata(&path).unwrap().len();
assert_eq!(snap.entries.len(), 1);
assert_eq!(snap.cursor.as_u64(), Some(10));
assert!(
size_after < size_before,
"stale cursor records should be compacted away: {size_before} -> {size_after}"
);
let after_first = fs::read(&path).unwrap();
load(&path).unwrap().unwrap();
let after_second = fs::read(&path).unwrap();
assert_eq!(after_first, after_second);
}
#[test]
fn already_compact_file_reloads_unchanged() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
load(&path).unwrap().unwrap();
let after_first = fs::read(&path).unwrap();
let snap = load(&path).unwrap().unwrap();
let after_second = fs::read(&path).unwrap();
assert_eq!(after_first, after_second);
assert_eq!(snap.entries.len(), 2);
assert_eq!(snap.cursor.as_u64(), Some(2));
}
#[test]
fn bad_magic() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
fs::write(&path, b"XXXX\x01\x00").unwrap();
match load(&path) {
Err(SnapshotError::InvalidFormat(msg)) => assert!(msg.contains("magic")),
other => panic!("expected InvalidFormat, got {other:?}"),
}
}
#[test]
fn wrong_version_rejected() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut data = Vec::new();
data.extend_from_slice(MAGIC);
data.extend_from_slice(&(FORMAT_VERSION + 1).to_le_bytes());
fs::write(&path, &data).unwrap();
match load(&path) {
Err(SnapshotError::InvalidFormat(msg)) => {
assert!(
msg.contains("version"),
"message should mention version: {msg}"
)
}
other => panic!("expected InvalidFormat, got {other:?}"),
}
}
#[test]
fn empty_log_returns_none() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut f = File::create(&path).unwrap();
f.write_all(MAGIC).unwrap();
f.write_all(&FORMAT_VERSION.to_le_bytes()).unwrap();
drop(f);
assert!(load(&path).unwrap().is_none());
}
#[test]
fn compaction_on_load_shrinks_file() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
for i in 0..10u64 {
w.write_update(&put("same-key", format!("v{i}").as_bytes(), i))
.unwrap();
w.checkpoint(&cursor(i)).unwrap();
}
drop(w);
let size_before = fs::metadata(&path).unwrap().len();
let snap = load(&path).unwrap().unwrap();
let size_after = fs::metadata(&path).unwrap().len();
assert_eq!(snap.entries.len(), 1);
assert_eq!(snap.entries["same-key"].value, b"v9");
assert!(
size_after < size_before,
"compaction should shrink: {size_before} -> {size_after}"
);
}
#[test]
fn compact_when_threshold_exceeded() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, 100).unwrap();
for i in 0..20u64 {
w.write_update(&put("key", format!("value-{i}").as_bytes(), i))
.unwrap();
if w.checkpoint(&cursor(i)).unwrap() {
w.compact().unwrap();
}
}
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 1);
assert_eq!(snap.entries["key"].value, b"value-19");
}
#[test]
fn reopen_appends() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("a", b"v1", 1)).unwrap();
w.checkpoint(&cursor(1)).unwrap();
drop(w);
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("b", b"v2", 2)).unwrap();
w.checkpoint(&cursor(2)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 2);
assert_eq!(snap.cursor.as_u64(), Some(2));
}
#[test]
fn large_values() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let big = vec![0xABu8; 100_000];
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("big", &big, 1)).unwrap();
w.checkpoint(&cursor(1)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(snap.entries.len(), 1);
assert_eq!(snap.entries["big"].value.len(), 100_000);
assert!(snap.entries["big"].value.iter().all(|&b| b == 0xAB));
}
#[test]
fn cursor_only_snapshot_returns_some_with_empty_entries() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.checkpoint(&cursor(42)).unwrap();
drop(w);
let snap = load(&path)
.unwrap()
.expect("cursor-only snapshot should return Some");
assert!(snap.entries.is_empty(), "no entries written, none expected");
assert_eq!(
snap.cursor.as_u64(),
Some(42),
"cursor must survive round-trip"
);
}
#[test]
fn stale_keys_detects_removed_entries() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("node.a", b"v1", 1)).unwrap();
w.write_update(&put("node.b", b"v2", 2)).unwrap();
w.write_update(&put("node.c", b"v3", 3)).unwrap();
w.checkpoint(&cursor(3)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
let mut stale = snap.stale_keys(["node.a", "node.c"]);
stale.sort();
assert_eq!(stale, vec!["node.b"]);
let stale = snap.stale_keys(["node.a", "node.b", "node.c"]);
assert!(stale.is_empty());
let mut stale: Vec<&str> = snap.stale_keys(std::iter::empty::<&str>());
stale.sort();
assert_eq!(stale, vec!["node.a", "node.b", "node.c"]);
}
#[test]
fn non_u64_version_token_round_trips() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let stamp = [9u8, 8, 7, 6, 5, 4, 3, 2, 1, 0];
let token = VersionToken::from_fdb_versionstamp(&stamp);
assert!(token.as_u64().is_none(), "10-byte token is not a u64");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&KvUpdate::Put(KvEntry {
key: "fdb.key".to_string(),
value: b"v".to_vec(),
version: token.clone(),
}))
.unwrap();
w.checkpoint(&cursor(1)).unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert_eq!(
snap.entries["fdb.key"].version.as_bytes(),
&stamp,
"versionstamp must survive the snapshot round-trip byte-for-byte"
);
}
#[test]
fn compact_preserves_uncheckpointed_writes() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.snap");
let mut w = SnapshotWriter::open(&path, u64::MAX).unwrap();
w.write_update(&put("node.a", b"survives", 1)).unwrap();
w.compact().unwrap();
drop(w);
let snap = load(&path).unwrap().unwrap();
assert!(
snap.entries.contains_key("node.a"),
"compact() must not drop buffered-but-uncheckpointed writes"
);
assert_eq!(snap.entries["node.a"].value, b"survives");
}
}